mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
|
3
|
+
# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h
|
|
4
|
+
# from Cutlass C++ to Cute-DSL.
|
|
5
|
+
import math
|
|
6
|
+
from typing import Callable, Optional, Type, Literal
|
|
7
|
+
|
|
8
|
+
import cuda.bindings.driver as cuda
|
|
9
|
+
|
|
10
|
+
import cutlass
|
|
11
|
+
import cutlass.cute as cute
|
|
12
|
+
import cutlass.utils.hopper_helpers as sm90_utils_basic
|
|
13
|
+
import cutlass.utils.blackwell_helpers as sm100_utils_basic
|
|
14
|
+
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
|
15
|
+
from cutlass import Float32, const_expr
|
|
16
|
+
from cutlass.utils import LayoutEnum
|
|
17
|
+
|
|
18
|
+
from mslk.attention.flash_attn import utils
|
|
19
|
+
from mslk.attention.flash_attn import copy_utils
|
|
20
|
+
from mslk.attention.flash_attn import ampere_helpers as sm80_utils
|
|
21
|
+
from mslk.attention.flash_attn import hopper_helpers as sm90_utils
|
|
22
|
+
from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
|
|
23
|
+
import cutlass.cute.nvgpu.tcgen05 as tcgen05
|
|
24
|
+
from mslk.attention.flash_attn.tile_scheduler import (
|
|
25
|
+
ParamsBase,
|
|
26
|
+
SingleTileScheduler,
|
|
27
|
+
SingleTileVarlenScheduler,
|
|
28
|
+
TileSchedulerArguments,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class FlashAttentionBackwardPostprocess:
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
dtype: Type[cutlass.Numeric],
|
|
36
|
+
head_dim: int,
|
|
37
|
+
arch: Literal[80, 90, 100],
|
|
38
|
+
tile_m: int = 128,
|
|
39
|
+
num_threads: int = 256,
|
|
40
|
+
AtomLayoutMdQ: int = 1,
|
|
41
|
+
dQ_swapAB: bool = False,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
:param head_dim: head dimension
|
|
45
|
+
:type head_dim: int
|
|
46
|
+
:param tile_m: m block size
|
|
47
|
+
:type tile_m: int
|
|
48
|
+
"""
|
|
49
|
+
self.dtype = dtype
|
|
50
|
+
self.tile_m = tile_m
|
|
51
|
+
assert arch in [80, 90, 100], (
|
|
52
|
+
"Only Ampere (80), Hopper (90), and Blackwell (100) are supported"
|
|
53
|
+
)
|
|
54
|
+
self.arch = arch
|
|
55
|
+
# padding head_dim to a multiple of 32 as k_block_size
|
|
56
|
+
hdim_multiple_of = 32
|
|
57
|
+
self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
|
|
58
|
+
self.check_hdim_oob = head_dim != self.tile_hdim
|
|
59
|
+
self.num_threads = num_threads
|
|
60
|
+
self.AtomLayoutMdQ = AtomLayoutMdQ
|
|
61
|
+
self.dQ_swapAB = dQ_swapAB
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def can_implement(dtype, head_dim, tile_m, num_threads) -> bool:
|
|
65
|
+
"""Check if the kernel can be implemented with the given parameters.
|
|
66
|
+
|
|
67
|
+
:param dtype: data type
|
|
68
|
+
:type dtype: cutlass.Numeric
|
|
69
|
+
:param head_dim: head dimension
|
|
70
|
+
:type head_dim: int
|
|
71
|
+
:param tile_m: m block size
|
|
72
|
+
:type tile_m: int
|
|
73
|
+
|
|
74
|
+
:return: True if the kernel can be implemented, False otherwise
|
|
75
|
+
:rtype: bool
|
|
76
|
+
"""
|
|
77
|
+
if dtype not in [cutlass.Float16, cutlass.BFloat16]:
|
|
78
|
+
return False
|
|
79
|
+
if head_dim % 8 != 0:
|
|
80
|
+
return False
|
|
81
|
+
if num_threads % 32 != 0:
|
|
82
|
+
return False
|
|
83
|
+
return True
|
|
84
|
+
|
|
85
|
+
def _get_tiled_mma(self):
|
|
86
|
+
if const_expr(self.arch == 80):
|
|
87
|
+
num_mma_warps = self.num_threads // 32
|
|
88
|
+
atom_layout_dQ = (
|
|
89
|
+
(self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1)
|
|
90
|
+
if const_expr(not self.dQ_swapAB)
|
|
91
|
+
else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1)
|
|
92
|
+
)
|
|
93
|
+
tiled_mma = cute.make_tiled_mma(
|
|
94
|
+
warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)),
|
|
95
|
+
atom_layout_dQ,
|
|
96
|
+
permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16),
|
|
97
|
+
)
|
|
98
|
+
elif const_expr(self.arch == 90):
|
|
99
|
+
num_mma_warp_groups = self.num_threads // 128
|
|
100
|
+
atom_layout_dQ = (self.AtomLayoutMdQ, num_mma_warp_groups // self.AtomLayoutMdQ)
|
|
101
|
+
tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])
|
|
102
|
+
tiled_mma = sm90_utils_basic.make_trivial_tiled_mma(
|
|
103
|
+
self.dtype,
|
|
104
|
+
self.dtype,
|
|
105
|
+
warpgroup.OperandMajorMode.K, # These don't matter, we only care about the accum
|
|
106
|
+
warpgroup.OperandMajorMode.K,
|
|
107
|
+
Float32,
|
|
108
|
+
atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1])
|
|
109
|
+
+ (1,),
|
|
110
|
+
tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1],
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
cta_group = tcgen05.CtaGroup.ONE
|
|
114
|
+
tiled_mma = sm100_utils_basic.make_trivial_tiled_mma(
|
|
115
|
+
self.dtype,
|
|
116
|
+
tcgen05.OperandMajorMode.MN, # dS_major_mode
|
|
117
|
+
tcgen05.OperandMajorMode.MN, # Kt_major_mode
|
|
118
|
+
Float32,
|
|
119
|
+
cta_group,
|
|
120
|
+
(self.tile_m, self.tile_hdim),
|
|
121
|
+
)
|
|
122
|
+
if const_expr(self.arch in [80, 90]):
|
|
123
|
+
assert self.num_threads == tiled_mma.size
|
|
124
|
+
return tiled_mma
|
|
125
|
+
|
|
126
|
+
def _setup_attributes(self):
|
|
127
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
128
|
+
# GMEM Tiled copy:
|
|
129
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
130
|
+
# Thread layouts for copies
|
|
131
|
+
universal_copy_bits = 128
|
|
132
|
+
async_copy_elems_accum = universal_copy_bits // Float32.width
|
|
133
|
+
atom_async_copy_accum = cute.make_copy_atom(
|
|
134
|
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
|
135
|
+
Float32,
|
|
136
|
+
num_bits_per_copy=universal_copy_bits,
|
|
137
|
+
)
|
|
138
|
+
# We don't do bound checking for the gmem -> smem load so we just assert here.
|
|
139
|
+
assert (self.tile_m * self.tile_hdim // async_copy_elems_accum) % self.num_threads == 0
|
|
140
|
+
self.g2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
|
|
141
|
+
atom_async_copy_accum,
|
|
142
|
+
cute.make_layout(self.num_threads),
|
|
143
|
+
cute.make_layout(async_copy_elems_accum),
|
|
144
|
+
)
|
|
145
|
+
num_s2r_copy_elems = 1 if const_expr(self.arch == 80) else 4
|
|
146
|
+
if const_expr(self.arch == 80):
|
|
147
|
+
self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
|
|
148
|
+
Float32, self.num_threads, num_s2r_copy_elems
|
|
149
|
+
)
|
|
150
|
+
self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim)
|
|
151
|
+
elif const_expr(self.arch == 90):
|
|
152
|
+
num_threads_per_warp_group = 128
|
|
153
|
+
num_mma_warp_groups = self.num_threads // 128
|
|
154
|
+
self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
|
|
155
|
+
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
|
|
156
|
+
cute.make_layout((num_threads_per_warp_group, num_mma_warp_groups)), # thr_layout
|
|
157
|
+
cute.make_layout(128 // Float32.width), # val_layout
|
|
158
|
+
)
|
|
159
|
+
self.sdQaccum_layout = cute.make_layout(
|
|
160
|
+
(self.tile_m * self.tile_hdim // num_mma_warp_groups, num_mma_warp_groups)
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
self.dQ_reduce_ncol = 32
|
|
164
|
+
dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol
|
|
165
|
+
assert self.num_threads == 128 # TODO: currently hard-coded
|
|
166
|
+
self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
|
|
167
|
+
Float32, self.num_threads, num_s2r_copy_elems
|
|
168
|
+
)
|
|
169
|
+
self.sdQaccum_layout = cute.make_layout(
|
|
170
|
+
(self.tile_m * self.tile_hdim // dQaccum_reduce_stage, dQaccum_reduce_stage)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d(
|
|
174
|
+
self.dtype, self.tile_hdim, self.num_threads
|
|
175
|
+
)
|
|
176
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
177
|
+
# Shared memory layout: dQ
|
|
178
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
179
|
+
# We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,
|
|
180
|
+
# then setting kBlockKSmem to 32 will cause "Static shape_div failure".
|
|
181
|
+
# We want to treat it as 64 x 48, so kBlockKSmem should be 16.
|
|
182
|
+
mma_shape_n = self.tiled_mma.get_tile_size(1)
|
|
183
|
+
if const_expr(self.arch == 80):
|
|
184
|
+
sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n)
|
|
185
|
+
self.sdQ_layout = cute.tile_to_shape(
|
|
186
|
+
sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1)
|
|
187
|
+
)
|
|
188
|
+
elif const_expr(self.arch == 90):
|
|
189
|
+
self.sdQ_layout = sm90_utils.make_smem_layout(
|
|
190
|
+
self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim)
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
# TODO: this is hard-coded for hdim 128
|
|
194
|
+
self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi(
|
|
195
|
+
self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim), 1
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
@cute.jit
|
|
199
|
+
def __call__(
|
|
200
|
+
self,
|
|
201
|
+
mdQaccum: cute.Tensor,
|
|
202
|
+
mdQ: cute.Tensor,
|
|
203
|
+
scale: cutlass.Float32,
|
|
204
|
+
mCuSeqlensQ: Optional[cute.Tensor],
|
|
205
|
+
mSeqUsedQ: Optional[cute.Tensor],
|
|
206
|
+
stream: cuda.CUstream,
|
|
207
|
+
):
|
|
208
|
+
# Get the data type and check if it is fp16 or bf16
|
|
209
|
+
if const_expr(mdQ.element_type not in [cutlass.Float16, cutlass.BFloat16]):
|
|
210
|
+
raise TypeError("Only Float16 or BFloat16 is supported")
|
|
211
|
+
if const_expr(mdQaccum is not None):
|
|
212
|
+
if const_expr(mdQaccum.element_type not in [cutlass.Float32]):
|
|
213
|
+
raise TypeError("dQaccum tensor must be Float32")
|
|
214
|
+
|
|
215
|
+
# Assume all strides are divisible by 128 bits except the last stride
|
|
216
|
+
new_stride = lambda t: (
|
|
217
|
+
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
|
|
218
|
+
t.stride[-1],
|
|
219
|
+
)
|
|
220
|
+
mdQaccum, mdQ = [
|
|
221
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
222
|
+
for t in (mdQaccum, mdQ)
|
|
223
|
+
]
|
|
224
|
+
|
|
225
|
+
self.tiled_mma = self._get_tiled_mma()
|
|
226
|
+
self._setup_attributes()
|
|
227
|
+
|
|
228
|
+
smem_size = max(
|
|
229
|
+
cute.size_in_bytes(cutlass.Float32, self.sdQaccum_layout),
|
|
230
|
+
cute.size_in_bytes(self.dtype, self.sdQ_layout),
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
if const_expr(mCuSeqlensQ is not None):
|
|
234
|
+
TileScheduler = SingleTileVarlenScheduler
|
|
235
|
+
num_head = mdQ.shape[1]
|
|
236
|
+
num_batch = mCuSeqlensQ.shape[0] - 1
|
|
237
|
+
num_block = cute.ceil_div(mdQ.shape[0], self.tile_m)
|
|
238
|
+
else:
|
|
239
|
+
TileScheduler = SingleTileScheduler
|
|
240
|
+
num_head = mdQ.shape[2]
|
|
241
|
+
num_batch = mdQ.shape[0]
|
|
242
|
+
num_block = cute.ceil_div(mdQ.shape[1], self.tile_m)
|
|
243
|
+
|
|
244
|
+
tile_sched_args = TileSchedulerArguments(
|
|
245
|
+
num_block=num_block,
|
|
246
|
+
num_head=num_head,
|
|
247
|
+
num_batch=num_batch,
|
|
248
|
+
num_splits=1,
|
|
249
|
+
seqlen_k=0,
|
|
250
|
+
headdim=mdQ.shape[2],
|
|
251
|
+
headdim_v=0,
|
|
252
|
+
total_q=mdQ.shape[0],
|
|
253
|
+
tile_shape_mn=(self.tile_m, 1),
|
|
254
|
+
mCuSeqlensQ=mCuSeqlensQ,
|
|
255
|
+
mSeqUsedQ=mSeqUsedQ,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
|
259
|
+
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
|
|
260
|
+
|
|
261
|
+
# grid_dim: (m_block, num_head, batch_size)
|
|
262
|
+
self.kernel(
|
|
263
|
+
mdQaccum,
|
|
264
|
+
mdQ,
|
|
265
|
+
mCuSeqlensQ,
|
|
266
|
+
mSeqUsedQ,
|
|
267
|
+
scale,
|
|
268
|
+
self.tiled_mma,
|
|
269
|
+
self.dQ_swapAB,
|
|
270
|
+
self.sdQaccum_layout,
|
|
271
|
+
self.sdQ_layout,
|
|
272
|
+
self.g2s_tiled_copy_dQaccum,
|
|
273
|
+
self.s2r_tiled_copy_dQaccum,
|
|
274
|
+
self.gmem_tiled_copy_dQ,
|
|
275
|
+
tile_sched_params,
|
|
276
|
+
TileScheduler,
|
|
277
|
+
).launch(
|
|
278
|
+
grid=grid_dim,
|
|
279
|
+
block=[self.num_threads, 1, 1],
|
|
280
|
+
smem=smem_size,
|
|
281
|
+
stream=stream,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
@cute.kernel
|
|
285
|
+
def kernel(
|
|
286
|
+
self,
|
|
287
|
+
mdQaccum: cute.Tensor,
|
|
288
|
+
mdQ: cute.Tensor,
|
|
289
|
+
mCuSeqlensQ: Optional[cute.Tensor],
|
|
290
|
+
mSeqUsedQ: Optional[cute.Tensor],
|
|
291
|
+
scale: cutlass.Float32,
|
|
292
|
+
tiled_mma: cute.TiledMma,
|
|
293
|
+
dQ_swapAB: cutlass.Constexpr,
|
|
294
|
+
sdQaccum_layout: cute.Layout,
|
|
295
|
+
sdQ_layout: cute.ComposedLayout,
|
|
296
|
+
g2s_tiled_copy_dQaccum: cute.TiledCopy,
|
|
297
|
+
s2r_tiled_copy_dQaccum: cute.TiledCopy,
|
|
298
|
+
gmem_tiled_copy_dQ: cute.TiledCopy,
|
|
299
|
+
tile_sched_params: ParamsBase,
|
|
300
|
+
TileScheduler: cutlass.Constexpr[Callable],
|
|
301
|
+
):
|
|
302
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
303
|
+
# Get shared memory buffer
|
|
304
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
305
|
+
smem = cutlass.utils.SmemAllocator()
|
|
306
|
+
sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024)
|
|
307
|
+
sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum)))
|
|
308
|
+
if const_expr(self.arch in [80, 90]):
|
|
309
|
+
sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout)
|
|
310
|
+
else:
|
|
311
|
+
# extra stage dimension
|
|
312
|
+
sdQ = cute.make_tensor(
|
|
313
|
+
cute.recast_ptr(sdQaccum.iterator, sdQ_layout.inner, dtype=self.dtype),
|
|
314
|
+
sdQ_layout.outer,
|
|
315
|
+
)[None, None, 0]
|
|
316
|
+
sdQt = utils.transpose_view(sdQ)
|
|
317
|
+
|
|
318
|
+
# Thread index, block index
|
|
319
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
320
|
+
|
|
321
|
+
tile_scheduler = TileScheduler.create(tile_sched_params)
|
|
322
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
323
|
+
|
|
324
|
+
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
|
325
|
+
|
|
326
|
+
if work_tile.is_valid_tile:
|
|
327
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
328
|
+
# Get the appropriate tiles for this thread block.
|
|
329
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
330
|
+
|
|
331
|
+
seqlen = SeqlenInfoQK.create(
|
|
332
|
+
batch_idx,
|
|
333
|
+
mdQ.shape[1],
|
|
334
|
+
0,
|
|
335
|
+
mCuSeqlensQ=mCuSeqlensQ,
|
|
336
|
+
mCuSeqlensK=None,
|
|
337
|
+
mSeqUsedQ=mSeqUsedQ,
|
|
338
|
+
mSeqUsedK=None,
|
|
339
|
+
)
|
|
340
|
+
if const_expr(not seqlen.has_cu_seqlens_q):
|
|
341
|
+
mdQ_cur = mdQ[batch_idx, None, head_idx, None]
|
|
342
|
+
mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
|
|
343
|
+
head_dim = mdQ.shape[3]
|
|
344
|
+
else:
|
|
345
|
+
padded_offset_q = seqlen.offset_q + batch_idx * self.tile_m
|
|
346
|
+
if cutlass.const_expr(self.arch >= 90):
|
|
347
|
+
padded_offset_q = padded_offset_q // self.tile_m * self.tile_m
|
|
348
|
+
mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None])
|
|
349
|
+
mdQaccum_cur = cute.domain_offset(
|
|
350
|
+
(padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None]
|
|
351
|
+
)
|
|
352
|
+
head_dim = mdQ.shape[2]
|
|
353
|
+
|
|
354
|
+
# HACK: Compiler doesn't seem to recognize that padding
|
|
355
|
+
# by padded_offset_q * self.tile_hdim keeps alignment
|
|
356
|
+
# since statically divisible by 4
|
|
357
|
+
|
|
358
|
+
mdQaccum_cur_ptr = cute.make_ptr(
|
|
359
|
+
dtype=mdQaccum_cur.element_type,
|
|
360
|
+
value=mdQaccum_cur.iterator.toint(),
|
|
361
|
+
mem_space=mdQaccum_cur.iterator.memspace,
|
|
362
|
+
assumed_align=mdQaccum.iterator.alignment,
|
|
363
|
+
)
|
|
364
|
+
mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout)
|
|
365
|
+
|
|
366
|
+
gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (m_block,))
|
|
367
|
+
gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))
|
|
368
|
+
|
|
369
|
+
seqlen_q = seqlen.seqlen_q
|
|
370
|
+
seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m)
|
|
371
|
+
|
|
372
|
+
# Step 1: load dQaccum from gmem to smem
|
|
373
|
+
g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx)
|
|
374
|
+
tdQgdQaccum = g2s_thr_copy_dQaccum.partition_S(gdQaccum)
|
|
375
|
+
tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum_flat)
|
|
376
|
+
cute.copy(g2s_tiled_copy_dQaccum, tdQgdQaccum, tdQsdQaccumg2s)
|
|
377
|
+
cute.arch.cp_async_commit_group()
|
|
378
|
+
cute.arch.cp_async_wait_group(0)
|
|
379
|
+
cute.arch.barrier()
|
|
380
|
+
|
|
381
|
+
# Step 2: load dQ from smem to rmem
|
|
382
|
+
s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx)
|
|
383
|
+
tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum)
|
|
384
|
+
tile_shape = (self.tile_m, self.tile_hdim)
|
|
385
|
+
acc = None
|
|
386
|
+
tiled_copy_t2r = None
|
|
387
|
+
if const_expr(self.arch in [80, 90]):
|
|
388
|
+
acc_shape = tiled_mma.partition_shape_C(
|
|
389
|
+
tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1]
|
|
390
|
+
)
|
|
391
|
+
acc = cute.make_fragment(acc_shape, cutlass.Float32)
|
|
392
|
+
assert cute.size(acc) == cute.size(tdQsdQaccum)
|
|
393
|
+
else:
|
|
394
|
+
thr_mma = tiled_mma.get_slice(0) # 1-CTA
|
|
395
|
+
dQacc_shape = tiled_mma.partition_shape_C((self.tile_m, self.tile_hdim))
|
|
396
|
+
tdQtdQ = tiled_mma.make_fragment_C(dQacc_shape)
|
|
397
|
+
tdQcdQ = thr_mma.partition_C(
|
|
398
|
+
cute.make_identity_tensor((self.tile_m, self.tile_hdim))
|
|
399
|
+
)
|
|
400
|
+
tmem_load_atom = cute.make_copy_atom(
|
|
401
|
+
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32
|
|
402
|
+
)
|
|
403
|
+
tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ)
|
|
404
|
+
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
|
405
|
+
tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape
|
|
406
|
+
acc = cute.make_fragment(tdQrdQ_t2r_shape, Float32)
|
|
407
|
+
tdQrdQaccum = cute.make_tensor(acc.iterator, cute.make_layout(tdQsdQaccum.shape))
|
|
408
|
+
cute.autovec_copy(tdQsdQaccum, tdQrdQaccum)
|
|
409
|
+
# Convert tdQrdQaccum from fp32 to fp16/bf16
|
|
410
|
+
rdQ = cute.make_fragment_like(acc, self.dtype)
|
|
411
|
+
rdQ.store((acc.load() * scale).to(self.dtype))
|
|
412
|
+
|
|
413
|
+
# Step 3: Copy dQ from register to smem
|
|
414
|
+
cute.arch.barrier() # make sure all threads have finished loading dQaccum
|
|
415
|
+
if const_expr(self.arch in [80, 90]):
|
|
416
|
+
copy_atom_r2s_dQ = utils.get_smem_store_atom(
|
|
417
|
+
self.arch, self.dtype, transpose=self.dQ_swapAB
|
|
418
|
+
)
|
|
419
|
+
tiled_copy_r2s_dQ = cute.make_tiled_copy_C(copy_atom_r2s_dQ, tiled_mma)
|
|
420
|
+
else:
|
|
421
|
+
# copy_atom_r2s_dQ = sm100_utils_basic.get_smem_store_op(
|
|
422
|
+
# LayoutEnum.ROW_MAJOR, self.dtype, Float32, tiled_copy_t2r,
|
|
423
|
+
# )
|
|
424
|
+
# tiled_copy_r2s_dQ = cute.make_tiled_copy_D(copy_atom_r2s_dQ, tiled_copy_t2r)
|
|
425
|
+
thr_layout_r2s_dQ = cute.make_layout((self.num_threads, 1)) # 128 threads
|
|
426
|
+
val_layout_r2s_dQ = cute.make_layout((1, 128 // self.dtype.width))
|
|
427
|
+
copy_atom_r2s_dQ = cute.make_copy_atom(
|
|
428
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
429
|
+
self.dtype,
|
|
430
|
+
num_bits_per_copy=128,
|
|
431
|
+
)
|
|
432
|
+
tiled_copy_r2s_dQ = cute.make_tiled_copy_tv(
|
|
433
|
+
copy_atom_r2s_dQ, thr_layout_r2s_dQ, val_layout_r2s_dQ
|
|
434
|
+
)
|
|
435
|
+
thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx)
|
|
436
|
+
cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim))
|
|
437
|
+
if const_expr(self.arch in [80, 90]):
|
|
438
|
+
taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ)
|
|
439
|
+
else:
|
|
440
|
+
taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape
|
|
441
|
+
taccdQrdQ = cute.make_tensor(rdQ.iterator, taccdQcdQ_shape)
|
|
442
|
+
taccdQsdQ = thr_copy_r2s_dQ.partition_D(sdQ if const_expr(not self.dQ_swapAB) else sdQt)
|
|
443
|
+
cute.copy(thr_copy_r2s_dQ, taccdQrdQ, taccdQsdQ)
|
|
444
|
+
|
|
445
|
+
# Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem
|
|
446
|
+
cute.arch.barrier() # make sure all smem stores are done
|
|
447
|
+
gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_slice(tidx)
|
|
448
|
+
tdQgdQ = gmem_thr_copy_dQ.partition_S(gdQ)
|
|
449
|
+
tdQsdQ = gmem_thr_copy_dQ.partition_D(sdQ)
|
|
450
|
+
tdQrdQ = cute.make_fragment_like(tdQsdQ, self.dtype)
|
|
451
|
+
# TODO: check OOB when reading from smem if kBlockM isn't evenly tiled
|
|
452
|
+
cute.autovec_copy(tdQsdQ, tdQrdQ)
|
|
453
|
+
|
|
454
|
+
# Step 5: Copy dQ from register to gmem
|
|
455
|
+
tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ)
|
|
456
|
+
tdQpdQ = utils.predicate_k(tdQcdQ, limit=head_dim)
|
|
457
|
+
for rest_m in cutlass.range(cute.size(tdQrdQ.shape[1]), unroll_full=True):
|
|
458
|
+
if tdQcdQ[0, rest_m, 0][0] < seqlen_q - m_block * self.tile_m:
|
|
459
|
+
cute.copy(
|
|
460
|
+
gmem_tiled_copy_dQ,
|
|
461
|
+
tdQrdQ[None, rest_m, None],
|
|
462
|
+
tdQgdQ[None, rest_m, None],
|
|
463
|
+
pred=tdQpdQ[None, rest_m, None],
|
|
464
|
+
)
|