mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1378 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
import functools
|
|
8
|
+
import sys
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import (
|
|
11
|
+
Any,
|
|
12
|
+
cast,
|
|
13
|
+
Dict,
|
|
14
|
+
Iterable,
|
|
15
|
+
List,
|
|
16
|
+
Optional,
|
|
17
|
+
Sequence,
|
|
18
|
+
Tuple,
|
|
19
|
+
Type,
|
|
20
|
+
TYPE_CHECKING,
|
|
21
|
+
Union,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
|
|
26
|
+
from ._triton.available import is_triton_available
|
|
27
|
+
from .attn_bias import (
|
|
28
|
+
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
|
|
29
|
+
BlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
30
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
31
|
+
BlockDiagonalGappyKeysMask,
|
|
32
|
+
BlockDiagonalLocalAttentionPaddedKeysMask,
|
|
33
|
+
BlockDiagonalPaddedKeysMask,
|
|
34
|
+
PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
35
|
+
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
36
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
37
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
38
|
+
)
|
|
39
|
+
from .common import AttentionFwOpBase, check_lastdim_alignment_stride1, Context, Inputs
|
|
40
|
+
from .utils.op_common import register_operator
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _strides(x: Optional[torch.Tensor], *stride_names: str):
|
|
44
|
+
if x is None:
|
|
45
|
+
return {f"stride_{name}": None for name in stride_names}
|
|
46
|
+
assert x.ndim == len(stride_names)
|
|
47
|
+
return {f"stride_{name}": s for name, s in zip(stride_names, x.stride())}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _is_supported_causal_bias(attn_bias: Any) -> bool:
|
|
51
|
+
return isinstance(
|
|
52
|
+
attn_bias,
|
|
53
|
+
(
|
|
54
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
55
|
+
BlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
56
|
+
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
|
|
57
|
+
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
58
|
+
PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
59
|
+
),
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _is_supported_local_bias(attn_bias: Any) -> bool:
|
|
64
|
+
return isinstance(
|
|
65
|
+
attn_bias,
|
|
66
|
+
(
|
|
67
|
+
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
|
|
68
|
+
BlockDiagonalLocalAttentionPaddedKeysMask,
|
|
69
|
+
),
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _is_supported_gappy_bias(attn_bias: Any) -> bool:
|
|
74
|
+
return isinstance(
|
|
75
|
+
attn_bias,
|
|
76
|
+
(
|
|
77
|
+
BlockDiagonalGappyKeysMask,
|
|
78
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
79
|
+
),
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _is_supported_paged_bias(attn_bias: Any) -> bool:
|
|
84
|
+
return isinstance(
|
|
85
|
+
attn_bias,
|
|
86
|
+
(
|
|
87
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
88
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
89
|
+
),
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@dataclass
|
|
94
|
+
class InputsFp8(Inputs):
|
|
95
|
+
"""
|
|
96
|
+
Each of k/v_fp8_scales is an int32 tensor of shape (1, B * Mkv, Hq),
|
|
97
|
+
or (1, page_size * max_pages_per_lane, Hq) in the paged case.
|
|
98
|
+
Each int32 element contains two packed fp16 number
|
|
99
|
+
- scales and shifts for row-wise FP8 quantization.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
k_fp8_scale_shift: Optional[torch.Tensor] = None
|
|
103
|
+
v_fp8_scale_shift: Optional[torch.Tensor] = None
|
|
104
|
+
q_fp8_scale_shift: Optional[torch.Tensor] = None
|
|
105
|
+
quantize_pv_to_fp8: bool = False
|
|
106
|
+
quantize_qk_to_fp8: bool = False
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def nbytes(self) -> int:
|
|
110
|
+
"""
|
|
111
|
+
Number of bytes in the input, not counting the attention bias.
|
|
112
|
+
"""
|
|
113
|
+
return (
|
|
114
|
+
super(InputsFp8, self).nbytes
|
|
115
|
+
+ (
|
|
116
|
+
self.k_fp8_scale_shift.untyped_storage().nbytes()
|
|
117
|
+
if self.k_fp8_scale_shift is not None
|
|
118
|
+
else 0
|
|
119
|
+
)
|
|
120
|
+
+ (
|
|
121
|
+
self.v_fp8_scale_shift.untyped_storage().nbytes()
|
|
122
|
+
if self.v_fp8_scale_shift is not None
|
|
123
|
+
else 0
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
if TYPE_CHECKING or is_triton_available():
|
|
129
|
+
from ._triton.splitk_kernels import _fwd_kernel_splitK, _splitK_reduce
|
|
130
|
+
else:
|
|
131
|
+
_fwd_kernel_splitK = None
|
|
132
|
+
_splitK_reduce = None
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _is_cuda() -> bool:
|
|
136
|
+
return torch.version.cuda is not None
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _is_cuda_at_least_sm80(device: torch.device) -> bool:
|
|
140
|
+
return _is_cuda() and torch.cuda.get_device_capability(device) >= (
|
|
141
|
+
8,
|
|
142
|
+
0,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@register_operator
|
|
147
|
+
class FwOp(AttentionFwOpBase):
|
|
148
|
+
"""Flash-Attention with Split-K. Supports fused int4 and fp8 K/V quantization.
|
|
149
|
+
Quantized path will be taken if input K/V have type int32.
|
|
150
|
+
|
|
151
|
+
Int4 quantization can be row-wise or group-wise (when cls.NUM_GROUPS > 1) along
|
|
152
|
+
the last dimension of K and V. Currently 1, 2, 4, or 8 groups per row are supported.
|
|
153
|
+
Quantization coefficients (scale and shift) are represented as two
|
|
154
|
+
float16 constants per group, packed into int32. Quantization coefficients of
|
|
155
|
+
all groups are placed at the beginning of the row. So, if unquantized K/V have head
|
|
156
|
+
dimension D, the quantized versions have head dimension D // 8 + NUM_GROUPS
|
|
157
|
+
and dtype int32.
|
|
158
|
+
Pseudocode for dequantizing one row can look like:
|
|
159
|
+
group_size = D // 8
|
|
160
|
+
for i in range(NUM_GROUPS):
|
|
161
|
+
group_start = NUM_GROUPS + i * group_size
|
|
162
|
+
group_quant = K[..., group_start: group_start + group_size]
|
|
163
|
+
scale, shift = unpack_int32_into_float16x2(group_quant[0])
|
|
164
|
+
group_dequant = group_quant[..., 1:] * scale + shift
|
|
165
|
+
...
|
|
166
|
+
|
|
167
|
+
For fp8 only row-wise quantization is supported. To use it, provide input of type
|
|
168
|
+
xformers.ops.fmha.triton_splitk.InputsFp8 (instead of the usual xformers.ops.fmha.Inputs) to
|
|
169
|
+
xformers.ops.fmha.triton_splitk.FwOp.apply or xformers.ops.fmha._memory_efficient_attention_forward.
|
|
170
|
+
|
|
171
|
+
This op uses Paged Attention when bias is one of the Paged* classes.
|
|
172
|
+
In this case bias has additional fields:
|
|
173
|
+
- block_tables of shape [batch_size, max_num_pages]
|
|
174
|
+
- K/V of shape [1, max_num_pages * page_size, num_heads, head_dim]
|
|
175
|
+
or [1, max_num_pages * page_size, num_groups, num_heads, head_dim]
|
|
176
|
+
|
|
177
|
+
The shape which the kernel takes the queries and the output
|
|
178
|
+
is quite different from the user interface. There are three
|
|
179
|
+
types of input (a) no bias / tensor bias, (b) variable q_len
|
|
180
|
+
(which is only for non causal) and (c) other bias objects.
|
|
181
|
+
From the interface to the kernel the following changes happen.
|
|
182
|
+
|
|
183
|
+
(0) In all cases, a group dimension may need to be added.
|
|
184
|
+
|
|
185
|
+
(1) For (c), a batch dimension is created, reshaping from (1, B*Mq, G, Hq, K)
|
|
186
|
+
to (B, Mq, G, Hq, K)
|
|
187
|
+
|
|
188
|
+
(2) For (a) and (c), in the case of multiquery (i.e. the head dimension
|
|
189
|
+
of keys and values is expanded), the head-swapping trick
|
|
190
|
+
reshaping from (B, Mq, G, Hq, K) to (B, M=Hq*Mq, G, H=1, K)
|
|
191
|
+
|
|
192
|
+
(3) For (b), in the case of multiquery, the head-swapping trick
|
|
193
|
+
trick, reshaping from (1, Mq, G, Hq, K) to (1, Mq*Hq, G, H=1, K)
|
|
194
|
+
Note here that Mq is a single long dimension which spans all the queries
|
|
195
|
+
in the batch, unlike in case (C). Also that Hq has to run faster than
|
|
196
|
+
Mq in order that the queries in a batch element remain evenly spaced.
|
|
197
|
+
|
|
198
|
+
In all cases, the shape as seen by the kernel is called (Bqq, Mqq, G, H, K).
|
|
199
|
+
The kernel operates on B batch elements and M queries per batch element.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
OPERATOR = True
|
|
203
|
+
SUPPORTED_DEVICES = {"cuda"}
|
|
204
|
+
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
|
|
205
|
+
SUPPORTED_DTYPES = {
|
|
206
|
+
torch.half,
|
|
207
|
+
torch.bfloat16,
|
|
208
|
+
torch.float8_e4m3fn,
|
|
209
|
+
} # Those are dtypes of Q. In the quantized case K/V has dtype int32
|
|
210
|
+
SUPPORTED_MAX_K = 512
|
|
211
|
+
SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
|
|
212
|
+
type(None),
|
|
213
|
+
torch.Tensor,
|
|
214
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
215
|
+
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
|
|
216
|
+
BlockDiagonalLocalAttentionPaddedKeysMask,
|
|
217
|
+
BlockDiagonalGappyKeysMask,
|
|
218
|
+
BlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
219
|
+
BlockDiagonalPaddedKeysMask,
|
|
220
|
+
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
221
|
+
PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
222
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
223
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
224
|
+
)
|
|
225
|
+
SUPPORTS_DROPOUT = False
|
|
226
|
+
SUPPORTS_CUSTOM_SCALE = True
|
|
227
|
+
SUPPORTS_BMGHK = True
|
|
228
|
+
SUPPORTS_OUTPUT_DTYPE = True
|
|
229
|
+
SUPPORTS_PARTIAL = True
|
|
230
|
+
NAME = "triton_splitKF"
|
|
231
|
+
|
|
232
|
+
SPLIT_K: Optional[int] = None
|
|
233
|
+
MAX_BLOCK_M = 32
|
|
234
|
+
|
|
235
|
+
# Whether blocks attending to no part of a variable sequence length
|
|
236
|
+
# should exit early. This requires extra kernels to run beforehand
|
|
237
|
+
# to initialise the outputs.
|
|
238
|
+
# TODO: avoid these by making the reduce kernel work out it doesn't need
|
|
239
|
+
# to look at the irrelevant places.
|
|
240
|
+
SPLIT_K_EARLY_EXIT: bool = False
|
|
241
|
+
|
|
242
|
+
# Perform kernel-level Triton autotune
|
|
243
|
+
AUTOTUNE = False
|
|
244
|
+
|
|
245
|
+
NUM_GROUPS = 1 # Default quantization is row-wise
|
|
246
|
+
NUM_GROUPS_VALUES = [1, 2, 4, 8]
|
|
247
|
+
|
|
248
|
+
# Values below are used when autotune=False.
|
|
249
|
+
# Note that under certain conditions different values might be used, see the code just before the kernel launch.
|
|
250
|
+
BLOCK_M: int = 16 # When M > 1, different BLOCK_M can be used.
|
|
251
|
+
BLOCK_N: int = 64
|
|
252
|
+
# On AMD or for M > 1 different NUM_STAGES and NUM_WARPS can be used.
|
|
253
|
+
NUM_STAGES: int = 1
|
|
254
|
+
NUM_WARPS: int = 2
|
|
255
|
+
|
|
256
|
+
@classmethod
|
|
257
|
+
def shape_not_supported_reasons(
|
|
258
|
+
cls, Mq: int, Mkv: int, K: int, Kv: int
|
|
259
|
+
) -> List[str]:
|
|
260
|
+
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
|
|
261
|
+
if K not in {16, 32, 64, 128, 256, 512}:
|
|
262
|
+
reasons.append(f"Embed dim {K} not supported")
|
|
263
|
+
return reasons
|
|
264
|
+
|
|
265
|
+
@classmethod
|
|
266
|
+
def not_supported_reasons(cls, d: Inputs) -> List[str]: # noqa: C901
|
|
267
|
+
reasons = super(FwOp, cls).not_supported_reasons(d)
|
|
268
|
+
if (sys.version_info.major, sys.version_info.minor) < (3, 9):
|
|
269
|
+
reasons.append("triton_splitk requires python 3.9 or above!")
|
|
270
|
+
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
|
271
|
+
if d.key.dtype != torch.int32:
|
|
272
|
+
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
|
|
273
|
+
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
|
|
274
|
+
if cls.OPERATOR is None:
|
|
275
|
+
reasons.append("triton is not available")
|
|
276
|
+
if d.device.type == "cuda":
|
|
277
|
+
# Has only been tested on 8.0 / 9.0.
|
|
278
|
+
if _is_cuda() and not _is_cuda_at_least_sm80(d.device):
|
|
279
|
+
reasons.append(
|
|
280
|
+
"requires NVidia GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
|
|
281
|
+
)
|
|
282
|
+
# TODO: AMD GPU support matrix needs to be figured out. MI300X is tested to work.
|
|
283
|
+
|
|
284
|
+
q_len = d.query.shape[1]
|
|
285
|
+
is_block_diagonal = isinstance(
|
|
286
|
+
d.attn_bias, (BlockDiagonalPaddedKeysMask, BlockDiagonalGappyKeysMask)
|
|
287
|
+
)
|
|
288
|
+
is_paged = _is_supported_paged_bias(d.attn_bias)
|
|
289
|
+
is_causal = _is_supported_causal_bias(d.attn_bias)
|
|
290
|
+
is_local = _is_supported_local_bias(d.attn_bias)
|
|
291
|
+
if is_block_diagonal or is_paged:
|
|
292
|
+
seqinfo = d.attn_bias.q_seqinfo # type: ignore
|
|
293
|
+
if q_len != seqinfo.seqstart_py[-1]:
|
|
294
|
+
reasons.append(
|
|
295
|
+
f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}"
|
|
296
|
+
)
|
|
297
|
+
q_len = seqinfo.max_seqlen
|
|
298
|
+
if is_causal and q_len != seqinfo.min_seqlen:
|
|
299
|
+
reasons.append(
|
|
300
|
+
f"Variable query len is not supported for causal masks: got {seqinfo.max_seqlen=} {seqinfo.min_seqlen=}."
|
|
301
|
+
)
|
|
302
|
+
elif is_local and q_len != seqinfo.min_seqlen:
|
|
303
|
+
reasons.append(
|
|
304
|
+
f"Variable query len is not supported for local masks: got {seqinfo.max_seqlen=} {seqinfo.min_seqlen=}."
|
|
305
|
+
)
|
|
306
|
+
if q_len > 16 and (is_causal or is_local):
|
|
307
|
+
# 16 is the minimum BLOCK_M which gets used
|
|
308
|
+
# XXX I don't really understand why this is needed.
|
|
309
|
+
reasons.append(
|
|
310
|
+
"Query length should not be larger than 16 for causal or local attention biases"
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
if is_paged:
|
|
314
|
+
page_size = d.attn_bias.page_size # type: ignore
|
|
315
|
+
if d.key.shape[1] % page_size:
|
|
316
|
+
reasons.append(
|
|
317
|
+
"For paged attention, key.shape[1] should be divisible "
|
|
318
|
+
"by the page size, "
|
|
319
|
+
f"but got {d.key.shape[1]=}, {page_size=}."
|
|
320
|
+
)
|
|
321
|
+
if page_size % cls.BLOCK_N:
|
|
322
|
+
reasons.append(
|
|
323
|
+
"For paged attention, page size should be divisible "
|
|
324
|
+
"by the block size, "
|
|
325
|
+
f"but got {page_size=}, {cls.BLOCK_N=}."
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
if isinstance(d.attn_bias, torch.Tensor):
|
|
329
|
+
if d.attn_bias.ndim not in (4, 5):
|
|
330
|
+
reasons.append(
|
|
331
|
+
"Additive attention bias has to have shape (B, G, H, Mq, Mkv) "
|
|
332
|
+
f"or (B, H, Mq, Mkv), but got {d.attn_bias.shape}."
|
|
333
|
+
)
|
|
334
|
+
if cls.SPLIT_K is not None and cls.SPLIT_K > 1:
|
|
335
|
+
reasons.append(
|
|
336
|
+
"Additive attention bias is not supported with split-k > 1."
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
return reasons
|
|
340
|
+
|
|
341
|
+
@classmethod
|
|
342
|
+
def get_split_k(
|
|
343
|
+
cls, B: int, G: int, H: int, Mk: int, Mq: int, page_size: int, is_paged=False
|
|
344
|
+
) -> int:
|
|
345
|
+
"""Heuristic for the number of splits"""
|
|
346
|
+
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
|
|
347
|
+
if torch.version.hip:
|
|
348
|
+
split_k = max(Mk + bh - 1, 1024) // bh
|
|
349
|
+
max_chunk_size = 64
|
|
350
|
+
split_k_stop_val = max(1024 / (B * G * H), 1)
|
|
351
|
+
while split_k > 1 and Mk / (split_k - 1) < max_chunk_size:
|
|
352
|
+
split_k = split_k - 1
|
|
353
|
+
|
|
354
|
+
while split_k > split_k_stop_val:
|
|
355
|
+
split_k = split_k // 2
|
|
356
|
+
|
|
357
|
+
split_size = (Mk + split_k - 1) // max(split_k, 1)
|
|
358
|
+
|
|
359
|
+
chunk_size = split_size // max_chunk_size * max_chunk_size
|
|
360
|
+
if chunk_size < split_size:
|
|
361
|
+
split_k += 1
|
|
362
|
+
|
|
363
|
+
split_k_upper_bound = 512
|
|
364
|
+
else:
|
|
365
|
+
if Mq > 1 and B * G * H > 64:
|
|
366
|
+
return 1
|
|
367
|
+
split_k = max(Mk, 1024) // bh
|
|
368
|
+
max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
|
|
369
|
+
split_k_stop_val = Mk / max_chunk_size
|
|
370
|
+
split_k_upper_bound = 64
|
|
371
|
+
|
|
372
|
+
while split_k > split_k_stop_val:
|
|
373
|
+
split_k = split_k // 2
|
|
374
|
+
|
|
375
|
+
split_k = min(split_k, split_k_upper_bound)
|
|
376
|
+
split_k = max(split_k, 1)
|
|
377
|
+
|
|
378
|
+
# makes no sense that split_size is larger than page_size
|
|
379
|
+
if is_paged and torch.version.hip:
|
|
380
|
+
split_size = (Mk + split_k - 1) // split_k
|
|
381
|
+
if split_size > page_size:
|
|
382
|
+
split_size = page_size
|
|
383
|
+
split_k = (Mk + split_size - 1) // split_size
|
|
384
|
+
|
|
385
|
+
return split_k
|
|
386
|
+
|
|
387
|
+
@classmethod
|
|
388
|
+
def get_kernel(cls):
|
|
389
|
+
from ._triton.splitk_kernels import (
|
|
390
|
+
_fwd_kernel_splitK_autotune,
|
|
391
|
+
_get_splitk_kernel,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
if cls.AUTOTUNE:
|
|
395
|
+
return _fwd_kernel_splitK_autotune[cls.NUM_GROUPS]
|
|
396
|
+
else:
|
|
397
|
+
return _get_splitk_kernel(cls.NUM_GROUPS)
|
|
398
|
+
|
|
399
|
+
@classmethod
|
|
400
|
+
def get_fp8_scale_shift(
|
|
401
|
+
cls, inp: Inputs
|
|
402
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
403
|
+
if not hasattr(inp, "k_fp8_scale_shift"):
|
|
404
|
+
return None, None, None
|
|
405
|
+
inp_ = cast(InputsFp8, inp)
|
|
406
|
+
k_fp8_scale_shift = inp_.k_fp8_scale_shift
|
|
407
|
+
v_fp8_scale_shift = inp_.v_fp8_scale_shift
|
|
408
|
+
q_fp8_scale_shift = inp_.q_fp8_scale_shift
|
|
409
|
+
|
|
410
|
+
assert k_fp8_scale_shift is not None
|
|
411
|
+
assert v_fp8_scale_shift is not None
|
|
412
|
+
if k_fp8_scale_shift.ndim == 3:
|
|
413
|
+
k_fp8 = k_fp8_scale_shift.unsqueeze(2)
|
|
414
|
+
v_fp8 = v_fp8_scale_shift.unsqueeze(2)
|
|
415
|
+
q_fp8 = (
|
|
416
|
+
None if q_fp8_scale_shift is None else q_fp8_scale_shift.unsqueeze(2)
|
|
417
|
+
)
|
|
418
|
+
return k_fp8, v_fp8, q_fp8
|
|
419
|
+
if k_fp8_scale_shift.ndim == 4:
|
|
420
|
+
return k_fp8_scale_shift, v_fp8_scale_shift, q_fp8_scale_shift
|
|
421
|
+
raise ValueError(
|
|
422
|
+
"FP8 scales have to be provided in BMH or BMGH format, "
|
|
423
|
+
f"but got {k_fp8_scale_shift.shape=}"
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
@classmethod
|
|
427
|
+
def get_extra_args( # noqa: C901
|
|
428
|
+
cls,
|
|
429
|
+
*,
|
|
430
|
+
is_paged: bool,
|
|
431
|
+
B: int,
|
|
432
|
+
M: int,
|
|
433
|
+
Kkv: int,
|
|
434
|
+
Kq: int,
|
|
435
|
+
Mq: int,
|
|
436
|
+
split_k: int,
|
|
437
|
+
attn_bias: Any,
|
|
438
|
+
k_fp8_scale_shift: Any,
|
|
439
|
+
) -> Dict[str, Any]:
|
|
440
|
+
BLOCK_M = cls.BLOCK_M
|
|
441
|
+
BLOCK_N = cls.BLOCK_N
|
|
442
|
+
if cls.AUTOTUNE:
|
|
443
|
+
extra_args = {}
|
|
444
|
+
else:
|
|
445
|
+
# TODO: remove this when autotuning on AMD is working
|
|
446
|
+
num_warps = cls.NUM_WARPS
|
|
447
|
+
num_stages = cls.NUM_STAGES
|
|
448
|
+
if torch.version.hip and attn_bias is not None:
|
|
449
|
+
# TODO: Double check paged.
|
|
450
|
+
mkv = attn_bias.k_seqinfo.max_seqlen
|
|
451
|
+
# TODO: Determine heuristics for paged attention
|
|
452
|
+
use_fp8_path = k_fp8_scale_shift is not None
|
|
453
|
+
if B == 1:
|
|
454
|
+
if use_fp8_path:
|
|
455
|
+
# Use specialized configs for FP8
|
|
456
|
+
if mkv <= 256:
|
|
457
|
+
BLOCK_N = 16
|
|
458
|
+
num_warps = 4
|
|
459
|
+
num_stages = 1
|
|
460
|
+
elif mkv <= 2048:
|
|
461
|
+
BLOCK_N = 32
|
|
462
|
+
num_warps = 4
|
|
463
|
+
num_stages = 1
|
|
464
|
+
elif mkv <= 16384:
|
|
465
|
+
BLOCK_N = 64
|
|
466
|
+
num_warps = 4
|
|
467
|
+
num_stages = 1
|
|
468
|
+
elif mkv >= 131072:
|
|
469
|
+
BLOCK_N = 128
|
|
470
|
+
num_warps = 2
|
|
471
|
+
num_stages = 1
|
|
472
|
+
else:
|
|
473
|
+
# Note: We don't have data for when transitioning num_wraps works well
|
|
474
|
+
BLOCK_N = 64
|
|
475
|
+
num_warps = 4
|
|
476
|
+
num_stages = 1
|
|
477
|
+
else:
|
|
478
|
+
num_warps = 4
|
|
479
|
+
num_stages = 1 # TODO num_stages = 0 gives better perf on AMD, but sometimes produces NaNs
|
|
480
|
+
BLOCK_N = 32
|
|
481
|
+
elif B <= 4 and split_k <= 128:
|
|
482
|
+
num_warps = 2
|
|
483
|
+
num_stages = 1
|
|
484
|
+
BLOCK_N = 32
|
|
485
|
+
elif B <= 16:
|
|
486
|
+
if use_fp8_path:
|
|
487
|
+
if mkv <= 256:
|
|
488
|
+
BLOCK_N = 16
|
|
489
|
+
num_warps = 4
|
|
490
|
+
num_stages = 1
|
|
491
|
+
elif mkv <= 4096:
|
|
492
|
+
BLOCK_N = 32
|
|
493
|
+
num_warps = 4
|
|
494
|
+
num_stages = 1
|
|
495
|
+
elif mkv <= 8192:
|
|
496
|
+
BLOCK_N = 16
|
|
497
|
+
num_warps = 2
|
|
498
|
+
num_stages = 1
|
|
499
|
+
elif mkv < 131072:
|
|
500
|
+
# Note: This isn't benchmarked, but fp8 seems to scale well.
|
|
501
|
+
BLOCK_N = 64
|
|
502
|
+
num_warps = 1
|
|
503
|
+
num_stages = 1
|
|
504
|
+
else:
|
|
505
|
+
BLOCK_N = 128
|
|
506
|
+
num_warps = 1
|
|
507
|
+
num_stages = 1
|
|
508
|
+
else:
|
|
509
|
+
if M < 16:
|
|
510
|
+
num_warps = 2
|
|
511
|
+
num_stages = 1
|
|
512
|
+
else:
|
|
513
|
+
num_warps = 1
|
|
514
|
+
num_stages = 1
|
|
515
|
+
BLOCK_N = 32
|
|
516
|
+
elif B <= 64 and use_fp8_path:
|
|
517
|
+
if is_paged:
|
|
518
|
+
num_stages = 1
|
|
519
|
+
if mkv <= 256:
|
|
520
|
+
BLOCK_N = 64
|
|
521
|
+
num_warps = 8
|
|
522
|
+
elif mkv <= 8192:
|
|
523
|
+
BLOCK_N = 64
|
|
524
|
+
num_warps = 1
|
|
525
|
+
elif mkv <= 16384:
|
|
526
|
+
BLOCK_N = 128
|
|
527
|
+
num_warps = 2
|
|
528
|
+
else:
|
|
529
|
+
# Note: This isn't benchmarked, but fp8 seems to scale well.
|
|
530
|
+
BLOCK_N = 128
|
|
531
|
+
num_warps = 1
|
|
532
|
+
else:
|
|
533
|
+
if mkv <= 256:
|
|
534
|
+
BLOCK_N = 16
|
|
535
|
+
num_warps = 4
|
|
536
|
+
num_stages = 1
|
|
537
|
+
elif mkv < 131072:
|
|
538
|
+
# Note: This isn't benchmarked, but fp8 seems to scale well.
|
|
539
|
+
BLOCK_N = 64
|
|
540
|
+
num_warps = 1
|
|
541
|
+
num_stages = 1
|
|
542
|
+
else:
|
|
543
|
+
BLOCK_N = 128
|
|
544
|
+
num_warps = 1
|
|
545
|
+
num_stages = 1
|
|
546
|
+
elif B <= 128 and use_fp8_path:
|
|
547
|
+
num_stages = 1
|
|
548
|
+
if is_paged:
|
|
549
|
+
if mkv <= 256:
|
|
550
|
+
num_warps = 4
|
|
551
|
+
BLOCK_N = 16
|
|
552
|
+
elif mkv <= 2048:
|
|
553
|
+
num_warps = 1
|
|
554
|
+
BLOCK_N = 64
|
|
555
|
+
elif mkv < 131072:
|
|
556
|
+
num_warps = 2
|
|
557
|
+
BLOCK_N = 128
|
|
558
|
+
else:
|
|
559
|
+
# Note: This isn't benchmarked, but fp8 seems to scale well.
|
|
560
|
+
num_warps = 1
|
|
561
|
+
BLOCK_N = 128
|
|
562
|
+
else:
|
|
563
|
+
if mkv <= 128:
|
|
564
|
+
num_warps = 4
|
|
565
|
+
BLOCK_N = 16
|
|
566
|
+
else:
|
|
567
|
+
num_warps = 1
|
|
568
|
+
BLOCK_N = 64
|
|
569
|
+
elif B <= 256 and use_fp8_path:
|
|
570
|
+
num_stages = 1
|
|
571
|
+
if is_paged:
|
|
572
|
+
if mkv <= 2048:
|
|
573
|
+
num_warps = 1
|
|
574
|
+
BLOCK_N = 64
|
|
575
|
+
elif mkv < 131072:
|
|
576
|
+
num_warps = 2
|
|
577
|
+
BLOCK_N = 128
|
|
578
|
+
else:
|
|
579
|
+
# Note: This isn't benchmarked, but fp8 seems to scale well.
|
|
580
|
+
num_warps = 1
|
|
581
|
+
BLOCK_N = 128
|
|
582
|
+
else:
|
|
583
|
+
if mkv <= 256:
|
|
584
|
+
num_warps = 2
|
|
585
|
+
BLOCK_N = 32
|
|
586
|
+
else:
|
|
587
|
+
num_warps = 1
|
|
588
|
+
BLOCK_N = 64
|
|
589
|
+
else:
|
|
590
|
+
num_warps = 1
|
|
591
|
+
num_stages = 1
|
|
592
|
+
BLOCK_N = 64
|
|
593
|
+
else:
|
|
594
|
+
should_modify_warp_and_block = (
|
|
595
|
+
Kkv == 128
|
|
596
|
+
and Kq == 128
|
|
597
|
+
and torch.cuda.get_device_capability() >= (8, 9)
|
|
598
|
+
)
|
|
599
|
+
if should_modify_warp_and_block:
|
|
600
|
+
if Mq > 1:
|
|
601
|
+
num_warps = 4
|
|
602
|
+
# Choose minimal round block size which covers M.
|
|
603
|
+
if M > 16:
|
|
604
|
+
BLOCK_M = 32
|
|
605
|
+
if M > 32:
|
|
606
|
+
BLOCK_M = 64
|
|
607
|
+
if M > 64:
|
|
608
|
+
BLOCK_M = 128
|
|
609
|
+
extra_args = {
|
|
610
|
+
"BLOCK_M": BLOCK_M,
|
|
611
|
+
"BLOCK_N": BLOCK_N,
|
|
612
|
+
"num_warps": num_warps,
|
|
613
|
+
"num_stages": num_stages,
|
|
614
|
+
}
|
|
615
|
+
return extra_args
|
|
616
|
+
|
|
617
|
+
@classmethod
|
|
618
|
+
def apply( # noqa: C901
|
|
619
|
+
cls,
|
|
620
|
+
inp: Inputs,
|
|
621
|
+
needs_gradient: bool,
|
|
622
|
+
) -> Tuple[torch.Tensor, Optional[Context]]:
|
|
623
|
+
"""
|
|
624
|
+
Note that inp can be of type InputsFp8, in which case K/V are assumed to be row-wise FP8-quantized.
|
|
625
|
+
This is different from int4 quantization, where coefficients are kept together with the quantized
|
|
626
|
+
values at the beginning of each row, and inp has type Inputs.
|
|
627
|
+
"""
|
|
628
|
+
|
|
629
|
+
output_dtype = inp.get_output_dtype()
|
|
630
|
+
# LSE may need higher precision than output
|
|
631
|
+
output_f64_lse = output_dtype in (torch.float32, torch.float64)
|
|
632
|
+
lse_dtype = torch.float64 if output_f64_lse else torch.float32
|
|
633
|
+
|
|
634
|
+
if inp.query.numel() == 0 or inp.key.numel() == 0:
|
|
635
|
+
out = torch.zeros_like(inp.query)
|
|
636
|
+
if needs_gradient:
|
|
637
|
+
lse_out = torch.full(
|
|
638
|
+
(inp.query.shape[0],)
|
|
639
|
+
+ inp.query.shape[2:-1]
|
|
640
|
+
+ (inp.query.shape[1],),
|
|
641
|
+
float("-inf"),
|
|
642
|
+
device=inp.query.device,
|
|
643
|
+
dtype=lse_dtype,
|
|
644
|
+
)
|
|
645
|
+
return out, Context(out=out, lse=lse_out)
|
|
646
|
+
return out, None
|
|
647
|
+
|
|
648
|
+
# Assert that if quantize_qk_to_fp8 is True, q_fp8_scale_shift must be provided
|
|
649
|
+
if hasattr(inp, "quantize_qk_to_fp8") and getattr(
|
|
650
|
+
inp, "quantize_qk_to_fp8", False
|
|
651
|
+
):
|
|
652
|
+
assert (
|
|
653
|
+
hasattr(inp, "q_fp8_scale_shift") and inp.q_fp8_scale_shift is not None # type: ignore
|
|
654
|
+
), "q_fp8_scale_shift must be provided when quantize_qk_to_fp8 is True"
|
|
655
|
+
|
|
656
|
+
k_fp8_scale_shift, v_fp8_scale_shift, q_fp8_scale_shift = (
|
|
657
|
+
cls.get_fp8_scale_shift(inp)
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
if not isinstance(inp.attn_bias, torch.Tensor):
|
|
661
|
+
attn_bias_tensor = None
|
|
662
|
+
attn_bias = cast(
|
|
663
|
+
Optional[
|
|
664
|
+
Union[
|
|
665
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
666
|
+
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
|
|
667
|
+
BlockDiagonalLocalAttentionPaddedKeysMask,
|
|
668
|
+
BlockDiagonalGappyKeysMask,
|
|
669
|
+
BlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
670
|
+
BlockDiagonalPaddedKeysMask,
|
|
671
|
+
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
672
|
+
PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
673
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
674
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
675
|
+
]
|
|
676
|
+
],
|
|
677
|
+
inp.attn_bias,
|
|
678
|
+
)
|
|
679
|
+
else:
|
|
680
|
+
attn_bias_tensor = inp.attn_bias
|
|
681
|
+
attn_bias = None
|
|
682
|
+
|
|
683
|
+
seq_len = None
|
|
684
|
+
seq_starts_k = None
|
|
685
|
+
seq_starts_q = None
|
|
686
|
+
seq_starts_q_multiplier = None
|
|
687
|
+
q, k, v = inp.get_qkv_in_bmghk()
|
|
688
|
+
IS_CAUSAL = False
|
|
689
|
+
IS_LOCAL = False
|
|
690
|
+
NUM_QUERIES_CAUSAL = 1
|
|
691
|
+
variable_q = False
|
|
692
|
+
window_left = -1
|
|
693
|
+
window_right = -1
|
|
694
|
+
|
|
695
|
+
is_block_diagonal = isinstance(attn_bias, BlockDiagonalPaddedKeysMask)
|
|
696
|
+
is_gappy = _is_supported_gappy_bias(attn_bias)
|
|
697
|
+
is_paged = _is_supported_paged_bias(attn_bias)
|
|
698
|
+
if attn_bias is not None:
|
|
699
|
+
assert is_paged or is_block_diagonal or is_gappy
|
|
700
|
+
assert attn_bias.k_seqinfo.seqlen.device == inp.query.device
|
|
701
|
+
seq_len = attn_bias.k_seqinfo.seqlen
|
|
702
|
+
assert seq_len.stride(0) == 1
|
|
703
|
+
if is_gappy:
|
|
704
|
+
seq_starts_k = attn_bias.k_seqinfo.seqstart
|
|
705
|
+
assert seq_starts_k.stride(0) == 1
|
|
706
|
+
assert q.shape[0] == 1
|
|
707
|
+
B = len(seq_len)
|
|
708
|
+
G, Hq, Kq = q.shape[-3:]
|
|
709
|
+
# force a bool because triton cannot take np.bool_
|
|
710
|
+
multiple_q = bool(attn_bias.q_seqinfo.max_seqlen > 1)
|
|
711
|
+
IS_CAUSAL = multiple_q and _is_supported_causal_bias(attn_bias)
|
|
712
|
+
IS_LOCAL = _is_supported_local_bias(attn_bias)
|
|
713
|
+
variable_q = multiple_q and not IS_CAUSAL
|
|
714
|
+
Kkv = v.shape[-1]
|
|
715
|
+
if isinstance(attn_bias, BlockDiagonalLocalAttentionPaddedKeysMask):
|
|
716
|
+
window_left = attn_bias.window_left
|
|
717
|
+
window_right = attn_bias.window_right
|
|
718
|
+
elif isinstance(attn_bias, BlockDiagonalCausalLocalAttentionPaddedKeysMask):
|
|
719
|
+
window_left = attn_bias._window_size - 1
|
|
720
|
+
|
|
721
|
+
if variable_q:
|
|
722
|
+
seq_starts_q = attn_bias.q_seqinfo.seqstart
|
|
723
|
+
seq_starts_q_multiplier = 1
|
|
724
|
+
assert seq_starts_q.stride(0) == 1
|
|
725
|
+
else:
|
|
726
|
+
q = q.view(B, -1, G, Hq, Kq)
|
|
727
|
+
if q_fp8_scale_shift is not None:
|
|
728
|
+
q_fp8_scale_shift = q_fp8_scale_shift.view(B, -1, G, Hq)
|
|
729
|
+
|
|
730
|
+
kv_shape = (1 if is_paged or is_gappy else B, -1, G, Hq, Kkv)
|
|
731
|
+
k = k.view(kv_shape)
|
|
732
|
+
v = v.view(kv_shape)
|
|
733
|
+
if k_fp8_scale_shift is not None and v_fp8_scale_shift is not None:
|
|
734
|
+
k_fp8_scale_shift = k_fp8_scale_shift.view(kv_shape[:-1])
|
|
735
|
+
v_fp8_scale_shift = v_fp8_scale_shift.view(kv_shape[:-1])
|
|
736
|
+
|
|
737
|
+
Mq = q.shape[1]
|
|
738
|
+
NUM_QUERIES_CAUSAL = Mq
|
|
739
|
+
else:
|
|
740
|
+
B, Mq, G, Hq, Kq = q.shape
|
|
741
|
+
|
|
742
|
+
if attn_bias_tensor is not None and attn_bias_tensor.ndim == 4:
|
|
743
|
+
# (B, H, Mq, Mkv) -> (B, G, H, Mq, Mkv)
|
|
744
|
+
attn_bias_tensor = attn_bias_tensor.unsqueeze(1)
|
|
745
|
+
|
|
746
|
+
# In the case of MQA/GQA, we make q have sequence length (H * Mq) and only one "head".
|
|
747
|
+
mqa_swap_seqlen_head = False
|
|
748
|
+
if (
|
|
749
|
+
k.shape[3] > 1
|
|
750
|
+
and k.stride(3) == 0
|
|
751
|
+
and v.stride(3) == 0
|
|
752
|
+
and attn_bias_tensor is None
|
|
753
|
+
):
|
|
754
|
+
mqa_swap_seqlen_head = True
|
|
755
|
+
if q_fp8_scale_shift is not None:
|
|
756
|
+
assert q_fp8_scale_shift.shape == q.shape[:-1], (
|
|
757
|
+
f"{q.shape=}, {q_fp8_scale_shift.shape=}"
|
|
758
|
+
)
|
|
759
|
+
if variable_q:
|
|
760
|
+
q_fp8_scale_shift = q_fp8_scale_shift.permute(0, 1, 3, 2).reshape(
|
|
761
|
+
1, -1, G, 1
|
|
762
|
+
)
|
|
763
|
+
else:
|
|
764
|
+
q_fp8_scale_shift = q_fp8_scale_shift.permute(0, 3, 1, 2).reshape(
|
|
765
|
+
q.shape[0], -1, G, 1
|
|
766
|
+
)
|
|
767
|
+
if variable_q:
|
|
768
|
+
seq_starts_q_multiplier = Hq
|
|
769
|
+
assert q.shape[0] == 1
|
|
770
|
+
# The idea is Hq,Mq are reshaped to (M=Mq*Hq, H=1)
|
|
771
|
+
q = q.permute(0, 1, 3, 2, 4).reshape(1, -1, G, 1, Kq)
|
|
772
|
+
else:
|
|
773
|
+
# This is a copy iff Mq, G and H are all > 1.
|
|
774
|
+
# The idea is Hq,Mq are reshaped to (M=Hq*Mq, H=1)
|
|
775
|
+
q = q.permute(0, 3, 1, 2, 4).reshape(q.shape[0], -1, G, 1, Kq)
|
|
776
|
+
k = k[:, :, :, :1]
|
|
777
|
+
v = v[:, :, :, :1]
|
|
778
|
+
if k_fp8_scale_shift is not None and v_fp8_scale_shift is not None:
|
|
779
|
+
k_fp8_scale_shift = k_fp8_scale_shift[:, :, :, :1]
|
|
780
|
+
v_fp8_scale_shift = v_fp8_scale_shift[:, :, :, :1]
|
|
781
|
+
|
|
782
|
+
if k.dtype == torch.int32:
|
|
783
|
+
if k_fp8_scale_shift is not None:
|
|
784
|
+
Lk = k.shape[-1] * 4
|
|
785
|
+
PACKED_PER_VAL = 4
|
|
786
|
+
else:
|
|
787
|
+
# Quantized K/V
|
|
788
|
+
PACKED_PER_VAL = 8
|
|
789
|
+
Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8
|
|
790
|
+
else:
|
|
791
|
+
Lk = k.shape[-1]
|
|
792
|
+
PACKED_PER_VAL = 1
|
|
793
|
+
assert cls.NUM_GROUPS == 1, f"{cls.NUM_GROUPS=}"
|
|
794
|
+
|
|
795
|
+
_, Mk, G, H, Kkv = k.shape
|
|
796
|
+
Bqq, Mqq, G, H, Kq = q.shape
|
|
797
|
+
assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}"
|
|
798
|
+
if variable_q:
|
|
799
|
+
assert attn_bias is not None
|
|
800
|
+
assert seq_starts_q_multiplier is not None
|
|
801
|
+
M = attn_bias.q_seqinfo.max_seqlen * seq_starts_q_multiplier
|
|
802
|
+
else:
|
|
803
|
+
M = Mqq
|
|
804
|
+
page_size = inp.attn_bias.page_size if is_paged else 0 # type: ignore
|
|
805
|
+
block_tables = None
|
|
806
|
+
kv_cache_blocks_per_row = 0
|
|
807
|
+
if is_paged:
|
|
808
|
+
block_tables = inp.attn_bias.block_tables # type: ignore
|
|
809
|
+
kv_cache_blocks_per_row = block_tables.shape[1]
|
|
810
|
+
Mk = block_tables.shape[1] * page_size
|
|
811
|
+
elif attn_bias is not None:
|
|
812
|
+
Mk = min(Mk, attn_bias.k_seqinfo.max_seqlen)
|
|
813
|
+
|
|
814
|
+
if cls.SPLIT_K is not None:
|
|
815
|
+
split_k = cls.SPLIT_K
|
|
816
|
+
else:
|
|
817
|
+
# Use heuristics
|
|
818
|
+
split_k = (
|
|
819
|
+
cls.get_split_k(B, G, H, Mk, Mq, page_size, is_paged)
|
|
820
|
+
if attn_bias_tensor is None
|
|
821
|
+
else 1
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
# M_ceil = Mqq rounded up to a multiple of MAX_BLOCK_M
|
|
825
|
+
M_ceil = (Mqq + cls.MAX_BLOCK_M - 1) // cls.MAX_BLOCK_M * cls.MAX_BLOCK_M
|
|
826
|
+
IS_SPLITK = split_k > 1 # or cls.autotune?
|
|
827
|
+
output_shape = (Bqq, Mq, G, Hq, Kq)
|
|
828
|
+
if IS_SPLITK:
|
|
829
|
+
o_splitk_dtype = (
|
|
830
|
+
torch.float64 if output_dtype == torch.float64 else torch.float32
|
|
831
|
+
)
|
|
832
|
+
if cls.SPLIT_K_EARLY_EXIT:
|
|
833
|
+
o_splitk = torch.zeros(
|
|
834
|
+
[Bqq, G, H, split_k, M_ceil, Kq],
|
|
835
|
+
dtype=o_splitk_dtype,
|
|
836
|
+
device=q.device,
|
|
837
|
+
)
|
|
838
|
+
else:
|
|
839
|
+
o_splitk = torch.empty(
|
|
840
|
+
[Bqq, G, H, split_k, M_ceil, Kq],
|
|
841
|
+
dtype=o_splitk_dtype,
|
|
842
|
+
device=q.device,
|
|
843
|
+
)
|
|
844
|
+
else:
|
|
845
|
+
o_splitk = torch.empty(
|
|
846
|
+
[Bqq, split_k, Mqq, G, H, Kq],
|
|
847
|
+
dtype=output_dtype,
|
|
848
|
+
device=q.device,
|
|
849
|
+
).permute(0, 3, 4, 1, 2, 5)
|
|
850
|
+
lse, lse_splitk = None, None
|
|
851
|
+
if IS_SPLITK or needs_gradient:
|
|
852
|
+
if IS_SPLITK or output_f64_lse:
|
|
853
|
+
lse_splitk_dtype = torch.float64
|
|
854
|
+
else:
|
|
855
|
+
lse_splitk_dtype = torch.float32
|
|
856
|
+
if cls.SPLIT_K_EARLY_EXIT:
|
|
857
|
+
lse_splitk = torch.full(
|
|
858
|
+
[Bqq, G, H, split_k, Mqq],
|
|
859
|
+
-float("inf"),
|
|
860
|
+
dtype=lse_splitk_dtype,
|
|
861
|
+
device=q.device,
|
|
862
|
+
)
|
|
863
|
+
else:
|
|
864
|
+
lse_splitk = torch.empty(
|
|
865
|
+
[Bqq, G, H, split_k, Mqq],
|
|
866
|
+
dtype=lse_splitk_dtype,
|
|
867
|
+
device=q.device,
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
def grid(META):
|
|
871
|
+
import triton
|
|
872
|
+
|
|
873
|
+
return triton.cdiv(M, META["BLOCK_M"]), B * G * H, split_k
|
|
874
|
+
|
|
875
|
+
split_size = (Mk + split_k - 1) // split_k
|
|
876
|
+
use_seq_len = seq_len is not None
|
|
877
|
+
|
|
878
|
+
kernel = cls.get_kernel()
|
|
879
|
+
extra_args = cls.get_extra_args(
|
|
880
|
+
is_paged=is_paged,
|
|
881
|
+
B=B,
|
|
882
|
+
M=M,
|
|
883
|
+
Kkv=Kkv,
|
|
884
|
+
Kq=Kq,
|
|
885
|
+
Mq=Mq,
|
|
886
|
+
split_k=split_k,
|
|
887
|
+
attn_bias=attn_bias,
|
|
888
|
+
k_fp8_scale_shift=k_fp8_scale_shift,
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
IS_HIP = torch.version.hip is not None
|
|
892
|
+
|
|
893
|
+
if inp.quantize_pv_to_fp8:
|
|
894
|
+
v = v.view(torch.int8)
|
|
895
|
+
v = v.view(torch.float8_e4m3fn)
|
|
896
|
+
|
|
897
|
+
kernel[grid](
|
|
898
|
+
Q=q,
|
|
899
|
+
K=k,
|
|
900
|
+
V=v,
|
|
901
|
+
sm_scale=inp.scale_float,
|
|
902
|
+
Out_splitK=o_splitk,
|
|
903
|
+
LSE_splitk=lse_splitk,
|
|
904
|
+
block_tables=block_tables,
|
|
905
|
+
Seq_len=seq_len,
|
|
906
|
+
Seq_starts_k=seq_starts_k,
|
|
907
|
+
Seq_starts_q=seq_starts_q,
|
|
908
|
+
Seq_starts_q_multiplier=seq_starts_q_multiplier,
|
|
909
|
+
additive_bias=attn_bias_tensor,
|
|
910
|
+
K_fp8_scale_shift=k_fp8_scale_shift,
|
|
911
|
+
V_fp8_scale_shift=v_fp8_scale_shift,
|
|
912
|
+
q_fp8_scale_shift=q_fp8_scale_shift,
|
|
913
|
+
**_strides(q, "qz", "qm", "qg", "qh", "qk"),
|
|
914
|
+
**_strides(k, "kz", "kn", "kg", "kh", "kk"),
|
|
915
|
+
**_strides(v, "vz", "vn", "vg", "vh", "vk"),
|
|
916
|
+
**_strides(o_splitk, "osk_z", "osk_g", "osk_h", "osk_s", "osk_m", "osk_k"),
|
|
917
|
+
**_strides(lse_splitk, "lsek_z", "lsek_g", "lsek_h", "lsek_s", "lsek_m"),
|
|
918
|
+
**_strides(block_tables, "blocktablesz", "blocktablesl"),
|
|
919
|
+
**_strides(
|
|
920
|
+
attn_bias_tensor, "bias_b", "bias_g", "bias_h", "bias_qm", "bias_km"
|
|
921
|
+
),
|
|
922
|
+
**_strides(
|
|
923
|
+
k_fp8_scale_shift,
|
|
924
|
+
"k_fp8_scale_shift_z",
|
|
925
|
+
"k_fp8_scale_shift_n",
|
|
926
|
+
"k_fp8_scale_shift_g",
|
|
927
|
+
"k_fp8_scale_shift_h",
|
|
928
|
+
),
|
|
929
|
+
**_strides(
|
|
930
|
+
v_fp8_scale_shift,
|
|
931
|
+
"v_fp8_scale_shift_z",
|
|
932
|
+
"v_fp8_scale_shift_n",
|
|
933
|
+
"v_fp8_scale_shift_g",
|
|
934
|
+
"v_fp8_scale_shift_h",
|
|
935
|
+
),
|
|
936
|
+
**_strides(
|
|
937
|
+
q_fp8_scale_shift,
|
|
938
|
+
"q_fp8_scale_shift_z",
|
|
939
|
+
"q_fp8_scale_shift_m",
|
|
940
|
+
"q_fp8_scale_shift_g",
|
|
941
|
+
"q_fp8_scale_shift_h",
|
|
942
|
+
),
|
|
943
|
+
kv_cache_blocks_per_row=kv_cache_blocks_per_row,
|
|
944
|
+
Z=B,
|
|
945
|
+
H=H,
|
|
946
|
+
G=G,
|
|
947
|
+
N_CTX_Q=M,
|
|
948
|
+
N_CTX_K=Mk,
|
|
949
|
+
BLOCK_N_PER_SPLIT=split_size,
|
|
950
|
+
BLOCK_DMODEL=Lk,
|
|
951
|
+
USE_SEQ_LEN=use_seq_len,
|
|
952
|
+
PACKED_PER_VAL=PACKED_PER_VAL,
|
|
953
|
+
N_GROUPS=cls.NUM_GROUPS,
|
|
954
|
+
IS_CAUSAL=IS_CAUSAL,
|
|
955
|
+
IS_LOCAL=IS_LOCAL,
|
|
956
|
+
NUM_QUERIES_CAUSAL=NUM_QUERIES_CAUSAL,
|
|
957
|
+
IS_SPLITK=IS_SPLITK,
|
|
958
|
+
SPLIT_K_EARLY_EXIT=cls.SPLIT_K_EARLY_EXIT,
|
|
959
|
+
USE_PAGED_ATTENTION=is_paged,
|
|
960
|
+
PAGE_SIZE=page_size,
|
|
961
|
+
WINDOW_LEFT=window_left,
|
|
962
|
+
WINDOW_RIGHT=window_right,
|
|
963
|
+
WRITE_LSE=IS_SPLITK or needs_gradient,
|
|
964
|
+
HAS_ADDITIVE_BIAS=attn_bias_tensor is not None,
|
|
965
|
+
NUM_PROGRAMS_DIM2_CONST=split_k,
|
|
966
|
+
IS_HIP=IS_HIP,
|
|
967
|
+
QUANTIZE_PV_TO_FP8=inp.quantize_pv_to_fp8,
|
|
968
|
+
QUANTIZE_QK_TO_FP8=inp.quantize_qk_to_fp8,
|
|
969
|
+
USE_FP32_SCALES=inp.use_fp32_scales,
|
|
970
|
+
**extra_args,
|
|
971
|
+
)
|
|
972
|
+
if not IS_SPLITK:
|
|
973
|
+
out = o_splitk[:, :, :, 0] # Bqq, G, H, Mqq, Kq
|
|
974
|
+
if variable_q and mqa_swap_seqlen_head:
|
|
975
|
+
out = out.view(1, G, Mq, Hq, Kq).permute(0, 2, 1, 3, 4).contiguous()
|
|
976
|
+
else:
|
|
977
|
+
out = out.view(Bqq, G, Hq, Mq, Kq)
|
|
978
|
+
# This is a copy iff mqa_swap_seqlen_head and Mq, G and Hq are all > 1.
|
|
979
|
+
out = out.permute(0, 3, 1, 2, 4).contiguous()
|
|
980
|
+
if needs_gradient:
|
|
981
|
+
assert lse_splitk is not None
|
|
982
|
+
lse = lse_splitk[:, :, :, 0] # Bqq, G, H, Mqq
|
|
983
|
+
if variable_q and mqa_swap_seqlen_head:
|
|
984
|
+
lse = lse.view(1, G, Mq, Hq).permute(0, 1, 3, 2)
|
|
985
|
+
else:
|
|
986
|
+
lse = lse.view(Bqq, G, Hq, Mq)
|
|
987
|
+
if attn_bias is not None and not variable_q:
|
|
988
|
+
lse = lse.permute(1, 2, 0, 3).reshape(1, G, Hq, B * Mq)
|
|
989
|
+
else:
|
|
990
|
+
lse = None
|
|
991
|
+
|
|
992
|
+
if inp.query.ndim == 4:
|
|
993
|
+
# BMGHK -> BMHK
|
|
994
|
+
assert G == 1
|
|
995
|
+
if lse is not None:
|
|
996
|
+
lse = lse[:, 0]
|
|
997
|
+
out = out[:, :, 0]
|
|
998
|
+
|
|
999
|
+
if lse is None:
|
|
1000
|
+
return out, None
|
|
1001
|
+
return out, Context(out=out, lse=lse)
|
|
1002
|
+
|
|
1003
|
+
out = torch.empty(output_shape, device=q.device, dtype=output_dtype)
|
|
1004
|
+
|
|
1005
|
+
# Merge attention and LSE outputs from different split-k chunks
|
|
1006
|
+
assert lse_splitk is not None
|
|
1007
|
+
output_lse = None
|
|
1008
|
+
if needs_gradient:
|
|
1009
|
+
if attn_bias is None or variable_q:
|
|
1010
|
+
output_lse = torch.empty(
|
|
1011
|
+
(Bqq, G, Hq, Mq), device=q.device, dtype=lse_dtype
|
|
1012
|
+
)
|
|
1013
|
+
lse = output_lse
|
|
1014
|
+
else:
|
|
1015
|
+
output_lse = torch.empty(
|
|
1016
|
+
(1, G, Hq, B * Mq), device=q.device, dtype=lse_dtype
|
|
1017
|
+
)
|
|
1018
|
+
lse = output_lse.view(G, Hq, B, Mq).permute(2, 0, 1, 3)
|
|
1019
|
+
|
|
1020
|
+
o_splitk = o_splitk[:, :, :, :, :Mqq]
|
|
1021
|
+
|
|
1022
|
+
if mqa_swap_seqlen_head:
|
|
1023
|
+
if variable_q:
|
|
1024
|
+
o_splitk = o_splitk.view(Bqq, G, split_k, Mq, Hq, Kq).permute(
|
|
1025
|
+
0, 1, 4, 2, 3, 5
|
|
1026
|
+
)
|
|
1027
|
+
lse_splitk = lse_splitk.view(Bqq, G, split_k, Mq, Hq).permute(
|
|
1028
|
+
0, 1, 4, 2, 3
|
|
1029
|
+
)
|
|
1030
|
+
else:
|
|
1031
|
+
o_splitk = o_splitk.view(Bqq, G, split_k, Hq, Mq, Kq).permute(
|
|
1032
|
+
0, 1, 3, 2, 4, 5
|
|
1033
|
+
)
|
|
1034
|
+
lse_splitk = lse_splitk.view(Bqq, G, split_k, Hq, Mq).permute(
|
|
1035
|
+
0, 1, 3, 2, 4
|
|
1036
|
+
)
|
|
1037
|
+
|
|
1038
|
+
merge_attentions(out, lse, o_splitk, lse_splitk)
|
|
1039
|
+
|
|
1040
|
+
if inp.query.ndim == 4:
|
|
1041
|
+
# BMGHK -> BMHK
|
|
1042
|
+
assert G == 1
|
|
1043
|
+
out = out[:, :, 0]
|
|
1044
|
+
if output_lse is not None:
|
|
1045
|
+
output_lse = output_lse[:, 0]
|
|
1046
|
+
if Mk == 0:
|
|
1047
|
+
out.zero_()
|
|
1048
|
+
|
|
1049
|
+
if attn_bias is not None and not variable_q:
|
|
1050
|
+
out = out.view(1, B * Mq, G, Hq, Kq)
|
|
1051
|
+
|
|
1052
|
+
if output_lse is None:
|
|
1053
|
+
return out, None
|
|
1054
|
+
|
|
1055
|
+
return out, Context(out=out, lse=output_lse)
|
|
1056
|
+
|
|
1057
|
+
@classmethod
|
|
1058
|
+
@functools.lru_cache
|
|
1059
|
+
def get_operator(
|
|
1060
|
+
cls,
|
|
1061
|
+
splitk: int,
|
|
1062
|
+
*,
|
|
1063
|
+
block_m: Optional[int] = None,
|
|
1064
|
+
block_n: Optional[int] = None,
|
|
1065
|
+
num_warps: Optional[int] = None,
|
|
1066
|
+
num_stages: Optional[int] = None,
|
|
1067
|
+
split_k_early_exit: Optional[bool] = None,
|
|
1068
|
+
) -> Type[AttentionFwOpBase]:
|
|
1069
|
+
kwargs = {
|
|
1070
|
+
"NAME": f"triton_splitK{splitk}",
|
|
1071
|
+
"SPLIT_K": splitk,
|
|
1072
|
+
}
|
|
1073
|
+
if block_m is not None:
|
|
1074
|
+
kwargs["BLOCK_M"] = block_m
|
|
1075
|
+
if block_n is not None:
|
|
1076
|
+
kwargs["BLOCK_N"] = block_n
|
|
1077
|
+
if num_warps is not None:
|
|
1078
|
+
kwargs["NUM_WARPS"] = num_warps
|
|
1079
|
+
if num_stages is not None:
|
|
1080
|
+
kwargs["NUM_STAGES"] = num_stages
|
|
1081
|
+
if split_k_early_exit is not None:
|
|
1082
|
+
kwargs["SPLIT_K_EARLY_EXIT"] = split_k_early_exit
|
|
1083
|
+
return type(
|
|
1084
|
+
f"FwOp_S{splitk}",
|
|
1085
|
+
(cls,),
|
|
1086
|
+
kwargs,
|
|
1087
|
+
)
|
|
1088
|
+
|
|
1089
|
+
|
|
1090
|
+
def merge_attentions(
|
|
1091
|
+
attn_out: torch.Tensor,
|
|
1092
|
+
lse_out: Optional[torch.Tensor],
|
|
1093
|
+
attn_split: torch.Tensor,
|
|
1094
|
+
lse_split: torch.Tensor,
|
|
1095
|
+
):
|
|
1096
|
+
import triton
|
|
1097
|
+
|
|
1098
|
+
from ._triton.splitk_kernels import _splitK_reduce
|
|
1099
|
+
|
|
1100
|
+
B, M, G, H, Kq = attn_out.shape
|
|
1101
|
+
B1, G1, H1, split_k, M1, Kq1 = attn_split.shape
|
|
1102
|
+
B2, G2, H2, split_k1, M2 = lse_split.shape
|
|
1103
|
+
|
|
1104
|
+
assert (
|
|
1105
|
+
B == B1 == B2
|
|
1106
|
+
and G == G1 == G2
|
|
1107
|
+
and H == H1 == H2
|
|
1108
|
+
and M == M1 == M2
|
|
1109
|
+
and Kq == Kq1
|
|
1110
|
+
), (
|
|
1111
|
+
f"Incompatible shapes: {attn_out.shape=}, {attn_split.shape=}, {lse_split.shape=}"
|
|
1112
|
+
)
|
|
1113
|
+
assert split_k == split_k1, (
|
|
1114
|
+
f"Incompatible shapes: {attn_split.shape=}, {lse_split.shape=}"
|
|
1115
|
+
)
|
|
1116
|
+
if lse_out is not None:
|
|
1117
|
+
B3, G3, H3, M3 = lse_out.shape
|
|
1118
|
+
assert B == B3 and G == G3 and H == H3 and M == M3, (
|
|
1119
|
+
f"Incompatible shapes: {attn_out.shape=}, {lse_out.shape=}"
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1122
|
+
num_warps = 4 if B * G * H < 32 or torch.version.hip else 2
|
|
1123
|
+
splitK_pow2 = triton.next_power_of_2(split_k)
|
|
1124
|
+
head_dim = attn_out.shape[-1]
|
|
1125
|
+
grid = (M, B * G * H, 1)
|
|
1126
|
+
# pyre-ignore[28]
|
|
1127
|
+
_splitK_reduce[grid](
|
|
1128
|
+
attn_split,
|
|
1129
|
+
lse_split,
|
|
1130
|
+
attn_out,
|
|
1131
|
+
lse_out,
|
|
1132
|
+
split_k=split_k,
|
|
1133
|
+
splitK_pow2=splitK_pow2,
|
|
1134
|
+
**_strides(attn_split, "osk_z", "osk_g", "osk_h", "osk_s", "osk_m", "osk_k"),
|
|
1135
|
+
**_strides(lse_split, "lsek_z", "lsek_g", "lsek_h", "lsek_s", "lsek_m"),
|
|
1136
|
+
**_strides(attn_out, "oz", "om", "og", "oh", "ok"),
|
|
1137
|
+
**_strides(lse_out, "lse_z", "lse_g", "lse_h", "lse_m"),
|
|
1138
|
+
head_dim=head_dim,
|
|
1139
|
+
head_dim_pow_2=triton.next_power_of_2(head_dim),
|
|
1140
|
+
G=G,
|
|
1141
|
+
H=H,
|
|
1142
|
+
WRITE_LSE=lse_out is not None,
|
|
1143
|
+
num_warps=num_warps,
|
|
1144
|
+
)
|
|
1145
|
+
|
|
1146
|
+
|
|
1147
|
+
@torch.library.custom_op(
|
|
1148
|
+
"mslk::fmha_merge_attentions_varargs",
|
|
1149
|
+
mutates_args=(),
|
|
1150
|
+
device_types=["cuda"],
|
|
1151
|
+
)
|
|
1152
|
+
def merge_attentions_varargs(
|
|
1153
|
+
attn_split: Sequence[torch.Tensor],
|
|
1154
|
+
lse_split: Sequence[torch.Tensor],
|
|
1155
|
+
write_lse: bool,
|
|
1156
|
+
output_dtype: Optional[torch.dtype],
|
|
1157
|
+
B: int,
|
|
1158
|
+
M: int,
|
|
1159
|
+
G: int,
|
|
1160
|
+
H: int,
|
|
1161
|
+
Kq: int,
|
|
1162
|
+
) -> List[torch.Tensor]:
|
|
1163
|
+
import triton
|
|
1164
|
+
|
|
1165
|
+
from ._triton.splitk_kernels import _splitK_reduce_varargs
|
|
1166
|
+
from ._triton.vararg_kernel import unroll_varargs
|
|
1167
|
+
|
|
1168
|
+
attn_out = torch.empty(
|
|
1169
|
+
(B, M, G, H, Kq),
|
|
1170
|
+
device=attn_split[0].device,
|
|
1171
|
+
dtype=output_dtype or attn_split[0].dtype,
|
|
1172
|
+
)
|
|
1173
|
+
if write_lse:
|
|
1174
|
+
lse_out = torch.empty(
|
|
1175
|
+
(B, G, H, M),
|
|
1176
|
+
device=attn_split[0].device,
|
|
1177
|
+
dtype=lse_split[0].dtype,
|
|
1178
|
+
)
|
|
1179
|
+
else:
|
|
1180
|
+
lse_out = None
|
|
1181
|
+
kernel_args, grid = _prepare_reduce_kernel_params(
|
|
1182
|
+
attn_out, lse_out, attn_split, lse_split
|
|
1183
|
+
)
|
|
1184
|
+
reduce_kernel = unroll_varargs(_splitK_reduce_varargs, N=len(attn_split))
|
|
1185
|
+
head_dim = attn_out.shape[-1]
|
|
1186
|
+
reduce_kernel[grid](
|
|
1187
|
+
*attn_split,
|
|
1188
|
+
*lse_split,
|
|
1189
|
+
Out=attn_out,
|
|
1190
|
+
LSE=lse_out,
|
|
1191
|
+
**kernel_args,
|
|
1192
|
+
head_dim=head_dim,
|
|
1193
|
+
head_dim_pow_2=triton.next_power_of_2(head_dim),
|
|
1194
|
+
WRITE_LSE=lse_out is not None,
|
|
1195
|
+
)
|
|
1196
|
+
if write_lse:
|
|
1197
|
+
assert lse_out is not None
|
|
1198
|
+
return [attn_out, lse_out]
|
|
1199
|
+
return [attn_out]
|
|
1200
|
+
|
|
1201
|
+
|
|
1202
|
+
@torch.library.register_fake("mslk::fmha_merge_attentions_varargs")
|
|
1203
|
+
def merge_attentions_varargs_fake(
|
|
1204
|
+
attn_split: Sequence[torch.Tensor],
|
|
1205
|
+
lse_split: Sequence[torch.Tensor],
|
|
1206
|
+
write_lse: bool,
|
|
1207
|
+
output_dtype: Optional[torch.dtype],
|
|
1208
|
+
B: int,
|
|
1209
|
+
M: int,
|
|
1210
|
+
G: int,
|
|
1211
|
+
H: int,
|
|
1212
|
+
Kq: int,
|
|
1213
|
+
) -> List[torch.Tensor]:
|
|
1214
|
+
attn_out = torch.empty(
|
|
1215
|
+
(B, M, G, H, Kq),
|
|
1216
|
+
device=attn_split[0].device,
|
|
1217
|
+
dtype=output_dtype or attn_split[0].dtype,
|
|
1218
|
+
)
|
|
1219
|
+
if write_lse:
|
|
1220
|
+
lse_out = torch.empty(
|
|
1221
|
+
(B, G, H, M),
|
|
1222
|
+
device=attn_split[0].device,
|
|
1223
|
+
dtype=lse_split[0].dtype,
|
|
1224
|
+
)
|
|
1225
|
+
return [attn_out, lse_out]
|
|
1226
|
+
return [attn_out]
|
|
1227
|
+
|
|
1228
|
+
|
|
1229
|
+
def _merge_attentions_backward(
|
|
1230
|
+
ctx: torch.autograd.function.FunctionCtx,
|
|
1231
|
+
grad: List[torch.Tensor],
|
|
1232
|
+
) -> Tuple[None, ...]:
|
|
1233
|
+
raise NotImplementedError(
|
|
1234
|
+
"Backward pass is not implemented for merge_attentions. "
|
|
1235
|
+
"If it was, it would be easy to get wrong attention gradients, "
|
|
1236
|
+
"because the gradients of the LSEs "
|
|
1237
|
+
"don't get propagated by attention backward."
|
|
1238
|
+
)
|
|
1239
|
+
|
|
1240
|
+
|
|
1241
|
+
merge_attentions_varargs.register_autograd(_merge_attentions_backward)
|
|
1242
|
+
|
|
1243
|
+
|
|
1244
|
+
@torch.library.custom_op(
|
|
1245
|
+
"mslk::merge_attentions_varargs_backward",
|
|
1246
|
+
mutates_args=(),
|
|
1247
|
+
device_types=["cuda"],
|
|
1248
|
+
)
|
|
1249
|
+
def merge_attentions_varargs_backward(
|
|
1250
|
+
attn_split: List[torch.Tensor],
|
|
1251
|
+
lse_split: List[torch.Tensor],
|
|
1252
|
+
attn_out: torch.Tensor,
|
|
1253
|
+
lse_out: torch.Tensor,
|
|
1254
|
+
grad_attn: torch.Tensor,
|
|
1255
|
+
grad_lse: torch.Tensor,
|
|
1256
|
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
|
1257
|
+
from ._triton.splitk_kernels import _splitK_reduce_varargs_backward
|
|
1258
|
+
from ._triton.vararg_kernel import unroll_varargs
|
|
1259
|
+
|
|
1260
|
+
dattn_splitk = [torch.empty_like(x) for x in attn_split]
|
|
1261
|
+
dlse_splitk = [torch.empty_like(x) for x in lse_split]
|
|
1262
|
+
|
|
1263
|
+
kernel_args, grid = _prepare_reduce_kernel_params(
|
|
1264
|
+
attn_out, lse_out, attn_split, lse_split, grad_attn, grad_lse
|
|
1265
|
+
)
|
|
1266
|
+
|
|
1267
|
+
reduce_kernel_backward = unroll_varargs(
|
|
1268
|
+
_splitK_reduce_varargs_backward, N=len(attn_split)
|
|
1269
|
+
)
|
|
1270
|
+
reduce_kernel_backward[grid](
|
|
1271
|
+
*attn_split,
|
|
1272
|
+
*lse_split,
|
|
1273
|
+
*dattn_splitk,
|
|
1274
|
+
*dlse_splitk,
|
|
1275
|
+
Out=attn_out,
|
|
1276
|
+
LSE=lse_out,
|
|
1277
|
+
DOut=grad_attn,
|
|
1278
|
+
DLSE=grad_lse,
|
|
1279
|
+
**kernel_args,
|
|
1280
|
+
BLOCK_SIZE=attn_out.shape[-1],
|
|
1281
|
+
)
|
|
1282
|
+
|
|
1283
|
+
return dattn_splitk, dlse_splitk
|
|
1284
|
+
|
|
1285
|
+
|
|
1286
|
+
@torch.library.register_fake("mslk::merge_attentions_varargs_backward")
|
|
1287
|
+
def merge_attentions_varargs_backward_fake(
|
|
1288
|
+
attn_split: List[torch.Tensor],
|
|
1289
|
+
lse_split: List[torch.Tensor],
|
|
1290
|
+
attn_out: torch.Tensor,
|
|
1291
|
+
lse_out: torch.Tensor,
|
|
1292
|
+
grad_attn: torch.Tensor,
|
|
1293
|
+
grad_lse: torch.Tensor,
|
|
1294
|
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
|
1295
|
+
dattn_splitk = [torch.empty_like(x) for x in attn_split]
|
|
1296
|
+
dlse_splitk = [torch.empty_like(x) for x in lse_split]
|
|
1297
|
+
return dattn_splitk, dlse_splitk
|
|
1298
|
+
|
|
1299
|
+
|
|
1300
|
+
def _prepare_reduce_kernel_params(
|
|
1301
|
+
attn_out: torch.Tensor,
|
|
1302
|
+
lse_out: Optional[torch.Tensor],
|
|
1303
|
+
attn_split: Sequence[torch.Tensor],
|
|
1304
|
+
lse_split: Sequence[torch.Tensor],
|
|
1305
|
+
grad_attn: Optional[torch.Tensor] = None,
|
|
1306
|
+
grad_lse: Optional[torch.Tensor] = None,
|
|
1307
|
+
) -> Tuple[Dict[str, int], Tuple[int, int, int]]:
|
|
1308
|
+
B, M, G, H, Kq = attn_out.shape
|
|
1309
|
+
B1, G1, H1, M1, Kq1 = attn_split[0].shape
|
|
1310
|
+
B2, G2, H2, M2 = lse_split[0].shape
|
|
1311
|
+
|
|
1312
|
+
assert (
|
|
1313
|
+
B == B1 == B2
|
|
1314
|
+
and G == G1 == G2
|
|
1315
|
+
and H == H1 == H2
|
|
1316
|
+
and M == M1 == M2
|
|
1317
|
+
and Kq == Kq1
|
|
1318
|
+
), (
|
|
1319
|
+
f"Incompatible shapes: {attn_out.shape=}, {attn_split[0].shape=}, {lse_split[0].shape=}"
|
|
1320
|
+
)
|
|
1321
|
+
if lse_out is not None:
|
|
1322
|
+
B3, G3, H3, M3 = lse_out.shape
|
|
1323
|
+
assert B == B3 and G == G3 and H == H3 and M == M3, (
|
|
1324
|
+
f"Incompatible shapes: {attn_out.shape=}, {lse_out.shape=}"
|
|
1325
|
+
)
|
|
1326
|
+
|
|
1327
|
+
attn_split_strides = {}
|
|
1328
|
+
lse_split_strides = {}
|
|
1329
|
+
for i in range(len(attn_split)):
|
|
1330
|
+
attn_split_strides.update(
|
|
1331
|
+
_strides(
|
|
1332
|
+
attn_split[i],
|
|
1333
|
+
"osk_z" + str(i),
|
|
1334
|
+
"osk_g" + str(i),
|
|
1335
|
+
"osk_h" + str(i),
|
|
1336
|
+
"osk_m" + str(i),
|
|
1337
|
+
"osk_k" + str(i),
|
|
1338
|
+
)
|
|
1339
|
+
)
|
|
1340
|
+
lse_split_strides.update(
|
|
1341
|
+
_strides(
|
|
1342
|
+
lse_split[i],
|
|
1343
|
+
"lsek_z" + str(i),
|
|
1344
|
+
"lsek_g" + str(i),
|
|
1345
|
+
"lsek_h" + str(i),
|
|
1346
|
+
"lsek_m" + str(i),
|
|
1347
|
+
)
|
|
1348
|
+
)
|
|
1349
|
+
|
|
1350
|
+
num_warps = 4 if B * G * H < 32 or torch.version.hip else 2
|
|
1351
|
+
grid = (M, B * G * H, 1)
|
|
1352
|
+
|
|
1353
|
+
kernel_args = {
|
|
1354
|
+
"G": G,
|
|
1355
|
+
"H": H,
|
|
1356
|
+
"num_warps": num_warps,
|
|
1357
|
+
**attn_split_strides,
|
|
1358
|
+
**lse_split_strides,
|
|
1359
|
+
}
|
|
1360
|
+
kernel_args.update(_strides(attn_out, "oz", "om", "og", "oh", "ok"))
|
|
1361
|
+
kernel_args.update(_strides(lse_out, "lse_z", "lse_g", "lse_h", "lse_m"))
|
|
1362
|
+
if grad_attn is not None:
|
|
1363
|
+
kernel_args.update(_strides(grad_attn, "doz", "dom", "dog", "doh", "dok"))
|
|
1364
|
+
kernel_args.update(_strides(grad_lse, "dlse_z", "dlse_g", "dlse_h", "dlse_m"))
|
|
1365
|
+
return kernel_args, grid
|
|
1366
|
+
|
|
1367
|
+
|
|
1368
|
+
FwOp_Map = {
|
|
1369
|
+
k: FwOp.get_operator(k) for k in [1, 2, 4, 8, 16, 32, 48, 64, 72, 80, 96, 112, 128]
|
|
1370
|
+
}
|
|
1371
|
+
FwOp_S1 = FwOp_Map[1]
|
|
1372
|
+
FwOp_S2 = FwOp_Map[2]
|
|
1373
|
+
FwOp_S4 = FwOp_Map[4]
|
|
1374
|
+
FwOp_S8 = FwOp_Map[8]
|
|
1375
|
+
FwOp_S16 = FwOp_Map[16]
|
|
1376
|
+
FwOp_S32 = FwOp_Map[32]
|
|
1377
|
+
FwOp_S64 = FwOp_Map[64]
|
|
1378
|
+
FwOp_S128 = FwOp_Map[128]
|