mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1262 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
|
3
|
+
# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_bwd_sm80.hpp
|
|
4
|
+
# from Cutlass C++ to Cute-DSL.
|
|
5
|
+
import math
|
|
6
|
+
from types import SimpleNamespace
|
|
7
|
+
from typing import Type, Callable, Optional
|
|
8
|
+
from functools import partial
|
|
9
|
+
|
|
10
|
+
import cuda.bindings.driver as cuda
|
|
11
|
+
|
|
12
|
+
import cutlass
|
|
13
|
+
import cutlass.cute as cute
|
|
14
|
+
from cutlass.cute.nvgpu import cpasync, warp
|
|
15
|
+
from cutlass import Float32, Int32
|
|
16
|
+
import cutlass.utils as utils_basic
|
|
17
|
+
|
|
18
|
+
from mslk.attention.flash_attn import ampere_helpers as sm80_utils
|
|
19
|
+
from mslk.attention.flash_attn import utils
|
|
20
|
+
from mslk.attention.flash_attn.mask import AttentionMask
|
|
21
|
+
from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
|
|
22
|
+
from mslk.attention.flash_attn.tile_scheduler import ParamsBase, SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class FlashAttentionBackwardSm80:
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
dtype: Type[cutlass.Numeric],
|
|
29
|
+
head_dim: int,
|
|
30
|
+
head_dim_v: Optional[int] = None,
|
|
31
|
+
qhead_per_kvhead: int = 1,
|
|
32
|
+
m_block_size: int = 64,
|
|
33
|
+
n_block_size: int = 128,
|
|
34
|
+
num_stages_Q: int = 2,
|
|
35
|
+
num_stages_dO: int = 2,
|
|
36
|
+
num_threads: int = 256,
|
|
37
|
+
pack_gqa: bool = False,
|
|
38
|
+
is_causal: bool = False,
|
|
39
|
+
SdP_swapAB: bool = False,
|
|
40
|
+
dKV_swapAB: bool = False,
|
|
41
|
+
dQ_swapAB: bool = False,
|
|
42
|
+
AtomLayoutMSdP: int = 1,
|
|
43
|
+
AtomLayoutNdKV: int = 8,
|
|
44
|
+
AtomLayoutMdQ: int = 1,
|
|
45
|
+
V_in_regs: bool = False,
|
|
46
|
+
):
|
|
47
|
+
"""Initializes the configuration for a flash attention v2 kernel.
|
|
48
|
+
|
|
49
|
+
All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension
|
|
50
|
+
should be a multiple of 8.
|
|
51
|
+
|
|
52
|
+
:param head_dim: head dimension
|
|
53
|
+
:type head_dim: int
|
|
54
|
+
:param m_block_size: m block size
|
|
55
|
+
:type m_block_size: int
|
|
56
|
+
:param n_block_size: n block size
|
|
57
|
+
:type n_block_size: int
|
|
58
|
+
:param num_threads: number of threads
|
|
59
|
+
:type num_threads: int
|
|
60
|
+
:param is_causal: is causal
|
|
61
|
+
"""
|
|
62
|
+
self.dtype = dtype
|
|
63
|
+
# padding head_dim to a multiple of 16 as k_block_size
|
|
64
|
+
hdim_multiple_of = 32
|
|
65
|
+
self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
|
|
66
|
+
head_dim_v = head_dim_v if head_dim_v is not None else head_dim
|
|
67
|
+
self.same_hdim_kv = head_dim == head_dim_v
|
|
68
|
+
self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)
|
|
69
|
+
# Can save registers (and hence be faster) if we don't have to check hdim predication
|
|
70
|
+
self.check_hdim_oob = head_dim != self.head_dim_padded
|
|
71
|
+
self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded
|
|
72
|
+
self.qhead_per_kvhead = qhead_per_kvhead
|
|
73
|
+
self.m_block_size = m_block_size
|
|
74
|
+
self.n_block_size = n_block_size
|
|
75
|
+
self.num_threads = num_threads
|
|
76
|
+
self.pack_gqa = pack_gqa
|
|
77
|
+
self.is_causal = is_causal
|
|
78
|
+
self.num_stages_Q = num_stages_Q
|
|
79
|
+
self.num_stages_dO = num_stages_dO
|
|
80
|
+
self.SdP_swapAB = SdP_swapAB
|
|
81
|
+
self.dKV_swapAB = dKV_swapAB
|
|
82
|
+
self.dQ_swapAB = dQ_swapAB
|
|
83
|
+
self.AtomLayoutMSdP = AtomLayoutMSdP
|
|
84
|
+
self.AtomLayoutNdKV = AtomLayoutNdKV
|
|
85
|
+
self.AtomLayoutMdQ = AtomLayoutMdQ
|
|
86
|
+
num_mma_warps = self.num_threads // cute.arch.WARP_SIZE
|
|
87
|
+
self.Mma_dKV_is_RS = AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_mma_warps and SdP_swapAB and not dKV_swapAB
|
|
88
|
+
self.V_in_regs = V_in_regs
|
|
89
|
+
self.share_QV_smem = V_in_regs
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def can_implement(
|
|
93
|
+
dtype, head_dim, head_dim_v, m_block_size, n_block_size, num_stages_Q, num_stages_dO,
|
|
94
|
+
num_threads, is_causal,
|
|
95
|
+
V_in_regs=False
|
|
96
|
+
) -> bool:
|
|
97
|
+
"""Check if the kernel can be implemented with the given parameters.
|
|
98
|
+
|
|
99
|
+
:param dtype: data type
|
|
100
|
+
:type dtype: cutlass.Numeric
|
|
101
|
+
:param head_dim: head dimension
|
|
102
|
+
:type head_dim: int
|
|
103
|
+
:param m_block_size: m block size
|
|
104
|
+
:type m_block_size: int
|
|
105
|
+
:param n_block_size: n block size
|
|
106
|
+
:type n_block_size: int
|
|
107
|
+
:param num_threads: number of threads
|
|
108
|
+
:type num_threads: int
|
|
109
|
+
:param is_causal: is causal
|
|
110
|
+
:type is_causal: bool
|
|
111
|
+
|
|
112
|
+
:return: True if the kernel can be implemented, False otherwise
|
|
113
|
+
:rtype: bool
|
|
114
|
+
"""
|
|
115
|
+
if dtype not in [cutlass.Float16, cutlass.BFloat16]:
|
|
116
|
+
return False
|
|
117
|
+
if head_dim % 8 != 0:
|
|
118
|
+
return False
|
|
119
|
+
if head_dim_v % 8 != 0:
|
|
120
|
+
return False
|
|
121
|
+
if n_block_size % 16 != 0:
|
|
122
|
+
return False
|
|
123
|
+
if num_threads % 32 != 0:
|
|
124
|
+
return False
|
|
125
|
+
# Check if block size setting is out of shared memory capacity
|
|
126
|
+
# Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size
|
|
127
|
+
smem_usage_Q = m_block_size * head_dim * num_stages_Q * 2
|
|
128
|
+
smem_usage_dO = m_block_size * head_dim_v * num_stages_dO * 2
|
|
129
|
+
smem_usage_K = n_block_size * head_dim * 2
|
|
130
|
+
smem_usage_V = n_block_size * head_dim_v * 2
|
|
131
|
+
smem_usage_QV = (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V)
|
|
132
|
+
smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K
|
|
133
|
+
smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80")
|
|
134
|
+
if smem_usage > smem_capacity:
|
|
135
|
+
return False
|
|
136
|
+
return True
|
|
137
|
+
|
|
138
|
+
def _check_type(
|
|
139
|
+
self,
|
|
140
|
+
mQ_type: Type[cutlass.Numeric],
|
|
141
|
+
mK_type: Type[cutlass.Numeric],
|
|
142
|
+
mV_type: Type[cutlass.Numeric],
|
|
143
|
+
mdO_type: Type[cutlass.Numeric],
|
|
144
|
+
mLSE_type: Type[cutlass.Numeric],
|
|
145
|
+
mdPsum_type: Type[cutlass.Numeric],
|
|
146
|
+
mdQaccum_type: Type[cutlass.Numeric],
|
|
147
|
+
mdK_type: Type[cutlass.Numeric],
|
|
148
|
+
mdV_type: Type[cutlass.Numeric],
|
|
149
|
+
mCuSeqlensQ_type: Type[cutlass.Numeric] | None,
|
|
150
|
+
mCuSeqlensK_type: Type[cutlass.Numeric] | None,
|
|
151
|
+
mSeqUsedQ_type: Type[cutlass.Numeric] | None,
|
|
152
|
+
mSeqUsedK_type: Type[cutlass.Numeric] | None,
|
|
153
|
+
):
|
|
154
|
+
if cutlass.const_expr(not (mQ_type == mK_type == mV_type == mdO_type)):
|
|
155
|
+
raise TypeError("All tensors must have the same data type")
|
|
156
|
+
if cutlass.const_expr(self.qhead_per_kvhead == 1):
|
|
157
|
+
if cutlass.const_expr(not (mdK_type == mdV_type == mQ_type)):
|
|
158
|
+
raise TypeError("mdK and mdV tensors must have the same data type as mQ")
|
|
159
|
+
else:
|
|
160
|
+
if cutlass.const_expr(not (mdK_type == mdV_type == cutlass.Float32)):
|
|
161
|
+
raise TypeError("mdKaccum and mdVaccum tensors must have the data type Float32")
|
|
162
|
+
if cutlass.const_expr(not mQ_type in [cutlass.Float16, cutlass.BFloat16]):
|
|
163
|
+
raise TypeError("Only Float16 or BFloat16 is supported")
|
|
164
|
+
if cutlass.const_expr(not mLSE_type in [cutlass.Float32]):
|
|
165
|
+
raise TypeError("LSE tensor must be Float32")
|
|
166
|
+
if cutlass.const_expr(not mdPsum_type in [cutlass.Float32]):
|
|
167
|
+
raise TypeError("dPsum tensor must be Float32")
|
|
168
|
+
if cutlass.const_expr(not mdQaccum_type in [cutlass.Float32]):
|
|
169
|
+
raise TypeError("dQaccum tensor must be Float32")
|
|
170
|
+
if cutlass.const_expr(mCuSeqlensQ_type not in [None, cutlass.Int32]):
|
|
171
|
+
raise TypeError("cuSeqlensQ tensor must be Int32")
|
|
172
|
+
if cutlass.const_expr(mCuSeqlensK_type not in [None, cutlass.Int32]):
|
|
173
|
+
raise TypeError("cuSeqlensK tensor must be Int32")
|
|
174
|
+
if cutlass.const_expr(mSeqUsedQ_type not in [None, cutlass.Int32]):
|
|
175
|
+
raise TypeError("SeqUsedQ tensor must be Int32")
|
|
176
|
+
if cutlass.const_expr(mSeqUsedK_type not in [None, cutlass.Int32]):
|
|
177
|
+
raise TypeError("SeqUsedK tensor must be Int32")
|
|
178
|
+
assert mQ_type == self.dtype
|
|
179
|
+
|
|
180
|
+
def _setup_attributes(self):
|
|
181
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
182
|
+
# Shared memory layout: Q/K/V
|
|
183
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
184
|
+
sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_padded)
|
|
185
|
+
self.sQ_layout = cute.tile_to_shape(
|
|
186
|
+
sQ_layout_atom, (self.m_block_size, self.head_dim_padded, self.num_stages_Q), (0, 1, 2),
|
|
187
|
+
)
|
|
188
|
+
sK_layout_atom = sQ_layout_atom
|
|
189
|
+
self.sK_layout = cute.tile_to_shape(
|
|
190
|
+
sK_layout_atom, (self.n_block_size, self.head_dim_padded), (0, 1),
|
|
191
|
+
)
|
|
192
|
+
sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_v_padded)
|
|
193
|
+
self.sV_layout = cute.tile_to_shape(
|
|
194
|
+
sV_layout_atom, (self.n_block_size, self.head_dim_v_padded), (0, 1),
|
|
195
|
+
)
|
|
196
|
+
sdO_layout_atom = sV_layout_atom
|
|
197
|
+
self.sdO_layout = cute.tile_to_shape(
|
|
198
|
+
sdO_layout_atom, (self.m_block_size, self.head_dim_v_padded, self.num_stages_dO), (0, 1, 2),
|
|
199
|
+
)
|
|
200
|
+
# TODO: do we set swizzle to be 3 here explicitly?
|
|
201
|
+
sPdS_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.n_block_size)
|
|
202
|
+
self.sPdS_layout = cute.tile_to_shape(
|
|
203
|
+
sPdS_layout_atom, (self.m_block_size, self.n_block_size), (0, 1),
|
|
204
|
+
)
|
|
205
|
+
# We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds,
|
|
206
|
+
# it's still a valid smem address.
|
|
207
|
+
self.sLSE_layout = cute.make_layout(
|
|
208
|
+
(self.m_block_size, self.num_stages_Q),
|
|
209
|
+
stride=(1, cute.round_up(self.m_block_size, 64)),
|
|
210
|
+
)
|
|
211
|
+
sLSEMma_layout = cute.make_layout(
|
|
212
|
+
(self.m_block_size, self.n_block_size, self.num_stages_Q),
|
|
213
|
+
stride=(1, 0, cute.round_up(self.m_block_size, 64)),
|
|
214
|
+
)
|
|
215
|
+
sLSEMma_layout_transposed = cute.make_layout(
|
|
216
|
+
(self.n_block_size, self.m_block_size, self.num_stages_Q),
|
|
217
|
+
stride=(0, 1, cute.round_up(self.m_block_size, 64)),
|
|
218
|
+
)
|
|
219
|
+
self.sLSEMma_layout = sLSEMma_layout if not self.SdP_swapAB else sLSEMma_layout_transposed
|
|
220
|
+
|
|
221
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
222
|
+
# GMEM Tiled copy:
|
|
223
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
224
|
+
# Thread layouts for copies
|
|
225
|
+
universal_copy_bits = 128
|
|
226
|
+
async_copy_elems = universal_copy_bits // self.dtype.width
|
|
227
|
+
# atom_async_copy: async copy atom for QKV load
|
|
228
|
+
atom_async_copy = cute.make_copy_atom(
|
|
229
|
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
|
230
|
+
self.dtype,
|
|
231
|
+
num_bits_per_copy=universal_copy_bits,
|
|
232
|
+
)
|
|
233
|
+
# atom_universal_copy: universal copy atom for O store
|
|
234
|
+
atom_universal_copy = cute.make_copy_atom(
|
|
235
|
+
cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits,
|
|
236
|
+
)
|
|
237
|
+
# tQK_layout: thread layout for QK load
|
|
238
|
+
tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems
|
|
239
|
+
assert self.num_threads % tQK_shape_dim_1 == 0, "num_threads must be divisible by tQK_shape_dim_1"
|
|
240
|
+
tQK_layout = cute.make_ordered_layout(
|
|
241
|
+
(self.num_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0),
|
|
242
|
+
)
|
|
243
|
+
# Do we need to check if we overshot kBlockM when we load Q?
|
|
244
|
+
self.is_even_m_smem_q = self.m_block_size % tQK_layout.shape[0] == 0
|
|
245
|
+
# Do we need to check if we overshot kBlockN when we load K?
|
|
246
|
+
self.is_even_n_smem_k = self.n_block_size % tQK_layout.shape[0] == 0
|
|
247
|
+
tVdO_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems
|
|
248
|
+
assert self.num_threads % tVdO_shape_dim_1 == 0, "num_threads must be divisible by tVdO_shape_dim_1"
|
|
249
|
+
tVdO_layout = cute.make_ordered_layout(
|
|
250
|
+
(self.num_threads // tVdO_shape_dim_1, tVdO_shape_dim_1), order=(1, 0),
|
|
251
|
+
)
|
|
252
|
+
# Do we need to check if we overshot kBlockN when we load V?
|
|
253
|
+
self.is_even_n_smem_v = self.n_block_size % tVdO_layout.shape[0] == 0
|
|
254
|
+
self.is_even_m_smem_do = self.m_block_size % tVdO_layout.shape[0] == 0
|
|
255
|
+
|
|
256
|
+
# Value layouts for copies
|
|
257
|
+
vQKVdO_layout = cute.make_layout((1, async_copy_elems))
|
|
258
|
+
|
|
259
|
+
# gmem_tiled_copy_QK: tiled copy for QK load
|
|
260
|
+
self.gmem_tiled_copy_QK = cute.make_tiled_copy_tv(atom_async_copy, tQK_layout, vQKVdO_layout)
|
|
261
|
+
self.gmem_tiled_copy_VdO = cute.make_tiled_copy_tv(atom_async_copy, tVdO_layout, vQKVdO_layout)
|
|
262
|
+
self.gmem_tiled_copy_dK = cute.make_tiled_copy_tv(atom_universal_copy, tQK_layout, vQKVdO_layout)
|
|
263
|
+
self.gmem_tiled_copy_dV = cute.make_tiled_copy_tv(atom_universal_copy, tVdO_layout, vQKVdO_layout)
|
|
264
|
+
async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width
|
|
265
|
+
|
|
266
|
+
# I think we wouldn't require this with smarter padding
|
|
267
|
+
if cutlass.const_expr(not self.varlen_q):
|
|
268
|
+
async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width
|
|
269
|
+
atom_async_copy_accum = cute.make_copy_atom(
|
|
270
|
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
|
271
|
+
cutlass.Float32,
|
|
272
|
+
num_bits_per_copy=universal_copy_bits,
|
|
273
|
+
)
|
|
274
|
+
else:
|
|
275
|
+
async_copy_elems_accum = 1
|
|
276
|
+
atom_async_copy_accum = cute.make_copy_atom(
|
|
277
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
278
|
+
cutlass.Float32,
|
|
279
|
+
num_bits_per_copy=cutlass.Float32.width,
|
|
280
|
+
)
|
|
281
|
+
self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv(
|
|
282
|
+
atom_async_copy_accum,
|
|
283
|
+
cute.make_layout(self.num_threads),
|
|
284
|
+
cute.make_layout(async_copy_elems_accum),
|
|
285
|
+
)
|
|
286
|
+
self.gmem_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
|
|
287
|
+
cute.make_copy_atom(
|
|
288
|
+
cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=cutlass.Float32.width
|
|
289
|
+
),
|
|
290
|
+
cute.make_layout(self.num_threads),
|
|
291
|
+
cute.make_layout(1)
|
|
292
|
+
)
|
|
293
|
+
if cutlass.const_expr(self.qhead_per_kvhead > 1):
|
|
294
|
+
self.gmem_tiled_copy_dK = self.gmem_tiled_copy_dQaccum
|
|
295
|
+
self.gmem_tiled_copy_dV = self.gmem_tiled_copy_dQaccum
|
|
296
|
+
|
|
297
|
+
def _get_tiled_mma(self):
|
|
298
|
+
num_mma_warps = self.num_threads // 32
|
|
299
|
+
AtomLayoutSdP = (self.AtomLayoutMSdP, num_mma_warps // self.AtomLayoutMSdP, 1) if cutlass.const_expr(not self.SdP_swapAB) else (num_mma_warps // self.AtomLayoutMSdP, self.AtomLayoutMSdP, 1)
|
|
300
|
+
tiled_mma_sdp = cute.make_tiled_mma(
|
|
301
|
+
warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)),
|
|
302
|
+
AtomLayoutSdP,
|
|
303
|
+
permutation_mnk=(AtomLayoutSdP[0] * 16, AtomLayoutSdP[1] * 16, 16),
|
|
304
|
+
)
|
|
305
|
+
AtomLayoutdKV = (self.AtomLayoutNdKV, num_mma_warps // self.AtomLayoutNdKV, 1) if cutlass.const_expr(not self.dKV_swapAB) else (num_mma_warps // self.AtomLayoutNdKV, self.AtomLayoutNdKV, 1)
|
|
306
|
+
tiled_mma_dkv = cute.make_tiled_mma(
|
|
307
|
+
warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)),
|
|
308
|
+
AtomLayoutdKV,
|
|
309
|
+
permutation_mnk=(AtomLayoutdKV[0] * 16, AtomLayoutdKV[1] * 16, 16),
|
|
310
|
+
)
|
|
311
|
+
AtomLayoutdQ = (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if cutlass.const_expr(not self.dQ_swapAB) else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1)
|
|
312
|
+
tiled_mma_dq = cute.make_tiled_mma(
|
|
313
|
+
warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)),
|
|
314
|
+
AtomLayoutdQ,
|
|
315
|
+
permutation_mnk=(AtomLayoutdQ[0] * 16, AtomLayoutdQ[1] * 16, 16),
|
|
316
|
+
)
|
|
317
|
+
return tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq
|
|
318
|
+
|
|
319
|
+
def _get_shared_storage_cls(self):
|
|
320
|
+
sQ_struct, sK_struct, sV_struct, sdO_struct = [
|
|
321
|
+
cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024]
|
|
322
|
+
for layout in (self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout)
|
|
323
|
+
]
|
|
324
|
+
cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout))
|
|
325
|
+
sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024]
|
|
326
|
+
sLSE_struct, sdPsum_struct = [
|
|
327
|
+
cute.struct.Align[cute.struct.MemRange[cutlass.Float32, cute.cosize(layout)], 128]
|
|
328
|
+
for layout in (self.sLSE_layout, self.sLSE_layout)
|
|
329
|
+
]
|
|
330
|
+
sP_struct, sdS_struct = [
|
|
331
|
+
cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 128]
|
|
332
|
+
for layout in (self.sPdS_layout, self.sPdS_layout)
|
|
333
|
+
]
|
|
334
|
+
|
|
335
|
+
@cute.struct
|
|
336
|
+
class SharedStorageSeparateQV:
|
|
337
|
+
sK: sK_struct
|
|
338
|
+
sV: sV_struct
|
|
339
|
+
sQ: sQ_struct
|
|
340
|
+
sdO: sdO_struct
|
|
341
|
+
sLSE: sLSE_struct
|
|
342
|
+
sdPsum: sdPsum_struct
|
|
343
|
+
sP: sP_struct
|
|
344
|
+
sdS: sdS_struct
|
|
345
|
+
# TODO: the case where there's no sP
|
|
346
|
+
|
|
347
|
+
@cute.struct
|
|
348
|
+
class SharedStorageSharedQV:
|
|
349
|
+
sK: sK_struct
|
|
350
|
+
sV: sV_struct
|
|
351
|
+
sQ: sQV_struct
|
|
352
|
+
sdO: sdO_struct
|
|
353
|
+
sLSE: sLSE_struct
|
|
354
|
+
sdPsum: sdPsum_struct
|
|
355
|
+
sP: sP_struct
|
|
356
|
+
sdS: sdS_struct
|
|
357
|
+
|
|
358
|
+
return SharedStorageSeparateQV if cutlass.const_expr(not self.share_QV_smem) else SharedStorageSharedQV
|
|
359
|
+
|
|
360
|
+
@cute.jit
|
|
361
|
+
def __call__(
|
|
362
|
+
self,
|
|
363
|
+
mQ: cute.Tensor,
|
|
364
|
+
mK: cute.Tensor,
|
|
365
|
+
mV: cute.Tensor,
|
|
366
|
+
mdO: cute.Tensor,
|
|
367
|
+
mLSE: cute.Tensor,
|
|
368
|
+
mdPsum: cute.Tensor,
|
|
369
|
+
mdQaccum: cute.Tensor,
|
|
370
|
+
mdK: cute.Tensor,
|
|
371
|
+
mdV: cute.Tensor,
|
|
372
|
+
softmax_scale: cutlass.Float32,
|
|
373
|
+
stream: cuda.CUstream,
|
|
374
|
+
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
|
375
|
+
mCuSeqlensK: Optional[cute.Tensor] = None,
|
|
376
|
+
mSeqUsedQ: Optional[cute.Tensor] = None,
|
|
377
|
+
mSeqUsedK: Optional[cute.Tensor] = None,
|
|
378
|
+
softcap: Float32 | float | None = None,
|
|
379
|
+
window_size_left: Int32 | int | None = None,
|
|
380
|
+
window_size_right: Int32 | int | None = None,
|
|
381
|
+
mdQ_semaphore: Optional[cute.Tensor] = None,
|
|
382
|
+
):
|
|
383
|
+
assert mdQ_semaphore is None, "semaphore not supported yet"
|
|
384
|
+
# Get the data type and check if it is fp16 or bf16
|
|
385
|
+
self._check_type(*(t.element_type if t is not None else None
|
|
386
|
+
for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)))
|
|
387
|
+
# Assume all strides are divisible by 128 bits except the last stride
|
|
388
|
+
new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1])
|
|
389
|
+
mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)]
|
|
390
|
+
self.varlen_q = (mCuSeqlensQ is not None)
|
|
391
|
+
self._setup_attributes()
|
|
392
|
+
SharedStorage = self._get_shared_storage_cls()
|
|
393
|
+
tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq = self._get_tiled_mma()
|
|
394
|
+
|
|
395
|
+
num_head = mQ.shape[1] if cutlass.const_expr(mCuSeqlensQ is not None) else mQ.shape[2]
|
|
396
|
+
|
|
397
|
+
if cutlass.const_expr(mCuSeqlensK is not None):
|
|
398
|
+
TileScheduler = SingleTileVarlenScheduler
|
|
399
|
+
num_batch = mCuSeqlensK.shape[0] - 1
|
|
400
|
+
else:
|
|
401
|
+
TileScheduler = SingleTileScheduler
|
|
402
|
+
num_batch = mK.shape[0]
|
|
403
|
+
|
|
404
|
+
# Uses seqlen k, etc. since main bwd kernel's blocks are over n
|
|
405
|
+
tile_sched_args = TileSchedulerArguments(
|
|
406
|
+
num_block=cute.ceil_div(mK.shape[1], self.n_block_size),
|
|
407
|
+
num_head=num_head,
|
|
408
|
+
num_batch=num_batch,
|
|
409
|
+
num_splits=1,
|
|
410
|
+
seqlen_k=0,
|
|
411
|
+
headdim=mK.shape[2],
|
|
412
|
+
headdim_v=mV.shape[2],
|
|
413
|
+
total_q=mK.shape[0],
|
|
414
|
+
tile_shape_mn=(self.n_block_size, self.m_block_size),
|
|
415
|
+
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1,
|
|
416
|
+
mCuSeqlensQ=mCuSeqlensK,
|
|
417
|
+
mSeqUsedQ=mSeqUsedK,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
|
421
|
+
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
|
|
422
|
+
|
|
423
|
+
softmax_scale_log2 = softmax_scale * math.log2(math.e)
|
|
424
|
+
self.kernel(
|
|
425
|
+
mQ,
|
|
426
|
+
mK,
|
|
427
|
+
mV,
|
|
428
|
+
mdO,
|
|
429
|
+
mLSE,
|
|
430
|
+
mdPsum,
|
|
431
|
+
mdQaccum,
|
|
432
|
+
mdK,
|
|
433
|
+
mdV,
|
|
434
|
+
mCuSeqlensQ,
|
|
435
|
+
mCuSeqlensK,
|
|
436
|
+
mSeqUsedQ,
|
|
437
|
+
mSeqUsedK,
|
|
438
|
+
softmax_scale,
|
|
439
|
+
softmax_scale_log2,
|
|
440
|
+
self.sQ_layout,
|
|
441
|
+
self.sK_layout,
|
|
442
|
+
self.sV_layout,
|
|
443
|
+
self.sdO_layout,
|
|
444
|
+
self.sPdS_layout,
|
|
445
|
+
self.sLSE_layout,
|
|
446
|
+
self.sLSEMma_layout,
|
|
447
|
+
self.gmem_tiled_copy_QK,
|
|
448
|
+
self.gmem_tiled_copy_VdO,
|
|
449
|
+
self.gmem_tiled_copy_dK,
|
|
450
|
+
self.gmem_tiled_copy_dV,
|
|
451
|
+
self.gmem_tiled_copy_LSE,
|
|
452
|
+
self.gmem_tiled_copy_dQaccum,
|
|
453
|
+
tiled_mma_sdp,
|
|
454
|
+
tiled_mma_dkv,
|
|
455
|
+
tiled_mma_dq,
|
|
456
|
+
SharedStorage,
|
|
457
|
+
tile_sched_params,
|
|
458
|
+
TileScheduler,
|
|
459
|
+
).launch(
|
|
460
|
+
grid=grid_dim,
|
|
461
|
+
block=[self.num_threads, 1, 1],
|
|
462
|
+
smem=SharedStorage.size_in_bytes(),
|
|
463
|
+
stream=stream,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
@cute.kernel
|
|
467
|
+
def kernel(
|
|
468
|
+
self,
|
|
469
|
+
mQ: cute.Tensor,
|
|
470
|
+
mK: cute.Tensor,
|
|
471
|
+
mV: cute.Tensor,
|
|
472
|
+
mdO: cute.Tensor,
|
|
473
|
+
mLSE: cute.Tensor,
|
|
474
|
+
mdPsum: cute.Tensor,
|
|
475
|
+
mdQaccum: cute.Tensor,
|
|
476
|
+
mdK: cute.Tensor,
|
|
477
|
+
mdV: cute.Tensor,
|
|
478
|
+
mCuSeqlensQ: Optional[cute.Tensor],
|
|
479
|
+
mCuSeqlensK: Optional[cute.Tensor],
|
|
480
|
+
mSeqUsedQ: Optional[cute.Tensor],
|
|
481
|
+
mSeqUsedK: Optional[cute.Tensor],
|
|
482
|
+
softmax_scale: cutlass.Float32,
|
|
483
|
+
softmax_scale_log2: cutlass.Float32,
|
|
484
|
+
sQ_layout: cute.ComposedLayout,
|
|
485
|
+
sK_layout: cute.ComposedLayout,
|
|
486
|
+
sV_layout: cute.ComposedLayout,
|
|
487
|
+
sdO_layout: cute.ComposedLayout,
|
|
488
|
+
sPdS_layout: cute.ComposedLayout,
|
|
489
|
+
sLSE_layout: cute.Layout,
|
|
490
|
+
sLSEMma_layout: cute.Layout,
|
|
491
|
+
gmem_tiled_copy_QK: cute.TiledCopy,
|
|
492
|
+
gmem_tiled_copy_VdO: cute.TiledCopy,
|
|
493
|
+
gmem_tiled_copy_dK: cute.TiledCopy,
|
|
494
|
+
gmem_tiled_copy_dV: cute.TiledCopy,
|
|
495
|
+
gmem_tiled_copy_LSE: cute.TiledCopy,
|
|
496
|
+
gmem_tiled_copy_dQaccum: cute.TiledCopy,
|
|
497
|
+
tiled_mma_sdp: cute.TiledMma,
|
|
498
|
+
tiled_mma_dkv: cute.TiledMma,
|
|
499
|
+
tiled_mma_dq: cute.TiledMma,
|
|
500
|
+
SharedStorage: cutlass.Constexpr,
|
|
501
|
+
tile_sched_params: ParamsBase,
|
|
502
|
+
TileScheduler: cutlass.Constexpr[Callable],
|
|
503
|
+
):
|
|
504
|
+
# Thread index, block index
|
|
505
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
506
|
+
|
|
507
|
+
tile_scheduler = TileScheduler.create(tile_sched_params)
|
|
508
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
509
|
+
|
|
510
|
+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
|
511
|
+
|
|
512
|
+
if work_tile.is_valid_tile:
|
|
513
|
+
seqlen = SeqlenInfoQK.create(batch_idx, mQ.shape[1], mK.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK)
|
|
514
|
+
|
|
515
|
+
m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size)
|
|
516
|
+
m_block_min = 0
|
|
517
|
+
if cutlass.const_expr(self.is_causal):
|
|
518
|
+
m_block_min = max(
|
|
519
|
+
(n_block * self.n_block_size + seqlen.seqlen_q - seqlen.seqlen_k) // self.m_block_size,
|
|
520
|
+
m_block_min,
|
|
521
|
+
)
|
|
522
|
+
# TODO: return early if m_block_max == 0
|
|
523
|
+
|
|
524
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
525
|
+
# Get the appropriate tiles for this thread block.
|
|
526
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
527
|
+
blkQ_shape = (self.m_block_size, self.head_dim_padded)
|
|
528
|
+
blkK_shape = (self.n_block_size, self.head_dim_padded)
|
|
529
|
+
blkV_shape = (self.n_block_size, self.head_dim_v_padded)
|
|
530
|
+
blkdO_shape = (self.m_block_size, self.head_dim_v_padded)
|
|
531
|
+
|
|
532
|
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
|
|
533
|
+
mQ_cur = mQ[batch_idx, None, head_idx, None]
|
|
534
|
+
mLSE_cur = mLSE[batch_idx, head_idx, None]
|
|
535
|
+
mdO_cur = mdO[batch_idx, None, head_idx, None]
|
|
536
|
+
mdPsum_cur = mdPsum[batch_idx, head_idx, None]
|
|
537
|
+
mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
|
|
538
|
+
else:
|
|
539
|
+
padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size
|
|
540
|
+
mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, head_idx, None])
|
|
541
|
+
mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[head_idx, None])
|
|
542
|
+
mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
|
|
543
|
+
mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None])
|
|
544
|
+
mdQaccum_cur = cute.domain_offset((padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None])
|
|
545
|
+
head_idx_kv = head_idx // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else head_idx
|
|
546
|
+
|
|
547
|
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_k):
|
|
548
|
+
mK_cur, mV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mK, mV)]
|
|
549
|
+
else:
|
|
550
|
+
mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mK, mV)]
|
|
551
|
+
|
|
552
|
+
# (m_block_size, head_dim, m_block)
|
|
553
|
+
gQ = cute.local_tile(mQ_cur, blkQ_shape, (None, 0))
|
|
554
|
+
# (n_block_size, head_dim)
|
|
555
|
+
gK = cute.local_tile(mK_cur, blkK_shape, (n_block, 0))
|
|
556
|
+
# (n_block_size, head_dim_v)
|
|
557
|
+
gV = cute.local_tile(mV_cur, blkV_shape, (n_block, 0))
|
|
558
|
+
# (m_block_size, head_dim_v, m_block)
|
|
559
|
+
gdO = cute.local_tile(mdO_cur, blkdO_shape, (None, 0))
|
|
560
|
+
gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (None,))
|
|
561
|
+
gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (None,))
|
|
562
|
+
gdQaccum = cute.local_tile(mdQaccum_cur, (self.m_block_size * self.head_dim_padded,), (None,))
|
|
563
|
+
|
|
564
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
565
|
+
# Get shared memory buffer
|
|
566
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
567
|
+
smem = cutlass.utils.SmemAllocator()
|
|
568
|
+
storage = smem.allocate(SharedStorage)
|
|
569
|
+
sQ = storage.sQ.get_tensor(sQ_layout)
|
|
570
|
+
sK = storage.sK.get_tensor(sK_layout)
|
|
571
|
+
if cutlass.const_expr(not self.share_QV_smem):
|
|
572
|
+
sV = storage.sV.get_tensor(sV_layout)
|
|
573
|
+
else:
|
|
574
|
+
sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout)
|
|
575
|
+
sdO = storage.sdO.get_tensor(sdO_layout)
|
|
576
|
+
sP = storage.sP.get_tensor(sPdS_layout)
|
|
577
|
+
sdS = storage.sdS.get_tensor(sPdS_layout)
|
|
578
|
+
sLSE = storage.sLSE.get_tensor(sLSE_layout)
|
|
579
|
+
sdPsum = storage.sdPsum.get_tensor(sLSE_layout)
|
|
580
|
+
sLSEMma = storage.sLSE.get_tensor(sLSEMma_layout)
|
|
581
|
+
sdPsumMma = storage.sdPsum.get_tensor(sLSEMma_layout)
|
|
582
|
+
|
|
583
|
+
# Transpose view of tensors for tiled mma
|
|
584
|
+
sQt, sdOt, sKt, sPt, sdSt = [utils.transpose_view(t) for t in (sQ, sdO, sK, sP, sdS)]
|
|
585
|
+
|
|
586
|
+
gmem_thr_copy_QK = gmem_tiled_copy_QK.get_slice(tidx)
|
|
587
|
+
gmem_thr_copy_VdO = gmem_tiled_copy_VdO.get_slice(tidx)
|
|
588
|
+
gmem_thr_copy_lse = gmem_tiled_copy_LSE.get_slice(tidx)
|
|
589
|
+
gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx)
|
|
590
|
+
# (CPY_Atom, CPY_M, CPY_K, m_block)
|
|
591
|
+
tQgQ = gmem_thr_copy_QK.partition_S(gQ)
|
|
592
|
+
tQsQ = gmem_thr_copy_QK.partition_D(sQ)
|
|
593
|
+
# (CPY_Atom, CPY_N, CPY_K)
|
|
594
|
+
tKgK = gmem_thr_copy_QK.partition_S(gK)
|
|
595
|
+
tKsK = gmem_thr_copy_QK.partition_D(sK)
|
|
596
|
+
# (CPY_Atom, CPY_N, CPY_K)
|
|
597
|
+
tVgV = gmem_thr_copy_VdO.partition_S(gV)
|
|
598
|
+
tVsV = gmem_thr_copy_VdO.partition_D(sV)
|
|
599
|
+
# (CPY_Atom, CPY_M, CPY_K, m_block)
|
|
600
|
+
tdOgdO = gmem_thr_copy_VdO.partition_S(gdO)
|
|
601
|
+
tdOsdO = gmem_thr_copy_VdO.partition_D(sdO)
|
|
602
|
+
tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE)
|
|
603
|
+
tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE)
|
|
604
|
+
tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum)
|
|
605
|
+
tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum)
|
|
606
|
+
tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum)
|
|
607
|
+
|
|
608
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
609
|
+
# Tile MMA compute thread partitions and allocate accumulators
|
|
610
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
611
|
+
thr_mma_sdp = tiled_mma_sdp.get_slice(tidx)
|
|
612
|
+
thr_mma_dkv = tiled_mma_dkv.get_slice(tidx)
|
|
613
|
+
thr_mma_dq = tiled_mma_dq.get_slice(tidx)
|
|
614
|
+
acc_shape_dK = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_padded))
|
|
615
|
+
acc_shape_dV = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_v_padded))
|
|
616
|
+
acc_dK = cute.make_fragment(acc_shape_dK, cutlass.Float32)
|
|
617
|
+
acc_dV = cute.make_fragment(acc_shape_dV, cutlass.Float32)
|
|
618
|
+
acc_dK.fill(0.0)
|
|
619
|
+
acc_dV.fill(0.0)
|
|
620
|
+
|
|
621
|
+
tSrQ = utils.mma_make_fragment_A(sQ[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB)
|
|
622
|
+
tSrK = utils.mma_make_fragment_B(sK, thr_mma_sdp, swapAB=self.SdP_swapAB)
|
|
623
|
+
tdPrdO = utils.mma_make_fragment_A(sdO[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB)
|
|
624
|
+
tdPrV = utils.mma_make_fragment_B(sV, thr_mma_sdp, swapAB=self.SdP_swapAB)
|
|
625
|
+
tdVrP = utils.mma_make_fragment_A(sPt, thr_mma_dkv, swapAB=self.dKV_swapAB)
|
|
626
|
+
tdVrdO = utils.mma_make_fragment_B(sdOt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB)
|
|
627
|
+
tdKrdS = utils.mma_make_fragment_A(sdSt, thr_mma_dkv, swapAB=self.dKV_swapAB)
|
|
628
|
+
tdKrQ = utils.mma_make_fragment_B(sQt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB)
|
|
629
|
+
tdQrdS = utils.mma_make_fragment_A(sdS, thr_mma_dq, swapAB=self.dQ_swapAB)
|
|
630
|
+
tdQrK = utils.mma_make_fragment_B(sKt, thr_mma_dq, swapAB=self.dQ_swapAB)
|
|
631
|
+
|
|
632
|
+
LSEslice = (None, 0, None) if cutlass.const_expr(not self.SdP_swapAB) else (0, None, None)
|
|
633
|
+
tSsLSEMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sLSEMma))[LSEslice]
|
|
634
|
+
tSsdPsumMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice]
|
|
635
|
+
|
|
636
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
637
|
+
# Smem copy atom tiling
|
|
638
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
639
|
+
smem_copy_atom = cute.make_copy_atom(
|
|
640
|
+
warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype,
|
|
641
|
+
)
|
|
642
|
+
smem_copy_atom_transposed = cute.make_copy_atom(
|
|
643
|
+
warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype,
|
|
644
|
+
)
|
|
645
|
+
smem_thr_copy_QdO = utils.make_tiled_copy_A(
|
|
646
|
+
smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB
|
|
647
|
+
).get_slice(tidx)
|
|
648
|
+
smem_thr_copy_KV = utils.make_tiled_copy_B(
|
|
649
|
+
smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB
|
|
650
|
+
).get_slice(tidx)
|
|
651
|
+
# TODO: should this be smem_copy_atom_transposed?
|
|
652
|
+
smem_thr_copy_PdSt = utils.make_tiled_copy_A(
|
|
653
|
+
smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB
|
|
654
|
+
).get_slice(tidx)
|
|
655
|
+
smem_thr_copy_QdOt = utils.make_tiled_copy_B(
|
|
656
|
+
smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB
|
|
657
|
+
).get_slice(tidx)
|
|
658
|
+
smem_thr_copy_dS = utils.make_tiled_copy_A(
|
|
659
|
+
smem_copy_atom, tiled_mma_dq, swapAB=self.dQ_swapAB
|
|
660
|
+
).get_slice(tidx)
|
|
661
|
+
smem_thr_copy_Kt = utils.make_tiled_copy_B(
|
|
662
|
+
smem_copy_atom_transposed, tiled_mma_dq, swapAB=self.dQ_swapAB
|
|
663
|
+
).get_slice(tidx)
|
|
664
|
+
# TODO: what's the number of bits? What if SdP_swapAB
|
|
665
|
+
r2s_thr_copy_PdS = cute.make_tiled_copy_C(
|
|
666
|
+
cute.make_copy_atom(
|
|
667
|
+
cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width
|
|
668
|
+
),
|
|
669
|
+
tiled_mma_sdp,
|
|
670
|
+
).get_slice(tidx)
|
|
671
|
+
|
|
672
|
+
tSsQ = smem_thr_copy_QdO.partition_S(sQ)
|
|
673
|
+
tdPsdO = smem_thr_copy_QdO.partition_S(sdO)
|
|
674
|
+
tSsK = smem_thr_copy_KV.partition_S(sK)
|
|
675
|
+
tdPsV = smem_thr_copy_KV.partition_S(sV)
|
|
676
|
+
tdVsPt = smem_thr_copy_PdSt.partition_S(sPt)
|
|
677
|
+
tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt)
|
|
678
|
+
tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt)
|
|
679
|
+
tdKsQt = smem_thr_copy_QdOt.partition_S(sQt)
|
|
680
|
+
tdQsdS = smem_thr_copy_dS.partition_S(sdS)
|
|
681
|
+
tdQsKt = smem_thr_copy_Kt.partition_S(sKt)
|
|
682
|
+
tPsP = r2s_thr_copy_PdS.partition_D(sP)
|
|
683
|
+
tdSsdS = r2s_thr_copy_PdS.partition_D(sdS)
|
|
684
|
+
|
|
685
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
686
|
+
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
|
|
687
|
+
# of tile_shape
|
|
688
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
689
|
+
# Construct identity layout for KV
|
|
690
|
+
cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
|
|
691
|
+
tQcQ = gmem_thr_copy_QK.partition_S(cQ)
|
|
692
|
+
t0QcQ = gmem_thr_copy_QK.get_slice(0).partition_S(cQ)
|
|
693
|
+
if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded):
|
|
694
|
+
tdOcdO = tQcQ
|
|
695
|
+
t0dOcdO = t0QcQ
|
|
696
|
+
else:
|
|
697
|
+
cdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
|
|
698
|
+
tdOcdO = gmem_thr_copy_VdO.partition_S(cdO)
|
|
699
|
+
t0dOcdO = gmem_thr_copy_VdO.get_slice(0).partition_S(cdO)
|
|
700
|
+
cLSE = cute.make_identity_tensor((self.m_block_size,))
|
|
701
|
+
tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE)
|
|
702
|
+
|
|
703
|
+
# Allocate predicate tensors for m and n, here we only allocate the tile of k, and
|
|
704
|
+
# use "if" on the mn dimension.
|
|
705
|
+
# This is to reduce register pressure and gets 2-3% performance gain.
|
|
706
|
+
|
|
707
|
+
d_head = mQ.shape[cute.rank(mQ) - 1]
|
|
708
|
+
d_head_v = mdO.shape[cute.rank(mdO) - 1]
|
|
709
|
+
|
|
710
|
+
tQpQ = utils.predicate_k(tQcQ, limit=d_head)
|
|
711
|
+
if cutlass.const_expr(self.same_hdim_kv):
|
|
712
|
+
tdOpdO = tQpQ
|
|
713
|
+
else:
|
|
714
|
+
tdOpdO = utils.predicate_k(tdOcdO, limit=d_head_v)
|
|
715
|
+
|
|
716
|
+
# group parameters for compute_one_m_block
|
|
717
|
+
mma_params = SimpleNamespace(
|
|
718
|
+
thr_mma_sdp=thr_mma_sdp, thr_mma_dkv=thr_mma_dkv, thr_mma_dq=thr_mma_dq,
|
|
719
|
+
tSrQ=tSrQ, tSrK=tSrK, tdPrdO=tdPrdO, tdPrV=tdPrV,
|
|
720
|
+
tdVrP=tdVrP, tdVrdO=tdVrdO, tdKrdS=tdKrdS, tdKrQ=tdKrQ,
|
|
721
|
+
tdQrdS=tdQrdS, tdQrK=tdQrK,
|
|
722
|
+
acc_dK=acc_dK, acc_dV=acc_dV,
|
|
723
|
+
)
|
|
724
|
+
smem_copy_params = SimpleNamespace(
|
|
725
|
+
smem_thr_copy_QdO=smem_thr_copy_QdO,
|
|
726
|
+
smem_thr_copy_KV=smem_thr_copy_KV,
|
|
727
|
+
smem_thr_copy_PdSt=smem_thr_copy_PdSt,
|
|
728
|
+
smem_thr_copy_QdOt=smem_thr_copy_QdOt,
|
|
729
|
+
smem_thr_copy_dS=smem_thr_copy_dS,
|
|
730
|
+
smem_thr_copy_Kt=smem_thr_copy_Kt,
|
|
731
|
+
r2s_thr_copy_PdS=r2s_thr_copy_PdS,
|
|
732
|
+
tSsQ=tSsQ, tSsK=tSsK, tdPsdO=tdPsdO, tdPsV=tdPsV,
|
|
733
|
+
tSsLSEMma=tSsLSEMma, tSsdPsumMma=tSsdPsumMma,
|
|
734
|
+
tPsP=tPsP, tdSsdS=tdSsdS,
|
|
735
|
+
tdVsPt=tdVsPt, tdVsdOt=tdVsdOt, tdKsdSt=tdKsdSt, tdKsQt=tdKsQt,
|
|
736
|
+
tdQsdS=tdQsdS, tdQsKt=tdQsKt,
|
|
737
|
+
)
|
|
738
|
+
gmem_copy_params = SimpleNamespace(
|
|
739
|
+
gmem_thr_copy_dQaccum=gmem_thr_copy_dQaccum, tdQgdQaccum=tdQgdQaccum
|
|
740
|
+
)
|
|
741
|
+
load_Q_LSE = partial(
|
|
742
|
+
self.load_Q_LSE, gmem_tiled_copy_QK, gmem_tiled_copy_LSE,
|
|
743
|
+
tQgQ, tQsQ, tQcQ, t0QcQ, tQpQ,
|
|
744
|
+
tLSEgLSE, tLSEsLSE, tLSEcLSE, seqlen=seqlen.seqlen_q
|
|
745
|
+
)
|
|
746
|
+
load_dO_dPsum = partial(
|
|
747
|
+
self.load_dO_dPsum, gmem_tiled_copy_VdO, gmem_tiled_copy_LSE,
|
|
748
|
+
tdOgdO, tdOsdO, tdOcdO, t0dOcdO, tdOpdO,
|
|
749
|
+
tLSEgdPsum, tLSEsdPsum, tLSEcLSE, seqlen=seqlen.seqlen_q
|
|
750
|
+
)
|
|
751
|
+
compute_one_m_block = partial(
|
|
752
|
+
self.compute_one_m_block, mma_params=mma_params,
|
|
753
|
+
smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params,
|
|
754
|
+
load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum,
|
|
755
|
+
m_block_max=m_block_max,
|
|
756
|
+
softmax_scale_log2=softmax_scale_log2,
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
760
|
+
# Prologue
|
|
761
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
762
|
+
# Start async loads of the last mn-tile, where we take care of the mn residue
|
|
763
|
+
self.load_V(gmem_thr_copy_VdO, tVgV, tVsV, n_block, seqlen=seqlen.seqlen_k,
|
|
764
|
+
headdim=d_head_v)
|
|
765
|
+
if cutlass.const_expr(self.V_in_regs):
|
|
766
|
+
cute.arch.cp_async_commit_group()
|
|
767
|
+
self.load_K(gmem_thr_copy_QK, tKgK, tKsK, n_block, seqlen=seqlen.seqlen_k,
|
|
768
|
+
headdim=d_head)
|
|
769
|
+
cute.arch.cp_async_commit_group()
|
|
770
|
+
|
|
771
|
+
if cutlass.const_expr(self.V_in_regs):
|
|
772
|
+
cute.arch.cp_async_wait_group(1)
|
|
773
|
+
cute.arch.barrier()
|
|
774
|
+
tdPrV_copy_view = smem_thr_copy_KV.retile(tdPrV)
|
|
775
|
+
cute.copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view)
|
|
776
|
+
# Sync to avoid loading Q to smem_q, which overlaps with smem_v
|
|
777
|
+
cute.arch.barrier()
|
|
778
|
+
|
|
779
|
+
m_block = m_block_min
|
|
780
|
+
assert self.num_stages_Q >= self.num_stages_dO
|
|
781
|
+
for stage in cutlass.range_constexpr(self.num_stages_Q):
|
|
782
|
+
if cutlass.const_expr(self.num_stages_Q == 1 or stage < self.num_stages_Q - 1):
|
|
783
|
+
if stage == 0 or m_block + stage < m_block_max:
|
|
784
|
+
load_Q_LSE(m_block + stage, smem_pipe_write_q=stage)
|
|
785
|
+
cute.arch.cp_async_commit_group()
|
|
786
|
+
if cutlass.const_expr(stage < self.num_stages_dO):
|
|
787
|
+
if stage == 0 or m_block + stage < m_block_max:
|
|
788
|
+
load_dO_dPsum(m_block + stage, smem_pipe_write_q=stage)
|
|
789
|
+
cute.arch.cp_async_commit_group()
|
|
790
|
+
|
|
791
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
792
|
+
# Mainloop
|
|
793
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
794
|
+
# Start processing of the first n-block.
|
|
795
|
+
mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k)
|
|
796
|
+
mask_fn = partial(
|
|
797
|
+
mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp,
|
|
798
|
+
mask_seqlen=True, mask_causal=self.is_causal
|
|
799
|
+
)
|
|
800
|
+
smem_pipe_read_q = cutlass.Int32(0)
|
|
801
|
+
smem_pipe_read_do = cutlass.Int32(0)
|
|
802
|
+
smem_pipe_write_q = cutlass.Int32(self.num_stages_Q - 1)
|
|
803
|
+
smem_pipe_write_do = cutlass.Int32(0)
|
|
804
|
+
for m_tile in cutlass.range(m_block_min, m_block_max, unroll=1):
|
|
805
|
+
compute_one_m_block(
|
|
806
|
+
m_tile, smem_pipe_read_q, smem_pipe_read_do, smem_pipe_write_q, smem_pipe_write_do,
|
|
807
|
+
mask_fn=mask_fn,
|
|
808
|
+
)
|
|
809
|
+
smem_pipe_read_q = self.advance_pipeline(smem_pipe_read_q, self.num_stages_Q)
|
|
810
|
+
smem_pipe_read_do = self.advance_pipeline(smem_pipe_read_do, self.num_stages_dO)
|
|
811
|
+
smem_pipe_write_q = self.advance_pipeline(smem_pipe_write_q, self.num_stages_Q)
|
|
812
|
+
smem_pipe_write_do = self.advance_pipeline(smem_pipe_write_do, self.num_stages_dO)
|
|
813
|
+
|
|
814
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
815
|
+
# Epilogue
|
|
816
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
817
|
+
# If GQA, we scale dK in the postprocessing kernel instead
|
|
818
|
+
if cutlass.const_expr(self.qhead_per_kvhead == 1):
|
|
819
|
+
acc_dK.store(acc_dK.load() * softmax_scale)
|
|
820
|
+
# reuse sK and sV data iterator
|
|
821
|
+
sdK = cute.make_tensor(sK.iterator, sK_layout)
|
|
822
|
+
sdV = cute.make_tensor(sV.iterator, sV_layout)
|
|
823
|
+
self.epilogue(
|
|
824
|
+
acc_dK, acc_dV, mdK, mdV, sdK, sdV,
|
|
825
|
+
gmem_tiled_copy_dK, gmem_tiled_copy_dV, tiled_mma_dkv,
|
|
826
|
+
tidx, n_block, head_idx, batch_idx, seqlen, d_head, d_head_v
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
@cute.jit
|
|
830
|
+
def compute_one_m_block(
|
|
831
|
+
self,
|
|
832
|
+
m_block: cutlass.Int32,
|
|
833
|
+
smem_pipe_read_q: cutlass.Int32,
|
|
834
|
+
smem_pipe_read_do: cutlass.Int32,
|
|
835
|
+
smem_pipe_write_q: cutlass.Int32,
|
|
836
|
+
smem_pipe_write_do: cutlass.Int32,
|
|
837
|
+
mma_params: SimpleNamespace,
|
|
838
|
+
smem_copy_params: SimpleNamespace,
|
|
839
|
+
gmem_copy_params: SimpleNamespace,
|
|
840
|
+
load_Q_LSE: Callable,
|
|
841
|
+
load_dO_dPsum: Callable,
|
|
842
|
+
m_block_max: cutlass.Int32,
|
|
843
|
+
softmax_scale_log2: cutlass.Float32,
|
|
844
|
+
mask_fn: Optional[Callable] = None,
|
|
845
|
+
):
|
|
846
|
+
def load_Q_next():
|
|
847
|
+
m_block_next = m_block + (self.num_stages_Q - 1 if cutlass.const_expr(self.num_stages_Q > 1) else 1)
|
|
848
|
+
if m_block_next < m_block_max:
|
|
849
|
+
load_Q_LSE(m_block_next, smem_pipe_write_q)
|
|
850
|
+
cute.arch.cp_async_commit_group()
|
|
851
|
+
|
|
852
|
+
def load_dO_next():
|
|
853
|
+
if m_block + self.num_stages_dO < m_block_max:
|
|
854
|
+
load_dO_dPsum(m_block + self.num_stages_dO, smem_pipe_write_do)
|
|
855
|
+
cute.arch.cp_async_commit_group()
|
|
856
|
+
|
|
857
|
+
# MMA S
|
|
858
|
+
acc_shape_SdP = mma_params.thr_mma_sdp.partition_shape_C(
|
|
859
|
+
(self.m_block_size, self.n_block_size) if cutlass.const_expr(not self.SdP_swapAB) else (self.n_block_size, self.m_block_size)
|
|
860
|
+
)
|
|
861
|
+
acc_S = cute.make_fragment(acc_shape_SdP, cutlass.Float32)
|
|
862
|
+
acc_S.fill(0.0)
|
|
863
|
+
cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_Q > 1) else 0)
|
|
864
|
+
cute.arch.barrier()
|
|
865
|
+
sm80_utils.gemm(
|
|
866
|
+
mma_params.thr_mma_sdp, acc_S, mma_params.tSrQ, mma_params.tSrK,
|
|
867
|
+
smem_copy_params.tSsQ[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0],
|
|
868
|
+
smem_copy_params.tSsK,
|
|
869
|
+
smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV,
|
|
870
|
+
swap_AB=self.SdP_swapAB,
|
|
871
|
+
)
|
|
872
|
+
tLSErLSE = cute.make_fragment_like(smem_copy_params.tSsLSEMma[None, 0])
|
|
873
|
+
cute.autovec_copy(
|
|
874
|
+
smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], tLSErLSE
|
|
875
|
+
)
|
|
876
|
+
if cutlass.const_expr(mask_fn is not None):
|
|
877
|
+
mask_fn(acc_S, m_block=m_block)
|
|
878
|
+
acc_S_mn = utils.make_acc_tensor_mn_view(acc_S)
|
|
879
|
+
bidx = 0
|
|
880
|
+
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn)
|
|
881
|
+
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE)
|
|
882
|
+
assert cute.size(acc_S_mn, mode=[0]) == cute.size(tLSErLSE)
|
|
883
|
+
for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True):
|
|
884
|
+
acc_S_mn[r, None].store(utils.exp2f(acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r]))
|
|
885
|
+
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn)
|
|
886
|
+
|
|
887
|
+
# MMA dP
|
|
888
|
+
acc_dP = cute.make_fragment(acc_shape_SdP, cutlass.Float32)
|
|
889
|
+
acc_dP.fill(0.0)
|
|
890
|
+
cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_dO > 1) else 0)
|
|
891
|
+
cute.arch.barrier()
|
|
892
|
+
sm80_utils.gemm(
|
|
893
|
+
mma_params.thr_mma_sdp, acc_dP, mma_params.tdPrdO, mma_params.tdPrV,
|
|
894
|
+
smem_copy_params.tdPsdO[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0],
|
|
895
|
+
smem_copy_params.tdPsV,
|
|
896
|
+
smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV,
|
|
897
|
+
hook_fn=load_Q_next if cutlass.const_expr(self.num_stages_Q > 1) else None,
|
|
898
|
+
swap_AB=self.SdP_swapAB,
|
|
899
|
+
)
|
|
900
|
+
tLSErdPsum = cute.make_fragment_like(smem_copy_params.tSsdPsumMma[None, 0])
|
|
901
|
+
cute.autovec_copy(
|
|
902
|
+
smem_copy_params.tSsdPsumMma[None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], tLSErdPsum
|
|
903
|
+
)
|
|
904
|
+
acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP)
|
|
905
|
+
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn)
|
|
906
|
+
assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum)
|
|
907
|
+
for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True):
|
|
908
|
+
acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r]))
|
|
909
|
+
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn)
|
|
910
|
+
rP = cute.make_fragment_like(acc_S, self.dtype)
|
|
911
|
+
rP.store(acc_S.load().to(self.dtype))
|
|
912
|
+
if cutlass.const_expr(not self.Mma_dKV_is_RS):
|
|
913
|
+
tPrP = smem_copy_params.r2s_thr_copy_PdS.retile(rP) # ((Atom,AtomNum), MMA_N, MMA_N)
|
|
914
|
+
cute.copy(smem_copy_params.r2s_thr_copy_PdS, tPrP, smem_copy_params.tPsP)
|
|
915
|
+
rdS = cute.make_fragment_like(acc_dP, self.dtype)
|
|
916
|
+
rdS.store(acc_dP.load().to(self.dtype))
|
|
917
|
+
if cutlass.const_expr(not self.Mma_dKV_is_RS):
|
|
918
|
+
cute.arch.barrier() # Make sure P is written
|
|
919
|
+
# For hdim 64, It's faster to write to smem_dS first before the dV gemm
|
|
920
|
+
if cutlass.const_expr(not self.Mma_dKV_is_RS):
|
|
921
|
+
tdSrdS = smem_copy_params.r2s_thr_copy_PdS.retile(rdS)
|
|
922
|
+
cute.copy(smem_copy_params.r2s_thr_copy_PdS, tdSrdS, smem_copy_params.tdSsdS)
|
|
923
|
+
if cutlass.const_expr(self.Mma_dKV_is_RS):
|
|
924
|
+
tdVrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout))
|
|
925
|
+
else:
|
|
926
|
+
tdVrP = mma_params.tdVrP
|
|
927
|
+
|
|
928
|
+
# MMA dK
|
|
929
|
+
sm80_utils.gemm(
|
|
930
|
+
mma_params.thr_mma_dkv, mma_params.acc_dV, tdVrP, mma_params.tdVrdO,
|
|
931
|
+
smem_copy_params.tdVsPt,
|
|
932
|
+
smem_copy_params.tdVsdOt[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0],
|
|
933
|
+
smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt,
|
|
934
|
+
A_in_regs=self.Mma_dKV_is_RS,
|
|
935
|
+
swap_AB=self.dKV_swapAB,
|
|
936
|
+
)
|
|
937
|
+
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(mma_params.acc_dV)
|
|
938
|
+
cute.arch.barrier() # Make sure dS is written
|
|
939
|
+
|
|
940
|
+
# MMA dQ
|
|
941
|
+
def dQ_mma(hook_fn):
|
|
942
|
+
acc_shape_dQ = mma_params.thr_mma_dq.partition_shape_C(
|
|
943
|
+
(self.m_block_size, self.head_dim_padded) if cutlass.const_expr(not self.dQ_swapAB) else (self.head_dim_padded, self.m_block_size)
|
|
944
|
+
)
|
|
945
|
+
acc_dQ = cute.make_fragment(acc_shape_dQ, cutlass.Float32)
|
|
946
|
+
acc_dQ.fill(0.0)
|
|
947
|
+
sm80_utils.gemm(
|
|
948
|
+
mma_params.thr_mma_dq, acc_dQ, mma_params.tdQrdS, mma_params.tdQrK,
|
|
949
|
+
smem_copy_params.tdQsdS, smem_copy_params.tdQsKt,
|
|
950
|
+
smem_copy_params.smem_thr_copy_dS, smem_copy_params.smem_thr_copy_Kt,
|
|
951
|
+
swap_AB=self.dQ_swapAB,
|
|
952
|
+
hook_fn=hook_fn
|
|
953
|
+
)
|
|
954
|
+
# ((1, 1), num_elements)
|
|
955
|
+
acc_dQ_atomic = gmem_copy_params.gmem_thr_copy_dQaccum.retile(acc_dQ)
|
|
956
|
+
tdQgdQaccum_atomic = gmem_copy_params.tdQgdQaccum[None, None, m_block]
|
|
957
|
+
assert cute.size(acc_dQ_atomic) == cute.size(tdQgdQaccum_atomic)
|
|
958
|
+
for i in cutlass.range(cute.size(acc_dQ_atomic), unroll_full=True):
|
|
959
|
+
utils.atomic_add_fp32(acc_dQ_atomic[i], utils.elem_pointer(tdQgdQaccum_atomic, i))
|
|
960
|
+
# utils.atomic_add_fp32(acc_dQ[i], tdQgdQaccum_atomic.iterator + i * tdQgdQaccum_atomic.stride[1])
|
|
961
|
+
# if cute.arch.thread_idx()[0] == 64 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dQ)
|
|
962
|
+
|
|
963
|
+
# If num_stages_Q == 1, we want to do Mma_dK first so we can start loading Q for the next iteration
|
|
964
|
+
if cutlass.const_expr(self.num_stages_Q > 1):
|
|
965
|
+
dQ_mma(load_dO_next)
|
|
966
|
+
|
|
967
|
+
# MMA dK
|
|
968
|
+
if cutlass.const_expr(self.Mma_dKV_is_RS):
|
|
969
|
+
tdKrdS = cute.make_tensor(rdS.iterator, utils.convert_layout_acc_frgA(rdS.layout))
|
|
970
|
+
else:
|
|
971
|
+
tdKrdS = mma_params.tdKrdS
|
|
972
|
+
sm80_utils.gemm(
|
|
973
|
+
mma_params.thr_mma_dkv, mma_params.acc_dK, tdKrdS, mma_params.tdKrQ,
|
|
974
|
+
smem_copy_params.tdKsdSt,
|
|
975
|
+
smem_copy_params.tdKsQt[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0],
|
|
976
|
+
smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt,
|
|
977
|
+
A_in_regs=self.Mma_dKV_is_RS,
|
|
978
|
+
swap_AB=self.dKV_swapAB,
|
|
979
|
+
hook_fn=load_dO_next if cutlass.const_expr(self.num_stages_Q == 1) else None,
|
|
980
|
+
)
|
|
981
|
+
# if cute.arch.thread_idx()[0] == 0: cute.print_tensor(mma_params.acc_dK)
|
|
982
|
+
if cutlass.const_expr(self.num_stages_Q == 1):
|
|
983
|
+
cute.arch.barrier()
|
|
984
|
+
dQ_mma(load_Q_next)
|
|
985
|
+
|
|
986
|
+
@cute.jit
|
|
987
|
+
def epilogue(
|
|
988
|
+
self,
|
|
989
|
+
acc_dK: cute.Tensor,
|
|
990
|
+
acc_dV: cute.Tensor,
|
|
991
|
+
mdK: cute.Tensor,
|
|
992
|
+
mdV: cute.Tensor,
|
|
993
|
+
sdK: cute.Tensor,
|
|
994
|
+
sdV: cute.Tensor,
|
|
995
|
+
gmem_tiled_copy_dK: cute.TiledCopy,
|
|
996
|
+
gmem_tiled_copy_dV: cute.TiledCopy,
|
|
997
|
+
tiled_mma: cute.TiledMma,
|
|
998
|
+
tidx: cutlass.Int32,
|
|
999
|
+
n_block: cutlass.Int32,
|
|
1000
|
+
num_head: cutlass.Int32,
|
|
1001
|
+
batch_size: cutlass.Int32,
|
|
1002
|
+
seqlen: SeqlenInfoQK,
|
|
1003
|
+
d_head: cutlass.Int32,
|
|
1004
|
+
d_head_v: cutlass.Int32
|
|
1005
|
+
):
|
|
1006
|
+
rdV = cute.make_fragment_like(acc_dV, self.dtype)
|
|
1007
|
+
rdV.store(acc_dV.load().to(self.dtype))
|
|
1008
|
+
rdK = cute.make_fragment_like(acc_dK, self.dtype)
|
|
1009
|
+
rdK.store(acc_dK.load().to(self.dtype))
|
|
1010
|
+
gmem_thr_copy_dK = gmem_tiled_copy_dK.get_slice(tidx)
|
|
1011
|
+
gmem_thr_copy_dV = gmem_tiled_copy_dV.get_slice(tidx)
|
|
1012
|
+
|
|
1013
|
+
batch_idx = batch_size
|
|
1014
|
+
head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head
|
|
1015
|
+
|
|
1016
|
+
if cutlass.const_expr(self.qhead_per_kvhead == 1):
|
|
1017
|
+
# Make sure all threads have finished reading K and V, otherwise we get racy dQ
|
|
1018
|
+
# because smem_q could be changed.
|
|
1019
|
+
cute.arch.barrier()
|
|
1020
|
+
# smem copy atom for dKV
|
|
1021
|
+
smem_copy_atom_dKV = cute.make_copy_atom(
|
|
1022
|
+
cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width
|
|
1023
|
+
)
|
|
1024
|
+
smem_thr_copy_dKV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma).get_slice(tidx)
|
|
1025
|
+
taccdVrdV = smem_thr_copy_dKV.retile(rdV)
|
|
1026
|
+
taccdKrdK = smem_thr_copy_dKV.retile(rdK)
|
|
1027
|
+
taccdVsdV = smem_thr_copy_dKV.partition_D(sdV)
|
|
1028
|
+
taccdKsdK = smem_thr_copy_dKV.partition_D(sdK)
|
|
1029
|
+
# copy acc O from rmem to smem with the smem copy atom
|
|
1030
|
+
cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV)
|
|
1031
|
+
cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK)
|
|
1032
|
+
|
|
1033
|
+
|
|
1034
|
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_k):
|
|
1035
|
+
mdK_cur, mdV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mdK, mdV)]
|
|
1036
|
+
else:
|
|
1037
|
+
mdK_cur, mdV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mdK, mdV)]
|
|
1038
|
+
|
|
1039
|
+
blkdK_shape = (self.n_block_size, self.head_dim_padded)
|
|
1040
|
+
blkdV_shape = (self.n_block_size, self.head_dim_v_padded)
|
|
1041
|
+
gdK = cute.local_tile(mdK_cur, blkdK_shape, (n_block, 0))
|
|
1042
|
+
gdV = cute.local_tile(mdV_cur, blkdV_shape, (n_block, 0))
|
|
1043
|
+
tdKsdK = gmem_thr_copy_dK.partition_S(sdK)
|
|
1044
|
+
tdKgdK = gmem_thr_copy_dK.partition_D(gdK)
|
|
1045
|
+
tdVsdV = gmem_thr_copy_dV.partition_S(sdV)
|
|
1046
|
+
tdVgdV = gmem_thr_copy_dV.partition_D(gdV)
|
|
1047
|
+
tdKrdK = cute.make_fragment_like(tdKgdK, self.dtype)
|
|
1048
|
+
tdVrdV = cute.make_fragment_like(tdVgdV, self.dtype)
|
|
1049
|
+
# sync before all smem stores are done.
|
|
1050
|
+
cute.arch.barrier()
|
|
1051
|
+
# load acc dK and dV from smem to rmem for wider vectorization
|
|
1052
|
+
# Need to check OOB when reading from smem if kBlockN isn't evenly tiled
|
|
1053
|
+
# TODO
|
|
1054
|
+
cute.autovec_copy(tdKsdK, tdKrdK)
|
|
1055
|
+
cute.autovec_copy(tdVsdV, tdVrdV)
|
|
1056
|
+
|
|
1057
|
+
cdK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded))
|
|
1058
|
+
tdKcdK = gmem_thr_copy_dK.partition_S(cdK)
|
|
1059
|
+
t0dKcdK = gmem_tiled_copy_dK.get_slice(0).partition_S(cdK)
|
|
1060
|
+
if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded):
|
|
1061
|
+
tdVcdV = tdKcdK
|
|
1062
|
+
t0dVcdV = t0dKcdK
|
|
1063
|
+
else:
|
|
1064
|
+
cdV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded))
|
|
1065
|
+
tdVcdV = gmem_thr_copy_dV.partition_S(cdV)
|
|
1066
|
+
t0dVcdV = gmem_tiled_copy_dV.get_slice(0).partition_S(cdV)
|
|
1067
|
+
tdKpdK = utils.predicate_k(tdKcdK, limit=d_head)
|
|
1068
|
+
if cutlass.const_expr(self.same_hdim_kv):
|
|
1069
|
+
tdVpdV = tdKpdK
|
|
1070
|
+
else:
|
|
1071
|
+
tdVpdV = utils.predicate_k(tdVcdV, limit=d_head_v)
|
|
1072
|
+
# copy acc dK and acc_dV from rmem to gmem
|
|
1073
|
+
for rest_m in cutlass.range_constexpr(cute.size(tdKrdK.shape[1])):
|
|
1074
|
+
if t0dKcdK[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdKcdK[0][0]:
|
|
1075
|
+
cute.copy(
|
|
1076
|
+
gmem_tiled_copy_dK,
|
|
1077
|
+
tdKrdK[None, rest_m, None],
|
|
1078
|
+
tdKgdK[None, rest_m, None],
|
|
1079
|
+
pred=tdKpdK[None, rest_m, None] if cutlass.const_expr(self.check_hdim_oob) else None,
|
|
1080
|
+
)
|
|
1081
|
+
for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])):
|
|
1082
|
+
if t0dVcdV[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdVcdV[0][0]:
|
|
1083
|
+
cute.copy(
|
|
1084
|
+
gmem_tiled_copy_dV,
|
|
1085
|
+
tdVrdV[None, rest_m, None],
|
|
1086
|
+
tdVgdV[None, rest_m, None],
|
|
1087
|
+
pred=tdVpdV[None, rest_m, None] if cutlass.const_expr(self.check_hdim_v_oob) else None,
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
else: # qhead_per_kvhead > 1, do atomic add
|
|
1091
|
+
# For Sm90, we need to sync to avoid racy writes to smem_q
|
|
1092
|
+
# For Sm80, we don't need to sync since we're not touching smem
|
|
1093
|
+
head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head
|
|
1094
|
+
|
|
1095
|
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_k):
|
|
1096
|
+
mdK_cur, mdV_cur = [t[batch_idx, head_idx_kv, None] for t in (mdK, mdV)]
|
|
1097
|
+
else:
|
|
1098
|
+
padded_offset_k = seqlen.offset_k + batch_idx * self.n_block_size
|
|
1099
|
+
mdK_cur = cute.domain_offset((padded_offset_k * self.head_dim_padded,), mdK[head_idx_kv, None])
|
|
1100
|
+
mdV_cur = cute.domain_offset((padded_offset_k * self.head_dim_v_padded,), mdV[head_idx_kv, None])
|
|
1101
|
+
|
|
1102
|
+
gdV = cute.local_tile(mdV_cur, (self.n_block_size * self.head_dim_v_padded,), (n_block,))
|
|
1103
|
+
gdK = cute.local_tile(mdK_cur, (self.n_block_size * self.head_dim_padded,), (n_block,))
|
|
1104
|
+
tdVgdVaccum = gmem_thr_copy_dV.partition_S(gdV)
|
|
1105
|
+
tdKgdKaccum = gmem_thr_copy_dK.partition_S(gdK)
|
|
1106
|
+
acc_dV_atomic = gmem_thr_copy_dV.retile(acc_dV)
|
|
1107
|
+
acc_dK_atomic = gmem_thr_copy_dK.retile(acc_dK)
|
|
1108
|
+
assert cute.size(acc_dV_atomic) == cute.size(tdVgdVaccum)
|
|
1109
|
+
assert cute.size(acc_dK_atomic) == cute.size(tdKgdKaccum)
|
|
1110
|
+
for i in cutlass.range(cute.size(acc_dV_atomic), unroll_full=True):
|
|
1111
|
+
utils.atomic_add_fp32(acc_dV_atomic[i], utils.elem_pointer(tdVgdVaccum, i))
|
|
1112
|
+
for i in cutlass.range(cute.size(acc_dK_atomic), unroll_full=True):
|
|
1113
|
+
utils.atomic_add_fp32(acc_dK_atomic[i], utils.elem_pointer(tdKgdKaccum, i))
|
|
1114
|
+
|
|
1115
|
+
@cute.jit
|
|
1116
|
+
def advance_pipeline(self, pipeline_index, num_stages: cutlass.Constexpr):
|
|
1117
|
+
return pipeline_index + 1 if pipeline_index < num_stages - 1 else 0
|
|
1118
|
+
|
|
1119
|
+
@cute.jit
|
|
1120
|
+
def load_K(
|
|
1121
|
+
self,
|
|
1122
|
+
gmem_thr_copy: cute.TiledCopy,
|
|
1123
|
+
tKgK: cute.Tensor,
|
|
1124
|
+
tKsK: cute.Tensor,
|
|
1125
|
+
block: cutlass.Int32,
|
|
1126
|
+
seqlen: cutlass.Int32,
|
|
1127
|
+
headdim: cutlass.Int32,
|
|
1128
|
+
):
|
|
1129
|
+
cK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded))
|
|
1130
|
+
tKcK = gmem_thr_copy.partition_S(cK)
|
|
1131
|
+
t0KcK = gmem_thr_copy.get_slice(0).partition_S(cK)
|
|
1132
|
+
tKpK = utils.predicate_k(tKcK, limit=headdim)
|
|
1133
|
+
for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])):
|
|
1134
|
+
# If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked
|
|
1135
|
+
if self.is_even_n_smem_k or n < cute.size(tKsK.shape[1]) - 1 or tKcK[0, n, 0][0] < self.n_block_size:
|
|
1136
|
+
# Instead of using tKcK, we using t0KcK and subtract the offset from the limit
|
|
1137
|
+
# (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time.
|
|
1138
|
+
predicate_n = t0KcK[0, n, 0][0] < seqlen - block * self.n_block_size - tKcK[0][0]
|
|
1139
|
+
predicate = cute.make_fragment_like(tKpK[None, 0, None])
|
|
1140
|
+
for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
|
|
1141
|
+
for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
|
|
1142
|
+
predicate[i, k] = (tKpK[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n
|
|
1143
|
+
cute.copy(
|
|
1144
|
+
gmem_thr_copy, tKgK[None, n, None], tKsK[None, n, None], pred=predicate,
|
|
1145
|
+
)
|
|
1146
|
+
# We need to clear the sK smem tiles since we'll use sKt for mma_dq
|
|
1147
|
+
|
|
1148
|
+
@cute.jit
|
|
1149
|
+
def load_V(
|
|
1150
|
+
self,
|
|
1151
|
+
gmem_thr_copy: cute.TiledCopy,
|
|
1152
|
+
tVgV: cute.Tensor,
|
|
1153
|
+
tVsV: cute.Tensor,
|
|
1154
|
+
block: cutlass.Int32,
|
|
1155
|
+
seqlen: cutlass.Int32,
|
|
1156
|
+
headdim: cutlass.Int32,
|
|
1157
|
+
):
|
|
1158
|
+
cV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded))
|
|
1159
|
+
tVcV = gmem_thr_copy.partition_S(cV)
|
|
1160
|
+
t0VcV = gmem_thr_copy.get_slice(0).partition_S(cV)
|
|
1161
|
+
tVpV = utils.predicate_k(tVcV, limit=headdim)
|
|
1162
|
+
for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])):
|
|
1163
|
+
# If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked
|
|
1164
|
+
if self.is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size:
|
|
1165
|
+
# Instead of using tVcV, we using t0VcV and subtract the offset from the limit
|
|
1166
|
+
# (seqlen - block * kBlockN). This is because the entries of t0VcV are known at compile time.
|
|
1167
|
+
predicate_n = t0VcV[0, n, 0][0] < seqlen - block * self.n_block_size - tVcV[0][0]
|
|
1168
|
+
predicate = cute.make_fragment_like(tVpV[None, 0, None])
|
|
1169
|
+
for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
|
|
1170
|
+
for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
|
|
1171
|
+
predicate[i, k] = (tVpV[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n
|
|
1172
|
+
cute.copy(
|
|
1173
|
+
gmem_thr_copy, tVgV[None, n, None], tVsV[None, n, None], pred=predicate,
|
|
1174
|
+
)
|
|
1175
|
+
|
|
1176
|
+
@cute.jit
|
|
1177
|
+
def load_Q_LSE(
|
|
1178
|
+
self,
|
|
1179
|
+
gmem_tiled_copy_Q: cute.TiledCopy,
|
|
1180
|
+
gmem_tiled_copy_LSE: cute.TiledCopy,
|
|
1181
|
+
tQgQ: cute.Tensor,
|
|
1182
|
+
tQsQ: cute.Tensor,
|
|
1183
|
+
tQcQ: cute.Tensor,
|
|
1184
|
+
t0QcQ: cute.Tensor,
|
|
1185
|
+
tQpQ: cute.Tensor,
|
|
1186
|
+
tLSEgLSE: cute.Tensor,
|
|
1187
|
+
tLSEsLSE: cute.Tensor,
|
|
1188
|
+
tLSEcLSE: cute.Tensor,
|
|
1189
|
+
block: cutlass.Int32,
|
|
1190
|
+
smem_pipe_write_q: cutlass.Int32,
|
|
1191
|
+
seqlen: cutlass.Int32,
|
|
1192
|
+
):
|
|
1193
|
+
for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
|
|
1194
|
+
# If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked
|
|
1195
|
+
if self.is_even_m_smem_q or m < cute.size(tQsQ.shape[1]) - 1 or tQcQ[0, m, 0][0] < self.m_block_size:
|
|
1196
|
+
# Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit
|
|
1197
|
+
# (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time.
|
|
1198
|
+
predicate_m = t0QcQ[0, m, 0][0] < seqlen - block * self.m_block_size - tQcQ[0][0]
|
|
1199
|
+
predicate = cute.make_fragment_like(tQpQ[None, 0, None])
|
|
1200
|
+
for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
|
|
1201
|
+
for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
|
|
1202
|
+
predicate[i, k] = (tQpQ[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m
|
|
1203
|
+
cute.copy(
|
|
1204
|
+
gmem_tiled_copy_Q,
|
|
1205
|
+
tQgQ[None, m, None, block],
|
|
1206
|
+
tQsQ[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q) > 1 else 0],
|
|
1207
|
+
pred=predicate,
|
|
1208
|
+
)
|
|
1209
|
+
# We need to clear the sQ smem tiles since we'll use sQt for mma_dK
|
|
1210
|
+
# We made sure LSE length is padded so we read `kBlockM` elements so that all
|
|
1211
|
+
# elements in sLSE are filled. Without this we might have uninitialized sLSE values.
|
|
1212
|
+
for m in cutlass.range_constexpr(cute.size(tLSEsLSE.shape[1])):
|
|
1213
|
+
if tLSEcLSE[0, m][0] < self.m_block_size:
|
|
1214
|
+
cute.copy(
|
|
1215
|
+
gmem_tiled_copy_LSE,
|
|
1216
|
+
tLSEgLSE[None, m, block],
|
|
1217
|
+
tLSEsLSE[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q > 1) else 0],
|
|
1218
|
+
)
|
|
1219
|
+
|
|
1220
|
+
@cute.jit
|
|
1221
|
+
def load_dO_dPsum(
|
|
1222
|
+
self,
|
|
1223
|
+
gmem_tiled_copy_dO: cute.TiledCopy,
|
|
1224
|
+
gmem_tiled_copy_dPsum: cute.TiledCopy,
|
|
1225
|
+
tdOgdO: cute.Tensor,
|
|
1226
|
+
tdOsdO: cute.Tensor,
|
|
1227
|
+
tdOcdO: cute.Tensor,
|
|
1228
|
+
t0dOcdO: cute.Tensor,
|
|
1229
|
+
tdOpdO: cute.Tensor,
|
|
1230
|
+
tdPsumgdPsum: cute.Tensor,
|
|
1231
|
+
tdPsumsdPsum: cute.Tensor,
|
|
1232
|
+
tdPsumcdPsum: cute.Tensor,
|
|
1233
|
+
block: cutlass.Int32,
|
|
1234
|
+
smem_pipe_write_q: cutlass.Int32,
|
|
1235
|
+
seqlen: cutlass.Int32,
|
|
1236
|
+
):
|
|
1237
|
+
for m in cutlass.range_constexpr(cute.size(tdOsdO.shape[1])):
|
|
1238
|
+
# If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked
|
|
1239
|
+
if self.is_even_m_smem_do or m < cute.size(tdOsdO.shape[1]) - 1 or tdOcdO[0, m, 0][0] < self.m_block_size:
|
|
1240
|
+
# Instead of using tdOcdO, we using t0dOcdO and subtract the offset from the limit
|
|
1241
|
+
# (seqlen - block * kBlockM). This is because the entries of t0dOcdO are known at compile time.
|
|
1242
|
+
predicate_m = t0dOcdO[0, m, 0][0] < seqlen - block * self.m_block_size - tdOcdO[0][0]
|
|
1243
|
+
predicate = cute.make_fragment_like(tdOpdO[None, 0, None])
|
|
1244
|
+
for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
|
|
1245
|
+
for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
|
|
1246
|
+
predicate[i, k] = (tdOpdO[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m
|
|
1247
|
+
cute.copy(
|
|
1248
|
+
gmem_tiled_copy_dO,
|
|
1249
|
+
tdOgdO[None, m, None, block],
|
|
1250
|
+
tdOsdO[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0],
|
|
1251
|
+
pred=predicate,
|
|
1252
|
+
)
|
|
1253
|
+
# We need to clear the sQ smem tiles since we'll use sQt for mma_dK
|
|
1254
|
+
# We made sure LSE length is padded so we read `kBlockM` elements so that all
|
|
1255
|
+
# elements in sLSE are filled. Without this we might have uninitialized sLSE values.
|
|
1256
|
+
for m in cutlass.range_constexpr(cute.size(tdPsumgdPsum.shape[1])):
|
|
1257
|
+
if tdPsumcdPsum[0, m][0] < self.m_block_size:
|
|
1258
|
+
cute.copy(
|
|
1259
|
+
gmem_tiled_copy_dPsum,
|
|
1260
|
+
tdPsumgdPsum[None, m, block],
|
|
1261
|
+
tdPsumsdPsum[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0],
|
|
1262
|
+
)
|