mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,705 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
|
3
|
+
# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_combine_kernel.h
|
|
4
|
+
# from Cutlass C++ to Cute-DSL.
|
|
5
|
+
import math
|
|
6
|
+
import operator
|
|
7
|
+
from typing import Type, Optional
|
|
8
|
+
from functools import partial
|
|
9
|
+
|
|
10
|
+
import cuda.bindings.driver as cuda
|
|
11
|
+
|
|
12
|
+
import cutlass
|
|
13
|
+
import cutlass.cute as cute
|
|
14
|
+
from cutlass.cute.nvgpu import cpasync
|
|
15
|
+
from cutlass import Float32, Int32, const_expr
|
|
16
|
+
|
|
17
|
+
from mslk.attention.flash_attn import utils
|
|
18
|
+
from mslk.attention.flash_attn.seqlen_info import SeqlenInfo
|
|
19
|
+
from cutlass.cute import FastDivmodDivisor
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FlashAttentionForwardCombine:
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
dtype: Type[cutlass.Numeric],
|
|
26
|
+
dtype_partial: Type[cutlass.Numeric],
|
|
27
|
+
head_dim: int,
|
|
28
|
+
m_block_size: int = 8,
|
|
29
|
+
k_block_size: int = 64,
|
|
30
|
+
log_max_splits: int = 4,
|
|
31
|
+
num_threads: int = 256,
|
|
32
|
+
stages: int = 4,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Forward combine kernel for split attention computation.
|
|
36
|
+
|
|
37
|
+
:param dtype: output data type
|
|
38
|
+
:param dtype_partial: partial accumulation data type
|
|
39
|
+
:param head_dim: head dimension
|
|
40
|
+
:param m_block_size: m block size
|
|
41
|
+
:param k_block_size: k block size
|
|
42
|
+
:param log_max_splits: log2 of maximum splits
|
|
43
|
+
:param num_threads: number of threads
|
|
44
|
+
:param varlen: whether using variable length sequences
|
|
45
|
+
:param stages: number of pipeline stages
|
|
46
|
+
"""
|
|
47
|
+
self.dtype = dtype
|
|
48
|
+
self.dtype_partial = dtype_partial
|
|
49
|
+
self.head_dim = head_dim
|
|
50
|
+
self.m_block_size = m_block_size
|
|
51
|
+
self.k_block_size = k_block_size
|
|
52
|
+
self.max_splits = 1 << log_max_splits
|
|
53
|
+
self.num_threads = num_threads
|
|
54
|
+
self.is_even_k = head_dim % k_block_size == 0
|
|
55
|
+
self.stages = stages
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def can_implement(
|
|
59
|
+
dtype,
|
|
60
|
+
dtype_partial,
|
|
61
|
+
head_dim,
|
|
62
|
+
m_block_size,
|
|
63
|
+
k_block_size,
|
|
64
|
+
log_max_splits,
|
|
65
|
+
num_threads,
|
|
66
|
+
) -> bool:
|
|
67
|
+
"""Check if the kernel can be implemented with the given parameters."""
|
|
68
|
+
if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]:
|
|
69
|
+
return False
|
|
70
|
+
if dtype_partial not in [cutlass.Float16, cutlass.BFloat16, Float32]:
|
|
71
|
+
return False
|
|
72
|
+
if head_dim % 8 != 0:
|
|
73
|
+
return False
|
|
74
|
+
if num_threads % 32 != 0:
|
|
75
|
+
return False
|
|
76
|
+
if m_block_size % 8 != 0:
|
|
77
|
+
return False
|
|
78
|
+
max_splits = 1 << log_max_splits
|
|
79
|
+
if max_splits > 256:
|
|
80
|
+
return False
|
|
81
|
+
if (m_block_size * max_splits) % num_threads != 0:
|
|
82
|
+
return False
|
|
83
|
+
return True
|
|
84
|
+
|
|
85
|
+
def _setup_attributes(self):
|
|
86
|
+
# GMEM copy setup for O partial
|
|
87
|
+
universal_copy_bits = 128
|
|
88
|
+
async_copy_elems = universal_copy_bits // self.dtype_partial.width
|
|
89
|
+
assert self.k_block_size % async_copy_elems == 0
|
|
90
|
+
|
|
91
|
+
k_block_gmem = (
|
|
92
|
+
128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32)
|
|
93
|
+
)
|
|
94
|
+
gmem_threads_per_row = k_block_gmem // async_copy_elems
|
|
95
|
+
assert self.num_threads % gmem_threads_per_row == 0
|
|
96
|
+
|
|
97
|
+
# Async copy atom for O partial load
|
|
98
|
+
atom_async_copy_partial = cute.make_copy_atom(
|
|
99
|
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
|
100
|
+
self.dtype_partial,
|
|
101
|
+
num_bits_per_copy=universal_copy_bits,
|
|
102
|
+
)
|
|
103
|
+
tOpartial_layout = cute.make_ordered_layout(
|
|
104
|
+
(self.num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
|
105
|
+
order=(1, 0),
|
|
106
|
+
)
|
|
107
|
+
vOpartial_layout = cute.make_layout((1, async_copy_elems)) # 4 vals per load
|
|
108
|
+
self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv(
|
|
109
|
+
atom_async_copy_partial, tOpartial_layout, vOpartial_layout
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# GMEM copy setup for final O (use universal copy for store)
|
|
113
|
+
atom_universal_copy = cute.make_copy_atom(
|
|
114
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
115
|
+
self.dtype,
|
|
116
|
+
num_bits_per_copy=async_copy_elems * self.dtype.width,
|
|
117
|
+
)
|
|
118
|
+
self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(
|
|
119
|
+
atom_universal_copy,
|
|
120
|
+
tOpartial_layout,
|
|
121
|
+
vOpartial_layout, # 4 vals per store
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# LSE copy setup with async copy (alignment = 1)
|
|
125
|
+
lse_copy_bits = Float32.width # 1 element per copy, width is in bits
|
|
126
|
+
m_block_smem = (
|
|
127
|
+
128
|
|
128
|
+
if self.m_block_size % 128 == 0
|
|
129
|
+
else (
|
|
130
|
+
64
|
|
131
|
+
if self.m_block_size % 64 == 0
|
|
132
|
+
else (
|
|
133
|
+
32
|
|
134
|
+
if self.m_block_size % 32 == 0
|
|
135
|
+
else (16 if self.m_block_size % 16 == 0 else 8)
|
|
136
|
+
)
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
gmem_threads_per_row_lse = m_block_smem
|
|
140
|
+
assert self.num_threads % gmem_threads_per_row_lse == 0
|
|
141
|
+
|
|
142
|
+
# Async copy atom for LSE load
|
|
143
|
+
atom_async_copy_lse = cute.make_copy_atom(
|
|
144
|
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS),
|
|
145
|
+
Float32,
|
|
146
|
+
num_bits_per_copy=lse_copy_bits,
|
|
147
|
+
)
|
|
148
|
+
tLSE_layout = cute.make_ordered_layout(
|
|
149
|
+
(self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse),
|
|
150
|
+
order=(1, 0),
|
|
151
|
+
)
|
|
152
|
+
vLSE_layout = cute.make_layout(1)
|
|
153
|
+
self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv(
|
|
154
|
+
atom_async_copy_lse, tLSE_layout, vLSE_layout
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
158
|
+
# Shared memory
|
|
159
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
160
|
+
|
|
161
|
+
# Shared memory to register copy for LSE
|
|
162
|
+
self.smem_threads_per_col_lse = self.num_threads // m_block_smem
|
|
163
|
+
assert 32 % self.smem_threads_per_col_lse == 0 # Must divide warp size
|
|
164
|
+
|
|
165
|
+
s2r_layout_atom_lse = cute.make_ordered_layout(
|
|
166
|
+
(self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse),
|
|
167
|
+
order=(0, 1),
|
|
168
|
+
)
|
|
169
|
+
self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv(
|
|
170
|
+
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32),
|
|
171
|
+
s2r_layout_atom_lse,
|
|
172
|
+
cute.make_layout(1),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# LSE shared memory layout with swizzling to avoid bank conflicts
|
|
176
|
+
# This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts
|
|
177
|
+
if const_expr(m_block_smem == 8):
|
|
178
|
+
smem_lse_swizzle = cute.make_swizzle(5, 0, 5)
|
|
179
|
+
elif const_expr(m_block_smem == 16):
|
|
180
|
+
smem_lse_swizzle = cute.make_swizzle(4, 0, 4)
|
|
181
|
+
else:
|
|
182
|
+
smem_lse_swizzle = cute.make_swizzle(3, 2, 3)
|
|
183
|
+
smem_layout_atom_lse = cute.make_composed_layout(
|
|
184
|
+
smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
|
|
185
|
+
)
|
|
186
|
+
self.smem_layout_lse = cute.tile_to_shape(
|
|
187
|
+
smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1)
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# O partial shared memory layout (simple layout for pipeline stages)
|
|
191
|
+
self.smem_layout_o = cute.make_ordered_layout(
|
|
192
|
+
(self.m_block_size, self.k_block_size, self.stages), order=(1, 0, 2)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
@cute.jit
|
|
196
|
+
def __call__(
|
|
197
|
+
self,
|
|
198
|
+
mO_partial: cute.Tensor,
|
|
199
|
+
mLSE_partial: cute.Tensor,
|
|
200
|
+
mO: cute.Tensor,
|
|
201
|
+
mLSE: Optional[cute.Tensor] = None,
|
|
202
|
+
cu_seqlens: Optional[cute.Tensor] = None,
|
|
203
|
+
seqused: Optional[cute.Tensor] = None,
|
|
204
|
+
num_splits_dynamic_ptr: Optional[cute.Tensor] = None,
|
|
205
|
+
semaphore_to_reset: Optional[cute.Tensor] = None,
|
|
206
|
+
stream: cuda.CUstream = None,
|
|
207
|
+
):
|
|
208
|
+
# Type checking
|
|
209
|
+
if const_expr(not (mO_partial.element_type == self.dtype_partial)):
|
|
210
|
+
raise TypeError("O partial tensor must match dtype_partial")
|
|
211
|
+
if const_expr(not (mO.element_type == self.dtype)):
|
|
212
|
+
raise TypeError("O tensor must match dtype")
|
|
213
|
+
if const_expr(mLSE_partial.element_type not in [Float32]):
|
|
214
|
+
raise TypeError("LSE partial tensor must be Float32")
|
|
215
|
+
if const_expr(mLSE is not None and mLSE.element_type not in [Float32]):
|
|
216
|
+
raise TypeError("LSE tensor must be Float32")
|
|
217
|
+
|
|
218
|
+
# Shape validation - input tensors are in user format, need to be converted to kernel format
|
|
219
|
+
if const_expr(len(mO_partial.shape) not in [4, 5]):
|
|
220
|
+
raise ValueError(
|
|
221
|
+
"O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)"
|
|
222
|
+
)
|
|
223
|
+
if const_expr(len(mLSE_partial.shape) not in [3, 4]):
|
|
224
|
+
raise ValueError(
|
|
225
|
+
"LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)"
|
|
226
|
+
)
|
|
227
|
+
if const_expr(len(mO.shape) not in [3, 4]):
|
|
228
|
+
raise ValueError(
|
|
229
|
+
"O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)"
|
|
230
|
+
)
|
|
231
|
+
if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]):
|
|
232
|
+
raise ValueError(
|
|
233
|
+
"LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)"
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Assume all strides are divisible by 128 bits except the last stride
|
|
237
|
+
new_stride = lambda t: (
|
|
238
|
+
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
|
|
239
|
+
t.stride[-1],
|
|
240
|
+
)
|
|
241
|
+
mO_partial, mO = [
|
|
242
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
243
|
+
for t in (mO_partial, mO)
|
|
244
|
+
]
|
|
245
|
+
# (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b)
|
|
246
|
+
# or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h)
|
|
247
|
+
O_partial_layout_transpose = (
|
|
248
|
+
[2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2]
|
|
249
|
+
)
|
|
250
|
+
# (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h)
|
|
251
|
+
mO_partial = cute.make_tensor(
|
|
252
|
+
mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose)
|
|
253
|
+
)
|
|
254
|
+
O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1]
|
|
255
|
+
mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose))
|
|
256
|
+
# (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b)
|
|
257
|
+
# or (num_splits, total_q, h) -> (total_q, num_splits, h)
|
|
258
|
+
LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2]
|
|
259
|
+
mLSE_partial = cute.make_tensor(
|
|
260
|
+
mLSE_partial.iterator,
|
|
261
|
+
cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose),
|
|
262
|
+
)
|
|
263
|
+
# (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h)
|
|
264
|
+
LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1]
|
|
265
|
+
mLSE = (
|
|
266
|
+
cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))
|
|
267
|
+
if mLSE is not None
|
|
268
|
+
else None
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Determine if we have variable length sequences
|
|
272
|
+
varlen = const_expr(cu_seqlens is not None or seqused is not None)
|
|
273
|
+
|
|
274
|
+
self._setup_attributes()
|
|
275
|
+
|
|
276
|
+
@cute.struct
|
|
277
|
+
class SharedStorage:
|
|
278
|
+
sLSE: cute.struct.Align[
|
|
279
|
+
cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
|
|
280
|
+
]
|
|
281
|
+
sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.m_block_size], 128]
|
|
282
|
+
sO: cute.struct.Align[
|
|
283
|
+
cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
|
|
284
|
+
]
|
|
285
|
+
|
|
286
|
+
smem_size = SharedStorage.size_in_bytes()
|
|
287
|
+
|
|
288
|
+
# Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch)
|
|
289
|
+
seqlen = mO_partial.shape[0]
|
|
290
|
+
num_head = mO_partial.shape[3]
|
|
291
|
+
batch_size = (
|
|
292
|
+
mO_partial.shape[4]
|
|
293
|
+
if const_expr(cu_seqlens is None)
|
|
294
|
+
else Int32(cu_seqlens.shape[0] - 1)
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# Create FastDivmodDivisor objects for efficient division
|
|
298
|
+
seqlen_divmod = FastDivmodDivisor(seqlen)
|
|
299
|
+
head_divmod = FastDivmodDivisor(num_head)
|
|
300
|
+
|
|
301
|
+
grid_dim = (
|
|
302
|
+
cute.ceil_div(seqlen * num_head, self.m_block_size),
|
|
303
|
+
cute.ceil_div(self.head_dim, self.k_block_size),
|
|
304
|
+
batch_size,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
self.kernel(
|
|
308
|
+
mO_partial,
|
|
309
|
+
mLSE_partial,
|
|
310
|
+
mO,
|
|
311
|
+
mLSE,
|
|
312
|
+
cu_seqlens,
|
|
313
|
+
seqused,
|
|
314
|
+
num_splits_dynamic_ptr,
|
|
315
|
+
semaphore_to_reset,
|
|
316
|
+
SharedStorage,
|
|
317
|
+
self.smem_layout_lse,
|
|
318
|
+
self.smem_layout_o,
|
|
319
|
+
self.gmem_tiled_copy_O_partial,
|
|
320
|
+
self.gmem_tiled_copy_O,
|
|
321
|
+
self.gmem_tiled_copy_LSE,
|
|
322
|
+
self.s2r_tiled_copy_LSE,
|
|
323
|
+
seqlen_divmod,
|
|
324
|
+
head_divmod,
|
|
325
|
+
varlen,
|
|
326
|
+
).launch(
|
|
327
|
+
grid=grid_dim,
|
|
328
|
+
block=[self.num_threads, 1, 1],
|
|
329
|
+
smem=smem_size,
|
|
330
|
+
stream=stream,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
@cute.kernel
|
|
334
|
+
def kernel(
|
|
335
|
+
self,
|
|
336
|
+
mO_partial: cute.Tensor,
|
|
337
|
+
mLSE_partial: cute.Tensor,
|
|
338
|
+
mO: cute.Tensor,
|
|
339
|
+
mLSE: Optional[cute.Tensor],
|
|
340
|
+
cu_seqlens: Optional[cute.Tensor],
|
|
341
|
+
seqused: Optional[cute.Tensor],
|
|
342
|
+
num_splits_dynamic_ptr: Optional[cute.Tensor],
|
|
343
|
+
semaphore_to_reset: Optional[cute.Tensor],
|
|
344
|
+
SharedStorage: cutlass.Constexpr,
|
|
345
|
+
smem_layout_lse: cute.Layout | cute.ComposedLayout,
|
|
346
|
+
smem_layout_o: cute.Layout,
|
|
347
|
+
gmem_tiled_copy_O_partial: cute.TiledCopy,
|
|
348
|
+
gmem_tiled_copy_O: cute.TiledCopy,
|
|
349
|
+
gmem_tiled_copy_LSE: cute.TiledCopy,
|
|
350
|
+
s2r_tiled_copy_LSE: cute.TiledCopy,
|
|
351
|
+
seqlen_divmod: FastDivmodDivisor,
|
|
352
|
+
head_divmod: FastDivmodDivisor,
|
|
353
|
+
varlen: cutlass.Constexpr[bool],
|
|
354
|
+
):
|
|
355
|
+
# Thread and block indices
|
|
356
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
357
|
+
m_block, k_block, batch_idx = cute.arch.block_idx()
|
|
358
|
+
|
|
359
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
360
|
+
# Get shared memory buffer
|
|
361
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
362
|
+
smem = cutlass.utils.SmemAllocator()
|
|
363
|
+
storage = smem.allocate(SharedStorage)
|
|
364
|
+
sLSE = storage.sLSE.get_tensor(smem_layout_lse)
|
|
365
|
+
sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.m_block_size,))
|
|
366
|
+
sO = storage.sO.get_tensor(smem_layout_o)
|
|
367
|
+
|
|
368
|
+
# Handle semaphore reset
|
|
369
|
+
if const_expr(semaphore_to_reset is not None):
|
|
370
|
+
if (
|
|
371
|
+
tidx == 0
|
|
372
|
+
and m_block == cute.arch.grid_dim()[0] - 1
|
|
373
|
+
and k_block == cute.arch.grid_dim()[1] - 1
|
|
374
|
+
and batch_idx == cute.arch.grid_dim()[2] - 1
|
|
375
|
+
):
|
|
376
|
+
semaphore_to_reset[0] = 0
|
|
377
|
+
|
|
378
|
+
# Get number of splits
|
|
379
|
+
num_splits = (
|
|
380
|
+
num_splits_dynamic_ptr[batch_idx]
|
|
381
|
+
if const_expr(num_splits_dynamic_ptr is not None)
|
|
382
|
+
else mLSE_partial.shape[1]
|
|
383
|
+
)
|
|
384
|
+
# Handle variable length sequences using SeqlenInfo
|
|
385
|
+
seqlen_info = SeqlenInfo.create(
|
|
386
|
+
batch_idx=batch_idx,
|
|
387
|
+
seqlen_static=mO_partial.shape[0],
|
|
388
|
+
cu_seqlens=cu_seqlens,
|
|
389
|
+
seqused=seqused,
|
|
390
|
+
)
|
|
391
|
+
seqlen, offset = seqlen_info.seqlen, seqlen_info.offset
|
|
392
|
+
|
|
393
|
+
# Extract number of heads (head index will be determined dynamically)
|
|
394
|
+
num_head = mO_partial.shape[3]
|
|
395
|
+
max_idx = seqlen * num_head
|
|
396
|
+
|
|
397
|
+
# Early exit for single split if dynamic
|
|
398
|
+
if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (
|
|
399
|
+
const_expr(not varlen) or m_block * self.m_block_size < max_idx
|
|
400
|
+
):
|
|
401
|
+
# ===============================
|
|
402
|
+
# Step 1: Load LSE_partial from gmem to shared memory
|
|
403
|
+
# ===============================
|
|
404
|
+
|
|
405
|
+
if const_expr(cu_seqlens is None):
|
|
406
|
+
# mLSE_partial_cur = mLSE_partial[None, None, None, batch_idx]
|
|
407
|
+
mLSE_partial_cur = utils.coord_offset_i64(mLSE_partial, batch_idx, dim=3)
|
|
408
|
+
else:
|
|
409
|
+
# mLSE_partial_cur = cute.domain_offset((offset, 0, 0), mLSE_partial)
|
|
410
|
+
mLSE_partial_cur = utils.domain_offset_i64((offset, 0, 0), mLSE_partial)
|
|
411
|
+
mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,))
|
|
412
|
+
|
|
413
|
+
gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx)
|
|
414
|
+
tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE)
|
|
415
|
+
|
|
416
|
+
# Create identity tensor for coordinate tracking
|
|
417
|
+
cLSE = cute.make_identity_tensor((self.max_splits, self.m_block_size))
|
|
418
|
+
tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE)
|
|
419
|
+
|
|
420
|
+
# Load LSE partial values
|
|
421
|
+
for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True):
|
|
422
|
+
mi = tLSEcLSE[0, 0, m][1] # Get m coordinate
|
|
423
|
+
idx = m_block * self.m_block_size + mi
|
|
424
|
+
if idx < max_idx:
|
|
425
|
+
# Calculate actual sequence position and head using FastDivmodDivisor
|
|
426
|
+
if const_expr(not varlen):
|
|
427
|
+
head_idx, m_idx = divmod(idx, seqlen_divmod)
|
|
428
|
+
else:
|
|
429
|
+
head_idx = idx // seqlen
|
|
430
|
+
m_idx = idx - head_idx * seqlen
|
|
431
|
+
mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx]
|
|
432
|
+
for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
|
|
433
|
+
si = tLSEcLSE[0, s, 0][0] # Get split coordinate
|
|
434
|
+
if si < num_splits:
|
|
435
|
+
cute.copy(
|
|
436
|
+
gmem_thr_copy_LSE,
|
|
437
|
+
mLSE_partial_cur_copy[None, si],
|
|
438
|
+
tLSEsLSE[None, s, m],
|
|
439
|
+
)
|
|
440
|
+
else:
|
|
441
|
+
tLSEsLSE[None, s, m].fill(-Float32.inf)
|
|
442
|
+
# Don't need to zero out the rest of the LSEs, as we will not write the output to gmem
|
|
443
|
+
cute.arch.cp_async_commit_group()
|
|
444
|
+
|
|
445
|
+
# ===============================
|
|
446
|
+
# Step 2: Load O_partial for pipeline stages
|
|
447
|
+
# ===============================
|
|
448
|
+
|
|
449
|
+
gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx)
|
|
450
|
+
cO = cute.make_identity_tensor((self.m_block_size, self.k_block_size))
|
|
451
|
+
tOcO = gmem_thr_copy_O_partial.partition_D(cO)
|
|
452
|
+
tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO)
|
|
453
|
+
if const_expr(cu_seqlens is None):
|
|
454
|
+
# mO_partial_cur = mO_partial[None, None, None, None, batch_idx]
|
|
455
|
+
mO_partial_cur = utils.coord_offset_i64(mO_partial, batch_idx, dim=4)
|
|
456
|
+
else:
|
|
457
|
+
# mO_partial_cur = cute.domain_offset((offset, 0, 0, 0), mO_partial)
|
|
458
|
+
mO_partial_cur = utils.domain_offset_i64((offset, 0, 0, 0), mO_partial)
|
|
459
|
+
|
|
460
|
+
# Precompute these values to avoid recomputing them in the loop
|
|
461
|
+
num_rows = const_expr(cute.size(tOcO, mode=[1]))
|
|
462
|
+
tOmidx = cute.make_fragment(num_rows, cutlass.Int32)
|
|
463
|
+
tOhidx = cute.make_fragment(num_rows, cutlass.Int32)
|
|
464
|
+
tOrOptr = cute.make_fragment(num_rows, cutlass.Int64)
|
|
465
|
+
for m in cutlass.range(num_rows, unroll_full=True):
|
|
466
|
+
mi = tOcO[0, m, 0][0] # m coordinate
|
|
467
|
+
idx = m_block * self.m_block_size + mi
|
|
468
|
+
if const_expr(not varlen):
|
|
469
|
+
tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod)
|
|
470
|
+
else:
|
|
471
|
+
tOhidx[m] = idx // seqlen
|
|
472
|
+
tOmidx[m] = idx - tOhidx[m] * seqlen
|
|
473
|
+
tOrOptr[m] = utils.elem_pointer_i64(
|
|
474
|
+
mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])
|
|
475
|
+
).toint()
|
|
476
|
+
if idx >= max_idx:
|
|
477
|
+
tOhidx[m] = -1
|
|
478
|
+
|
|
479
|
+
tOpO = cute.make_fragment(cute.size(tOcO, [2]), cutlass.Boolean)
|
|
480
|
+
if const_expr(not self.is_even_k):
|
|
481
|
+
for k in cutlass.range(cute.size(tOpO), unroll_full=True):
|
|
482
|
+
tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size
|
|
483
|
+
# if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO)
|
|
484
|
+
|
|
485
|
+
load_O_partial = partial(
|
|
486
|
+
self.load_O_partial,
|
|
487
|
+
gmem_tiled_copy_O_partial,
|
|
488
|
+
tOrOptr,
|
|
489
|
+
tOsO_partial,
|
|
490
|
+
tOhidx,
|
|
491
|
+
tOpO,
|
|
492
|
+
tOcO,
|
|
493
|
+
mO_partial_cur.layout,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
# Load first few stages of O_partial
|
|
497
|
+
for stage in cutlass.range(self.stages - 1, unroll_full=True):
|
|
498
|
+
if stage < num_splits:
|
|
499
|
+
load_O_partial(stage, stage)
|
|
500
|
+
cute.arch.cp_async_commit_group()
|
|
501
|
+
|
|
502
|
+
# ===============================
|
|
503
|
+
# Step 3: Load and transpose LSE from smem to registers
|
|
504
|
+
# ===============================
|
|
505
|
+
|
|
506
|
+
# Wait for LSE and initial O partial stages to complete
|
|
507
|
+
cute.arch.cp_async_wait_group(self.stages - 1)
|
|
508
|
+
cute.arch.sync_threads()
|
|
509
|
+
# if cute.arch.thread_idx()[0] == 0:
|
|
510
|
+
# # cute.print_tensor(sLSE)
|
|
511
|
+
# for i in range(64):
|
|
512
|
+
# cute.printf("sLSE[%d, 0] = %f", i, sLSE[i, 0])
|
|
513
|
+
# cute.arch.sync_threads()
|
|
514
|
+
|
|
515
|
+
s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx)
|
|
516
|
+
ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE)
|
|
517
|
+
ts2rrLSE = cute.make_fragment_like(ts2rsLSE)
|
|
518
|
+
cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE)
|
|
519
|
+
|
|
520
|
+
# ===============================
|
|
521
|
+
# Step 4: Compute final LSE along split dimension
|
|
522
|
+
# ===============================
|
|
523
|
+
|
|
524
|
+
lse_sum = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Float32)
|
|
525
|
+
ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE)
|
|
526
|
+
# We compute the max valid split for each row to short-circuit the computation later
|
|
527
|
+
max_valid_split = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Int32)
|
|
528
|
+
assert cute.size(ts2rrLSE, mode=[0]) == 1
|
|
529
|
+
# Compute max, scales, and final LSE for each row
|
|
530
|
+
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
|
531
|
+
# Find max LSE value across splits
|
|
532
|
+
threads_per_col = const_expr(self.smem_threads_per_col_lse)
|
|
533
|
+
lse_max = utils.warp_reduce(
|
|
534
|
+
ts2rrLSE[None, None, m]
|
|
535
|
+
.load()
|
|
536
|
+
.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
|
537
|
+
op=cute.arch.fmax,
|
|
538
|
+
width=threads_per_col,
|
|
539
|
+
)
|
|
540
|
+
# if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max)
|
|
541
|
+
# Find max valid split index
|
|
542
|
+
max_valid_idx = -1
|
|
543
|
+
for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
|
|
544
|
+
if ts2rrLSE[0, s, m] != -Float32.inf:
|
|
545
|
+
max_valid_idx = ts2rcLSE[0, s, 0][0] # Get split coordinate
|
|
546
|
+
# if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx)
|
|
547
|
+
max_valid_split[m] = utils.warp_reduce(max_valid_idx, max, width=threads_per_col)
|
|
548
|
+
# Compute exp scales and sum
|
|
549
|
+
lse_max_cur = (
|
|
550
|
+
0.0 if lse_max == -Float32.inf else lse_max
|
|
551
|
+
) # In case all local LSEs are -inf
|
|
552
|
+
LOG2_E = math.log2(math.e)
|
|
553
|
+
lse_sum_cur = 0.0
|
|
554
|
+
for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
|
|
555
|
+
scale = utils.exp2f(ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E))
|
|
556
|
+
lse_sum_cur += scale
|
|
557
|
+
ts2rrLSE[0, s, m] = scale # Store scale for later use
|
|
558
|
+
lse_sum_cur = utils.warp_reduce(lse_sum_cur, operator.add, width=threads_per_col)
|
|
559
|
+
lse_sum[m] = utils.logf(lse_sum_cur) + lse_max
|
|
560
|
+
# Normalize scales
|
|
561
|
+
inv_sum = (
|
|
562
|
+
0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur
|
|
563
|
+
)
|
|
564
|
+
ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum)
|
|
565
|
+
# Store the scales exp(lse - lse_logsum) back to smem
|
|
566
|
+
cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE)
|
|
567
|
+
|
|
568
|
+
# Store max valid split to smem
|
|
569
|
+
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
|
570
|
+
if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
|
|
571
|
+
mi = ts2rcLSE[0, 0, m][1]
|
|
572
|
+
if mi < self.m_block_size:
|
|
573
|
+
sMaxValidSplit[mi] = max_valid_split[m]
|
|
574
|
+
|
|
575
|
+
# ===============================
|
|
576
|
+
# Step 5: Store final LSE to gmem
|
|
577
|
+
# ===============================
|
|
578
|
+
|
|
579
|
+
if const_expr(mLSE is not None):
|
|
580
|
+
if const_expr(cu_seqlens is None):
|
|
581
|
+
# mLSE_cur = mLSE[None, None, batch_idx]
|
|
582
|
+
mLSE_cur = utils.coord_offset_i64(mLSE, batch_idx, dim=2)
|
|
583
|
+
else:
|
|
584
|
+
# mLSE_cur = cute.domain_offset((offset, 0), mLSE)
|
|
585
|
+
mLSE_cur = utils.domain_offset_i64((offset, 0), mLSE)
|
|
586
|
+
if k_block == 0: # Only first k_block writes LSE when mLSE is provided
|
|
587
|
+
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
|
588
|
+
if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
|
|
589
|
+
mi = ts2rcLSE[0, 0, m][1]
|
|
590
|
+
idx = m_block * self.m_block_size + mi
|
|
591
|
+
if idx < max_idx:
|
|
592
|
+
if const_expr(not varlen):
|
|
593
|
+
head_idx, m_idx = divmod(idx, seqlen_divmod)
|
|
594
|
+
else:
|
|
595
|
+
head_idx = idx // seqlen
|
|
596
|
+
m_idx = idx - head_idx * seqlen
|
|
597
|
+
mLSE_cur[m_idx, head_idx] = lse_sum[m]
|
|
598
|
+
|
|
599
|
+
# ===============================
|
|
600
|
+
# Step 6: Read O_partial and accumulate final O
|
|
601
|
+
# ===============================
|
|
602
|
+
|
|
603
|
+
cute.arch.sync_threads()
|
|
604
|
+
|
|
605
|
+
# Get max valid split for this thread
|
|
606
|
+
thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]]
|
|
607
|
+
for m in cutlass.range(1, cute.size(tOcO, mode=[1])):
|
|
608
|
+
thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]])
|
|
609
|
+
|
|
610
|
+
tOrO_partial = cute.make_fragment_like(tOsO_partial[None, None, None, 0])
|
|
611
|
+
tOrO = cute.make_fragment_like(tOrO_partial, Float32)
|
|
612
|
+
tOrO.fill(0.0)
|
|
613
|
+
|
|
614
|
+
stage_load = self.stages - 1
|
|
615
|
+
stage_compute = 0
|
|
616
|
+
|
|
617
|
+
# Main accumulation loop
|
|
618
|
+
for s in cutlass.range(thr_max_valid_split + 1, unroll=4):
|
|
619
|
+
# Get scales for this split
|
|
620
|
+
scale = cute.make_fragment(num_rows, Float32)
|
|
621
|
+
for m in cutlass.range(num_rows, unroll_full=True):
|
|
622
|
+
scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem
|
|
623
|
+
|
|
624
|
+
# Load next stage if needed
|
|
625
|
+
split_to_load = s + self.stages - 1
|
|
626
|
+
if split_to_load <= thr_max_valid_split:
|
|
627
|
+
load_O_partial(split_to_load, stage_load)
|
|
628
|
+
cute.arch.cp_async_commit_group()
|
|
629
|
+
stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1
|
|
630
|
+
|
|
631
|
+
# Wait for the current stage to be ready
|
|
632
|
+
cute.arch.cp_async_wait_group(self.stages - 1)
|
|
633
|
+
# We don't need __syncthreads() because each thread is just reading its own data from smem
|
|
634
|
+
# Copy from smem to registers
|
|
635
|
+
cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial)
|
|
636
|
+
stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1
|
|
637
|
+
|
|
638
|
+
# Accumulate scaled partial results
|
|
639
|
+
for m in cutlass.range(num_rows, unroll_full=True):
|
|
640
|
+
if tOhidx[m] >= 0 and scale[m] > 0.0:
|
|
641
|
+
tOrO[None, m, None].store(
|
|
642
|
+
tOrO[None, m, None].load()
|
|
643
|
+
+ scale[m] * tOrO_partial[None, m, None].load().to(Float32)
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
# ===============================
|
|
647
|
+
# Step 7: Write final O to gmem
|
|
648
|
+
# ===============================
|
|
649
|
+
|
|
650
|
+
rO = cute.make_fragment_like(tOrO, self.dtype)
|
|
651
|
+
rO.store(tOrO.load().to(self.dtype))
|
|
652
|
+
if const_expr(cu_seqlens is None):
|
|
653
|
+
# mO_cur = mO[None, None, None, batch_idx]
|
|
654
|
+
mO_cur = utils.coord_offset_i64(mO, batch_idx, dim=3)
|
|
655
|
+
else:
|
|
656
|
+
# mO_cur = cute.domain_offset((offset, 0, 0), mO)
|
|
657
|
+
mO_cur = utils.domain_offset_i64((offset, 0, 0), mO)
|
|
658
|
+
mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur)
|
|
659
|
+
elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1]))
|
|
660
|
+
# mO_cur_copy = cute.tiled_divide(mO_cur, (1, elems_per_store,))
|
|
661
|
+
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
|
|
662
|
+
# Write final results
|
|
663
|
+
for m in cutlass.range(num_rows, unroll_full=True):
|
|
664
|
+
if tOhidx[m] >= 0:
|
|
665
|
+
mO_cur_copy = cute.tiled_divide(
|
|
666
|
+
mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,)
|
|
667
|
+
)
|
|
668
|
+
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
|
|
669
|
+
k_idx = tOcO[0, 0, k][1] // elems_per_store
|
|
670
|
+
if const_expr(self.is_even_k) or tOpO[k]:
|
|
671
|
+
cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx])
|
|
672
|
+
|
|
673
|
+
@cute.jit
|
|
674
|
+
def load_O_partial(
|
|
675
|
+
self,
|
|
676
|
+
gmem_tiled_copy_O_partial: cute.TiledCopy,
|
|
677
|
+
tOrOptr: cute.Tensor,
|
|
678
|
+
tOsO_partial: cute.Tensor,
|
|
679
|
+
tOhidx: cute.Tensor,
|
|
680
|
+
tOpO: cute.Tensor,
|
|
681
|
+
tOcO: cute.Tensor,
|
|
682
|
+
mO_cur_partial_layout: cute.Layout,
|
|
683
|
+
split: Int32,
|
|
684
|
+
stage: Int32,
|
|
685
|
+
) -> None:
|
|
686
|
+
elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1]))
|
|
687
|
+
tOsO_partial_cur = tOsO_partial[None, None, None, stage]
|
|
688
|
+
for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True):
|
|
689
|
+
if tOhidx[m] >= 0:
|
|
690
|
+
o_gmem_ptr = cute.make_ptr(
|
|
691
|
+
tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16
|
|
692
|
+
)
|
|
693
|
+
mO_partial_cur = cute.make_tensor(
|
|
694
|
+
o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0))
|
|
695
|
+
)
|
|
696
|
+
mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,))
|
|
697
|
+
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
|
|
698
|
+
k_idx = tOcO[0, 0, k][1] // elems_per_load
|
|
699
|
+
if const_expr(self.is_even_k) or tOpO[k]:
|
|
700
|
+
cute.copy(
|
|
701
|
+
gmem_tiled_copy_O_partial,
|
|
702
|
+
# mO_partial_cur_copy[None, k_idx, split],
|
|
703
|
+
utils.coord_offset_i64(mO_partial_cur_copy, split, dim=2)[None, k_idx],
|
|
704
|
+
tOsO_partial_cur[None, m, k],
|
|
705
|
+
)
|