mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
|
3
|
+
# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_preprocess_kernel.h
|
|
4
|
+
# from Cutlass C++ to Cute-DSL.
|
|
5
|
+
import math
|
|
6
|
+
import operator
|
|
7
|
+
from typing import Callable, Type, Optional, Literal
|
|
8
|
+
|
|
9
|
+
import cuda.bindings.driver as cuda
|
|
10
|
+
|
|
11
|
+
import cutlass
|
|
12
|
+
import cutlass.cute as cute
|
|
13
|
+
from cutlass import Float32
|
|
14
|
+
|
|
15
|
+
from mslk.attention.flash_attn import utils
|
|
16
|
+
from mslk.attention.flash_attn import copy_utils
|
|
17
|
+
from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
|
|
18
|
+
from mslk.attention.flash_attn.tile_scheduler import (
|
|
19
|
+
ParamsBase,
|
|
20
|
+
SingleTileScheduler,
|
|
21
|
+
SingleTileVarlenScheduler,
|
|
22
|
+
TileSchedulerArguments,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class FlashAttentionBackwardPreprocess:
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
dtype: Type[cutlass.Numeric],
|
|
30
|
+
head_dim: int,
|
|
31
|
+
arch: Literal[80, 90, 100],
|
|
32
|
+
m_block_size: int = 128,
|
|
33
|
+
num_threads: int = 128,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension
|
|
37
|
+
should be a multiple of 8.
|
|
38
|
+
|
|
39
|
+
:param head_dim: head dimension
|
|
40
|
+
:type head_dim: int
|
|
41
|
+
:param m_block_size: m block size
|
|
42
|
+
:type m_block_size: int
|
|
43
|
+
:param num_threads: number of threads
|
|
44
|
+
:type num_threads: int
|
|
45
|
+
"""
|
|
46
|
+
self.dtype = dtype
|
|
47
|
+
self.m_block_size = m_block_size
|
|
48
|
+
self.arch = arch
|
|
49
|
+
# padding head_dim to a multiple of 32 as k_block_size
|
|
50
|
+
hdim_multiple_of = 32
|
|
51
|
+
self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
|
|
52
|
+
self.check_hdim_oob = head_dim != self.head_dim_padded
|
|
53
|
+
self.num_threads = num_threads
|
|
54
|
+
|
|
55
|
+
@staticmethod
|
|
56
|
+
def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool:
|
|
57
|
+
"""Check if the kernel can be implemented with the given parameters.
|
|
58
|
+
|
|
59
|
+
:param dtype: data type
|
|
60
|
+
:type dtype: cutlass.Numeric
|
|
61
|
+
:param head_dim: head dimension
|
|
62
|
+
:type head_dim: int
|
|
63
|
+
:param m_block_size: m block size
|
|
64
|
+
:type m_block_size: int
|
|
65
|
+
:param num_threads: number of threads
|
|
66
|
+
:type num_threads: int
|
|
67
|
+
|
|
68
|
+
:return: True if the kernel can be implemented, False otherwise
|
|
69
|
+
:rtype: bool
|
|
70
|
+
"""
|
|
71
|
+
if dtype not in [cutlass.Float16, cutlass.BFloat16]:
|
|
72
|
+
return False
|
|
73
|
+
if head_dim % 8 != 0:
|
|
74
|
+
return False
|
|
75
|
+
if num_threads % 32 != 0:
|
|
76
|
+
return False
|
|
77
|
+
if num_threads < m_block_size: # For multiplying lse with log2
|
|
78
|
+
return False
|
|
79
|
+
return True
|
|
80
|
+
|
|
81
|
+
def _setup_attributes(self):
|
|
82
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
83
|
+
# GMEM Tiled copy:
|
|
84
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
85
|
+
# Thread layouts for copies
|
|
86
|
+
# We want kBlockKGmem to be a power of 2 so that when we do the summing,
|
|
87
|
+
# it's just between threads in the same warp
|
|
88
|
+
gmem_k_block_size = (
|
|
89
|
+
128
|
|
90
|
+
if self.head_dim_padded % 128 == 0
|
|
91
|
+
else (
|
|
92
|
+
64
|
|
93
|
+
if self.head_dim_padded % 64 == 0
|
|
94
|
+
else (32 if self.head_dim_padded % 32 == 0 else 16)
|
|
95
|
+
)
|
|
96
|
+
)
|
|
97
|
+
self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d(
|
|
98
|
+
self.dtype, gmem_k_block_size, self.num_threads
|
|
99
|
+
)
|
|
100
|
+
universal_copy_bits = 128
|
|
101
|
+
num_copy_elems_dQaccum = universal_copy_bits // Float32.width
|
|
102
|
+
assert (
|
|
103
|
+
self.m_block_size * self.head_dim_padded // num_copy_elems_dQaccum
|
|
104
|
+
) % self.num_threads == 0
|
|
105
|
+
self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
|
|
106
|
+
Float32, self.num_threads, num_copy_elems_dQaccum
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
@cute.jit
|
|
110
|
+
def __call__(
|
|
111
|
+
self,
|
|
112
|
+
mO: cute.Tensor,
|
|
113
|
+
mdO: cute.Tensor,
|
|
114
|
+
mdPsum: cute.Tensor,
|
|
115
|
+
mLSE: Optional[cute.Tensor],
|
|
116
|
+
mLSElog2: Optional[cute.Tensor],
|
|
117
|
+
mdQaccum: Optional[cute.Tensor],
|
|
118
|
+
mCuSeqlensQ: Optional[cute.Tensor],
|
|
119
|
+
mSeqUsedQ: Optional[cute.Tensor],
|
|
120
|
+
stream: cuda.CUstream,
|
|
121
|
+
):
|
|
122
|
+
# Get the data type and check if it is fp16 or bf16
|
|
123
|
+
if cutlass.const_expr(not (mO.element_type == mdO.element_type)):
|
|
124
|
+
raise TypeError("All tensors must have the same data type")
|
|
125
|
+
if cutlass.const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]):
|
|
126
|
+
raise TypeError("Only Float16 or BFloat16 is supported")
|
|
127
|
+
if cutlass.const_expr(mdPsum.element_type not in [Float32]):
|
|
128
|
+
raise TypeError("dPsum tensor must be Float32")
|
|
129
|
+
if cutlass.const_expr(mdQaccum is not None):
|
|
130
|
+
if cutlass.const_expr(mdQaccum.element_type not in [Float32]):
|
|
131
|
+
raise TypeError("dQaccum tensor must be Float32")
|
|
132
|
+
if cutlass.const_expr(mLSE is not None):
|
|
133
|
+
assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided"
|
|
134
|
+
if cutlass.const_expr(mLSE.element_type not in [Float32]):
|
|
135
|
+
raise TypeError("LSE tensor must be Float32")
|
|
136
|
+
if cutlass.const_expr(mLSElog2.element_type not in [Float32]):
|
|
137
|
+
raise TypeError("LSElog2 tensor must be Float32")
|
|
138
|
+
|
|
139
|
+
# Assume all strides are divisible by 128 bits except the last stride
|
|
140
|
+
new_stride = lambda t: (
|
|
141
|
+
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
|
|
142
|
+
t.stride[-1],
|
|
143
|
+
)
|
|
144
|
+
mO, mdO, mdQaccum = [
|
|
145
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
146
|
+
if t is not None
|
|
147
|
+
else None
|
|
148
|
+
for t in (mO, mdO, mdQaccum)
|
|
149
|
+
]
|
|
150
|
+
|
|
151
|
+
self._setup_attributes()
|
|
152
|
+
|
|
153
|
+
if cutlass.const_expr(mCuSeqlensQ is not None):
|
|
154
|
+
TileScheduler = SingleTileVarlenScheduler
|
|
155
|
+
num_head = mO.shape[1]
|
|
156
|
+
num_batch = mCuSeqlensQ.shape[0] - 1
|
|
157
|
+
else:
|
|
158
|
+
TileScheduler = SingleTileScheduler
|
|
159
|
+
num_head = mO.shape[2]
|
|
160
|
+
num_batch = mO.shape[0]
|
|
161
|
+
|
|
162
|
+
tile_sched_args = TileSchedulerArguments(
|
|
163
|
+
num_block=cute.ceil_div(mO.shape[1], self.m_block_size),
|
|
164
|
+
num_head=num_head,
|
|
165
|
+
num_batch=num_batch,
|
|
166
|
+
num_splits=1,
|
|
167
|
+
seqlen_k=0,
|
|
168
|
+
headdim=0,
|
|
169
|
+
headdim_v=mO.shape[2],
|
|
170
|
+
total_q=mO.shape[0],
|
|
171
|
+
tile_shape_mn=(self.m_block_size, 1),
|
|
172
|
+
mCuSeqlensQ=mCuSeqlensQ,
|
|
173
|
+
mSeqUsedQ=mSeqUsedQ,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
|
177
|
+
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
|
|
178
|
+
|
|
179
|
+
self.kernel(
|
|
180
|
+
mO,
|
|
181
|
+
mdO,
|
|
182
|
+
mdPsum,
|
|
183
|
+
mLSE,
|
|
184
|
+
mLSElog2,
|
|
185
|
+
mdQaccum,
|
|
186
|
+
mCuSeqlensQ,
|
|
187
|
+
mSeqUsedQ,
|
|
188
|
+
self.gmem_tiled_copy_O,
|
|
189
|
+
self.gmem_tiled_copy_dQaccum,
|
|
190
|
+
tile_sched_params,
|
|
191
|
+
TileScheduler,
|
|
192
|
+
).launch(
|
|
193
|
+
grid=grid_dim,
|
|
194
|
+
block=[self.num_threads, 1, 1],
|
|
195
|
+
stream=stream,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
@cute.kernel
|
|
199
|
+
def kernel(
|
|
200
|
+
self,
|
|
201
|
+
mO: cute.Tensor,
|
|
202
|
+
mdO: cute.Tensor,
|
|
203
|
+
mdPsum: cute.Tensor,
|
|
204
|
+
mLSE: Optional[cute.Tensor],
|
|
205
|
+
mLSElog2: Optional[cute.Tensor],
|
|
206
|
+
mdQaccum: Optional[cute.Tensor],
|
|
207
|
+
mCuSeqlensQ: Optional[cute.Tensor],
|
|
208
|
+
mSeqUsedQ: Optional[cute.Tensor],
|
|
209
|
+
gmem_tiled_copy_O: cute.TiledCopy,
|
|
210
|
+
gmem_tiled_copy_dQaccum: cute.TiledCopy,
|
|
211
|
+
tile_sched_params: ParamsBase,
|
|
212
|
+
TileScheduler: cutlass.Constexpr[Callable],
|
|
213
|
+
):
|
|
214
|
+
# Thread index, block index
|
|
215
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
216
|
+
|
|
217
|
+
tile_scheduler = TileScheduler.create(tile_sched_params)
|
|
218
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
219
|
+
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
|
220
|
+
|
|
221
|
+
if work_tile.is_valid_tile:
|
|
222
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
223
|
+
# Get the appropriate tiles for this thread block.
|
|
224
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
225
|
+
seqlen = SeqlenInfoQK.create(
|
|
226
|
+
batch_idx,
|
|
227
|
+
mO.shape[1],
|
|
228
|
+
0,
|
|
229
|
+
mCuSeqlensQ=mCuSeqlensQ,
|
|
230
|
+
mCuSeqlensK=None,
|
|
231
|
+
mSeqUsedQ=mSeqUsedQ,
|
|
232
|
+
mSeqUsedK=None,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
|
|
236
|
+
mO_cur = mO[batch_idx, None, head_idx, None]
|
|
237
|
+
mdO_cur = mdO[batch_idx, None, head_idx, None]
|
|
238
|
+
mdPsum_cur = mdPsum[batch_idx, head_idx, None]
|
|
239
|
+
headdim_v = mO.shape[3]
|
|
240
|
+
else:
|
|
241
|
+
mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, head_idx, None])
|
|
242
|
+
mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
|
|
243
|
+
|
|
244
|
+
padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size
|
|
245
|
+
if cutlass.const_expr(self.arch >= 90):
|
|
246
|
+
padded_offset_q = padded_offset_q // self.m_block_size * self.m_block_size
|
|
247
|
+
mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None])
|
|
248
|
+
headdim_v = mO.shape[2]
|
|
249
|
+
|
|
250
|
+
blkOdO_shape = (self.m_block_size, self.head_dim_padded)
|
|
251
|
+
# (m_block_size, head_dim)
|
|
252
|
+
gO = cute.local_tile(mO_cur, blkOdO_shape, (m_block, 0))
|
|
253
|
+
gdO = cute.local_tile(mdO_cur, blkOdO_shape, (m_block, 0))
|
|
254
|
+
|
|
255
|
+
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
|
|
256
|
+
# (CPY_Atom, CPY_M, CPY_K)
|
|
257
|
+
tOgO = gmem_thr_copy_O.partition_S(gO)
|
|
258
|
+
tOgdO = gmem_thr_copy_O.partition_S(gdO)
|
|
259
|
+
|
|
260
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
261
|
+
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
|
|
262
|
+
# of tile_shape
|
|
263
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
264
|
+
# Construct identity layout for KV
|
|
265
|
+
cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
|
|
266
|
+
tOcO = gmem_thr_copy_O.partition_S(cO)
|
|
267
|
+
t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO)
|
|
268
|
+
tOpO = utils.predicate_k(tOcO, limit=headdim_v)
|
|
269
|
+
tOpdO = utils.predicate_k(tOcO, limit=headdim_v)
|
|
270
|
+
|
|
271
|
+
seqlen_q = seqlen.seqlen_q
|
|
272
|
+
seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size)
|
|
273
|
+
|
|
274
|
+
if cutlass.const_expr(mLSE is not None):
|
|
275
|
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
|
|
276
|
+
mLSE_cur = mLSE[batch_idx, head_idx, None]
|
|
277
|
+
else:
|
|
278
|
+
mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[head_idx, None])
|
|
279
|
+
|
|
280
|
+
gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,))
|
|
281
|
+
lse = Float32.inf
|
|
282
|
+
if tidx < seqlen_q - m_block * self.m_block_size:
|
|
283
|
+
lse = gLSE[tidx]
|
|
284
|
+
|
|
285
|
+
tOrO = cute.make_fragment_like(tOgO)
|
|
286
|
+
tOrdO = cute.make_fragment_like(tOgdO)
|
|
287
|
+
assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0])
|
|
288
|
+
assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1])
|
|
289
|
+
assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2])
|
|
290
|
+
for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True):
|
|
291
|
+
# Instead of using tOcO, we using t0OcO and subtract the offset from the limit
|
|
292
|
+
# (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time.
|
|
293
|
+
if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]:
|
|
294
|
+
cute.copy(
|
|
295
|
+
gmem_thr_copy_O,
|
|
296
|
+
tOgO[None, m, None],
|
|
297
|
+
tOrO[None, m, None],
|
|
298
|
+
pred=tOpO[None, m, None]
|
|
299
|
+
if cutlass.const_expr(self.check_hdim_oob)
|
|
300
|
+
else None,
|
|
301
|
+
)
|
|
302
|
+
cute.copy(
|
|
303
|
+
gmem_thr_copy_O,
|
|
304
|
+
tOgdO[None, m, None],
|
|
305
|
+
tOrdO[None, m, None],
|
|
306
|
+
pred=tOpdO[None, m, None]
|
|
307
|
+
if cutlass.const_expr(self.check_hdim_oob)
|
|
308
|
+
else None,
|
|
309
|
+
)
|
|
310
|
+
# Sum across the "k" dimension
|
|
311
|
+
dpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce(
|
|
312
|
+
cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1)
|
|
313
|
+
)
|
|
314
|
+
threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0]
|
|
315
|
+
assert cute.arch.WARP_SIZE % threads_per_row == 0
|
|
316
|
+
dpsum = utils.warp_reduce(dpsum, operator.add, width=threads_per_row)
|
|
317
|
+
dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), Float32)
|
|
318
|
+
dP_sum.store(dpsum)
|
|
319
|
+
|
|
320
|
+
# Write dPsum from rmem -> gmem
|
|
321
|
+
gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (m_block,))
|
|
322
|
+
# Only the thread corresponding to column 0 writes out the dPsum to gmem
|
|
323
|
+
if tOcO[0, 0, 0][1] == 0:
|
|
324
|
+
for m in cutlass.range(cute.size(dP_sum), unroll_full=True):
|
|
325
|
+
row = tOcO[0, m, 0][0]
|
|
326
|
+
gdPsum[row] = dP_sum[m] if row < seqlen_q - m_block * self.m_block_size else 0.0
|
|
327
|
+
|
|
328
|
+
# Clear dQaccum
|
|
329
|
+
if cutlass.const_expr(mdQaccum is not None):
|
|
330
|
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
|
|
331
|
+
mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
|
|
332
|
+
else:
|
|
333
|
+
mdQaccum_cur = cute.domain_offset(
|
|
334
|
+
(padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None]
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# HACK: Compiler doesn't seem to recognize that padding
|
|
338
|
+
# by padded_offset_q * self.head_dim_padded keeps alignment
|
|
339
|
+
# since statically divisible by 4
|
|
340
|
+
|
|
341
|
+
mdQaccum_cur_ptr = cute.make_ptr(
|
|
342
|
+
dtype=mdQaccum_cur.element_type,
|
|
343
|
+
value=mdQaccum_cur.iterator.toint(),
|
|
344
|
+
mem_space=mdQaccum_cur.iterator.memspace,
|
|
345
|
+
assumed_align=mdQaccum.iterator.alignment,
|
|
346
|
+
)
|
|
347
|
+
mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout)
|
|
348
|
+
|
|
349
|
+
blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,)
|
|
350
|
+
gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,))
|
|
351
|
+
gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx)
|
|
352
|
+
tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum)
|
|
353
|
+
zero = cute.make_fragment_like(tdQgdQaccum)
|
|
354
|
+
zero.fill(0.0)
|
|
355
|
+
cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum)
|
|
356
|
+
|
|
357
|
+
if cutlass.const_expr(mLSE is not None):
|
|
358
|
+
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
|
|
359
|
+
mLSElog2_cur = mLSElog2[batch_idx, head_idx, None]
|
|
360
|
+
else:
|
|
361
|
+
mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[head_idx, None])
|
|
362
|
+
|
|
363
|
+
gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,))
|
|
364
|
+
LOG2_E = math.log2(math.e)
|
|
365
|
+
if tidx < seqlen_q_rounded - m_block * self.m_block_size:
|
|
366
|
+
gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0
|