mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2471 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
|
3
|
+
# A reimplementation of
|
|
4
|
+
# https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_kernel_sm80.h
|
|
5
|
+
# and https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_kernel_sm90.h
|
|
6
|
+
# from Cutlass C++ to Cute-DSL.
|
|
7
|
+
# Built on Cute-DSL example: https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
from types import SimpleNamespace
|
|
11
|
+
from typing import Type, Callable, Optional, List
|
|
12
|
+
from functools import partial
|
|
13
|
+
|
|
14
|
+
import cuda.bindings.driver as cuda
|
|
15
|
+
|
|
16
|
+
import cutlass
|
|
17
|
+
import cutlass.cute as cute
|
|
18
|
+
from cutlass import Constexpr, Float32, Int32, const_expr, Boolean
|
|
19
|
+
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
|
20
|
+
from cutlass.cute.arch import ProxyKind, SharedSpace
|
|
21
|
+
import cutlass.utils as utils_basic
|
|
22
|
+
from cutlass.utils import LayoutEnum
|
|
23
|
+
import cutlass.utils.hopper_helpers as sm90_utils_basic
|
|
24
|
+
|
|
25
|
+
from mslk.attention.flash_attn import ampere_helpers as sm80_utils
|
|
26
|
+
from mslk.attention.flash_attn import hopper_helpers as sm90_utils
|
|
27
|
+
from mslk.attention.flash_attn import utils
|
|
28
|
+
from mslk.attention.flash_attn import copy_utils
|
|
29
|
+
from mslk.attention.flash_attn.mask import AttentionMask
|
|
30
|
+
from mslk.attention.flash_attn.softmax import Softmax, apply_score_mod_inner
|
|
31
|
+
from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
|
|
32
|
+
from mslk.attention.flash_attn.block_info import BlockInfo
|
|
33
|
+
from mslk.attention.flash_attn.block_sparsity import BlockSparseTensors
|
|
34
|
+
from mslk.attention.flash_attn.block_sparse_utils import (
|
|
35
|
+
produce_block_sparse_loads,
|
|
36
|
+
consume_block_sparse_loads,
|
|
37
|
+
)
|
|
38
|
+
from mslk.attention.flash_attn import pipeline
|
|
39
|
+
from mslk.attention.flash_attn.pack_gqa import PackGQA
|
|
40
|
+
from mslk.attention.flash_attn.named_barrier import NamedBarrierFwd
|
|
41
|
+
from mslk.attention.flash_attn.tile_scheduler import (
|
|
42
|
+
TileSchedulerArguments,
|
|
43
|
+
SingleTileScheduler,
|
|
44
|
+
SingleTileLPTScheduler,
|
|
45
|
+
SingleTileVarlenScheduler,
|
|
46
|
+
ParamsBase,
|
|
47
|
+
)
|
|
48
|
+
from cutlass.cute import FastDivmodDivisor
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class FlashAttentionForwardBase:
|
|
52
|
+
arch: int = 80
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
dtype: Type[cutlass.Numeric],
|
|
57
|
+
head_dim: int,
|
|
58
|
+
head_dim_v: Optional[int] = None,
|
|
59
|
+
qhead_per_kvhead: int = 1,
|
|
60
|
+
is_causal: bool = False,
|
|
61
|
+
is_local: bool = False,
|
|
62
|
+
pack_gqa: bool = True,
|
|
63
|
+
tile_m: int = 128,
|
|
64
|
+
tile_n: int = 128,
|
|
65
|
+
num_stages: int = 1,
|
|
66
|
+
num_threads: int = 128,
|
|
67
|
+
Q_in_regs: bool = False,
|
|
68
|
+
score_mod: Optional[cutlass.Constexpr] = None,
|
|
69
|
+
mask_mod: Optional[cutlass.Constexpr] = None,
|
|
70
|
+
has_aux_tensors: bool = False,
|
|
71
|
+
):
|
|
72
|
+
"""Initializes the configuration for a flash attention kernel.
|
|
73
|
+
|
|
74
|
+
All contiguous dimensions must be at least 16 bytes aligned, which means that the head dimension
|
|
75
|
+
should be a multiple of 8.
|
|
76
|
+
|
|
77
|
+
:param head_dim: head dimension
|
|
78
|
+
:type head_dim: int
|
|
79
|
+
:param tile_m: m block size
|
|
80
|
+
:type tile_m: int
|
|
81
|
+
:param tile_n: n block size
|
|
82
|
+
:type tile_n: int
|
|
83
|
+
:param num_threads: number of threads
|
|
84
|
+
:type num_threads: int
|
|
85
|
+
:param is_causal: is causal
|
|
86
|
+
:param score_mod: A callable that takes the attention scores and applies a modification.
|
|
87
|
+
Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Any``
|
|
88
|
+
:param mask_mod: A callable that takes the attention scores and returns a boolean representing whether that score should be masked.
|
|
89
|
+
Callable signature: ``mask_mod(batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Boolean``
|
|
90
|
+
"""
|
|
91
|
+
self.dtype = dtype
|
|
92
|
+
# padding head_dim to a multiple of 16 as k_block_size
|
|
93
|
+
hdim_multiple_of = 16
|
|
94
|
+
self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
|
|
95
|
+
head_dim_v = head_dim_v if head_dim_v is not None else head_dim
|
|
96
|
+
self.same_hdim_kv = head_dim == head_dim_v
|
|
97
|
+
self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)
|
|
98
|
+
# Can save registers (and hence be faster) if we don't have to check hdim predication
|
|
99
|
+
self.check_hdim_oob = head_dim != self.tile_hdim
|
|
100
|
+
self.check_hdim_v_oob = head_dim_v != self.tile_hdimv
|
|
101
|
+
self.qhead_per_kvhead = qhead_per_kvhead
|
|
102
|
+
self.is_causal = is_causal
|
|
103
|
+
self.is_local = is_local
|
|
104
|
+
self.pack_gqa = pack_gqa
|
|
105
|
+
self.tile_m = tile_m
|
|
106
|
+
self.tile_n = tile_n
|
|
107
|
+
self.num_threads = num_threads
|
|
108
|
+
self.num_stages = num_stages
|
|
109
|
+
self.Q_in_regs = Q_in_regs
|
|
110
|
+
self.score_mod = score_mod
|
|
111
|
+
self.mask_mod = mask_mod
|
|
112
|
+
self.qk_acc_dtype = Float32
|
|
113
|
+
if const_expr(has_aux_tensors):
|
|
114
|
+
self.vec_size: cutlass.Constexpr = 1
|
|
115
|
+
else:
|
|
116
|
+
self.vec_size: cutlass.Constexpr = 2
|
|
117
|
+
|
|
118
|
+
@staticmethod
|
|
119
|
+
def can_implement(
|
|
120
|
+
dtype,
|
|
121
|
+
head_dim,
|
|
122
|
+
head_dim_v,
|
|
123
|
+
tile_m,
|
|
124
|
+
tile_n,
|
|
125
|
+
num_stages,
|
|
126
|
+
num_threads,
|
|
127
|
+
is_causal,
|
|
128
|
+
Q_in_regs=False,
|
|
129
|
+
) -> bool:
|
|
130
|
+
"""Check if the kernel can be implemented with the given parameters.
|
|
131
|
+
|
|
132
|
+
:param dtype: data type
|
|
133
|
+
:type dtype: cutlass.Numeric
|
|
134
|
+
:param head_dim: head dimension
|
|
135
|
+
:type head_dim: int
|
|
136
|
+
:param tile_m: m block size
|
|
137
|
+
:type tile_m: int
|
|
138
|
+
:param tile_n: n block size
|
|
139
|
+
:type tile_n: int
|
|
140
|
+
:param num_threads: number of threads
|
|
141
|
+
:type num_threads: int
|
|
142
|
+
:param is_causal: is causal
|
|
143
|
+
:type is_causal: bool
|
|
144
|
+
|
|
145
|
+
:return: True if the kernel can be implemented, False otherwise
|
|
146
|
+
:rtype: bool
|
|
147
|
+
"""
|
|
148
|
+
if dtype not in [cutlass.Float16, cutlass.BFloat16]:
|
|
149
|
+
return False
|
|
150
|
+
if head_dim % 8 != 0:
|
|
151
|
+
return False
|
|
152
|
+
if head_dim_v % 8 != 0:
|
|
153
|
+
return False
|
|
154
|
+
if tile_n % 16 != 0:
|
|
155
|
+
return False
|
|
156
|
+
if num_threads % 32 != 0:
|
|
157
|
+
return False
|
|
158
|
+
# Check if block size setting is out of shared memory capacity
|
|
159
|
+
# Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size
|
|
160
|
+
smem_usage_Q = tile_m * head_dim * 2
|
|
161
|
+
smem_usage_K = tile_n * head_dim * num_stages * 2
|
|
162
|
+
smem_usage_V = tile_n * head_dim_v * num_stages * 2
|
|
163
|
+
smem_usage_QV = (
|
|
164
|
+
(smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V)
|
|
165
|
+
)
|
|
166
|
+
smem_usage = smem_usage_QV + smem_usage_K
|
|
167
|
+
# TODO: sm86 and sm89
|
|
168
|
+
smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80")
|
|
169
|
+
if smem_usage > smem_capacity:
|
|
170
|
+
return False
|
|
171
|
+
# Check if twice the block size is divisible by the number of threads
|
|
172
|
+
if (tile_m * 2) % num_threads != 0:
|
|
173
|
+
return False
|
|
174
|
+
return True
|
|
175
|
+
|
|
176
|
+
def _check_type(
|
|
177
|
+
self,
|
|
178
|
+
mQ_type: Type[cutlass.Numeric],
|
|
179
|
+
mK_type: Type[cutlass.Numeric],
|
|
180
|
+
mV_type: Type[cutlass.Numeric],
|
|
181
|
+
mO_type: Type[cutlass.Numeric],
|
|
182
|
+
mLSE_type: Type[cutlass.Numeric] | None,
|
|
183
|
+
mCuSeqlensQ_type: Type[cutlass.Numeric] | None,
|
|
184
|
+
mCuSeqlensK_type: Type[cutlass.Numeric] | None,
|
|
185
|
+
mSeqUsedQ_type: Type[cutlass.Numeric] | None,
|
|
186
|
+
mSeqUsedK_type: Type[cutlass.Numeric] | None,
|
|
187
|
+
):
|
|
188
|
+
# Get the data type and check if it is fp16 or bf16
|
|
189
|
+
if const_expr(not (mQ_type == mK_type == mV_type == mO_type)):
|
|
190
|
+
raise TypeError("All tensors must have the same data type")
|
|
191
|
+
if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]):
|
|
192
|
+
raise TypeError("Only Float16 or BFloat16 is supported")
|
|
193
|
+
if const_expr(mLSE_type not in [None, Float32]):
|
|
194
|
+
raise TypeError("LSE tensor must be Float32")
|
|
195
|
+
if const_expr(mCuSeqlensQ_type not in [None, Int32]):
|
|
196
|
+
raise TypeError("cu_seqlens_q tensor must be Int32")
|
|
197
|
+
if const_expr(mCuSeqlensK_type not in [None, Int32]):
|
|
198
|
+
raise TypeError("cu_seqlens_k tensor must be Int32")
|
|
199
|
+
if const_expr(mSeqUsedQ_type not in [None, Int32]):
|
|
200
|
+
raise TypeError("seqused_q tensor must be Int32")
|
|
201
|
+
if const_expr(mSeqUsedK_type not in [None, Int32]):
|
|
202
|
+
raise TypeError("seqused_k tensor must be Int32")
|
|
203
|
+
assert mQ_type == self.dtype
|
|
204
|
+
|
|
205
|
+
def _setup_attributes(self):
|
|
206
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
207
|
+
# Shared memory layout: Q/K/V
|
|
208
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
209
|
+
sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom = (
|
|
210
|
+
self._get_smem_layout_atom()
|
|
211
|
+
)
|
|
212
|
+
self.sQ_layout = cute.tile_to_shape(
|
|
213
|
+
sQ_layout_atom,
|
|
214
|
+
(self.tile_m, self.tile_hdim),
|
|
215
|
+
(0, 1),
|
|
216
|
+
)
|
|
217
|
+
self.sK_layout = cute.tile_to_shape(
|
|
218
|
+
sK_layout_atom,
|
|
219
|
+
(self.tile_n, self.tile_hdim, self.num_stages),
|
|
220
|
+
(0, 1, 2),
|
|
221
|
+
)
|
|
222
|
+
self.sV_layout = cute.tile_to_shape(
|
|
223
|
+
sV_layout_atom,
|
|
224
|
+
(self.tile_n, self.tile_hdimv, self.num_stages),
|
|
225
|
+
(0, 1, 2),
|
|
226
|
+
)
|
|
227
|
+
self.sO_layout = cute.tile_to_shape(
|
|
228
|
+
sO_layout_atom,
|
|
229
|
+
(self.tile_m, self.tile_hdimv),
|
|
230
|
+
(0, 1),
|
|
231
|
+
)
|
|
232
|
+
if const_expr(sP_layout_atom is not None):
|
|
233
|
+
self.sP_layout = cute.tile_to_shape(
|
|
234
|
+
sP_layout_atom,
|
|
235
|
+
(self.tile_m, self.tile_n),
|
|
236
|
+
(0, 1),
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
self.sP_layout = None
|
|
240
|
+
|
|
241
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
242
|
+
# GMEM Tiled copy:
|
|
243
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
244
|
+
# Thread layouts for copies
|
|
245
|
+
universal_copy_bits = 128
|
|
246
|
+
async_copy_elems = universal_copy_bits // self.dtype.width
|
|
247
|
+
# atom_async_copy: async copy atom for QKV load
|
|
248
|
+
atom_async_copy = cute.make_copy_atom(
|
|
249
|
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
|
250
|
+
self.dtype,
|
|
251
|
+
num_bits_per_copy=universal_copy_bits,
|
|
252
|
+
)
|
|
253
|
+
# atom_universal_copy: universal copy atom for O store
|
|
254
|
+
atom_universal_copy = cute.make_copy_atom(
|
|
255
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
256
|
+
self.dtype,
|
|
257
|
+
num_bits_per_copy=universal_copy_bits,
|
|
258
|
+
)
|
|
259
|
+
# tQ_layout and tK_layout: thread layout for QK load
|
|
260
|
+
tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems
|
|
261
|
+
assert self.num_Q_load_threads % tQK_shape_dim_1 == 0, (
|
|
262
|
+
"num_threads must be divisible by tQK_shape_dim_1"
|
|
263
|
+
)
|
|
264
|
+
assert self.num_producer_threads % tQK_shape_dim_1 == 0, (
|
|
265
|
+
"num_threads must be divisible by tQK_shape_dim_1"
|
|
266
|
+
)
|
|
267
|
+
tQ_layout = cute.make_ordered_layout(
|
|
268
|
+
(self.num_Q_load_threads // tQK_shape_dim_1, tQK_shape_dim_1),
|
|
269
|
+
order=(1, 0),
|
|
270
|
+
)
|
|
271
|
+
tK_layout = cute.make_ordered_layout(
|
|
272
|
+
(self.num_producer_threads // tQK_shape_dim_1, tQK_shape_dim_1),
|
|
273
|
+
order=(1, 0),
|
|
274
|
+
)
|
|
275
|
+
# So that we don't have to check if we overshoot kBlockM when we load Q
|
|
276
|
+
assert self.tile_m % tQ_layout.shape[0] == 0
|
|
277
|
+
tV_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems
|
|
278
|
+
tV_layout = cute.make_ordered_layout(
|
|
279
|
+
(self.num_producer_threads // tV_shape_dim_1, tV_shape_dim_1),
|
|
280
|
+
order=(1, 0),
|
|
281
|
+
)
|
|
282
|
+
# TODO: need a different layout for O if O dtype is not the same as V dtype
|
|
283
|
+
# tO_layout: thread layout for O store
|
|
284
|
+
tO_layout = cute.make_ordered_layout(
|
|
285
|
+
(self.num_epilogue_threads // tV_shape_dim_1, tV_shape_dim_1),
|
|
286
|
+
order=(1, 0),
|
|
287
|
+
)
|
|
288
|
+
# So that we don't have to check if we overshoot kBlockM when we store O
|
|
289
|
+
assert self.tile_m % tO_layout.shape[0] == 0
|
|
290
|
+
|
|
291
|
+
# Value layouts for copies
|
|
292
|
+
vQKV_layout = cute.make_layout((1, async_copy_elems))
|
|
293
|
+
vO_layout = vQKV_layout
|
|
294
|
+
|
|
295
|
+
self.gmem_tiled_copy_Q = cute.make_tiled_copy_tv(atom_async_copy, tQ_layout, vQKV_layout)
|
|
296
|
+
self.gmem_tiled_copy_K = cute.make_tiled_copy_tv(atom_async_copy, tK_layout, vQKV_layout)
|
|
297
|
+
self.gmem_tiled_copy_V = cute.make_tiled_copy_tv(atom_async_copy, tV_layout, vQKV_layout)
|
|
298
|
+
# gmem_tiled_copy_O: tiled copy for O store
|
|
299
|
+
self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout)
|
|
300
|
+
|
|
301
|
+
def _get_smem_layout_atom(self):
|
|
302
|
+
raise NotImplementedError()
|
|
303
|
+
|
|
304
|
+
def _get_tiled_mma(self):
|
|
305
|
+
raise NotImplementedError()
|
|
306
|
+
|
|
307
|
+
def _get_shared_storage_cls(self):
|
|
308
|
+
raise NotImplementedError()
|
|
309
|
+
|
|
310
|
+
@cute.jit
|
|
311
|
+
def __call__(
|
|
312
|
+
self,
|
|
313
|
+
mQ: cute.Tensor,
|
|
314
|
+
mK: cute.Tensor,
|
|
315
|
+
mV: cute.Tensor,
|
|
316
|
+
mO: cute.Tensor,
|
|
317
|
+
mLSE: Optional[cute.Tensor],
|
|
318
|
+
softmax_scale: Float32,
|
|
319
|
+
stream: cuda.CUstream,
|
|
320
|
+
):
|
|
321
|
+
"""Configures and launches the flash attention kernel.
|
|
322
|
+
|
|
323
|
+
mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:
|
|
324
|
+
(batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1)
|
|
325
|
+
"""
|
|
326
|
+
raise NotImplementedError()
|
|
327
|
+
|
|
328
|
+
@cute.jit
|
|
329
|
+
def epilogue(
|
|
330
|
+
self,
|
|
331
|
+
acc_O: cute.Tensor,
|
|
332
|
+
lse: cute.Tensor,
|
|
333
|
+
mO: cute.Tensor,
|
|
334
|
+
mLSE: Optional[cute.Tensor],
|
|
335
|
+
sO: cute.Tensor,
|
|
336
|
+
seqlen: SeqlenInfoQK,
|
|
337
|
+
gmem_tiled_copy_O: cute.TiledCopy,
|
|
338
|
+
tma_atom_O: Optional[cute.CopyAtom],
|
|
339
|
+
tiled_mma: cute.TiledMma,
|
|
340
|
+
tidx: Int32,
|
|
341
|
+
m_block: Int32,
|
|
342
|
+
head_idx: Int32,
|
|
343
|
+
batch_idx: Int32,
|
|
344
|
+
):
|
|
345
|
+
# store acc_O
|
|
346
|
+
rO = cute.make_fragment_like(acc_O, self.dtype)
|
|
347
|
+
rO.store(acc_O.load().to(self.dtype))
|
|
348
|
+
# Make sure all threads have finished reading V
|
|
349
|
+
cute.arch.barrier(
|
|
350
|
+
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads
|
|
351
|
+
)
|
|
352
|
+
smem_copy_atom_O = utils.get_smem_store_atom(self.arch, self.dtype)
|
|
353
|
+
smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx)
|
|
354
|
+
taccOrO = smem_thr_copy_O.retile(rO)
|
|
355
|
+
taccOsO = smem_thr_copy_O.partition_D(sO)
|
|
356
|
+
# taccOsO = quack_copy_utils.partition_D_position_independent(smem_thr_copy_O, sO)
|
|
357
|
+
# copy acc O from rmem to smem with the smem copy atom
|
|
358
|
+
cute.copy(smem_copy_atom_O, taccOrO, taccOsO)
|
|
359
|
+
|
|
360
|
+
cO = cute.make_identity_tensor((self.tile_m, self.tile_hdimv))
|
|
361
|
+
pack_gqa = PackGQA(
|
|
362
|
+
self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Write LSE from rmem -> gmem
|
|
366
|
+
if const_expr(mLSE is not None):
|
|
367
|
+
if const_expr(not seqlen.has_cu_seqlens_q):
|
|
368
|
+
mLSE_cur = mLSE[None, head_idx, batch_idx]
|
|
369
|
+
else:
|
|
370
|
+
offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q)
|
|
371
|
+
mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx])
|
|
372
|
+
if const_expr(not self.pack_gqa):
|
|
373
|
+
gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,))
|
|
374
|
+
gLSE_expanded_layout = cute.append(
|
|
375
|
+
gLSE.layout, cute.make_layout((self.tile_hdimv,), stride=(0,))
|
|
376
|
+
)
|
|
377
|
+
gLSE_expanded = cute.make_tensor(gLSE.iterator, gLSE_expanded_layout)
|
|
378
|
+
thr_mma = tiled_mma.get_slice(tidx)
|
|
379
|
+
taccOgLSE = utils.make_acc_tensor_mn_view(thr_mma.partition_C(gLSE_expanded))
|
|
380
|
+
assert cute.size(taccOgLSE, mode=[0]) == cute.size(lse)
|
|
381
|
+
taccOcO = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cO))
|
|
382
|
+
t0accOcO = utils.make_acc_tensor_mn_view(thr_mma.get_slice(0).partition_C(cO))
|
|
383
|
+
# Only the thread corresponding to column 0 writes out the lse to gmem
|
|
384
|
+
if taccOcO[0][1] == 0:
|
|
385
|
+
for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])):
|
|
386
|
+
if (
|
|
387
|
+
t0accOcO[m, 0][0]
|
|
388
|
+
< seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0]
|
|
389
|
+
):
|
|
390
|
+
taccOgLSE[m, 0] = lse[m]
|
|
391
|
+
else:
|
|
392
|
+
pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q)
|
|
393
|
+
|
|
394
|
+
if const_expr(not seqlen.has_cu_seqlens_q):
|
|
395
|
+
mO_cur = mO[None, None, head_idx, batch_idx]
|
|
396
|
+
else:
|
|
397
|
+
offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q)
|
|
398
|
+
mO_cur = cute.domain_offset((offset, 0), mO[None, None, head_idx])
|
|
399
|
+
# thr_mma = tiled_mma.get_slice(tidx)
|
|
400
|
+
# taccOgO = thr_mma.partition_C(gO)
|
|
401
|
+
# cute.autovec_copy(rO, taccOgO)
|
|
402
|
+
# sync to make sure all smem stores are done
|
|
403
|
+
if const_expr(self.use_tma_O):
|
|
404
|
+
# ensure smem writes are visible to TMA
|
|
405
|
+
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
|
|
406
|
+
cute.arch.barrier_arrive(
|
|
407
|
+
barrier_id=int(NamedBarrierFwd.Epilogue),
|
|
408
|
+
number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,
|
|
409
|
+
)
|
|
410
|
+
gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0))
|
|
411
|
+
store_O, _, _ = copy_utils.tma_get_copy_fn(
|
|
412
|
+
tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True
|
|
413
|
+
)
|
|
414
|
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
415
|
+
if warp_idx == 4:
|
|
416
|
+
cute.arch.barrier(
|
|
417
|
+
barrier_id=int(NamedBarrierFwd.Epilogue),
|
|
418
|
+
number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,
|
|
419
|
+
)
|
|
420
|
+
store_O()
|
|
421
|
+
cute.arch.cp_async_bulk_commit_group()
|
|
422
|
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
|
423
|
+
else:
|
|
424
|
+
cute.arch.barrier(
|
|
425
|
+
barrier_id=int(NamedBarrierFwd.Epilogue),
|
|
426
|
+
number_of_threads=self.num_epilogue_threads,
|
|
427
|
+
)
|
|
428
|
+
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
|
|
429
|
+
tOsO = gmem_thr_copy_O.partition_S(sO)
|
|
430
|
+
tOrO = cute.make_fragment_like(tOsO, self.dtype)
|
|
431
|
+
# load acc O from smem to rmem for wider vectorization
|
|
432
|
+
cute.autovec_copy(tOsO, tOrO)
|
|
433
|
+
if const_expr(not self.pack_gqa):
|
|
434
|
+
gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0))
|
|
435
|
+
tOgO = gmem_thr_copy_O.partition_D(gO)
|
|
436
|
+
tOcO = gmem_thr_copy_O.partition_S(cO)
|
|
437
|
+
t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
|
|
438
|
+
tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])
|
|
439
|
+
# copy acc O from rmem to gmem
|
|
440
|
+
for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
|
|
441
|
+
if (
|
|
442
|
+
t0OcO[0, rest_m, 0][0]
|
|
443
|
+
< seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0]
|
|
444
|
+
):
|
|
445
|
+
cute.copy(
|
|
446
|
+
gmem_tiled_copy_O,
|
|
447
|
+
tOrO[None, rest_m, None],
|
|
448
|
+
tOgO[None, rest_m, None],
|
|
449
|
+
pred=tOpO[None, rest_m, None]
|
|
450
|
+
if const_expr(self.check_hdim_v_oob)
|
|
451
|
+
else None,
|
|
452
|
+
)
|
|
453
|
+
else:
|
|
454
|
+
pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q)
|
|
455
|
+
|
|
456
|
+
@cute.jit
|
|
457
|
+
def advance_pipeline(self, pipeline_index):
|
|
458
|
+
return pipeline_index + 1 if pipeline_index < self.num_stages - 1 else 0
|
|
459
|
+
|
|
460
|
+
@cute.jit
|
|
461
|
+
def load_Q(
|
|
462
|
+
self,
|
|
463
|
+
gmem_thr_copy: cute.TiledCopy,
|
|
464
|
+
gQ: cute.Tensor,
|
|
465
|
+
sQ: cute.Tensor,
|
|
466
|
+
block: Int32,
|
|
467
|
+
seqlen: Int32,
|
|
468
|
+
headdim: Int32,
|
|
469
|
+
):
|
|
470
|
+
tQsQ, tQgQ = gmem_thr_copy.partition_D(sQ), gmem_thr_copy.partition_S(gQ)
|
|
471
|
+
cQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim))
|
|
472
|
+
tQcQ = gmem_thr_copy.partition_S(cQ)
|
|
473
|
+
t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ)
|
|
474
|
+
tQpQ = utils.predicate_k(tQcQ, limit=headdim)
|
|
475
|
+
for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
|
|
476
|
+
# Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit
|
|
477
|
+
# (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time.
|
|
478
|
+
if t0QcQ[0, m, 0][0] < seqlen - block * self.tile_m - tQcQ[0][0]:
|
|
479
|
+
cute.copy(
|
|
480
|
+
gmem_thr_copy,
|
|
481
|
+
tQgQ[None, m, None],
|
|
482
|
+
tQsQ[None, m, None],
|
|
483
|
+
pred=tQpQ[None, m, None] if const_expr(self.check_hdim_oob) else None,
|
|
484
|
+
)
|
|
485
|
+
# We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
|
486
|
+
|
|
487
|
+
@cute.jit
|
|
488
|
+
def load_K(
|
|
489
|
+
self,
|
|
490
|
+
gmem_tiled_copy: cute.TiledCopy,
|
|
491
|
+
tKgK: cute.Tensor,
|
|
492
|
+
tKsK: cute.Tensor,
|
|
493
|
+
tKcK: cute.Tensor,
|
|
494
|
+
t0KcK: cute.Tensor,
|
|
495
|
+
tKpK: cute.Tensor,
|
|
496
|
+
block: Int32,
|
|
497
|
+
smem_pipe_write: Int32,
|
|
498
|
+
seqlen: Int32,
|
|
499
|
+
need_predicates: cutlass.Constexpr,
|
|
500
|
+
):
|
|
501
|
+
# Do we need to check if we overshoot kBlockN when we load K?
|
|
502
|
+
is_even_n_smem_k = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0
|
|
503
|
+
if const_expr(need_predicates or not is_even_n_smem_k):
|
|
504
|
+
# Instead of using tKcK, we using t0KcK and subtract the offset from the limit
|
|
505
|
+
# (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time.
|
|
506
|
+
if const_expr(is_even_n_smem_k):
|
|
507
|
+
seqlen_limit = seqlen - block * self.tile_n
|
|
508
|
+
else:
|
|
509
|
+
if const_expr(not need_predicates):
|
|
510
|
+
seqlen_limit = self.tile_n
|
|
511
|
+
else:
|
|
512
|
+
seqlen_limit = cutlass.min(seqlen - block * self.tile_n, self.tile_n)
|
|
513
|
+
seqlen_limit -= tKcK[0][0]
|
|
514
|
+
for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])):
|
|
515
|
+
if t0KcK[0, n, 0][0] < seqlen_limit:
|
|
516
|
+
cute.copy(
|
|
517
|
+
gmem_tiled_copy,
|
|
518
|
+
tKgK[None, n, None, block],
|
|
519
|
+
tKsK[
|
|
520
|
+
None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0
|
|
521
|
+
],
|
|
522
|
+
pred=tKpK[None, n, None] if const_expr(self.check_hdim_oob) else None,
|
|
523
|
+
)
|
|
524
|
+
# We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
|
|
525
|
+
else:
|
|
526
|
+
cute.copy(
|
|
527
|
+
gmem_tiled_copy,
|
|
528
|
+
tKgK[None, None, None, block],
|
|
529
|
+
tKsK[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0],
|
|
530
|
+
pred=tKpK if const_expr(self.check_hdim_oob) else None,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
@cute.jit
|
|
534
|
+
def load_V(
|
|
535
|
+
self,
|
|
536
|
+
gmem_tiled_copy: cute.TiledCopy,
|
|
537
|
+
tVgV: cute.Tensor,
|
|
538
|
+
tVsV: cute.Tensor,
|
|
539
|
+
tVcV: cute.Tensor,
|
|
540
|
+
t0VcV: cute.Tensor,
|
|
541
|
+
tVpV: cute.Tensor,
|
|
542
|
+
block: Int32,
|
|
543
|
+
smem_pipe_write: Int32,
|
|
544
|
+
seqlen: Int32,
|
|
545
|
+
need_predicates: cutlass.Constexpr,
|
|
546
|
+
):
|
|
547
|
+
# Do we need to check if we overshoot kBlockN when we load V?
|
|
548
|
+
is_even_n_smem_v = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0
|
|
549
|
+
if const_expr(need_predicates or not is_even_n_smem_v):
|
|
550
|
+
for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])):
|
|
551
|
+
# If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked
|
|
552
|
+
if (
|
|
553
|
+
is_even_n_smem_v
|
|
554
|
+
or n < cute.size(tVsV.shape[1]) - 1
|
|
555
|
+
or tVcV[0, n, 0][0] < self.tile_n
|
|
556
|
+
):
|
|
557
|
+
predicate = tVpV[None, n, None] if const_expr(self.check_hdim_v_oob) else None
|
|
558
|
+
if const_expr(need_predicates):
|
|
559
|
+
seqlen_limit = seqlen - block * self.tile_n - tVcV[0][0]
|
|
560
|
+
predicate_n = t0VcV[0, n, 0][0] < seqlen_limit
|
|
561
|
+
predicate = cute.make_fragment_like(tVpV[None, 0, None])
|
|
562
|
+
for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
|
|
563
|
+
for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
|
|
564
|
+
predicate[i, k] = (
|
|
565
|
+
tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True
|
|
566
|
+
) and predicate_n
|
|
567
|
+
cute.copy(
|
|
568
|
+
gmem_tiled_copy,
|
|
569
|
+
tVgV[None, n, None, block],
|
|
570
|
+
tVsV[
|
|
571
|
+
None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0
|
|
572
|
+
],
|
|
573
|
+
pred=predicate,
|
|
574
|
+
)
|
|
575
|
+
else:
|
|
576
|
+
cute.copy(
|
|
577
|
+
gmem_tiled_copy,
|
|
578
|
+
tVgV[None, None, None, block],
|
|
579
|
+
tVsV[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0],
|
|
580
|
+
pred=tVpV if const_expr(self.check_hdim_v_oob) else None,
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
class FlashAttentionForwardSm80(FlashAttentionForwardBase):
|
|
585
|
+
def _get_smem_layout_atom(self):
|
|
586
|
+
sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdim)
|
|
587
|
+
sK_layout_atom = sQ_layout_atom
|
|
588
|
+
sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdimv)
|
|
589
|
+
sO_layout_atom = sV_layout_atom
|
|
590
|
+
sP_layout_atom = None
|
|
591
|
+
return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom
|
|
592
|
+
|
|
593
|
+
def _get_tiled_mma(self):
|
|
594
|
+
tiled_mma_qk = cute.make_tiled_mma(
|
|
595
|
+
warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)),
|
|
596
|
+
(self.num_threads // 32, 1, 1),
|
|
597
|
+
permutation_mnk=(self.num_threads // 32 * 16, 16, 16),
|
|
598
|
+
)
|
|
599
|
+
tiled_mma_pv = cute.make_tiled_mma(
|
|
600
|
+
warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)),
|
|
601
|
+
(self.num_threads // 32, 1, 1),
|
|
602
|
+
permutation_mnk=(self.num_threads // 32 * 16, 16, 16),
|
|
603
|
+
)
|
|
604
|
+
return tiled_mma_qk, tiled_mma_pv
|
|
605
|
+
|
|
606
|
+
def _get_shared_storage_cls(self):
|
|
607
|
+
sQ_struct, sK_struct, sV_struct = [
|
|
608
|
+
cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024]
|
|
609
|
+
for layout in (self.sQ_layout, self.sK_layout, self.sV_layout)
|
|
610
|
+
]
|
|
611
|
+
cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout))
|
|
612
|
+
sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024]
|
|
613
|
+
|
|
614
|
+
@cute.struct
|
|
615
|
+
class SharedStorageQKV:
|
|
616
|
+
sV: sV_struct
|
|
617
|
+
sQ: sQ_struct
|
|
618
|
+
sK: sK_struct
|
|
619
|
+
|
|
620
|
+
@cute.struct
|
|
621
|
+
class SharedStorageSharedQV:
|
|
622
|
+
sQ: sQV_struct
|
|
623
|
+
sK: sK_struct
|
|
624
|
+
|
|
625
|
+
return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV
|
|
626
|
+
|
|
627
|
+
@cute.jit
|
|
628
|
+
def __call__(
|
|
629
|
+
self,
|
|
630
|
+
mQ: cute.Tensor,
|
|
631
|
+
mK: cute.Tensor,
|
|
632
|
+
mV: cute.Tensor,
|
|
633
|
+
mO: cute.Tensor,
|
|
634
|
+
mLSE: Optional[cute.Tensor],
|
|
635
|
+
stream: cuda.CUstream,
|
|
636
|
+
softmax_scale: Optional[Float32] = None,
|
|
637
|
+
window_size_left: Optional[Int32] = None,
|
|
638
|
+
window_size_right: Optional[Int32] = None,
|
|
639
|
+
learnable_sink: Optional[cute.Tensor] = None,
|
|
640
|
+
aux_tensors=None,
|
|
641
|
+
):
|
|
642
|
+
"""Configures and launches the flash attention kernel.
|
|
643
|
+
|
|
644
|
+
mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:
|
|
645
|
+
(batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1)
|
|
646
|
+
"""
|
|
647
|
+
assert learnable_sink is None, "Learnable sink is not supported in this kernel"
|
|
648
|
+
self._check_type(
|
|
649
|
+
*(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE))
|
|
650
|
+
)
|
|
651
|
+
tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma()
|
|
652
|
+
self.num_mma_threads = tiled_mma_pv.size
|
|
653
|
+
self.num_producer_threads = self.num_threads
|
|
654
|
+
self.num_Q_load_threads = self.num_threads
|
|
655
|
+
self.num_epilogue_threads = self.num_threads
|
|
656
|
+
# self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None
|
|
657
|
+
self.use_tma_O = self.arch >= 90
|
|
658
|
+
self._setup_attributes()
|
|
659
|
+
SharedStorage = self._get_shared_storage_cls()
|
|
660
|
+
# Assume all strides are divisible by 128 bits except the last stride
|
|
661
|
+
new_stride = lambda t: (
|
|
662
|
+
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
|
|
663
|
+
t.stride[-1],
|
|
664
|
+
)
|
|
665
|
+
mQ, mK, mV, mO = [
|
|
666
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
667
|
+
for t in (mQ, mK, mV, mO)
|
|
668
|
+
]
|
|
669
|
+
mQ, mK, mV, mO = [
|
|
670
|
+
cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0]))
|
|
671
|
+
for t in (mQ, mK, mV, mO)
|
|
672
|
+
]
|
|
673
|
+
mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=[2, 1, 0]))
|
|
674
|
+
# grid_dim: (m_block, num_head, batch_size)
|
|
675
|
+
grid_dim = (
|
|
676
|
+
cute.ceil_div(mQ.shape[0], self.tile_m),
|
|
677
|
+
cute.size(mQ.shape[2]),
|
|
678
|
+
cute.size(mQ.shape[3]),
|
|
679
|
+
)
|
|
680
|
+
LOG2_E = math.log2(math.e)
|
|
681
|
+
if const_expr(self.score_mod is None):
|
|
682
|
+
softmax_scale_log2 = Float32(softmax_scale * LOG2_E)
|
|
683
|
+
softmax_scale = None
|
|
684
|
+
else:
|
|
685
|
+
# NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk
|
|
686
|
+
# But in the original base 10. We hijack softmax_scale_log2 to just be the change of base
|
|
687
|
+
# and correctly apply the softmax_scale prior to score_mod in the softmax step
|
|
688
|
+
softmax_scale_log2 = Float32(LOG2_E)
|
|
689
|
+
softmax_scale = Float32(softmax_scale)
|
|
690
|
+
|
|
691
|
+
fastdiv_mods = None
|
|
692
|
+
if const_expr(aux_tensors is not None):
|
|
693
|
+
seqlen_q = cute.size(mQ.shape[0]) // (
|
|
694
|
+
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
|
|
695
|
+
)
|
|
696
|
+
seqlen_k = cute.size(mK.shape[0])
|
|
697
|
+
seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
|
|
698
|
+
seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
|
|
699
|
+
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
|
|
700
|
+
|
|
701
|
+
self.kernel(
|
|
702
|
+
mQ,
|
|
703
|
+
mK,
|
|
704
|
+
mV,
|
|
705
|
+
mO,
|
|
706
|
+
mLSE,
|
|
707
|
+
softmax_scale_log2,
|
|
708
|
+
softmax_scale,
|
|
709
|
+
window_size_left,
|
|
710
|
+
window_size_right,
|
|
711
|
+
self.sQ_layout,
|
|
712
|
+
self.sK_layout,
|
|
713
|
+
self.sV_layout,
|
|
714
|
+
self.sO_layout,
|
|
715
|
+
self.sP_layout,
|
|
716
|
+
self.gmem_tiled_copy_Q,
|
|
717
|
+
self.gmem_tiled_copy_K,
|
|
718
|
+
self.gmem_tiled_copy_V,
|
|
719
|
+
self.gmem_tiled_copy_O,
|
|
720
|
+
tiled_mma_qk,
|
|
721
|
+
tiled_mma_pv,
|
|
722
|
+
SharedStorage,
|
|
723
|
+
aux_tensors,
|
|
724
|
+
fastdiv_mods,
|
|
725
|
+
).launch(
|
|
726
|
+
grid=grid_dim,
|
|
727
|
+
block=[self.num_threads, 1, 1],
|
|
728
|
+
smem=SharedStorage.size_in_bytes(),
|
|
729
|
+
stream=stream,
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
@cute.kernel
|
|
733
|
+
def kernel(
|
|
734
|
+
self,
|
|
735
|
+
mQ: cute.Tensor,
|
|
736
|
+
mK: cute.Tensor,
|
|
737
|
+
mV: cute.Tensor,
|
|
738
|
+
mO: cute.Tensor,
|
|
739
|
+
mLSE: Optional[cute.Tensor],
|
|
740
|
+
softmax_scale_log2: Float32,
|
|
741
|
+
softmax_scale: Optional[Float32],
|
|
742
|
+
window_size_left: Optional[Int32],
|
|
743
|
+
window_size_right: Optional[Int32],
|
|
744
|
+
sQ_layout: cute.ComposedLayout,
|
|
745
|
+
sK_layout: cute.ComposedLayout,
|
|
746
|
+
sV_layout: cute.ComposedLayout,
|
|
747
|
+
sO_layout: cute.ComposedLayout,
|
|
748
|
+
sP_layout: cute.ComposedLayout | None,
|
|
749
|
+
gmem_tiled_copy_Q: cute.TiledCopy,
|
|
750
|
+
gmem_tiled_copy_K: cute.TiledCopy,
|
|
751
|
+
gmem_tiled_copy_V: cute.TiledCopy,
|
|
752
|
+
gmem_tiled_copy_O: cute.TiledCopy,
|
|
753
|
+
tiled_mma_qk: cute.TiledMma,
|
|
754
|
+
tiled_mma_pv: cute.TiledMma,
|
|
755
|
+
SharedStorage: cutlass.Constexpr,
|
|
756
|
+
aux_tensors=None,
|
|
757
|
+
fastdiv_mods=None,
|
|
758
|
+
):
|
|
759
|
+
# Thread index, block index
|
|
760
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
761
|
+
m_block, num_head, batch_size = cute.arch.block_idx()
|
|
762
|
+
|
|
763
|
+
block_info = BlockInfo(
|
|
764
|
+
self.tile_m,
|
|
765
|
+
self.tile_n,
|
|
766
|
+
self.is_causal,
|
|
767
|
+
self.is_local,
|
|
768
|
+
False, # is_split_kv
|
|
769
|
+
window_size_left,
|
|
770
|
+
window_size_right,
|
|
771
|
+
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
|
772
|
+
)
|
|
773
|
+
seqlen = SeqlenInfoQK.create(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0])
|
|
774
|
+
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
|
|
775
|
+
# TODO: return early if n_block_max == 0
|
|
776
|
+
# if self.is_causal:
|
|
777
|
+
# if n_block_max <= 0:
|
|
778
|
+
# return
|
|
779
|
+
n_block = n_block_max - 1
|
|
780
|
+
|
|
781
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
782
|
+
# Get the appropriate tiles for this thread block.
|
|
783
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
784
|
+
blkQ_shape = (self.tile_m, self.tile_hdim)
|
|
785
|
+
blkK_shape = (self.tile_n, self.tile_hdim)
|
|
786
|
+
blkV_shape = (self.tile_n, self.tile_hdimv)
|
|
787
|
+
gQ = cute.local_tile(mQ[None, None, num_head, batch_size], blkQ_shape, (m_block, 0))
|
|
788
|
+
num_head_kv = num_head // self.qhead_per_kvhead
|
|
789
|
+
gK = cute.local_tile(mK[None, None, num_head_kv, batch_size], blkK_shape, (None, 0))
|
|
790
|
+
gV = cute.local_tile(mV[None, None, num_head_kv, batch_size], blkV_shape, (None, 0))
|
|
791
|
+
|
|
792
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
793
|
+
# Get shared memory buffer
|
|
794
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
795
|
+
smem = cutlass.utils.SmemAllocator()
|
|
796
|
+
storage = smem.allocate(SharedStorage)
|
|
797
|
+
sQ = storage.sQ.get_tensor(sQ_layout)
|
|
798
|
+
sK = storage.sK.get_tensor(sK_layout)
|
|
799
|
+
if const_expr(not self.Q_in_regs):
|
|
800
|
+
sV = storage.sV.get_tensor(sV_layout)
|
|
801
|
+
else:
|
|
802
|
+
sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout)
|
|
803
|
+
# Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma
|
|
804
|
+
sVt = utils.transpose_view(sV)
|
|
805
|
+
|
|
806
|
+
gmem_thr_copy_K = gmem_tiled_copy_K.get_slice(tidx)
|
|
807
|
+
gmem_thr_copy_V = gmem_tiled_copy_V.get_slice(tidx)
|
|
808
|
+
# (CPY_Atom, CPY_N, CPY_K, n_block)
|
|
809
|
+
tKsK, tKgK = gmem_thr_copy_K.partition_D(sK), gmem_thr_copy_K.partition_S(gK)
|
|
810
|
+
# (CPY_Atom, CPY_N, CPY_K, n_block)
|
|
811
|
+
tVsV, tVgV = gmem_thr_copy_V.partition_D(sV), gmem_thr_copy_V.partition_S(gV)
|
|
812
|
+
|
|
813
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
814
|
+
# Tile MMA compute thread partitions and allocate accumulators
|
|
815
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
816
|
+
thr_mma_qk = tiled_mma_qk.get_slice(tidx)
|
|
817
|
+
thr_mma_pv = tiled_mma_pv.get_slice(tidx)
|
|
818
|
+
tSrQ = thr_mma_qk.make_fragment_A(thr_mma_qk.partition_A(sQ))
|
|
819
|
+
tSrK = thr_mma_qk.make_fragment_B(thr_mma_qk.partition_B(sK[None, None, 0]))
|
|
820
|
+
tOrVt = thr_mma_pv.make_fragment_B(thr_mma_pv.partition_B(sVt[None, None, 0]))
|
|
821
|
+
acc_shape_O = thr_mma_pv.partition_shape_C((self.tile_m, self.tile_hdimv))
|
|
822
|
+
acc_O = cute.make_fragment(acc_shape_O, Float32)
|
|
823
|
+
acc_O.fill(0.0)
|
|
824
|
+
|
|
825
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
826
|
+
# Smem copy atom tiling
|
|
827
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
828
|
+
smem_copy_atom_QK = cute.make_copy_atom(
|
|
829
|
+
warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4),
|
|
830
|
+
self.dtype,
|
|
831
|
+
)
|
|
832
|
+
smem_copy_atom_V = cute.make_copy_atom(
|
|
833
|
+
warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4),
|
|
834
|
+
self.dtype,
|
|
835
|
+
)
|
|
836
|
+
smem_thr_copy_Q = utils.make_tiled_copy_A(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx)
|
|
837
|
+
smem_thr_copy_K = utils.make_tiled_copy_B(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx)
|
|
838
|
+
smem_thr_copy_V = utils.make_tiled_copy_B(smem_copy_atom_V, tiled_mma_pv).get_slice(tidx)
|
|
839
|
+
|
|
840
|
+
tSsQ = smem_thr_copy_Q.partition_S(sQ)
|
|
841
|
+
tSsK = smem_thr_copy_K.partition_S(sK)
|
|
842
|
+
tOsVt = smem_thr_copy_V.partition_S(sVt)
|
|
843
|
+
|
|
844
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
845
|
+
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
|
|
846
|
+
# of tile_shape
|
|
847
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
848
|
+
# Construct identity layout for KV
|
|
849
|
+
cK = cute.make_identity_tensor((self.tile_n, self.tile_hdim))
|
|
850
|
+
tKcK = gmem_thr_copy_K.partition_S(cK)
|
|
851
|
+
t0KcK = gmem_thr_copy_K.get_slice(0).partition_S(cK)
|
|
852
|
+
if const_expr(self.tile_hdim == self.tile_hdimv):
|
|
853
|
+
tVcV = tKcK
|
|
854
|
+
t0VcV = t0KcK
|
|
855
|
+
else:
|
|
856
|
+
cV = cute.make_identity_tensor((self.tile_n, self.tile_hdimv))
|
|
857
|
+
tVcV = gmem_thr_copy_V.partition_S(cV)
|
|
858
|
+
t0VcV = gmem_thr_copy_V.get_slice(0).partition_S(cV)
|
|
859
|
+
# Allocate predicate tensors for m and n, here we only allocate the tile of k, and
|
|
860
|
+
# use "if" on the mn dimension.
|
|
861
|
+
# This is to reduce register pressure and gets 2-3% performance gain.
|
|
862
|
+
tKpK = utils.predicate_k(tKcK, limit=mK.shape[1])
|
|
863
|
+
if const_expr(self.same_hdim_kv):
|
|
864
|
+
tVpV = tKpK
|
|
865
|
+
else:
|
|
866
|
+
tVpV = utils.predicate_k(tVcV, limit=mV.shape[1])
|
|
867
|
+
|
|
868
|
+
# shape: (atom_v_m * rest_m)
|
|
869
|
+
softmax = Softmax.create(
|
|
870
|
+
softmax_scale_log2,
|
|
871
|
+
num_rows=acc_O.shape[0][0] * acc_O.shape[1],
|
|
872
|
+
softmax_scale=softmax_scale,
|
|
873
|
+
)
|
|
874
|
+
softmax.reset()
|
|
875
|
+
|
|
876
|
+
# group parameters for compute_one_n_block
|
|
877
|
+
mma_params = SimpleNamespace(
|
|
878
|
+
thr_mma_qk=thr_mma_qk,
|
|
879
|
+
thr_mma_pv=thr_mma_pv,
|
|
880
|
+
tSrQ=tSrQ,
|
|
881
|
+
tSrK=tSrK,
|
|
882
|
+
tOrVt=tOrVt,
|
|
883
|
+
acc_O=acc_O,
|
|
884
|
+
)
|
|
885
|
+
smem_copy_params = SimpleNamespace(
|
|
886
|
+
smem_thr_copy_Q=smem_thr_copy_Q,
|
|
887
|
+
smem_thr_copy_K=smem_thr_copy_K,
|
|
888
|
+
smem_thr_copy_V=smem_thr_copy_V,
|
|
889
|
+
tSsQ=tSsQ,
|
|
890
|
+
tSsK=tSsK,
|
|
891
|
+
tOsVt=tOsVt,
|
|
892
|
+
)
|
|
893
|
+
load_K = partial(
|
|
894
|
+
self.load_K, gmem_tiled_copy_K, tKgK, tKsK, tKcK, t0KcK, tKpK, seqlen=seqlen.seqlen_k
|
|
895
|
+
)
|
|
896
|
+
load_V = partial(
|
|
897
|
+
self.load_V, gmem_tiled_copy_V, tVgV, tVsV, tVcV, t0VcV, tVpV, seqlen=seqlen.seqlen_k
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
compute_one_n_block = partial(
|
|
901
|
+
self.compute_one_n_block,
|
|
902
|
+
mma_params=mma_params,
|
|
903
|
+
smem_copy_params=smem_copy_params,
|
|
904
|
+
softmax=softmax,
|
|
905
|
+
load_K=load_K,
|
|
906
|
+
load_V=load_V,
|
|
907
|
+
score_mod=self.score_mod,
|
|
908
|
+
batch_idx=batch_size,
|
|
909
|
+
head_idx=num_head,
|
|
910
|
+
m_block=m_block,
|
|
911
|
+
aux_tensors=aux_tensors,
|
|
912
|
+
fastdiv_mods=fastdiv_mods,
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
916
|
+
# Prologue
|
|
917
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
918
|
+
# Start async loads of the last mn-tile, where we take care of the mn residue
|
|
919
|
+
gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx)
|
|
920
|
+
self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, headdim=mQ.shape[1])
|
|
921
|
+
cute.arch.cp_async_commit_group()
|
|
922
|
+
|
|
923
|
+
def preprocess_Q():
|
|
924
|
+
cute.arch.cp_async_wait_group(self.num_stages * 2 - 1)
|
|
925
|
+
if const_expr(self.Q_in_regs):
|
|
926
|
+
cute.arch.barrier()
|
|
927
|
+
tSrQ_copy_view = smem_thr_copy_Q.retile(tSrQ)
|
|
928
|
+
cute.copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view)
|
|
929
|
+
|
|
930
|
+
# If Q_in_regs, we load Q, then load 1 stage of K, then (optionally) rotate Q and
|
|
931
|
+
# read from smem_q to registers, then load V.
|
|
932
|
+
# If !Q_in_regs, we load Q, load all stages of K & V, then (optionally) rotate Q.
|
|
933
|
+
if const_expr(self.Q_in_regs):
|
|
934
|
+
load_K(n_block, smem_pipe_write=0, need_predicates=True)
|
|
935
|
+
cute.arch.cp_async_commit_group()
|
|
936
|
+
preprocess_Q()
|
|
937
|
+
cute.arch.barrier() # Make sure all threads have read smem_q before loading V
|
|
938
|
+
|
|
939
|
+
for stage in cutlass.range_constexpr(self.num_stages):
|
|
940
|
+
if const_expr(not self.Q_in_regs or stage > 0):
|
|
941
|
+
if stage == 0 or n_block - stage >= 0:
|
|
942
|
+
load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0)
|
|
943
|
+
cute.arch.cp_async_commit_group()
|
|
944
|
+
if const_expr(stage < self.num_stages - 1):
|
|
945
|
+
if stage == 0 or n_block - stage >= 0:
|
|
946
|
+
load_V(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0)
|
|
947
|
+
cute.arch.cp_async_commit_group()
|
|
948
|
+
if const_expr(not self.Q_in_regs):
|
|
949
|
+
preprocess_Q()
|
|
950
|
+
|
|
951
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
952
|
+
# Mainloop
|
|
953
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
954
|
+
# Start processing of the first n-block.
|
|
955
|
+
# For performance reason, we separate out two kinds of iterations:
|
|
956
|
+
# those that need masking on S, and those that don't.
|
|
957
|
+
# We need masking on S for the very last block when K and V has length not multiple of tile_n.
|
|
958
|
+
# We also need masking on S if it's causal, for the last several blocks.
|
|
959
|
+
mask = AttentionMask(
|
|
960
|
+
self.tile_m,
|
|
961
|
+
self.tile_n,
|
|
962
|
+
seqlen.seqlen_q,
|
|
963
|
+
seqlen.seqlen_k,
|
|
964
|
+
window_size_left,
|
|
965
|
+
window_size_right,
|
|
966
|
+
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
|
967
|
+
)
|
|
968
|
+
mask_fn = partial(
|
|
969
|
+
mask.apply_mask,
|
|
970
|
+
m_block=m_block,
|
|
971
|
+
thr_mma=thr_mma_qk,
|
|
972
|
+
mask_causal=self.is_causal,
|
|
973
|
+
mask_local=self.is_local,
|
|
974
|
+
fastdiv_mods=fastdiv_mods if const_expr(self.mask_mod is not None) else None,
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
# First iteration with seqlen masking
|
|
978
|
+
smem_pipe_read = Int32(0)
|
|
979
|
+
smem_pipe_write = Int32(self.num_stages - 1)
|
|
980
|
+
compute_one_n_block(
|
|
981
|
+
n_block,
|
|
982
|
+
smem_pipe_read,
|
|
983
|
+
smem_pipe_write,
|
|
984
|
+
is_first_n_block=True,
|
|
985
|
+
check_inf=True,
|
|
986
|
+
mask_fn=partial(mask_fn, mask_seqlen=True),
|
|
987
|
+
)
|
|
988
|
+
smem_pipe_read = self.advance_pipeline(smem_pipe_read)
|
|
989
|
+
smem_pipe_write = self.advance_pipeline(smem_pipe_write)
|
|
990
|
+
# Next couple of iterations with causal masking
|
|
991
|
+
if const_expr(self.is_causal or self.is_local):
|
|
992
|
+
n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask(
|
|
993
|
+
seqlen, m_block, n_block_min
|
|
994
|
+
)
|
|
995
|
+
for n_tile in cutlass.range(n_block_max - 1 - n_block_min_causal_local_mask, unroll=1):
|
|
996
|
+
n_block = n_block_max - 2 - n_tile
|
|
997
|
+
compute_one_n_block(
|
|
998
|
+
n_block,
|
|
999
|
+
smem_pipe_read,
|
|
1000
|
+
smem_pipe_write,
|
|
1001
|
+
check_inf=True,
|
|
1002
|
+
mask_fn=partial(mask_fn, mask_seqlen=False),
|
|
1003
|
+
)
|
|
1004
|
+
smem_pipe_read = self.advance_pipeline(smem_pipe_read)
|
|
1005
|
+
smem_pipe_write = self.advance_pipeline(smem_pipe_write)
|
|
1006
|
+
# The remaining iterations have no masking
|
|
1007
|
+
for n_tile in cutlass.range(n_block, unroll=1):
|
|
1008
|
+
compute_one_n_block(
|
|
1009
|
+
n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True
|
|
1010
|
+
)
|
|
1011
|
+
smem_pipe_read = self.advance_pipeline(smem_pipe_read)
|
|
1012
|
+
smem_pipe_write = self.advance_pipeline(smem_pipe_write)
|
|
1013
|
+
# TODO: local
|
|
1014
|
+
|
|
1015
|
+
# normalize acc_O by row_sum and calculate the lse
|
|
1016
|
+
row_scale = softmax.finalize()
|
|
1017
|
+
softmax.rescale_O(acc_O, row_scale)
|
|
1018
|
+
|
|
1019
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
1020
|
+
# Epilogue
|
|
1021
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
1022
|
+
# reuse sQ's data iterator
|
|
1023
|
+
sO = cute.make_tensor(sQ.iterator, sO_layout)
|
|
1024
|
+
self.epilogue(
|
|
1025
|
+
acc_O,
|
|
1026
|
+
softmax.row_sum,
|
|
1027
|
+
mO,
|
|
1028
|
+
mLSE,
|
|
1029
|
+
sO,
|
|
1030
|
+
seqlen,
|
|
1031
|
+
gmem_tiled_copy_O,
|
|
1032
|
+
None,
|
|
1033
|
+
tiled_mma_pv,
|
|
1034
|
+
tidx,
|
|
1035
|
+
m_block,
|
|
1036
|
+
num_head,
|
|
1037
|
+
batch_size,
|
|
1038
|
+
)
|
|
1039
|
+
|
|
1040
|
+
@cute.jit
|
|
1041
|
+
def compute_one_n_block(
|
|
1042
|
+
self,
|
|
1043
|
+
n_block: Int32,
|
|
1044
|
+
smem_pipe_read: Int32,
|
|
1045
|
+
smem_pipe_write: Int32,
|
|
1046
|
+
mma_params: SimpleNamespace,
|
|
1047
|
+
smem_copy_params: SimpleNamespace,
|
|
1048
|
+
softmax: Softmax,
|
|
1049
|
+
load_K: Callable,
|
|
1050
|
+
load_V: Callable,
|
|
1051
|
+
score_mod: Callable | None,
|
|
1052
|
+
batch_idx: cutlass.Int32,
|
|
1053
|
+
head_idx: cutlass.Int32,
|
|
1054
|
+
m_block: cutlass.Int32,
|
|
1055
|
+
seqlen: SeqlenInfoQK,
|
|
1056
|
+
aux_tensors=None,
|
|
1057
|
+
fastdiv_mods=None,
|
|
1058
|
+
mask_fn: Optional[Callable] = None,
|
|
1059
|
+
is_first_n_block: cutlass.Constexpr = False,
|
|
1060
|
+
check_inf: cutlass.Constexpr = True,
|
|
1061
|
+
):
|
|
1062
|
+
"""Compute one n_block of S/O.
|
|
1063
|
+
|
|
1064
|
+
This function provides different variants for processing the first n block versus
|
|
1065
|
+
subsequent blocks.
|
|
1066
|
+
"""
|
|
1067
|
+
|
|
1068
|
+
def sync():
|
|
1069
|
+
cute.arch.cp_async_wait_group(self.num_stages * 2 - 2)
|
|
1070
|
+
cute.arch.barrier()
|
|
1071
|
+
|
|
1072
|
+
acc_shape_S = mma_params.thr_mma_qk.partition_shape_C((self.tile_m, self.tile_n))
|
|
1073
|
+
acc_S = cute.make_fragment(acc_shape_S, Float32)
|
|
1074
|
+
acc_S.fill(0.0)
|
|
1075
|
+
# wait for smem tile QK before mma calculation for S
|
|
1076
|
+
sync()
|
|
1077
|
+
|
|
1078
|
+
# need predicates for the first tile
|
|
1079
|
+
def load_V_next():
|
|
1080
|
+
if self.num_stages == 1 or n_block - self.num_stages + 1 >= 0:
|
|
1081
|
+
load_V(
|
|
1082
|
+
n_block - self.num_stages + 1,
|
|
1083
|
+
smem_pipe_write,
|
|
1084
|
+
need_predicates=is_first_n_block and self.num_stages == 1,
|
|
1085
|
+
)
|
|
1086
|
+
cute.arch.cp_async_commit_group()
|
|
1087
|
+
|
|
1088
|
+
load_V_next()
|
|
1089
|
+
sm80_utils.gemm(
|
|
1090
|
+
mma_params.thr_mma_qk,
|
|
1091
|
+
acc_S,
|
|
1092
|
+
mma_params.tSrQ,
|
|
1093
|
+
mma_params.tSrK,
|
|
1094
|
+
smem_copy_params.tSsQ,
|
|
1095
|
+
smem_copy_params.tSsK[
|
|
1096
|
+
None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0
|
|
1097
|
+
],
|
|
1098
|
+
smem_copy_params.smem_thr_copy_Q,
|
|
1099
|
+
smem_copy_params.smem_thr_copy_K,
|
|
1100
|
+
# hook_fn=load_V_next,
|
|
1101
|
+
A_in_regs=self.Q_in_regs,
|
|
1102
|
+
)
|
|
1103
|
+
if const_expr(score_mod is not None):
|
|
1104
|
+
self.apply_score_mod(
|
|
1105
|
+
mma_params.thr_mma_qk,
|
|
1106
|
+
batch_idx,
|
|
1107
|
+
head_idx,
|
|
1108
|
+
m_block,
|
|
1109
|
+
acc_S,
|
|
1110
|
+
n_block,
|
|
1111
|
+
seqlen,
|
|
1112
|
+
softmax_scale=softmax.softmax_scale,
|
|
1113
|
+
aux_tensors=aux_tensors,
|
|
1114
|
+
fastdiv_mods=fastdiv_mods,
|
|
1115
|
+
)
|
|
1116
|
+
|
|
1117
|
+
smem_pipe_write = self.advance_pipeline(smem_pipe_write)
|
|
1118
|
+
|
|
1119
|
+
def load_K_next():
|
|
1120
|
+
if n_block - self.num_stages >= 0:
|
|
1121
|
+
load_K(n_block - self.num_stages, smem_pipe_write, need_predicates=False)
|
|
1122
|
+
cute.arch.cp_async_commit_group()
|
|
1123
|
+
|
|
1124
|
+
# wait for smem tile V for O
|
|
1125
|
+
if const_expr(self.num_stages == 1):
|
|
1126
|
+
sync()
|
|
1127
|
+
load_K_next()
|
|
1128
|
+
if const_expr(mask_fn is not None):
|
|
1129
|
+
mask_fn(acc_S, n_block=n_block)
|
|
1130
|
+
row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf)
|
|
1131
|
+
softmax.rescale_O(mma_params.acc_O, row_scale)
|
|
1132
|
+
rP = cute.make_fragment_like(acc_S, self.dtype)
|
|
1133
|
+
rP.store(acc_S.load().to(self.dtype))
|
|
1134
|
+
tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout))
|
|
1135
|
+
if const_expr(self.num_stages > 1):
|
|
1136
|
+
sync()
|
|
1137
|
+
load_K_next()
|
|
1138
|
+
sm80_utils.gemm_rs(
|
|
1139
|
+
mma_params.thr_mma_pv,
|
|
1140
|
+
mma_params.acc_O,
|
|
1141
|
+
tOrP,
|
|
1142
|
+
mma_params.tOrVt,
|
|
1143
|
+
smem_copy_params.tOsVt[
|
|
1144
|
+
None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0
|
|
1145
|
+
],
|
|
1146
|
+
smem_copy_params.smem_thr_copy_V,
|
|
1147
|
+
# hook_fn=load_K_next,
|
|
1148
|
+
)
|
|
1149
|
+
# if const_expr(self.num_stages > 1):
|
|
1150
|
+
# load_K_next()
|
|
1151
|
+
|
|
1152
|
+
|
|
1153
|
+
class FlashAttentionForwardSm90(FlashAttentionForwardBase):
|
|
1154
|
+
arch = 90
|
|
1155
|
+
|
|
1156
|
+
def __init__(
|
|
1157
|
+
self,
|
|
1158
|
+
*args,
|
|
1159
|
+
intra_wg_overlap: bool = True,
|
|
1160
|
+
mma_pv_is_rs: bool = True,
|
|
1161
|
+
**kwargs,
|
|
1162
|
+
):
|
|
1163
|
+
super().__init__(*args, **kwargs)
|
|
1164
|
+
self.intra_wg_overlap = intra_wg_overlap
|
|
1165
|
+
self.mma_pv_is_rs = mma_pv_is_rs
|
|
1166
|
+
self.buffer_align_bytes = 1024
|
|
1167
|
+
|
|
1168
|
+
def _get_smem_layout_atom(self):
|
|
1169
|
+
sQ_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1170
|
+
sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim),
|
|
1171
|
+
self.dtype,
|
|
1172
|
+
)
|
|
1173
|
+
sK_layout_atom = sQ_layout_atom
|
|
1174
|
+
sV_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1175
|
+
sm90_utils_basic.get_smem_layout_atom(
|
|
1176
|
+
LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv
|
|
1177
|
+
),
|
|
1178
|
+
self.dtype,
|
|
1179
|
+
)
|
|
1180
|
+
sO_layout_atom = sV_layout_atom
|
|
1181
|
+
if not self.mma_pv_is_rs:
|
|
1182
|
+
sP_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1183
|
+
sm90_utils_basic.get_smem_layout_atom(
|
|
1184
|
+
LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n
|
|
1185
|
+
),
|
|
1186
|
+
self.dtype,
|
|
1187
|
+
)
|
|
1188
|
+
else:
|
|
1189
|
+
sP_layout_atom = None
|
|
1190
|
+
return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom
|
|
1191
|
+
|
|
1192
|
+
def _get_tiled_mma(self):
|
|
1193
|
+
tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma(
|
|
1194
|
+
self.dtype,
|
|
1195
|
+
self.dtype,
|
|
1196
|
+
warpgroup.OperandMajorMode.K,
|
|
1197
|
+
warpgroup.OperandMajorMode.K,
|
|
1198
|
+
Float32,
|
|
1199
|
+
atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512
|
|
1200
|
+
tiler_mn=(64, self.tile_n),
|
|
1201
|
+
)
|
|
1202
|
+
tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma(
|
|
1203
|
+
self.dtype,
|
|
1204
|
+
self.dtype,
|
|
1205
|
+
warpgroup.OperandMajorMode.K,
|
|
1206
|
+
warpgroup.OperandMajorMode.MN,
|
|
1207
|
+
Float32,
|
|
1208
|
+
atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512
|
|
1209
|
+
tiler_mn=(64, self.tile_hdimv),
|
|
1210
|
+
a_source=warpgroup.OperandSource.RMEM
|
|
1211
|
+
if self.mma_pv_is_rs
|
|
1212
|
+
else warpgroup.OperandSource.SMEM,
|
|
1213
|
+
)
|
|
1214
|
+
tiled_mma_pv_rs = sm90_utils_basic.make_trivial_tiled_mma(
|
|
1215
|
+
self.dtype,
|
|
1216
|
+
self.dtype,
|
|
1217
|
+
warpgroup.OperandMajorMode.K,
|
|
1218
|
+
warpgroup.OperandMajorMode.MN,
|
|
1219
|
+
Float32,
|
|
1220
|
+
atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512
|
|
1221
|
+
tiler_mn=(64, self.tile_hdimv),
|
|
1222
|
+
a_source=warpgroup.OperandSource.RMEM,
|
|
1223
|
+
)
|
|
1224
|
+
return tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs
|
|
1225
|
+
|
|
1226
|
+
def _get_shared_storage_cls(self):
|
|
1227
|
+
# If we use cp.async to load Q, we want sQ to align to 1024 bytes
|
|
1228
|
+
sQ_struct, sK_struct, sV_struct = [
|
|
1229
|
+
cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes]
|
|
1230
|
+
for layout in (self.sQ_layout, self.sK_layout, self.sV_layout)
|
|
1231
|
+
|
|
1232
|
+
]
|
|
1233
|
+
cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout))
|
|
1234
|
+
sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024]
|
|
1235
|
+
cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0
|
|
1236
|
+
sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024]
|
|
1237
|
+
# 1 for Q, 1 for O, self.num_stages*2 for K, self.num_stages*2 for V,
|
|
1238
|
+
mbar_ptr_QO_struct = cute.struct.MemRange[cutlass.Int64, 2]
|
|
1239
|
+
mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2]
|
|
1240
|
+
mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2]
|
|
1241
|
+
|
|
1242
|
+
@cute.struct
|
|
1243
|
+
class SharedStorageQKV:
|
|
1244
|
+
mbar_ptr: mbar_ptr_QO_struct
|
|
1245
|
+
mbar_ptr_K: mbar_ptr_K_struct
|
|
1246
|
+
mbar_ptr_V: mbar_ptr_V_struct
|
|
1247
|
+
sV: sV_struct
|
|
1248
|
+
sQ: sQ_struct
|
|
1249
|
+
sK: sK_struct
|
|
1250
|
+
sP: sP_struct
|
|
1251
|
+
|
|
1252
|
+
@cute.struct
|
|
1253
|
+
class SharedStorageSharedQV:
|
|
1254
|
+
mbar_ptr: mbar_ptr_QO_struct
|
|
1255
|
+
mbar_ptr_K: mbar_ptr_K_struct
|
|
1256
|
+
mbar_ptr_V: mbar_ptr_V_struct
|
|
1257
|
+
sQ: sQV_struct
|
|
1258
|
+
sK: sK_struct
|
|
1259
|
+
sP: sP_struct
|
|
1260
|
+
|
|
1261
|
+
return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV
|
|
1262
|
+
|
|
1263
|
+
@cute.jit
|
|
1264
|
+
def __call__(
|
|
1265
|
+
self,
|
|
1266
|
+
mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
|
|
1267
|
+
mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table
|
|
1268
|
+
mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table
|
|
1269
|
+
mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
|
1270
|
+
mLSE: Optional[cute.Tensor],
|
|
1271
|
+
softmax_scale: Float32,
|
|
1272
|
+
stream: cuda.CUstream,
|
|
1273
|
+
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
|
1274
|
+
mCuSeqlensK: Optional[cute.Tensor] = None,
|
|
1275
|
+
mSeqUsedQ: Optional[cute.Tensor] = None,
|
|
1276
|
+
mSeqUsedK: Optional[cute.Tensor] = None,
|
|
1277
|
+
mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq)
|
|
1278
|
+
window_size_left: Int32 | int | None = None,
|
|
1279
|
+
window_size_right: Int32 | int | None = None,
|
|
1280
|
+
learnable_sink: Optional[cute.Tensor] = None,
|
|
1281
|
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
|
1282
|
+
aux_tensors: Optional[list] = None,
|
|
1283
|
+
):
|
|
1284
|
+
"""Configures and launches the flash attention kernel.
|
|
1285
|
+
|
|
1286
|
+
mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:
|
|
1287
|
+
(batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1)
|
|
1288
|
+
"""
|
|
1289
|
+
|
|
1290
|
+
self._check_type(
|
|
1291
|
+
*(
|
|
1292
|
+
t.element_type if t is not None else None
|
|
1293
|
+
for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)
|
|
1294
|
+
)
|
|
1295
|
+
)
|
|
1296
|
+
|
|
1297
|
+
# Assume all strides are divisible by 128 bits except the last stride
|
|
1298
|
+
new_stride = lambda t: (
|
|
1299
|
+
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
|
|
1300
|
+
t.stride[-1],
|
|
1301
|
+
)
|
|
1302
|
+
|
|
1303
|
+
mQ, mK, mV, mO = [
|
|
1304
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
1305
|
+
for t in (mQ, mK, mV, mO)
|
|
1306
|
+
]
|
|
1307
|
+
QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
|
|
1308
|
+
mQ, mO = [utils.select(t, QO_layout_transpose) for t in (mQ, mO)]
|
|
1309
|
+
KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]
|
|
1310
|
+
mK, mV = [utils.select(t, KV_layout_transpose) for t in (mK, mV)]
|
|
1311
|
+
LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
|
|
1312
|
+
mLSE = utils.select(mLSE, LSE_layout_transpose) if const_expr(mLSE is not None) else None
|
|
1313
|
+
|
|
1314
|
+
tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs = self._get_tiled_mma()
|
|
1315
|
+
self.num_mma_threads = tiled_mma_qk.size
|
|
1316
|
+
self.num_threads_per_warp_group = 128
|
|
1317
|
+
self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group
|
|
1318
|
+
self.num_threads = self.num_threads_per_warp_group * (self.num_mma_warp_groups + 1)
|
|
1319
|
+
self.num_producer_threads = 32
|
|
1320
|
+
self.num_Q_load_threads = self.num_mma_threads # If not TMA_Q, MMA threads load Q
|
|
1321
|
+
self.num_epilogue_threads = self.num_mma_threads
|
|
1322
|
+
self.num_mma_regs = (
|
|
1323
|
+
256
|
|
1324
|
+
if self.num_mma_warp_groups == 1
|
|
1325
|
+
else (240 if self.num_mma_warp_groups == 2 else 160)
|
|
1326
|
+
)
|
|
1327
|
+
self.num_producer_regs = (
|
|
1328
|
+
56 if self.num_mma_warp_groups == 1 else (24 if self.num_mma_warp_groups == 2 else 32)
|
|
1329
|
+
)
|
|
1330
|
+
# self.num_mma_regs = 232
|
|
1331
|
+
# self.num_producer_regs = 40
|
|
1332
|
+
self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
|
|
1333
|
+
|
|
1334
|
+
self.use_scheduler_barrier = (
|
|
1335
|
+
(self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128)
|
|
1336
|
+
if const_expr(self.intra_wg_overlap)
|
|
1337
|
+
else (self.num_mma_warp_groups == 2)
|
|
1338
|
+
)
|
|
1339
|
+
self.use_tma_Q = self.arch >= 90 and not (
|
|
1340
|
+
self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0
|
|
1341
|
+
)
|
|
1342
|
+
self.use_tma_O = (
|
|
1343
|
+
self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa
|
|
1344
|
+
)
|
|
1345
|
+
# TODO: rescale_O_before_gemm
|
|
1346
|
+
self._setup_attributes()
|
|
1347
|
+
# TODO: we prob don't need most of what's in _setup_attributes
|
|
1348
|
+
self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [
|
|
1349
|
+
sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage)
|
|
1350
|
+
for mX, shape, stage in [
|
|
1351
|
+
(mQ, (self.tile_m, self.tile_hdim), None),
|
|
1352
|
+
(mK, (self.tile_n, self.tile_hdim), self.num_stages),
|
|
1353
|
+
(mV, (self.tile_n, self.tile_hdimv), self.num_stages),
|
|
1354
|
+
(mO, (self.tile_m, self.tile_hdimv), None),
|
|
1355
|
+
]
|
|
1356
|
+
]
|
|
1357
|
+
self.sP_layout = None
|
|
1358
|
+
if const_expr(not self.mma_pv_is_rs):
|
|
1359
|
+
self.sP_layout = sm90_utils.make_smem_layout(
|
|
1360
|
+
mV.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n)
|
|
1361
|
+
)
|
|
1362
|
+
|
|
1363
|
+
SharedStorage = self._get_shared_storage_cls()
|
|
1364
|
+
|
|
1365
|
+
if const_expr(self.pack_gqa):
|
|
1366
|
+
shape_Q_packed = (
|
|
1367
|
+
(self.qhead_per_kvhead, mQ.shape[0]),
|
|
1368
|
+
mQ.shape[1],
|
|
1369
|
+
mK.shape[2],
|
|
1370
|
+
*mQ.shape[3:],
|
|
1371
|
+
)
|
|
1372
|
+
stride_Q_packed = (
|
|
1373
|
+
(mQ.stride[2], mQ.stride[0]),
|
|
1374
|
+
mQ.stride[1],
|
|
1375
|
+
mQ.stride[2] * self.qhead_per_kvhead,
|
|
1376
|
+
*mQ.stride[3:],
|
|
1377
|
+
)
|
|
1378
|
+
mQ = cute.make_tensor(
|
|
1379
|
+
mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)
|
|
1380
|
+
)
|
|
1381
|
+
shape_O_packed = (
|
|
1382
|
+
(self.qhead_per_kvhead, mO.shape[0]),
|
|
1383
|
+
mK.shape[1],
|
|
1384
|
+
mK.shape[2],
|
|
1385
|
+
*mO.shape[3:],
|
|
1386
|
+
)
|
|
1387
|
+
stride_O_packed = (
|
|
1388
|
+
(mO.stride[2], mO.stride[0]),
|
|
1389
|
+
mO.stride[1],
|
|
1390
|
+
mO.stride[2] * self.qhead_per_kvhead,
|
|
1391
|
+
*mO.stride[3:],
|
|
1392
|
+
)
|
|
1393
|
+
mO = cute.make_tensor(
|
|
1394
|
+
mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)
|
|
1395
|
+
)
|
|
1396
|
+
if const_expr(mLSE is not None):
|
|
1397
|
+
shape_LSE_packed = (
|
|
1398
|
+
(self.qhead_per_kvhead, mLSE.shape[0]),
|
|
1399
|
+
mK.shape[2],
|
|
1400
|
+
*mLSE.shape[2:],
|
|
1401
|
+
)
|
|
1402
|
+
stride_LSE_packed = (
|
|
1403
|
+
(mLSE.stride[1], mLSE.stride[0]),
|
|
1404
|
+
mLSE.stride[1] * self.qhead_per_kvhead,
|
|
1405
|
+
*mLSE.stride[2:],
|
|
1406
|
+
)
|
|
1407
|
+
mLSE = cute.make_tensor(
|
|
1408
|
+
mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)
|
|
1409
|
+
)
|
|
1410
|
+
|
|
1411
|
+
# TMA
|
|
1412
|
+
gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp()
|
|
1413
|
+
gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast
|
|
1414
|
+
gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp()
|
|
1415
|
+
self.tma_copy_bytes = {
|
|
1416
|
+
name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1]))
|
|
1417
|
+
for name, mX, layout in [
|
|
1418
|
+
("Q", mQ, self.sQ_layout),
|
|
1419
|
+
("K", mK, self.sK_layout),
|
|
1420
|
+
("V", mV, self.sV_layout),
|
|
1421
|
+
]
|
|
1422
|
+
}
|
|
1423
|
+
tma_atom_Q, tma_tensor_Q = None, None
|
|
1424
|
+
if const_expr(self.use_tma_Q):
|
|
1425
|
+
tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom(
|
|
1426
|
+
gmem_tiled_copy_Q,
|
|
1427
|
+
mQ,
|
|
1428
|
+
self.sQ_layout,
|
|
1429
|
+
(self.tile_m, self.tile_hdim), # No mcast
|
|
1430
|
+
)
|
|
1431
|
+
tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom(
|
|
1432
|
+
gmem_tiled_copy_KV,
|
|
1433
|
+
mK,
|
|
1434
|
+
cute.select(self.sK_layout, mode=[0, 1]),
|
|
1435
|
+
(self.tile_n, self.tile_hdim),
|
|
1436
|
+
1, # No mcast for now
|
|
1437
|
+
)
|
|
1438
|
+
tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom(
|
|
1439
|
+
gmem_tiled_copy_KV,
|
|
1440
|
+
mV,
|
|
1441
|
+
cute.select(self.sV_layout, mode=[0, 1]),
|
|
1442
|
+
(self.tile_n, self.tile_hdimv),
|
|
1443
|
+
1, # No mcast for now
|
|
1444
|
+
)
|
|
1445
|
+
tma_atom_O, tma_tensor_O = None, None
|
|
1446
|
+
if const_expr(self.use_tma_O):
|
|
1447
|
+
tma_atom_O, tma_tensor_O = cpasync.make_tiled_tma_atom(
|
|
1448
|
+
gmem_tiled_copy_O,
|
|
1449
|
+
mO,
|
|
1450
|
+
self.sO_layout,
|
|
1451
|
+
(self.tile_m, self.tile_hdimv), # No mcast
|
|
1452
|
+
)
|
|
1453
|
+
if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):
|
|
1454
|
+
TileScheduler = SingleTileVarlenScheduler
|
|
1455
|
+
else:
|
|
1456
|
+
TileScheduler = (
|
|
1457
|
+
SingleTileScheduler
|
|
1458
|
+
if const_expr(not self.is_causal or self.is_local)
|
|
1459
|
+
else SingleTileLPTScheduler
|
|
1460
|
+
)
|
|
1461
|
+
tile_sched_args = TileSchedulerArguments(
|
|
1462
|
+
cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m),
|
|
1463
|
+
cute.size(mQ.shape[2]),
|
|
1464
|
+
cute.size(mQ.shape[3])
|
|
1465
|
+
if const_expr(mCuSeqlensQ is None)
|
|
1466
|
+
else cute.size(mCuSeqlensQ.shape[0] - 1),
|
|
1467
|
+
1, # num_splits
|
|
1468
|
+
cute.size(mK.shape[0]),
|
|
1469
|
+
mQ.shape[1],
|
|
1470
|
+
mV.shape[1],
|
|
1471
|
+
total_q=cute.size(mQ.shape[0])
|
|
1472
|
+
if const_expr(mCuSeqlensQ is not None)
|
|
1473
|
+
else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),
|
|
1474
|
+
tile_shape_mn=(self.tile_m, self.tile_n),
|
|
1475
|
+
mCuSeqlensQ=mCuSeqlensQ,
|
|
1476
|
+
mSeqUsedQ=mSeqUsedQ,
|
|
1477
|
+
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
|
1478
|
+
element_size=self.dtype.width // 8,
|
|
1479
|
+
is_persistent=False,
|
|
1480
|
+
lpt=self.is_causal or self.is_local,
|
|
1481
|
+
)
|
|
1482
|
+
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
|
1483
|
+
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
|
|
1484
|
+
LOG2_E = math.log2(math.e)
|
|
1485
|
+
if const_expr(self.score_mod is None):
|
|
1486
|
+
softmax_scale_log2 = softmax_scale * LOG2_E
|
|
1487
|
+
softmax_scale = None
|
|
1488
|
+
else:
|
|
1489
|
+
# NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk
|
|
1490
|
+
# But in the original base 10. We hijack softmax_scale_log2 to just be the change of base
|
|
1491
|
+
# and correctly apply the softmax_scale prior to score_mod in the softmax step
|
|
1492
|
+
softmax_scale_log2 = LOG2_E
|
|
1493
|
+
softmax_scale = softmax_scale
|
|
1494
|
+
if const_expr(window_size_left is not None):
|
|
1495
|
+
window_size_left = Int32(window_size_left)
|
|
1496
|
+
if const_expr(window_size_right is not None):
|
|
1497
|
+
window_size_right = Int32(window_size_right)
|
|
1498
|
+
|
|
1499
|
+
fastdiv_mods = None
|
|
1500
|
+
if const_expr(aux_tensors is not None):
|
|
1501
|
+
seqlen_q = cute.size(mQ.shape[0]) // (
|
|
1502
|
+
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
|
|
1503
|
+
)
|
|
1504
|
+
seqlen_k = (
|
|
1505
|
+
cute.size(mK.shape[0])
|
|
1506
|
+
if const_expr(mPageTable is None)
|
|
1507
|
+
else mK.shape[0] * mPageTable.shape[1]
|
|
1508
|
+
)
|
|
1509
|
+
seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
|
|
1510
|
+
seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
|
|
1511
|
+
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
|
|
1512
|
+
|
|
1513
|
+
self.kernel(
|
|
1514
|
+
tma_tensor_Q if const_expr(self.use_tma_Q) else mQ,
|
|
1515
|
+
tma_tensor_K,
|
|
1516
|
+
tma_tensor_V,
|
|
1517
|
+
tma_tensor_O if const_expr(self.use_tma_O) else mO,
|
|
1518
|
+
mLSE,
|
|
1519
|
+
mCuSeqlensQ,
|
|
1520
|
+
mCuSeqlensK,
|
|
1521
|
+
mSeqUsedQ,
|
|
1522
|
+
mSeqUsedK,
|
|
1523
|
+
tma_atom_Q,
|
|
1524
|
+
tma_atom_K,
|
|
1525
|
+
tma_atom_V,
|
|
1526
|
+
tma_atom_O,
|
|
1527
|
+
softmax_scale_log2,
|
|
1528
|
+
softmax_scale,
|
|
1529
|
+
window_size_left,
|
|
1530
|
+
window_size_right,
|
|
1531
|
+
learnable_sink,
|
|
1532
|
+
blocksparse_tensors,
|
|
1533
|
+
self.sQ_layout,
|
|
1534
|
+
self.sK_layout,
|
|
1535
|
+
self.sV_layout,
|
|
1536
|
+
self.sO_layout,
|
|
1537
|
+
self.sP_layout,
|
|
1538
|
+
self.gmem_tiled_copy_Q,
|
|
1539
|
+
self.gmem_tiled_copy_K,
|
|
1540
|
+
self.gmem_tiled_copy_V,
|
|
1541
|
+
self.gmem_tiled_copy_O,
|
|
1542
|
+
tiled_mma_qk,
|
|
1543
|
+
tiled_mma_pv,
|
|
1544
|
+
tiled_mma_pv_rs,
|
|
1545
|
+
tile_sched_params,
|
|
1546
|
+
TileScheduler,
|
|
1547
|
+
SharedStorage,
|
|
1548
|
+
aux_tensors,
|
|
1549
|
+
fastdiv_mods,
|
|
1550
|
+
).launch(
|
|
1551
|
+
grid=grid_dim,
|
|
1552
|
+
block=[self.num_threads, 1, 1],
|
|
1553
|
+
stream=stream,
|
|
1554
|
+
min_blocks_per_mp=1,
|
|
1555
|
+
)
|
|
1556
|
+
|
|
1557
|
+
@cute.kernel
|
|
1558
|
+
def kernel(
|
|
1559
|
+
self,
|
|
1560
|
+
mQ: cute.Tensor,
|
|
1561
|
+
mK: cute.Tensor,
|
|
1562
|
+
mV: cute.Tensor,
|
|
1563
|
+
mO: cute.Tensor,
|
|
1564
|
+
mLSE: Optional[cute.Tensor],
|
|
1565
|
+
mCuSeqlensQ: Optional[cute.Tensor],
|
|
1566
|
+
mCuSeqlensK: Optional[cute.Tensor],
|
|
1567
|
+
mSeqUsedQ: Optional[cute.Tensor],
|
|
1568
|
+
mSeqUsedK: Optional[cute.Tensor],
|
|
1569
|
+
tma_atom_Q: Optional[cute.CopyAtom],
|
|
1570
|
+
tma_atom_K: Optional[cute.CopyAtom],
|
|
1571
|
+
tma_atom_V: Optional[cute.CopyAtom],
|
|
1572
|
+
tma_atom_O: Optional[cute.CopyAtom],
|
|
1573
|
+
softmax_scale_log2: Float32,
|
|
1574
|
+
softmax_scale: Optional[Float32],
|
|
1575
|
+
window_size_left: Optional[Int32],
|
|
1576
|
+
window_size_right: Optional[Int32],
|
|
1577
|
+
learnable_sink: Optional[cute.Tensor],
|
|
1578
|
+
blocksparse_tensors: Optional[BlockSparseTensors],
|
|
1579
|
+
sQ_layout: cute.ComposedLayout,
|
|
1580
|
+
sK_layout: cute.ComposedLayout,
|
|
1581
|
+
sV_layout: cute.ComposedLayout,
|
|
1582
|
+
sO_layout: cute.ComposedLayout,
|
|
1583
|
+
sP_layout: cute.ComposedLayout | None,
|
|
1584
|
+
gmem_tiled_copy_Q: cute.TiledCopy,
|
|
1585
|
+
gmem_tiled_copy_K: cute.TiledCopy,
|
|
1586
|
+
gmem_tiled_copy_V: cute.TiledCopy,
|
|
1587
|
+
gmem_tiled_copy_O: cute.TiledCopy,
|
|
1588
|
+
tiled_mma_qk: cute.TiledMma,
|
|
1589
|
+
tiled_mma_pv: cute.TiledMma,
|
|
1590
|
+
tiled_mma_pv_rs: cute.TiledMma,
|
|
1591
|
+
tile_sched_params: ParamsBase,
|
|
1592
|
+
TileScheduler: cutlass.Constexpr[Callable],
|
|
1593
|
+
SharedStorage: cutlass.Constexpr[Callable],
|
|
1594
|
+
aux_tensors=Optional[list[cute.Tensor]],
|
|
1595
|
+
fastdiv_mods=None,
|
|
1596
|
+
):
|
|
1597
|
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
1598
|
+
# Prefetch tma descriptor
|
|
1599
|
+
if warp_idx == 0:
|
|
1600
|
+
for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O):
|
|
1601
|
+
if const_expr(tma_atom is not None):
|
|
1602
|
+
cpasync.prefetch_descriptor(tma_atom)
|
|
1603
|
+
|
|
1604
|
+
smem = cutlass.utils.SmemAllocator()
|
|
1605
|
+
storage = smem.allocate(SharedStorage)
|
|
1606
|
+
|
|
1607
|
+
# Mbarrier init
|
|
1608
|
+
mbar_ptr_Q = storage.mbar_ptr.data_ptr()
|
|
1609
|
+
if warp_idx == 1:
|
|
1610
|
+
# if tidx < 2:
|
|
1611
|
+
# # barrierO num threads should be self.num_mma_threads
|
|
1612
|
+
# cute.arch.mbarrier_init(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads)
|
|
1613
|
+
if const_expr(not self.use_tma_Q):
|
|
1614
|
+
cute.arch.mbarrier_init(mbar_ptr_Q, self.num_Q_load_threads)
|
|
1615
|
+
# cute.arch.mbarrier_init(mbar_ptr_Q + 1, self.num_mma_threads)
|
|
1616
|
+
# We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync
|
|
1617
|
+
pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup(
|
|
1618
|
+
cutlass.pipeline.Agent.Thread
|
|
1619
|
+
)
|
|
1620
|
+
pipeline_kv_consumer_group = cutlass.pipeline.CooperativeGroup(
|
|
1621
|
+
cutlass.pipeline.Agent.Thread, self.num_mma_threads // cute.arch.WARP_SIZE
|
|
1622
|
+
)
|
|
1623
|
+
pipeline_k = pipeline.PipelineTmaAsync.create(
|
|
1624
|
+
barrier_storage=storage.mbar_ptr_K.data_ptr(),
|
|
1625
|
+
num_stages=self.num_stages,
|
|
1626
|
+
producer_group=pipeline_kv_producer_group,
|
|
1627
|
+
consumer_group=pipeline_kv_consumer_group,
|
|
1628
|
+
tx_count=self.tma_copy_bytes["K"],
|
|
1629
|
+
defer_sync=True,
|
|
1630
|
+
)
|
|
1631
|
+
pipeline_v = pipeline.PipelineTmaAsync.create(
|
|
1632
|
+
barrier_storage=storage.mbar_ptr_V.data_ptr(),
|
|
1633
|
+
num_stages=self.num_stages,
|
|
1634
|
+
producer_group=pipeline_kv_producer_group,
|
|
1635
|
+
consumer_group=pipeline_kv_consumer_group,
|
|
1636
|
+
tx_count=self.tma_copy_bytes["V"],
|
|
1637
|
+
defer_sync=False
|
|
1638
|
+
)
|
|
1639
|
+
|
|
1640
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
1641
|
+
# Get shared memory buffer
|
|
1642
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
1643
|
+
sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner)
|
|
1644
|
+
sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)
|
|
1645
|
+
if const_expr(not self.Q_in_regs):
|
|
1646
|
+
sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner)
|
|
1647
|
+
else:
|
|
1648
|
+
sV = storage.sQ.get_tensor(
|
|
1649
|
+
sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type
|
|
1650
|
+
)
|
|
1651
|
+
# Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma
|
|
1652
|
+
sVt = utils.transpose_view(sV)
|
|
1653
|
+
sP = None
|
|
1654
|
+
if const_expr(sP_layout is not None):
|
|
1655
|
+
sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner)
|
|
1656
|
+
# reuse sQ's data iterator
|
|
1657
|
+
sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype)
|
|
1658
|
+
|
|
1659
|
+
block_info = BlockInfo(
|
|
1660
|
+
self.tile_m,
|
|
1661
|
+
self.tile_n,
|
|
1662
|
+
self.is_causal,
|
|
1663
|
+
self.is_local,
|
|
1664
|
+
False, # is_split_kv
|
|
1665
|
+
window_size_left,
|
|
1666
|
+
window_size_right,
|
|
1667
|
+
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
|
1668
|
+
)
|
|
1669
|
+
SeqlenInfoCls = partial(
|
|
1670
|
+
SeqlenInfoQK.create,
|
|
1671
|
+
seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1],
|
|
1672
|
+
seqlen_k_static=mK.shape[0],
|
|
1673
|
+
mCuSeqlensQ=mCuSeqlensQ,
|
|
1674
|
+
mCuSeqlensK=mCuSeqlensK,
|
|
1675
|
+
mSeqUsedQ=mSeqUsedQ,
|
|
1676
|
+
mSeqUsedK=mSeqUsedK,
|
|
1677
|
+
)
|
|
1678
|
+
AttentionMaskCls = partial(
|
|
1679
|
+
AttentionMask,
|
|
1680
|
+
self.tile_m,
|
|
1681
|
+
self.tile_n,
|
|
1682
|
+
window_size_left=window_size_left,
|
|
1683
|
+
window_size_right=window_size_right,
|
|
1684
|
+
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
|
1685
|
+
)
|
|
1686
|
+
TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
|
|
1687
|
+
|
|
1688
|
+
if warp_idx < 4: # Producer
|
|
1689
|
+
cute.arch.warpgroup_reg_dealloc(self.num_producer_regs)
|
|
1690
|
+
self.load(
|
|
1691
|
+
mQ,
|
|
1692
|
+
mK,
|
|
1693
|
+
mV,
|
|
1694
|
+
sQ,
|
|
1695
|
+
sK,
|
|
1696
|
+
sV,
|
|
1697
|
+
tma_atom_Q,
|
|
1698
|
+
tma_atom_K,
|
|
1699
|
+
tma_atom_V,
|
|
1700
|
+
pipeline_k,
|
|
1701
|
+
pipeline_v,
|
|
1702
|
+
mbar_ptr_Q,
|
|
1703
|
+
blocksparse_tensors,
|
|
1704
|
+
block_info,
|
|
1705
|
+
SeqlenInfoCls,
|
|
1706
|
+
TileSchedulerCls,
|
|
1707
|
+
)
|
|
1708
|
+
|
|
1709
|
+
else: # Consumer
|
|
1710
|
+
cute.arch.warpgroup_reg_alloc(self.num_mma_regs)
|
|
1711
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
1712
|
+
# Tile MMA compute thread partitions and allocate accumulators
|
|
1713
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
1714
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
1715
|
+
tidx = tidx - 128
|
|
1716
|
+
self.mma(
|
|
1717
|
+
tiled_mma_qk,
|
|
1718
|
+
tiled_mma_pv,
|
|
1719
|
+
tiled_mma_pv_rs,
|
|
1720
|
+
mQ,
|
|
1721
|
+
mO,
|
|
1722
|
+
mLSE,
|
|
1723
|
+
sQ,
|
|
1724
|
+
sK,
|
|
1725
|
+
sVt,
|
|
1726
|
+
sP,
|
|
1727
|
+
sO,
|
|
1728
|
+
learnable_sink,
|
|
1729
|
+
pipeline_k,
|
|
1730
|
+
pipeline_v,
|
|
1731
|
+
mbar_ptr_Q,
|
|
1732
|
+
gmem_tiled_copy_Q,
|
|
1733
|
+
gmem_tiled_copy_O,
|
|
1734
|
+
tma_atom_O,
|
|
1735
|
+
tidx,
|
|
1736
|
+
softmax_scale_log2,
|
|
1737
|
+
softmax_scale,
|
|
1738
|
+
block_info,
|
|
1739
|
+
SeqlenInfoCls,
|
|
1740
|
+
AttentionMaskCls,
|
|
1741
|
+
TileSchedulerCls,
|
|
1742
|
+
blocksparse_tensors,
|
|
1743
|
+
aux_tensors,
|
|
1744
|
+
fastdiv_mods,
|
|
1745
|
+
)
|
|
1746
|
+
|
|
1747
|
+
@cute.jit
|
|
1748
|
+
def load(
|
|
1749
|
+
self,
|
|
1750
|
+
mQ: cute.Tensor,
|
|
1751
|
+
mK: cute.Tensor,
|
|
1752
|
+
mV: cute.Tensor,
|
|
1753
|
+
sQ: cute.Tensor,
|
|
1754
|
+
sK: cute.Tensor,
|
|
1755
|
+
sV: cute.Tensor,
|
|
1756
|
+
tma_atom_Q: cute.CopyAtom,
|
|
1757
|
+
tma_atom_K: cute.CopyAtom,
|
|
1758
|
+
tma_atom_V: cute.CopyAtom,
|
|
1759
|
+
pipeline_k: cutlass.pipeline.PipelineAsync,
|
|
1760
|
+
pipeline_v: cutlass.pipeline.PipelineAsync,
|
|
1761
|
+
mbar_ptr_Q: cutlass.Pointer,
|
|
1762
|
+
blocksparse_tensors: Optional[BlockSparseTensors],
|
|
1763
|
+
block_info: BlockInfo,
|
|
1764
|
+
SeqlenInfoCls: Callable,
|
|
1765
|
+
TileSchedulerCls: Callable,
|
|
1766
|
+
):
|
|
1767
|
+
warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
|
|
1768
|
+
if warp_idx_in_wg == 0:
|
|
1769
|
+
q_producer_phase = Int32(1)
|
|
1770
|
+
kv_producer_state = pipeline.make_pipeline_state(
|
|
1771
|
+
cutlass.pipeline.PipelineUserType.Producer, self.num_stages
|
|
1772
|
+
)
|
|
1773
|
+
tile_scheduler = TileSchedulerCls()
|
|
1774
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1775
|
+
while work_tile.is_valid_tile:
|
|
1776
|
+
# if work_tile.is_valid_tile:
|
|
1777
|
+
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
|
1778
|
+
seqlen = SeqlenInfoCls(batch_idx)
|
|
1779
|
+
mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
|
|
1780
|
+
head_idx_kv = (
|
|
1781
|
+
head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx
|
|
1782
|
+
)
|
|
1783
|
+
mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv]
|
|
1784
|
+
mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv]
|
|
1785
|
+
gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0))
|
|
1786
|
+
gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0))
|
|
1787
|
+
if const_expr(self.use_tma_Q):
|
|
1788
|
+
gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))
|
|
1789
|
+
load_Q, _, _ = copy_utils.tma_get_copy_fn(
|
|
1790
|
+
tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True
|
|
1791
|
+
)
|
|
1792
|
+
# TODO: mcast
|
|
1793
|
+
# TODO check warp_idx if we have 128 producer threads
|
|
1794
|
+
load_K, _, _ = copy_utils.tma_get_copy_fn(
|
|
1795
|
+
tma_atom_K, 0, cute.make_layout(1), gK, sK
|
|
1796
|
+
)
|
|
1797
|
+
load_K = copy_utils.tma_producer_copy_fn(load_K, pipeline_k)
|
|
1798
|
+
load_V, _, _ = copy_utils.tma_get_copy_fn(
|
|
1799
|
+
tma_atom_V, 0, cute.make_layout(1), gV, sV
|
|
1800
|
+
)
|
|
1801
|
+
load_V = copy_utils.tma_producer_copy_fn(load_V, pipeline_v)
|
|
1802
|
+
|
|
1803
|
+
if const_expr(not self.use_block_sparsity):
|
|
1804
|
+
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
|
|
1805
|
+
# if cute.arch.thread_idx()[0] == 0:
|
|
1806
|
+
# cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max)
|
|
1807
|
+
# First iteration: load both Q & K with the same mbarrier
|
|
1808
|
+
n_block = n_block_max - 1
|
|
1809
|
+
pipeline_k.producer_acquire(
|
|
1810
|
+
kv_producer_state,
|
|
1811
|
+
extra_tx_count=self.tma_copy_bytes["Q"]
|
|
1812
|
+
if const_expr(self.use_tma_Q)
|
|
1813
|
+
else 0,
|
|
1814
|
+
)
|
|
1815
|
+
if const_expr(self.use_tma_Q):
|
|
1816
|
+
load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
|
|
1817
|
+
load_K(src_idx=n_block, producer_state=kv_producer_state)
|
|
1818
|
+
|
|
1819
|
+
if const_expr(not self.intra_wg_overlap):
|
|
1820
|
+
pipeline_v.producer_acquire(kv_producer_state)
|
|
1821
|
+
load_V(src_idx=n_block, producer_state=kv_producer_state)
|
|
1822
|
+
kv_producer_state.advance()
|
|
1823
|
+
for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
|
|
1824
|
+
n_block = n_block_max - 1 - i - 1
|
|
1825
|
+
pipeline_k.producer_acquire(kv_producer_state)
|
|
1826
|
+
load_K(src_idx=n_block, producer_state=kv_producer_state)
|
|
1827
|
+
pipeline_v.producer_acquire(kv_producer_state)
|
|
1828
|
+
load_V(src_idx=n_block, producer_state=kv_producer_state)
|
|
1829
|
+
kv_producer_state.advance()
|
|
1830
|
+
else:
|
|
1831
|
+
for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
|
|
1832
|
+
n_block_prev = n_block_max - i - 1
|
|
1833
|
+
n_block = n_block_prev - 1
|
|
1834
|
+
kv_producer_state_prev = kv_producer_state.clone()
|
|
1835
|
+
kv_producer_state.advance()
|
|
1836
|
+
pipeline_k.producer_acquire(kv_producer_state)
|
|
1837
|
+
load_K(src_idx=n_block, producer_state=kv_producer_state)
|
|
1838
|
+
pipeline_v.producer_acquire(kv_producer_state_prev)
|
|
1839
|
+
load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev)
|
|
1840
|
+
n_block = n_block_min
|
|
1841
|
+
pipeline_v.producer_acquire(kv_producer_state)
|
|
1842
|
+
load_V(src_idx=n_block, producer_state=kv_producer_state)
|
|
1843
|
+
kv_producer_state.advance()
|
|
1844
|
+
else:
|
|
1845
|
+
kv_producer_state = produce_block_sparse_loads(
|
|
1846
|
+
blocksparse_tensors,
|
|
1847
|
+
batch_idx,
|
|
1848
|
+
head_idx,
|
|
1849
|
+
m_block,
|
|
1850
|
+
kv_producer_state,
|
|
1851
|
+
load_Q,
|
|
1852
|
+
load_K,
|
|
1853
|
+
load_V,
|
|
1854
|
+
pipeline_k,
|
|
1855
|
+
pipeline_v,
|
|
1856
|
+
self.use_tma_Q,
|
|
1857
|
+
self.tma_copy_bytes["Q"],
|
|
1858
|
+
self.intra_wg_overlap,
|
|
1859
|
+
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
|
1860
|
+
)
|
|
1861
|
+
|
|
1862
|
+
tile_scheduler.prefetch_next_work()
|
|
1863
|
+
tile_scheduler.advance_to_next_work()
|
|
1864
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1865
|
+
# End of persistent scheduler loop
|
|
1866
|
+
|
|
1867
|
+
@cute.jit
|
|
1868
|
+
def mma(
|
|
1869
|
+
self,
|
|
1870
|
+
tiled_mma_qk: cute.TiledMma,
|
|
1871
|
+
tiled_mma_pv: cute.TiledMma,
|
|
1872
|
+
tiled_mma_pv_rs: cute.TiledMma,
|
|
1873
|
+
# softmax: Softmax,
|
|
1874
|
+
# acc_O: cute.Tensor,
|
|
1875
|
+
mQ: cute.Tensor,
|
|
1876
|
+
mO: cute.Tensor,
|
|
1877
|
+
mLSE: Optional[cute.Tensor],
|
|
1878
|
+
sQ: cute.Tensor,
|
|
1879
|
+
sK: cute.Tensor,
|
|
1880
|
+
sVt: cute.Tensor,
|
|
1881
|
+
sP: Optional[cute.Tensor],
|
|
1882
|
+
sO: cute.Tensor,
|
|
1883
|
+
learnable_sink: Optional[cute.Tensor],
|
|
1884
|
+
pipeline_k: cutlass.pipeline.PipelineAsync,
|
|
1885
|
+
pipeline_v: cutlass.pipeline.PipelineAsync,
|
|
1886
|
+
mbar_ptr_Q: cutlass.Pointer,
|
|
1887
|
+
gmem_tiled_copy_Q: cute.TiledCopy,
|
|
1888
|
+
gmem_tiled_copy_O: cute.TiledCopy,
|
|
1889
|
+
tma_atom_O: Optional[cute.CopyAtom],
|
|
1890
|
+
tidx: Int32,
|
|
1891
|
+
softmax_scale_log2: Float32,
|
|
1892
|
+
softmax_scale: Optional[Float32],
|
|
1893
|
+
block_info: BlockInfo,
|
|
1894
|
+
SeqlenInfoCls: Callable,
|
|
1895
|
+
AttentionMaskCls: Callable,
|
|
1896
|
+
TileSchedulerCls: Callable,
|
|
1897
|
+
blocksparse_tensors: Optional[BlockSparseTensors],
|
|
1898
|
+
aux_tensors: Optional[list],
|
|
1899
|
+
fastdiv_mods=None,
|
|
1900
|
+
):
|
|
1901
|
+
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
|
|
1902
|
+
warp_group_thread_layout = cute.make_layout(
|
|
1903
|
+
self.num_mma_warp_groups, stride=self.num_threads_per_warp_group
|
|
1904
|
+
)
|
|
1905
|
+
thr_mma_qk = tiled_mma_qk.get_slice(tidx)
|
|
1906
|
+
wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx))
|
|
1907
|
+
wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx))
|
|
1908
|
+
tSrQ = tiled_mma_qk.make_fragment_A(wg_mma_qk.partition_A(sQ))
|
|
1909
|
+
tSrK = tiled_mma_qk.make_fragment_B(wg_mma_qk.partition_B(sK))
|
|
1910
|
+
if const_expr(self.mma_pv_is_rs):
|
|
1911
|
+
acc_S_shape = tiled_mma_qk.partition_shape_C((self.tile_m, self.tile_n))
|
|
1912
|
+
tOrP = cute.make_fragment(
|
|
1913
|
+
utils.convert_layout_acc_frgA(cute.make_layout(acc_S_shape)), self.dtype
|
|
1914
|
+
)
|
|
1915
|
+
else:
|
|
1916
|
+
tOrP = tiled_mma_pv.make_fragment_A(wg_mma_pv.partition_A(sP))
|
|
1917
|
+
tOrVt = tiled_mma_pv.make_fragment_B(wg_mma_pv.partition_B(sVt))
|
|
1918
|
+
|
|
1919
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
1920
|
+
# Smem copy atom tiling
|
|
1921
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
1922
|
+
smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype)
|
|
1923
|
+
smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx)
|
|
1924
|
+
# tPsP = smem_thr_copy_P.partition_D(sP_pi) if const_expr(sP_pi is not None) else None
|
|
1925
|
+
tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None
|
|
1926
|
+
# if cute.arch.thread_idx()[0] == 0:
|
|
1927
|
+
# cute.printf(sP_pi.layout, sP_pi.iterator)
|
|
1928
|
+
# cute.printf(sP.layout, sP.iterator)
|
|
1929
|
+
# cute.printf(tPsP.layout, tPsP.iterator)
|
|
1930
|
+
|
|
1931
|
+
self.mma_init()
|
|
1932
|
+
|
|
1933
|
+
acc_shape_O = tiled_mma_pv.partition_shape_C((self.tile_m, self.tile_hdimv))
|
|
1934
|
+
acc_O = cute.make_fragment(acc_shape_O, Float32)
|
|
1935
|
+
smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP)
|
|
1936
|
+
|
|
1937
|
+
mma_qk_fn = partial(
|
|
1938
|
+
sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK
|
|
1939
|
+
)
|
|
1940
|
+
mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt)
|
|
1941
|
+
|
|
1942
|
+
mma_one_n_block_all = partial(
|
|
1943
|
+
self.mma_one_n_block_intrawg_overlap
|
|
1944
|
+
if const_expr(self.intra_wg_overlap)
|
|
1945
|
+
else self.mma_one_n_block,
|
|
1946
|
+
mma_qk_fn=mma_qk_fn,
|
|
1947
|
+
tiled_mma_pv_rs=tiled_mma_pv_rs,
|
|
1948
|
+
pipeline_k=pipeline_k,
|
|
1949
|
+
pipeline_v=pipeline_v,
|
|
1950
|
+
acc_O=acc_O,
|
|
1951
|
+
tOrP=tOrP,
|
|
1952
|
+
smem_copy_params=smem_copy_params,
|
|
1953
|
+
check_inf=True,
|
|
1954
|
+
)
|
|
1955
|
+
|
|
1956
|
+
q_consumer_phase = Int32(0)
|
|
1957
|
+
kv_consumer_state = pipeline.make_pipeline_state(
|
|
1958
|
+
cutlass.pipeline.PipelineUserType.Consumer, self.num_stages
|
|
1959
|
+
)
|
|
1960
|
+
|
|
1961
|
+
tile_scheduler = TileSchedulerCls()
|
|
1962
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1963
|
+
softmax = Softmax.create(
|
|
1964
|
+
softmax_scale_log2,
|
|
1965
|
+
num_rows=acc_O.shape[0][0] * acc_O.shape[1],
|
|
1966
|
+
softmax_scale=softmax_scale,
|
|
1967
|
+
)
|
|
1968
|
+
|
|
1969
|
+
process_first_half_block = partial(
|
|
1970
|
+
self.first_half_block_overlap,
|
|
1971
|
+
mma_qk_fn=mma_qk_fn,
|
|
1972
|
+
pipeline_k=pipeline_k,
|
|
1973
|
+
tOrP=tOrP,
|
|
1974
|
+
smem_copy_params=smem_copy_params,
|
|
1975
|
+
softmax=softmax,
|
|
1976
|
+
)
|
|
1977
|
+
process_last_half_block = partial(
|
|
1978
|
+
self.last_half_block_overlap,
|
|
1979
|
+
pipeline_v=pipeline_v,
|
|
1980
|
+
mma_pv_fn=mma_pv_fn,
|
|
1981
|
+
)
|
|
1982
|
+
while work_tile.is_valid_tile:
|
|
1983
|
+
# if work_tile.is_valid_tile:
|
|
1984
|
+
|
|
1985
|
+
# shape: (atom_v_m * rest_m)
|
|
1986
|
+
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
|
1987
|
+
seqlen = SeqlenInfoCls(batch_idx)
|
|
1988
|
+
|
|
1989
|
+
# Recompute fastdiv_mods if necessary for varlen with aux_tensors
|
|
1990
|
+
recompute_fastdiv_mods_q = cutlass.const_expr(
|
|
1991
|
+
aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
|
|
1992
|
+
)
|
|
1993
|
+
recompute_fastdiv_mods_k = cutlass.const_expr(
|
|
1994
|
+
aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k)
|
|
1995
|
+
)
|
|
1996
|
+
if cutlass.const_expr(fastdiv_mods is not None):
|
|
1997
|
+
seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
|
|
1998
|
+
fastdiv_mods = (
|
|
1999
|
+
seqlen_q_divmod
|
|
2000
|
+
if not recompute_fastdiv_mods_q
|
|
2001
|
+
else FastDivmodDivisor(seqlen.seqlen_q),
|
|
2002
|
+
seqlen_k_divmod
|
|
2003
|
+
if not recompute_fastdiv_mods_k
|
|
2004
|
+
else FastDivmodDivisor(seqlen.seqlen_k),
|
|
2005
|
+
)
|
|
2006
|
+
|
|
2007
|
+
mask = AttentionMaskCls(seqlen)
|
|
2008
|
+
mask_fn = partial(
|
|
2009
|
+
mask.apply_mask,
|
|
2010
|
+
batch_idx=batch_idx,
|
|
2011
|
+
head_idx=head_idx,
|
|
2012
|
+
m_block=m_block,
|
|
2013
|
+
thr_mma=thr_mma_qk,
|
|
2014
|
+
mask_causal=self.is_causal,
|
|
2015
|
+
mask_local=self.is_local,
|
|
2016
|
+
aux_tensors=aux_tensors,
|
|
2017
|
+
fastdiv_mods=fastdiv_mods,
|
|
2018
|
+
)
|
|
2019
|
+
score_mod_fn = None
|
|
2020
|
+
if const_expr(self.score_mod is not None):
|
|
2021
|
+
score_mod_fn = partial(
|
|
2022
|
+
self.apply_score_mod,
|
|
2023
|
+
thr_mma_qk,
|
|
2024
|
+
batch_idx,
|
|
2025
|
+
head_idx,
|
|
2026
|
+
m_block,
|
|
2027
|
+
softmax_scale=softmax_scale,
|
|
2028
|
+
aux_tensors=aux_tensors,
|
|
2029
|
+
fastdiv_mods=fastdiv_mods,
|
|
2030
|
+
)
|
|
2031
|
+
mma_one_n_block = partial(
|
|
2032
|
+
mma_one_n_block_all,
|
|
2033
|
+
seqlen=seqlen,
|
|
2034
|
+
softmax=softmax,
|
|
2035
|
+
score_mod_fn=score_mod_fn,
|
|
2036
|
+
)
|
|
2037
|
+
# Load Q if not TMA_Q
|
|
2038
|
+
if const_expr(not self.use_tma_Q):
|
|
2039
|
+
pack_gqa = PackGQA(
|
|
2040
|
+
self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead
|
|
2041
|
+
)
|
|
2042
|
+
mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
|
|
2043
|
+
# gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx)
|
|
2044
|
+
# gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))
|
|
2045
|
+
# self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q,
|
|
2046
|
+
# headdim=mQ.shape[1])
|
|
2047
|
+
pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q)
|
|
2048
|
+
cute.arch.cp_async_mbarrier_arrive_noinc(mbar_ptr_Q)
|
|
2049
|
+
|
|
2050
|
+
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
|
|
2051
|
+
if const_expr(not self.use_tma_Q):
|
|
2052
|
+
cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase)
|
|
2053
|
+
q_consumer_phase ^= 1
|
|
2054
|
+
# For performance reason, we separate out two kinds of iterations:
|
|
2055
|
+
# those that need masking on S, and those that don't.
|
|
2056
|
+
# We need masking on S for the very last block when K and V has length not multiple of tile_n.
|
|
2057
|
+
# We also need masking on S if it's causal, for the last several blocks.
|
|
2058
|
+
# softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True
|
|
2059
|
+
O_should_accumulate = False
|
|
2060
|
+
|
|
2061
|
+
# ==========================================
|
|
2062
|
+
# MAINLOOP
|
|
2063
|
+
# ==========================================
|
|
2064
|
+
if const_expr(not self.use_block_sparsity):
|
|
2065
|
+
# ==========================================
|
|
2066
|
+
# No block-sparsity (original path)
|
|
2067
|
+
# ==========================================
|
|
2068
|
+
# First iteration with seqlen masking
|
|
2069
|
+
if const_expr(self.intra_wg_overlap):
|
|
2070
|
+
kv_consumer_state = process_first_half_block(
|
|
2071
|
+
n_block=n_block_max - 1,
|
|
2072
|
+
seqlen=seqlen,
|
|
2073
|
+
kv_consumer_state=kv_consumer_state,
|
|
2074
|
+
mask_fn=partial(mask_fn, mask_mod=self.mask_mod),
|
|
2075
|
+
score_mod_fn=score_mod_fn,
|
|
2076
|
+
is_first_block=True,
|
|
2077
|
+
)
|
|
2078
|
+
# Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter
|
|
2079
|
+
# acc_O.fill(0.0)
|
|
2080
|
+
else:
|
|
2081
|
+
self.warp_scheduler_barrier_sync()
|
|
2082
|
+
kv_consumer_state = mma_one_n_block(
|
|
2083
|
+
kv_consumer_state,
|
|
2084
|
+
n_block=n_block_max - 1,
|
|
2085
|
+
seqlen=seqlen,
|
|
2086
|
+
mma_pv_fn=partial(mma_pv_fn, zero_init=True),
|
|
2087
|
+
is_first_n_block=True,
|
|
2088
|
+
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True),
|
|
2089
|
+
)
|
|
2090
|
+
O_should_accumulate = True
|
|
2091
|
+
# if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min)
|
|
2092
|
+
n_block_max -= 1
|
|
2093
|
+
# Next couple of iterations with causal masking
|
|
2094
|
+
if const_expr(self.is_causal or self.is_local):
|
|
2095
|
+
n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask(
|
|
2096
|
+
seqlen, m_block, n_block_min
|
|
2097
|
+
)
|
|
2098
|
+
# if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask)
|
|
2099
|
+
for n_tile in cutlass.range(
|
|
2100
|
+
n_block_max - n_block_min_causal_local_mask, unroll=1
|
|
2101
|
+
):
|
|
2102
|
+
kv_consumer_state = mma_one_n_block(
|
|
2103
|
+
kv_consumer_state,
|
|
2104
|
+
n_block=n_block_max - 1 - n_tile,
|
|
2105
|
+
seqlen=seqlen,
|
|
2106
|
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
|
2107
|
+
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
|
|
2108
|
+
)
|
|
2109
|
+
O_should_accumulate = True
|
|
2110
|
+
n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask)
|
|
2111
|
+
# The remaining iterations have no masking
|
|
2112
|
+
n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask(
|
|
2113
|
+
seqlen, m_block, n_block_min
|
|
2114
|
+
)
|
|
2115
|
+
# if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min)
|
|
2116
|
+
for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1):
|
|
2117
|
+
kv_consumer_state = mma_one_n_block(
|
|
2118
|
+
kv_consumer_state,
|
|
2119
|
+
n_block=n_block_max - 1 - n_tile,
|
|
2120
|
+
seqlen=seqlen,
|
|
2121
|
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
|
2122
|
+
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
|
|
2123
|
+
)
|
|
2124
|
+
O_should_accumulate = True
|
|
2125
|
+
# Separate iterations with local masking on the left
|
|
2126
|
+
if const_expr(self.is_local and block_info.window_size_left is not None):
|
|
2127
|
+
n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask)
|
|
2128
|
+
for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1):
|
|
2129
|
+
kv_consumer_state = mma_one_n_block(
|
|
2130
|
+
kv_consumer_state,
|
|
2131
|
+
n_block=n_block_max - 1 - n_tile,
|
|
2132
|
+
seqlen=seqlen,
|
|
2133
|
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
|
2134
|
+
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
|
|
2135
|
+
)
|
|
2136
|
+
O_should_accumulate = True
|
|
2137
|
+
# Last "half" iteration
|
|
2138
|
+
if const_expr(self.intra_wg_overlap):
|
|
2139
|
+
kv_consumer_state = process_last_half_block(
|
|
2140
|
+
kv_consumer_state=kv_consumer_state,
|
|
2141
|
+
zero_init=not O_should_accumulate,
|
|
2142
|
+
)
|
|
2143
|
+
O_should_accumulate = True
|
|
2144
|
+
else:
|
|
2145
|
+
self.warp_scheduler_barrier_arrive()
|
|
2146
|
+
|
|
2147
|
+
else:
|
|
2148
|
+
# ==========================================
|
|
2149
|
+
# Block sparsity
|
|
2150
|
+
# ==========================================
|
|
2151
|
+
kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads(
|
|
2152
|
+
blocksparse_tensors,
|
|
2153
|
+
batch_idx,
|
|
2154
|
+
head_idx,
|
|
2155
|
+
m_block,
|
|
2156
|
+
seqlen,
|
|
2157
|
+
kv_consumer_state,
|
|
2158
|
+
mma_pv_fn,
|
|
2159
|
+
mma_one_n_block,
|
|
2160
|
+
process_first_half_block,
|
|
2161
|
+
process_last_half_block,
|
|
2162
|
+
mask_fn,
|
|
2163
|
+
score_mod_fn,
|
|
2164
|
+
O_should_accumulate,
|
|
2165
|
+
self.mask_mod,
|
|
2166
|
+
fastdiv_mods,
|
|
2167
|
+
self.intra_wg_overlap,
|
|
2168
|
+
self.warp_scheduler_barrier_sync,
|
|
2169
|
+
self.warp_scheduler_barrier_arrive,
|
|
2170
|
+
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
|
2171
|
+
)
|
|
2172
|
+
|
|
2173
|
+
# Handle empty case (when no blocks to process)
|
|
2174
|
+
if not processed_any:
|
|
2175
|
+
softmax.reset()
|
|
2176
|
+
acc_O.fill(0.0)
|
|
2177
|
+
|
|
2178
|
+
sink_val = None
|
|
2179
|
+
if const_expr(learnable_sink is not None):
|
|
2180
|
+
if const_expr(not self.pack_gqa):
|
|
2181
|
+
sink_val = Float32(learnable_sink[head_idx])
|
|
2182
|
+
else: # Each thread might have a different sink value due to different q_head
|
|
2183
|
+
sink_val = cute.make_fragment_like(softmax.row_max, Float32)
|
|
2184
|
+
cS = cute.make_identity_tensor((self.tile_m, self.tile_n))
|
|
2185
|
+
tScS_mn = utils.make_acc_tensor_mn_view(thr_mma_qk.partition_C(cS))
|
|
2186
|
+
for r in cutlass.range(cute.size(sink_val), unroll_full=True):
|
|
2187
|
+
row = m_block * self.tile_m + tScS_mn[r][0]
|
|
2188
|
+
q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead
|
|
2189
|
+
sink_val[r] = Float32(learnable_sink[q_head_idx])
|
|
2190
|
+
|
|
2191
|
+
# normalize acc_O by row_sum and calculate the lse
|
|
2192
|
+
row_scale = softmax.finalize(sink_val=sink_val)
|
|
2193
|
+
softmax.rescale_O(acc_O, row_scale)
|
|
2194
|
+
|
|
2195
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
2196
|
+
# Epilogue
|
|
2197
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
2198
|
+
self.epilogue(
|
|
2199
|
+
acc_O,
|
|
2200
|
+
softmax.row_sum,
|
|
2201
|
+
mO,
|
|
2202
|
+
mLSE,
|
|
2203
|
+
sO,
|
|
2204
|
+
seqlen,
|
|
2205
|
+
gmem_tiled_copy_O,
|
|
2206
|
+
tma_atom_O,
|
|
2207
|
+
tiled_mma_pv,
|
|
2208
|
+
tidx,
|
|
2209
|
+
m_block,
|
|
2210
|
+
head_idx,
|
|
2211
|
+
batch_idx,
|
|
2212
|
+
)
|
|
2213
|
+
|
|
2214
|
+
tile_scheduler.advance_to_next_work()
|
|
2215
|
+
work_tile = tile_scheduler.get_current_work()
|
|
2216
|
+
|
|
2217
|
+
|
|
2218
|
+
@cute.jit
|
|
2219
|
+
def first_half_block_overlap(
|
|
2220
|
+
self,
|
|
2221
|
+
n_block: Int32,
|
|
2222
|
+
mma_qk_fn: Callable,
|
|
2223
|
+
kv_consumer_state,
|
|
2224
|
+
pipeline_k,
|
|
2225
|
+
tOrP: cute.Tensor,
|
|
2226
|
+
smem_copy_params: SimpleNamespace,
|
|
2227
|
+
softmax: Softmax,
|
|
2228
|
+
seqlen: SeqlenInfoQK,
|
|
2229
|
+
mask_fn: Callable = None,
|
|
2230
|
+
score_mod_fn: Optional[Callable] = None,
|
|
2231
|
+
is_first_block: bool = False,
|
|
2232
|
+
):
|
|
2233
|
+
"""Processes the first half block when using intra-warpgroup-overlap"""
|
|
2234
|
+
|
|
2235
|
+
pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state))
|
|
2236
|
+
acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0)
|
|
2237
|
+
pipeline_k.consumer_release(kv_consumer_state)
|
|
2238
|
+
|
|
2239
|
+
# Apply score modification if present
|
|
2240
|
+
if const_expr(score_mod_fn is not None):
|
|
2241
|
+
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
|
|
2242
|
+
|
|
2243
|
+
# Apply mask; mask_seqlen always True for first block
|
|
2244
|
+
# Caveat: if full block further right than mask block, seqlen masking is redundant;
|
|
2245
|
+
# however, masking is being applied anyway, so essentially no perf hit
|
|
2246
|
+
mask_fn(acc_S, n_block=n_block, mask_seqlen=True)
|
|
2247
|
+
|
|
2248
|
+
softmax.online_softmax(acc_S, is_first=is_first_block)
|
|
2249
|
+
|
|
2250
|
+
tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout))
|
|
2251
|
+
tOrP_cur = (
|
|
2252
|
+
tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
|
|
2253
|
+
)
|
|
2254
|
+
tOrP_cur.store(tOrP_acc.load().to(self.dtype))
|
|
2255
|
+
|
|
2256
|
+
# if pv gemm not rs
|
|
2257
|
+
if const_expr(not self.mma_pv_is_rs):
|
|
2258
|
+
tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
|
|
2259
|
+
cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
|
|
2260
|
+
# Fence and barrier to make smem store visible to WGMMA
|
|
2261
|
+
cute.arch.fence_proxy(
|
|
2262
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
2263
|
+
)
|
|
2264
|
+
cute.arch.sync_warp()
|
|
2265
|
+
|
|
2266
|
+
return kv_consumer_state
|
|
2267
|
+
|
|
2268
|
+
@cute.jit
|
|
2269
|
+
def last_half_block_overlap(
|
|
2270
|
+
self,
|
|
2271
|
+
kv_consumer_state,
|
|
2272
|
+
pipeline_v,
|
|
2273
|
+
mma_pv_fn: Callable,
|
|
2274
|
+
zero_init: bool,
|
|
2275
|
+
):
|
|
2276
|
+
"""Processes the final PV GEMM when using intra-warpgroup-overlap"""
|
|
2277
|
+
|
|
2278
|
+
pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state))
|
|
2279
|
+
mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0)
|
|
2280
|
+
pipeline_v.consumer_release(kv_consumer_state)
|
|
2281
|
+
kv_consumer_state.advance()
|
|
2282
|
+
return kv_consumer_state
|
|
2283
|
+
|
|
2284
|
+
@cute.jit
|
|
2285
|
+
def mma_one_n_block(
|
|
2286
|
+
self,
|
|
2287
|
+
smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,
|
|
2288
|
+
n_block: Int32,
|
|
2289
|
+
mma_qk_fn: Callable,
|
|
2290
|
+
mma_pv_fn: Callable,
|
|
2291
|
+
tiled_mma_pv_rs: cute.TiledMma,
|
|
2292
|
+
pipeline_k: cutlass.pipeline.PipelineAsync,
|
|
2293
|
+
pipeline_v: cutlass.pipeline.PipelineAsync,
|
|
2294
|
+
acc_O: cute.Tensor,
|
|
2295
|
+
tOrP: cute.Tensor,
|
|
2296
|
+
smem_copy_params: SimpleNamespace,
|
|
2297
|
+
softmax: Softmax,
|
|
2298
|
+
seqlen: SeqlenInfoQK,
|
|
2299
|
+
score_mod_fn: Optional[Callable] = None,
|
|
2300
|
+
mask_fn: Optional[Callable] = None,
|
|
2301
|
+
is_first_n_block: cutlass.Constexpr = False,
|
|
2302
|
+
check_inf: cutlass.Constexpr = True,
|
|
2303
|
+
):
|
|
2304
|
+
pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))
|
|
2305
|
+
# S = Q @ K.T
|
|
2306
|
+
acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)
|
|
2307
|
+
self.warp_scheduler_barrier_arrive()
|
|
2308
|
+
warpgroup.wait_group(0)
|
|
2309
|
+
pipeline_k.consumer_release(smem_pipe_read)
|
|
2310
|
+
|
|
2311
|
+
# handle score mods and masking
|
|
2312
|
+
if const_expr(score_mod_fn is not None):
|
|
2313
|
+
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
|
|
2314
|
+
if const_expr(mask_fn is not None):
|
|
2315
|
+
mask_fn(acc_S=acc_S, n_block=n_block)
|
|
2316
|
+
|
|
2317
|
+
row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf)
|
|
2318
|
+
# if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S))
|
|
2319
|
+
tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout))
|
|
2320
|
+
tOrP_cur = (
|
|
2321
|
+
tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
|
|
2322
|
+
)
|
|
2323
|
+
# tOrP.store(tOrP_acc.load().to(self.dtype))
|
|
2324
|
+
# the "to(self.dtype)" conversion fails to vectorize for block sizes other
|
|
2325
|
+
# than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of
|
|
2326
|
+
# 2 elements. So we just call ptx directly.
|
|
2327
|
+
utils.cvt_f16(tOrP_acc, tOrP_cur)
|
|
2328
|
+
if const_expr(not self.mma_pv_is_rs):
|
|
2329
|
+
tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
|
|
2330
|
+
cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
|
|
2331
|
+
softmax.rescale_O(acc_O, row_scale)
|
|
2332
|
+
if const_expr(not self.mma_pv_is_rs):
|
|
2333
|
+
# Fence and barrier to make sure smem store is visible to WGMMA
|
|
2334
|
+
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
|
|
2335
|
+
cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
|
|
2336
|
+
pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read))
|
|
2337
|
+
self.warp_scheduler_barrier_sync()
|
|
2338
|
+
# O += P @ V
|
|
2339
|
+
mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0)
|
|
2340
|
+
pipeline_v.consumer_release(smem_pipe_read)
|
|
2341
|
+
smem_pipe_read.advance()
|
|
2342
|
+
return smem_pipe_read
|
|
2343
|
+
|
|
2344
|
+
@cute.jit
|
|
2345
|
+
def mma_one_n_block_intrawg_overlap(
|
|
2346
|
+
self,
|
|
2347
|
+
smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,
|
|
2348
|
+
n_block: Int32,
|
|
2349
|
+
mma_qk_fn: Callable,
|
|
2350
|
+
mma_pv_fn: Callable,
|
|
2351
|
+
tiled_mma_pv_rs: cute.TiledMma,
|
|
2352
|
+
pipeline_k: cutlass.pipeline.PipelineAsync,
|
|
2353
|
+
pipeline_v: cutlass.pipeline.PipelineAsync,
|
|
2354
|
+
acc_O: cute.Tensor,
|
|
2355
|
+
tOrP: cute.Tensor,
|
|
2356
|
+
smem_copy_params: SimpleNamespace,
|
|
2357
|
+
softmax: Softmax,
|
|
2358
|
+
seqlen: SeqlenInfoQK,
|
|
2359
|
+
score_mod_fn: Optional[Callable] = None,
|
|
2360
|
+
mask_fn: Optional[Callable] = None,
|
|
2361
|
+
check_inf: cutlass.Constexpr = True,
|
|
2362
|
+
):
|
|
2363
|
+
smem_pipe_read_v = smem_pipe_read.clone()
|
|
2364
|
+
smem_pipe_read.advance()
|
|
2365
|
+
pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))
|
|
2366
|
+
self.warp_scheduler_barrier_sync()
|
|
2367
|
+
# S = Q @ K.T
|
|
2368
|
+
acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)
|
|
2369
|
+
pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v))
|
|
2370
|
+
# O += P @ V
|
|
2371
|
+
mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1)
|
|
2372
|
+
self.warp_scheduler_barrier_arrive()
|
|
2373
|
+
warpgroup.wait_group(1)
|
|
2374
|
+
pipeline_k.consumer_release(smem_pipe_read)
|
|
2375
|
+
|
|
2376
|
+
# handle score mods and masking
|
|
2377
|
+
if const_expr(score_mod_fn is not None):
|
|
2378
|
+
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
|
|
2379
|
+
if const_expr(mask_fn is not None):
|
|
2380
|
+
mask_fn(acc_S=acc_S, n_block=n_block)
|
|
2381
|
+
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S))
|
|
2382
|
+
|
|
2383
|
+
row_scale = softmax.online_softmax(acc_S, check_inf=check_inf)
|
|
2384
|
+
warpgroup.wait_group(0)
|
|
2385
|
+
pipeline_v.consumer_release(smem_pipe_read_v)
|
|
2386
|
+
tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout))
|
|
2387
|
+
tOrP_cur = (
|
|
2388
|
+
tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
|
|
2389
|
+
)
|
|
2390
|
+
# tOrP_cur.store(tOrP_acc.load().to(self.dtype))
|
|
2391
|
+
# the "to(self.dtype)" conversion fails to vectorize for block sizes other
|
|
2392
|
+
# than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of
|
|
2393
|
+
# 2 elements. So we just call ptx directly.
|
|
2394
|
+
utils.cvt_f16(tOrP_acc, tOrP_cur)
|
|
2395
|
+
if const_expr(not self.mma_pv_is_rs):
|
|
2396
|
+
tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
|
|
2397
|
+
cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
|
|
2398
|
+
softmax.rescale_O(acc_O, row_scale)
|
|
2399
|
+
if const_expr(not self.mma_pv_is_rs):
|
|
2400
|
+
# Fence and barrier to make sure smem store is visible to WGMMA
|
|
2401
|
+
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
|
|
2402
|
+
cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
|
|
2403
|
+
return smem_pipe_read
|
|
2404
|
+
|
|
2405
|
+
@cute.jit
|
|
2406
|
+
def mma_init(self):
|
|
2407
|
+
warp_group_idx = utils.canonical_warp_group_idx(sync=False)
|
|
2408
|
+
if const_expr(self.use_scheduler_barrier):
|
|
2409
|
+
if warp_group_idx == 1:
|
|
2410
|
+
cute.arch.barrier_arrive(
|
|
2411
|
+
barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1),
|
|
2412
|
+
number_of_threads=2 * self.num_threads_per_warp_group,
|
|
2413
|
+
)
|
|
2414
|
+
|
|
2415
|
+
@cute.jit
|
|
2416
|
+
def apply_score_mod(
|
|
2417
|
+
self,
|
|
2418
|
+
thr_mma_qk,
|
|
2419
|
+
batch_idx,
|
|
2420
|
+
head_idx,
|
|
2421
|
+
m_block,
|
|
2422
|
+
acc_S,
|
|
2423
|
+
n_block,
|
|
2424
|
+
softmax_scale,
|
|
2425
|
+
seqlen,
|
|
2426
|
+
aux_tensors: Optional[list] = None,
|
|
2427
|
+
fastdiv_mods=None,
|
|
2428
|
+
):
|
|
2429
|
+
# Prepare index tensor
|
|
2430
|
+
cS = cute.make_identity_tensor((self.tile_m, self.tile_n))
|
|
2431
|
+
cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS)
|
|
2432
|
+
tScS = thr_mma_qk.partition_C(cS)
|
|
2433
|
+
|
|
2434
|
+
apply_score_mod_inner(
|
|
2435
|
+
acc_S,
|
|
2436
|
+
tScS,
|
|
2437
|
+
self.score_mod,
|
|
2438
|
+
batch_idx,
|
|
2439
|
+
head_idx,
|
|
2440
|
+
softmax_scale,
|
|
2441
|
+
self.vec_size,
|
|
2442
|
+
self.qk_acc_dtype,
|
|
2443
|
+
aux_tensors,
|
|
2444
|
+
fastdiv_mods,
|
|
2445
|
+
seqlen_info=seqlen,
|
|
2446
|
+
constant_q_idx=None,
|
|
2447
|
+
qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
|
2448
|
+
)
|
|
2449
|
+
|
|
2450
|
+
def warp_scheduler_barrier_sync(self):
|
|
2451
|
+
if const_expr(self.use_scheduler_barrier):
|
|
2452
|
+
cute.arch.barrier(
|
|
2453
|
+
barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1)
|
|
2454
|
+
- 1
|
|
2455
|
+
+ utils.canonical_warp_group_idx(sync=False),
|
|
2456
|
+
number_of_threads=2 * self.num_threads_per_warp_group,
|
|
2457
|
+
)
|
|
2458
|
+
|
|
2459
|
+
def warp_scheduler_barrier_arrive(self):
|
|
2460
|
+
if const_expr(self.use_scheduler_barrier):
|
|
2461
|
+
assert self.num_mma_warp_groups in [2, 3]
|
|
2462
|
+
cur_wg = utils.canonical_warp_group_idx(sync=False) - 1
|
|
2463
|
+
if const_expr(self.num_mma_warp_groups == 2):
|
|
2464
|
+
next_wg = 1 - cur_wg
|
|
2465
|
+
else:
|
|
2466
|
+
t = cur_wg + 1
|
|
2467
|
+
next_wg = t % self.num_mma_warp_groups
|
|
2468
|
+
cute.arch.barrier_arrive(
|
|
2469
|
+
barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg,
|
|
2470
|
+
number_of_threads=2 * self.num_threads_per_warp_group,
|
|
2471
|
+
)
|