mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1771 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
|
3
|
+
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0.
|
|
4
|
+
|
|
5
|
+
# Supported features:
|
|
6
|
+
# - BF16 & FP16 dtype
|
|
7
|
+
# - noncausal & causal attention
|
|
8
|
+
# - MHA, GQA, MQA
|
|
9
|
+
# - hdim 64, 96, 128.
|
|
10
|
+
# - (hdim_qk, hdim_v) = (192, 128) for Blackwell (i.e. DeepSeek shape)
|
|
11
|
+
# - varlen
|
|
12
|
+
# - sliding window
|
|
13
|
+
# - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow)
|
|
14
|
+
|
|
15
|
+
# Features not supported yet:
|
|
16
|
+
# - split (i.e. FlashDecoding)
|
|
17
|
+
# - tuned block sizes
|
|
18
|
+
# - paged KV
|
|
19
|
+
# - append KV to existing KV cache
|
|
20
|
+
# - FP8
|
|
21
|
+
# - bwd pass optimized for Hopper/Blackwell
|
|
22
|
+
|
|
23
|
+
import math
|
|
24
|
+
from functools import lru_cache
|
|
25
|
+
from typing import Optional, Tuple, Callable
|
|
26
|
+
|
|
27
|
+
import torch
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
import cuda.bindings.driver as cuda
|
|
31
|
+
|
|
32
|
+
import cutlass
|
|
33
|
+
import cutlass.cute as cute
|
|
34
|
+
|
|
35
|
+
from mslk.attention.flash_attn import utils
|
|
36
|
+
from mslk.attention.flash_attn.cute_dsl_utils import to_cute_tensor
|
|
37
|
+
from mslk.attention.flash_attn.flash_fwd import FlashAttentionForwardSm90
|
|
38
|
+
from mslk.attention.flash_attn.flash_fwd_sm100 import FlashAttentionForwardSm100
|
|
39
|
+
from mslk.attention.flash_attn.flash_bwd_preprocess import FlashAttentionBackwardPreprocess
|
|
40
|
+
from mslk.attention.flash_attn.flash_bwd import FlashAttentionBackwardSm80
|
|
41
|
+
from mslk.attention.flash_attn.flash_bwd_sm90 import FlashAttentionBackwardSm90
|
|
42
|
+
from mslk.attention.flash_attn.flash_bwd_sm100 import FlashAttentionBackwardSm100
|
|
43
|
+
from mslk.attention.flash_attn.flash_bwd_postprocess import FlashAttentionBackwardPostprocess
|
|
44
|
+
from mslk.attention.flash_attn.flash_fwd_combine import FlashAttentionForwardCombine
|
|
45
|
+
|
|
46
|
+
from mslk.attention.flash_attn.block_sparsity import (
|
|
47
|
+
BlockSparseTensorsTorch,
|
|
48
|
+
to_cute_block_sparse_tensors,
|
|
49
|
+
normalize_block_sparse_tensors,
|
|
50
|
+
get_block_sparse_expected_shapes,
|
|
51
|
+
get_block_sparse_expected_shapes_bwd,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
@lru_cache(maxsize=None)
|
|
55
|
+
def _get_device_capability():
|
|
56
|
+
"""Cached device capability check."""
|
|
57
|
+
return torch.cuda.get_device_capability()[0]
|
|
58
|
+
|
|
59
|
+
def maybe_contiguous(x):
|
|
60
|
+
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device):
|
|
64
|
+
assert t.shape == expected_shape, f"{name} shape {t.shape} != expected {expected_shape}"
|
|
65
|
+
assert t.dtype == expected_dtype, f"{name} dtype {t.dtype} != expected {expected_dtype}"
|
|
66
|
+
assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}"
|
|
67
|
+
assert t.is_cuda, f"{name} must be on CUDA"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
torch2cute_dtype_map = {
|
|
71
|
+
torch.float16: cutlass.Float16,
|
|
72
|
+
torch.bfloat16: cutlass.BFloat16,
|
|
73
|
+
torch.float32: cutlass.Float32,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits):
|
|
78
|
+
# If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.
|
|
79
|
+
if num_n_blocks <= 4:
|
|
80
|
+
return 1
|
|
81
|
+
|
|
82
|
+
# NOTE: We should revisit this heuristic after persistence is supported for split KV.
|
|
83
|
+
# Sometimes, it's ideal to over-schedule splits for better efficiency.
|
|
84
|
+
return min(num_SMs // total_mblocks, max_splits, num_n_blocks)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _flash_attn_fwd(
|
|
88
|
+
q: torch.Tensor,
|
|
89
|
+
k: torch.Tensor,
|
|
90
|
+
v: torch.Tensor,
|
|
91
|
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
|
92
|
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
|
93
|
+
seqused_q: Optional[torch.Tensor] = None,
|
|
94
|
+
seqused_k: Optional[torch.Tensor] = None,
|
|
95
|
+
max_seqlen_q: Optional[int] = None,
|
|
96
|
+
max_seqlen_k: Optional[int] = None,
|
|
97
|
+
page_table: Optional[torch.Tensor] = None,
|
|
98
|
+
softmax_scale: Optional[float] = None,
|
|
99
|
+
causal: bool = False,
|
|
100
|
+
softcap: Optional[float] = None,
|
|
101
|
+
window_size_left: Optional[int] = None,
|
|
102
|
+
window_size_right: Optional[int] = None,
|
|
103
|
+
learnable_sink: Optional[torch.Tensor] = None,
|
|
104
|
+
# m_block_size: int = 128,
|
|
105
|
+
# n_block_size: int = 64,
|
|
106
|
+
# num_threads: int = 128,
|
|
107
|
+
m_block_size: int = 128,
|
|
108
|
+
n_block_size: int = 128,
|
|
109
|
+
num_threads: int = 384,
|
|
110
|
+
num_splits: int = 1,
|
|
111
|
+
pack_gqa: Optional[bool] = None,
|
|
112
|
+
_compute_capability: Optional[int] = None,
|
|
113
|
+
score_mod: Optional[Callable] = None,
|
|
114
|
+
mask_mod: Optional[Callable] = None,
|
|
115
|
+
block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None,
|
|
116
|
+
return_lse: bool = False,
|
|
117
|
+
out: Optional[torch.Tensor] = None,
|
|
118
|
+
lse: Optional[torch.Tensor] = None,
|
|
119
|
+
aux_tensors: Optional[list[torch.Tensor]] = None,
|
|
120
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
121
|
+
"""Forward pass for FlashAttention.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
...
|
|
125
|
+
score_mod: A callable that takes the attention scores and applies a modification.
|
|
126
|
+
mask_mod: A callable that takes token position information and selectively masks
|
|
127
|
+
block_sparse_tensors: A tuple of tensors used for block sparsity.
|
|
128
|
+
return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate
|
|
129
|
+
out: Optional pre-allocated output tensor. If None, will be allocated internally.
|
|
130
|
+
lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed.
|
|
131
|
+
aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel.
|
|
132
|
+
"""
|
|
133
|
+
q, k, v = [maybe_contiguous(t) for t in (q, k, v)]
|
|
134
|
+
num_head, head_dim = q.shape[-2:]
|
|
135
|
+
if cu_seqlens_q is None:
|
|
136
|
+
batch_size, seqlen_q = q.shape[:2]
|
|
137
|
+
total_q = batch_size * seqlen_q
|
|
138
|
+
else:
|
|
139
|
+
batch_size = cu_seqlens_q.shape[0] - 1
|
|
140
|
+
seqlen_q = None
|
|
141
|
+
total_q = q.shape[0]
|
|
142
|
+
if page_table is not None:
|
|
143
|
+
assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k"
|
|
144
|
+
assert page_table.dtype == torch.int32, "page_table must be int32"
|
|
145
|
+
assert page_table.stride(-1) == 1, "page_table must be contiguous in the last dimension"
|
|
146
|
+
max_num_pages_per_seq = page_table.shape[1]
|
|
147
|
+
assert page_table.shape == (batch_size, max_num_pages_per_seq)
|
|
148
|
+
num_pages, page_size = k.shape[:2]
|
|
149
|
+
seqlen_k = num_pages * page_size
|
|
150
|
+
else:
|
|
151
|
+
num_pages, page_size = None, None
|
|
152
|
+
seqlen_k = k.shape[-3]
|
|
153
|
+
num_head_kv = k.shape[-2]
|
|
154
|
+
head_dim_v = v.shape[-1]
|
|
155
|
+
if cu_seqlens_k is None:
|
|
156
|
+
if page_table is None:
|
|
157
|
+
assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim)
|
|
158
|
+
assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v)
|
|
159
|
+
else:
|
|
160
|
+
assert k.shape == (num_pages, page_size, num_head_kv, head_dim)
|
|
161
|
+
assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v)
|
|
162
|
+
else:
|
|
163
|
+
assert k.shape == (seqlen_k, num_head_kv, head_dim)
|
|
164
|
+
assert v.shape == (seqlen_k, num_head_kv, head_dim_v)
|
|
165
|
+
assert cu_seqlens_k.shape == (batch_size + 1,), (
|
|
166
|
+
"cu_seqlens_k must have shape (batch_size + 1,)"
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
if cu_seqlens_q is not None:
|
|
170
|
+
assert cu_seqlens_q.shape == (batch_size + 1,), (
|
|
171
|
+
"cu_seqlens_q must have shape (batch_size + 1,)"
|
|
172
|
+
)
|
|
173
|
+
assert seqused_q is None or seqused_q.shape == (batch_size,), (
|
|
174
|
+
"seqused_q must have shape (batch_size,)"
|
|
175
|
+
)
|
|
176
|
+
assert seqused_k is None or seqused_k.shape == (batch_size,), (
|
|
177
|
+
"seqused_k must have shape (batch_size,)"
|
|
178
|
+
)
|
|
179
|
+
assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16"
|
|
180
|
+
assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype"
|
|
181
|
+
for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]:
|
|
182
|
+
if t is not None:
|
|
183
|
+
assert t.dtype == torch.int32, (
|
|
184
|
+
"cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32"
|
|
185
|
+
)
|
|
186
|
+
assert t.stride(0) == 1, (
|
|
187
|
+
"cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous"
|
|
188
|
+
)
|
|
189
|
+
if learnable_sink is not None:
|
|
190
|
+
assert learnable_sink.shape == (num_head,)
|
|
191
|
+
assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16"
|
|
192
|
+
|
|
193
|
+
assert all(
|
|
194
|
+
t is None or t.is_cuda
|
|
195
|
+
for t in (
|
|
196
|
+
q,
|
|
197
|
+
k,
|
|
198
|
+
v,
|
|
199
|
+
cu_seqlens_q,
|
|
200
|
+
cu_seqlens_k,
|
|
201
|
+
seqused_q,
|
|
202
|
+
seqused_k,
|
|
203
|
+
page_table,
|
|
204
|
+
learnable_sink,
|
|
205
|
+
)
|
|
206
|
+
), "inputs must be on CUDA device"
|
|
207
|
+
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
|
|
208
|
+
assert head_dim <= 256, "head_dim must be less than or equal to 256"
|
|
209
|
+
alignment = 16 // q.element_size()
|
|
210
|
+
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
|
|
211
|
+
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
|
|
212
|
+
if softmax_scale is None:
|
|
213
|
+
softmax_scale = 1.0 / math.sqrt(head_dim)
|
|
214
|
+
if softcap == 0.0:
|
|
215
|
+
softcap = None
|
|
216
|
+
qhead_per_kvhead = num_head // num_head_kv
|
|
217
|
+
if pack_gqa is None:
|
|
218
|
+
pack_gqa = qhead_per_kvhead > 1
|
|
219
|
+
|
|
220
|
+
out_torch_dtype = q.dtype
|
|
221
|
+
device = q.device
|
|
222
|
+
q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,)
|
|
223
|
+
lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q)
|
|
224
|
+
requires_grad = q.requires_grad or k.requires_grad or v.requires_grad
|
|
225
|
+
|
|
226
|
+
if out is None:
|
|
227
|
+
out = torch.empty(
|
|
228
|
+
*q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device
|
|
229
|
+
)
|
|
230
|
+
else:
|
|
231
|
+
_validate_tensor(out, "out", (*q_batch_seqlen_shape, num_head, head_dim_v), out_torch_dtype, device)
|
|
232
|
+
|
|
233
|
+
if lse is None:
|
|
234
|
+
lse = (
|
|
235
|
+
torch.empty(lse_shape, dtype=torch.float32, device=device)
|
|
236
|
+
if requires_grad or return_lse
|
|
237
|
+
else None
|
|
238
|
+
)
|
|
239
|
+
elif lse is not None:
|
|
240
|
+
_validate_tensor(lse, "lse", lse_shape, torch.float32, device)
|
|
241
|
+
|
|
242
|
+
dtype = torch2cute_dtype_map[q.dtype]
|
|
243
|
+
compute_capability = (
|
|
244
|
+
_get_device_capability()
|
|
245
|
+
if _compute_capability is None
|
|
246
|
+
else _compute_capability
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
assert compute_capability in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
|
|
250
|
+
|
|
251
|
+
use_block_sparsity = block_sparse_tensors is not None
|
|
252
|
+
|
|
253
|
+
if mask_mod is None:
|
|
254
|
+
if causal:
|
|
255
|
+
window_size_right = 0
|
|
256
|
+
local = window_size_left is not None or window_size_right is not None
|
|
257
|
+
if window_size_left is not None or window_size_right is not None:
|
|
258
|
+
if window_size_left is None and window_size_right == 0:
|
|
259
|
+
causal, local = True, False
|
|
260
|
+
window_size_right = None
|
|
261
|
+
else:
|
|
262
|
+
causal, local = False, True
|
|
263
|
+
else:
|
|
264
|
+
causal, local = False, False
|
|
265
|
+
|
|
266
|
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
267
|
+
|
|
268
|
+
if compute_capability == 9: # TODO: tune block size according to hdim.
|
|
269
|
+
if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity:
|
|
270
|
+
n_block_size = 192
|
|
271
|
+
|
|
272
|
+
if compute_capability in [10, 11]:
|
|
273
|
+
if (
|
|
274
|
+
pack_gqa
|
|
275
|
+
and (128 % qhead_per_kvhead != 0)
|
|
276
|
+
):
|
|
277
|
+
pack_gqa = False
|
|
278
|
+
# TODO: fix GQA + SplitKV + non-varlen
|
|
279
|
+
if pack_gqa and num_splits != 1 and cu_seqlens_q is None:
|
|
280
|
+
pack_gqa = False
|
|
281
|
+
|
|
282
|
+
if max_seqlen_q is None:
|
|
283
|
+
max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q
|
|
284
|
+
if max_seqlen_k is None:
|
|
285
|
+
max_seqlen_k = seqlen_k
|
|
286
|
+
seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead
|
|
287
|
+
if compute_capability == 10:
|
|
288
|
+
q_stage = 2 if seqlen_q_packgqa > m_block_size else 1
|
|
289
|
+
else:
|
|
290
|
+
q_stage = 1
|
|
291
|
+
|
|
292
|
+
if num_splits < 1:
|
|
293
|
+
m_block_size_effective = q_stage * m_block_size
|
|
294
|
+
seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, window_size_right + window_size_left + 1 + m_block_size))
|
|
295
|
+
num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size
|
|
296
|
+
num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective
|
|
297
|
+
total_mblocks = batch_size * num_head_kv * num_m_blocks
|
|
298
|
+
num_splits = num_splits_heuristic(
|
|
299
|
+
total_mblocks,
|
|
300
|
+
torch.cuda.get_device_properties(device).multi_processor_count,
|
|
301
|
+
num_n_blocks,
|
|
302
|
+
128,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
is_split_kv = num_splits > 1
|
|
306
|
+
if is_split_kv:
|
|
307
|
+
out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device)
|
|
308
|
+
lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device)
|
|
309
|
+
|
|
310
|
+
# hash score and mask mods for compile cache
|
|
311
|
+
score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False
|
|
312
|
+
mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False
|
|
313
|
+
|
|
314
|
+
if softcap is not None:
|
|
315
|
+
assert score_mod is None, "softcap and score_mod cannot be used together"
|
|
316
|
+
score_mod = utils.create_softcap_scoremod(softcap)
|
|
317
|
+
|
|
318
|
+
is_varlen = (
|
|
319
|
+
cu_seqlens_q is not None
|
|
320
|
+
or cu_seqlens_k is not None
|
|
321
|
+
or seqused_q is not None
|
|
322
|
+
or seqused_k is not None
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
if mask_mod is not None:
|
|
326
|
+
if is_varlen:
|
|
327
|
+
raise NotImplementedError(
|
|
328
|
+
"mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR."
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
if use_block_sparsity:
|
|
332
|
+
if is_varlen:
|
|
333
|
+
raise NotImplementedError(
|
|
334
|
+
"Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR."
|
|
335
|
+
)
|
|
336
|
+
# NB: pack_gqa requires block sparse head dim == 1 (broadcasted)
|
|
337
|
+
if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1:
|
|
338
|
+
pack_gqa = False
|
|
339
|
+
if is_split_kv:
|
|
340
|
+
raise NotImplementedError(
|
|
341
|
+
"Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split."
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
compile_key = (
|
|
345
|
+
dtype,
|
|
346
|
+
head_dim,
|
|
347
|
+
head_dim_v,
|
|
348
|
+
qhead_per_kvhead,
|
|
349
|
+
causal,
|
|
350
|
+
score_mod_hash,
|
|
351
|
+
mask_mod_hash,
|
|
352
|
+
use_block_sparsity,
|
|
353
|
+
len(aux_tensors) if aux_tensors is not None else 0,
|
|
354
|
+
lse is None,
|
|
355
|
+
cu_seqlens_q is None,
|
|
356
|
+
cu_seqlens_k is None,
|
|
357
|
+
seqused_q is None,
|
|
358
|
+
seqused_k is None,
|
|
359
|
+
page_table is not None,
|
|
360
|
+
window_size_left is not None,
|
|
361
|
+
window_size_right is not None,
|
|
362
|
+
learnable_sink is not None,
|
|
363
|
+
m_block_size,
|
|
364
|
+
n_block_size,
|
|
365
|
+
q_stage,
|
|
366
|
+
num_threads,
|
|
367
|
+
is_split_kv,
|
|
368
|
+
pack_gqa,
|
|
369
|
+
compute_capability,
|
|
370
|
+
page_size not in [None, 128], # paged KV non-TMA
|
|
371
|
+
)
|
|
372
|
+
if compile_key not in _flash_attn_fwd.compile_cache:
|
|
373
|
+
(
|
|
374
|
+
cu_seqlens_q_tensor,
|
|
375
|
+
cu_seqlens_k_tensor,
|
|
376
|
+
seqused_q_tensor,
|
|
377
|
+
seqused_k_tensor,
|
|
378
|
+
learnable_sink_tensor,
|
|
379
|
+
) = [
|
|
380
|
+
to_cute_tensor(t, assumed_align=4, leading_dim=0)
|
|
381
|
+
if t is not None
|
|
382
|
+
else None
|
|
383
|
+
for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink)
|
|
384
|
+
]
|
|
385
|
+
page_table_tensor = (
|
|
386
|
+
to_cute_tensor(page_table, assumed_align=4, leading_dim=1)
|
|
387
|
+
if page_table is not None
|
|
388
|
+
else None
|
|
389
|
+
)
|
|
390
|
+
q_tensor, k_tensor, v_tensor, o_tensor = [
|
|
391
|
+
to_cute_tensor(t) for t in (q, k, v, out if not is_split_kv else out_partial)
|
|
392
|
+
]
|
|
393
|
+
if is_split_kv:
|
|
394
|
+
lse_tensor = to_cute_tensor(lse_partial, assumed_align=4)
|
|
395
|
+
elif lse is not None:
|
|
396
|
+
lse_tensor = to_cute_tensor(lse, assumed_align=4)
|
|
397
|
+
else:
|
|
398
|
+
lse_tensor = None
|
|
399
|
+
|
|
400
|
+
sparse_tensors = None
|
|
401
|
+
if block_sparse_tensors is not None:
|
|
402
|
+
if seqlen_q is None:
|
|
403
|
+
raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).")
|
|
404
|
+
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes(
|
|
405
|
+
batch_size, num_head, seqlen_q, seqlen_k,
|
|
406
|
+
m_block_size, n_block_size, q_stage,
|
|
407
|
+
)
|
|
408
|
+
compile_time_normalized = normalize_block_sparse_tensors(
|
|
409
|
+
block_sparse_tensors,
|
|
410
|
+
expected_count_shape=expected_count_shape,
|
|
411
|
+
expected_index_shape=expected_index_shape,
|
|
412
|
+
)
|
|
413
|
+
sparse_tensors = to_cute_block_sparse_tensors(compile_time_normalized)
|
|
414
|
+
|
|
415
|
+
cute_aux_tensors = None
|
|
416
|
+
if aux_tensors is not None:
|
|
417
|
+
cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors]
|
|
418
|
+
|
|
419
|
+
if compute_capability == 9:
|
|
420
|
+
assert page_table is None, "paged KV not supported on SM 9.0"
|
|
421
|
+
assert not is_split_kv, "SplitKV not supported on SM 9.0"
|
|
422
|
+
# fa_fwd = FlashAttentionForwardSm80(
|
|
423
|
+
fa_fwd = FlashAttentionForwardSm90(
|
|
424
|
+
dtype,
|
|
425
|
+
head_dim,
|
|
426
|
+
head_dim_v,
|
|
427
|
+
qhead_per_kvhead,
|
|
428
|
+
is_causal=causal,
|
|
429
|
+
is_local=local,
|
|
430
|
+
pack_gqa=pack_gqa,
|
|
431
|
+
tile_m=m_block_size,
|
|
432
|
+
tile_n=n_block_size,
|
|
433
|
+
# num_stages=1,
|
|
434
|
+
num_stages=2,
|
|
435
|
+
num_threads=num_threads,
|
|
436
|
+
Q_in_regs=False,
|
|
437
|
+
intra_wg_overlap=True,
|
|
438
|
+
mma_pv_is_rs=True,
|
|
439
|
+
mask_mod=mask_mod,
|
|
440
|
+
score_mod=score_mod,
|
|
441
|
+
has_aux_tensors=aux_tensors is not None,
|
|
442
|
+
)
|
|
443
|
+
elif compute_capability in [10, 11]:
|
|
444
|
+
fa_fwd = FlashAttentionForwardSm100(
|
|
445
|
+
head_dim,
|
|
446
|
+
head_dim_v,
|
|
447
|
+
qhead_per_kvhead=qhead_per_kvhead,
|
|
448
|
+
is_causal=causal,
|
|
449
|
+
is_local=local,
|
|
450
|
+
is_split_kv=is_split_kv,
|
|
451
|
+
pack_gqa=pack_gqa,
|
|
452
|
+
m_block_size=m_block_size,
|
|
453
|
+
n_block_size=n_block_size,
|
|
454
|
+
q_stage=q_stage,
|
|
455
|
+
is_persistent=not causal
|
|
456
|
+
and not local
|
|
457
|
+
and cu_seqlens_q is None
|
|
458
|
+
and seqused_q is None
|
|
459
|
+
and not is_split_kv,
|
|
460
|
+
score_mod=score_mod,
|
|
461
|
+
mask_mod=mask_mod,
|
|
462
|
+
has_aux_tensors=aux_tensors is not None,
|
|
463
|
+
paged_kv_non_tma=page_size not in [None, 128],
|
|
464
|
+
is_varlen_q=cu_seqlens_q is not None
|
|
465
|
+
or seqused_q is not None,
|
|
466
|
+
)
|
|
467
|
+
else:
|
|
468
|
+
raise ValueError(
|
|
469
|
+
f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x, 11.x"
|
|
470
|
+
)
|
|
471
|
+
# TODO: check @can_implement
|
|
472
|
+
_flash_attn_fwd.compile_cache[compile_key] = cute.compile(
|
|
473
|
+
fa_fwd,
|
|
474
|
+
q_tensor,
|
|
475
|
+
k_tensor,
|
|
476
|
+
v_tensor,
|
|
477
|
+
o_tensor,
|
|
478
|
+
lse_tensor,
|
|
479
|
+
softmax_scale,
|
|
480
|
+
current_stream,
|
|
481
|
+
cu_seqlens_q_tensor,
|
|
482
|
+
cu_seqlens_k_tensor,
|
|
483
|
+
seqused_q_tensor,
|
|
484
|
+
seqused_k_tensor,
|
|
485
|
+
page_table_tensor,
|
|
486
|
+
window_size_left,
|
|
487
|
+
window_size_right,
|
|
488
|
+
learnable_sink_tensor,
|
|
489
|
+
sparse_tensors,
|
|
490
|
+
cute_aux_tensors,
|
|
491
|
+
options="--enable-tvm-ffi",
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
# Expand block sparse tensors to match actual head count (may be broadcast from 1)
|
|
495
|
+
normalized_block_sparse_tensors = None
|
|
496
|
+
if block_sparse_tensors is not None:
|
|
497
|
+
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes(
|
|
498
|
+
batch_size, num_head, seqlen_q, seqlen_k,
|
|
499
|
+
m_block_size, n_block_size, q_stage,
|
|
500
|
+
)
|
|
501
|
+
normalized_block_sparse_tensors = normalize_block_sparse_tensors(
|
|
502
|
+
block_sparse_tensors,
|
|
503
|
+
expected_count_shape=expected_count_shape,
|
|
504
|
+
expected_index_shape=expected_index_shape,
|
|
505
|
+
)
|
|
506
|
+
_flash_attn_fwd.compile_cache[compile_key](
|
|
507
|
+
q,
|
|
508
|
+
k,
|
|
509
|
+
v,
|
|
510
|
+
out if not is_split_kv else out_partial,
|
|
511
|
+
lse_partial if is_split_kv else lse,
|
|
512
|
+
softmax_scale,
|
|
513
|
+
current_stream,
|
|
514
|
+
cu_seqlens_q,
|
|
515
|
+
cu_seqlens_k,
|
|
516
|
+
seqused_q,
|
|
517
|
+
seqused_k,
|
|
518
|
+
page_table,
|
|
519
|
+
window_size_left,
|
|
520
|
+
window_size_right,
|
|
521
|
+
learnable_sink,
|
|
522
|
+
normalized_block_sparse_tensors,
|
|
523
|
+
aux_tensors,
|
|
524
|
+
)
|
|
525
|
+
if is_split_kv:
|
|
526
|
+
_flash_attn_fwd_combine(
|
|
527
|
+
out_partial,
|
|
528
|
+
lse_partial.transpose(-1, -2),
|
|
529
|
+
out,
|
|
530
|
+
lse.transpose(-1, -2) if lse is not None else None,
|
|
531
|
+
cu_seqlens_q,
|
|
532
|
+
seqused_q,
|
|
533
|
+
)
|
|
534
|
+
return out, lse
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
_flash_attn_fwd.compile_cache = {}
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
def _flash_attn_bwd(
|
|
541
|
+
q: torch.Tensor,
|
|
542
|
+
k: torch.Tensor,
|
|
543
|
+
v: torch.Tensor,
|
|
544
|
+
out: torch.Tensor,
|
|
545
|
+
dout: torch.Tensor,
|
|
546
|
+
lse: torch.Tensor,
|
|
547
|
+
softmax_scale: Optional[float] = None,
|
|
548
|
+
causal: bool = False,
|
|
549
|
+
softcap: float = 0.0,
|
|
550
|
+
window_size_left: Optional[int] = None,
|
|
551
|
+
window_size_right: Optional[int] = None,
|
|
552
|
+
m_block_size: int = 64,
|
|
553
|
+
n_block_size: int = 128,
|
|
554
|
+
num_threads: int = 256,
|
|
555
|
+
pack_gqa: bool = False,
|
|
556
|
+
num_stages_Q: int = 2,
|
|
557
|
+
num_stages_dO: int = 2,
|
|
558
|
+
SdP_swapAB: bool = False,
|
|
559
|
+
dKV_swapAB: bool = False,
|
|
560
|
+
dQ_swapAB: bool = False,
|
|
561
|
+
AtomLayoutMSdP: int = 2,
|
|
562
|
+
AtomLayoutNdKV: int = 2,
|
|
563
|
+
AtomLayoutMdQ: int = 2,
|
|
564
|
+
V_in_regs: bool = False,
|
|
565
|
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
|
566
|
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
|
567
|
+
seqused_q: Optional[torch.Tensor] = None,
|
|
568
|
+
seqused_k: Optional[torch.Tensor] = None,
|
|
569
|
+
max_seqlen_q: Optional[int] = None,
|
|
570
|
+
max_seqlen_k: Optional[int] = None,
|
|
571
|
+
deterministic: bool = False,
|
|
572
|
+
dq: Optional[torch.Tensor] = None,
|
|
573
|
+
dk: Optional[torch.Tensor] = None,
|
|
574
|
+
dv: Optional[torch.Tensor] = None,
|
|
575
|
+
score_mod: Optional[Callable] = None,
|
|
576
|
+
score_mod_bwd: Optional[Callable] = None,
|
|
577
|
+
mask_mod: Optional[Callable] = None,
|
|
578
|
+
aux_tensors: Optional[list[torch.Tensor]] = None,
|
|
579
|
+
block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None,
|
|
580
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
581
|
+
compute_capability = _get_device_capability()
|
|
582
|
+
assert compute_capability in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
|
|
583
|
+
|
|
584
|
+
if compute_capability == 9:
|
|
585
|
+
m_block_size = 80 if not causal else 64
|
|
586
|
+
n_block_size = 128
|
|
587
|
+
num_stages_Q = 2
|
|
588
|
+
num_stages_dO = 2
|
|
589
|
+
num_stages_PdS = 2
|
|
590
|
+
SdP_swapAB = True
|
|
591
|
+
dKV_swapAB = False
|
|
592
|
+
dQ_swapAB = not causal
|
|
593
|
+
AtomLayoutMSdP = 1
|
|
594
|
+
AtomLayoutNdKV = 2
|
|
595
|
+
AtomLayoutMdQ = 1
|
|
596
|
+
cluster_size = 1
|
|
597
|
+
assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x"
|
|
598
|
+
else:
|
|
599
|
+
m_block_size = 128
|
|
600
|
+
n_block_size = 128
|
|
601
|
+
dQ_swapAB = False
|
|
602
|
+
dKV_swapAB = False
|
|
603
|
+
AtomLayoutMdQ = 1
|
|
604
|
+
AtomLayoutNdKV = 1
|
|
605
|
+
# TODO: support cluster size 2
|
|
606
|
+
cluster_size = 1
|
|
607
|
+
q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [
|
|
608
|
+
maybe_contiguous(t)
|
|
609
|
+
for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
|
610
|
+
]
|
|
611
|
+
num_head, head_dim = q.shape[-2:]
|
|
612
|
+
if cu_seqlens_q is None:
|
|
613
|
+
batch_size, seqlen_q = q.shape[:2]
|
|
614
|
+
total_q = batch_size * seqlen_q
|
|
615
|
+
else:
|
|
616
|
+
batch_size = cu_seqlens_q.shape[0] - 1
|
|
617
|
+
total_q = q.shape[0]
|
|
618
|
+
seqlen_q = max_seqlen_q if max_seqlen_q is not None else total_q
|
|
619
|
+
|
|
620
|
+
if cu_seqlens_k is None:
|
|
621
|
+
batch_size, seqlen_k = k.shape[:2]
|
|
622
|
+
total_k = batch_size * seqlen_k
|
|
623
|
+
else:
|
|
624
|
+
batch_size = cu_seqlens_k.shape[0] - 1
|
|
625
|
+
total_k = k.shape[0]
|
|
626
|
+
seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k
|
|
627
|
+
|
|
628
|
+
num_head_kv = k.shape[-2]
|
|
629
|
+
head_dim_v = v.shape[-1]
|
|
630
|
+
|
|
631
|
+
if causal:
|
|
632
|
+
window_size_right = 0
|
|
633
|
+
local = window_size_left is not None or window_size_right is not None
|
|
634
|
+
if local:
|
|
635
|
+
if window_size_left is None and window_size_right == 0:
|
|
636
|
+
causal, local = True, False
|
|
637
|
+
window_size_right = None
|
|
638
|
+
else:
|
|
639
|
+
causal, local = False, True
|
|
640
|
+
|
|
641
|
+
use_block_sparsity = block_sparse_tensors is not None
|
|
642
|
+
|
|
643
|
+
# SM90 block-sparse backward: tile_m=64 is the GCD between a m_block_size that fits,
|
|
644
|
+
# the base block_m of 128 from forward, and block-sparse size for subtiling.
|
|
645
|
+
if compute_capability == 9 and use_block_sparsity:
|
|
646
|
+
m_block_size = 64
|
|
647
|
+
# dQ_swapAB tuning: use False when m_block_size=64 (same as causal case)
|
|
648
|
+
dQ_swapAB = False
|
|
649
|
+
|
|
650
|
+
# NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2
|
|
651
|
+
subtile_factor = 2
|
|
652
|
+
sparse_block_size_q = subtile_factor * m_block_size
|
|
653
|
+
|
|
654
|
+
seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size
|
|
655
|
+
seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size
|
|
656
|
+
|
|
657
|
+
if cu_seqlens_k is None:
|
|
658
|
+
assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim)
|
|
659
|
+
assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v)
|
|
660
|
+
else:
|
|
661
|
+
assert k.shape == (total_k, num_head_kv, head_dim)
|
|
662
|
+
assert v.shape == (total_k, num_head_kv, head_dim_v)
|
|
663
|
+
assert cu_seqlens_k.shape == (batch_size + 1,), (
|
|
664
|
+
"cu_seqlens_k must have shape (batch_size + 1,)"
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
if cu_seqlens_q is not None:
|
|
668
|
+
assert cu_seqlens_q.shape == (batch_size + 1,), (
|
|
669
|
+
"cu_seqlens_q must have shape (batch_size + 1,)"
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
assert out.shape == (total_q, num_head, head_dim_v)
|
|
673
|
+
assert dout.shape == (total_q, num_head, head_dim_v)
|
|
674
|
+
assert lse.shape == (num_head, total_q), "lse must have shape (num_head, total_q)"
|
|
675
|
+
else:
|
|
676
|
+
assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v)
|
|
677
|
+
assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v)
|
|
678
|
+
assert lse.shape == (batch_size, num_head, seqlen_q), (
|
|
679
|
+
"lse must have shape (batch_size, num_head, seqlen_q)"
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16"
|
|
683
|
+
assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, (
|
|
684
|
+
"inputs must have the same dtype"
|
|
685
|
+
)
|
|
686
|
+
for t in [cu_seqlens_q, cu_seqlens_k]:
|
|
687
|
+
if t is not None:
|
|
688
|
+
assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32"
|
|
689
|
+
assert lse.dtype == torch.float32, "lse must be float32"
|
|
690
|
+
assert all(
|
|
691
|
+
t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k)
|
|
692
|
+
), "inputs must be on CUDA device"
|
|
693
|
+
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
|
|
694
|
+
assert head_dim <= 256, "head_dim must be less than or equal to 256"
|
|
695
|
+
alignment = 16 // q.element_size()
|
|
696
|
+
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
|
|
697
|
+
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
|
|
698
|
+
if softmax_scale is None:
|
|
699
|
+
softmax_scale = 1.0 / math.sqrt(head_dim)
|
|
700
|
+
qhead_per_kvhead = num_head // num_head_kv
|
|
701
|
+
if pack_gqa is None:
|
|
702
|
+
pack_gqa = qhead_per_kvhead > 1
|
|
703
|
+
# pack_gqa backward not yet supported in bwd
|
|
704
|
+
pack_gqa = False
|
|
705
|
+
if compute_capability not in [10, 11]:
|
|
706
|
+
assert deterministic is False, "bwd deterministic only supported for sm100/sm110 for now"
|
|
707
|
+
|
|
708
|
+
if score_mod is not None:
|
|
709
|
+
assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided"
|
|
710
|
+
assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)"
|
|
711
|
+
assert cu_seqlens_q is None and cu_seqlens_k is None, (
|
|
712
|
+
"varlen + score_mod not supported in bwd yet"
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
device = q.device
|
|
716
|
+
out_torch_dtype = q.dtype
|
|
717
|
+
|
|
718
|
+
if dq is None:
|
|
719
|
+
dq = torch.empty_like(q)
|
|
720
|
+
else:
|
|
721
|
+
_validate_tensor(dq, "dq", q.shape, out_torch_dtype, device)
|
|
722
|
+
|
|
723
|
+
if dk is None:
|
|
724
|
+
dk = torch.empty_like(k)
|
|
725
|
+
else:
|
|
726
|
+
_validate_tensor(dk, "dk", k.shape, out_torch_dtype, device)
|
|
727
|
+
|
|
728
|
+
if dv is None:
|
|
729
|
+
dv = torch.empty_like(v)
|
|
730
|
+
else:
|
|
731
|
+
_validate_tensor(dv, "dv", v.shape, out_torch_dtype, device)
|
|
732
|
+
|
|
733
|
+
head_dim_rounded = (head_dim + 32 - 1) // 32 * 32
|
|
734
|
+
|
|
735
|
+
if cu_seqlens_q is None:
|
|
736
|
+
dq_accum = torch.empty(
|
|
737
|
+
batch_size,
|
|
738
|
+
num_head,
|
|
739
|
+
seqlen_q_rounded * head_dim_rounded,
|
|
740
|
+
dtype=torch.float32,
|
|
741
|
+
device=device,
|
|
742
|
+
)
|
|
743
|
+
dpsum = torch.empty(
|
|
744
|
+
batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device
|
|
745
|
+
)
|
|
746
|
+
lse_log2 = torch.empty(
|
|
747
|
+
batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device
|
|
748
|
+
)
|
|
749
|
+
else:
|
|
750
|
+
total_q_rounded_padded = (
|
|
751
|
+
(total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size
|
|
752
|
+
)
|
|
753
|
+
dq_accum = torch.empty(
|
|
754
|
+
num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device
|
|
755
|
+
)
|
|
756
|
+
dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
|
|
757
|
+
lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
|
|
758
|
+
|
|
759
|
+
dKV_postprocess = qhead_per_kvhead > 1
|
|
760
|
+
if dKV_postprocess:
|
|
761
|
+
head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32
|
|
762
|
+
if cu_seqlens_k is None:
|
|
763
|
+
num_n_blocks = seqlen_k_rounded // n_block_size
|
|
764
|
+
if cluster_size == 2 and num_n_blocks % cluster_size != 0:
|
|
765
|
+
seqlen_k_rounded = seqlen_k_rounded + n_block_size
|
|
766
|
+
dk_accum = torch.zeros(
|
|
767
|
+
batch_size,
|
|
768
|
+
num_head_kv,
|
|
769
|
+
seqlen_k_rounded * head_dim_rounded,
|
|
770
|
+
dtype=torch.float32,
|
|
771
|
+
device=device,
|
|
772
|
+
)
|
|
773
|
+
dv_accum = torch.zeros(
|
|
774
|
+
batch_size,
|
|
775
|
+
num_head_kv,
|
|
776
|
+
seqlen_k_rounded * head_dim_v_rounded,
|
|
777
|
+
dtype=torch.float32,
|
|
778
|
+
device=device,
|
|
779
|
+
)
|
|
780
|
+
else:
|
|
781
|
+
total_k_rounded_padded = (
|
|
782
|
+
(total_k + cu_seqlens_k.shape[0] * n_block_size - 1) // n_block_size * n_block_size
|
|
783
|
+
)
|
|
784
|
+
num_n_blocks = total_k_rounded_padded // n_block_size
|
|
785
|
+
if cluster_size == 2 and num_n_blocks % cluster_size != 0:
|
|
786
|
+
total_k_rounded_padded = total_k_rounded_padded + n_block_size
|
|
787
|
+
dk_accum = torch.zeros(
|
|
788
|
+
num_head_kv,
|
|
789
|
+
total_k_rounded_padded * head_dim_rounded,
|
|
790
|
+
dtype=torch.float32,
|
|
791
|
+
device=device,
|
|
792
|
+
)
|
|
793
|
+
dv_accum = torch.zeros(
|
|
794
|
+
num_head_kv,
|
|
795
|
+
total_k_rounded_padded * head_dim_v_rounded,
|
|
796
|
+
dtype=torch.float32,
|
|
797
|
+
device=device,
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
dtype = torch2cute_dtype_map[q.dtype]
|
|
801
|
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
802
|
+
|
|
803
|
+
if deterministic:
|
|
804
|
+
dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, 1, dtype=torch.int32, device="cuda")
|
|
805
|
+
else:
|
|
806
|
+
dQ_semaphore = None
|
|
807
|
+
|
|
808
|
+
if deterministic and qhead_per_kvhead > 1:
|
|
809
|
+
dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda")
|
|
810
|
+
dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda")
|
|
811
|
+
else:
|
|
812
|
+
dK_semaphore = None
|
|
813
|
+
dV_semaphore = None
|
|
814
|
+
|
|
815
|
+
# Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum.
|
|
816
|
+
compile_key_pre = (
|
|
817
|
+
compute_capability,
|
|
818
|
+
dtype,
|
|
819
|
+
head_dim_v,
|
|
820
|
+
m_block_size,
|
|
821
|
+
num_threads,
|
|
822
|
+
cu_seqlens_q is None,
|
|
823
|
+
seqused_q is None,
|
|
824
|
+
)
|
|
825
|
+
if compile_key_pre not in _flash_attn_bwd.compile_cache_pre:
|
|
826
|
+
o_tensor, do_tensor = [to_cute_tensor(t) for t in (out, dout)]
|
|
827
|
+
dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
|
|
828
|
+
to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2)
|
|
829
|
+
]
|
|
830
|
+
lse_tensor = to_cute_tensor(lse, assumed_align=4)
|
|
831
|
+
cu_seqlens_q_tensor, seqused_q_tensor = [
|
|
832
|
+
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
|
833
|
+
for t in (cu_seqlens_q, seqused_q)
|
|
834
|
+
]
|
|
835
|
+
arch = compute_capability * 10
|
|
836
|
+
fa_bwd_pre = FlashAttentionBackwardPreprocess(
|
|
837
|
+
dtype,
|
|
838
|
+
head_dim_v,
|
|
839
|
+
arch,
|
|
840
|
+
m_block_size,
|
|
841
|
+
num_threads=num_threads,
|
|
842
|
+
)
|
|
843
|
+
# TODO: check @can_implement
|
|
844
|
+
_flash_attn_bwd.compile_cache_pre[compile_key_pre] = cute.compile(
|
|
845
|
+
fa_bwd_pre,
|
|
846
|
+
o_tensor,
|
|
847
|
+
do_tensor,
|
|
848
|
+
dpsum_tensor,
|
|
849
|
+
lse_tensor,
|
|
850
|
+
lse_log2_tensor,
|
|
851
|
+
dq_accum_tensor,
|
|
852
|
+
cu_seqlens_q_tensor,
|
|
853
|
+
seqused_q_tensor,
|
|
854
|
+
current_stream,
|
|
855
|
+
options="--enable-tvm-ffi",
|
|
856
|
+
)
|
|
857
|
+
_flash_attn_bwd.compile_cache_pre[compile_key_pre](
|
|
858
|
+
out,
|
|
859
|
+
dout,
|
|
860
|
+
dpsum,
|
|
861
|
+
lse,
|
|
862
|
+
lse_log2,
|
|
863
|
+
dq_accum,
|
|
864
|
+
cu_seqlens_q,
|
|
865
|
+
seqused_q,
|
|
866
|
+
current_stream,
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
# NB num_threads application for 3 kernels
|
|
870
|
+
# There are pre, main, post processing kernels, currenlty num_threads is only actually
|
|
871
|
+
# used for the pre proc, and then we hard code to 384 for the main and post proc, and we do
|
|
872
|
+
# before cache key gen
|
|
873
|
+
num_threads = 384
|
|
874
|
+
|
|
875
|
+
# Backward kernel: compute dk, dv, dq_accum.
|
|
876
|
+
score_mod_hash = utils.hash_callable(score_mod) if score_mod else False
|
|
877
|
+
score_mod_bwd_hash = utils.hash_callable(score_mod_bwd) if score_mod_bwd else False
|
|
878
|
+
mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod else False
|
|
879
|
+
num_aux_tensors = len(aux_tensors) if aux_tensors else 0
|
|
880
|
+
cute_aux_tensors = None
|
|
881
|
+
if aux_tensors is not None:
|
|
882
|
+
cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors]
|
|
883
|
+
|
|
884
|
+
if compute_capability == 9:
|
|
885
|
+
compile_key = (
|
|
886
|
+
compute_capability,
|
|
887
|
+
dtype,
|
|
888
|
+
head_dim,
|
|
889
|
+
head_dim_v,
|
|
890
|
+
qhead_per_kvhead,
|
|
891
|
+
causal,
|
|
892
|
+
softcap != 0.0,
|
|
893
|
+
m_block_size,
|
|
894
|
+
n_block_size,
|
|
895
|
+
num_threads,
|
|
896
|
+
pack_gqa,
|
|
897
|
+
num_stages_Q,
|
|
898
|
+
num_stages_dO,
|
|
899
|
+
SdP_swapAB,
|
|
900
|
+
dKV_swapAB,
|
|
901
|
+
dQ_swapAB,
|
|
902
|
+
AtomLayoutMSdP,
|
|
903
|
+
AtomLayoutNdKV,
|
|
904
|
+
AtomLayoutMdQ,
|
|
905
|
+
V_in_regs,
|
|
906
|
+
cu_seqlens_q is None,
|
|
907
|
+
cu_seqlens_k is None,
|
|
908
|
+
seqused_q is None,
|
|
909
|
+
seqused_k is None,
|
|
910
|
+
score_mod_hash,
|
|
911
|
+
score_mod_bwd_hash,
|
|
912
|
+
mask_mod_hash,
|
|
913
|
+
num_aux_tensors,
|
|
914
|
+
use_block_sparsity,
|
|
915
|
+
)
|
|
916
|
+
else:
|
|
917
|
+
compile_key = (
|
|
918
|
+
compute_capability,
|
|
919
|
+
dtype,
|
|
920
|
+
head_dim,
|
|
921
|
+
head_dim_v,
|
|
922
|
+
qhead_per_kvhead,
|
|
923
|
+
causal,
|
|
924
|
+
window_size_left is not None,
|
|
925
|
+
window_size_right is not None,
|
|
926
|
+
softcap != 0.0,
|
|
927
|
+
m_block_size,
|
|
928
|
+
n_block_size,
|
|
929
|
+
num_threads,
|
|
930
|
+
pack_gqa,
|
|
931
|
+
cluster_size,
|
|
932
|
+
deterministic,
|
|
933
|
+
score_mod_hash,
|
|
934
|
+
score_mod_bwd_hash,
|
|
935
|
+
mask_mod_hash,
|
|
936
|
+
num_aux_tensors,
|
|
937
|
+
use_block_sparsity,
|
|
938
|
+
cu_seqlens_q is None,
|
|
939
|
+
cu_seqlens_k is None,
|
|
940
|
+
seqused_q is None,
|
|
941
|
+
seqused_k is None,
|
|
942
|
+
)
|
|
943
|
+
if compile_key not in _flash_attn_bwd.compile_cache:
|
|
944
|
+
q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [
|
|
945
|
+
to_cute_tensor(t) for t in (q, k, v, dout, dq, dk, dv)
|
|
946
|
+
]
|
|
947
|
+
dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
|
|
948
|
+
to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2)
|
|
949
|
+
]
|
|
950
|
+
if dKV_postprocess:
|
|
951
|
+
dk_accum_tensor, dv_accum_tensor = [
|
|
952
|
+
to_cute_tensor(t) for t in (dk_accum, dv_accum)
|
|
953
|
+
]
|
|
954
|
+
cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [
|
|
955
|
+
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
|
956
|
+
for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
|
957
|
+
]
|
|
958
|
+
dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [
|
|
959
|
+
utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order())
|
|
960
|
+
if t is not None else None
|
|
961
|
+
for t in (dQ_semaphore, dK_semaphore, dV_semaphore)
|
|
962
|
+
]
|
|
963
|
+
fa_bwd_sm80 = FlashAttentionBackwardSm80(
|
|
964
|
+
dtype,
|
|
965
|
+
head_dim,
|
|
966
|
+
head_dim_v,
|
|
967
|
+
qhead_per_kvhead,
|
|
968
|
+
m_block_size,
|
|
969
|
+
n_block_size,
|
|
970
|
+
num_stages_Q,
|
|
971
|
+
num_stages_dO,
|
|
972
|
+
num_threads,
|
|
973
|
+
pack_gqa,
|
|
974
|
+
causal,
|
|
975
|
+
SdP_swapAB,
|
|
976
|
+
dKV_swapAB,
|
|
977
|
+
dQ_swapAB,
|
|
978
|
+
AtomLayoutMSdP,
|
|
979
|
+
AtomLayoutNdKV,
|
|
980
|
+
AtomLayoutMdQ,
|
|
981
|
+
V_in_regs=V_in_regs,
|
|
982
|
+
)
|
|
983
|
+
if compute_capability == 9:
|
|
984
|
+
fa_bwd_obj = FlashAttentionBackwardSm90(
|
|
985
|
+
dtype,
|
|
986
|
+
head_dim,
|
|
987
|
+
head_dim_v,
|
|
988
|
+
qhead_per_kvhead,
|
|
989
|
+
causal,
|
|
990
|
+
m_block_size,
|
|
991
|
+
n_block_size,
|
|
992
|
+
num_stages_Q,
|
|
993
|
+
num_stages_dO,
|
|
994
|
+
num_stages_PdS,
|
|
995
|
+
SdP_swapAB,
|
|
996
|
+
dKV_swapAB,
|
|
997
|
+
dQ_swapAB,
|
|
998
|
+
AtomLayoutMSdP,
|
|
999
|
+
AtomLayoutNdKV,
|
|
1000
|
+
AtomLayoutMdQ,
|
|
1001
|
+
num_threads,
|
|
1002
|
+
V_in_regs=V_in_regs,
|
|
1003
|
+
score_mod=score_mod,
|
|
1004
|
+
score_mod_bwd=score_mod_bwd,
|
|
1005
|
+
mask_mod=mask_mod,
|
|
1006
|
+
has_aux_tensors=aux_tensors is not None,
|
|
1007
|
+
subtile_factor=subtile_factor,
|
|
1008
|
+
)
|
|
1009
|
+
else:
|
|
1010
|
+
fa_bwd_obj = FlashAttentionBackwardSm100(
|
|
1011
|
+
head_dim,
|
|
1012
|
+
head_dim_v,
|
|
1013
|
+
is_causal=causal,
|
|
1014
|
+
is_local=local,
|
|
1015
|
+
qhead_per_kvhead=qhead_per_kvhead,
|
|
1016
|
+
# tile_m=m_block_size,
|
|
1017
|
+
# tile_n=n_block_size,
|
|
1018
|
+
cluster_size=cluster_size,
|
|
1019
|
+
# cluster_size=1,
|
|
1020
|
+
deterministic=deterministic,
|
|
1021
|
+
score_mod=score_mod,
|
|
1022
|
+
score_mod_bwd=score_mod_bwd,
|
|
1023
|
+
mask_mod=mask_mod,
|
|
1024
|
+
has_aux_tensors=aux_tensors is not None,
|
|
1025
|
+
subtile_factor=subtile_factor,
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
# Block sparse tensors for backward use Q-direction indexing (transposed from forward).
|
|
1029
|
+
# sparse_block_size_q = subtile_factor * tile_m matches BlockMask granularity.
|
|
1030
|
+
sparse_tensors_compile = None
|
|
1031
|
+
if block_sparse_tensors is not None:
|
|
1032
|
+
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd(
|
|
1033
|
+
batch_size, num_head, seqlen_q, seqlen_k,
|
|
1034
|
+
m_block_size, n_block_size, subtile_factor,
|
|
1035
|
+
)
|
|
1036
|
+
compile_time_normalized = normalize_block_sparse_tensors(
|
|
1037
|
+
block_sparse_tensors,
|
|
1038
|
+
expected_count_shape=expected_count_shape,
|
|
1039
|
+
expected_index_shape=expected_index_shape,
|
|
1040
|
+
context="_flash_attn_bwd",
|
|
1041
|
+
hint=lambda: (
|
|
1042
|
+
f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). "
|
|
1043
|
+
f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) "
|
|
1044
|
+
f"(sparse_block_size_q={sparse_block_size_q})."
|
|
1045
|
+
),
|
|
1046
|
+
)
|
|
1047
|
+
sparse_tensors_compile = to_cute_block_sparse_tensors(compile_time_normalized)
|
|
1048
|
+
|
|
1049
|
+
# TODO: check @can_implement
|
|
1050
|
+
_flash_attn_bwd.compile_cache[compile_key] = cute.compile(
|
|
1051
|
+
fa_bwd_obj,
|
|
1052
|
+
q_tensor,
|
|
1053
|
+
k_tensor,
|
|
1054
|
+
v_tensor,
|
|
1055
|
+
do_tensor,
|
|
1056
|
+
lse_log2_tensor,
|
|
1057
|
+
dpsum_tensor,
|
|
1058
|
+
dq_accum_tensor,
|
|
1059
|
+
dk_tensor if not dKV_postprocess else dk_accum_tensor,
|
|
1060
|
+
dv_tensor if not dKV_postprocess else dv_accum_tensor,
|
|
1061
|
+
softmax_scale,
|
|
1062
|
+
current_stream,
|
|
1063
|
+
cu_seqlens_q_tensor,
|
|
1064
|
+
cu_seqlens_k_tensor,
|
|
1065
|
+
seqused_q_tensor,
|
|
1066
|
+
seqused_k_tensor,
|
|
1067
|
+
None, # softcap - not yet supported in backward
|
|
1068
|
+
window_size_left,
|
|
1069
|
+
window_size_right,
|
|
1070
|
+
dQ_semaphore_tensor,
|
|
1071
|
+
dK_semaphore_tensor,
|
|
1072
|
+
dV_semaphore_tensor,
|
|
1073
|
+
cute_aux_tensors,
|
|
1074
|
+
sparse_tensors_compile,
|
|
1075
|
+
options="--enable-tvm-ffi",
|
|
1076
|
+
)
|
|
1077
|
+
# Runtime normalization of block sparse tensors for both SM90 and SM100
|
|
1078
|
+
normalized_block_sparse_tensors = None
|
|
1079
|
+
if block_sparse_tensors is not None:
|
|
1080
|
+
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd(
|
|
1081
|
+
batch_size, num_head, seqlen_q, seqlen_k,
|
|
1082
|
+
m_block_size, n_block_size, subtile_factor,
|
|
1083
|
+
)
|
|
1084
|
+
normalized_block_sparse_tensors = normalize_block_sparse_tensors(
|
|
1085
|
+
block_sparse_tensors,
|
|
1086
|
+
expected_count_shape=expected_count_shape,
|
|
1087
|
+
expected_index_shape=expected_index_shape,
|
|
1088
|
+
context="_flash_attn_bwd",
|
|
1089
|
+
hint=lambda: (
|
|
1090
|
+
f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). "
|
|
1091
|
+
f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) "
|
|
1092
|
+
f"(sparse_block_size_q={sparse_block_size_q})."
|
|
1093
|
+
),
|
|
1094
|
+
)
|
|
1095
|
+
|
|
1096
|
+
_flash_attn_bwd.compile_cache[compile_key](
|
|
1097
|
+
q,
|
|
1098
|
+
k,
|
|
1099
|
+
v,
|
|
1100
|
+
dout,
|
|
1101
|
+
lse_log2,
|
|
1102
|
+
dpsum,
|
|
1103
|
+
dq_accum,
|
|
1104
|
+
dk if not dKV_postprocess else dk_accum,
|
|
1105
|
+
dv if not dKV_postprocess else dv_accum,
|
|
1106
|
+
softmax_scale,
|
|
1107
|
+
current_stream,
|
|
1108
|
+
cu_seqlens_q,
|
|
1109
|
+
cu_seqlens_k,
|
|
1110
|
+
seqused_q,
|
|
1111
|
+
seqused_k,
|
|
1112
|
+
None, # softcap - not yet supported in backward
|
|
1113
|
+
window_size_left,
|
|
1114
|
+
window_size_right,
|
|
1115
|
+
dQ_semaphore,
|
|
1116
|
+
dK_semaphore,
|
|
1117
|
+
dV_semaphore,
|
|
1118
|
+
aux_tensors,
|
|
1119
|
+
normalized_block_sparse_tensors,
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1122
|
+
num_threads = 256 if compute_capability == 9 else 128
|
|
1123
|
+
arch = compute_capability * 10
|
|
1124
|
+
# Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16
|
|
1125
|
+
compile_key_post = (
|
|
1126
|
+
compute_capability,
|
|
1127
|
+
dtype,
|
|
1128
|
+
head_dim,
|
|
1129
|
+
m_block_size,
|
|
1130
|
+
num_threads,
|
|
1131
|
+
AtomLayoutMdQ,
|
|
1132
|
+
dQ_swapAB,
|
|
1133
|
+
cu_seqlens_q is None,
|
|
1134
|
+
seqused_q is None,
|
|
1135
|
+
)
|
|
1136
|
+
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
|
|
1137
|
+
dq_accum_tensor = to_cute_tensor(dq_accum)
|
|
1138
|
+
dq_tensor = to_cute_tensor(dq)
|
|
1139
|
+
cu_seqlens_q_tensor, seqused_q_tensor = [
|
|
1140
|
+
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
|
1141
|
+
for t in (cu_seqlens_q, seqused_q)
|
|
1142
|
+
]
|
|
1143
|
+
fa_bwd_post = FlashAttentionBackwardPostprocess(
|
|
1144
|
+
dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB
|
|
1145
|
+
)
|
|
1146
|
+
# TODO: check @can_implement
|
|
1147
|
+
_flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
|
|
1148
|
+
fa_bwd_post,
|
|
1149
|
+
dq_accum_tensor,
|
|
1150
|
+
dq_tensor,
|
|
1151
|
+
softmax_scale,
|
|
1152
|
+
cu_seqlens_q_tensor,
|
|
1153
|
+
seqused_q_tensor,
|
|
1154
|
+
current_stream,
|
|
1155
|
+
options="--enable-tvm-ffi",
|
|
1156
|
+
)
|
|
1157
|
+
_flash_attn_bwd.compile_cache_post[compile_key_post](
|
|
1158
|
+
dq_accum,
|
|
1159
|
+
dq,
|
|
1160
|
+
softmax_scale,
|
|
1161
|
+
cu_seqlens_q,
|
|
1162
|
+
seqused_q,
|
|
1163
|
+
current_stream,
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
if dKV_postprocess:
|
|
1167
|
+
# Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16
|
|
1168
|
+
compile_key_post = (
|
|
1169
|
+
compute_capability,
|
|
1170
|
+
dtype,
|
|
1171
|
+
head_dim,
|
|
1172
|
+
n_block_size,
|
|
1173
|
+
num_threads,
|
|
1174
|
+
AtomLayoutNdKV,
|
|
1175
|
+
dKV_swapAB,
|
|
1176
|
+
cu_seqlens_k is None,
|
|
1177
|
+
seqused_k is None,
|
|
1178
|
+
)
|
|
1179
|
+
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
|
|
1180
|
+
dk_accum_tensor = to_cute_tensor(dk_accum)
|
|
1181
|
+
dk_tensor = to_cute_tensor(dk)
|
|
1182
|
+
cu_seqlens_k_tensor, seqused_k_tensor = [
|
|
1183
|
+
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
|
1184
|
+
for t in (cu_seqlens_k, seqused_k)
|
|
1185
|
+
]
|
|
1186
|
+
arch = compute_capability * 10
|
|
1187
|
+
fa_bwd_post = FlashAttentionBackwardPostprocess(
|
|
1188
|
+
dtype, head_dim, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB
|
|
1189
|
+
)
|
|
1190
|
+
# TODO: check @can_implement
|
|
1191
|
+
_flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
|
|
1192
|
+
fa_bwd_post,
|
|
1193
|
+
dk_accum_tensor,
|
|
1194
|
+
dk_tensor,
|
|
1195
|
+
softmax_scale,
|
|
1196
|
+
cu_seqlens_k_tensor,
|
|
1197
|
+
seqused_k_tensor,
|
|
1198
|
+
current_stream,
|
|
1199
|
+
options="--enable-tvm-ffi",
|
|
1200
|
+
)
|
|
1201
|
+
_flash_attn_bwd.compile_cache_post[compile_key_post](
|
|
1202
|
+
dk_accum,
|
|
1203
|
+
dk,
|
|
1204
|
+
softmax_scale,
|
|
1205
|
+
cu_seqlens_k,
|
|
1206
|
+
seqused_k,
|
|
1207
|
+
current_stream,
|
|
1208
|
+
)
|
|
1209
|
+
compile_key_post = (
|
|
1210
|
+
compute_capability,
|
|
1211
|
+
dtype,
|
|
1212
|
+
head_dim_v,
|
|
1213
|
+
n_block_size,
|
|
1214
|
+
num_threads,
|
|
1215
|
+
AtomLayoutNdKV,
|
|
1216
|
+
dKV_swapAB,
|
|
1217
|
+
cu_seqlens_k is None,
|
|
1218
|
+
seqused_k is None,
|
|
1219
|
+
)
|
|
1220
|
+
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
|
|
1221
|
+
dv_accum_tensor = to_cute_tensor(dv_accum)
|
|
1222
|
+
dv_tensor = to_cute_tensor(dv)
|
|
1223
|
+
cu_seqlens_k_tensor, seqused_k_tensor = [
|
|
1224
|
+
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
|
1225
|
+
for t in (cu_seqlens_k, seqused_k)
|
|
1226
|
+
]
|
|
1227
|
+
arch = compute_capability * 10
|
|
1228
|
+
fa_bwd_post = FlashAttentionBackwardPostprocess(
|
|
1229
|
+
dtype, head_dim_v, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB
|
|
1230
|
+
)
|
|
1231
|
+
# TODO: check @can_implement
|
|
1232
|
+
_flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
|
|
1233
|
+
fa_bwd_post,
|
|
1234
|
+
dv_accum_tensor,
|
|
1235
|
+
dv_tensor,
|
|
1236
|
+
cutlass.Float32(1.0),
|
|
1237
|
+
cu_seqlens_k_tensor,
|
|
1238
|
+
seqused_k_tensor,
|
|
1239
|
+
current_stream,
|
|
1240
|
+
options="--enable-tvm-ffi",
|
|
1241
|
+
)
|
|
1242
|
+
_flash_attn_bwd.compile_cache_post[compile_key_post](
|
|
1243
|
+
dv_accum,
|
|
1244
|
+
dv,
|
|
1245
|
+
1.0,
|
|
1246
|
+
cu_seqlens_k,
|
|
1247
|
+
seqused_k,
|
|
1248
|
+
current_stream,
|
|
1249
|
+
)
|
|
1250
|
+
|
|
1251
|
+
return dq, dk, dv
|
|
1252
|
+
|
|
1253
|
+
|
|
1254
|
+
_flash_attn_bwd.compile_cache_pre = {}
|
|
1255
|
+
_flash_attn_bwd.compile_cache = {}
|
|
1256
|
+
_flash_attn_bwd.compile_cache_post = {}
|
|
1257
|
+
|
|
1258
|
+
|
|
1259
|
+
class FlashAttnFunc(torch.autograd.Function):
|
|
1260
|
+
@staticmethod
|
|
1261
|
+
def forward(
|
|
1262
|
+
ctx,
|
|
1263
|
+
q: torch.Tensor,
|
|
1264
|
+
k: torch.Tensor,
|
|
1265
|
+
v: torch.Tensor,
|
|
1266
|
+
softmax_scale: Optional[float] = None,
|
|
1267
|
+
causal: bool = False,
|
|
1268
|
+
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
|
|
1269
|
+
learnable_sink: Optional[torch.Tensor] = None,
|
|
1270
|
+
softcap: float = 0.0,
|
|
1271
|
+
num_splits: int = 1,
|
|
1272
|
+
pack_gqa: Optional[bool] = None,
|
|
1273
|
+
deterministic: bool = False,
|
|
1274
|
+
mask_mod: Optional[Callable] = None,
|
|
1275
|
+
full_block_cnt: Optional[torch.Tensor] = None,
|
|
1276
|
+
full_block_idx: Optional[torch.Tensor] = None,
|
|
1277
|
+
mask_block_cnt: Optional[torch.Tensor] = None,
|
|
1278
|
+
mask_block_idx: Optional[torch.Tensor] = None,
|
|
1279
|
+
):
|
|
1280
|
+
# Only create block sparse tensors if at least one block sparse parameter is provided
|
|
1281
|
+
block_sparse_tensors = None
|
|
1282
|
+
if any(t is not None for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]):
|
|
1283
|
+
block_sparse_tensors = BlockSparseTensorsTorch(
|
|
1284
|
+
full_block_cnt=full_block_cnt,
|
|
1285
|
+
full_block_idx=full_block_idx,
|
|
1286
|
+
mask_block_cnt=mask_block_cnt,
|
|
1287
|
+
mask_block_idx=mask_block_idx,
|
|
1288
|
+
)
|
|
1289
|
+
out, lse = _flash_attn_fwd(
|
|
1290
|
+
q,
|
|
1291
|
+
k,
|
|
1292
|
+
v,
|
|
1293
|
+
softmax_scale=softmax_scale,
|
|
1294
|
+
causal=causal,
|
|
1295
|
+
window_size_left=window_size[0],
|
|
1296
|
+
window_size_right=window_size[1],
|
|
1297
|
+
learnable_sink=learnable_sink,
|
|
1298
|
+
softcap=softcap,
|
|
1299
|
+
num_splits=num_splits,
|
|
1300
|
+
pack_gqa=pack_gqa,
|
|
1301
|
+
mask_mod=mask_mod,
|
|
1302
|
+
block_sparse_tensors=block_sparse_tensors
|
|
1303
|
+
)
|
|
1304
|
+
ctx.save_for_backward(q, k, v, out, lse)
|
|
1305
|
+
ctx.softmax_scale = softmax_scale
|
|
1306
|
+
ctx.causal = causal
|
|
1307
|
+
ctx.window_size = window_size
|
|
1308
|
+
ctx.softcap = softcap
|
|
1309
|
+
ctx.deterministic = deterministic
|
|
1310
|
+
return out, lse
|
|
1311
|
+
|
|
1312
|
+
@staticmethod
|
|
1313
|
+
def backward(ctx, dout, *args):
|
|
1314
|
+
q, k, v, out, lse = ctx.saved_tensors
|
|
1315
|
+
dq, dk, dv = _flash_attn_bwd(
|
|
1316
|
+
q,
|
|
1317
|
+
k,
|
|
1318
|
+
v,
|
|
1319
|
+
out,
|
|
1320
|
+
dout,
|
|
1321
|
+
lse,
|
|
1322
|
+
ctx.softmax_scale,
|
|
1323
|
+
ctx.causal,
|
|
1324
|
+
ctx.softcap,
|
|
1325
|
+
window_size_left=ctx.window_size[0],
|
|
1326
|
+
window_size_right=ctx.window_size[1],
|
|
1327
|
+
deterministic=ctx.deterministic,
|
|
1328
|
+
)
|
|
1329
|
+
return dq, dk, dv, *((None,) * 20) # Extra Nones is fine
|
|
1330
|
+
|
|
1331
|
+
|
|
1332
|
+
class FlashAttnVarlenFunc(torch.autograd.Function):
|
|
1333
|
+
@staticmethod
|
|
1334
|
+
def forward(
|
|
1335
|
+
ctx,
|
|
1336
|
+
q: torch.Tensor,
|
|
1337
|
+
k: torch.Tensor,
|
|
1338
|
+
v: torch.Tensor,
|
|
1339
|
+
cu_seqlens_q: Optional[torch.Tensor],
|
|
1340
|
+
cu_seqlens_k: Optional[torch.Tensor],
|
|
1341
|
+
seqused_q: Optional[torch.Tensor] = None,
|
|
1342
|
+
seqused_k: Optional[torch.Tensor] = None,
|
|
1343
|
+
max_seqlen_q: Optional[int] = None,
|
|
1344
|
+
max_seqlen_k: Optional[int] = None,
|
|
1345
|
+
page_table: Optional[torch.Tensor] = None,
|
|
1346
|
+
softmax_scale: Optional[float] = None,
|
|
1347
|
+
causal: bool = False,
|
|
1348
|
+
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
|
|
1349
|
+
learnable_sink: Optional[torch.Tensor] = None,
|
|
1350
|
+
softcap: float = 0.0,
|
|
1351
|
+
num_splits: int = 1,
|
|
1352
|
+
pack_gqa: Optional[bool] = None,
|
|
1353
|
+
deterministic: bool = False,
|
|
1354
|
+
score_mod: Optional[Callable] = None,
|
|
1355
|
+
aux_tensors: Optional[list] = None,
|
|
1356
|
+
):
|
|
1357
|
+
out, lse = _flash_attn_fwd(
|
|
1358
|
+
q,
|
|
1359
|
+
k,
|
|
1360
|
+
v,
|
|
1361
|
+
cu_seqlens_q,
|
|
1362
|
+
cu_seqlens_k,
|
|
1363
|
+
seqused_q,
|
|
1364
|
+
seqused_k,
|
|
1365
|
+
max_seqlen_q=max_seqlen_q,
|
|
1366
|
+
max_seqlen_k=max_seqlen_k,
|
|
1367
|
+
page_table=page_table,
|
|
1368
|
+
softmax_scale=softmax_scale,
|
|
1369
|
+
causal=causal,
|
|
1370
|
+
window_size_left=window_size[0],
|
|
1371
|
+
window_size_right=window_size[1],
|
|
1372
|
+
learnable_sink=learnable_sink,
|
|
1373
|
+
softcap=softcap,
|
|
1374
|
+
num_splits=num_splits,
|
|
1375
|
+
pack_gqa=pack_gqa,
|
|
1376
|
+
score_mod=score_mod,
|
|
1377
|
+
aux_tensors=aux_tensors,
|
|
1378
|
+
)
|
|
1379
|
+
ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
|
1380
|
+
ctx.softmax_scale = softmax_scale
|
|
1381
|
+
ctx.causal = causal
|
|
1382
|
+
ctx.window_size = window_size
|
|
1383
|
+
ctx.softcap = softcap
|
|
1384
|
+
ctx.deterministic = deterministic
|
|
1385
|
+
ctx.max_seqlen_q = max_seqlen_q
|
|
1386
|
+
ctx.max_seqlen_k = max_seqlen_k
|
|
1387
|
+
return out, lse
|
|
1388
|
+
|
|
1389
|
+
@staticmethod
|
|
1390
|
+
def backward(ctx, dout, *args):
|
|
1391
|
+
q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
|
|
1392
|
+
assert ctx.softcap == 0.0
|
|
1393
|
+
dq, dk, dv = _flash_attn_bwd(
|
|
1394
|
+
q,
|
|
1395
|
+
k,
|
|
1396
|
+
v,
|
|
1397
|
+
out,
|
|
1398
|
+
dout,
|
|
1399
|
+
lse,
|
|
1400
|
+
ctx.softmax_scale,
|
|
1401
|
+
ctx.causal,
|
|
1402
|
+
ctx.softcap,
|
|
1403
|
+
window_size_left=ctx.window_size[0],
|
|
1404
|
+
window_size_right=ctx.window_size[1],
|
|
1405
|
+
cu_seqlens_q=cu_seqlens_q,
|
|
1406
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
1407
|
+
seqused_q=seqused_q,
|
|
1408
|
+
seqused_k=seqused_k,
|
|
1409
|
+
max_seqlen_q=ctx.max_seqlen_q,
|
|
1410
|
+
max_seqlen_k=ctx.max_seqlen_k,
|
|
1411
|
+
deterministic=ctx.deterministic,
|
|
1412
|
+
)
|
|
1413
|
+
|
|
1414
|
+
return dq, dk, dv, *((None,) * 20)
|
|
1415
|
+
|
|
1416
|
+
|
|
1417
|
+
def flash_attn_func(
|
|
1418
|
+
q: torch.Tensor,
|
|
1419
|
+
k: torch.Tensor,
|
|
1420
|
+
v: torch.Tensor,
|
|
1421
|
+
softmax_scale: Optional[float] = None,
|
|
1422
|
+
causal: bool = False,
|
|
1423
|
+
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
|
|
1424
|
+
learnable_sink: Optional[torch.Tensor] = None,
|
|
1425
|
+
softcap: float = 0.0,
|
|
1426
|
+
num_splits: int = 1,
|
|
1427
|
+
pack_gqa: Optional[bool] = None,
|
|
1428
|
+
deterministic: bool = False,
|
|
1429
|
+
mask_mod: Optional[Callable] = None,
|
|
1430
|
+
full_block_cnt: Optional[torch.Tensor] = None,
|
|
1431
|
+
full_block_idx: Optional[torch.Tensor] = None,
|
|
1432
|
+
mask_block_cnt: Optional[torch.Tensor] = None,
|
|
1433
|
+
mask_block_idx: Optional[torch.Tensor] = None,
|
|
1434
|
+
):
|
|
1435
|
+
return FlashAttnFunc.apply(
|
|
1436
|
+
q,
|
|
1437
|
+
k,
|
|
1438
|
+
v,
|
|
1439
|
+
softmax_scale,
|
|
1440
|
+
causal,
|
|
1441
|
+
window_size,
|
|
1442
|
+
learnable_sink,
|
|
1443
|
+
softcap,
|
|
1444
|
+
num_splits,
|
|
1445
|
+
pack_gqa,
|
|
1446
|
+
deterministic,
|
|
1447
|
+
mask_mod,
|
|
1448
|
+
full_block_cnt,
|
|
1449
|
+
full_block_idx,
|
|
1450
|
+
mask_block_cnt,
|
|
1451
|
+
mask_block_idx,
|
|
1452
|
+
)
|
|
1453
|
+
|
|
1454
|
+
|
|
1455
|
+
def flash_attn_varlen_func(
|
|
1456
|
+
q: torch.Tensor,
|
|
1457
|
+
k: torch.Tensor,
|
|
1458
|
+
v: torch.Tensor,
|
|
1459
|
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
|
1460
|
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
|
1461
|
+
max_seqlen_q: Optional[int] = None,
|
|
1462
|
+
max_seqlen_k: Optional[int] = None,
|
|
1463
|
+
seqused_q: Optional[torch.Tensor] = None,
|
|
1464
|
+
seqused_k: Optional[torch.Tensor] = None,
|
|
1465
|
+
page_table: Optional[torch.Tensor] = None,
|
|
1466
|
+
softmax_scale: Optional[float] = None,
|
|
1467
|
+
causal: bool = False,
|
|
1468
|
+
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
|
|
1469
|
+
learnable_sink: Optional[torch.Tensor] = None,
|
|
1470
|
+
softcap: float = 0.0,
|
|
1471
|
+
num_splits: int = 1,
|
|
1472
|
+
pack_gqa: Optional[bool] = None,
|
|
1473
|
+
deterministic: bool = False,
|
|
1474
|
+
score_mod: Optional[Callable] = None,
|
|
1475
|
+
aux_tensors: Optional[list] = None,
|
|
1476
|
+
):
|
|
1477
|
+
return FlashAttnVarlenFunc.apply(
|
|
1478
|
+
q,
|
|
1479
|
+
k,
|
|
1480
|
+
v,
|
|
1481
|
+
cu_seqlens_q,
|
|
1482
|
+
cu_seqlens_k,
|
|
1483
|
+
seqused_q,
|
|
1484
|
+
seqused_k,
|
|
1485
|
+
max_seqlen_q,
|
|
1486
|
+
max_seqlen_k,
|
|
1487
|
+
page_table,
|
|
1488
|
+
softmax_scale,
|
|
1489
|
+
causal,
|
|
1490
|
+
window_size,
|
|
1491
|
+
learnable_sink,
|
|
1492
|
+
softcap,
|
|
1493
|
+
num_splits,
|
|
1494
|
+
pack_gqa,
|
|
1495
|
+
deterministic,
|
|
1496
|
+
score_mod,
|
|
1497
|
+
aux_tensors,
|
|
1498
|
+
)
|
|
1499
|
+
|
|
1500
|
+
|
|
1501
|
+
def _flash_attn_fwd_combine(
|
|
1502
|
+
out_partial: torch.Tensor,
|
|
1503
|
+
lse_partial: torch.Tensor,
|
|
1504
|
+
out: torch.Tensor,
|
|
1505
|
+
lse: Optional[torch.Tensor] = None,
|
|
1506
|
+
cu_seqlens: Optional[torch.Tensor] = None,
|
|
1507
|
+
seqused: Optional[torch.Tensor] = None,
|
|
1508
|
+
num_splits_dynamic_ptr: Optional[torch.Tensor] = None,
|
|
1509
|
+
semaphore_to_reset: Optional[torch.Tensor] = None,
|
|
1510
|
+
) -> None:
|
|
1511
|
+
"""Forward combine kernel for split attention computation.
|
|
1512
|
+
|
|
1513
|
+
Combines partial outputs and log-sum-exp values from multiple splits
|
|
1514
|
+
of attention computation into final outputs.
|
|
1515
|
+
|
|
1516
|
+
Args:
|
|
1517
|
+
out_partial: Partial outputs tensor (num_splits, batch, seqlen, nheads, headdim) or
|
|
1518
|
+
(num_splits, total_q, nheads, headdim) if there's cu_seqlens
|
|
1519
|
+
lse_partial: Partial LSE tensor (num_splits, batch, seqlen, nheads) or
|
|
1520
|
+
(num_splits, total_q, nheads) if there's cu_seqlens
|
|
1521
|
+
out: Output tensor (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim) if there's cu_seqlens
|
|
1522
|
+
lse: Output LSE tensor (batch, seqlen, nheads) or (total_q, nheads) if there's cu_seqlens.
|
|
1523
|
+
cu_seqlens: Cumulative sequence lengths for variable length sequences
|
|
1524
|
+
seqused: Used sequence lengths for each batch
|
|
1525
|
+
num_splits_dynamic_ptr: Dynamic number of splits per batch
|
|
1526
|
+
semaphore_to_reset: Semaphore for synchronization
|
|
1527
|
+
k_block_size: Block size for head dimension
|
|
1528
|
+
|
|
1529
|
+
Returns:
|
|
1530
|
+
None
|
|
1531
|
+
"""
|
|
1532
|
+
# Input validation
|
|
1533
|
+
assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions"
|
|
1534
|
+
assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions"
|
|
1535
|
+
assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], (
|
|
1536
|
+
"out_partial must be fp16, bf16, or fp32"
|
|
1537
|
+
)
|
|
1538
|
+
assert lse_partial.dtype == torch.float32, "lse_partial must be fp32"
|
|
1539
|
+
assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device"
|
|
1540
|
+
assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension"
|
|
1541
|
+
assert lse_partial.stride(-2) == 1, "lse_partial must be contiguous in the seqlen dimension"
|
|
1542
|
+
assert lse_partial.shape == out_partial.shape[:-1]
|
|
1543
|
+
|
|
1544
|
+
# Determine if this is variable length based on dimensions
|
|
1545
|
+
is_varlen = out_partial.dim() == 4
|
|
1546
|
+
|
|
1547
|
+
# Validate output tensor shapes and types
|
|
1548
|
+
assert out.shape == out_partial.shape[1:], "out shape mismatch"
|
|
1549
|
+
if lse is not None:
|
|
1550
|
+
assert lse.shape == lse_partial.shape[1:], "lse shape mismatch"
|
|
1551
|
+
assert lse.dtype == torch.float32, "lse must be fp32"
|
|
1552
|
+
|
|
1553
|
+
# Validate optional tensors
|
|
1554
|
+
for t, name in [
|
|
1555
|
+
(cu_seqlens, "cu_seqlens"),
|
|
1556
|
+
(seqused, "seqused"),
|
|
1557
|
+
(num_splits_dynamic_ptr, "num_splits_dynamic_ptr"),
|
|
1558
|
+
]:
|
|
1559
|
+
if t is not None:
|
|
1560
|
+
assert t.dtype == torch.int32, f"{name} must be int32"
|
|
1561
|
+
assert t.is_cuda, f"{name} must be on CUDA device"
|
|
1562
|
+
assert t.is_contiguous(), f"{name} must be contiguous"
|
|
1563
|
+
|
|
1564
|
+
head_dim = out_partial.shape[-1]
|
|
1565
|
+
num_splits = out_partial.shape[0]
|
|
1566
|
+
assert num_splits <= 256
|
|
1567
|
+
# If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
|
|
1568
|
+
# so that kBlockM is smaller and we have more parallelism.
|
|
1569
|
+
k_block_size = 64 if head_dim <= 64 else 128
|
|
1570
|
+
# We want kBlockM to be as small as possible to maximize parallelism.
|
|
1571
|
+
# E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
|
|
1572
|
+
m_block_size = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32)
|
|
1573
|
+
log_max_splits = max(math.ceil(math.log2(num_splits)), 4)
|
|
1574
|
+
if m_block_size == 8:
|
|
1575
|
+
# If kBlockM == 8 then the minimum number of splits is 32.
|
|
1576
|
+
# TODO: we can deal w this by using 128 threads instead
|
|
1577
|
+
log_max_splits = max(log_max_splits, 5)
|
|
1578
|
+
|
|
1579
|
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
1580
|
+
|
|
1581
|
+
# Create combine kernel configuration
|
|
1582
|
+
dtype = torch2cute_dtype_map[out.dtype]
|
|
1583
|
+
dtype_partial = torch2cute_dtype_map[out_partial.dtype]
|
|
1584
|
+
|
|
1585
|
+
compile_key = (
|
|
1586
|
+
dtype,
|
|
1587
|
+
dtype_partial,
|
|
1588
|
+
head_dim,
|
|
1589
|
+
m_block_size,
|
|
1590
|
+
k_block_size,
|
|
1591
|
+
log_max_splits,
|
|
1592
|
+
cu_seqlens is not None,
|
|
1593
|
+
seqused is not None,
|
|
1594
|
+
lse is not None,
|
|
1595
|
+
)
|
|
1596
|
+
|
|
1597
|
+
if compile_key not in _flash_attn_fwd_combine.compile_cache:
|
|
1598
|
+
out_partial_tensor = to_cute_tensor(
|
|
1599
|
+
out_partial, leading_dim=4 if not is_varlen else 3
|
|
1600
|
+
)
|
|
1601
|
+
lse_partial_tensor = to_cute_tensor(
|
|
1602
|
+
lse_partial, assumed_align=4, leading_dim=lse_partial.ndim - 2
|
|
1603
|
+
)
|
|
1604
|
+
out_tensor = to_cute_tensor(out, leading_dim=3 if not is_varlen else 2)
|
|
1605
|
+
lse_tensor = (
|
|
1606
|
+
to_cute_tensor(lse, assumed_align=4, leading_dim=lse.ndim - 2)
|
|
1607
|
+
if lse is not None
|
|
1608
|
+
else None
|
|
1609
|
+
)
|
|
1610
|
+
|
|
1611
|
+
optional_tensors = [
|
|
1612
|
+
to_cute_tensor(t, assumed_align=4, leading_dim=0)
|
|
1613
|
+
if t is not None
|
|
1614
|
+
else None
|
|
1615
|
+
for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset)
|
|
1616
|
+
]
|
|
1617
|
+
cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = (
|
|
1618
|
+
optional_tensors
|
|
1619
|
+
)
|
|
1620
|
+
fa_combine = FlashAttentionForwardCombine(
|
|
1621
|
+
dtype=dtype,
|
|
1622
|
+
dtype_partial=dtype_partial,
|
|
1623
|
+
head_dim=head_dim,
|
|
1624
|
+
m_block_size=m_block_size,
|
|
1625
|
+
k_block_size=k_block_size,
|
|
1626
|
+
log_max_splits=log_max_splits,
|
|
1627
|
+
)
|
|
1628
|
+
|
|
1629
|
+
# Check if implementation is supported
|
|
1630
|
+
if not fa_combine.can_implement(
|
|
1631
|
+
dtype,
|
|
1632
|
+
dtype_partial,
|
|
1633
|
+
head_dim,
|
|
1634
|
+
m_block_size,
|
|
1635
|
+
k_block_size,
|
|
1636
|
+
log_max_splits,
|
|
1637
|
+
num_threads=256,
|
|
1638
|
+
):
|
|
1639
|
+
raise RuntimeError(
|
|
1640
|
+
"FlashAttention combine kernel cannot be implemented with given parameters"
|
|
1641
|
+
)
|
|
1642
|
+
|
|
1643
|
+
_flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile(
|
|
1644
|
+
fa_combine,
|
|
1645
|
+
out_partial_tensor,
|
|
1646
|
+
lse_partial_tensor,
|
|
1647
|
+
out_tensor,
|
|
1648
|
+
lse_tensor,
|
|
1649
|
+
cu_seqlens_tensor,
|
|
1650
|
+
seqused_tensor,
|
|
1651
|
+
num_splits_dynamic_tensor,
|
|
1652
|
+
semaphore_tensor,
|
|
1653
|
+
current_stream,
|
|
1654
|
+
options="--enable-tvm-ffi",
|
|
1655
|
+
)
|
|
1656
|
+
_flash_attn_fwd_combine.compile_cache[compile_key](
|
|
1657
|
+
out_partial,
|
|
1658
|
+
lse_partial,
|
|
1659
|
+
out,
|
|
1660
|
+
lse,
|
|
1661
|
+
cu_seqlens,
|
|
1662
|
+
seqused,
|
|
1663
|
+
num_splits_dynamic_ptr,
|
|
1664
|
+
semaphore_to_reset,
|
|
1665
|
+
current_stream,
|
|
1666
|
+
)
|
|
1667
|
+
|
|
1668
|
+
|
|
1669
|
+
_flash_attn_fwd_combine.compile_cache = {}
|
|
1670
|
+
|
|
1671
|
+
|
|
1672
|
+
def flash_attn_combine(
|
|
1673
|
+
out_partial: torch.Tensor,
|
|
1674
|
+
lse_partial: torch.Tensor,
|
|
1675
|
+
out: Optional[torch.Tensor] = None,
|
|
1676
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
1677
|
+
cu_seqlens: Optional[torch.Tensor] = None,
|
|
1678
|
+
seqused: Optional[torch.Tensor] = None,
|
|
1679
|
+
return_lse: bool = True,
|
|
1680
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
1681
|
+
"""Flash Attention combine function for split attention computation.
|
|
1682
|
+
|
|
1683
|
+
Combines partial outputs and log-sum-exp values from multiple splits
|
|
1684
|
+
of attention computation into final outputs. This is the main user-facing
|
|
1685
|
+
interface for the combine kernel.
|
|
1686
|
+
|
|
1687
|
+
Args:
|
|
1688
|
+
out_partial: Partial outputs tensor with shape:
|
|
1689
|
+
- (num_splits, batch_size, seqlen, num_heads, head_size) for regular batched input
|
|
1690
|
+
- (num_splits, total_q, num_heads, head_size) for variable length input
|
|
1691
|
+
lse_partial: Partial LSE tensor with shape:
|
|
1692
|
+
- (num_splits, batch_size, seqlen, num_heads) for regular batched input
|
|
1693
|
+
- (num_splits, total_q, num_heads) for variable length input
|
|
1694
|
+
out: Optional output tensor. If None, will be created automatically.
|
|
1695
|
+
out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input.
|
|
1696
|
+
cu_seqlens: Cumulative sequence lengths for variable length sequences
|
|
1697
|
+
seqused: Used sequence lengths for each batch
|
|
1698
|
+
return_lse: Whether to return the combined LSE tensor. Default is True.
|
|
1699
|
+
|
|
1700
|
+
Returns:
|
|
1701
|
+
Tuple of (out, lse) where:
|
|
1702
|
+
- out: Combined output tensor with shape (batch_size, seqlen, num_heads, head_size)
|
|
1703
|
+
or (total_q, num_heads, head_size) for varlen
|
|
1704
|
+
- lse: Combined log-sum-exp tensor with shape (batch_size, seqlen, num_heads)
|
|
1705
|
+
or (total_q, num_heads) for varlen. None if return_lse=False
|
|
1706
|
+
|
|
1707
|
+
Note:
|
|
1708
|
+
This function expects the input tensors to be in the format produced by
|
|
1709
|
+
split attention computation, where the first dimension is num_splits.
|
|
1710
|
+
The permuting from user format to kernel format is now done inside the kernel.
|
|
1711
|
+
"""
|
|
1712
|
+
# Input validation
|
|
1713
|
+
assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions"
|
|
1714
|
+
assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions"
|
|
1715
|
+
assert out_partial.dtype == torch.float32, "out_partial must be fp32 (from accumulation)"
|
|
1716
|
+
assert lse_partial.dtype == torch.float32, "lse_partial must be fp32"
|
|
1717
|
+
|
|
1718
|
+
# Determine if this is variable length based on dimensions
|
|
1719
|
+
is_varlen = out_partial.dim() == 4
|
|
1720
|
+
|
|
1721
|
+
if is_varlen:
|
|
1722
|
+
# Variable length: (num_splits, total_q, num_heads, head_size)
|
|
1723
|
+
num_splits, total_q, num_heads, head_size = out_partial.shape
|
|
1724
|
+
assert lse_partial.shape == (num_splits, total_q, num_heads), (
|
|
1725
|
+
"lse_partial shape mismatch for varlen"
|
|
1726
|
+
)
|
|
1727
|
+
batch_size = 1 # Treat as single batch for varlen
|
|
1728
|
+
seqlen = total_q
|
|
1729
|
+
else:
|
|
1730
|
+
# Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size)
|
|
1731
|
+
num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape
|
|
1732
|
+
assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), (
|
|
1733
|
+
"lse_partial shape mismatch"
|
|
1734
|
+
)
|
|
1735
|
+
|
|
1736
|
+
# Determine output dtype
|
|
1737
|
+
if out_dtype is None:
|
|
1738
|
+
out_dtype = out_partial.dtype
|
|
1739
|
+
|
|
1740
|
+
# Create output if not provided
|
|
1741
|
+
device = out_partial.device
|
|
1742
|
+
if out is None:
|
|
1743
|
+
if is_varlen:
|
|
1744
|
+
out = torch.empty(total_q, num_heads, head_size, dtype=out_dtype, device=device)
|
|
1745
|
+
else:
|
|
1746
|
+
out = torch.empty(
|
|
1747
|
+
batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device
|
|
1748
|
+
)
|
|
1749
|
+
|
|
1750
|
+
# Create lse output only if requested
|
|
1751
|
+
if return_lse:
|
|
1752
|
+
if is_varlen:
|
|
1753
|
+
lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose(
|
|
1754
|
+
0, 1
|
|
1755
|
+
)
|
|
1756
|
+
else:
|
|
1757
|
+
lse = torch.empty(
|
|
1758
|
+
batch_size, num_heads, seqlen, dtype=torch.float32, device=device
|
|
1759
|
+
).transpose(1, 2)
|
|
1760
|
+
else:
|
|
1761
|
+
lse = None
|
|
1762
|
+
|
|
1763
|
+
_flash_attn_fwd_combine(
|
|
1764
|
+
out_partial,
|
|
1765
|
+
lse_partial,
|
|
1766
|
+
out,
|
|
1767
|
+
lse,
|
|
1768
|
+
cu_seqlens,
|
|
1769
|
+
seqused,
|
|
1770
|
+
)
|
|
1771
|
+
return out, lse
|