mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,858 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
from typing import Any, Iterable, List, Optional, Sequence, Set, Tuple, TypeVar
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from torch.utils.flop_counter import (
|
|
14
|
+
_unpack_flash_attention_nested_shapes,
|
|
15
|
+
register_flop_formula,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from .attn_bias import (
|
|
19
|
+
BlockDiagonalCausalFromBottomRightMask,
|
|
20
|
+
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
21
|
+
BlockDiagonalCausalLocalAttentionMask,
|
|
22
|
+
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
|
|
23
|
+
BlockDiagonalCausalMask,
|
|
24
|
+
BlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
25
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
26
|
+
BlockDiagonalGappyKeysMask,
|
|
27
|
+
BlockDiagonalLocalAttentionFromBottomRightGappyKeysMask,
|
|
28
|
+
BlockDiagonalLocalAttentionPaddedKeysMask,
|
|
29
|
+
BlockDiagonalMask,
|
|
30
|
+
BlockDiagonalPaddedKeysMask,
|
|
31
|
+
LocalAttentionFromBottomRightMask,
|
|
32
|
+
LowerTriangularFromBottomRightLocalAttentionMask,
|
|
33
|
+
LowerTriangularFromBottomRightMask,
|
|
34
|
+
LowerTriangularMask,
|
|
35
|
+
PagedBlockDiagonalCausalLocalPaddedKeysMask,
|
|
36
|
+
PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
37
|
+
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
38
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
39
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
40
|
+
VARLEN_BIASES,
|
|
41
|
+
)
|
|
42
|
+
from .common import (
|
|
43
|
+
AttentionBwOpBase,
|
|
44
|
+
AttentionFwOpBase,
|
|
45
|
+
check_lastdim_alignment_stride1,
|
|
46
|
+
Context,
|
|
47
|
+
Gradients,
|
|
48
|
+
Inputs,
|
|
49
|
+
ScaledTensor,
|
|
50
|
+
)
|
|
51
|
+
from .flash import (
|
|
52
|
+
_check_needs_no_topleft,
|
|
53
|
+
_convert_input_format,
|
|
54
|
+
_is_causal,
|
|
55
|
+
_post_process_lse,
|
|
56
|
+
_window_size,
|
|
57
|
+
)
|
|
58
|
+
from .utils.op_common import get_operator, register_operator
|
|
59
|
+
|
|
60
|
+
FLASH_VERSION = "0.0.0"
|
|
61
|
+
|
|
62
|
+
T = TypeVar("T")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def maybe_contiguous(x: T) -> T:
|
|
66
|
+
return x.contiguous() if x is not None and x.stride(-1) != 1 else x # type: ignore[attr-defined]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
from xformers import _C_flashattention3 # type: ignore[attr-defined]
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
from xformers._cpp_lib import _build_metadata # type: ignore[attr-defined]
|
|
74
|
+
|
|
75
|
+
if _build_metadata is not None:
|
|
76
|
+
FLASH_VERSION = _build_metadata.flash_version
|
|
77
|
+
except ImportError:
|
|
78
|
+
FLASH_VERSION = "unknown"
|
|
79
|
+
except ImportError:
|
|
80
|
+
try:
|
|
81
|
+
# type: ignore
|
|
82
|
+
from ai_codesign.gen_ai.flash_attention_v2.hopper.flash_attn_interface import (
|
|
83
|
+
flashattn_hopper_cuda as _C_flashattention3,
|
|
84
|
+
)
|
|
85
|
+
except ImportError:
|
|
86
|
+
# We end up here is arch is not 90a
|
|
87
|
+
_C_flashattention3 = None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _heuristic_kvsplit(
|
|
91
|
+
inp: Inputs,
|
|
92
|
+
enable_kvsplit_attn: bool,
|
|
93
|
+
) -> bool:
|
|
94
|
+
atten_bias = inp.attn_bias
|
|
95
|
+
|
|
96
|
+
# make sure Q doesn't have varlen
|
|
97
|
+
# pyre-ignore Undefined attribute [16]
|
|
98
|
+
if atten_bias.q_seqinfo.min_seqlen != atten_bias.q_seqinfo.max_seqlen: # type: ignore[union-attr]
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
# filter out prefill case
|
|
102
|
+
# pyre-ignore Undefined attribute [16]
|
|
103
|
+
if atten_bias.q_seqinfo.max_seqlen == atten_bias.k_seqinfo.max_seqlen: # type: ignore[union-attr]
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
return enable_kvsplit_attn
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def mask_non_zeros(s_q: int, s_k: int, window_left: int, window_right: int) -> int:
|
|
110
|
+
# Exact formula for easy cases
|
|
111
|
+
if window_left < 0 and window_right < 0: # full
|
|
112
|
+
return s_q * s_k
|
|
113
|
+
if window_left < 0 and window_right == 0: # causal
|
|
114
|
+
# (from bottom right)
|
|
115
|
+
return (s_q * (s_q + 1)) // 2 + s_q * max(0, s_k - s_q)
|
|
116
|
+
|
|
117
|
+
# NOTE: Flops calculations here assume `s_q == s_k`
|
|
118
|
+
# otherwise the local attention computations are too involved
|
|
119
|
+
# See also https://docs.google.com/spreadsheets/d/1u1ItCZcHLArcqXLj7mwR4H1pI3lMKU1zlxCYi8JCYgk/edit?usp=sharing
|
|
120
|
+
if window_left < 0:
|
|
121
|
+
window_left = s_k
|
|
122
|
+
if window_right < 0:
|
|
123
|
+
window_right = s_k
|
|
124
|
+
|
|
125
|
+
# below the diagonal
|
|
126
|
+
# ┌───────┐
|
|
127
|
+
# │ ╲ │
|
|
128
|
+
# │ ╲ │ <- Upper triangle ("ut")
|
|
129
|
+
# │┄┄┄╲ │ <--- `lastq_ut`
|
|
130
|
+
# │╲ ╲ │
|
|
131
|
+
# │ ╲ ╲ │ <- Lower part
|
|
132
|
+
# │ ╲ ╲│
|
|
133
|
+
# └───────┘
|
|
134
|
+
mask_nz = min(s_q, s_k) # diagonal
|
|
135
|
+
# Below diagonal (with `window_left`)
|
|
136
|
+
lastq_ut = min(window_left, s_q)
|
|
137
|
+
mask_nz += ((lastq_ut - 1) * lastq_ut) // 2 # upper triangle
|
|
138
|
+
mask_nz += (s_q - lastq_ut) * window_left # lower part
|
|
139
|
+
# Above diagonal (with `window_right`)
|
|
140
|
+
# (counting rows from the bottom for symmetry)
|
|
141
|
+
firstq_bt = min(window_right + 1, s_q)
|
|
142
|
+
mask_nz += ((firstq_bt - 1) * firstq_bt) // 2 # bottom triangle
|
|
143
|
+
mask_nz += (s_q - firstq_bt) * window_right
|
|
144
|
+
|
|
145
|
+
return mask_nz
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
# Copied from PyTorch, modified to support MQA/GQA and local attention
|
|
149
|
+
# No need to take care of this for the bwd because we don't "unexpand" the keys
|
|
150
|
+
# and values (in the fwd we expand to help with the seqlen/headdim swap trick).
|
|
151
|
+
def sdpa_flop_count(
|
|
152
|
+
query_shape, key_shape, value_shape, window_left: int, window_right: int
|
|
153
|
+
):
|
|
154
|
+
"""
|
|
155
|
+
Count flops for self-attention.
|
|
156
|
+
|
|
157
|
+
NB: We can assume that value_shape == key_shape
|
|
158
|
+
"""
|
|
159
|
+
b, h_q, s_q, d_q = query_shape
|
|
160
|
+
_b2, h_kv, s_k, _d2 = key_shape
|
|
161
|
+
_b3, _h2, _s3, d_v = value_shape
|
|
162
|
+
assert b == _b2 == _b3
|
|
163
|
+
assert h_kv == _h2
|
|
164
|
+
assert d_q == _d2
|
|
165
|
+
assert s_k == _s3
|
|
166
|
+
assert d_q == _d2
|
|
167
|
+
assert h_q % h_kv == 0
|
|
168
|
+
# How many values are computed in the attention?
|
|
169
|
+
mask_nz = mask_non_zeros(s_q, s_k, window_left, window_right)
|
|
170
|
+
|
|
171
|
+
# q@k.T
|
|
172
|
+
total_flops = 2 * b * h_q * d_q * mask_nz
|
|
173
|
+
# attn@v
|
|
174
|
+
total_flops += 2 * b * h_q * d_v * mask_nz
|
|
175
|
+
return total_flops
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
if _C_flashattention3 is not None: # noqa: C901
|
|
179
|
+
# Compatibility check for FAv3 APIs
|
|
180
|
+
EXPECTED_NUM_OF_ARGS = [
|
|
181
|
+
("fwd", 33),
|
|
182
|
+
("bwd", 22),
|
|
183
|
+
]
|
|
184
|
+
|
|
185
|
+
import re
|
|
186
|
+
|
|
187
|
+
def count_args_from_doc(docstring) -> int:
|
|
188
|
+
# Use a regular expression to find the argument list inside parentheses
|
|
189
|
+
match = re.search(r"\((.*?)\)", docstring)
|
|
190
|
+
if match:
|
|
191
|
+
# Extract the argument list and split by commas
|
|
192
|
+
args_list = match.group(1).split(",")
|
|
193
|
+
# Count the number of arguments
|
|
194
|
+
return len(args_list)
|
|
195
|
+
else:
|
|
196
|
+
raise ValueError("No valid argument list found in the docstring.")
|
|
197
|
+
|
|
198
|
+
for name, num_of_args in EXPECTED_NUM_OF_ARGS:
|
|
199
|
+
num_of_args_from_doc = count_args_from_doc(
|
|
200
|
+
getattr(_C_flashattention3, name).__doc__
|
|
201
|
+
)
|
|
202
|
+
assert num_of_args_from_doc == num_of_args, (
|
|
203
|
+
f"Found func signature mismatch for {name}. Expected {num_of_args},"
|
|
204
|
+
f"actual: {num_of_args_from_doc} Please update the version of Flash Attention3."
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# returns: out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p
|
|
208
|
+
@torch.library.custom_op(
|
|
209
|
+
"mslk_flash3::flash_fwd", mutates_args=(), device_types=["cuda"]
|
|
210
|
+
)
|
|
211
|
+
def mha_fwd(
|
|
212
|
+
query: torch.Tensor,
|
|
213
|
+
key: torch.Tensor,
|
|
214
|
+
value: torch.Tensor,
|
|
215
|
+
cu_seqlens_q: Optional[torch.Tensor],
|
|
216
|
+
cu_seqlens_k: Optional[torch.Tensor],
|
|
217
|
+
seqused_k: Optional[torch.Tensor],
|
|
218
|
+
leftpad_k: Optional[torch.Tensor],
|
|
219
|
+
max_seqlen_q: int,
|
|
220
|
+
max_seqlen_k: int,
|
|
221
|
+
p: float,
|
|
222
|
+
softmax_scale: float,
|
|
223
|
+
is_causal: bool,
|
|
224
|
+
descale_q: Optional[torch.Tensor] = None,
|
|
225
|
+
descale_k: Optional[torch.Tensor] = None,
|
|
226
|
+
descale_v: Optional[torch.Tensor] = None,
|
|
227
|
+
block_table: Optional[torch.Tensor] = None,
|
|
228
|
+
use_kvsplit: bool = False,
|
|
229
|
+
window_left: int = -1,
|
|
230
|
+
window_right: int = -1,
|
|
231
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
232
|
+
query, key = [maybe_contiguous(x) for x in (query, key)]
|
|
233
|
+
value = (
|
|
234
|
+
value.contiguous()
|
|
235
|
+
if value.stride(-1) != 1 and value.stride(-3) != 1
|
|
236
|
+
else value
|
|
237
|
+
)
|
|
238
|
+
cu_seqlens_q, cu_seqlens_k, seqused_k = [
|
|
239
|
+
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, seqused_k)
|
|
240
|
+
]
|
|
241
|
+
block_table = maybe_contiguous(block_table)
|
|
242
|
+
|
|
243
|
+
def _get_batch():
|
|
244
|
+
if cu_seqlens_q is not None:
|
|
245
|
+
return cu_seqlens_q.shape[0] - 1
|
|
246
|
+
return query.shape[0]
|
|
247
|
+
|
|
248
|
+
is_paged = block_table is not None
|
|
249
|
+
bs = _get_batch()
|
|
250
|
+
orig_query_shape = query.shape
|
|
251
|
+
|
|
252
|
+
pack_gqa = None
|
|
253
|
+
if use_kvsplit:
|
|
254
|
+
# For KV split, we need to make sure query in shape [batch, seqlen, num_heads, head_dim_q]
|
|
255
|
+
# to be compatible with `pack_gqa` feature
|
|
256
|
+
query = query.view(bs, -1, query.shape[-2], query.shape[-1])
|
|
257
|
+
cu_seqlens_q = None
|
|
258
|
+
|
|
259
|
+
# Auto-detect if we should use GQA parallel mode
|
|
260
|
+
if query.shape[1] <= 64 and query.shape[2] != key.shape[2]:
|
|
261
|
+
pack_gqa = True
|
|
262
|
+
|
|
263
|
+
assert _C_flashattention3 is not None
|
|
264
|
+
out, softmax_lse, *rest = _C_flashattention3.fwd(
|
|
265
|
+
query,
|
|
266
|
+
key,
|
|
267
|
+
value,
|
|
268
|
+
None,
|
|
269
|
+
None, # k_new, v_new
|
|
270
|
+
None, # qv
|
|
271
|
+
None, # out
|
|
272
|
+
cu_seqlens_q,
|
|
273
|
+
cu_seqlens_k if not is_paged else None,
|
|
274
|
+
None, # cu_seqlens_k_new
|
|
275
|
+
None, # seqused_q
|
|
276
|
+
seqused_k,
|
|
277
|
+
max_seqlen_q,
|
|
278
|
+
max_seqlen_k,
|
|
279
|
+
block_table,
|
|
280
|
+
None, # kv_batch_idx
|
|
281
|
+
leftpad_k,
|
|
282
|
+
None, # rotary_cos
|
|
283
|
+
None, # rotary_sin
|
|
284
|
+
None, # seqlens_rotary
|
|
285
|
+
descale_q,
|
|
286
|
+
descale_k,
|
|
287
|
+
descale_v,
|
|
288
|
+
softmax_scale,
|
|
289
|
+
is_causal,
|
|
290
|
+
window_left,
|
|
291
|
+
window_right,
|
|
292
|
+
0.0, # softcap
|
|
293
|
+
not use_kvsplit, # rotary_interleaved
|
|
294
|
+
None, # scheduler_metadata
|
|
295
|
+
1 if not use_kvsplit else 0, # num_splits
|
|
296
|
+
pack_gqa, # pack_gqa
|
|
297
|
+
0, # sm_margin
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
if query.shape != orig_query_shape:
|
|
301
|
+
# Reshape softmax_lse to match expected output format
|
|
302
|
+
num_heads_q = query.shape[-2]
|
|
303
|
+
orig_lse_shape = softmax_lse.shape
|
|
304
|
+
softmax_lse = softmax_lse.view(
|
|
305
|
+
orig_lse_shape[0], num_heads_q, -1, orig_lse_shape[2]
|
|
306
|
+
)
|
|
307
|
+
softmax_lse = softmax_lse.permute(1, 0, 2, 3).reshape(num_heads_q, -1)
|
|
308
|
+
|
|
309
|
+
return out, softmax_lse
|
|
310
|
+
|
|
311
|
+
@torch.library.register_fake("mslk_flash3::flash_fwd")
|
|
312
|
+
def mha_fwd_fake(
|
|
313
|
+
query: torch.Tensor,
|
|
314
|
+
key: torch.Tensor,
|
|
315
|
+
value: torch.Tensor,
|
|
316
|
+
cu_seqlens_q: Optional[torch.Tensor],
|
|
317
|
+
cu_seqlens_k: Optional[torch.Tensor],
|
|
318
|
+
seqused_k: Optional[torch.Tensor],
|
|
319
|
+
leftpad_k: Optional[torch.Tensor],
|
|
320
|
+
max_seqlen_q: int,
|
|
321
|
+
max_seqlen_k: int,
|
|
322
|
+
p: float,
|
|
323
|
+
softmax_scale: float,
|
|
324
|
+
is_causal: bool,
|
|
325
|
+
descale_q: Optional[torch.Tensor] = None,
|
|
326
|
+
descale_k: Optional[torch.Tensor] = None,
|
|
327
|
+
descale_v: Optional[torch.Tensor] = None,
|
|
328
|
+
block_table: Optional[torch.Tensor] = None,
|
|
329
|
+
use_kvsplit: bool = False,
|
|
330
|
+
window_left: int = -1,
|
|
331
|
+
window_right: int = -1,
|
|
332
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
333
|
+
query_shape = query.shape
|
|
334
|
+
out_shape = (*query_shape[:-1], value.shape[-1])
|
|
335
|
+
if query.dtype == torch.float8_e4m3fn or query.dtype == torch.float8_e5m2:
|
|
336
|
+
out = query.new_empty(out_shape, dtype=torch.bfloat16)
|
|
337
|
+
else:
|
|
338
|
+
out = query.new_empty(out_shape)
|
|
339
|
+
# Query is (B, M, H, K) or (total_M, H, K)
|
|
340
|
+
# LSE is (B, H, M) or (H, total_M)
|
|
341
|
+
lse_shape = (
|
|
342
|
+
(query_shape[0], query_shape[2], query_shape[1])
|
|
343
|
+
if cu_seqlens_q is None
|
|
344
|
+
else (query_shape[1], query_shape[0])
|
|
345
|
+
)
|
|
346
|
+
lse = query.new_empty(lse_shape, dtype=torch.float32)
|
|
347
|
+
return out, lse
|
|
348
|
+
|
|
349
|
+
@register_flop_formula(torch.ops.mslk_flash3.flash_fwd, get_raw=True)
|
|
350
|
+
def mha_fwd_flops(
|
|
351
|
+
query: torch.Tensor,
|
|
352
|
+
key: torch.Tensor,
|
|
353
|
+
value: torch.Tensor,
|
|
354
|
+
cu_seqlens_q: Optional[torch.Tensor],
|
|
355
|
+
cu_seqlens_k: Optional[torch.Tensor],
|
|
356
|
+
seqused_k: Optional[torch.Tensor],
|
|
357
|
+
leftpad_k: Optional[torch.Tensor],
|
|
358
|
+
max_seqlen_q: int,
|
|
359
|
+
max_seqlen_k: int,
|
|
360
|
+
p: float,
|
|
361
|
+
softmax_scale: float,
|
|
362
|
+
is_causal: bool,
|
|
363
|
+
descale_q: Optional[torch.Tensor] = None,
|
|
364
|
+
descale_k: Optional[torch.Tensor] = None,
|
|
365
|
+
descale_v: Optional[torch.Tensor] = None,
|
|
366
|
+
block_table: Optional[torch.Tensor] = None,
|
|
367
|
+
use_kvsplit: bool = False,
|
|
368
|
+
window_left: int = -1,
|
|
369
|
+
window_right: int = -1,
|
|
370
|
+
# The FLOPs counter might pass more args (out_val, out_shape, ...)
|
|
371
|
+
*args,
|
|
372
|
+
**kwargs,
|
|
373
|
+
):
|
|
374
|
+
assert 3 <= query.ndim <= 4
|
|
375
|
+
assert 3 <= key.ndim <= 4
|
|
376
|
+
assert 3 <= value.ndim <= 4
|
|
377
|
+
# This FLOP formula is used by torch.compile's partitioner "automatic
|
|
378
|
+
# activation checkpointing" (AutoAC) to decide which ops to preserve
|
|
379
|
+
# for backward or to recompute. However, this formula is data-dependent!
|
|
380
|
+
# This makes all invocations reuse the choices made based on the first
|
|
381
|
+
# inputs, which may be sub-optimal but also lead to inconsistent
|
|
382
|
+
# behavior across runs. In the presence of tensor parallelism it might
|
|
383
|
+
# also lead to deadlocks if AutoAC recomputes different collectives
|
|
384
|
+
# on different ranks. For distributed jobs it seems more robust to have
|
|
385
|
+
# all ranks always use the "worst case" FLOP estimate. Ranks are in
|
|
386
|
+
# lockstep anyways and will be going as fast as the slowest one.
|
|
387
|
+
if os.environ.get("XFORMERS_FLOP_FORMULA_WORST_CASE", "0") == "1":
|
|
388
|
+
cu_seqlens_q = cu_seqlens_k = max_seqlen_q = max_seqlen_k = None # type: ignore[assignment]
|
|
389
|
+
query = query.unsqueeze(0) if query.ndim == 3 else query
|
|
390
|
+
key = key.unsqueeze(0) if key.ndim == 3 else key
|
|
391
|
+
value = value.unsqueeze(0) if value.ndim == 3 else value
|
|
392
|
+
sizes = _unpack_flash_attention_nested_shapes(
|
|
393
|
+
query=query.transpose(-2, -3) if query.ndim == 4 else query,
|
|
394
|
+
key=key.transpose(-2, -3) if key.ndim == 4 else key,
|
|
395
|
+
value=value.transpose(-2, -3) if value.ndim == 4 else value,
|
|
396
|
+
cum_seq_q=cu_seqlens_q,
|
|
397
|
+
cum_seq_k=cu_seqlens_k,
|
|
398
|
+
max_q=max_seqlen_q,
|
|
399
|
+
max_k=max_seqlen_k,
|
|
400
|
+
)
|
|
401
|
+
if is_causal:
|
|
402
|
+
window_right = 0
|
|
403
|
+
res = sum(
|
|
404
|
+
sdpa_flop_count(
|
|
405
|
+
query_shape,
|
|
406
|
+
key_shape,
|
|
407
|
+
value_shape,
|
|
408
|
+
window_left=window_left,
|
|
409
|
+
window_right=window_right,
|
|
410
|
+
)
|
|
411
|
+
for query_shape, key_shape, value_shape, _ in sizes
|
|
412
|
+
)
|
|
413
|
+
return res
|
|
414
|
+
|
|
415
|
+
def _create_dq_dk_dv(
|
|
416
|
+
grads_share_storage: bool, query, key, value
|
|
417
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
418
|
+
# Create dq,dk,dv
|
|
419
|
+
# If Q/K/V come from a single QKV tensor, let's put the gradient in the
|
|
420
|
+
# right strides, so we can avoid a `cat`
|
|
421
|
+
if grads_share_storage:
|
|
422
|
+
chunk = torch.empty(
|
|
423
|
+
(*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]),
|
|
424
|
+
dtype=query.dtype,
|
|
425
|
+
device=query.device,
|
|
426
|
+
)
|
|
427
|
+
return chunk.select(-3, 0), chunk.select(-3, 1), chunk.select(-3, 2)
|
|
428
|
+
return torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
|
|
429
|
+
|
|
430
|
+
@torch.library.custom_op(
|
|
431
|
+
"mslk_flash3::flash_bwd", mutates_args=(), device_types=["cuda"]
|
|
432
|
+
)
|
|
433
|
+
def mha_bwd(
|
|
434
|
+
grads_share_storage: bool,
|
|
435
|
+
dout: torch.Tensor,
|
|
436
|
+
query: torch.Tensor,
|
|
437
|
+
key: torch.Tensor,
|
|
438
|
+
value: torch.Tensor,
|
|
439
|
+
out: torch.Tensor,
|
|
440
|
+
softmax_lse: torch.Tensor,
|
|
441
|
+
cu_seqlens_q: torch.Tensor,
|
|
442
|
+
cu_seqlens_k: torch.Tensor,
|
|
443
|
+
max_seqlen_q: int,
|
|
444
|
+
max_seqlen_k: int,
|
|
445
|
+
softmax_scale: float,
|
|
446
|
+
is_causal: bool,
|
|
447
|
+
window_left: int,
|
|
448
|
+
window_right: int,
|
|
449
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
450
|
+
dq, dk, dv = _create_dq_dk_dv(grads_share_storage, query, key, value)
|
|
451
|
+
is_deterministic = False
|
|
452
|
+
if cu_seqlens_q is None:
|
|
453
|
+
assert cu_seqlens_k is None
|
|
454
|
+
|
|
455
|
+
assert _C_flashattention3 is not None
|
|
456
|
+
dq, dk, dv, softmax_d, *rest = _C_flashattention3.bwd(
|
|
457
|
+
dout,
|
|
458
|
+
query,
|
|
459
|
+
key,
|
|
460
|
+
value,
|
|
461
|
+
out,
|
|
462
|
+
softmax_lse,
|
|
463
|
+
dq,
|
|
464
|
+
dk,
|
|
465
|
+
dv,
|
|
466
|
+
cu_seqlens_q,
|
|
467
|
+
cu_seqlens_k,
|
|
468
|
+
None, # seqused_q
|
|
469
|
+
None, # seqused_k
|
|
470
|
+
max_seqlen_q,
|
|
471
|
+
max_seqlen_k,
|
|
472
|
+
softmax_scale,
|
|
473
|
+
is_causal,
|
|
474
|
+
window_left,
|
|
475
|
+
window_right,
|
|
476
|
+
0.0, # not used, softcap
|
|
477
|
+
is_deterministic,
|
|
478
|
+
0, # not used, sm_margin
|
|
479
|
+
)
|
|
480
|
+
return dq, dk, dv
|
|
481
|
+
|
|
482
|
+
@torch.library.register_fake("mslk_flash3::flash_bwd")
|
|
483
|
+
def mha_bwd_fake(
|
|
484
|
+
grads_share_storage: bool,
|
|
485
|
+
dout: torch.Tensor,
|
|
486
|
+
query: torch.Tensor,
|
|
487
|
+
key: torch.Tensor,
|
|
488
|
+
value: torch.Tensor,
|
|
489
|
+
out: torch.Tensor,
|
|
490
|
+
softmax_lse: torch.Tensor,
|
|
491
|
+
cu_seqlens_q: torch.Tensor,
|
|
492
|
+
cu_seqlens_k: torch.Tensor,
|
|
493
|
+
max_seqlen_q: int,
|
|
494
|
+
max_seqlen_k: int,
|
|
495
|
+
softmax_scale: float,
|
|
496
|
+
is_causal: bool,
|
|
497
|
+
window_left: int,
|
|
498
|
+
window_right: int,
|
|
499
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
500
|
+
return _create_dq_dk_dv(grads_share_storage, query, key, value)
|
|
501
|
+
|
|
502
|
+
@register_flop_formula(torch.ops.mslk_flash3.flash_bwd, get_raw=True)
|
|
503
|
+
def mha_bwd_flops(
|
|
504
|
+
grads_share_storage: bool,
|
|
505
|
+
dout: torch.Tensor,
|
|
506
|
+
query: torch.Tensor,
|
|
507
|
+
key: torch.Tensor,
|
|
508
|
+
value: torch.Tensor,
|
|
509
|
+
out: torch.Tensor,
|
|
510
|
+
softmax_lse: torch.Tensor,
|
|
511
|
+
cu_seqlens_q: torch.Tensor,
|
|
512
|
+
cu_seqlens_k: torch.Tensor,
|
|
513
|
+
max_seqlen_q: int,
|
|
514
|
+
max_seqlen_k: int,
|
|
515
|
+
softmax_scale: float,
|
|
516
|
+
is_causal: bool,
|
|
517
|
+
window_left: int,
|
|
518
|
+
window_right: int,
|
|
519
|
+
# The FLOPs counter might pass more args (out_val, out_shape, ...)
|
|
520
|
+
*args,
|
|
521
|
+
**kwargs,
|
|
522
|
+
):
|
|
523
|
+
return (
|
|
524
|
+
5
|
|
525
|
+
* mha_fwd_flops(
|
|
526
|
+
query,
|
|
527
|
+
key,
|
|
528
|
+
value,
|
|
529
|
+
cu_seqlens_q=cu_seqlens_q,
|
|
530
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
531
|
+
seqused_k=None,
|
|
532
|
+
leftpad_k=None,
|
|
533
|
+
max_seqlen_q=max_seqlen_q,
|
|
534
|
+
max_seqlen_k=max_seqlen_k,
|
|
535
|
+
p=0.0,
|
|
536
|
+
softmax_scale=1.0,
|
|
537
|
+
is_causal=is_causal,
|
|
538
|
+
descale_q=None,
|
|
539
|
+
descale_k=None,
|
|
540
|
+
descale_v=None,
|
|
541
|
+
block_table=None,
|
|
542
|
+
use_kvsplit=False,
|
|
543
|
+
window_left=window_left,
|
|
544
|
+
window_right=window_right,
|
|
545
|
+
)
|
|
546
|
+
// 2
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
@register_operator
|
|
551
|
+
class FwOp(AttentionFwOpBase):
|
|
552
|
+
"""Operator that computes memory-efficient attention using \
|
|
553
|
+
`Flash-Attention <https://github.com/HazyResearch/flash-attention>`_ \
|
|
554
|
+
implementation.
|
|
555
|
+
"""
|
|
556
|
+
|
|
557
|
+
OPERATOR = get_operator("mslk_flash3", "flash_fwd")
|
|
558
|
+
SUPPORTED_DEVICES: Set[str] = {"cuda"}
|
|
559
|
+
CUDA_MINIMUM_COMPUTE_CAPABILITY = (9, 0)
|
|
560
|
+
CUDA_MAXIMUM_COMPUTE_CAPABILITY = (9, 0)
|
|
561
|
+
SUPPORTED_DTYPES: Set[torch.dtype] = {
|
|
562
|
+
torch.half,
|
|
563
|
+
torch.bfloat16,
|
|
564
|
+
torch.float8_e4m3fn,
|
|
565
|
+
}
|
|
566
|
+
SUPPORTED_MAX_K = 256
|
|
567
|
+
SUPPORTED_MIN_K = 64
|
|
568
|
+
SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
|
|
569
|
+
type(None),
|
|
570
|
+
LowerTriangularMask,
|
|
571
|
+
LowerTriangularFromBottomRightMask,
|
|
572
|
+
LowerTriangularFromBottomRightLocalAttentionMask,
|
|
573
|
+
BlockDiagonalMask,
|
|
574
|
+
BlockDiagonalCausalMask,
|
|
575
|
+
BlockDiagonalCausalLocalAttentionMask,
|
|
576
|
+
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
577
|
+
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
|
|
578
|
+
BlockDiagonalCausalFromBottomRightMask,
|
|
579
|
+
BlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
580
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
581
|
+
BlockDiagonalLocalAttentionPaddedKeysMask,
|
|
582
|
+
BlockDiagonalGappyKeysMask,
|
|
583
|
+
BlockDiagonalLocalAttentionFromBottomRightGappyKeysMask,
|
|
584
|
+
BlockDiagonalPaddedKeysMask,
|
|
585
|
+
LocalAttentionFromBottomRightMask,
|
|
586
|
+
PagedBlockDiagonalCausalLocalPaddedKeysMask,
|
|
587
|
+
PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
588
|
+
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
589
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
590
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
SUPPORTS_DROPOUT = False
|
|
594
|
+
SUPPORTS_CUSTOM_SCALE = True
|
|
595
|
+
SUPPORTS_DIFFERENT_VALUE_EMBED = False
|
|
596
|
+
SUPPORTS_BMGHK = True
|
|
597
|
+
SUPPORTS_PARTIAL = True
|
|
598
|
+
UNPADDED_LSE = True
|
|
599
|
+
NAME = f"fa3F@{FLASH_VERSION}"
|
|
600
|
+
VERSION = FLASH_VERSION
|
|
601
|
+
|
|
602
|
+
@classmethod
|
|
603
|
+
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
|
604
|
+
reasons = super(FwOp, cls).not_supported_reasons(d)
|
|
605
|
+
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
|
606
|
+
if d.query.shape[-1] not in [64, 128, 192, 256]:
|
|
607
|
+
reasons.append("only head-dim 64, 128, 192 or 256 is supported")
|
|
608
|
+
|
|
609
|
+
_check_needs_no_topleft(d, reasons)
|
|
610
|
+
|
|
611
|
+
return reasons
|
|
612
|
+
|
|
613
|
+
@classmethod
|
|
614
|
+
def apply(
|
|
615
|
+
cls,
|
|
616
|
+
inp: Inputs,
|
|
617
|
+
needs_gradient: bool,
|
|
618
|
+
use_kvsplit: bool = False,
|
|
619
|
+
) -> Tuple[torch.Tensor, Optional[Context]]:
|
|
620
|
+
original_query_shape = inp.query.shape
|
|
621
|
+
out_shape = [
|
|
622
|
+
*inp.query.shape[:-1],
|
|
623
|
+
inp.value.shape[-1],
|
|
624
|
+
]
|
|
625
|
+
|
|
626
|
+
def unpack_func(x) -> Tuple[torch.Tensor, Any]:
|
|
627
|
+
return x.unpack() if isinstance(x, ScaledTensor) else (x, None)
|
|
628
|
+
|
|
629
|
+
inp.query, descale_q = unpack_func(inp.query)
|
|
630
|
+
inp.key, descale_k = unpack_func(inp.key)
|
|
631
|
+
inp.value, descale_v = unpack_func(inp.value)
|
|
632
|
+
(
|
|
633
|
+
inp,
|
|
634
|
+
cu_seqlens_q,
|
|
635
|
+
max_seqlen_q,
|
|
636
|
+
cu_seqlens_k,
|
|
637
|
+
max_seqlen_k,
|
|
638
|
+
seqused_k,
|
|
639
|
+
) = _convert_input_format(inp, supports_mqa=True, use_kvsplit=use_kvsplit)
|
|
640
|
+
|
|
641
|
+
q = inp.query
|
|
642
|
+
k = inp.key
|
|
643
|
+
v = inp.value
|
|
644
|
+
|
|
645
|
+
if inp.query.numel() > 0 and inp.key.numel() > 0:
|
|
646
|
+
win_left, win_right = _window_size(inp.attn_bias)
|
|
647
|
+
block_tables = (
|
|
648
|
+
inp.attn_bias.block_tables
|
|
649
|
+
if isinstance(
|
|
650
|
+
inp.attn_bias,
|
|
651
|
+
(PagedBlockDiagonalGappyKeysMask, PagedBlockDiagonalPaddedKeysMask),
|
|
652
|
+
)
|
|
653
|
+
else None
|
|
654
|
+
)
|
|
655
|
+
leftpad_k = None
|
|
656
|
+
if isinstance(inp.attn_bias, PagedBlockDiagonalGappyKeysMask):
|
|
657
|
+
assert cu_seqlens_q is not None
|
|
658
|
+
assert cu_seqlens_k is not None
|
|
659
|
+
if len(cu_seqlens_q) == len(cu_seqlens_k):
|
|
660
|
+
# case #1: len(cu_seqlens_k) = batch_size + 1
|
|
661
|
+
leftpad_k = cu_seqlens_k[:-1]
|
|
662
|
+
else:
|
|
663
|
+
# case #2: len(cu_seqlens_k) = batch_size
|
|
664
|
+
assert len(cu_seqlens_q) - len(cu_seqlens_k) == 1, (
|
|
665
|
+
f"{len(cu_seqlens_q)=} {len(cu_seqlens_k)=}"
|
|
666
|
+
)
|
|
667
|
+
leftpad_k = cu_seqlens_k
|
|
668
|
+
out, softmax_lse = cls.OPERATOR(
|
|
669
|
+
q,
|
|
670
|
+
k,
|
|
671
|
+
v,
|
|
672
|
+
cu_seqlens_q=cu_seqlens_q,
|
|
673
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
674
|
+
seqused_k=seqused_k,
|
|
675
|
+
leftpad_k=leftpad_k,
|
|
676
|
+
max_seqlen_q=max_seqlen_q,
|
|
677
|
+
max_seqlen_k=max_seqlen_k,
|
|
678
|
+
p=inp.p,
|
|
679
|
+
softmax_scale=inp.scale_float,
|
|
680
|
+
is_causal=_is_causal(inp.attn_bias),
|
|
681
|
+
descale_q=descale_q,
|
|
682
|
+
descale_k=descale_k,
|
|
683
|
+
descale_v=descale_v,
|
|
684
|
+
block_table=block_tables,
|
|
685
|
+
use_kvsplit=use_kvsplit,
|
|
686
|
+
window_left=win_left,
|
|
687
|
+
window_right=win_right,
|
|
688
|
+
)
|
|
689
|
+
out = out.reshape(out_shape)
|
|
690
|
+
else:
|
|
691
|
+
out = torch.zeros(
|
|
692
|
+
inp.query.shape, device=inp.query.device, dtype=inp.query.dtype
|
|
693
|
+
)
|
|
694
|
+
if inp.is_partial:
|
|
695
|
+
softmax_lse = torch.full(
|
|
696
|
+
[inp.query.shape[0], inp.query.shape[2], inp.query.shape[1]],
|
|
697
|
+
float("-inf"),
|
|
698
|
+
device=inp.query.device,
|
|
699
|
+
dtype=torch.float32,
|
|
700
|
+
)
|
|
701
|
+
else:
|
|
702
|
+
softmax_lse = torch.empty(
|
|
703
|
+
[inp.query.shape[0], inp.query.shape[2], inp.query.shape[1]],
|
|
704
|
+
device=inp.query.device,
|
|
705
|
+
dtype=torch.float32,
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
ctx = Context(
|
|
709
|
+
out=out,
|
|
710
|
+
lse=softmax_lse,
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
if not needs_gradient:
|
|
714
|
+
return out, None
|
|
715
|
+
ctx = Context(
|
|
716
|
+
out=out,
|
|
717
|
+
lse=_post_process_lse(softmax_lse, inp, tuple(original_query_shape)),
|
|
718
|
+
)
|
|
719
|
+
return (out, ctx)
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
@register_operator
|
|
723
|
+
class BwOp(AttentionBwOpBase):
|
|
724
|
+
__doc__ = FwOp.__doc__
|
|
725
|
+
|
|
726
|
+
OPERATOR = get_operator("mslk_flash3", "flash_bwd")
|
|
727
|
+
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
|
|
728
|
+
CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY
|
|
729
|
+
CUDA_MAXIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MAXIMUM_COMPUTE_CAPABILITY
|
|
730
|
+
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
|
|
731
|
+
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
|
|
732
|
+
SUPPORTED_MIN_K = FwOp.SUPPORTED_MIN_K
|
|
733
|
+
SUPPORTED_ATTN_BIAS_TYPES = (
|
|
734
|
+
# Exclude padded or gappy masks, since seqused_k is not supported by the kernel.
|
|
735
|
+
type(None),
|
|
736
|
+
LowerTriangularMask,
|
|
737
|
+
LowerTriangularFromBottomRightMask,
|
|
738
|
+
LowerTriangularFromBottomRightLocalAttentionMask,
|
|
739
|
+
BlockDiagonalMask,
|
|
740
|
+
BlockDiagonalCausalMask,
|
|
741
|
+
BlockDiagonalCausalLocalAttentionMask,
|
|
742
|
+
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
743
|
+
BlockDiagonalCausalFromBottomRightMask,
|
|
744
|
+
LocalAttentionFromBottomRightMask,
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
|
|
748
|
+
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
|
|
749
|
+
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
|
|
750
|
+
IS_DETERMINISTIC = False
|
|
751
|
+
SUPPORTS_BMGHK = False
|
|
752
|
+
SUPPORTS_LSE_FORMATS: Sequence[str] = ["", "varlen_flat"]
|
|
753
|
+
NAME = f"fa3B@{FLASH_VERSION}"
|
|
754
|
+
VERSION = FLASH_VERSION
|
|
755
|
+
|
|
756
|
+
@classmethod
|
|
757
|
+
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
|
758
|
+
reasons = super(BwOp, cls).not_supported_reasons(d)
|
|
759
|
+
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
|
760
|
+
_check_needs_no_topleft(d, reasons)
|
|
761
|
+
if d.query.shape[-1] not in [64, 128, 192, 256]:
|
|
762
|
+
reasons.append("only head-dim 64, 128, 192 or 256 is supported")
|
|
763
|
+
|
|
764
|
+
_check_needs_no_topleft(d, reasons)
|
|
765
|
+
return reasons
|
|
766
|
+
|
|
767
|
+
@classmethod
|
|
768
|
+
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
|
769
|
+
dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape
|
|
770
|
+
(
|
|
771
|
+
inp,
|
|
772
|
+
cu_seqlens_q,
|
|
773
|
+
max_seqlen_q,
|
|
774
|
+
cu_seqlens_k,
|
|
775
|
+
max_seqlen_k,
|
|
776
|
+
_, # seqused_k,
|
|
777
|
+
) = _convert_input_format(inp, supports_mqa=False)
|
|
778
|
+
ctx_lse = ctx.lse
|
|
779
|
+
|
|
780
|
+
if isinstance(inp.attn_bias, VARLEN_BIASES):
|
|
781
|
+
assert ctx_lse.shape[0] == 1
|
|
782
|
+
ctx_lse = ctx_lse[0]
|
|
783
|
+
else:
|
|
784
|
+
# NOTE: cutlass pads the last dimension, we need to slice it
|
|
785
|
+
assert ctx_lse.shape[2] >= max_seqlen_q
|
|
786
|
+
ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous()
|
|
787
|
+
|
|
788
|
+
kernel_out_shape = [
|
|
789
|
+
*inp.query.shape[:-1],
|
|
790
|
+
inp.value.shape[-1],
|
|
791
|
+
]
|
|
792
|
+
assert grad.dtype in cls.SUPPORTED_DTYPES
|
|
793
|
+
|
|
794
|
+
if inp.query.numel() and inp.key.numel():
|
|
795
|
+
win_left, win_right = _window_size(inp.attn_bias)
|
|
796
|
+
dq, dk, dv = cls.OPERATOR(
|
|
797
|
+
ctx.qkv_share_storage,
|
|
798
|
+
grad.reshape(kernel_out_shape).contiguous(),
|
|
799
|
+
inp.query,
|
|
800
|
+
inp.key,
|
|
801
|
+
inp.value,
|
|
802
|
+
ctx.out.reshape(kernel_out_shape),
|
|
803
|
+
ctx.lse,
|
|
804
|
+
cu_seqlens_q,
|
|
805
|
+
cu_seqlens_k,
|
|
806
|
+
max_seqlen_q,
|
|
807
|
+
max_seqlen_k,
|
|
808
|
+
window_left=win_left,
|
|
809
|
+
window_right=win_right,
|
|
810
|
+
softmax_scale=inp.scale_float,
|
|
811
|
+
is_causal=_is_causal(inp.attn_bias),
|
|
812
|
+
)
|
|
813
|
+
grads = Gradients(dq, dk, dv)
|
|
814
|
+
else:
|
|
815
|
+
grads = Gradients(
|
|
816
|
+
dq=torch.zeros_like(inp.query),
|
|
817
|
+
dk=torch.zeros_like(inp.key),
|
|
818
|
+
dv=torch.zeros_like(inp.value),
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
grads.dq = grads.dq.reshape(dq_shape)
|
|
822
|
+
grads.dk = grads.dk.reshape(dk_shape)
|
|
823
|
+
grads.dv = grads.dv.reshape(dv_shape)
|
|
824
|
+
return grads
|
|
825
|
+
|
|
826
|
+
|
|
827
|
+
@register_operator
|
|
828
|
+
class FwOp_KVSplit(FwOp):
|
|
829
|
+
"""Operator that computes memory-efficient attention using \
|
|
830
|
+
`Flash-Attention3 <https://github.com/Dao-AILab/flash-attention/tree/main/hopper>`_ \
|
|
831
|
+
implementation with heuristic rules to dispatch decoding shapes to KVSplit Attention \
|
|
832
|
+
"""
|
|
833
|
+
|
|
834
|
+
NAME = f"fa3F_splitKV@{FLASH_VERSION}"
|
|
835
|
+
enable_kvsplit_attn: bool = True
|
|
836
|
+
|
|
837
|
+
SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
|
|
838
|
+
type(None),
|
|
839
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
840
|
+
BlockDiagonalPaddedKeysMask,
|
|
841
|
+
BlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
842
|
+
BlockDiagonalGappyKeysMask,
|
|
843
|
+
BlockDiagonalLocalAttentionPaddedKeysMask,
|
|
844
|
+
PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
845
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
846
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
847
|
+
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
848
|
+
)
|
|
849
|
+
|
|
850
|
+
@classmethod
|
|
851
|
+
def apply( # type: ignore[override]
|
|
852
|
+
cls,
|
|
853
|
+
inp: Inputs,
|
|
854
|
+
needs_gradient: bool,
|
|
855
|
+
) -> Tuple[torch.Tensor, Optional[Context]]:
|
|
856
|
+
use_kvsplit = _heuristic_kvsplit(inp, cls.enable_kvsplit_attn)
|
|
857
|
+
|
|
858
|
+
return super().apply(inp, needs_gradient, use_kvsplit)
|