mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2727 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Supported features:
|
|
3
|
+
# - BF16 & FP16 dtype
|
|
4
|
+
# - noncausal & causal attention
|
|
5
|
+
# - MHA, GQA, MQA
|
|
6
|
+
# - hdim 64, 96, 128, (192, 128).
|
|
7
|
+
# - varlen
|
|
8
|
+
# - sliding window
|
|
9
|
+
# - split-kv
|
|
10
|
+
# Unsupported features that will be added later:
|
|
11
|
+
# - page size != 128
|
|
12
|
+
# - more hdim (192, 256)
|
|
13
|
+
# Based on the cutlass example and cute-dsl example:
|
|
14
|
+
# https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha
|
|
15
|
+
# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py
|
|
16
|
+
|
|
17
|
+
import enum
|
|
18
|
+
import math
|
|
19
|
+
from typing import Type, Tuple, Callable, Optional, Literal
|
|
20
|
+
from functools import partial
|
|
21
|
+
|
|
22
|
+
import cuda.bindings.driver as cuda
|
|
23
|
+
|
|
24
|
+
import cutlass
|
|
25
|
+
import cutlass.cute as cute
|
|
26
|
+
from cutlass import Float32, Int32, const_expr
|
|
27
|
+
from cutlass.cute.nvgpu import cpasync
|
|
28
|
+
import cutlass.cute.nvgpu.tcgen05 as tcgen05
|
|
29
|
+
import cutlass.utils.blackwell_helpers as sm100_utils_basic
|
|
30
|
+
|
|
31
|
+
from mslk.attention.flash_attn.paged_kv import PagedKVManager
|
|
32
|
+
import mslk.attention.flash_attn.utils as utils
|
|
33
|
+
from mslk.attention.flash_attn import copy_utils
|
|
34
|
+
import mslk.attention.flash_attn.pipeline as pipeline
|
|
35
|
+
from mslk.attention.flash_attn.mask import AttentionMask
|
|
36
|
+
from mslk.attention.flash_attn.softmax import SoftmaxSm100, apply_score_mod_inner
|
|
37
|
+
from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
|
|
38
|
+
from mslk.attention.flash_attn.block_info import BlockInfo
|
|
39
|
+
from mslk.attention.flash_attn.block_sparsity import BlockSparseTensors
|
|
40
|
+
from mslk.attention.flash_attn.block_sparse_utils import (
|
|
41
|
+
get_total_block_count,
|
|
42
|
+
produce_block_sparse_loads_sm100,
|
|
43
|
+
softmax_block_sparse_sm100,
|
|
44
|
+
handle_block_sparse_empty_tile_correction_sm100,
|
|
45
|
+
)
|
|
46
|
+
from mslk.attention.flash_attn.pack_gqa import PackGQA
|
|
47
|
+
from mslk.attention.flash_attn import mma_sm100_desc as sm100_desc
|
|
48
|
+
from mslk.attention.flash_attn import blackwell_helpers as sm100_utils
|
|
49
|
+
from cutlass.cute import FastDivmodDivisor
|
|
50
|
+
from mslk.attention.flash_attn.tile_scheduler import (
|
|
51
|
+
TileSchedulerArguments,
|
|
52
|
+
SingleTileScheduler,
|
|
53
|
+
StaticPersistentTileScheduler,
|
|
54
|
+
SingleTileLPTScheduler,
|
|
55
|
+
SingleTileVarlenScheduler,
|
|
56
|
+
ParamsBase,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class NamedBarrierFwd(enum.IntEnum):
|
|
61
|
+
Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
|
|
62
|
+
# WarpSchedulerWG1 = enum.auto()
|
|
63
|
+
# WarpSchedulerWG2 = enum.auto()
|
|
64
|
+
# WarpSchedulerWG3 = enum.auto()
|
|
65
|
+
# PFull = enum.auto()
|
|
66
|
+
# PEmpty = enum.auto()
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class FlashAttentionForwardSm100:
|
|
70
|
+
arch = 100
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
# dtype: Type[cutlass.Numeric],
|
|
75
|
+
head_dim: int,
|
|
76
|
+
head_dim_v: Optional[int] = None,
|
|
77
|
+
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
|
|
78
|
+
is_causal: bool = False,
|
|
79
|
+
is_local: bool = False,
|
|
80
|
+
is_split_kv: bool = False,
|
|
81
|
+
pack_gqa: bool = False,
|
|
82
|
+
m_block_size: int = 128,
|
|
83
|
+
n_block_size: int = 128,
|
|
84
|
+
q_stage: cutlass.Constexpr[int] = 2,
|
|
85
|
+
is_persistent: bool = True,
|
|
86
|
+
score_mod: cutlass.Constexpr | None = None,
|
|
87
|
+
mask_mod: cutlass.Constexpr | None = None,
|
|
88
|
+
has_aux_tensors: cutlass.Constexpr = False,
|
|
89
|
+
paged_kv_non_tma: bool = False,
|
|
90
|
+
is_varlen_q: bool = False,
|
|
91
|
+
):
|
|
92
|
+
self.use_tma_KV = not paged_kv_non_tma
|
|
93
|
+
# self.dtype = dtype
|
|
94
|
+
# padding head_dim to a multiple of 16 as k_block_size
|
|
95
|
+
hdim_multiple_of = 16
|
|
96
|
+
self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
|
|
97
|
+
head_dim_v = head_dim_v if head_dim_v is not None else head_dim
|
|
98
|
+
self.same_hdim_kv = head_dim == head_dim_v
|
|
99
|
+
self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)
|
|
100
|
+
self.same_hdim_kv_padded = self.head_dim_padded == self.head_dim_v_padded
|
|
101
|
+
self.check_hdim_oob = head_dim != self.head_dim_padded
|
|
102
|
+
self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded
|
|
103
|
+
self.m_block_size = m_block_size
|
|
104
|
+
self.n_block_size = n_block_size
|
|
105
|
+
self.q_stage = q_stage
|
|
106
|
+
assert self.q_stage in [1, 2]
|
|
107
|
+
|
|
108
|
+
# 2 Q tile per CTA
|
|
109
|
+
self.cta_tiler = (self.q_stage * m_block_size, n_block_size, self.head_dim_padded)
|
|
110
|
+
self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim_padded)
|
|
111
|
+
self.mma_tiler_pv = (m_block_size, self.head_dim_v_padded, n_block_size)
|
|
112
|
+
self.qk_acc_dtype = Float32
|
|
113
|
+
self.pv_acc_dtype = Float32
|
|
114
|
+
self.cluster_shape_mn = (1, 1)
|
|
115
|
+
self.is_persistent = is_persistent
|
|
116
|
+
self.is_causal = is_causal
|
|
117
|
+
self.is_local = is_local
|
|
118
|
+
self.is_varlen_q = is_varlen_q
|
|
119
|
+
self.use_correction_warps_for_epi = is_varlen_q
|
|
120
|
+
self.qhead_per_kvhead = qhead_per_kvhead
|
|
121
|
+
self.is_split_kv = is_split_kv
|
|
122
|
+
self.pack_gqa = pack_gqa
|
|
123
|
+
if pack_gqa:
|
|
124
|
+
assert m_block_size % self.qhead_per_kvhead == 0, (
|
|
125
|
+
"For PackGQA, m_block_size must be divisible by qhead_per_kvhead"
|
|
126
|
+
)
|
|
127
|
+
assert not (self.is_split_kv and self.head_dim_v_padded >= 192), (
|
|
128
|
+
"SplitKV is not supported for hdim >= 192"
|
|
129
|
+
)
|
|
130
|
+
self.score_mod = score_mod
|
|
131
|
+
self.mask_mod = mask_mod
|
|
132
|
+
if cutlass.const_expr(has_aux_tensors):
|
|
133
|
+
self.vec_size: cutlass.Constexpr = 1
|
|
134
|
+
else:
|
|
135
|
+
self.vec_size: cutlass.Constexpr = 2
|
|
136
|
+
# Does S1 need to wait for S0 to finish
|
|
137
|
+
# self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local)
|
|
138
|
+
self.s0_s1_barrier = False
|
|
139
|
+
self.overlap_sO_sQ = (
|
|
140
|
+
(self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or
|
|
141
|
+
(self.head_dim_v_padded >= 128 and self.is_split_kv)
|
|
142
|
+
)
|
|
143
|
+
if self.overlap_sO_sQ:
|
|
144
|
+
self.is_persistent = False
|
|
145
|
+
|
|
146
|
+
assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), (
|
|
147
|
+
"Paged KV does not support irregular head dim"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
self.softmax0_warp_ids = (0, 1, 2, 3)
|
|
151
|
+
self.softmax1_warp_ids = (4, 5, 6, 7)
|
|
152
|
+
self.correction_warp_ids = (8, 9, 10, 11)
|
|
153
|
+
self.mma_warp_id = 12
|
|
154
|
+
self.epilogue_warp_ids = (13,)
|
|
155
|
+
self.load_warp_ids = (14,)
|
|
156
|
+
self.empty_warp_ids = (15,)
|
|
157
|
+
SM100_TMEM_CAPACITY_COLUMNS = 512
|
|
158
|
+
self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
|
|
159
|
+
|
|
160
|
+
self.threads_per_cta = cute.arch.WARP_SIZE * len(
|
|
161
|
+
(
|
|
162
|
+
*self.softmax0_warp_ids,
|
|
163
|
+
*self.softmax1_warp_ids,
|
|
164
|
+
*self.correction_warp_ids,
|
|
165
|
+
self.mma_warp_id,
|
|
166
|
+
*self.load_warp_ids,
|
|
167
|
+
*self.epilogue_warp_ids,
|
|
168
|
+
*self.empty_warp_ids,
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
if self.q_stage == 1:
|
|
173
|
+
if not self.use_tma_KV:
|
|
174
|
+
self.empty_warp_ids = self.empty_warp_ids + self.load_warp_ids
|
|
175
|
+
self.load_warp_ids = self.softmax1_warp_ids
|
|
176
|
+
else:
|
|
177
|
+
self.empty_warp_ids = self.empty_warp_ids + self.softmax1_warp_ids
|
|
178
|
+
self.softmax1_warp_ids = ()
|
|
179
|
+
elif not self.use_tma_KV:
|
|
180
|
+
self.load_warp_ids = (14, 15)
|
|
181
|
+
self.empty_warp_ids = ()
|
|
182
|
+
|
|
183
|
+
if self.use_correction_warps_for_epi:
|
|
184
|
+
self.empty_warp_ids = self.empty_warp_ids + self.epilogue_warp_ids
|
|
185
|
+
self.epilogue_warp_ids = self.correction_warp_ids
|
|
186
|
+
elif self.is_varlen_q: # fallback
|
|
187
|
+
self.epilogue_warp_ids = (13, 14)
|
|
188
|
+
|
|
189
|
+
self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128
|
|
190
|
+
self.tmem_o_offset = [
|
|
191
|
+
self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded
|
|
192
|
+
for i in range(self.q_stage)
|
|
193
|
+
] # e.g., 256, 384
|
|
194
|
+
self.tmem_total = self.tmem_o_offset[-1] + self.head_dim_v_padded
|
|
195
|
+
assert self.tmem_total <= SM100_TMEM_CAPACITY_COLUMNS
|
|
196
|
+
self.tmem_s_to_p_offset = self.n_block_size // 2
|
|
197
|
+
self.tmem_p_offset = [
|
|
198
|
+
self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2)
|
|
199
|
+
] # 0, 128
|
|
200
|
+
|
|
201
|
+
# vec buffer for row_max & row_sum
|
|
202
|
+
self.tmem_vec_offset = self.tmem_s_offset
|
|
203
|
+
|
|
204
|
+
if self.head_dim_padded < 96:
|
|
205
|
+
self.num_regs_softmax = 200
|
|
206
|
+
self.num_regs_correction = 64
|
|
207
|
+
self.num_regs_other = 48
|
|
208
|
+
else:
|
|
209
|
+
# self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184
|
|
210
|
+
self.num_regs_softmax = 200
|
|
211
|
+
# self.num_regs_softmax = 176
|
|
212
|
+
# self.num_regs_correction = 96
|
|
213
|
+
# self.num_regs_correction = 80
|
|
214
|
+
# self.num_regs_correction = 64 if self.is_causal or self.is_local else 80
|
|
215
|
+
self.num_regs_correction = 64
|
|
216
|
+
# self.num_regs_other = 32
|
|
217
|
+
# self.num_regs_other = 64
|
|
218
|
+
# self.num_regs_other = 80
|
|
219
|
+
self.num_regs_other = 48
|
|
220
|
+
# self.num_regs_other = 96 if self.is_causal or self.is_local else 80
|
|
221
|
+
# self.num_regs_other = 64 if self.is_causal or self.is_local else 80
|
|
222
|
+
self.num_regs_empty = 24
|
|
223
|
+
|
|
224
|
+
self.buffer_align_bytes = 1024
|
|
225
|
+
|
|
226
|
+
def _setup_attributes(self):
|
|
227
|
+
"""Set up configurations and parameters for the FMHA kernel operation.
|
|
228
|
+
|
|
229
|
+
This method initializes and configures various attributes required for the
|
|
230
|
+
execution of the fused multi-head attention kernel, mainly about the pipeline stages:
|
|
231
|
+
|
|
232
|
+
- Sets up staging parameters for Q, K, V inputs and accumulator data
|
|
233
|
+
- Configures pipeline stages for softmax, correction, and epilogue operations
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
self.kv_stage = 4 if self.q_dtype.width == 8 or self.q_stage == 1 else 3
|
|
237
|
+
self.acc_stage = 1
|
|
238
|
+
# For hdim 192,128, we don't have enough smem to store all 3 stages of KV:
|
|
239
|
+
# 128 x 192 x 2 bytes x 3 stages = 144KB, and we need 96KB for Q.
|
|
240
|
+
# Instead we store smem as [smem_large, smem_small, smem_large], where smem_large is
|
|
241
|
+
# 128 x 192 and smem_small is 128 x 128. We set the stride between the stages to be
|
|
242
|
+
# 128 * 160, so that indexing the 0th and 2nd stages will get the right address,
|
|
243
|
+
# but for the 1st stage we need to add or subtract (depending on phase) 128 x 64.
|
|
244
|
+
self.uneven_kv_smem = (
|
|
245
|
+
self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3
|
|
246
|
+
)
|
|
247
|
+
self.uneven_kv_smem_offset = (
|
|
248
|
+
self.m_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2
|
|
249
|
+
if self.uneven_kv_smem
|
|
250
|
+
else 0
|
|
251
|
+
)
|
|
252
|
+
assert self.uneven_kv_smem_offset % 1024 == 0
|
|
253
|
+
|
|
254
|
+
@cute.jit
|
|
255
|
+
def __call__(
|
|
256
|
+
self,
|
|
257
|
+
mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
|
|
258
|
+
mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table
|
|
259
|
+
mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table
|
|
260
|
+
mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
|
261
|
+
mLSE: Optional[cute.Tensor],
|
|
262
|
+
softmax_scale: Float32,
|
|
263
|
+
stream: cuda.CUstream,
|
|
264
|
+
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
|
265
|
+
mCuSeqlensK: Optional[cute.Tensor] = None,
|
|
266
|
+
mSeqUsedQ: Optional[cute.Tensor] = None,
|
|
267
|
+
mSeqUsedK: Optional[cute.Tensor] = None,
|
|
268
|
+
mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq)
|
|
269
|
+
window_size_left: Int32 | int | None = None,
|
|
270
|
+
window_size_right: Int32 | int | None = None,
|
|
271
|
+
learnable_sink: Optional[cute.Tensor] = None,
|
|
272
|
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
|
273
|
+
aux_tensors: Optional[list] = None,
|
|
274
|
+
):
|
|
275
|
+
"""Execute the Fused Multi-Head Attention operation on the provided tensors.
|
|
276
|
+
|
|
277
|
+
This method prepares the input tensors for processing, validates their shapes and types,
|
|
278
|
+
configures the computation parameters, and launches the CUDA kernel.
|
|
279
|
+
|
|
280
|
+
The method handles:
|
|
281
|
+
1. Tensor layout transformations for specific memory access patterns
|
|
282
|
+
2. Validation of tensor shapes and data types
|
|
283
|
+
3. Initialization of hardware-specific parameters and memory layouts
|
|
284
|
+
4. Configuration of TMA (Tensor Memory Access) operations
|
|
285
|
+
5. Grid and work scheduling computation
|
|
286
|
+
6. Kernel launch with appropriate parameters
|
|
287
|
+
"""
|
|
288
|
+
# setup static attributes before smem/grid/tma computation
|
|
289
|
+
self.q_dtype = mQ.element_type
|
|
290
|
+
self.k_dtype = mK.element_type
|
|
291
|
+
self.v_dtype = mV.element_type
|
|
292
|
+
self.o_dtype = mO.element_type
|
|
293
|
+
# Assume all strides are divisible by 128 bits except the last stride
|
|
294
|
+
new_stride = lambda t: (
|
|
295
|
+
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
|
|
296
|
+
t.stride[-1],
|
|
297
|
+
)
|
|
298
|
+
mQ, mK, mV, mO = [
|
|
299
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
300
|
+
for t in (mQ, mK, mV, mO)
|
|
301
|
+
]
|
|
302
|
+
Q_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
|
|
303
|
+
mQ = cute.make_tensor(mQ.iterator, cute.select(mQ.layout, mode=Q_layout_transpose))
|
|
304
|
+
# (s_k, d, h_k, b_k) or (total_k, d, h_k) if there's cu_seqlens_k or (page_size, d, h_k, num_pages) if there's page_table
|
|
305
|
+
KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]
|
|
306
|
+
mK, mV = [
|
|
307
|
+
cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose))
|
|
308
|
+
for t in (mK, mV)
|
|
309
|
+
]
|
|
310
|
+
if const_expr(self.is_split_kv):
|
|
311
|
+
O_layout_transpose = [2, 4, 3, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 3, 2, 0]
|
|
312
|
+
LSE_layout_transpose = [3, 2, 1, 0] if const_expr(mCuSeqlensQ is None) else [2, 1, 0]
|
|
313
|
+
num_splits = mO.shape[0]
|
|
314
|
+
else:
|
|
315
|
+
O_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
|
|
316
|
+
LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
|
|
317
|
+
num_splits = Int32(1)
|
|
318
|
+
mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose))
|
|
319
|
+
mLSE = (
|
|
320
|
+
cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))
|
|
321
|
+
if const_expr(mLSE is not None)
|
|
322
|
+
else None
|
|
323
|
+
)
|
|
324
|
+
# (s, d, h, b) -> (d, s, h, b)
|
|
325
|
+
V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2]
|
|
326
|
+
mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose))
|
|
327
|
+
|
|
328
|
+
self.q_major_mode = cutlass.utils.LayoutEnum.from_tensor(mQ).mma_major_mode()
|
|
329
|
+
self.k_major_mode = cutlass.utils.LayoutEnum.from_tensor(mK).mma_major_mode()
|
|
330
|
+
self.v_major_mode = cutlass.utils.LayoutEnum.from_tensor(mV).mma_major_mode()
|
|
331
|
+
self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO)
|
|
332
|
+
|
|
333
|
+
if const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K):
|
|
334
|
+
raise RuntimeError("The layout of mQ is not supported")
|
|
335
|
+
if const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K):
|
|
336
|
+
raise RuntimeError("The layout of mK is not supported")
|
|
337
|
+
if const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN):
|
|
338
|
+
raise RuntimeError("The layout of mV is not supported")
|
|
339
|
+
|
|
340
|
+
# check type consistency
|
|
341
|
+
if const_expr(self.q_dtype != self.k_dtype):
|
|
342
|
+
raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}")
|
|
343
|
+
if const_expr(self.q_dtype != self.v_dtype):
|
|
344
|
+
raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}")
|
|
345
|
+
self._setup_attributes()
|
|
346
|
+
self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None
|
|
347
|
+
# This can be tuned
|
|
348
|
+
self.e2e_freq = 16
|
|
349
|
+
if const_expr(
|
|
350
|
+
self.head_dim_padded > 64 and not self.is_causal and not self.is_local and self.pack_gqa
|
|
351
|
+
):
|
|
352
|
+
self.e2e_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10
|
|
353
|
+
|
|
354
|
+
cta_group = tcgen05.CtaGroup.ONE
|
|
355
|
+
# the intermediate tensor p is from tmem & mK-major
|
|
356
|
+
p_source = tcgen05.OperandSource.TMEM
|
|
357
|
+
p_major_mode = tcgen05.OperandMajorMode.K
|
|
358
|
+
tiled_mma_qk = sm100_utils_basic.make_trivial_tiled_mma(
|
|
359
|
+
self.q_dtype,
|
|
360
|
+
self.q_major_mode,
|
|
361
|
+
self.k_major_mode,
|
|
362
|
+
self.qk_acc_dtype,
|
|
363
|
+
cta_group,
|
|
364
|
+
self.mma_tiler_qk[:2],
|
|
365
|
+
)
|
|
366
|
+
tiled_mma_pv = sm100_utils_basic.make_trivial_tiled_mma(
|
|
367
|
+
self.v_dtype,
|
|
368
|
+
p_major_mode,
|
|
369
|
+
self.v_major_mode,
|
|
370
|
+
self.pv_acc_dtype,
|
|
371
|
+
cta_group,
|
|
372
|
+
self.mma_tiler_pv[:2],
|
|
373
|
+
p_source,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
self.cluster_shape_mnk = (*self.cluster_shape_mn, 1)
|
|
377
|
+
self.cluster_layout_vmnk = cute.tiled_divide(
|
|
378
|
+
cute.make_layout(self.cluster_shape_mnk),
|
|
379
|
+
(tiled_mma_qk.thr_id.shape,),
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
self.epi_tile = self.mma_tiler_pv[:2]
|
|
383
|
+
|
|
384
|
+
sQ_layout = sm100_utils_basic.make_smem_layout_a(
|
|
385
|
+
tiled_mma_qk,
|
|
386
|
+
self.mma_tiler_qk,
|
|
387
|
+
self.q_dtype,
|
|
388
|
+
self.q_stage,
|
|
389
|
+
)
|
|
390
|
+
sK_layout = sm100_utils_basic.make_smem_layout_b(
|
|
391
|
+
tiled_mma_qk,
|
|
392
|
+
self.mma_tiler_qk,
|
|
393
|
+
self.k_dtype,
|
|
394
|
+
self.kv_stage,
|
|
395
|
+
)
|
|
396
|
+
tP_layout = sm100_utils_basic.make_smem_layout_a(
|
|
397
|
+
tiled_mma_pv,
|
|
398
|
+
self.mma_tiler_pv,
|
|
399
|
+
self.q_dtype,
|
|
400
|
+
self.acc_stage,
|
|
401
|
+
)
|
|
402
|
+
sV_layout = sm100_utils_basic.make_smem_layout_b(
|
|
403
|
+
tiled_mma_pv,
|
|
404
|
+
self.mma_tiler_pv,
|
|
405
|
+
self.v_dtype,
|
|
406
|
+
self.kv_stage,
|
|
407
|
+
)
|
|
408
|
+
sO_layout = sm100_utils_basic.make_smem_layout_epi(
|
|
409
|
+
self.o_dtype,
|
|
410
|
+
self.o_layout,
|
|
411
|
+
self.epi_tile,
|
|
412
|
+
self.q_stage,
|
|
413
|
+
)
|
|
414
|
+
if const_expr(not self.same_hdim_kv_padded):
|
|
415
|
+
# sK and sV are using the same physical smem so we need to adjust the stride so that they line up
|
|
416
|
+
stride_sK = const_expr(
|
|
417
|
+
max(sK_layout.outer.stride[-1], 0)
|
|
418
|
+
) # take max to turn tuple to Int32
|
|
419
|
+
stride_sV = const_expr(max(sV_layout.outer.stride[-1], 0))
|
|
420
|
+
stage_stride = const_expr(
|
|
421
|
+
max(stride_sK, stride_sV)
|
|
422
|
+
if not self.uneven_kv_smem
|
|
423
|
+
else (stride_sK + stride_sV) // 2
|
|
424
|
+
)
|
|
425
|
+
sK_layout = cute.make_composed_layout(
|
|
426
|
+
sK_layout.inner,
|
|
427
|
+
0,
|
|
428
|
+
cute.make_layout(
|
|
429
|
+
(*sK_layout.outer.shape[:-1], self.kv_stage),
|
|
430
|
+
stride=(*sK_layout.outer.stride[:-1], stage_stride),
|
|
431
|
+
),
|
|
432
|
+
)
|
|
433
|
+
sV_layout = cute.make_composed_layout(
|
|
434
|
+
sV_layout.inner,
|
|
435
|
+
0,
|
|
436
|
+
cute.make_layout(
|
|
437
|
+
(*sV_layout.outer.shape[:-1], self.kv_stage),
|
|
438
|
+
stride=(*sV_layout.outer.stride[:-1], stage_stride),
|
|
439
|
+
),
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
if const_expr(self.pack_gqa):
|
|
443
|
+
shape_Q_packed = (
|
|
444
|
+
(self.qhead_per_kvhead, mQ.shape[0]),
|
|
445
|
+
mQ.shape[1],
|
|
446
|
+
mK.shape[2],
|
|
447
|
+
*mQ.shape[3:],
|
|
448
|
+
)
|
|
449
|
+
stride_Q_packed = (
|
|
450
|
+
(mQ.stride[2], mQ.stride[0]),
|
|
451
|
+
mQ.stride[1],
|
|
452
|
+
mQ.stride[2] * self.qhead_per_kvhead,
|
|
453
|
+
*mQ.stride[3:],
|
|
454
|
+
)
|
|
455
|
+
mQ = cute.make_tensor(
|
|
456
|
+
mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)
|
|
457
|
+
)
|
|
458
|
+
shape_O_packed = (
|
|
459
|
+
(self.qhead_per_kvhead, mO.shape[0]),
|
|
460
|
+
mO.shape[1],
|
|
461
|
+
mK.shape[2],
|
|
462
|
+
*mO.shape[3:],
|
|
463
|
+
)
|
|
464
|
+
stride_O_packed = (
|
|
465
|
+
(mO.stride[2], mO.stride[0]),
|
|
466
|
+
mO.stride[1],
|
|
467
|
+
mO.stride[2] * self.qhead_per_kvhead,
|
|
468
|
+
*mO.stride[3:],
|
|
469
|
+
)
|
|
470
|
+
mO = cute.make_tensor(
|
|
471
|
+
mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)
|
|
472
|
+
)
|
|
473
|
+
if const_expr(mLSE is not None):
|
|
474
|
+
shape_LSE_packed = (
|
|
475
|
+
(self.qhead_per_kvhead, mLSE.shape[0]),
|
|
476
|
+
mK.shape[2],
|
|
477
|
+
*mLSE.shape[2:],
|
|
478
|
+
)
|
|
479
|
+
stride_LSE_packed = (
|
|
480
|
+
(mLSE.stride[1], mLSE.stride[0]),
|
|
481
|
+
mLSE.stride[1] * self.qhead_per_kvhead,
|
|
482
|
+
*mLSE.stride[2:],
|
|
483
|
+
)
|
|
484
|
+
mLSE = cute.make_tensor(
|
|
485
|
+
mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
self.tma_copy_bytes = {
|
|
489
|
+
name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2]))
|
|
490
|
+
for name, mX, layout in [
|
|
491
|
+
("Q", mQ, sQ_layout),
|
|
492
|
+
("K", mK, sK_layout),
|
|
493
|
+
("V", mV, sV_layout),
|
|
494
|
+
]
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
# TMA load for Q
|
|
498
|
+
tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group)
|
|
499
|
+
tma_store_op = cpasync.CopyBulkTensorTileS2GOp()
|
|
500
|
+
|
|
501
|
+
tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A(
|
|
502
|
+
tma_load_op,
|
|
503
|
+
mQ,
|
|
504
|
+
cute.select(sQ_layout, mode=[0, 1, 2]),
|
|
505
|
+
self.mma_tiler_qk,
|
|
506
|
+
tiled_mma_qk,
|
|
507
|
+
self.cluster_layout_vmnk.shape,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
if const_expr(self.use_tma_KV):
|
|
511
|
+
# TMA load for K
|
|
512
|
+
tma_atom_K, mK = cute.nvgpu.make_tiled_tma_atom_B(
|
|
513
|
+
tma_load_op,
|
|
514
|
+
mK,
|
|
515
|
+
cute.select(sK_layout, mode=[0, 1, 2]),
|
|
516
|
+
self.mma_tiler_qk,
|
|
517
|
+
tiled_mma_qk,
|
|
518
|
+
self.cluster_layout_vmnk.shape,
|
|
519
|
+
)
|
|
520
|
+
# TMA load for V
|
|
521
|
+
tma_atom_V, mV = cute.nvgpu.make_tiled_tma_atom_B(
|
|
522
|
+
tma_load_op,
|
|
523
|
+
mV,
|
|
524
|
+
cute.select(sV_layout, mode=[0, 1, 2]),
|
|
525
|
+
self.mma_tiler_pv,
|
|
526
|
+
tiled_mma_pv,
|
|
527
|
+
self.cluster_layout_vmnk.shape,
|
|
528
|
+
)
|
|
529
|
+
else:
|
|
530
|
+
tma_atom_K = None
|
|
531
|
+
tma_atom_V = None
|
|
532
|
+
|
|
533
|
+
o_cta_v_layout = cute.composition(cute.make_identity_layout(mO.shape), self.epi_tile)
|
|
534
|
+
|
|
535
|
+
self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)
|
|
536
|
+
if const_expr(self.use_tma_O):
|
|
537
|
+
tma_atom_O, mO = cpasync.make_tiled_tma_atom(
|
|
538
|
+
tma_store_op,
|
|
539
|
+
mO,
|
|
540
|
+
cute.select(sO_layout, mode=[0, 1]),
|
|
541
|
+
o_cta_v_layout,
|
|
542
|
+
)
|
|
543
|
+
gmem_tiled_copy_O = None
|
|
544
|
+
else:
|
|
545
|
+
tma_atom_O = None
|
|
546
|
+
universal_copy_bits = 128
|
|
547
|
+
async_copy_elems = universal_copy_bits // self.o_dtype.width
|
|
548
|
+
atom_universal_copy = cute.make_copy_atom(
|
|
549
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
550
|
+
self.o_dtype,
|
|
551
|
+
num_bits_per_copy=universal_copy_bits,
|
|
552
|
+
)
|
|
553
|
+
tO_shape_dim_1 = sO_layout.outer.shape[1][0] // async_copy_elems
|
|
554
|
+
tO_layout = cute.make_ordered_layout(
|
|
555
|
+
(self.num_epilogue_threads // tO_shape_dim_1, tO_shape_dim_1),
|
|
556
|
+
order=(1, 0),
|
|
557
|
+
)
|
|
558
|
+
# So that we don't have to check if we overshoot kBlockM when we store O
|
|
559
|
+
assert self.m_block_size % tO_layout.shape[0] == 0
|
|
560
|
+
vO_layout = cute.make_layout((1, async_copy_elems))
|
|
561
|
+
gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout)
|
|
562
|
+
|
|
563
|
+
if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):
|
|
564
|
+
TileScheduler = SingleTileVarlenScheduler
|
|
565
|
+
else:
|
|
566
|
+
if const_expr(self.is_causal or self.is_local):
|
|
567
|
+
TileScheduler = SingleTileLPTScheduler
|
|
568
|
+
else:
|
|
569
|
+
TileScheduler = (
|
|
570
|
+
SingleTileScheduler
|
|
571
|
+
if const_expr(not self.is_persistent)
|
|
572
|
+
else StaticPersistentTileScheduler
|
|
573
|
+
)
|
|
574
|
+
tile_sched_args = TileSchedulerArguments(
|
|
575
|
+
cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]),
|
|
576
|
+
cute.size(mQ.shape[2]),
|
|
577
|
+
cute.size(mQ.shape[3])
|
|
578
|
+
if const_expr(mCuSeqlensQ is None)
|
|
579
|
+
else cute.size(mCuSeqlensQ.shape[0] - 1),
|
|
580
|
+
num_splits,
|
|
581
|
+
cute.size(mK.shape[0])
|
|
582
|
+
if const_expr(mPageTable is None)
|
|
583
|
+
else mK.shape[0] * mPageTable.shape[1],
|
|
584
|
+
mQ.shape[1],
|
|
585
|
+
mV.shape[0], # Note that this is different from Sm90 since we transpose mV in Sm100
|
|
586
|
+
total_q=cute.size(mQ.shape[0])
|
|
587
|
+
if const_expr(mCuSeqlensQ is not None)
|
|
588
|
+
else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),
|
|
589
|
+
tile_shape_mn=self.cta_tiler[:2],
|
|
590
|
+
mCuSeqlensQ=mCuSeqlensQ,
|
|
591
|
+
mSeqUsedQ=mSeqUsedQ,
|
|
592
|
+
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
|
593
|
+
element_size=self.k_dtype.width // 8,
|
|
594
|
+
is_persistent=self.is_persistent,
|
|
595
|
+
lpt=self.is_causal or self.is_local,
|
|
596
|
+
is_split_kv=self.is_split_kv,
|
|
597
|
+
)
|
|
598
|
+
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
|
599
|
+
self.tile_scheduler_cls = TileScheduler
|
|
600
|
+
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
|
|
601
|
+
|
|
602
|
+
self.mbar_load_q_full_offset = 0
|
|
603
|
+
self.mbar_load_q_empty_offset = self.mbar_load_q_full_offset + self.q_stage
|
|
604
|
+
self.mbar_load_kv_full_offset = self.mbar_load_q_empty_offset + self.q_stage
|
|
605
|
+
self.mbar_load_kv_empty_offset = self.mbar_load_kv_full_offset + self.kv_stage
|
|
606
|
+
self.mbar_P_full_O_rescaled_offset = self.mbar_load_kv_empty_offset + self.kv_stage
|
|
607
|
+
self.mbar_S_full_offset = self.mbar_P_full_O_rescaled_offset + self.q_stage
|
|
608
|
+
self.mbar_O_full_offset = self.mbar_S_full_offset + self.q_stage
|
|
609
|
+
self.mbar_softmax_corr_full_offset = self.mbar_O_full_offset + self.q_stage
|
|
610
|
+
self.mbar_softmax_corr_empty_offset = self.mbar_softmax_corr_full_offset + self.q_stage
|
|
611
|
+
self.mbar_corr_epi_full_offset = self.mbar_softmax_corr_empty_offset + self.q_stage
|
|
612
|
+
self.mbar_corr_epi_empty_offset = self.mbar_corr_epi_full_offset + self.q_stage
|
|
613
|
+
self.mbar_s0_s1_sequence_offset = self.mbar_corr_epi_empty_offset + self.q_stage
|
|
614
|
+
self.mbar_tmem_dealloc_offset = self.mbar_s0_s1_sequence_offset + 8
|
|
615
|
+
self.mbar_P_full_2_offset = self.mbar_tmem_dealloc_offset + 1
|
|
616
|
+
self.mbar_total = self.mbar_P_full_2_offset + self.q_stage
|
|
617
|
+
|
|
618
|
+
sO_size = cute.cosize(sO_layout) if const_expr(not self.overlap_sO_sQ) else 0
|
|
619
|
+
sQ_size = (
|
|
620
|
+
cute.cosize(sQ_layout) if const_expr(not self.overlap_sO_sQ) else
|
|
621
|
+
cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width)
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
@cute.struct
|
|
625
|
+
class SharedStorage:
|
|
626
|
+
# m_barriers for pipelines
|
|
627
|
+
mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mbar_total]
|
|
628
|
+
# Tmem holding buffer
|
|
629
|
+
tmem_holding_buf: Int32
|
|
630
|
+
# Smem tensors
|
|
631
|
+
# store row max and row sum
|
|
632
|
+
sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2]
|
|
633
|
+
sO: cute.struct.Align[
|
|
634
|
+
cute.struct.MemRange[self.o_dtype, sO_size],
|
|
635
|
+
self.buffer_align_bytes,
|
|
636
|
+
]
|
|
637
|
+
sQ: cute.struct.Align[
|
|
638
|
+
cute.struct.MemRange[self.q_dtype, sQ_size],
|
|
639
|
+
self.buffer_align_bytes,
|
|
640
|
+
]
|
|
641
|
+
sK: cute.struct.Align[
|
|
642
|
+
# cute.cosize(sK_layout) is correct even in the case of self.uneven_kv_smem
|
|
643
|
+
cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)],
|
|
644
|
+
self.buffer_align_bytes,
|
|
645
|
+
]
|
|
646
|
+
|
|
647
|
+
self.shared_storage = SharedStorage
|
|
648
|
+
|
|
649
|
+
LOG2_E = math.log2(math.e)
|
|
650
|
+
if const_expr(self.score_mod is None):
|
|
651
|
+
softmax_scale_log2 = softmax_scale * LOG2_E
|
|
652
|
+
softmax_scale = None
|
|
653
|
+
else:
|
|
654
|
+
# NB: If a users passes in a score mod, we want to apply the score-mod in the sm_scaled qk
|
|
655
|
+
# But in the original base 10. We hijack softmax_scale_log2 to just be the change of base
|
|
656
|
+
# and correctly apply the softmax_scale prior to score_mod in the softmax step
|
|
657
|
+
softmax_scale_log2 = LOG2_E
|
|
658
|
+
softmax_scale = softmax_scale
|
|
659
|
+
|
|
660
|
+
if const_expr(window_size_left is not None):
|
|
661
|
+
window_size_left = Int32(window_size_left)
|
|
662
|
+
if const_expr(window_size_right is not None):
|
|
663
|
+
window_size_right = Int32(window_size_right)
|
|
664
|
+
|
|
665
|
+
fastdiv_mods = None
|
|
666
|
+
if cutlass.const_expr(aux_tensors is not None):
|
|
667
|
+
seqlen_q = cute.size(mQ.shape[0]) // (
|
|
668
|
+
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
|
|
669
|
+
)
|
|
670
|
+
seqlen_k = (
|
|
671
|
+
cute.size(mK.shape[0])
|
|
672
|
+
if const_expr(mPageTable is None)
|
|
673
|
+
else mK.shape[0] * mPageTable.shape[1]
|
|
674
|
+
)
|
|
675
|
+
seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
|
|
676
|
+
seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
|
|
677
|
+
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
|
|
678
|
+
|
|
679
|
+
self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
|
|
680
|
+
if cutlass.const_expr(self.use_block_sparsity and mPageTable is not None):
|
|
681
|
+
raise NotImplementedError("Block sparsity + paged KV not supported on SM100")
|
|
682
|
+
|
|
683
|
+
# Launch the kernel synchronously
|
|
684
|
+
self.kernel(
|
|
685
|
+
mQ,
|
|
686
|
+
mK,
|
|
687
|
+
mV,
|
|
688
|
+
mO,
|
|
689
|
+
mLSE,
|
|
690
|
+
mCuSeqlensQ,
|
|
691
|
+
mCuSeqlensK,
|
|
692
|
+
mSeqUsedQ,
|
|
693
|
+
mSeqUsedK,
|
|
694
|
+
mPageTable,
|
|
695
|
+
tma_atom_Q,
|
|
696
|
+
tma_atom_K,
|
|
697
|
+
tma_atom_V,
|
|
698
|
+
tma_atom_O,
|
|
699
|
+
softmax_scale_log2,
|
|
700
|
+
softmax_scale,
|
|
701
|
+
window_size_left,
|
|
702
|
+
window_size_right,
|
|
703
|
+
learnable_sink,
|
|
704
|
+
blocksparse_tensors,
|
|
705
|
+
sQ_layout,
|
|
706
|
+
sK_layout,
|
|
707
|
+
tP_layout,
|
|
708
|
+
sV_layout,
|
|
709
|
+
sO_layout,
|
|
710
|
+
gmem_tiled_copy_O,
|
|
711
|
+
tiled_mma_qk,
|
|
712
|
+
tiled_mma_pv,
|
|
713
|
+
tile_sched_params,
|
|
714
|
+
num_splits,
|
|
715
|
+
aux_tensors,
|
|
716
|
+
fastdiv_mods,
|
|
717
|
+
).launch(
|
|
718
|
+
grid=grid_dim,
|
|
719
|
+
block=[self.threads_per_cta, 1, 1],
|
|
720
|
+
cluster=self.cluster_shape_mnk,
|
|
721
|
+
smem=self.shared_storage.size_in_bytes(),
|
|
722
|
+
stream=stream,
|
|
723
|
+
min_blocks_per_mp=1,
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
# GPU device kernel
|
|
727
|
+
@cute.kernel
|
|
728
|
+
def kernel(
|
|
729
|
+
self,
|
|
730
|
+
mQ: cute.Tensor, # (s_q, d, h, b) or (total_q, d, h) if there is cu_seqlens_q
|
|
731
|
+
mK: cute.Tensor, # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there is cu_seqlens_k or (page_size, d, h_k, num_pages) if there is page_table
|
|
732
|
+
mV: cute.Tensor, # (d, s_k, h_k, b_k) or (d, total_k, h_k) if there is cu_seqlens_k or (d, page_size, h_k, num_pages) if there is page_table
|
|
733
|
+
mO: cute.Tensor,
|
|
734
|
+
mLSE: Optional[cute.Tensor],
|
|
735
|
+
mCuSeqlensQ: Optional[cute.Tensor],
|
|
736
|
+
mCuSeqlensK: Optional[cute.Tensor],
|
|
737
|
+
mSeqUsedQ: Optional[cute.Tensor],
|
|
738
|
+
mSeqUsedK: Optional[cute.Tensor],
|
|
739
|
+
mPageTable: Optional[cute.Tensor],
|
|
740
|
+
tma_atom_Q: cute.CopyAtom,
|
|
741
|
+
tma_atom_K: Optional[cute.CopyAtom],
|
|
742
|
+
tma_atom_V: Optional[cute.CopyAtom],
|
|
743
|
+
tma_atom_O: Optional[cute.CopyAtom],
|
|
744
|
+
softmax_scale_log2: Float32,
|
|
745
|
+
softmax_scale: Float32 | None,
|
|
746
|
+
window_size_left: Optional[Int32],
|
|
747
|
+
window_size_right: Optional[Int32],
|
|
748
|
+
learnable_sink: Optional[cute.Tensor],
|
|
749
|
+
blocksparse_tensors: Optional[BlockSparseTensors],
|
|
750
|
+
sQ_layout: cute.ComposedLayout,
|
|
751
|
+
sK_layout: cute.ComposedLayout,
|
|
752
|
+
tP_layout: cute.ComposedLayout,
|
|
753
|
+
sV_layout: cute.ComposedLayout,
|
|
754
|
+
sO_layout: cute.ComposedLayout,
|
|
755
|
+
gmem_tiled_copy_O: Optional[cute.TiledCopy],
|
|
756
|
+
tiled_mma_qk: cute.TiledMma,
|
|
757
|
+
tiled_mma_pv: cute.TiledMma,
|
|
758
|
+
tile_sched_params: ParamsBase,
|
|
759
|
+
num_splits: Int32,
|
|
760
|
+
aux_tensors: Optional[list] = None,
|
|
761
|
+
fastdiv_mods=(None, None),
|
|
762
|
+
):
|
|
763
|
+
"""The device kernel implementation of the Fused Multi-Head Attention.
|
|
764
|
+
|
|
765
|
+
This kernel coordinates multiple specialized warps to perform different phases of the FMHA computation:
|
|
766
|
+
1. Load warp: Loads Q, K, V data from global memory to shared memory using TMA
|
|
767
|
+
2. MMA warp: Performs matrix multiplications (Q*K^T and P*V)
|
|
768
|
+
3. Softmax warps: Compute softmax normalization on attention scores
|
|
769
|
+
4. Correction warps: Apply adjustments to intermediate results
|
|
770
|
+
5. Epilogue warp: Handles final output transformation and storage
|
|
771
|
+
|
|
772
|
+
The kernel implements a complex pipeline with overlapping computation and memory operations,
|
|
773
|
+
using tensor memory access (TMA) for efficient data loading, warp specialization for different
|
|
774
|
+
computation phases, and optional attention masking.
|
|
775
|
+
"""
|
|
776
|
+
|
|
777
|
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
778
|
+
|
|
779
|
+
# Prefetch tma descriptor
|
|
780
|
+
if warp_idx == 0:
|
|
781
|
+
cpasync.prefetch_descriptor(tma_atom_Q)
|
|
782
|
+
if const_expr(tma_atom_K is not None):
|
|
783
|
+
cpasync.prefetch_descriptor(tma_atom_K)
|
|
784
|
+
if const_expr(tma_atom_V is not None):
|
|
785
|
+
cpasync.prefetch_descriptor(tma_atom_V)
|
|
786
|
+
if const_expr(tma_atom_O is not None):
|
|
787
|
+
cpasync.prefetch_descriptor(tma_atom_O)
|
|
788
|
+
|
|
789
|
+
# Alloc
|
|
790
|
+
smem = cutlass.utils.SmemAllocator()
|
|
791
|
+
storage = smem.allocate(self.shared_storage)
|
|
792
|
+
|
|
793
|
+
mbar_ptr = storage.mbar_ptr.data_ptr()
|
|
794
|
+
# Use the first N warps to initialize barriers
|
|
795
|
+
if warp_idx == 1:
|
|
796
|
+
# Init "full" barrier with number of producers, "empty" barrier with number of consumers
|
|
797
|
+
for i in cutlass.range_constexpr(self.q_stage):
|
|
798
|
+
cute.arch.mbarrier_init(
|
|
799
|
+
mbar_ptr + self.mbar_load_q_full_offset + i, 1
|
|
800
|
+
)
|
|
801
|
+
cute.arch.mbarrier_init(
|
|
802
|
+
mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id])
|
|
803
|
+
)
|
|
804
|
+
if warp_idx == 2:
|
|
805
|
+
for i in cutlass.range_constexpr(self.q_stage):
|
|
806
|
+
cute.arch.mbarrier_init(
|
|
807
|
+
mbar_ptr + self.mbar_softmax_corr_empty_offset + i, cute.arch.WARP_SIZE * 4
|
|
808
|
+
)
|
|
809
|
+
cute.arch.mbarrier_init(
|
|
810
|
+
mbar_ptr + self.mbar_softmax_corr_full_offset + i, cute.arch.WARP_SIZE * 4
|
|
811
|
+
)
|
|
812
|
+
if warp_idx == 3:
|
|
813
|
+
if const_expr(self.s0_s1_barrier):
|
|
814
|
+
for i in cutlass.range_constexpr(8):
|
|
815
|
+
cute.arch.mbarrier_init(
|
|
816
|
+
mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE
|
|
817
|
+
)
|
|
818
|
+
if const_expr(not self.use_correction_warps_for_epi) and warp_idx == 4:
|
|
819
|
+
for i in cutlass.range_constexpr(self.q_stage):
|
|
820
|
+
cute.arch.mbarrier_init(
|
|
821
|
+
mbar_ptr + self.mbar_corr_epi_full_offset + i,
|
|
822
|
+
cute.arch.WARP_SIZE * len(self.correction_warp_ids),
|
|
823
|
+
)
|
|
824
|
+
cute.arch.mbarrier_init(
|
|
825
|
+
mbar_ptr + self.mbar_corr_epi_empty_offset + i,
|
|
826
|
+
cute.arch.WARP_SIZE * len(self.epilogue_warp_ids),
|
|
827
|
+
)
|
|
828
|
+
if warp_idx == 5:
|
|
829
|
+
for i in cutlass.range_constexpr(self.q_stage):
|
|
830
|
+
cute.arch.mbarrier_init(
|
|
831
|
+
mbar_ptr + self.mbar_P_full_O_rescaled_offset + i,
|
|
832
|
+
cute.arch.WARP_SIZE
|
|
833
|
+
* (len(self.softmax0_warp_ids) + len(self.correction_warp_ids)),
|
|
834
|
+
)
|
|
835
|
+
cute.arch.mbarrier_init(
|
|
836
|
+
mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id])
|
|
837
|
+
)
|
|
838
|
+
cute.arch.mbarrier_init(
|
|
839
|
+
mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id])
|
|
840
|
+
)
|
|
841
|
+
if warp_idx == 6:
|
|
842
|
+
for i in cutlass.range_constexpr(self.q_stage):
|
|
843
|
+
cute.arch.mbarrier_init(
|
|
844
|
+
mbar_ptr + self.mbar_P_full_2_offset + i,
|
|
845
|
+
cute.arch.WARP_SIZE * len(self.softmax0_warp_ids),
|
|
846
|
+
)
|
|
847
|
+
if warp_idx == 7:
|
|
848
|
+
cute.arch.mbarrier_init(
|
|
849
|
+
mbar_ptr + self.mbar_tmem_dealloc_offset,
|
|
850
|
+
cute.arch.WARP_SIZE
|
|
851
|
+
* len(
|
|
852
|
+
(
|
|
853
|
+
*self.softmax0_warp_ids,
|
|
854
|
+
*self.softmax1_warp_ids,
|
|
855
|
+
*self.correction_warp_ids,
|
|
856
|
+
)
|
|
857
|
+
),
|
|
858
|
+
)
|
|
859
|
+
# Relying on pipeline_kv constructor to call mbarrier_init_fence and sync
|
|
860
|
+
pipeline_kv = self.make_and_init_load_kv_pipeline(mbar_ptr + self.mbar_load_kv_full_offset)
|
|
861
|
+
|
|
862
|
+
# Generate smem tensor Q/K/V/O
|
|
863
|
+
# (MMA, MMA_Q, MMA_D, PIPE)
|
|
864
|
+
sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner)
|
|
865
|
+
# (MMA, MMA_K, MMA_D, PIPE)
|
|
866
|
+
sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)
|
|
867
|
+
# (MMA, MMA_K, MMA_D, PIPE)
|
|
868
|
+
# Strip swizzle info to reuse smem
|
|
869
|
+
sV = cute.make_tensor(cute.recast_ptr(sK.iterator, sV_layout.inner), sV_layout.outer)
|
|
870
|
+
if const_expr(not self.overlap_sO_sQ):
|
|
871
|
+
sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner)
|
|
872
|
+
else:
|
|
873
|
+
sO = cute.make_tensor(cute.recast_ptr(sQ.iterator, sO_layout.inner, self.o_dtype), sO_layout.outer)
|
|
874
|
+
|
|
875
|
+
sScale = storage.sScale.get_tensor(cute.make_layout(self.q_stage * self.m_block_size * 2))
|
|
876
|
+
|
|
877
|
+
thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM
|
|
878
|
+
thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM
|
|
879
|
+
|
|
880
|
+
qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2])
|
|
881
|
+
tStS_fake = thr_mma_qk.make_fragment_C(qk_acc_shape)
|
|
882
|
+
# This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always
|
|
883
|
+
# request 512 columns of tmem, so we know that it starts at 0.
|
|
884
|
+
tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16)
|
|
885
|
+
tStS = cute.make_tensor(tmem_ptr, tStS_fake.layout)
|
|
886
|
+
|
|
887
|
+
pv_acc_shape = thr_mma_pv.partition_shape_C(self.mma_tiler_pv[:2])
|
|
888
|
+
tOtO = thr_mma_pv.make_fragment_C(pv_acc_shape)
|
|
889
|
+
|
|
890
|
+
tStSs = tuple(
|
|
891
|
+
cute.make_tensor(tStS.iterator + self.tmem_s_offset[stage], tStS.layout)
|
|
892
|
+
for stage in range(self.q_stage)
|
|
893
|
+
)
|
|
894
|
+
tOtOs = tuple(
|
|
895
|
+
cute.make_tensor(tOtO.iterator + self.tmem_o_offset[stage], tOtO.layout)
|
|
896
|
+
for stage in range(self.q_stage)
|
|
897
|
+
)
|
|
898
|
+
|
|
899
|
+
tP = cute.make_tensor(tStS.iterator, tP_layout.outer)
|
|
900
|
+
tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0]
|
|
901
|
+
|
|
902
|
+
tOrPs = [
|
|
903
|
+
cute.make_tensor(
|
|
904
|
+
tOrP.iterator
|
|
905
|
+
+ self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p_offset[stage],
|
|
906
|
+
tOrP.layout,
|
|
907
|
+
)
|
|
908
|
+
for stage in range(self.q_stage)
|
|
909
|
+
]
|
|
910
|
+
|
|
911
|
+
block_info = BlockInfo(
|
|
912
|
+
# This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1])
|
|
913
|
+
self.cta_tiler[0],
|
|
914
|
+
self.cta_tiler[1],
|
|
915
|
+
self.is_causal,
|
|
916
|
+
self.is_local,
|
|
917
|
+
self.is_split_kv,
|
|
918
|
+
window_size_left,
|
|
919
|
+
window_size_right,
|
|
920
|
+
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
|
921
|
+
)
|
|
922
|
+
SeqlenInfoCls = partial(
|
|
923
|
+
SeqlenInfoQK.create,
|
|
924
|
+
seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1],
|
|
925
|
+
seqlen_k_static=mK.shape[0]
|
|
926
|
+
if const_expr(mPageTable is None)
|
|
927
|
+
else mK.shape[0] * mPageTable.shape[1],
|
|
928
|
+
mCuSeqlensQ=mCuSeqlensQ,
|
|
929
|
+
mCuSeqlensK=mCuSeqlensK,
|
|
930
|
+
mSeqUsedQ=mSeqUsedQ,
|
|
931
|
+
mSeqUsedK=mSeqUsedK,
|
|
932
|
+
)
|
|
933
|
+
AttentionMaskCls = partial(
|
|
934
|
+
AttentionMask,
|
|
935
|
+
self.m_block_size,
|
|
936
|
+
self.n_block_size,
|
|
937
|
+
window_size_left=window_size_left,
|
|
938
|
+
window_size_right=window_size_right,
|
|
939
|
+
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
|
940
|
+
)
|
|
941
|
+
TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params)
|
|
942
|
+
|
|
943
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
944
|
+
# EMPTY
|
|
945
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
946
|
+
for i in cutlass.range_constexpr(len(self.empty_warp_ids)):
|
|
947
|
+
if warp_idx == self.empty_warp_ids[i]:
|
|
948
|
+
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
|
|
949
|
+
|
|
950
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
951
|
+
# LOAD
|
|
952
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
953
|
+
if warp_idx >= self.load_warp_ids[0] and warp_idx <= self.load_warp_ids[-1]:
|
|
954
|
+
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
|
|
955
|
+
self.load(
|
|
956
|
+
thr_mma_qk,
|
|
957
|
+
thr_mma_pv,
|
|
958
|
+
mQ,
|
|
959
|
+
mK,
|
|
960
|
+
mV,
|
|
961
|
+
sQ,
|
|
962
|
+
sK,
|
|
963
|
+
sV,
|
|
964
|
+
mPageTable,
|
|
965
|
+
tma_atom_Q,
|
|
966
|
+
tma_atom_K,
|
|
967
|
+
tma_atom_V,
|
|
968
|
+
pipeline_kv,
|
|
969
|
+
mbar_ptr,
|
|
970
|
+
block_info,
|
|
971
|
+
num_splits,
|
|
972
|
+
SeqlenInfoCls,
|
|
973
|
+
TileSchedulerCls,
|
|
974
|
+
blocksparse_tensors,
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
978
|
+
# MMA
|
|
979
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
980
|
+
if warp_idx == self.mma_warp_id:
|
|
981
|
+
# if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids:
|
|
982
|
+
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
|
|
983
|
+
# Alloc tmem buffer
|
|
984
|
+
tmem_alloc_cols = Int32(self.tmem_alloc_cols)
|
|
985
|
+
if warp_idx == self.mma_warp_id:
|
|
986
|
+
cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf)
|
|
987
|
+
cute.arch.sync_warp()
|
|
988
|
+
|
|
989
|
+
self.mma(
|
|
990
|
+
tiled_mma_qk,
|
|
991
|
+
tiled_mma_pv,
|
|
992
|
+
sQ,
|
|
993
|
+
sK,
|
|
994
|
+
sV,
|
|
995
|
+
tStSs,
|
|
996
|
+
tOtOs,
|
|
997
|
+
tOrPs,
|
|
998
|
+
pipeline_kv,
|
|
999
|
+
mbar_ptr,
|
|
1000
|
+
block_info,
|
|
1001
|
+
num_splits,
|
|
1002
|
+
SeqlenInfoCls,
|
|
1003
|
+
TileSchedulerCls,
|
|
1004
|
+
blocksparse_tensors,
|
|
1005
|
+
)
|
|
1006
|
+
|
|
1007
|
+
# if warp_idx == self.mma_warp_id:
|
|
1008
|
+
# dealloc tmem buffer
|
|
1009
|
+
cute.arch.relinquish_tmem_alloc_permit()
|
|
1010
|
+
cute.arch.mbarrier_wait(mbar_ptr + self.mbar_tmem_dealloc_offset, 0)
|
|
1011
|
+
tmem_alloc_cols = Int32(self.tmem_alloc_cols)
|
|
1012
|
+
# Retrieving tmem ptr and make acc
|
|
1013
|
+
tmem_ptr = cute.arch.retrieve_tmem_ptr(
|
|
1014
|
+
Float32,
|
|
1015
|
+
alignment=16,
|
|
1016
|
+
ptr_to_buffer_holding_addr=storage.tmem_holding_buf,
|
|
1017
|
+
)
|
|
1018
|
+
cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols)
|
|
1019
|
+
|
|
1020
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
1021
|
+
# Epilogue
|
|
1022
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
1023
|
+
if const_expr(not self.use_correction_warps_for_epi):
|
|
1024
|
+
if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]:
|
|
1025
|
+
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
|
|
1026
|
+
self.epilogue_s2g(
|
|
1027
|
+
mO,
|
|
1028
|
+
sO,
|
|
1029
|
+
gmem_tiled_copy_O,
|
|
1030
|
+
tma_atom_O,
|
|
1031
|
+
mbar_ptr,
|
|
1032
|
+
block_info,
|
|
1033
|
+
num_splits,
|
|
1034
|
+
SeqlenInfoCls,
|
|
1035
|
+
TileSchedulerCls,
|
|
1036
|
+
)
|
|
1037
|
+
|
|
1038
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
1039
|
+
# Softmax
|
|
1040
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
1041
|
+
if (
|
|
1042
|
+
(const_expr(self.q_stage == 2) and warp_idx <= self.softmax1_warp_ids[-1]) or
|
|
1043
|
+
(const_expr(self.q_stage == 1) and warp_idx <= self.softmax0_warp_ids[-1])
|
|
1044
|
+
):
|
|
1045
|
+
# increase register after decreasing
|
|
1046
|
+
cute.arch.warpgroup_reg_alloc(self.num_regs_softmax)
|
|
1047
|
+
softmax_loop = partial(
|
|
1048
|
+
self.softmax_loop,
|
|
1049
|
+
softmax_scale_log2=softmax_scale_log2,
|
|
1050
|
+
softmax_scale=softmax_scale,
|
|
1051
|
+
thr_mma_qk=thr_mma_qk,
|
|
1052
|
+
sScale=sScale,
|
|
1053
|
+
mLSE=mLSE,
|
|
1054
|
+
learnable_sink=learnable_sink,
|
|
1055
|
+
mbar_ptr=mbar_ptr,
|
|
1056
|
+
block_info=block_info,
|
|
1057
|
+
num_splits=num_splits,
|
|
1058
|
+
SeqlenInfoCls=SeqlenInfoCls,
|
|
1059
|
+
AttentionMaskCls=AttentionMaskCls,
|
|
1060
|
+
TileSchedulerCls=TileSchedulerCls,
|
|
1061
|
+
aux_tensors=aux_tensors,
|
|
1062
|
+
fastdiv_mods=fastdiv_mods,
|
|
1063
|
+
blocksparse_tensors=blocksparse_tensors,
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
if const_expr(not self.s0_s1_barrier):
|
|
1067
|
+
stage = Int32(0 if const_expr(self.q_stage == 1) or warp_idx < self.softmax1_warp_ids[0] else 1)
|
|
1068
|
+
softmax_loop(
|
|
1069
|
+
stage=stage,
|
|
1070
|
+
tStSi=cute.make_tensor(
|
|
1071
|
+
tStS.iterator
|
|
1072
|
+
+ (self.tmem_s_offset[0] if stage == 0 else self.tmem_s_offset[1]),
|
|
1073
|
+
tStS.layout,
|
|
1074
|
+
),
|
|
1075
|
+
)
|
|
1076
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset)
|
|
1077
|
+
else:
|
|
1078
|
+
# If there's s0_s1_barrier, it's faster to have 2 WGs having different code
|
|
1079
|
+
if warp_idx < self.softmax1_warp_ids[0]:
|
|
1080
|
+
tStSi = cute.make_tensor(tStS.iterator + self.tmem_s_offset[0], tStS.layout)
|
|
1081
|
+
softmax_loop(stage=0, tStSi=tStSi)
|
|
1082
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset)
|
|
1083
|
+
if warp_idx < self.correction_warp_ids[0] and warp_idx >= self.softmax1_warp_ids[0]:
|
|
1084
|
+
tStSi = cute.make_tensor(tStS.iterator + self.tmem_s_offset[1], tStS.layout)
|
|
1085
|
+
softmax_loop(stage=1, tStSi=tStSi)
|
|
1086
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset)
|
|
1087
|
+
|
|
1088
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
1089
|
+
# Correction
|
|
1090
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
1091
|
+
if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id:
|
|
1092
|
+
cute.arch.warpgroup_reg_dealloc(self.num_regs_correction)
|
|
1093
|
+
self.correction_loop(
|
|
1094
|
+
thr_mma_qk,
|
|
1095
|
+
thr_mma_pv,
|
|
1096
|
+
tStS,
|
|
1097
|
+
tOtOs,
|
|
1098
|
+
sScale,
|
|
1099
|
+
mO,
|
|
1100
|
+
mLSE,
|
|
1101
|
+
sO,
|
|
1102
|
+
learnable_sink,
|
|
1103
|
+
gmem_tiled_copy_O,
|
|
1104
|
+
tma_atom_O,
|
|
1105
|
+
mbar_ptr,
|
|
1106
|
+
softmax_scale_log2,
|
|
1107
|
+
block_info,
|
|
1108
|
+
num_splits,
|
|
1109
|
+
SeqlenInfoCls,
|
|
1110
|
+
TileSchedulerCls,
|
|
1111
|
+
blocksparse_tensors,
|
|
1112
|
+
)
|
|
1113
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset)
|
|
1114
|
+
|
|
1115
|
+
return
|
|
1116
|
+
|
|
1117
|
+
@cute.jit
|
|
1118
|
+
def load(
|
|
1119
|
+
self,
|
|
1120
|
+
thr_mma_qk: cute.core.ThrMma,
|
|
1121
|
+
thr_mma_pv: cute.core.ThrMma,
|
|
1122
|
+
mQ: cute.Tensor,
|
|
1123
|
+
mK: cute.Tensor,
|
|
1124
|
+
mV: cute.Tensor,
|
|
1125
|
+
sQ: cute.Tensor,
|
|
1126
|
+
sK: cute.Tensor,
|
|
1127
|
+
sV: cute.Tensor,
|
|
1128
|
+
mPageTable: Optional[cute.Tensor],
|
|
1129
|
+
tma_atom_Q: cute.CopyAtom,
|
|
1130
|
+
tma_atom_K: Optional[cute.CopyAtom],
|
|
1131
|
+
tma_atom_V: Optional[cute.CopyAtom],
|
|
1132
|
+
pipeline_kv: cutlass.pipeline.PipelineAsync,
|
|
1133
|
+
mbar_ptr: cute.Pointer,
|
|
1134
|
+
block_info: BlockInfo,
|
|
1135
|
+
num_splits: Int32,
|
|
1136
|
+
SeqlenInfoCls: Callable,
|
|
1137
|
+
TileSchedulerCls: Callable,
|
|
1138
|
+
blocksparse_tensors: Optional[BlockSparseTensors],
|
|
1139
|
+
):
|
|
1140
|
+
num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE
|
|
1141
|
+
tidx = cute.arch.thread_idx()[0] % num_load_threads
|
|
1142
|
+
q_producer_phase = Int32(1)
|
|
1143
|
+
kv_producer_state = cutlass.pipeline.make_pipeline_state(
|
|
1144
|
+
cutlass.pipeline.PipelineUserType.Producer, self.kv_stage
|
|
1145
|
+
)
|
|
1146
|
+
tile_scheduler = TileSchedulerCls()
|
|
1147
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1148
|
+
while work_tile.is_valid_tile:
|
|
1149
|
+
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
|
|
1150
|
+
seqlen = SeqlenInfoCls(batch_idx)
|
|
1151
|
+
mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
|
|
1152
|
+
gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0))
|
|
1153
|
+
|
|
1154
|
+
head_idx_kv = (
|
|
1155
|
+
head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx
|
|
1156
|
+
)
|
|
1157
|
+
if const_expr(mPageTable is None):
|
|
1158
|
+
if const_expr(not seqlen.has_cu_seqlens_k):
|
|
1159
|
+
mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)]
|
|
1160
|
+
else:
|
|
1161
|
+
mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, head_idx_kv])
|
|
1162
|
+
mV_cur = cute.domain_offset((0, seqlen.offset_k), mV[None, None, head_idx_kv])
|
|
1163
|
+
gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0))
|
|
1164
|
+
gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None))
|
|
1165
|
+
else:
|
|
1166
|
+
# Need to keep batch coord None since we'll index into it with page idx
|
|
1167
|
+
mK_cur, mV_cur = [t[None, None, head_idx_kv, None] for t in (mK, mV)]
|
|
1168
|
+
gK = cute.local_tile(
|
|
1169
|
+
mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None)
|
|
1170
|
+
)
|
|
1171
|
+
gV = cute.local_tile(
|
|
1172
|
+
mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None)
|
|
1173
|
+
)
|
|
1174
|
+
tSgQ = thr_mma_qk.partition_A(gQ)
|
|
1175
|
+
tSgK = thr_mma_qk.partition_B(gK)
|
|
1176
|
+
tOgV = thr_mma_pv.partition_B(gV)
|
|
1177
|
+
load_Q_fn, _, _ = copy_utils.tma_get_copy_fn(
|
|
1178
|
+
tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ
|
|
1179
|
+
)
|
|
1180
|
+
|
|
1181
|
+
if const_expr(self.use_tma_KV):
|
|
1182
|
+
tKsK, tKgK = cpasync.tma_partition(
|
|
1183
|
+
tma_atom_K,
|
|
1184
|
+
0, # no multicast
|
|
1185
|
+
cute.make_layout(1),
|
|
1186
|
+
cute.group_modes(sK, 0, 3),
|
|
1187
|
+
cute.group_modes(tSgK, 0, 3),
|
|
1188
|
+
)
|
|
1189
|
+
tVsV, tVgV = cpasync.tma_partition(
|
|
1190
|
+
tma_atom_V,
|
|
1191
|
+
0, # no multicast
|
|
1192
|
+
cute.make_layout(1),
|
|
1193
|
+
cute.group_modes(sV, 0, 3),
|
|
1194
|
+
cute.group_modes(tOgV, 0, 3),
|
|
1195
|
+
)
|
|
1196
|
+
paged_kv_manager = None
|
|
1197
|
+
else:
|
|
1198
|
+
page_size = mK.shape[0]
|
|
1199
|
+
paged_kv_manager = PagedKVManager.create(
|
|
1200
|
+
mPageTable,
|
|
1201
|
+
mK,
|
|
1202
|
+
mV,
|
|
1203
|
+
FastDivmodDivisor(page_size),
|
|
1204
|
+
batch_idx,
|
|
1205
|
+
head_idx_kv,
|
|
1206
|
+
tidx,
|
|
1207
|
+
seqlen.seqlen_k,
|
|
1208
|
+
0, # leftpad_k
|
|
1209
|
+
self.n_block_size,
|
|
1210
|
+
self.head_dim_padded,
|
|
1211
|
+
self.head_dim_v_padded,
|
|
1212
|
+
num_load_threads,
|
|
1213
|
+
mK.element_type,
|
|
1214
|
+
)
|
|
1215
|
+
tKsK, tKgK = None, None
|
|
1216
|
+
tVsV, tVgV = None, None
|
|
1217
|
+
|
|
1218
|
+
load_Q = partial(
|
|
1219
|
+
self.load_Q,
|
|
1220
|
+
load_Q_fn,
|
|
1221
|
+
mbar_ptr + self.mbar_load_q_full_offset,
|
|
1222
|
+
mbar_ptr + self.mbar_load_q_empty_offset,
|
|
1223
|
+
phase=q_producer_phase,
|
|
1224
|
+
)
|
|
1225
|
+
# We have to use mbarrier directly in the load for KV instead of replying on
|
|
1226
|
+
# pipeline_kv, because we could have different number of TMA bytes for K and V
|
|
1227
|
+
load_K = partial(
|
|
1228
|
+
self.load_KV,
|
|
1229
|
+
tma_atom_K,
|
|
1230
|
+
tKgK,
|
|
1231
|
+
tKsK,
|
|
1232
|
+
paged_kv_manager,
|
|
1233
|
+
sK,
|
|
1234
|
+
mbar_ptr + self.mbar_load_kv_full_offset,
|
|
1235
|
+
mbar_ptr + self.mbar_load_kv_empty_offset,
|
|
1236
|
+
K_or_V="K",
|
|
1237
|
+
)
|
|
1238
|
+
load_V = partial(
|
|
1239
|
+
self.load_KV,
|
|
1240
|
+
tma_atom_V,
|
|
1241
|
+
tVgV,
|
|
1242
|
+
tVsV,
|
|
1243
|
+
paged_kv_manager,
|
|
1244
|
+
sV,
|
|
1245
|
+
mbar_ptr + self.mbar_load_kv_full_offset,
|
|
1246
|
+
mbar_ptr + self.mbar_load_kv_empty_offset,
|
|
1247
|
+
K_or_V="V",
|
|
1248
|
+
)
|
|
1249
|
+
|
|
1250
|
+
if const_expr(not self.use_block_sparsity):
|
|
1251
|
+
n_block_min, n_block_max = block_info.get_n_block_min_max(
|
|
1252
|
+
seqlen, m_block, split_idx, num_splits
|
|
1253
|
+
)
|
|
1254
|
+
if const_expr(not self.is_split_kv) or n_block_min < n_block_max:
|
|
1255
|
+
if const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE:
|
|
1256
|
+
load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0
|
|
1257
|
+
n_block_first = n_block_max - 1 if n_block_max > 0 else 0
|
|
1258
|
+
page_idx = (
|
|
1259
|
+
mPageTable[batch_idx, n_block_first]
|
|
1260
|
+
if const_expr(mPageTable is not None and self.use_tma_KV)
|
|
1261
|
+
else None
|
|
1262
|
+
)
|
|
1263
|
+
if const_expr(not self.use_tma_KV):
|
|
1264
|
+
paged_kv_manager.load_page_table(n_block_first)
|
|
1265
|
+
load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0
|
|
1266
|
+
kv_producer_state.advance()
|
|
1267
|
+
if const_expr(self.q_stage == 2) and (const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE):
|
|
1268
|
+
load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1
|
|
1269
|
+
q_producer_phase ^= 1
|
|
1270
|
+
load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0
|
|
1271
|
+
kv_producer_state.advance()
|
|
1272
|
+
for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
|
|
1273
|
+
n_block = n_block_max - 2 - i
|
|
1274
|
+
page_idx = (
|
|
1275
|
+
mPageTable[batch_idx, n_block]
|
|
1276
|
+
if const_expr(mPageTable is not None and self.use_tma_KV)
|
|
1277
|
+
else None
|
|
1278
|
+
)
|
|
1279
|
+
if const_expr(not self.use_tma_KV):
|
|
1280
|
+
paged_kv_manager.load_page_table(n_block)
|
|
1281
|
+
# if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx)
|
|
1282
|
+
load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki
|
|
1283
|
+
kv_producer_state.advance()
|
|
1284
|
+
load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi
|
|
1285
|
+
kv_producer_state.advance()
|
|
1286
|
+
|
|
1287
|
+
else:
|
|
1288
|
+
kv_producer_state, q_producer_phase = produce_block_sparse_loads_sm100(
|
|
1289
|
+
blocksparse_tensors,
|
|
1290
|
+
batch_idx,
|
|
1291
|
+
head_idx,
|
|
1292
|
+
m_block,
|
|
1293
|
+
kv_producer_state,
|
|
1294
|
+
load_Q,
|
|
1295
|
+
load_K,
|
|
1296
|
+
load_V,
|
|
1297
|
+
pipeline_kv,
|
|
1298
|
+
self.q_stage,
|
|
1299
|
+
q_producer_phase,
|
|
1300
|
+
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
|
1301
|
+
)
|
|
1302
|
+
|
|
1303
|
+
|
|
1304
|
+
tile_scheduler.prefetch_next_work()
|
|
1305
|
+
tile_scheduler.advance_to_next_work()
|
|
1306
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1307
|
+
# End of persistent scheduler loop
|
|
1308
|
+
|
|
1309
|
+
@cute.jit
|
|
1310
|
+
def mma(
|
|
1311
|
+
self,
|
|
1312
|
+
tiled_mma_qk: cute.core.ThrMma,
|
|
1313
|
+
tiled_mma_pv: cute.core.ThrMma,
|
|
1314
|
+
sQ: cute.Tensor,
|
|
1315
|
+
sK: cute.Tensor,
|
|
1316
|
+
sV: cute.Tensor,
|
|
1317
|
+
tStSs: Tuple[cute.Tensor, cute.Tensor],
|
|
1318
|
+
tOtOs: tuple[cute.Tensor],
|
|
1319
|
+
tOrPs: Tuple[cute.Tensor, cute.Tensor],
|
|
1320
|
+
pipeline_kv: cutlass.pipeline.PipelineAsync,
|
|
1321
|
+
mbar_ptr: cute.Pointer,
|
|
1322
|
+
block_info: BlockInfo,
|
|
1323
|
+
num_splits: Int32,
|
|
1324
|
+
SeqlenInfoCls: Callable,
|
|
1325
|
+
TileSchedulerCls: Callable,
|
|
1326
|
+
blocksparse_tensors: Optional[BlockSparseTensors],
|
|
1327
|
+
):
|
|
1328
|
+
tSrQ = tiled_mma_qk.make_fragment_A(sQ)
|
|
1329
|
+
tSrK = tiled_mma_qk.make_fragment_B(sK)
|
|
1330
|
+
tOrV = tiled_mma_pv.make_fragment_B(sV)
|
|
1331
|
+
if const_expr(self.q_stage == 2):
|
|
1332
|
+
tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 1])
|
|
1333
|
+
else:
|
|
1334
|
+
tSrQs = (tSrQ[None, None, None, 0],)
|
|
1335
|
+
|
|
1336
|
+
qk_mma_op, pv_mma_op = tiled_mma_qk.op, tiled_mma_pv.op
|
|
1337
|
+
|
|
1338
|
+
gemm_Si = [
|
|
1339
|
+
partial(
|
|
1340
|
+
sm100_utils.gemm_ptx_partial,
|
|
1341
|
+
qk_mma_op,
|
|
1342
|
+
self.tmem_s_offset[stage],
|
|
1343
|
+
tSrQs[stage],
|
|
1344
|
+
sA=sQ[None, None, None, stage],
|
|
1345
|
+
zero_init=True,
|
|
1346
|
+
)
|
|
1347
|
+
for stage in range(self.q_stage)
|
|
1348
|
+
]
|
|
1349
|
+
gemm_Pi = [
|
|
1350
|
+
partial(
|
|
1351
|
+
sm100_utils.gemm_ptx_partial,
|
|
1352
|
+
pv_mma_op,
|
|
1353
|
+
self.tmem_o_offset[stage],
|
|
1354
|
+
tOrPs[stage],
|
|
1355
|
+
sA=None,
|
|
1356
|
+
)
|
|
1357
|
+
for stage in range(self.q_stage)
|
|
1358
|
+
]
|
|
1359
|
+
|
|
1360
|
+
mma_q_consumer_phase = Int32(0)
|
|
1361
|
+
mma_kv_consumer_state = cutlass.pipeline.make_pipeline_state(
|
|
1362
|
+
cutlass.pipeline.PipelineUserType.Consumer, self.kv_stage
|
|
1363
|
+
)
|
|
1364
|
+
P_full_O_rescaled_phase = Int32(0)
|
|
1365
|
+
|
|
1366
|
+
tile_scheduler = TileSchedulerCls()
|
|
1367
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1368
|
+
while work_tile.is_valid_tile:
|
|
1369
|
+
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
|
|
1370
|
+
seqlen = SeqlenInfoCls(batch_idx)
|
|
1371
|
+
|
|
1372
|
+
block_iter_count = Int32(0)
|
|
1373
|
+
process_tile = False
|
|
1374
|
+
|
|
1375
|
+
if const_expr(self.use_block_sparsity):
|
|
1376
|
+
block_iter_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1)
|
|
1377
|
+
process_tile = block_iter_count > Int32(0)
|
|
1378
|
+
else:
|
|
1379
|
+
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)
|
|
1380
|
+
block_iter_count = n_block_max - n_block_min
|
|
1381
|
+
if const_expr(not self.is_split_kv):
|
|
1382
|
+
process_tile = True
|
|
1383
|
+
else:
|
|
1384
|
+
process_tile = n_block_min < n_block_max
|
|
1385
|
+
|
|
1386
|
+
if process_tile:
|
|
1387
|
+
for stage in cutlass.range_constexpr(self.q_stage):
|
|
1388
|
+
# GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1)
|
|
1389
|
+
# 1. wait for Q0 / Q1
|
|
1390
|
+
cute.arch.mbarrier_wait(
|
|
1391
|
+
mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase
|
|
1392
|
+
)
|
|
1393
|
+
# 2. wait for K0
|
|
1394
|
+
if const_expr(stage == 0):
|
|
1395
|
+
pipeline_kv.consumer_wait(mma_kv_consumer_state)
|
|
1396
|
+
tSrKi = tSrK[None, None, None, mma_kv_consumer_state.index]
|
|
1397
|
+
# We don't need to acquire empty S0 / S1.
|
|
1398
|
+
# For the first iteration, we don't need to wait as we're guaranteed S0 / S1
|
|
1399
|
+
# are empty. For subsequent iterations, the wait happened at the end
|
|
1400
|
+
# of the while loop.
|
|
1401
|
+
# 3. gemm
|
|
1402
|
+
# tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True)
|
|
1403
|
+
sK_cur = sK[None, None, None, mma_kv_consumer_state.index]
|
|
1404
|
+
if const_expr(self.uneven_kv_smem):
|
|
1405
|
+
sK_cur = self.offset_kv_smem(
|
|
1406
|
+
sK_cur, mma_kv_consumer_state.index, mma_kv_consumer_state.phase
|
|
1407
|
+
)
|
|
1408
|
+
gemm_Si[stage](tCrB=tSrKi, sB=sK_cur)
|
|
1409
|
+
# 4. release S0 / S1
|
|
1410
|
+
with cute.arch.elect_one():
|
|
1411
|
+
tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage)
|
|
1412
|
+
mma_q_consumer_phase ^= 1
|
|
1413
|
+
# 5. release K0
|
|
1414
|
+
pipeline_kv.consumer_release(mma_kv_consumer_state)
|
|
1415
|
+
mma_kv_consumer_state.advance()
|
|
1416
|
+
# End of GEMM (Q1 * K0 -> S1)
|
|
1417
|
+
# Note: Q0 & Q1 are still needed in the seqlen_kv loop
|
|
1418
|
+
# so we need to release them after the seqlen_kv loop
|
|
1419
|
+
|
|
1420
|
+
# O hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate
|
|
1421
|
+
block_loop_count = block_iter_count - 1
|
|
1422
|
+
O_should_accumulate = False
|
|
1423
|
+
for i in cutlass.range(block_loop_count, unroll=1):
|
|
1424
|
+
# GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop
|
|
1425
|
+
# 1. wait for V0
|
|
1426
|
+
pipeline_kv.consumer_wait(mma_kv_consumer_state)
|
|
1427
|
+
mma_kv_release_state = mma_kv_consumer_state.clone()
|
|
1428
|
+
Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase
|
|
1429
|
+
tOrVi = tOrV[None, None, None, Vi_index]
|
|
1430
|
+
for stage in cutlass.range_constexpr(self.q_stage):
|
|
1431
|
+
# 2. acquire corrected O0/O1_partial and P0 / P1
|
|
1432
|
+
# For the first iteration in this work tile, waiting for O0/O1_partial
|
|
1433
|
+
# means that the correction warps has finished reading tO during
|
|
1434
|
+
# the last iteration of the previous work tile has finished.
|
|
1435
|
+
cute.arch.mbarrier_wait(
|
|
1436
|
+
mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage,
|
|
1437
|
+
P_full_O_rescaled_phase,
|
|
1438
|
+
)
|
|
1439
|
+
# 3. gemm
|
|
1440
|
+
# sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True)
|
|
1441
|
+
# gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate)
|
|
1442
|
+
sV_cur = sV[None, None, None, Vi_index]
|
|
1443
|
+
if const_expr(self.uneven_kv_smem):
|
|
1444
|
+
sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase)
|
|
1445
|
+
gemm_Pi[stage](
|
|
1446
|
+
tCrB=tOrVi,
|
|
1447
|
+
sB=sV_cur,
|
|
1448
|
+
zero_init=not O_should_accumulate,
|
|
1449
|
+
mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage,
|
|
1450
|
+
mbar_phase=P_full_O_rescaled_phase,
|
|
1451
|
+
)
|
|
1452
|
+
# 4. release accumulated O0_partial / O1_partial
|
|
1453
|
+
# Don't need to signal O_full to the correction warps anymore since the
|
|
1454
|
+
# correction warps wait for the softmax warps anyway. By the time the softmax
|
|
1455
|
+
# warps finished, S_i for the next iteration must have been done, so O_i-1
|
|
1456
|
+
# must have been done as well.
|
|
1457
|
+
# with cute.arch.elect_one():
|
|
1458
|
+
# tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage)
|
|
1459
|
+
# 5. release V(i-1)
|
|
1460
|
+
if const_expr(stage == self.q_stage - 1):
|
|
1461
|
+
pipeline_kv.consumer_release(mma_kv_release_state)
|
|
1462
|
+
mma_kv_release_state.advance()
|
|
1463
|
+
# End of GEMM_PV00 (P0 * V0 -> O0_partial)
|
|
1464
|
+
|
|
1465
|
+
# GEMM_QK0i (Q0 * Ki -> S0)
|
|
1466
|
+
# 1. wait for Ki
|
|
1467
|
+
if const_expr(stage == 0):
|
|
1468
|
+
mma_kv_consumer_state.advance()
|
|
1469
|
+
pipeline_kv.consumer_wait(mma_kv_consumer_state)
|
|
1470
|
+
Ki_index, Ki_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase
|
|
1471
|
+
# 2. gemm
|
|
1472
|
+
# Don't need to wait for the softmax warp to have finished reading the previous
|
|
1473
|
+
# Si, since this gemm is scheduled after the PV gemm, which guaranteed that Si
|
|
1474
|
+
# has been read and Pi has been written.
|
|
1475
|
+
# tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrK[None, None, None, Ki_index], zero_init=True)
|
|
1476
|
+
sK_cur = sK[None, None, None, Ki_index]
|
|
1477
|
+
if const_expr(self.uneven_kv_smem):
|
|
1478
|
+
sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase)
|
|
1479
|
+
gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index], sB=sK_cur)
|
|
1480
|
+
# 3. release S0
|
|
1481
|
+
with cute.arch.elect_one():
|
|
1482
|
+
tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage)
|
|
1483
|
+
# End of GEMM_QK0i (Q0 * Ki -> S0)
|
|
1484
|
+
# 4. release Ki
|
|
1485
|
+
pipeline_kv.consumer_release(mma_kv_consumer_state)
|
|
1486
|
+
mma_kv_consumer_state.advance()
|
|
1487
|
+
P_full_O_rescaled_phase ^= 1
|
|
1488
|
+
O_should_accumulate = True
|
|
1489
|
+
# End of seqlen_kv loop
|
|
1490
|
+
|
|
1491
|
+
# release Q0 & Q1
|
|
1492
|
+
with cute.arch.elect_one():
|
|
1493
|
+
for stage in cutlass.range_constexpr(self.q_stage):
|
|
1494
|
+
tcgen05.commit(mbar_ptr + self.mbar_load_q_empty_offset + stage)
|
|
1495
|
+
|
|
1496
|
+
# GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop
|
|
1497
|
+
# 1. wait for V0
|
|
1498
|
+
pipeline_kv.consumer_wait(mma_kv_consumer_state)
|
|
1499
|
+
Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase
|
|
1500
|
+
tOrVi = tOrV[None, None, None, Vi_index]
|
|
1501
|
+
for stage in cutlass.range_constexpr(self.q_stage):
|
|
1502
|
+
# 2. acquire corrected Oi_partial and Pi
|
|
1503
|
+
cute.arch.mbarrier_wait(
|
|
1504
|
+
mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase
|
|
1505
|
+
)
|
|
1506
|
+
# 3. gemm
|
|
1507
|
+
# sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True)
|
|
1508
|
+
# gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate)
|
|
1509
|
+
sV_cur = sV[None, None, None, Vi_index]
|
|
1510
|
+
if const_expr(self.uneven_kv_smem):
|
|
1511
|
+
sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase)
|
|
1512
|
+
gemm_Pi[stage](
|
|
1513
|
+
tCrB=tOrVi,
|
|
1514
|
+
sB=sV_cur,
|
|
1515
|
+
zero_init=not O_should_accumulate,
|
|
1516
|
+
mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage,
|
|
1517
|
+
mbar_phase=P_full_O_rescaled_phase,
|
|
1518
|
+
)
|
|
1519
|
+
# 4. release accumulated O0_partial
|
|
1520
|
+
# We do need O_full here since for the last tile, by the time the softmax warp
|
|
1521
|
+
# has signaled to the correction warps, the softmax warp has just finished compute
|
|
1522
|
+
# the row sum of the current tile. It does not guarantee that the 1st tile
|
|
1523
|
+
# of the next work tile has been computed yet.
|
|
1524
|
+
with cute.arch.elect_one():
|
|
1525
|
+
tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage)
|
|
1526
|
+
# End of GEMM_PV00 (P0 * V0 -> O0_partial)
|
|
1527
|
+
P_full_O_rescaled_phase ^= 1
|
|
1528
|
+
# 5. release Vi_end
|
|
1529
|
+
pipeline_kv.consumer_release(mma_kv_consumer_state)
|
|
1530
|
+
mma_kv_consumer_state.advance()
|
|
1531
|
+
# End of GEMM_PV1(i_end) (P1 * Vi_end -> O1)
|
|
1532
|
+
|
|
1533
|
+
# Advance to next tile
|
|
1534
|
+
tile_scheduler.advance_to_next_work()
|
|
1535
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1536
|
+
# End of persistent scheduler loop
|
|
1537
|
+
|
|
1538
|
+
|
|
1539
|
+
# for both softmax0 and softmax1 warp group
|
|
1540
|
+
@cute.jit
|
|
1541
|
+
def softmax_loop(
|
|
1542
|
+
self,
|
|
1543
|
+
stage: int | Int32,
|
|
1544
|
+
softmax_scale_log2: Float32,
|
|
1545
|
+
softmax_scale: Float32,
|
|
1546
|
+
thr_mma_qk: cute.core.ThrMma,
|
|
1547
|
+
tStSi: cute.Tensor,
|
|
1548
|
+
sScale: cute.Tensor,
|
|
1549
|
+
mLSE: Optional[cute.Tensor],
|
|
1550
|
+
learnable_sink: Optional[cute.Tensor],
|
|
1551
|
+
mbar_ptr: cute.Pointer,
|
|
1552
|
+
block_info: BlockInfo,
|
|
1553
|
+
num_splits: Int32,
|
|
1554
|
+
SeqlenInfoCls: Callable,
|
|
1555
|
+
AttentionMaskCls: Callable,
|
|
1556
|
+
TileSchedulerCls: Callable,
|
|
1557
|
+
aux_tensors: Optional[list] = None,
|
|
1558
|
+
fastdiv_mods=(None, None),
|
|
1559
|
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
|
1560
|
+
):
|
|
1561
|
+
"""Compute softmax on attention scores from QK matrix multiplication.
|
|
1562
|
+
|
|
1563
|
+
This method handles the softmax computation for either the first or second half of the
|
|
1564
|
+
attention matrix, depending on the 'stage' parameter. It calculates row-wise maximum
|
|
1565
|
+
and sum values needed for stable softmax computation, applies optional masking, and
|
|
1566
|
+
transforms raw attention scores into probability distributions.
|
|
1567
|
+
|
|
1568
|
+
The implementation uses specialized memory access patterns and efficient math operations
|
|
1569
|
+
for computing exp(x) using exp2 functions. It also coordinates pipeline
|
|
1570
|
+
synchronization between MMA, correction, and sequence processing stages.
|
|
1571
|
+
"""
|
|
1572
|
+
tidx = cute.arch.thread_idx()[0] % (
|
|
1573
|
+
cute.arch.WARP_SIZE
|
|
1574
|
+
# * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids)
|
|
1575
|
+
* (len(self.softmax0_warp_ids))
|
|
1576
|
+
)
|
|
1577
|
+
|
|
1578
|
+
tStScale = cute.composition(tStSi, cute.make_layout((self.m_block_size, 1)))
|
|
1579
|
+
tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2]))
|
|
1580
|
+
tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1)))
|
|
1581
|
+
|
|
1582
|
+
tilePlikeFP32 = self.mma_tiler_qk[1] // 32 * self.v_dtype.width
|
|
1583
|
+
tStP_layout = cute.composition(
|
|
1584
|
+
tStSi.layout, cute.make_layout((self.m_block_size, tilePlikeFP32))
|
|
1585
|
+
)
|
|
1586
|
+
tStP = cute.make_tensor(tStSi.iterator + self.tmem_s_to_p_offset, tStP_layout)
|
|
1587
|
+
|
|
1588
|
+
tmem_load_atom = cute.make_copy_atom(
|
|
1589
|
+
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)),
|
|
1590
|
+
Float32,
|
|
1591
|
+
)
|
|
1592
|
+
thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi).get_slice(tidx)
|
|
1593
|
+
tStS_t2r = thr_tmem_load.partition_S(tStSi)
|
|
1594
|
+
|
|
1595
|
+
tmem_store_scale_atom = cute.make_copy_atom(
|
|
1596
|
+
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)),
|
|
1597
|
+
Float32,
|
|
1598
|
+
)
|
|
1599
|
+
thr_tmem_store_scale = tcgen05.make_tmem_copy(tmem_store_scale_atom, tStScale).get_slice(
|
|
1600
|
+
tidx
|
|
1601
|
+
)
|
|
1602
|
+
|
|
1603
|
+
tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale)
|
|
1604
|
+
tmem_store_atom = cute.make_copy_atom(
|
|
1605
|
+
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)),
|
|
1606
|
+
Float32,
|
|
1607
|
+
)
|
|
1608
|
+
thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx)
|
|
1609
|
+
tStP_r2t = thr_tmem_store.partition_D(tStP)
|
|
1610
|
+
|
|
1611
|
+
mma_si_consumer_phase = Int32(0)
|
|
1612
|
+
si_corr_producer_phase = Int32(1)
|
|
1613
|
+
s0_s1_sequence_phase = Int32(1 if stage == 0 else 0)
|
|
1614
|
+
|
|
1615
|
+
# self.warp_scheduler_barrier_init()
|
|
1616
|
+
|
|
1617
|
+
warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
|
|
1618
|
+
mbar_s0_s1_sequence_offset = self.mbar_s0_s1_sequence_offset + warp_idx_in_wg
|
|
1619
|
+
|
|
1620
|
+
tile_scheduler = TileSchedulerCls()
|
|
1621
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1622
|
+
while work_tile.is_valid_tile:
|
|
1623
|
+
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
|
|
1624
|
+
seqlen = SeqlenInfoCls(batch_idx)
|
|
1625
|
+
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)
|
|
1626
|
+
|
|
1627
|
+
mask = AttentionMaskCls(seqlen)
|
|
1628
|
+
shared_mask_kwargs = dict(
|
|
1629
|
+
m_block=self.q_stage * m_block + stage,
|
|
1630
|
+
thr_mma=thr_mma_qk,
|
|
1631
|
+
thr_tmem_load=thr_tmem_load,
|
|
1632
|
+
mask_causal=self.is_causal,
|
|
1633
|
+
mask_local=self.is_local,
|
|
1634
|
+
batch_idx=batch_idx,
|
|
1635
|
+
head_idx=head_idx,
|
|
1636
|
+
aux_tensors=aux_tensors,
|
|
1637
|
+
)
|
|
1638
|
+
|
|
1639
|
+
# Recompute fastdiv_mods if necessary
|
|
1640
|
+
recompute_fastdiv_mods_q = cutlass.const_expr(
|
|
1641
|
+
aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
|
|
1642
|
+
)
|
|
1643
|
+
recompute_fastdiv_mods_k = cutlass.const_expr(
|
|
1644
|
+
aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k)
|
|
1645
|
+
)
|
|
1646
|
+
|
|
1647
|
+
if cutlass.const_expr(fastdiv_mods is not None):
|
|
1648
|
+
seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
|
|
1649
|
+
fastdiv_mods = (
|
|
1650
|
+
seqlen_q_divmod
|
|
1651
|
+
if not recompute_fastdiv_mods_q
|
|
1652
|
+
else FastDivmodDivisor(seqlen.seqlen_q),
|
|
1653
|
+
seqlen_k_divmod
|
|
1654
|
+
if not recompute_fastdiv_mods_k
|
|
1655
|
+
else FastDivmodDivisor(seqlen.seqlen_k),
|
|
1656
|
+
)
|
|
1657
|
+
|
|
1658
|
+
mask_mod = self.mask_mod if const_expr(self.mask_mod is not None) else None
|
|
1659
|
+
mask_fn = partial(
|
|
1660
|
+
mask.apply_mask_sm100,
|
|
1661
|
+
mask_mod=mask_mod,
|
|
1662
|
+
fastdiv_mods=fastdiv_mods,
|
|
1663
|
+
**shared_mask_kwargs,
|
|
1664
|
+
)
|
|
1665
|
+
if const_expr(self.use_block_sparsity):
|
|
1666
|
+
# Full blocks dont need mask_mod
|
|
1667
|
+
mask_fn_none = partial(
|
|
1668
|
+
mask.apply_mask_sm100,
|
|
1669
|
+
mask_mod=None,
|
|
1670
|
+
fastdiv_mods=fastdiv_mods,
|
|
1671
|
+
**shared_mask_kwargs,
|
|
1672
|
+
)
|
|
1673
|
+
else:
|
|
1674
|
+
mask_fn_none = None
|
|
1675
|
+
|
|
1676
|
+
softmax = SoftmaxSm100.create(
|
|
1677
|
+
softmax_scale_log2,
|
|
1678
|
+
rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0,
|
|
1679
|
+
softmax_scale=softmax_scale,
|
|
1680
|
+
)
|
|
1681
|
+
softmax.reset()
|
|
1682
|
+
|
|
1683
|
+
if const_expr(self.use_block_sparsity):
|
|
1684
|
+
tile_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1)
|
|
1685
|
+
has_work = tile_block_count > Int32(0)
|
|
1686
|
+
else:
|
|
1687
|
+
tile_block_count = n_block_max - n_block_min
|
|
1688
|
+
has_work = const_expr(not self.is_split_kv) or tile_block_count > Int32(0)
|
|
1689
|
+
|
|
1690
|
+
softmax_step = partial(
|
|
1691
|
+
self.softmax_step,
|
|
1692
|
+
softmax=softmax,
|
|
1693
|
+
mbar_ptr=mbar_ptr,
|
|
1694
|
+
mbar_s0_s1_sequence_offset=mbar_s0_s1_sequence_offset,
|
|
1695
|
+
thr_mma_qk=thr_mma_qk,
|
|
1696
|
+
thr_tmem_load=thr_tmem_load,
|
|
1697
|
+
thr_tmem_store=thr_tmem_store,
|
|
1698
|
+
thr_tmem_store_scale=thr_tmem_store_scale,
|
|
1699
|
+
tStS_t2r=tStS_t2r,
|
|
1700
|
+
tStScale_r2t=tStScale_r2t,
|
|
1701
|
+
tStP_r2t=tStP_r2t,
|
|
1702
|
+
sScale=sScale,
|
|
1703
|
+
stage=stage,
|
|
1704
|
+
batch_idx=batch_idx,
|
|
1705
|
+
head_idx=head_idx,
|
|
1706
|
+
m_block=self.q_stage * m_block + stage,
|
|
1707
|
+
seqlen=seqlen,
|
|
1708
|
+
aux_tensors=aux_tensors,
|
|
1709
|
+
fastdiv_mods=fastdiv_mods,
|
|
1710
|
+
)
|
|
1711
|
+
|
|
1712
|
+
if has_work:
|
|
1713
|
+
# Softmax acts as the producer: wait until correction signals the stage is empty
|
|
1714
|
+
cute.arch.mbarrier_wait(
|
|
1715
|
+
mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase
|
|
1716
|
+
)
|
|
1717
|
+
si_corr_producer_phase ^= 1
|
|
1718
|
+
|
|
1719
|
+
# Block sparse or dense iteration
|
|
1720
|
+
if const_expr(self.use_block_sparsity):
|
|
1721
|
+
# When aux_tensors exist, Q indices beyond seqlen_q must be wrapped to avoid
|
|
1722
|
+
# OOB aux_tensor access. Only edge tiles (where m_tile_end > seqlen_q) need this.
|
|
1723
|
+
if const_expr(aux_tensors is not None):
|
|
1724
|
+
m_tile_end = (self.q_stage * m_block + stage + 1) * self.m_block_size
|
|
1725
|
+
check_m_boundary = m_tile_end > seqlen.seqlen_q
|
|
1726
|
+
else:
|
|
1727
|
+
check_m_boundary = False
|
|
1728
|
+
(
|
|
1729
|
+
mma_si_consumer_phase,
|
|
1730
|
+
si_corr_producer_phase,
|
|
1731
|
+
s0_s1_sequence_phase,
|
|
1732
|
+
empty_tile,
|
|
1733
|
+
) = softmax_block_sparse_sm100(
|
|
1734
|
+
blocksparse_tensors,
|
|
1735
|
+
batch_idx,
|
|
1736
|
+
head_idx,
|
|
1737
|
+
m_block,
|
|
1738
|
+
softmax_step,
|
|
1739
|
+
mask_fn,
|
|
1740
|
+
mask_fn_none,
|
|
1741
|
+
mma_si_consumer_phase,
|
|
1742
|
+
si_corr_producer_phase,
|
|
1743
|
+
s0_s1_sequence_phase,
|
|
1744
|
+
mbar_ptr,
|
|
1745
|
+
self.mbar_softmax_corr_full_offset,
|
|
1746
|
+
self.mbar_softmax_corr_empty_offset,
|
|
1747
|
+
self.mbar_P_full_O_rescaled_offset,
|
|
1748
|
+
self.mbar_P_full_2_offset,
|
|
1749
|
+
self.q_stage,
|
|
1750
|
+
Int32(stage),
|
|
1751
|
+
check_m_boundary,
|
|
1752
|
+
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
|
1753
|
+
)
|
|
1754
|
+
if not empty_tile:
|
|
1755
|
+
sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0]
|
|
1756
|
+
if const_expr(mLSE is not None or learnable_sink is not None):
|
|
1757
|
+
sScale[
|
|
1758
|
+
tidx + stage * self.m_block_size + self.m_block_size * 2
|
|
1759
|
+
] = softmax.row_max[0]
|
|
1760
|
+
# if tidx == 0:
|
|
1761
|
+
# cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0])
|
|
1762
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage)
|
|
1763
|
+
# if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0])
|
|
1764
|
+
else:
|
|
1765
|
+
if const_expr(not self.is_split_kv) or tile_block_count > Int32(0):
|
|
1766
|
+
mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(
|
|
1767
|
+
mma_si_consumer_phase,
|
|
1768
|
+
si_corr_producer_phase,
|
|
1769
|
+
s0_s1_sequence_phase,
|
|
1770
|
+
n_block_max - 1,
|
|
1771
|
+
is_first=True,
|
|
1772
|
+
mask_fn=partial(mask_fn, mask_seqlen=True),
|
|
1773
|
+
)
|
|
1774
|
+
n_block_max -= 1
|
|
1775
|
+
# Next couple of iterations with causal masking
|
|
1776
|
+
if const_expr(self.is_causal or self.is_local):
|
|
1777
|
+
n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask(
|
|
1778
|
+
seqlen, m_block, n_block_min
|
|
1779
|
+
)
|
|
1780
|
+
for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1):
|
|
1781
|
+
n_block = n_block_max - 1 - n_tile
|
|
1782
|
+
mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = (
|
|
1783
|
+
softmax_step(
|
|
1784
|
+
mma_si_consumer_phase,
|
|
1785
|
+
si_corr_producer_phase,
|
|
1786
|
+
s0_s1_sequence_phase,
|
|
1787
|
+
n_block,
|
|
1788
|
+
mask_fn=partial(mask_fn, mask_seqlen=False),
|
|
1789
|
+
)
|
|
1790
|
+
)
|
|
1791
|
+
n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask)
|
|
1792
|
+
# The remaining iterations have no masking (but may still need mask_mod)
|
|
1793
|
+
n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask(
|
|
1794
|
+
seqlen, m_block, n_block_min
|
|
1795
|
+
)
|
|
1796
|
+
for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1):
|
|
1797
|
+
n_block = n_block_max - n_tile - 1
|
|
1798
|
+
if const_expr(self.mask_mod is not None):
|
|
1799
|
+
mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(
|
|
1800
|
+
mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block,
|
|
1801
|
+
mask_fn=partial(mask_fn, mask_seqlen=False),
|
|
1802
|
+
)
|
|
1803
|
+
else:
|
|
1804
|
+
mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(
|
|
1805
|
+
mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block,
|
|
1806
|
+
)
|
|
1807
|
+
# Separate iterations with local masking on the left
|
|
1808
|
+
if const_expr(self.is_local and block_info.window_size_left is not None):
|
|
1809
|
+
n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask)
|
|
1810
|
+
for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1):
|
|
1811
|
+
n_block = n_block_max - 1 - n_tile
|
|
1812
|
+
mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = (
|
|
1813
|
+
softmax_step(
|
|
1814
|
+
mma_si_consumer_phase,
|
|
1815
|
+
si_corr_producer_phase,
|
|
1816
|
+
s0_s1_sequence_phase,
|
|
1817
|
+
n_block,
|
|
1818
|
+
mask_fn=partial(mask_fn, mask_seqlen=False),
|
|
1819
|
+
)
|
|
1820
|
+
)
|
|
1821
|
+
# Now that we no longer already have the 1st iteration, need mask_seqlen=True here
|
|
1822
|
+
|
|
1823
|
+
# Dense path always writes scale / signals
|
|
1824
|
+
sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0]
|
|
1825
|
+
if const_expr(mLSE is not None or learnable_sink is not None):
|
|
1826
|
+
sScale[
|
|
1827
|
+
tidx + stage * self.m_block_size + self.m_block_size * 2
|
|
1828
|
+
] = softmax.row_max[0]
|
|
1829
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage)
|
|
1830
|
+
|
|
1831
|
+
# # Write LSE to gmem
|
|
1832
|
+
# if const_expr(mLSE is not None):
|
|
1833
|
+
# acc_O_mn_row_is_zero_or_nan = softmax.row_sum[0] == 0.0 or softmax.row_sum[0] != softmax.row_sum[0]
|
|
1834
|
+
# scale = (
|
|
1835
|
+
# cute.arch.rcp_approx(softmax.row_sum[0] if not acc_O_mn_row_is_zero_or_nan else 1.0)
|
|
1836
|
+
# )
|
|
1837
|
+
# LN2 = math.log(2.0)
|
|
1838
|
+
# lse = (
|
|
1839
|
+
# (softmax.row_max[0] * softmax.scale_log2 + utils.log2f(softmax.row_sum[0])) * LN2
|
|
1840
|
+
# if not acc_O_mn_row_is_zero_or_nan else -Float32.inf
|
|
1841
|
+
# )
|
|
1842
|
+
# if const_expr(not seqlen.has_cu_seqlens_q):
|
|
1843
|
+
# mLSE_cur = mLSE[None, head_idx, batch_idx]
|
|
1844
|
+
# else:
|
|
1845
|
+
# mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx])
|
|
1846
|
+
# gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block * 2 + stage,))
|
|
1847
|
+
# if tidx < seqlen.seqlen_q - (m_block * 2 + stage) * self.m_block_size:
|
|
1848
|
+
# gLSE[tidx] = lse
|
|
1849
|
+
|
|
1850
|
+
# Advance to next tile
|
|
1851
|
+
tile_scheduler.advance_to_next_work()
|
|
1852
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1853
|
+
# End of persistent scheduler loop
|
|
1854
|
+
|
|
1855
|
+
@cute.jit
|
|
1856
|
+
def softmax_step(
|
|
1857
|
+
self,
|
|
1858
|
+
mma_si_consumer_phase: Int32,
|
|
1859
|
+
si_corr_producer_phase: Int32,
|
|
1860
|
+
s0_s1_sequence_phase: Int32,
|
|
1861
|
+
n_block: Int32,
|
|
1862
|
+
softmax: SoftmaxSm100,
|
|
1863
|
+
mbar_ptr: cute.Pointer,
|
|
1864
|
+
mbar_s0_s1_sequence_offset: Int32,
|
|
1865
|
+
thr_mma_qk: cute.core.ThrMma,
|
|
1866
|
+
thr_tmem_load: cute.CopyAtom,
|
|
1867
|
+
thr_tmem_store: cute.CopyAtom,
|
|
1868
|
+
thr_tmem_store_scale: cute.CopyAtom,
|
|
1869
|
+
tStS_t2r: cute.Tensor,
|
|
1870
|
+
tStScale_r2t: cute.Tensor,
|
|
1871
|
+
tStP_r2t: cute.Tensor,
|
|
1872
|
+
sScale: cute.Tensor,
|
|
1873
|
+
stage: int | Int32,
|
|
1874
|
+
batch_idx: Int32,
|
|
1875
|
+
head_idx: Int32,
|
|
1876
|
+
m_block: Int32,
|
|
1877
|
+
seqlen,
|
|
1878
|
+
aux_tensors: Optional[list] = None,
|
|
1879
|
+
fastdiv_mods=(None, None),
|
|
1880
|
+
mask_fn: Optional[Callable] = None,
|
|
1881
|
+
is_first: bool = False,
|
|
1882
|
+
) -> Tuple[cute.Int32, cute.Int32, cute.Int32]:
|
|
1883
|
+
"""Perform a single step of the softmax computation on a block of attention scores.
|
|
1884
|
+
|
|
1885
|
+
This method processes one block of the attention matrix, computing numerically stable
|
|
1886
|
+
softmax by first finding the row maximum, subtracting it from all elements, applying
|
|
1887
|
+
exponential function, and then normalizing by the sum of exponentials. It also handles
|
|
1888
|
+
optional masking of attention scores.
|
|
1889
|
+
|
|
1890
|
+
The method involves several key operations:
|
|
1891
|
+
1. Loading attention scores from tensor memory
|
|
1892
|
+
2. Applying optional masking based on position
|
|
1893
|
+
3. Computing row-wise maximum values for numerical stability
|
|
1894
|
+
4. Transforming scores using exp2(x*scale - max*scale)
|
|
1895
|
+
5. Computing row sums for normalization
|
|
1896
|
+
6. Coordinating pipeline synchronization between different processing stages
|
|
1897
|
+
"""
|
|
1898
|
+
tilePlikeFP32 = self.mma_tiler_qk[1] // Float32.width * self.v_dtype.width
|
|
1899
|
+
tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2]))
|
|
1900
|
+
tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1)))
|
|
1901
|
+
tScP = cute.composition(tScS, cute.make_layout((self.m_block_size, tilePlikeFP32)))
|
|
1902
|
+
|
|
1903
|
+
# Wait for Si
|
|
1904
|
+
cute.arch.mbarrier_wait(mbar_ptr + self.mbar_S_full_offset + stage, mma_si_consumer_phase)
|
|
1905
|
+
tSrS_t2r = cute.make_fragment(thr_tmem_load.partition_D(tScS).shape, self.qk_acc_dtype)
|
|
1906
|
+
cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r)
|
|
1907
|
+
if cutlass.const_expr(self.score_mod is not None):
|
|
1908
|
+
self.apply_score_mod(
|
|
1909
|
+
tSrS_t2r,
|
|
1910
|
+
thr_tmem_load,
|
|
1911
|
+
thr_mma_qk,
|
|
1912
|
+
batch_idx,
|
|
1913
|
+
head_idx,
|
|
1914
|
+
m_block,
|
|
1915
|
+
n_block,
|
|
1916
|
+
softmax,
|
|
1917
|
+
seqlen,
|
|
1918
|
+
aux_tensors,
|
|
1919
|
+
fastdiv_mods,
|
|
1920
|
+
)
|
|
1921
|
+
|
|
1922
|
+
if const_expr(mask_fn is not None):
|
|
1923
|
+
mask_fn(tSrS_t2r, n_block=n_block)
|
|
1924
|
+
row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load(), is_first)
|
|
1925
|
+
|
|
1926
|
+
if const_expr(not is_first):
|
|
1927
|
+
# tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScScale).shape, Float32)
|
|
1928
|
+
# tSrScale_r2t[0] = acc_scale
|
|
1929
|
+
# cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t)
|
|
1930
|
+
# cute.arch.fence_view_async_tmem_store()
|
|
1931
|
+
thread_idx = thr_tmem_load.thr_idx
|
|
1932
|
+
sScale[thread_idx + stage * self.m_block_size] = acc_scale
|
|
1933
|
+
# if thread_idx == 0: cute.printf("softmax acc_scale stage %d: %f, row_max = %f\n", stage, acc_scale, row_max)
|
|
1934
|
+
# Notify correction wg that row_max is ready
|
|
1935
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage)
|
|
1936
|
+
|
|
1937
|
+
# if thread_idx == 0 and stage == 0: cute.print_tensor(tSrS_t2r)
|
|
1938
|
+
# print(tSrS_t2r)
|
|
1939
|
+
softmax.scale_subtract_rowmax(tSrS_t2r, row_max)
|
|
1940
|
+
# Sequence barrier wait
|
|
1941
|
+
if const_expr(self.s0_s1_barrier):
|
|
1942
|
+
cute.arch.mbarrier_wait(
|
|
1943
|
+
mbar_ptr + mbar_s0_s1_sequence_offset + stage * 4, s0_s1_sequence_phase
|
|
1944
|
+
)
|
|
1945
|
+
tSrP_r2t_f32 = cute.make_fragment(thr_tmem_store.partition_S(tScP).shape, Float32)
|
|
1946
|
+
tSrP_r2t = cute.make_tensor(
|
|
1947
|
+
cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype),
|
|
1948
|
+
tSrS_t2r.layout,
|
|
1949
|
+
)
|
|
1950
|
+
# softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t)
|
|
1951
|
+
softmax.apply_exp2_convert(
|
|
1952
|
+
tSrS_t2r,
|
|
1953
|
+
tSrP_r2t,
|
|
1954
|
+
e2e=mask_fn is None and self.head_dim_padded <= 128,
|
|
1955
|
+
e2e_freq=self.e2e_freq,
|
|
1956
|
+
)
|
|
1957
|
+
# Sequence barrier arrive
|
|
1958
|
+
if const_expr(self.s0_s1_barrier):
|
|
1959
|
+
cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4)
|
|
1960
|
+
# print(tSrP_r2t_f32, tStP_r2t)
|
|
1961
|
+
# cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t)
|
|
1962
|
+
for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 4 * 3):
|
|
1963
|
+
cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i])
|
|
1964
|
+
cute.arch.fence_view_async_tmem_store()
|
|
1965
|
+
# Notify mma warp that P is ready
|
|
1966
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage)
|
|
1967
|
+
for i in cutlass.range_constexpr(
|
|
1968
|
+
cute.size(tStP_r2t.shape[2]) // 4 * 3, cute.size(tStP_r2t.shape[2])
|
|
1969
|
+
):
|
|
1970
|
+
cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i])
|
|
1971
|
+
cute.arch.fence_view_async_tmem_store()
|
|
1972
|
+
# Notify mma warp that the 2nd half of P is ready
|
|
1973
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_2_offset + stage)
|
|
1974
|
+
cute.arch.mbarrier_wait(
|
|
1975
|
+
mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase
|
|
1976
|
+
)
|
|
1977
|
+
softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first)
|
|
1978
|
+
# acc_scale = cute.arch.exp2(acc_scale_)
|
|
1979
|
+
return mma_si_consumer_phase ^ 1, si_corr_producer_phase ^ 1, s0_s1_sequence_phase ^ 1
|
|
1980
|
+
|
|
1981
|
+
@cute.jit
|
|
1982
|
+
def correction_loop(
|
|
1983
|
+
self,
|
|
1984
|
+
thr_mma_qk: cute.core.ThrMma,
|
|
1985
|
+
thr_mma_pv: cute.core.ThrMma,
|
|
1986
|
+
tStS: cute.Tensor,
|
|
1987
|
+
tOtOs: tuple[cute.Tensor],
|
|
1988
|
+
sScale: cute.Tensor,
|
|
1989
|
+
mO: cute.Tensor,
|
|
1990
|
+
mLSE: cute.Tensor,
|
|
1991
|
+
sO: cute.Tensor,
|
|
1992
|
+
learnable_sink: Optional[cute.Tensor],
|
|
1993
|
+
gmem_tiled_copy_O: cute.TiledCopy,
|
|
1994
|
+
tma_atom_O: cute.CopyAtom,
|
|
1995
|
+
mbar_ptr: cute.Pointer,
|
|
1996
|
+
softmax_scale_log2: Float32,
|
|
1997
|
+
block_info: BlockInfo,
|
|
1998
|
+
num_splits: Int32,
|
|
1999
|
+
SeqlenInfoCls: Callable,
|
|
2000
|
+
TileSchedulerCls: Callable,
|
|
2001
|
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
|
2002
|
+
):
|
|
2003
|
+
tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids))
|
|
2004
|
+
tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2]))
|
|
2005
|
+
tStScale_layout = cute.composition(tStS.layout, cute.make_layout((self.m_block_size, 1)))
|
|
2006
|
+
tStScales = tuple(
|
|
2007
|
+
cute.make_tensor(tStS.iterator + self.tmem_vec_offset[stage], tStScale_layout)
|
|
2008
|
+
for stage in range(self.q_stage)
|
|
2009
|
+
)
|
|
2010
|
+
tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1)))
|
|
2011
|
+
tmem_load_v_atom = cute.make_copy_atom(
|
|
2012
|
+
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)),
|
|
2013
|
+
self.qk_acc_dtype,
|
|
2014
|
+
)
|
|
2015
|
+
thr_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScales[0]).get_slice(tidx)
|
|
2016
|
+
|
|
2017
|
+
tStScales_t2r = [thr_tmem_load_vec.partition_S(tStScales[stage]) for stage in range(self.q_stage)]
|
|
2018
|
+
tSrScale_t2r_shape = thr_tmem_load_vec.partition_D(tScScale).shape
|
|
2019
|
+
|
|
2020
|
+
# First iter: no correction is required
|
|
2021
|
+
for stage in cutlass.range_constexpr(self.q_stage):
|
|
2022
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage)
|
|
2023
|
+
|
|
2024
|
+
softmax_corr_consumer_phase = Int32(0)
|
|
2025
|
+
o_corr_consumer_phase = Int32(0)
|
|
2026
|
+
corr_epi_producer_phase = Int32(1)
|
|
2027
|
+
|
|
2028
|
+
tile_scheduler = TileSchedulerCls()
|
|
2029
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
2030
|
+
while work_tile.is_valid_tile:
|
|
2031
|
+
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
|
|
2032
|
+
seqlen = SeqlenInfoCls(batch_idx)
|
|
2033
|
+
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)
|
|
2034
|
+
|
|
2035
|
+
if const_expr(self.is_split_kv):
|
|
2036
|
+
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
|
|
2037
|
+
else:
|
|
2038
|
+
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]
|
|
2039
|
+
gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0))
|
|
2040
|
+
|
|
2041
|
+
# Default LSE to -inf for invalid split_idx tiles
|
|
2042
|
+
stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage
|
|
2043
|
+
|
|
2044
|
+
if const_expr(self.use_block_sparsity):
|
|
2045
|
+
total_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1)
|
|
2046
|
+
has_work = total_block_count > Int32(0)
|
|
2047
|
+
else:
|
|
2048
|
+
total_block_count = n_block_max - n_block_min
|
|
2049
|
+
has_work = const_expr(not self.is_split_kv) or total_block_count > Int32(0)
|
|
2050
|
+
|
|
2051
|
+
if has_work:
|
|
2052
|
+
# Ignore first signal from softmax as no correction is required
|
|
2053
|
+
cute.arch.mbarrier_wait(
|
|
2054
|
+
mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase
|
|
2055
|
+
)
|
|
2056
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 0)
|
|
2057
|
+
if const_expr(self.q_stage == 2):
|
|
2058
|
+
cute.arch.mbarrier_wait(
|
|
2059
|
+
mbar_ptr + self.mbar_softmax_corr_full_offset + 1, softmax_corr_consumer_phase
|
|
2060
|
+
)
|
|
2061
|
+
softmax_corr_consumer_phase ^= 1
|
|
2062
|
+
|
|
2063
|
+
tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32)
|
|
2064
|
+
for i in cutlass.range(total_block_count - 1, unroll=1):
|
|
2065
|
+
for stage in cutlass.range_constexpr(self.q_stage):
|
|
2066
|
+
# wait for S0 / S1
|
|
2067
|
+
cute.arch.mbarrier_wait(
|
|
2068
|
+
mbar_ptr + self.mbar_softmax_corr_full_offset + stage,
|
|
2069
|
+
softmax_corr_consumer_phase,
|
|
2070
|
+
)
|
|
2071
|
+
# cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r)
|
|
2072
|
+
# cute.arch.fence_view_async_tmem_load()
|
|
2073
|
+
# scale = tSrScale_t2r[0]
|
|
2074
|
+
scale = sScale[tidx + stage * self.m_block_size]
|
|
2075
|
+
should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0
|
|
2076
|
+
# should_rescale = True
|
|
2077
|
+
# if tidx == 0: cute.printf("Correction scale i = %d, for stage %d: %f, should_rescale = %d\n", i, stage, scale, should_rescale)
|
|
2078
|
+
# Don't need O_full anymore, since by the time softmax has signaled the correction
|
|
2079
|
+
# warps, S_i must have been done, so O_i-1 must have been done as well.
|
|
2080
|
+
# cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase)
|
|
2081
|
+
if should_rescale:
|
|
2082
|
+
self.correction_rescale(
|
|
2083
|
+
thr_mma_pv, tOtOs[stage], tidx, scale
|
|
2084
|
+
)
|
|
2085
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage)
|
|
2086
|
+
if const_expr(self.q_stage == 2):
|
|
2087
|
+
cute.arch.mbarrier_arrive(
|
|
2088
|
+
mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage)
|
|
2089
|
+
)
|
|
2090
|
+
else:
|
|
2091
|
+
cute.arch.mbarrier_arrive(
|
|
2092
|
+
mbar_ptr + self.mbar_softmax_corr_empty_offset + stage
|
|
2093
|
+
)
|
|
2094
|
+
softmax_corr_consumer_phase ^= 1
|
|
2095
|
+
# o_corr_consumer_phase ^= 1
|
|
2096
|
+
if const_expr(self.q_stage == 2):
|
|
2097
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 1)
|
|
2098
|
+
# End of seqlen_corr_loop_steps
|
|
2099
|
+
|
|
2100
|
+
# Even in the case of self.overlap_sO_sQ, we can write to stage 0 of sO without
|
|
2101
|
+
# additional sync because the MMA in the top half must have been done.
|
|
2102
|
+
# Similarly we can write to stage 1 of sO without additional sync.
|
|
2103
|
+
learnable_sink_val = [None] * self.q_stage
|
|
2104
|
+
if const_expr(learnable_sink is not None):
|
|
2105
|
+
if const_expr(not self.pack_gqa):
|
|
2106
|
+
sink_val = Float32(learnable_sink[head_idx])
|
|
2107
|
+
learnable_sink_val = [sink_val] * self.q_stage
|
|
2108
|
+
else: # Each thread might have a different sink value due to different q_head
|
|
2109
|
+
for stage in cutlass.range_constexpr(self.q_stage):
|
|
2110
|
+
q_head_idx = (
|
|
2111
|
+
(self.q_stage * m_block + stage) * self.m_block_size + tidx
|
|
2112
|
+
) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead
|
|
2113
|
+
learnable_sink_val[stage] = Float32(learnable_sink[q_head_idx])
|
|
2114
|
+
for stage in cutlass.range_constexpr(self.q_stage):
|
|
2115
|
+
cute.arch.mbarrier_wait(
|
|
2116
|
+
mbar_ptr + self.mbar_softmax_corr_full_offset + stage,
|
|
2117
|
+
softmax_corr_consumer_phase,
|
|
2118
|
+
)
|
|
2119
|
+
# cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r)
|
|
2120
|
+
# cute.arch.fence_view_async_tmem_load()
|
|
2121
|
+
# scale = tSrScale_t2r[0]
|
|
2122
|
+
row_sum = sScale[tidx + stage * self.m_block_size]
|
|
2123
|
+
if const_expr(mLSE is not None or learnable_sink is not None):
|
|
2124
|
+
row_max = sScale[tidx + stage * self.m_block_size + self.m_block_size * 2]
|
|
2125
|
+
else:
|
|
2126
|
+
row_max = None
|
|
2127
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage)
|
|
2128
|
+
if const_expr(learnable_sink is not None):
|
|
2129
|
+
LOG2_E = math.log2(math.e)
|
|
2130
|
+
sink_val = learnable_sink_val[stage]
|
|
2131
|
+
if const_expr(not self.is_split_kv) or split_idx == 0:
|
|
2132
|
+
if row_max == -Float32.inf:
|
|
2133
|
+
# It's possible to have an empty row with splitKV.
|
|
2134
|
+
row_max = sink_val * (LOG2_E / softmax_scale_log2)
|
|
2135
|
+
row_sum = Float32(1.0)
|
|
2136
|
+
else:
|
|
2137
|
+
row_sum += utils.exp2f(
|
|
2138
|
+
sink_val * LOG2_E - row_max * softmax_scale_log2
|
|
2139
|
+
)
|
|
2140
|
+
acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum
|
|
2141
|
+
stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan)
|
|
2142
|
+
scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0)
|
|
2143
|
+
cute.arch.mbarrier_wait(
|
|
2144
|
+
mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase
|
|
2145
|
+
)
|
|
2146
|
+
if const_expr(not self.use_correction_warps_for_epi):
|
|
2147
|
+
cute.arch.mbarrier_wait(
|
|
2148
|
+
mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase
|
|
2149
|
+
)
|
|
2150
|
+
self.correction_epilogue(
|
|
2151
|
+
thr_mma_pv,
|
|
2152
|
+
tOtOs[stage],
|
|
2153
|
+
tidx,
|
|
2154
|
+
stage,
|
|
2155
|
+
m_block,
|
|
2156
|
+
seqlen.seqlen_q,
|
|
2157
|
+
scale,
|
|
2158
|
+
sO[None, None, stage],
|
|
2159
|
+
mO_cur,
|
|
2160
|
+
gO,
|
|
2161
|
+
gmem_tiled_copy_O,
|
|
2162
|
+
)
|
|
2163
|
+
if const_expr(not self.use_correction_warps_for_epi):
|
|
2164
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage)
|
|
2165
|
+
# Signal for the next work tile that O buffers in tmem are already read, so
|
|
2166
|
+
# mma warp can write to them
|
|
2167
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage)
|
|
2168
|
+
# if tidx == 0: cute.printf("Correction final scale for stage %d: %f\n", stage, scale)
|
|
2169
|
+
|
|
2170
|
+
o_corr_consumer_phase ^= 1
|
|
2171
|
+
softmax_corr_consumer_phase ^= 1
|
|
2172
|
+
corr_epi_producer_phase ^= 1
|
|
2173
|
+
else:
|
|
2174
|
+
# WARNING: we need some code before the const_expr, see https://github.com/NVIDIA/cutlass/issues/2781
|
|
2175
|
+
if const_expr(self.use_correction_warps_for_epi):
|
|
2176
|
+
gmem_tiled_copy_O_for_empty_tile = gmem_tiled_copy_O
|
|
2177
|
+
else:
|
|
2178
|
+
gmem_tiled_copy_O_for_empty_tile = None
|
|
2179
|
+
if const_expr(self.use_block_sparsity):
|
|
2180
|
+
(
|
|
2181
|
+
softmax_corr_consumer_phase,
|
|
2182
|
+
o_corr_consumer_phase,
|
|
2183
|
+
corr_epi_producer_phase,
|
|
2184
|
+
) = handle_block_sparse_empty_tile_correction_sm100(
|
|
2185
|
+
tidx,
|
|
2186
|
+
self.q_stage,
|
|
2187
|
+
self.m_block_size,
|
|
2188
|
+
self.qhead_per_kvhead,
|
|
2189
|
+
self.pack_gqa,
|
|
2190
|
+
self.is_split_kv,
|
|
2191
|
+
learnable_sink,
|
|
2192
|
+
mLSE,
|
|
2193
|
+
seqlen,
|
|
2194
|
+
m_block,
|
|
2195
|
+
head_idx,
|
|
2196
|
+
batch_idx,
|
|
2197
|
+
split_idx,
|
|
2198
|
+
sScale,
|
|
2199
|
+
stats,
|
|
2200
|
+
self.correction_epilogue,
|
|
2201
|
+
thr_mma_pv,
|
|
2202
|
+
tOtOs,
|
|
2203
|
+
sO,
|
|
2204
|
+
mbar_ptr,
|
|
2205
|
+
self.mbar_softmax_corr_full_offset,
|
|
2206
|
+
self.mbar_softmax_corr_empty_offset,
|
|
2207
|
+
self.mbar_P_full_O_rescaled_offset,
|
|
2208
|
+
self.mbar_P_full_2_offset,
|
|
2209
|
+
self.mbar_corr_epi_full_offset,
|
|
2210
|
+
self.mbar_corr_epi_empty_offset,
|
|
2211
|
+
softmax_corr_consumer_phase,
|
|
2212
|
+
o_corr_consumer_phase,
|
|
2213
|
+
corr_epi_producer_phase,
|
|
2214
|
+
softmax_scale_log2,
|
|
2215
|
+
mO_cur,
|
|
2216
|
+
gO,
|
|
2217
|
+
gmem_tiled_copy_O_for_empty_tile,
|
|
2218
|
+
)
|
|
2219
|
+
|
|
2220
|
+
if const_expr(mLSE is not None):
|
|
2221
|
+
if const_expr(not seqlen.has_cu_seqlens_q):
|
|
2222
|
+
if const_expr(self.is_split_kv):
|
|
2223
|
+
mLSE_cur = mLSE[None, head_idx, batch_idx, split_idx]
|
|
2224
|
+
else:
|
|
2225
|
+
mLSE_cur = mLSE[None, head_idx, batch_idx]
|
|
2226
|
+
else:
|
|
2227
|
+
offset = (
|
|
2228
|
+
seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q)
|
|
2229
|
+
)
|
|
2230
|
+
if const_expr(self.is_split_kv):
|
|
2231
|
+
mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx, split_idx])
|
|
2232
|
+
else:
|
|
2233
|
+
mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx])
|
|
2234
|
+
for stage in cutlass.range_constexpr(self.q_stage):
|
|
2235
|
+
gLSE = cute.local_tile(
|
|
2236
|
+
mLSE_cur, (self.m_block_size,), (self.q_stage * m_block + stage,)
|
|
2237
|
+
)
|
|
2238
|
+
row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage]
|
|
2239
|
+
# if tidx == 0 and stage <= 1:
|
|
2240
|
+
# cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan)
|
|
2241
|
+
LN2 = math.log(2.0)
|
|
2242
|
+
lse = (
|
|
2243
|
+
(row_max * softmax_scale_log2 + utils.log2f(row_sum)) * LN2
|
|
2244
|
+
if not acc_O_mn_row_is_zero_or_nan
|
|
2245
|
+
else -Float32.inf
|
|
2246
|
+
)
|
|
2247
|
+
seqlen_q = (
|
|
2248
|
+
seqlen.seqlen_q
|
|
2249
|
+
if const_expr(not self.pack_gqa)
|
|
2250
|
+
else seqlen.seqlen_q * self.qhead_per_kvhead
|
|
2251
|
+
)
|
|
2252
|
+
if tidx < seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size:
|
|
2253
|
+
# This actually just works with PackGQA too
|
|
2254
|
+
gLSE[tidx] = lse
|
|
2255
|
+
|
|
2256
|
+
# Advance to next tile
|
|
2257
|
+
tile_scheduler.advance_to_next_work()
|
|
2258
|
+
work_tile = tile_scheduler.get_current_work()
|
|
2259
|
+
# End of persistent scheduler loop
|
|
2260
|
+
|
|
2261
|
+
@cute.jit
|
|
2262
|
+
def correction_rescale(
|
|
2263
|
+
self,
|
|
2264
|
+
thr_mma: cute.core.ThrMma,
|
|
2265
|
+
tOtO: cute.Tensor,
|
|
2266
|
+
tidx: Int32,
|
|
2267
|
+
scale: Float32,
|
|
2268
|
+
):
|
|
2269
|
+
"""Rescale intermediate attention results based on softmax normalization factor.
|
|
2270
|
+
|
|
2271
|
+
This method performs a crucial correction step in the attention computation pipeline.
|
|
2272
|
+
When processing attention in blocks, the softmax normalization factors may change
|
|
2273
|
+
as new blocks are processed. This method rescales previously computed partial
|
|
2274
|
+
output values to account for updated normalization factors.
|
|
2275
|
+
|
|
2276
|
+
The implementation uses efficient tensor memory operations to:
|
|
2277
|
+
1. Load existing partial attention output from tensor memory
|
|
2278
|
+
2. Apply the scaling factor to all elements
|
|
2279
|
+
3. Store the rescaled results back to tensor memory
|
|
2280
|
+
"""
|
|
2281
|
+
tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2]))
|
|
2282
|
+
corr_tile_size = 16 # tuneable parameter
|
|
2283
|
+
tmem_load_atom = cute.make_copy_atom(
|
|
2284
|
+
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
|
|
2285
|
+
self.pv_acc_dtype,
|
|
2286
|
+
)
|
|
2287
|
+
tmem_store_atom = cute.make_copy_atom(
|
|
2288
|
+
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
|
|
2289
|
+
self.pv_acc_dtype,
|
|
2290
|
+
)
|
|
2291
|
+
tOtO_i = cute.composition(tOtO, cute.make_layout((self.m_block_size, corr_tile_size)))
|
|
2292
|
+
tOcO_i = cute.composition(tOcO, cute.make_layout((self.m_block_size, corr_tile_size)))
|
|
2293
|
+
thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i).get_slice(tidx)
|
|
2294
|
+
thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i).get_slice(tidx)
|
|
2295
|
+
tOtO_t2r = thr_tmem_load.partition_S(tOtO_i)
|
|
2296
|
+
tOrO_t2r_shape = thr_tmem_load.partition_D(tOcO_i).shape
|
|
2297
|
+
tOtO_r2t = thr_tmem_store.partition_D(tOtO_i)
|
|
2298
|
+
|
|
2299
|
+
frg_count = self.head_dim_v_padded // corr_tile_size
|
|
2300
|
+
tOrO_frg = cute.make_fragment((tOrO_t2r_shape, frg_count), self.pv_acc_dtype)
|
|
2301
|
+
for i in cutlass.range_constexpr(frg_count):
|
|
2302
|
+
tOrO_frg = cute.make_fragment(tOrO_t2r_shape, self.pv_acc_dtype)
|
|
2303
|
+
tOtO_t2r_i = cute.make_tensor(tOtO_t2r.iterator + i * corr_tile_size, tOtO_t2r.layout)
|
|
2304
|
+
cute.copy(thr_tmem_load, tOtO_t2r_i, tOrO_frg)
|
|
2305
|
+
for j in cutlass.range(0, cute.size(tOrO_frg), 2, unroll_full=True):
|
|
2306
|
+
tOrO_frg[j], tOrO_frg[j + 1] = utils.mul_packed_f32x2(
|
|
2307
|
+
(tOrO_frg[j], tOrO_frg[j + 1]),
|
|
2308
|
+
(scale, scale),
|
|
2309
|
+
)
|
|
2310
|
+
tOtO_r2t_i = cute.make_tensor(tOtO_r2t.iterator + i * corr_tile_size, tOtO_r2t.layout)
|
|
2311
|
+
cute.copy(thr_tmem_store, tOrO_frg, tOtO_r2t_i)
|
|
2312
|
+
cute.arch.fence_view_async_tmem_store()
|
|
2313
|
+
|
|
2314
|
+
@cute.jit
|
|
2315
|
+
def correction_epilogue(
|
|
2316
|
+
self,
|
|
2317
|
+
thr_mma: cute.core.ThrMma,
|
|
2318
|
+
tOtO: cute.Tensor,
|
|
2319
|
+
tidx: Int32,
|
|
2320
|
+
stage: Int32,
|
|
2321
|
+
m_block: Int32,
|
|
2322
|
+
seqlen_q: Int32,
|
|
2323
|
+
scale: Float32,
|
|
2324
|
+
sO: cute.Tensor,
|
|
2325
|
+
mO_cur: Optional[cute.Tensor] = None,
|
|
2326
|
+
gO: Optional[cute.Tensor] = None,
|
|
2327
|
+
gmem_tiled_copy_O: Optional[cute.TiledCopy] = None,
|
|
2328
|
+
):
|
|
2329
|
+
"""Apply final scaling and transformation to attention output before writing to global memory.
|
|
2330
|
+
|
|
2331
|
+
This correction_epilogue function handles the final processing step for attention output values.
|
|
2332
|
+
It applies a scaling factor to the accumulated attention results and prepares the
|
|
2333
|
+
data for efficient transfer back to global memory.
|
|
2334
|
+
|
|
2335
|
+
The method performs:
|
|
2336
|
+
1. Loading of accumulated attention results from tensor memory
|
|
2337
|
+
2. Application of the final output scaling factor
|
|
2338
|
+
3. Type conversion if necessary (typically from higher precision accumulator to output precision)
|
|
2339
|
+
4. Reorganization of data for optimal memory access patterns
|
|
2340
|
+
5. Preparation for efficient TMA store operations
|
|
2341
|
+
|
|
2342
|
+
:param thr_mma: Thread MMA operation for the computation
|
|
2343
|
+
:type thr_mma: cute.core.ThrMma
|
|
2344
|
+
:param tOtO: Tensor containing accumulated attention output
|
|
2345
|
+
:type tOtO: cute.Tensor
|
|
2346
|
+
:param scale: Final scaling factor to apply to the output
|
|
2347
|
+
:type scale: Float32
|
|
2348
|
+
:param sO: Shared memory tensor for the final output
|
|
2349
|
+
:type sO: cute.Tensor
|
|
2350
|
+
"""
|
|
2351
|
+
|
|
2352
|
+
corr_tile_size = 32 * 8 // self.o_dtype.width
|
|
2353
|
+
tOsO = thr_mma.partition_C(sO)
|
|
2354
|
+
tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2]))
|
|
2355
|
+
|
|
2356
|
+
tOtO_i = cute.logical_divide(tOtO, cute.make_layout((self.m_block_size, corr_tile_size)))
|
|
2357
|
+
tOcO_i = cute.logical_divide(tOcO, cute.make_layout((self.m_block_size, corr_tile_size)))
|
|
2358
|
+
tOsO_i = cute.logical_divide(tOsO, cute.make_layout((self.m_block_size, corr_tile_size)))
|
|
2359
|
+
|
|
2360
|
+
epi_subtile = (self.epi_tile[0], corr_tile_size)
|
|
2361
|
+
tmem_copy_atom = sm100_utils_basic.get_tmem_load_op(
|
|
2362
|
+
self.mma_tiler_pv,
|
|
2363
|
+
self.o_layout,
|
|
2364
|
+
self.o_dtype,
|
|
2365
|
+
self.pv_acc_dtype,
|
|
2366
|
+
epi_subtile,
|
|
2367
|
+
use_2cta_instrs=False,
|
|
2368
|
+
)
|
|
2369
|
+
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]).get_slice(
|
|
2370
|
+
tidx
|
|
2371
|
+
)
|
|
2372
|
+
thr_tmem_load = tiled_tmem_load.get_slice(tidx)
|
|
2373
|
+
smem_copy_atom = sm100_utils_basic.get_smem_store_op(
|
|
2374
|
+
self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load
|
|
2375
|
+
)
|
|
2376
|
+
tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load)
|
|
2377
|
+
|
|
2378
|
+
tOtO_t2r = thr_tmem_load.partition_S(tOtO_i[(None, None), None])
|
|
2379
|
+
tOsO_s2r = thr_tmem_load.partition_D(tOsO_i[(None, None), None])
|
|
2380
|
+
tOcO_t2r = thr_tmem_load.partition_D(tOcO_i[(None, None), None])
|
|
2381
|
+
for i in cutlass.range_constexpr(self.head_dim_v_padded // corr_tile_size):
|
|
2382
|
+
tOtO_t2r_i = tOtO_t2r[None, 0, 0, i]
|
|
2383
|
+
tOsO_r2s_i = tOsO_s2r[None, 0, 0, i]
|
|
2384
|
+
tOrO_frg = cute.make_fragment(tOcO_t2r[None, 0, 0, i].shape, self.pv_acc_dtype)
|
|
2385
|
+
cute.copy(tiled_tmem_load, tOtO_t2r_i, tOrO_frg)
|
|
2386
|
+
for j in cutlass.range_constexpr(0, cute.size(tOrO_frg), 2):
|
|
2387
|
+
tOrO_frg[j], tOrO_frg[j + 1] = utils.mul_packed_f32x2(
|
|
2388
|
+
(tOrO_frg[j], tOrO_frg[j + 1]),
|
|
2389
|
+
(scale, scale),
|
|
2390
|
+
)
|
|
2391
|
+
tOrO_frg_cvt = cute.make_fragment(tOrO_frg.shape, self.o_dtype)
|
|
2392
|
+
tOrO_frg_cvt.store(tOrO_frg.load().to(self.o_dtype))
|
|
2393
|
+
cute.copy(tiled_smem_store, tOrO_frg_cvt, tOsO_r2s_i)
|
|
2394
|
+
# fence view async shared
|
|
2395
|
+
cute.arch.fence_proxy(
|
|
2396
|
+
cute.arch.ProxyKind.async_shared,
|
|
2397
|
+
space=cute.arch.SharedSpace.shared_cta,
|
|
2398
|
+
)
|
|
2399
|
+
|
|
2400
|
+
if const_expr(self.use_correction_warps_for_epi):
|
|
2401
|
+
assert(not self.use_tma_O)
|
|
2402
|
+
assert(gmem_tiled_copy_O is not None)
|
|
2403
|
+
cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue),
|
|
2404
|
+
number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE)
|
|
2405
|
+
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
|
|
2406
|
+
tOsO = gmem_thr_copy_O.partition_S(sO)
|
|
2407
|
+
cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
|
|
2408
|
+
tOgO = gmem_thr_copy_O.partition_D(gO)
|
|
2409
|
+
tOcO = gmem_thr_copy_O.partition_S(cO)
|
|
2410
|
+
t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
|
|
2411
|
+
tOpO = utils.predicate_k(tOcO, limit=mO_cur.shape[1])
|
|
2412
|
+
pack_gqa = PackGQA(
|
|
2413
|
+
self.m_block_size,
|
|
2414
|
+
self.head_dim_v_padded,
|
|
2415
|
+
self.check_hdim_v_oob,
|
|
2416
|
+
self.qhead_per_kvhead,
|
|
2417
|
+
)
|
|
2418
|
+
|
|
2419
|
+
# load acc O from smem to rmem for wider vectorization
|
|
2420
|
+
tOrO = cute.make_fragment_like(tOsO, self.o_dtype)
|
|
2421
|
+
cute.autovec_copy(tOsO, tOrO)
|
|
2422
|
+
# copy acc O from rmem to gmem
|
|
2423
|
+
if const_expr(not self.pack_gqa):
|
|
2424
|
+
for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
|
|
2425
|
+
if (
|
|
2426
|
+
t0OcO[0, rest_m, 0][0]
|
|
2427
|
+
< seqlen_q
|
|
2428
|
+
- (self.q_stage * m_block + stage) * self.m_block_size
|
|
2429
|
+
- tOcO[0][0]
|
|
2430
|
+
):
|
|
2431
|
+
cute.copy(
|
|
2432
|
+
gmem_tiled_copy_O,
|
|
2433
|
+
tOrO[None, rest_m, None],
|
|
2434
|
+
tOgO[None, rest_m, None, self.q_stage * m_block + stage],
|
|
2435
|
+
pred=tOpO[None, rest_m, None]
|
|
2436
|
+
if const_expr(self.check_hdim_v_oob)
|
|
2437
|
+
else None,
|
|
2438
|
+
)
|
|
2439
|
+
else:
|
|
2440
|
+
pack_gqa.store_O(
|
|
2441
|
+
mO_cur,
|
|
2442
|
+
tOrO,
|
|
2443
|
+
gmem_tiled_copy_O,
|
|
2444
|
+
tidx,
|
|
2445
|
+
self.q_stage * m_block + stage,
|
|
2446
|
+
seqlen_q,
|
|
2447
|
+
)
|
|
2448
|
+
|
|
2449
|
+
@cute.jit
|
|
2450
|
+
def epilogue_s2g(
|
|
2451
|
+
self,
|
|
2452
|
+
mO: cute.Tensor,
|
|
2453
|
+
sO: cute.Tensor,
|
|
2454
|
+
gmem_tiled_copy_O: cute.TiledCopy,
|
|
2455
|
+
tma_atom_O: Optional[cute.CopyAtom],
|
|
2456
|
+
mbar_ptr: cute.Pointer,
|
|
2457
|
+
block_info: BlockInfo,
|
|
2458
|
+
num_splits: int,
|
|
2459
|
+
SeqlenInfoCls: Callable,
|
|
2460
|
+
TileSchedulerCls: Callable,
|
|
2461
|
+
):
|
|
2462
|
+
epi_consumer_phase = Int32(0)
|
|
2463
|
+
tile_scheduler = TileSchedulerCls()
|
|
2464
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
2465
|
+
while work_tile.is_valid_tile:
|
|
2466
|
+
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
|
|
2467
|
+
seqlen = SeqlenInfoCls(batch_idx)
|
|
2468
|
+
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)
|
|
2469
|
+
|
|
2470
|
+
if const_expr(not self.is_split_kv) or n_block_min < n_block_max:
|
|
2471
|
+
if const_expr(self.is_split_kv):
|
|
2472
|
+
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
|
|
2473
|
+
else:
|
|
2474
|
+
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]
|
|
2475
|
+
gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0))
|
|
2476
|
+
if const_expr(self.use_tma_O):
|
|
2477
|
+
store_O, _, _ = copy_utils.tma_get_copy_fn(
|
|
2478
|
+
tma_atom_O, 0, cute.make_layout(1), sO, gO
|
|
2479
|
+
)
|
|
2480
|
+
for stage in cutlass.range_constexpr(self.q_stage):
|
|
2481
|
+
# wait from corr, issue tma store on smem
|
|
2482
|
+
# 1. wait for O0 / O1 final
|
|
2483
|
+
cute.arch.mbarrier_wait(
|
|
2484
|
+
mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase
|
|
2485
|
+
)
|
|
2486
|
+
# 2. copy O0 / O1 to gmem
|
|
2487
|
+
store_O(src_idx=stage, dst_idx=self.q_stage * m_block + stage)
|
|
2488
|
+
cute.arch.cp_async_bulk_commit_group()
|
|
2489
|
+
for stage in cutlass.range_constexpr(self.q_stage):
|
|
2490
|
+
# Ensure O0 / O1 buffer is ready to be released
|
|
2491
|
+
if const_expr(self.q_stage == 2):
|
|
2492
|
+
cute.arch.cp_async_bulk_wait_group(1 - stage, read=True)
|
|
2493
|
+
else:
|
|
2494
|
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
|
2495
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage)
|
|
2496
|
+
else:
|
|
2497
|
+
tidx = cute.arch.thread_idx()[0] % (
|
|
2498
|
+
cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)
|
|
2499
|
+
)
|
|
2500
|
+
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
|
|
2501
|
+
tOsO = gmem_thr_copy_O.partition_S(sO)
|
|
2502
|
+
cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
|
|
2503
|
+
tOgO = gmem_thr_copy_O.partition_D(gO)
|
|
2504
|
+
tOcO = gmem_thr_copy_O.partition_S(cO)
|
|
2505
|
+
t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
|
|
2506
|
+
tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])
|
|
2507
|
+
pack_gqa = PackGQA(
|
|
2508
|
+
self.m_block_size,
|
|
2509
|
+
self.head_dim_v_padded,
|
|
2510
|
+
self.check_hdim_v_oob,
|
|
2511
|
+
self.qhead_per_kvhead,
|
|
2512
|
+
)
|
|
2513
|
+
for stage in cutlass.range_constexpr(self.q_stage):
|
|
2514
|
+
# wait from corr, issue tma store on smem
|
|
2515
|
+
# 1. wait for O0 / O1 final
|
|
2516
|
+
cute.arch.mbarrier_wait(
|
|
2517
|
+
mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase
|
|
2518
|
+
)
|
|
2519
|
+
# 2. copy O0 / O1 to gmem
|
|
2520
|
+
# load acc O from smem to rmem for wider vectorization
|
|
2521
|
+
tOrO = cute.make_fragment_like(tOsO[None, None, None, 0], self.o_dtype)
|
|
2522
|
+
cute.autovec_copy(tOsO[None, None, None, stage], tOrO)
|
|
2523
|
+
# copy acc O from rmem to gmem
|
|
2524
|
+
if const_expr(not self.pack_gqa):
|
|
2525
|
+
for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
|
|
2526
|
+
if (
|
|
2527
|
+
t0OcO[0, rest_m, 0][0]
|
|
2528
|
+
< seqlen.seqlen_q
|
|
2529
|
+
- (self.q_stage * m_block + stage) * self.m_block_size
|
|
2530
|
+
- tOcO[0][0]
|
|
2531
|
+
):
|
|
2532
|
+
cute.copy(
|
|
2533
|
+
gmem_tiled_copy_O,
|
|
2534
|
+
tOrO[None, rest_m, None],
|
|
2535
|
+
tOgO[None, rest_m, None, self.q_stage * m_block + stage],
|
|
2536
|
+
pred=tOpO[None, rest_m, None]
|
|
2537
|
+
if const_expr(self.check_hdim_v_oob)
|
|
2538
|
+
else None,
|
|
2539
|
+
)
|
|
2540
|
+
else:
|
|
2541
|
+
pack_gqa.store_O(
|
|
2542
|
+
mO_cur,
|
|
2543
|
+
tOrO,
|
|
2544
|
+
gmem_tiled_copy_O,
|
|
2545
|
+
tidx,
|
|
2546
|
+
self.q_stage * m_block + stage,
|
|
2547
|
+
seqlen.seqlen_q,
|
|
2548
|
+
)
|
|
2549
|
+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage)
|
|
2550
|
+
|
|
2551
|
+
epi_consumer_phase ^= 1
|
|
2552
|
+
|
|
2553
|
+
# Advance to next tile
|
|
2554
|
+
tile_scheduler.advance_to_next_work()
|
|
2555
|
+
work_tile = tile_scheduler.get_current_work()
|
|
2556
|
+
|
|
2557
|
+
def load_Q(
|
|
2558
|
+
self,
|
|
2559
|
+
load_Q_fn: Callable,
|
|
2560
|
+
mbar_full_ptr: cute.Pointer,
|
|
2561
|
+
mbar_empty_ptr: cute.Pointer,
|
|
2562
|
+
block: Int32,
|
|
2563
|
+
stage: int,
|
|
2564
|
+
phase: Int32,
|
|
2565
|
+
):
|
|
2566
|
+
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase)
|
|
2567
|
+
with cute.arch.elect_one():
|
|
2568
|
+
cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, self.tma_copy_bytes["Q"])
|
|
2569
|
+
load_Q_fn(src_idx=block, dst_idx=stage, tma_bar_ptr=mbar_full_ptr + stage)
|
|
2570
|
+
|
|
2571
|
+
@cute.jit
|
|
2572
|
+
def load_KV(
|
|
2573
|
+
self,
|
|
2574
|
+
tma_atom: Optional[cute.CopyAtom],
|
|
2575
|
+
tXgX: Optional[cute.Tensor],
|
|
2576
|
+
tXsX: Optional[cute.Tensor],
|
|
2577
|
+
paged_kv_manager: Optional[PagedKVManager],
|
|
2578
|
+
sX: cute.Tensor,
|
|
2579
|
+
mbar_full_ptr: cute.Pointer,
|
|
2580
|
+
mbar_empty_ptr: cute.Pointer,
|
|
2581
|
+
block: Int32,
|
|
2582
|
+
producer_state: cutlass.pipeline.PipelineState,
|
|
2583
|
+
K_or_V: Literal["K", "V"],
|
|
2584
|
+
page_idx: Optional[Int32] = None,
|
|
2585
|
+
):
|
|
2586
|
+
assert K_or_V in ("K", "V")
|
|
2587
|
+
stage, phase = producer_state.index, producer_state.phase
|
|
2588
|
+
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase)
|
|
2589
|
+
if const_expr(K_or_V == "K" and self.uneven_kv_smem):
|
|
2590
|
+
# Before this round, the smem location was occupied by V, which is smaller than
|
|
2591
|
+
# K. So we need to wait for the stage after that (stage 1) to be empty as well.
|
|
2592
|
+
if stage == 0:
|
|
2593
|
+
cute.arch.mbarrier_wait(mbar_empty_ptr + 1, phase)
|
|
2594
|
+
|
|
2595
|
+
if const_expr(self.use_tma_KV):
|
|
2596
|
+
assert (
|
|
2597
|
+
tXgX is not None and
|
|
2598
|
+
tXsX is not None and
|
|
2599
|
+
tma_atom is not None
|
|
2600
|
+
)
|
|
2601
|
+
with cute.arch.elect_one():
|
|
2602
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
2603
|
+
mbar_full_ptr + stage, self.tma_copy_bytes[K_or_V],
|
|
2604
|
+
)
|
|
2605
|
+
tXsX_cur = tXsX[None, stage]
|
|
2606
|
+
if const_expr(self.uneven_kv_smem):
|
|
2607
|
+
# Since this is the producer_state, the phase starts at 1, so we have to invert it
|
|
2608
|
+
tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase ^ 1)
|
|
2609
|
+
# Currently we assume that page_size == n_block_size so we index into tXgX with block = 0
|
|
2610
|
+
tXgX_cur = tXgX[None, block] if const_expr(page_idx is None) else tXgX[None, 0, page_idx]
|
|
2611
|
+
cute.copy(tma_atom, tXgX_cur, tXsX_cur, tma_bar_ptr=mbar_full_ptr + stage)
|
|
2612
|
+
else:
|
|
2613
|
+
assert paged_kv_manager is not None
|
|
2614
|
+
paged_kv_manager.load_KV(block, sX[None, None, None, stage], K_or_V)
|
|
2615
|
+
cute.arch.cp_async_commit_group()
|
|
2616
|
+
cute.arch.cp_async_mbarrier_arrive_noinc(mbar_full_ptr + stage)
|
|
2617
|
+
|
|
2618
|
+
@cute.jit
|
|
2619
|
+
def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32):
|
|
2620
|
+
if const_expr(self.uneven_kv_smem):
|
|
2621
|
+
# smem layout is [smem_large, smem_small, smem_large], and the current stride is
|
|
2622
|
+
# (smem_large + smem_small) // 2. So for stage == 1, move right by offset if
|
|
2623
|
+
# phase == 0, or left by offset if phase == 1.
|
|
2624
|
+
offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase)
|
|
2625
|
+
return cute.make_tensor(sX.iterator + offset, sX.layout)
|
|
2626
|
+
else:
|
|
2627
|
+
return sX
|
|
2628
|
+
|
|
2629
|
+
def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr):
|
|
2630
|
+
load_kv_consumer_group = cutlass.pipeline.CooperativeGroup(
|
|
2631
|
+
cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])
|
|
2632
|
+
)
|
|
2633
|
+
if self.use_tma_KV:
|
|
2634
|
+
load_kv_producer_group = cutlass.pipeline.CooperativeGroup(
|
|
2635
|
+
cutlass.pipeline.Agent.Thread, len(self.load_warp_ids)
|
|
2636
|
+
)
|
|
2637
|
+
return cutlass.pipeline.PipelineTmaUmma.create(
|
|
2638
|
+
barrier_storage=load_kv_mbar_ptr,
|
|
2639
|
+
num_stages=self.kv_stage,
|
|
2640
|
+
producer_group=load_kv_producer_group,
|
|
2641
|
+
consumer_group=load_kv_consumer_group,
|
|
2642
|
+
tx_count=self.tma_copy_bytes["K"],
|
|
2643
|
+
)
|
|
2644
|
+
else:
|
|
2645
|
+
load_kv_producer_group = cutlass.pipeline.CooperativeGroup(
|
|
2646
|
+
cutlass.pipeline.Agent.Thread, len(self.load_warp_ids) * cute.arch.WARP_SIZE
|
|
2647
|
+
)
|
|
2648
|
+
return cutlass.pipeline.PipelineAsyncUmma.create(
|
|
2649
|
+
num_stages=self.kv_stage,
|
|
2650
|
+
producer_group=load_kv_producer_group,
|
|
2651
|
+
consumer_group=load_kv_consumer_group,
|
|
2652
|
+
barrier_storage=load_kv_mbar_ptr,
|
|
2653
|
+
)
|
|
2654
|
+
|
|
2655
|
+
# @cute.jit
|
|
2656
|
+
# def warp_scheduler_barrier_init(self):
|
|
2657
|
+
# warp_group_idx = utils.canonical_warp_group_idx(sync=False)
|
|
2658
|
+
# if warp_group_idx == 0:
|
|
2659
|
+
# cute.arch.barrier_arrive(
|
|
2660
|
+
# barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), number_of_threads=2 * 128,
|
|
2661
|
+
# )
|
|
2662
|
+
|
|
2663
|
+
# def warp_scheduler_barrier_sync(self):
|
|
2664
|
+
# cute.arch.barrier(
|
|
2665
|
+
# barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False),
|
|
2666
|
+
# number_of_threads=2 * 128
|
|
2667
|
+
# )
|
|
2668
|
+
|
|
2669
|
+
# def warp_scheduler_barrier_arrive(self):
|
|
2670
|
+
# cur_wg = utils.canonical_warp_group_idx(sync=False)
|
|
2671
|
+
# next_wg = 1 - cur_wg
|
|
2672
|
+
# cute.arch.barrier_arrive(
|
|
2673
|
+
# barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128,
|
|
2674
|
+
# )
|
|
2675
|
+
|
|
2676
|
+
@cute.jit
|
|
2677
|
+
def apply_score_mod(
|
|
2678
|
+
self,
|
|
2679
|
+
tSrS_t2r,
|
|
2680
|
+
thr_tmem_load,
|
|
2681
|
+
thr_mma_qk,
|
|
2682
|
+
batch_idx,
|
|
2683
|
+
head_idx,
|
|
2684
|
+
m_block,
|
|
2685
|
+
n_block,
|
|
2686
|
+
softmax,
|
|
2687
|
+
seqlen: SeqlenInfoQK,
|
|
2688
|
+
aux_tensors=None,
|
|
2689
|
+
fastdiv_mods=(None, None),
|
|
2690
|
+
):
|
|
2691
|
+
"""Apply score modification for SM100 (constant q_idx)."""
|
|
2692
|
+
# Prepare index tensor with extra partition
|
|
2693
|
+
cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size))
|
|
2694
|
+
cS = cute.domain_offset((m_block * self.m_block_size, n_block * self.n_block_size), cS)
|
|
2695
|
+
tScS = thr_mma_qk.partition_C(cS)
|
|
2696
|
+
tScS_t2r = thr_tmem_load.partition_D(tScS)
|
|
2697
|
+
|
|
2698
|
+
# Shared q_idx for all scores
|
|
2699
|
+
q_idx_logical = tScS_t2r[0][0]
|
|
2700
|
+
|
|
2701
|
+
# For Pack-GQA, compute the logical head index for this tile
|
|
2702
|
+
if cutlass.const_expr(self.pack_gqa):
|
|
2703
|
+
# Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead)
|
|
2704
|
+
q_physical = q_idx_logical
|
|
2705
|
+
q_idx_logical = q_physical // self.qhead_per_kvhead
|
|
2706
|
+
head_offset = q_physical - q_idx_logical * self.qhead_per_kvhead
|
|
2707
|
+
head_idx = head_idx * self.qhead_per_kvhead + head_offset
|
|
2708
|
+
|
|
2709
|
+
if cutlass.const_expr(aux_tensors is not None):
|
|
2710
|
+
seqlen_q_divmod, _ = fastdiv_mods
|
|
2711
|
+
_, q_idx_logical = divmod(q_idx_logical, seqlen_q_divmod)
|
|
2712
|
+
|
|
2713
|
+
apply_score_mod_inner(
|
|
2714
|
+
tSrS_t2r,
|
|
2715
|
+
tScS_t2r,
|
|
2716
|
+
self.score_mod,
|
|
2717
|
+
batch_idx,
|
|
2718
|
+
head_idx,
|
|
2719
|
+
softmax.softmax_scale,
|
|
2720
|
+
self.vec_size,
|
|
2721
|
+
self.qk_acc_dtype,
|
|
2722
|
+
aux_tensors,
|
|
2723
|
+
fastdiv_mods,
|
|
2724
|
+
seqlen_info=seqlen,
|
|
2725
|
+
constant_q_idx=q_idx_logical,
|
|
2726
|
+
qhead_per_kvhead=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1,
|
|
2727
|
+
)
|