mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1452 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
"""
|
|
3
|
+
Block-sparse runtime utilities for CUTE DSL kernels.
|
|
4
|
+
|
|
5
|
+
This module contains runtime execution functions for block-sparse attention kernels.
|
|
6
|
+
These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Callable, Optional
|
|
10
|
+
from functools import partial
|
|
11
|
+
import math
|
|
12
|
+
import cutlass
|
|
13
|
+
import cutlass.cute as cute
|
|
14
|
+
from cutlass import Float32, Int32, const_expr
|
|
15
|
+
|
|
16
|
+
# Import data structures from block_sparsity
|
|
17
|
+
from mslk.attention.flash_attn.block_sparsity import BlockSparseTensors
|
|
18
|
+
from mslk.attention.flash_attn import utils
|
|
19
|
+
from mslk.attention.flash_attn import copy_utils
|
|
20
|
+
from mslk.attention.flash_attn.named_barrier import NamedBarrierBwd
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@cute.jit
|
|
24
|
+
def load_block_list(
|
|
25
|
+
block_indices: cute.Tensor,
|
|
26
|
+
block_count,
|
|
27
|
+
load_q_with_first: cutlass.Constexpr,
|
|
28
|
+
first_block_preloaded: cutlass.Constexpr,
|
|
29
|
+
kv_producer_state,
|
|
30
|
+
load_Q,
|
|
31
|
+
load_K,
|
|
32
|
+
load_V,
|
|
33
|
+
pipeline_k,
|
|
34
|
+
pipeline_v,
|
|
35
|
+
use_tma_q: cutlass.Constexpr,
|
|
36
|
+
tma_q_bytes: cutlass.Constexpr,
|
|
37
|
+
intra_wg_overlap: cutlass.Constexpr,
|
|
38
|
+
):
|
|
39
|
+
"""Iterate over the sparse blocks and load K, V (and Q) into the pipeline.
|
|
40
|
+
for the intra_wg_overlap case, we overlap the loads of K and V. And this
|
|
41
|
+
means we need to pipeline the last V load from the partial block case,
|
|
42
|
+
with the loads for the full blocks. Set first_block_preloaded when the
|
|
43
|
+
caller has already issued the first K load for the list.
|
|
44
|
+
|
|
45
|
+
Note:
|
|
46
|
+
we iterate along the block_n indices in reverse.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Updated kv_producer_state after processing the block list.
|
|
50
|
+
|
|
51
|
+
"""
|
|
52
|
+
if block_count > 0:
|
|
53
|
+
if const_expr(not intra_wg_overlap):
|
|
54
|
+
# Peel first iteration: the first block may need to load Q alongside K,
|
|
55
|
+
# Parameters are already Constexpr, so no need to wrap in const_expr()
|
|
56
|
+
n_block_first = block_indices[block_count - 1]
|
|
57
|
+
extra_tx = tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0
|
|
58
|
+
pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx)
|
|
59
|
+
|
|
60
|
+
if const_expr(load_q_with_first and use_tma_q):
|
|
61
|
+
load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
|
|
62
|
+
|
|
63
|
+
load_K(src_idx=n_block_first, producer_state=kv_producer_state)
|
|
64
|
+
pipeline_v.producer_acquire(kv_producer_state)
|
|
65
|
+
load_V(src_idx=n_block_first, producer_state=kv_producer_state)
|
|
66
|
+
kv_producer_state.advance()
|
|
67
|
+
|
|
68
|
+
for offset in cutlass.range(1, block_count):
|
|
69
|
+
n_block = block_indices[block_count - 1 - offset]
|
|
70
|
+
pipeline_k.producer_acquire(kv_producer_state)
|
|
71
|
+
load_K(src_idx=n_block, producer_state=kv_producer_state)
|
|
72
|
+
pipeline_v.producer_acquire(kv_producer_state)
|
|
73
|
+
load_V(src_idx=n_block, producer_state=kv_producer_state)
|
|
74
|
+
kv_producer_state.advance()
|
|
75
|
+
else:
|
|
76
|
+
n_block_first = block_indices[block_count - 1]
|
|
77
|
+
if const_expr(not first_block_preloaded):
|
|
78
|
+
extra_tx = (
|
|
79
|
+
tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0
|
|
80
|
+
)
|
|
81
|
+
pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx)
|
|
82
|
+
|
|
83
|
+
if const_expr(load_q_with_first and use_tma_q):
|
|
84
|
+
load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
|
|
85
|
+
|
|
86
|
+
load_K(src_idx=n_block_first, producer_state=kv_producer_state)
|
|
87
|
+
|
|
88
|
+
for idx in cutlass.range(block_count - 1, unroll=1):
|
|
89
|
+
n_block_prev = block_indices[block_count - 1 - idx]
|
|
90
|
+
n_block = block_indices[block_count - 2 - idx]
|
|
91
|
+
kv_producer_state_prev = kv_producer_state.clone()
|
|
92
|
+
kv_producer_state.advance()
|
|
93
|
+
pipeline_k.producer_acquire(kv_producer_state)
|
|
94
|
+
load_K(src_idx=n_block, producer_state=kv_producer_state)
|
|
95
|
+
pipeline_v.producer_acquire(kv_producer_state_prev)
|
|
96
|
+
load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev)
|
|
97
|
+
|
|
98
|
+
return kv_producer_state
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@cute.jit
|
|
102
|
+
def finish_overlap_v_load(
|
|
103
|
+
block_indices: cute.Tensor,
|
|
104
|
+
block_count,
|
|
105
|
+
load_V,
|
|
106
|
+
pipeline_v,
|
|
107
|
+
kv_producer_state,
|
|
108
|
+
):
|
|
109
|
+
"""Load the final V block after overlapped K/V loads."""
|
|
110
|
+
if block_count > 0:
|
|
111
|
+
n_block_last = block_indices[0]
|
|
112
|
+
pipeline_v.producer_acquire(kv_producer_state)
|
|
113
|
+
load_V(src_idx=n_block_last, producer_state=kv_producer_state)
|
|
114
|
+
kv_producer_state.advance()
|
|
115
|
+
|
|
116
|
+
return kv_producer_state
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@cute.jit
|
|
120
|
+
def sparse_tensor_m_block(
|
|
121
|
+
m_block,
|
|
122
|
+
qhead_per_kvhead: cutlass.Constexpr[int],
|
|
123
|
+
):
|
|
124
|
+
"""Map packed m_block indices to block-sparse tensor indices."""
|
|
125
|
+
if const_expr(qhead_per_kvhead != 1):
|
|
126
|
+
return m_block // qhead_per_kvhead
|
|
127
|
+
return m_block
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@cute.jit
|
|
131
|
+
def produce_block_sparse_loads(
|
|
132
|
+
blocksparse_tensors: BlockSparseTensors,
|
|
133
|
+
batch_idx,
|
|
134
|
+
head_idx,
|
|
135
|
+
m_block,
|
|
136
|
+
kv_producer_state,
|
|
137
|
+
load_Q,
|
|
138
|
+
load_K,
|
|
139
|
+
load_V,
|
|
140
|
+
pipeline_k,
|
|
141
|
+
pipeline_v,
|
|
142
|
+
use_tma_q: cutlass.Constexpr,
|
|
143
|
+
tma_q_bytes: cutlass.Constexpr,
|
|
144
|
+
intra_wg_overlap: cutlass.Constexpr,
|
|
145
|
+
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
|
|
146
|
+
):
|
|
147
|
+
"""Iterate over the mask and full block lists for a single tile.
|
|
148
|
+
|
|
149
|
+
The masked (partial) list may leave the last V load pending when intra-warp-group
|
|
150
|
+
overlap is enabled. The first full block must consume that pending V while
|
|
151
|
+
issuing its own K load on the next pipeline stage.
|
|
152
|
+
|
|
153
|
+
In the intra-wg-overlap path, the last masked block leaves its V copy in flight
|
|
154
|
+
while we advance the producer state to start the next full K. Either the full list
|
|
155
|
+
overlaps that pending V load, or, if no full blocks exist, we explicitly drain it.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and
|
|
159
|
+
must be converted to unpacked for sparse tensor indexing.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
|
|
163
|
+
|
|
164
|
+
m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead)
|
|
165
|
+
|
|
166
|
+
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
|
|
167
|
+
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
|
168
|
+
|
|
169
|
+
if const_expr(full_block_cnt is not None):
|
|
170
|
+
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
|
|
171
|
+
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
|
172
|
+
else:
|
|
173
|
+
curr_full_block_cnt = Int32(0)
|
|
174
|
+
curr_full_block_idx = None
|
|
175
|
+
|
|
176
|
+
mask_empty = curr_mask_block_cnt == 0
|
|
177
|
+
full_empty = curr_full_block_cnt == 0
|
|
178
|
+
|
|
179
|
+
if mask_empty:
|
|
180
|
+
# No masked blocks: the full list owns the initial Q+K load.
|
|
181
|
+
kv_producer_state = load_block_list(
|
|
182
|
+
curr_full_block_idx,
|
|
183
|
+
curr_full_block_cnt,
|
|
184
|
+
load_q_with_first=True,
|
|
185
|
+
first_block_preloaded=False,
|
|
186
|
+
kv_producer_state=kv_producer_state,
|
|
187
|
+
load_Q=load_Q,
|
|
188
|
+
load_K=load_K,
|
|
189
|
+
load_V=load_V,
|
|
190
|
+
pipeline_k=pipeline_k,
|
|
191
|
+
pipeline_v=pipeline_v,
|
|
192
|
+
use_tma_q=use_tma_q,
|
|
193
|
+
tma_q_bytes=tma_q_bytes,
|
|
194
|
+
intra_wg_overlap=intra_wg_overlap,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if const_expr(intra_wg_overlap) and curr_full_block_cnt > 0:
|
|
198
|
+
kv_producer_state = finish_overlap_v_load(
|
|
199
|
+
curr_full_block_idx,
|
|
200
|
+
curr_full_block_cnt,
|
|
201
|
+
load_V,
|
|
202
|
+
pipeline_v,
|
|
203
|
+
kv_producer_state,
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
# Masked blocks present: load Q together with the first masked K so consumers can
|
|
207
|
+
# start immediately. When overlap is disabled this fully drains the list.
|
|
208
|
+
kv_producer_state = load_block_list(
|
|
209
|
+
curr_mask_block_idx,
|
|
210
|
+
curr_mask_block_cnt,
|
|
211
|
+
load_q_with_first=True,
|
|
212
|
+
first_block_preloaded=False,
|
|
213
|
+
kv_producer_state=kv_producer_state,
|
|
214
|
+
load_Q=load_Q,
|
|
215
|
+
load_K=load_K,
|
|
216
|
+
load_V=load_V,
|
|
217
|
+
pipeline_k=pipeline_k,
|
|
218
|
+
pipeline_v=pipeline_v,
|
|
219
|
+
use_tma_q=use_tma_q,
|
|
220
|
+
tma_q_bytes=tma_q_bytes,
|
|
221
|
+
intra_wg_overlap=intra_wg_overlap,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
if full_empty:
|
|
225
|
+
if const_expr(intra_wg_overlap):
|
|
226
|
+
kv_producer_state = finish_overlap_v_load(
|
|
227
|
+
curr_mask_block_idx,
|
|
228
|
+
curr_mask_block_cnt,
|
|
229
|
+
load_V,
|
|
230
|
+
pipeline_v,
|
|
231
|
+
kv_producer_state,
|
|
232
|
+
)
|
|
233
|
+
else:
|
|
234
|
+
if const_expr(intra_wg_overlap):
|
|
235
|
+
# Bridge the masked list to the full list by overlapping the pending masked V
|
|
236
|
+
# with the first full K load.
|
|
237
|
+
n_block_mask_last = curr_mask_block_idx[0]
|
|
238
|
+
n_block_full_first = curr_full_block_idx[curr_full_block_cnt - 1]
|
|
239
|
+
kv_producer_state_prev = kv_producer_state.clone()
|
|
240
|
+
kv_producer_state.advance()
|
|
241
|
+
pipeline_k.producer_acquire(kv_producer_state)
|
|
242
|
+
load_K(src_idx=n_block_full_first, producer_state=kv_producer_state)
|
|
243
|
+
pipeline_v.producer_acquire(kv_producer_state_prev)
|
|
244
|
+
load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev)
|
|
245
|
+
|
|
246
|
+
kv_producer_state = load_block_list(
|
|
247
|
+
curr_full_block_idx,
|
|
248
|
+
curr_full_block_cnt,
|
|
249
|
+
load_q_with_first=False,
|
|
250
|
+
first_block_preloaded=True,
|
|
251
|
+
kv_producer_state=kv_producer_state,
|
|
252
|
+
load_Q=load_Q,
|
|
253
|
+
load_K=load_K,
|
|
254
|
+
load_V=load_V,
|
|
255
|
+
pipeline_k=pipeline_k,
|
|
256
|
+
pipeline_v=pipeline_v,
|
|
257
|
+
use_tma_q=use_tma_q,
|
|
258
|
+
tma_q_bytes=tma_q_bytes,
|
|
259
|
+
intra_wg_overlap=intra_wg_overlap,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
kv_producer_state = finish_overlap_v_load(
|
|
263
|
+
curr_full_block_idx,
|
|
264
|
+
curr_full_block_cnt,
|
|
265
|
+
load_V,
|
|
266
|
+
pipeline_v,
|
|
267
|
+
kv_producer_state,
|
|
268
|
+
)
|
|
269
|
+
else:
|
|
270
|
+
# Non-overlap path with both lists: run the full list normally (skipping the Q
|
|
271
|
+
# reload because the masked list already issued it).
|
|
272
|
+
kv_producer_state = load_block_list(
|
|
273
|
+
curr_full_block_idx,
|
|
274
|
+
curr_full_block_cnt,
|
|
275
|
+
load_q_with_first=False,
|
|
276
|
+
first_block_preloaded=False,
|
|
277
|
+
kv_producer_state=kv_producer_state,
|
|
278
|
+
load_Q=load_Q,
|
|
279
|
+
load_K=load_K,
|
|
280
|
+
load_V=load_V,
|
|
281
|
+
pipeline_k=pipeline_k,
|
|
282
|
+
pipeline_v=pipeline_v,
|
|
283
|
+
use_tma_q=use_tma_q,
|
|
284
|
+
tma_q_bytes=tma_q_bytes,
|
|
285
|
+
intra_wg_overlap=intra_wg_overlap,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
return kv_producer_state
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
@cute.jit
|
|
292
|
+
def consume_block_sparse_loads(
|
|
293
|
+
blocksparse_tensors: BlockSparseTensors,
|
|
294
|
+
batch_idx,
|
|
295
|
+
head_idx,
|
|
296
|
+
m_block,
|
|
297
|
+
seqlen,
|
|
298
|
+
kv_consumer_state,
|
|
299
|
+
mma_pv_fn,
|
|
300
|
+
mma_one_n_block,
|
|
301
|
+
process_first_half_block,
|
|
302
|
+
process_last_half_block,
|
|
303
|
+
mask_fn,
|
|
304
|
+
score_mod_fn,
|
|
305
|
+
O_should_accumulate,
|
|
306
|
+
mask_mod,
|
|
307
|
+
fastdiv_mods,
|
|
308
|
+
intra_wg_overlap: cutlass.Constexpr,
|
|
309
|
+
warp_scheduler_barrier_sync: Callable,
|
|
310
|
+
warp_scheduler_barrier_arrive: Callable,
|
|
311
|
+
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
|
|
312
|
+
):
|
|
313
|
+
"""Consume the mask and full block lists for a single tile on the consumer side.
|
|
314
|
+
|
|
315
|
+
Mirrors `produce_block_sparse_loads` so that the consumer pipeline uses
|
|
316
|
+
the same sparse tensor indexing.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and
|
|
320
|
+
must be converted to unpacked for sparse tensor indexing.
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
|
|
324
|
+
|
|
325
|
+
m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead)
|
|
326
|
+
|
|
327
|
+
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
|
|
328
|
+
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
|
329
|
+
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
|
|
330
|
+
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
|
331
|
+
|
|
332
|
+
processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0
|
|
333
|
+
|
|
334
|
+
if const_expr(not intra_wg_overlap):
|
|
335
|
+
if curr_mask_block_cnt > 0:
|
|
336
|
+
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
|
|
337
|
+
warp_scheduler_barrier_sync()
|
|
338
|
+
kv_consumer_state = mma_one_n_block(
|
|
339
|
+
kv_consumer_state,
|
|
340
|
+
n_block=mask_n_block,
|
|
341
|
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
|
342
|
+
mask_fn=partial(
|
|
343
|
+
mask_fn,
|
|
344
|
+
mask_mod=mask_mod,
|
|
345
|
+
mask_seqlen=True,
|
|
346
|
+
fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None,
|
|
347
|
+
),
|
|
348
|
+
is_first_n_block=True,
|
|
349
|
+
)
|
|
350
|
+
O_should_accumulate = True
|
|
351
|
+
for i in cutlass.range(1, curr_mask_block_cnt):
|
|
352
|
+
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
|
|
353
|
+
kv_consumer_state = mma_one_n_block(
|
|
354
|
+
kv_consumer_state,
|
|
355
|
+
n_block=mask_n_block,
|
|
356
|
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
|
357
|
+
mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False),
|
|
358
|
+
is_first_n_block=False,
|
|
359
|
+
)
|
|
360
|
+
O_should_accumulate = True
|
|
361
|
+
if curr_full_block_cnt == 0:
|
|
362
|
+
warp_scheduler_barrier_arrive()
|
|
363
|
+
|
|
364
|
+
if curr_full_block_cnt > 0:
|
|
365
|
+
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
|
|
366
|
+
if curr_mask_block_cnt == 0:
|
|
367
|
+
warp_scheduler_barrier_sync()
|
|
368
|
+
kv_consumer_state = mma_one_n_block(
|
|
369
|
+
kv_consumer_state,
|
|
370
|
+
n_block=full_n_block,
|
|
371
|
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
|
372
|
+
mask_fn=partial(mask_fn, mask_seqlen=True),
|
|
373
|
+
is_first_n_block=True,
|
|
374
|
+
)
|
|
375
|
+
O_should_accumulate = True
|
|
376
|
+
for i in cutlass.range(1, curr_full_block_cnt):
|
|
377
|
+
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
|
|
378
|
+
kv_consumer_state = mma_one_n_block(
|
|
379
|
+
kv_consumer_state,
|
|
380
|
+
n_block=full_n_block,
|
|
381
|
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
|
382
|
+
mask_fn=partial(mask_fn, mask_seqlen=False),
|
|
383
|
+
is_first_n_block=False,
|
|
384
|
+
)
|
|
385
|
+
O_should_accumulate = True
|
|
386
|
+
else:
|
|
387
|
+
kv_consumer_state = mma_one_n_block(
|
|
388
|
+
kv_consumer_state,
|
|
389
|
+
n_block=full_n_block,
|
|
390
|
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
|
391
|
+
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
|
|
392
|
+
is_first_n_block=False,
|
|
393
|
+
)
|
|
394
|
+
O_should_accumulate = True
|
|
395
|
+
for i in cutlass.range(1, curr_full_block_cnt):
|
|
396
|
+
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
|
|
397
|
+
kv_consumer_state = mma_one_n_block(
|
|
398
|
+
kv_consumer_state,
|
|
399
|
+
n_block=full_n_block,
|
|
400
|
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
|
401
|
+
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False),
|
|
402
|
+
is_first_n_block=False,
|
|
403
|
+
)
|
|
404
|
+
O_should_accumulate = True
|
|
405
|
+
warp_scheduler_barrier_arrive()
|
|
406
|
+
else:
|
|
407
|
+
if curr_mask_block_cnt > 0:
|
|
408
|
+
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
|
|
409
|
+
kv_consumer_state = process_first_half_block(
|
|
410
|
+
n_block=mask_n_block,
|
|
411
|
+
seqlen=seqlen,
|
|
412
|
+
kv_consumer_state=kv_consumer_state,
|
|
413
|
+
mask_fn=partial(
|
|
414
|
+
mask_fn,
|
|
415
|
+
mask_mod=mask_mod,
|
|
416
|
+
mask_seqlen=True,
|
|
417
|
+
fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None,
|
|
418
|
+
),
|
|
419
|
+
score_mod_fn=score_mod_fn,
|
|
420
|
+
is_first_block=True,
|
|
421
|
+
)
|
|
422
|
+
for i in cutlass.range(1, curr_mask_block_cnt):
|
|
423
|
+
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
|
|
424
|
+
kv_consumer_state = mma_one_n_block(
|
|
425
|
+
kv_consumer_state,
|
|
426
|
+
n_block=mask_n_block,
|
|
427
|
+
seqlen=seqlen,
|
|
428
|
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
|
429
|
+
mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False),
|
|
430
|
+
)
|
|
431
|
+
O_should_accumulate = True
|
|
432
|
+
|
|
433
|
+
if curr_full_block_cnt > 0:
|
|
434
|
+
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
|
|
435
|
+
if curr_mask_block_cnt == 0:
|
|
436
|
+
kv_consumer_state = process_first_half_block(
|
|
437
|
+
n_block=full_n_block,
|
|
438
|
+
seqlen=seqlen,
|
|
439
|
+
kv_consumer_state=kv_consumer_state,
|
|
440
|
+
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
|
|
441
|
+
score_mod_fn=score_mod_fn,
|
|
442
|
+
is_first_block=True,
|
|
443
|
+
)
|
|
444
|
+
else:
|
|
445
|
+
kv_consumer_state = mma_one_n_block(
|
|
446
|
+
kv_consumer_state,
|
|
447
|
+
n_block=full_n_block,
|
|
448
|
+
seqlen=seqlen,
|
|
449
|
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
|
450
|
+
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
|
|
451
|
+
)
|
|
452
|
+
O_should_accumulate = True
|
|
453
|
+
for i in cutlass.range(1, curr_full_block_cnt):
|
|
454
|
+
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
|
|
455
|
+
kv_consumer_state = mma_one_n_block(
|
|
456
|
+
kv_consumer_state,
|
|
457
|
+
n_block=full_n_block,
|
|
458
|
+
seqlen=seqlen,
|
|
459
|
+
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
|
|
460
|
+
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False),
|
|
461
|
+
)
|
|
462
|
+
O_should_accumulate = True
|
|
463
|
+
|
|
464
|
+
if curr_mask_block_cnt + curr_full_block_cnt > 0:
|
|
465
|
+
kv_consumer_state = process_last_half_block(
|
|
466
|
+
kv_consumer_state=kv_consumer_state,
|
|
467
|
+
zero_init=not O_should_accumulate,
|
|
468
|
+
)
|
|
469
|
+
O_should_accumulate = True
|
|
470
|
+
|
|
471
|
+
return kv_consumer_state, O_should_accumulate, processed_any
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
@cute.jit
|
|
475
|
+
def load_block_list_sm100(
|
|
476
|
+
block_indices: cute.Tensor,
|
|
477
|
+
block_count,
|
|
478
|
+
load_q_with_first: cutlass.Constexpr,
|
|
479
|
+
m_block,
|
|
480
|
+
q_stage: cutlass.Constexpr,
|
|
481
|
+
kv_producer_state,
|
|
482
|
+
load_Q,
|
|
483
|
+
load_K,
|
|
484
|
+
load_V,
|
|
485
|
+
pipeline_kv,
|
|
486
|
+
):
|
|
487
|
+
"""SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count)."""
|
|
488
|
+
if block_count > 0:
|
|
489
|
+
# First iteration: load Q alongside K if requested
|
|
490
|
+
n_block_first = block_indices[block_count - 1]
|
|
491
|
+
|
|
492
|
+
if const_expr(load_q_with_first):
|
|
493
|
+
# SM100 loads Q0 and optionally Q1
|
|
494
|
+
load_Q(block=q_stage * m_block + 0, stage=0)
|
|
495
|
+
if const_expr(q_stage == 2):
|
|
496
|
+
load_Q(block=q_stage * m_block + 1, stage=1)
|
|
497
|
+
|
|
498
|
+
# SM100 doesn't use producer_acquire for pipeline_kv in load path
|
|
499
|
+
# The pipeline barriers are handled inside load_KV
|
|
500
|
+
load_K(block=n_block_first, producer_state=kv_producer_state, page_idx=None)
|
|
501
|
+
kv_producer_state.advance()
|
|
502
|
+
load_V(block=n_block_first, producer_state=kv_producer_state, page_idx=None)
|
|
503
|
+
kv_producer_state.advance()
|
|
504
|
+
|
|
505
|
+
# Remaining blocks
|
|
506
|
+
for offset in cutlass.range(1, block_count):
|
|
507
|
+
n_block = block_indices[block_count - 1 - offset]
|
|
508
|
+
load_K(block=n_block, producer_state=kv_producer_state, page_idx=None)
|
|
509
|
+
kv_producer_state.advance()
|
|
510
|
+
load_V(block=n_block, producer_state=kv_producer_state, page_idx=None)
|
|
511
|
+
kv_producer_state.advance()
|
|
512
|
+
|
|
513
|
+
return kv_producer_state
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
# SM100-specific tile processor using SM100 helpers
|
|
517
|
+
@cute.jit
|
|
518
|
+
def produce_block_sparse_loads_sm100(
|
|
519
|
+
blocksparse_tensors: BlockSparseTensors,
|
|
520
|
+
batch_idx,
|
|
521
|
+
head_idx,
|
|
522
|
+
m_block,
|
|
523
|
+
kv_producer_state,
|
|
524
|
+
load_Q,
|
|
525
|
+
load_K,
|
|
526
|
+
load_V,
|
|
527
|
+
pipeline_kv,
|
|
528
|
+
q_stage: cutlass.Constexpr,
|
|
529
|
+
q_producer_phase: Int32,
|
|
530
|
+
qhead_per_kvhead: cutlass.Constexpr,
|
|
531
|
+
):
|
|
532
|
+
"""SM100 entry point for sparse block iteration.
|
|
533
|
+
|
|
534
|
+
SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use
|
|
535
|
+
simplified block processing that just calls producer_acquire without extras.
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
m_block: which tile of m we are processing
|
|
539
|
+
qhead_per_kvhead: Constexpr pack factor
|
|
540
|
+
"""
|
|
541
|
+
# NB: Compute unpacked index for sparse tensor access
|
|
542
|
+
if const_expr(qhead_per_kvhead != 1):
|
|
543
|
+
m_block_sparse = m_block // qhead_per_kvhead
|
|
544
|
+
else:
|
|
545
|
+
m_block_sparse = m_block
|
|
546
|
+
|
|
547
|
+
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
|
|
548
|
+
|
|
549
|
+
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
|
|
550
|
+
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
|
551
|
+
|
|
552
|
+
if const_expr(full_block_cnt is not None):
|
|
553
|
+
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
|
|
554
|
+
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
|
555
|
+
else:
|
|
556
|
+
curr_full_block_cnt = Int32(0)
|
|
557
|
+
curr_full_block_idx = None
|
|
558
|
+
|
|
559
|
+
mask_empty = curr_mask_block_cnt == 0
|
|
560
|
+
full_empty = curr_full_block_cnt == 0
|
|
561
|
+
|
|
562
|
+
q_phase_flipped = False
|
|
563
|
+
|
|
564
|
+
if mask_empty:
|
|
565
|
+
# No masked blocks: process full list with Q loading
|
|
566
|
+
kv_producer_state = load_block_list_sm100(
|
|
567
|
+
curr_full_block_idx,
|
|
568
|
+
curr_full_block_cnt,
|
|
569
|
+
load_q_with_first=True,
|
|
570
|
+
m_block=m_block,
|
|
571
|
+
q_stage=q_stage,
|
|
572
|
+
kv_producer_state=kv_producer_state,
|
|
573
|
+
load_Q=load_Q,
|
|
574
|
+
load_K=load_K,
|
|
575
|
+
load_V=load_V,
|
|
576
|
+
pipeline_kv=pipeline_kv,
|
|
577
|
+
)
|
|
578
|
+
q_phase_flipped = not full_empty
|
|
579
|
+
else:
|
|
580
|
+
# Process masked blocks with Q loading
|
|
581
|
+
kv_producer_state = load_block_list_sm100(
|
|
582
|
+
curr_mask_block_idx,
|
|
583
|
+
curr_mask_block_cnt,
|
|
584
|
+
load_q_with_first=True,
|
|
585
|
+
m_block=m_block,
|
|
586
|
+
q_stage=q_stage,
|
|
587
|
+
kv_producer_state=kv_producer_state,
|
|
588
|
+
load_Q=load_Q,
|
|
589
|
+
load_K=load_K,
|
|
590
|
+
load_V=load_V,
|
|
591
|
+
pipeline_kv=pipeline_kv,
|
|
592
|
+
)
|
|
593
|
+
q_phase_flipped = True
|
|
594
|
+
|
|
595
|
+
if not full_empty:
|
|
596
|
+
# Process full blocks without Q loading
|
|
597
|
+
kv_producer_state = load_block_list_sm100(
|
|
598
|
+
curr_full_block_idx,
|
|
599
|
+
curr_full_block_cnt,
|
|
600
|
+
load_q_with_first=False,
|
|
601
|
+
m_block=m_block,
|
|
602
|
+
q_stage=q_stage,
|
|
603
|
+
kv_producer_state=kv_producer_state,
|
|
604
|
+
load_Q=load_Q,
|
|
605
|
+
load_K=load_K,
|
|
606
|
+
load_V=load_V,
|
|
607
|
+
pipeline_kv=pipeline_kv,
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
if q_phase_flipped:
|
|
611
|
+
q_producer_phase ^= 1
|
|
612
|
+
|
|
613
|
+
return kv_producer_state, q_producer_phase
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
@cute.jit
|
|
617
|
+
def get_total_block_count(
|
|
618
|
+
blocksparse_tensors: BlockSparseTensors,
|
|
619
|
+
batch_idx,
|
|
620
|
+
head_idx,
|
|
621
|
+
m_block,
|
|
622
|
+
qhead_per_kvhead: cutlass.Constexpr,
|
|
623
|
+
):
|
|
624
|
+
# NB: Convert packed m_block to unpacked for sparse tensor indexing
|
|
625
|
+
if const_expr(qhead_per_kvhead != 1):
|
|
626
|
+
m_block_sparse = m_block // qhead_per_kvhead
|
|
627
|
+
else:
|
|
628
|
+
m_block_sparse = m_block
|
|
629
|
+
|
|
630
|
+
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
|
|
631
|
+
if const_expr(full_block_cnt is not None):
|
|
632
|
+
return (
|
|
633
|
+
mask_block_cnt[batch_idx, head_idx, m_block_sparse]
|
|
634
|
+
+ full_block_cnt[batch_idx, head_idx, m_block_sparse]
|
|
635
|
+
)
|
|
636
|
+
else:
|
|
637
|
+
return mask_block_cnt[batch_idx, head_idx, m_block_sparse]
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
@cute.jit
|
|
641
|
+
def handle_block_sparse_empty_tile_correction_sm100(
|
|
642
|
+
tidx: Int32,
|
|
643
|
+
q_stage: cutlass.Constexpr,
|
|
644
|
+
m_block_size: cutlass.Constexpr,
|
|
645
|
+
qhead_per_kvhead,
|
|
646
|
+
pack_gqa: cutlass.Constexpr,
|
|
647
|
+
is_split_kv: cutlass.Constexpr,
|
|
648
|
+
learnable_sink,
|
|
649
|
+
mLSE,
|
|
650
|
+
seqlen,
|
|
651
|
+
m_block: Int32,
|
|
652
|
+
head_idx: Int32,
|
|
653
|
+
batch_idx: Int32,
|
|
654
|
+
split_idx: Int32,
|
|
655
|
+
sScale: cute.Tensor,
|
|
656
|
+
stats: list,
|
|
657
|
+
correction_epilogue: Callable,
|
|
658
|
+
thr_mma_pv: cute.core.ThrMma,
|
|
659
|
+
tOtOs: tuple[cute.Tensor],
|
|
660
|
+
sO: cute.Tensor,
|
|
661
|
+
mbar_ptr,
|
|
662
|
+
mbar_softmax_corr_full_offset: Int32,
|
|
663
|
+
mbar_softmax_corr_empty_offset: Int32,
|
|
664
|
+
mbar_P_full_O_rescaled_offset: Int32,
|
|
665
|
+
mbar_P_full_2_offset: Int32,
|
|
666
|
+
mbar_corr_epi_full_offset: Int32,
|
|
667
|
+
mbar_corr_epi_empty_offset: Int32,
|
|
668
|
+
softmax_corr_consumer_phase: Int32,
|
|
669
|
+
o_corr_consumer_phase: Int32,
|
|
670
|
+
corr_epi_producer_phase: Int32,
|
|
671
|
+
softmax_scale_log2: Float32,
|
|
672
|
+
mO_cur: Optional[cute.Tensor] = None,
|
|
673
|
+
gO: Optional[cute.Tensor] = None,
|
|
674
|
+
gmem_tiled_copy_O: Optional[cute.TiledCopy] = None,
|
|
675
|
+
):
|
|
676
|
+
"""Handle the block-sparse case where a tile is fully masked:
|
|
677
|
+
* zero staged results
|
|
678
|
+
* seed stats
|
|
679
|
+
* satisfy the usual barrier protocol so downstream warps continue to make progress.
|
|
680
|
+
"""
|
|
681
|
+
LOG2_E = Float32(math.log2(math.e))
|
|
682
|
+
|
|
683
|
+
for stage in cutlass.range_constexpr(q_stage):
|
|
684
|
+
row_sum_value = Float32(1.0)
|
|
685
|
+
row_max_value = (
|
|
686
|
+
-Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None
|
|
687
|
+
)
|
|
688
|
+
if const_expr(learnable_sink is not None):
|
|
689
|
+
sink_val = -Float32.inf
|
|
690
|
+
if const_expr(not pack_gqa):
|
|
691
|
+
sink_val = Float32(learnable_sink[head_idx])
|
|
692
|
+
elif tidx < m_block_size:
|
|
693
|
+
q_head_idx = (
|
|
694
|
+
(q_stage * m_block + stage) * m_block_size + tidx
|
|
695
|
+
) % qhead_per_kvhead + head_idx * qhead_per_kvhead
|
|
696
|
+
sink_val = Float32(learnable_sink[q_head_idx])
|
|
697
|
+
if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0):
|
|
698
|
+
if row_max_value == -Float32.inf:
|
|
699
|
+
row_max_value = sink_val * (LOG2_E / softmax_scale_log2)
|
|
700
|
+
row_sum_value = Float32(1.0)
|
|
701
|
+
else:
|
|
702
|
+
row_sum_value = row_sum_value + utils.exp2f(
|
|
703
|
+
sink_val * LOG2_E - row_max_value * softmax_scale_log2
|
|
704
|
+
)
|
|
705
|
+
if tidx < m_block_size:
|
|
706
|
+
scale_row_idx = tidx + stage * m_block_size
|
|
707
|
+
sScale[scale_row_idx] = row_sum_value
|
|
708
|
+
if const_expr(mLSE is not None or learnable_sink is not None):
|
|
709
|
+
sScale[scale_row_idx + m_block_size * 2] = row_max_value
|
|
710
|
+
acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value
|
|
711
|
+
stats[stage] = (row_sum_value, row_max_value, acc_flag)
|
|
712
|
+
|
|
713
|
+
cute.arch.mbarrier_wait(
|
|
714
|
+
mbar_ptr + mbar_softmax_corr_full_offset + stage,
|
|
715
|
+
softmax_corr_consumer_phase,
|
|
716
|
+
)
|
|
717
|
+
cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage)
|
|
718
|
+
|
|
719
|
+
if const_expr(gmem_tiled_copy_O is None):
|
|
720
|
+
cute.arch.mbarrier_wait(
|
|
721
|
+
mbar_ptr + mbar_corr_epi_empty_offset + stage,
|
|
722
|
+
corr_epi_producer_phase,
|
|
723
|
+
)
|
|
724
|
+
correction_epilogue(
|
|
725
|
+
thr_mma_pv,
|
|
726
|
+
tOtOs[stage],
|
|
727
|
+
tidx,
|
|
728
|
+
stage,
|
|
729
|
+
m_block,
|
|
730
|
+
seqlen.seqlen_q,
|
|
731
|
+
Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs
|
|
732
|
+
sO[None, None, stage],
|
|
733
|
+
mO_cur,
|
|
734
|
+
gO,
|
|
735
|
+
gmem_tiled_copy_O,
|
|
736
|
+
)
|
|
737
|
+
if const_expr(gmem_tiled_copy_O is None):
|
|
738
|
+
cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage)
|
|
739
|
+
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage)
|
|
740
|
+
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage)
|
|
741
|
+
|
|
742
|
+
softmax_corr_consumer_phase ^= 1
|
|
743
|
+
o_corr_consumer_phase ^= 1
|
|
744
|
+
corr_epi_producer_phase ^= 1
|
|
745
|
+
|
|
746
|
+
return (
|
|
747
|
+
softmax_corr_consumer_phase,
|
|
748
|
+
o_corr_consumer_phase,
|
|
749
|
+
corr_epi_producer_phase,
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
@cute.jit
|
|
754
|
+
def softmax_block_sparse_sm100(
|
|
755
|
+
blocksparse_tensors: BlockSparseTensors,
|
|
756
|
+
batch_idx,
|
|
757
|
+
head_idx,
|
|
758
|
+
m_block,
|
|
759
|
+
softmax_step: Callable,
|
|
760
|
+
mask_fn: Callable,
|
|
761
|
+
mask_fn_none: Callable,
|
|
762
|
+
mma_si_consumer_phase: Int32,
|
|
763
|
+
si_corr_producer_phase: Int32,
|
|
764
|
+
s0_s1_sequence_phase: Int32,
|
|
765
|
+
mbar_ptr,
|
|
766
|
+
mbar_softmax_corr_full_offset: Int32,
|
|
767
|
+
mbar_softmax_corr_empty_offset: Int32,
|
|
768
|
+
mbar_P_full_O_rescaled_offset: Int32,
|
|
769
|
+
mbar_P_full_2_offset: Int32,
|
|
770
|
+
q_stage: cutlass.Constexpr,
|
|
771
|
+
stage_idx: Int32,
|
|
772
|
+
check_m_boundary: bool,
|
|
773
|
+
qhead_per_kvhead: cutlass.Constexpr,
|
|
774
|
+
):
|
|
775
|
+
# Convert packed m_block to unpacked for sparse tensor indexing
|
|
776
|
+
if const_expr(qhead_per_kvhead != 1):
|
|
777
|
+
m_block_sparse = m_block // qhead_per_kvhead
|
|
778
|
+
else:
|
|
779
|
+
m_block_sparse = m_block
|
|
780
|
+
|
|
781
|
+
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
|
|
782
|
+
|
|
783
|
+
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
|
|
784
|
+
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
|
785
|
+
|
|
786
|
+
if const_expr(full_block_cnt is not None):
|
|
787
|
+
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
|
|
788
|
+
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
|
|
789
|
+
else:
|
|
790
|
+
curr_full_block_cnt = Int32(0)
|
|
791
|
+
curr_full_block_idx = None
|
|
792
|
+
|
|
793
|
+
total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt
|
|
794
|
+
|
|
795
|
+
if total_block_cnt == 0:
|
|
796
|
+
cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_full_offset + stage_idx)
|
|
797
|
+
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage_idx)
|
|
798
|
+
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage_idx)
|
|
799
|
+
cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage_idx)
|
|
800
|
+
else:
|
|
801
|
+
if curr_mask_block_cnt > 0:
|
|
802
|
+
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
|
|
803
|
+
(
|
|
804
|
+
mma_si_consumer_phase,
|
|
805
|
+
si_corr_producer_phase,
|
|
806
|
+
s0_s1_sequence_phase,
|
|
807
|
+
) = softmax_step(
|
|
808
|
+
mma_si_consumer_phase,
|
|
809
|
+
si_corr_producer_phase,
|
|
810
|
+
s0_s1_sequence_phase,
|
|
811
|
+
mask_n_block,
|
|
812
|
+
is_first=True,
|
|
813
|
+
mask_fn=partial(mask_fn, mask_seqlen=True, check_q_boundary=check_m_boundary),
|
|
814
|
+
)
|
|
815
|
+
for i in cutlass.range(1, curr_mask_block_cnt):
|
|
816
|
+
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
|
|
817
|
+
(
|
|
818
|
+
mma_si_consumer_phase,
|
|
819
|
+
si_corr_producer_phase,
|
|
820
|
+
s0_s1_sequence_phase,
|
|
821
|
+
) = softmax_step(
|
|
822
|
+
mma_si_consumer_phase,
|
|
823
|
+
si_corr_producer_phase,
|
|
824
|
+
s0_s1_sequence_phase,
|
|
825
|
+
mask_n_block,
|
|
826
|
+
mask_fn=partial(mask_fn, mask_seqlen=False, check_q_boundary=check_m_boundary),
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
if curr_full_block_cnt > 0:
|
|
830
|
+
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
|
|
831
|
+
if curr_mask_block_cnt == 0:
|
|
832
|
+
(
|
|
833
|
+
mma_si_consumer_phase,
|
|
834
|
+
si_corr_producer_phase,
|
|
835
|
+
s0_s1_sequence_phase,
|
|
836
|
+
) = softmax_step(
|
|
837
|
+
mma_si_consumer_phase,
|
|
838
|
+
si_corr_producer_phase,
|
|
839
|
+
s0_s1_sequence_phase,
|
|
840
|
+
full_n_block,
|
|
841
|
+
is_first=True,
|
|
842
|
+
mask_fn=partial(
|
|
843
|
+
mask_fn_none, mask_seqlen=True, check_q_boundary=check_m_boundary
|
|
844
|
+
),
|
|
845
|
+
)
|
|
846
|
+
else:
|
|
847
|
+
(
|
|
848
|
+
mma_si_consumer_phase,
|
|
849
|
+
si_corr_producer_phase,
|
|
850
|
+
s0_s1_sequence_phase,
|
|
851
|
+
) = softmax_step(
|
|
852
|
+
mma_si_consumer_phase,
|
|
853
|
+
si_corr_producer_phase,
|
|
854
|
+
s0_s1_sequence_phase,
|
|
855
|
+
full_n_block,
|
|
856
|
+
is_first=False,
|
|
857
|
+
mask_fn=partial(
|
|
858
|
+
mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary
|
|
859
|
+
),
|
|
860
|
+
)
|
|
861
|
+
for i in cutlass.range(1, curr_full_block_cnt):
|
|
862
|
+
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
|
|
863
|
+
(
|
|
864
|
+
mma_si_consumer_phase,
|
|
865
|
+
si_corr_producer_phase,
|
|
866
|
+
s0_s1_sequence_phase,
|
|
867
|
+
) = softmax_step(
|
|
868
|
+
mma_si_consumer_phase,
|
|
869
|
+
si_corr_producer_phase,
|
|
870
|
+
s0_s1_sequence_phase,
|
|
871
|
+
full_n_block,
|
|
872
|
+
mask_fn=partial(
|
|
873
|
+
mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary
|
|
874
|
+
),
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
return (
|
|
878
|
+
mma_si_consumer_phase,
|
|
879
|
+
si_corr_producer_phase,
|
|
880
|
+
s0_s1_sequence_phase,
|
|
881
|
+
total_block_cnt == 0,
|
|
882
|
+
)
|
|
883
|
+
|
|
884
|
+
|
|
885
|
+
# =============================================================================
|
|
886
|
+
# Backward-specific block-sparse helpers (SM100)
|
|
887
|
+
# =============================================================================
|
|
888
|
+
#
|
|
889
|
+
# In backward, iteration is transposed compared to forward:
|
|
890
|
+
# - Forward: outer loop over m_blocks (Q tiles), inner loop over n_blocks (KV tiles)
|
|
891
|
+
# - Backward: outer loop over n_blocks (KV tiles), inner loop over m_blocks (Q tiles)
|
|
892
|
+
#
|
|
893
|
+
# The backward block-sparse tensors use "Q direction" indexing:
|
|
894
|
+
# - q_block_cnt[batch, head, n_block] → count of m_blocks to process for this KV tile
|
|
895
|
+
# - q_block_idx[batch, head, n_block, :] → indices of m_blocks to process
|
|
896
|
+
#
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
@cute.jit
|
|
900
|
+
def get_total_q_block_count_bwd(
|
|
901
|
+
blocksparse_tensors: BlockSparseTensors,
|
|
902
|
+
batch_idx,
|
|
903
|
+
head_idx,
|
|
904
|
+
n_block,
|
|
905
|
+
subtile_factor: cutlass.Constexpr = 1,
|
|
906
|
+
m_block_max: int = 0,
|
|
907
|
+
):
|
|
908
|
+
"""Count total tile iterations for given n_block (KV tile) in backward."""
|
|
909
|
+
q_block_cnt, _, full_block_cnt, _ = blocksparse_tensors
|
|
910
|
+
total = q_block_cnt[batch_idx, head_idx, n_block]
|
|
911
|
+
if const_expr(full_block_cnt is not None):
|
|
912
|
+
total = total + full_block_cnt[batch_idx, head_idx, n_block]
|
|
913
|
+
return total * subtile_factor
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
@cute.jit
|
|
917
|
+
def produce_block_sparse_q_loads_bwd_sm100(
|
|
918
|
+
blocksparse_tensors: BlockSparseTensors,
|
|
919
|
+
batch_idx,
|
|
920
|
+
head_idx,
|
|
921
|
+
n_block,
|
|
922
|
+
# Pipeline states (will be returned after advancing)
|
|
923
|
+
producer_state_Q_LSE,
|
|
924
|
+
producer_state_dO_dPsum,
|
|
925
|
+
# Pipelines
|
|
926
|
+
pipeline_Q,
|
|
927
|
+
pipeline_LSE,
|
|
928
|
+
pipeline_dO,
|
|
929
|
+
pipeline_dPsum,
|
|
930
|
+
# Load functions
|
|
931
|
+
load_K,
|
|
932
|
+
load_V,
|
|
933
|
+
load_Q,
|
|
934
|
+
load_dO,
|
|
935
|
+
copy_stats,
|
|
936
|
+
# Global tensors for LSE/dPsum
|
|
937
|
+
gLSE,
|
|
938
|
+
sLSE,
|
|
939
|
+
gdPsum,
|
|
940
|
+
sdPsum,
|
|
941
|
+
# TMA copy bytes for extra_tx_count
|
|
942
|
+
tma_copy_bytes_K,
|
|
943
|
+
tma_copy_bytes_V,
|
|
944
|
+
# Flags for which loads to perform
|
|
945
|
+
should_load_Q: cutlass.Constexpr,
|
|
946
|
+
should_load_dO: cutlass.Constexpr,
|
|
947
|
+
# Subtiling factor and bounds
|
|
948
|
+
subtile_factor: cutlass.Constexpr = 1,
|
|
949
|
+
m_block_max: int = 0,
|
|
950
|
+
):
|
|
951
|
+
"""SM100 backward block sparse loading with subtiling.
|
|
952
|
+
|
|
953
|
+
Returns updated (producer_state_Q_LSE, producer_state_dO_dPsum).
|
|
954
|
+
First iteration loads K/V alongside Q/dO; subsequent iterations load only Q/dO.
|
|
955
|
+
"""
|
|
956
|
+
(
|
|
957
|
+
curr_q_cnt,
|
|
958
|
+
curr_q_idx,
|
|
959
|
+
curr_full_cnt,
|
|
960
|
+
curr_full_idx,
|
|
961
|
+
loop_count,
|
|
962
|
+
) = get_block_sparse_iteration_info_bwd(
|
|
963
|
+
blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor, m_block_max
|
|
964
|
+
)
|
|
965
|
+
|
|
966
|
+
for iter_idx in cutlass.range(loop_count, unroll=1):
|
|
967
|
+
m_block, _ = get_m_block_from_iter_bwd(
|
|
968
|
+
iter_idx,
|
|
969
|
+
curr_q_cnt,
|
|
970
|
+
curr_q_idx,
|
|
971
|
+
curr_full_cnt,
|
|
972
|
+
curr_full_idx,
|
|
973
|
+
subtile_factor,
|
|
974
|
+
m_block_max,
|
|
975
|
+
)
|
|
976
|
+
m_block_safe = m_block
|
|
977
|
+
if m_block_max > 0:
|
|
978
|
+
m_block_safe = cutlass.min(m_block, m_block_max - 1)
|
|
979
|
+
|
|
980
|
+
if iter_idx == 0:
|
|
981
|
+
# First block: load K/V alongside Q/dO
|
|
982
|
+
if const_expr(should_load_Q):
|
|
983
|
+
pipeline_Q.producer_acquire(producer_state_Q_LSE, extra_tx_count=tma_copy_bytes_K)
|
|
984
|
+
load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE))
|
|
985
|
+
load_Q(m_block_safe, producer_state=producer_state_Q_LSE)
|
|
986
|
+
pipeline_Q.producer_commit(producer_state_Q_LSE)
|
|
987
|
+
pipeline_LSE.producer_acquire(producer_state_Q_LSE)
|
|
988
|
+
with cute.arch.elect_one():
|
|
989
|
+
copy_stats(
|
|
990
|
+
gLSE[None, m_block_safe],
|
|
991
|
+
sLSE[None, producer_state_Q_LSE.index],
|
|
992
|
+
mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE),
|
|
993
|
+
)
|
|
994
|
+
producer_state_Q_LSE.advance()
|
|
995
|
+
if const_expr(should_load_dO):
|
|
996
|
+
pipeline_dO.producer_acquire(
|
|
997
|
+
producer_state_dO_dPsum, extra_tx_count=tma_copy_bytes_V
|
|
998
|
+
)
|
|
999
|
+
load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum))
|
|
1000
|
+
load_dO(m_block_safe, producer_state=producer_state_dO_dPsum)
|
|
1001
|
+
pipeline_dO.producer_commit(producer_state_dO_dPsum)
|
|
1002
|
+
pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)
|
|
1003
|
+
with cute.arch.elect_one():
|
|
1004
|
+
copy_stats(
|
|
1005
|
+
gdPsum[None, m_block_safe],
|
|
1006
|
+
sdPsum[None, producer_state_dO_dPsum.index],
|
|
1007
|
+
mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum),
|
|
1008
|
+
)
|
|
1009
|
+
producer_state_dO_dPsum.advance()
|
|
1010
|
+
else:
|
|
1011
|
+
# Subsequent blocks: just load Q/dO (K/V already loaded)
|
|
1012
|
+
if const_expr(should_load_Q):
|
|
1013
|
+
pipeline_Q.producer_acquire(producer_state_Q_LSE)
|
|
1014
|
+
load_Q(m_block_safe, producer_state=producer_state_Q_LSE)
|
|
1015
|
+
pipeline_Q.producer_commit(producer_state_Q_LSE)
|
|
1016
|
+
pipeline_LSE.producer_acquire(producer_state_Q_LSE)
|
|
1017
|
+
with cute.arch.elect_one():
|
|
1018
|
+
copy_stats(
|
|
1019
|
+
gLSE[None, m_block_safe],
|
|
1020
|
+
sLSE[None, producer_state_Q_LSE.index],
|
|
1021
|
+
mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE),
|
|
1022
|
+
)
|
|
1023
|
+
producer_state_Q_LSE.advance()
|
|
1024
|
+
if const_expr(should_load_dO):
|
|
1025
|
+
pipeline_dO.producer_acquire(producer_state_dO_dPsum)
|
|
1026
|
+
load_dO(m_block_safe, producer_state=producer_state_dO_dPsum)
|
|
1027
|
+
pipeline_dO.producer_commit(producer_state_dO_dPsum)
|
|
1028
|
+
pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)
|
|
1029
|
+
with cute.arch.elect_one():
|
|
1030
|
+
copy_stats(
|
|
1031
|
+
gdPsum[None, m_block_safe],
|
|
1032
|
+
sdPsum[None, producer_state_dO_dPsum.index],
|
|
1033
|
+
mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum),
|
|
1034
|
+
)
|
|
1035
|
+
producer_state_dO_dPsum.advance()
|
|
1036
|
+
|
|
1037
|
+
return producer_state_Q_LSE, producer_state_dO_dPsum
|
|
1038
|
+
|
|
1039
|
+
|
|
1040
|
+
@cute.jit
|
|
1041
|
+
def get_block_sparse_iteration_info_bwd(
|
|
1042
|
+
blocksparse_tensors: BlockSparseTensors,
|
|
1043
|
+
batch_idx,
|
|
1044
|
+
head_idx,
|
|
1045
|
+
n_block,
|
|
1046
|
+
subtile_factor: cutlass.Constexpr = 1,
|
|
1047
|
+
m_block_max: int = 0,
|
|
1048
|
+
):
|
|
1049
|
+
"""Extract block-sparse iteration info for backward pass.
|
|
1050
|
+
|
|
1051
|
+
Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count).
|
|
1052
|
+
"""
|
|
1053
|
+
q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
|
|
1054
|
+
curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
|
|
1055
|
+
curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
|
|
1056
|
+
|
|
1057
|
+
if const_expr(full_cnt is not None):
|
|
1058
|
+
curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
|
|
1059
|
+
curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
|
|
1060
|
+
else:
|
|
1061
|
+
curr_full_cnt = Int32(0)
|
|
1062
|
+
curr_full_idx = None
|
|
1063
|
+
|
|
1064
|
+
sparse_block_count = curr_q_cnt
|
|
1065
|
+
if const_expr(full_cnt is not None):
|
|
1066
|
+
sparse_block_count = sparse_block_count + curr_full_cnt
|
|
1067
|
+
total_count = sparse_block_count * subtile_factor
|
|
1068
|
+
|
|
1069
|
+
return curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count
|
|
1070
|
+
|
|
1071
|
+
|
|
1072
|
+
@cute.jit
|
|
1073
|
+
def get_m_block_from_iter_bwd(
|
|
1074
|
+
iter_idx,
|
|
1075
|
+
curr_q_cnt,
|
|
1076
|
+
curr_q_idx: cute.Tensor,
|
|
1077
|
+
curr_full_cnt,
|
|
1078
|
+
curr_full_idx: Optional[cute.Tensor],
|
|
1079
|
+
subtile_factor: cutlass.Constexpr = 1,
|
|
1080
|
+
m_block_max: int = 0,
|
|
1081
|
+
):
|
|
1082
|
+
"""Derive m_block index and is_full_block flag from iteration index.
|
|
1083
|
+
|
|
1084
|
+
Returns (m_block, is_full_block):
|
|
1085
|
+
- m_block: The actual Q-tile block index
|
|
1086
|
+
- is_full_block: True if this is a full block (no mask_mod needed)
|
|
1087
|
+
"""
|
|
1088
|
+
sparse_iter_idx = iter_idx // subtile_factor
|
|
1089
|
+
subtile_offset = iter_idx % subtile_factor
|
|
1090
|
+
|
|
1091
|
+
sparse_m_block = Int32(0)
|
|
1092
|
+
is_full_block = False
|
|
1093
|
+
if const_expr(curr_full_idx is not None):
|
|
1094
|
+
if sparse_iter_idx < curr_q_cnt:
|
|
1095
|
+
sparse_m_block = curr_q_idx[sparse_iter_idx]
|
|
1096
|
+
else:
|
|
1097
|
+
sparse_m_block = curr_full_idx[sparse_iter_idx - curr_q_cnt]
|
|
1098
|
+
is_full_block = True
|
|
1099
|
+
else:
|
|
1100
|
+
sparse_m_block = curr_q_idx[sparse_iter_idx]
|
|
1101
|
+
|
|
1102
|
+
return sparse_m_block * subtile_factor + subtile_offset, is_full_block
|
|
1103
|
+
|
|
1104
|
+
|
|
1105
|
+
@cute.jit
|
|
1106
|
+
def _load_q_do_block_sm90(
|
|
1107
|
+
m_block,
|
|
1108
|
+
producer_state_Q,
|
|
1109
|
+
producer_state_dO,
|
|
1110
|
+
pipeline_Q,
|
|
1111
|
+
pipeline_dO,
|
|
1112
|
+
load_K,
|
|
1113
|
+
load_V,
|
|
1114
|
+
load_Q,
|
|
1115
|
+
load_dO,
|
|
1116
|
+
load_LSE,
|
|
1117
|
+
load_dPsum,
|
|
1118
|
+
tma_copy_bytes_K,
|
|
1119
|
+
tma_copy_bytes_V,
|
|
1120
|
+
Q_stage_eq_dO_stage: cutlass.Constexpr,
|
|
1121
|
+
load_kv: bool,
|
|
1122
|
+
):
|
|
1123
|
+
"""Load one Q/dO block, optionally loading K/V on first iteration."""
|
|
1124
|
+
if load_kv:
|
|
1125
|
+
pipeline_Q.producer_acquire(producer_state_Q, extra_tx_count=tma_copy_bytes_K)
|
|
1126
|
+
load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))
|
|
1127
|
+
else:
|
|
1128
|
+
pipeline_Q.producer_acquire(producer_state_Q)
|
|
1129
|
+
load_Q(m_block, producer_state=producer_state_Q)
|
|
1130
|
+
with cute.arch.elect_one():
|
|
1131
|
+
load_LSE(m_block, producer_state=producer_state_Q)
|
|
1132
|
+
|
|
1133
|
+
producer_state_dO_cur = (
|
|
1134
|
+
producer_state_dO if const_expr(not Q_stage_eq_dO_stage) else producer_state_Q
|
|
1135
|
+
)
|
|
1136
|
+
if load_kv:
|
|
1137
|
+
pipeline_dO.producer_acquire(producer_state_dO_cur, extra_tx_count=tma_copy_bytes_V)
|
|
1138
|
+
load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur))
|
|
1139
|
+
else:
|
|
1140
|
+
pipeline_dO.producer_acquire(producer_state_dO_cur)
|
|
1141
|
+
load_dO(m_block, producer_state=producer_state_dO_cur)
|
|
1142
|
+
with cute.arch.elect_one():
|
|
1143
|
+
load_dPsum(m_block, producer_state=producer_state_dO_cur)
|
|
1144
|
+
|
|
1145
|
+
producer_state_Q.advance()
|
|
1146
|
+
producer_state_dO.advance()
|
|
1147
|
+
return producer_state_Q, producer_state_dO
|
|
1148
|
+
|
|
1149
|
+
|
|
1150
|
+
@cute.jit
|
|
1151
|
+
def produce_block_sparse_q_loads_bwd_sm90(
|
|
1152
|
+
blocksparse_tensors: BlockSparseTensors,
|
|
1153
|
+
batch_idx,
|
|
1154
|
+
head_idx,
|
|
1155
|
+
n_block,
|
|
1156
|
+
producer_state_Q,
|
|
1157
|
+
producer_state_dO,
|
|
1158
|
+
pipeline_Q,
|
|
1159
|
+
pipeline_dO,
|
|
1160
|
+
load_K,
|
|
1161
|
+
load_V,
|
|
1162
|
+
load_Q,
|
|
1163
|
+
load_dO,
|
|
1164
|
+
load_LSE,
|
|
1165
|
+
load_dPsum,
|
|
1166
|
+
tma_copy_bytes_K,
|
|
1167
|
+
tma_copy_bytes_V,
|
|
1168
|
+
Q_stage_eq_dO_stage: cutlass.Constexpr,
|
|
1169
|
+
subtile_factor: cutlass.Constexpr,
|
|
1170
|
+
m_block_max: int,
|
|
1171
|
+
):
|
|
1172
|
+
"""SM90 backward block sparse loading with separate partial/full loops.
|
|
1173
|
+
|
|
1174
|
+
K/V are loaded with the first valid block. Iterates partial blocks first,
|
|
1175
|
+
then full blocks, matching consumer order.
|
|
1176
|
+
|
|
1177
|
+
Returns updated (producer_state_Q, producer_state_dO).
|
|
1178
|
+
"""
|
|
1179
|
+
q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
|
|
1180
|
+
curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
|
|
1181
|
+
curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
|
|
1182
|
+
|
|
1183
|
+
if const_expr(full_cnt is not None):
|
|
1184
|
+
curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
|
|
1185
|
+
curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
|
|
1186
|
+
else:
|
|
1187
|
+
curr_full_cnt = Int32(0)
|
|
1188
|
+
curr_full_idx = None
|
|
1189
|
+
|
|
1190
|
+
kv_loaded = False
|
|
1191
|
+
|
|
1192
|
+
for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):
|
|
1193
|
+
sparse_idx = iter_idx // subtile_factor
|
|
1194
|
+
subtile_offset = iter_idx % subtile_factor
|
|
1195
|
+
m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset
|
|
1196
|
+
|
|
1197
|
+
if m_block < m_block_max:
|
|
1198
|
+
producer_state_Q, producer_state_dO = _load_q_do_block_sm90(
|
|
1199
|
+
m_block,
|
|
1200
|
+
producer_state_Q,
|
|
1201
|
+
producer_state_dO,
|
|
1202
|
+
pipeline_Q,
|
|
1203
|
+
pipeline_dO,
|
|
1204
|
+
load_K,
|
|
1205
|
+
load_V,
|
|
1206
|
+
load_Q,
|
|
1207
|
+
load_dO,
|
|
1208
|
+
load_LSE,
|
|
1209
|
+
load_dPsum,
|
|
1210
|
+
tma_copy_bytes_K,
|
|
1211
|
+
tma_copy_bytes_V,
|
|
1212
|
+
Q_stage_eq_dO_stage,
|
|
1213
|
+
load_kv=not kv_loaded,
|
|
1214
|
+
)
|
|
1215
|
+
kv_loaded = True
|
|
1216
|
+
|
|
1217
|
+
if const_expr(full_cnt is not None):
|
|
1218
|
+
for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):
|
|
1219
|
+
sparse_idx = iter_idx // subtile_factor
|
|
1220
|
+
subtile_offset = iter_idx % subtile_factor
|
|
1221
|
+
m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset
|
|
1222
|
+
|
|
1223
|
+
if m_block < m_block_max:
|
|
1224
|
+
producer_state_Q, producer_state_dO = _load_q_do_block_sm90(
|
|
1225
|
+
m_block,
|
|
1226
|
+
producer_state_Q,
|
|
1227
|
+
producer_state_dO,
|
|
1228
|
+
pipeline_Q,
|
|
1229
|
+
pipeline_dO,
|
|
1230
|
+
load_K,
|
|
1231
|
+
load_V,
|
|
1232
|
+
load_Q,
|
|
1233
|
+
load_dO,
|
|
1234
|
+
load_LSE,
|
|
1235
|
+
load_dPsum,
|
|
1236
|
+
tma_copy_bytes_K,
|
|
1237
|
+
tma_copy_bytes_V,
|
|
1238
|
+
Q_stage_eq_dO_stage,
|
|
1239
|
+
load_kv=not kv_loaded,
|
|
1240
|
+
)
|
|
1241
|
+
kv_loaded = True
|
|
1242
|
+
|
|
1243
|
+
return producer_state_Q, producer_state_dO
|
|
1244
|
+
|
|
1245
|
+
|
|
1246
|
+
@cute.jit
|
|
1247
|
+
def consume_block_sparse_mma_bwd_sm90(
|
|
1248
|
+
blocksparse_tensors: BlockSparseTensors,
|
|
1249
|
+
batch_idx,
|
|
1250
|
+
head_idx,
|
|
1251
|
+
n_block,
|
|
1252
|
+
consumer_state_Q,
|
|
1253
|
+
consumer_state_dO,
|
|
1254
|
+
mma_one_m_block_fn,
|
|
1255
|
+
mask,
|
|
1256
|
+
mask_mod,
|
|
1257
|
+
is_causal: cutlass.Constexpr,
|
|
1258
|
+
is_local: cutlass.Constexpr,
|
|
1259
|
+
thr_mma_SdP,
|
|
1260
|
+
softmax_scale,
|
|
1261
|
+
seqlen,
|
|
1262
|
+
subtile_factor: cutlass.Constexpr,
|
|
1263
|
+
m_block_max: int,
|
|
1264
|
+
aux_tensors=None,
|
|
1265
|
+
fastdiv_mods=(None, None),
|
|
1266
|
+
):
|
|
1267
|
+
"""SM90 backward block sparse MMA consumption with separate partial/full loops.
|
|
1268
|
+
|
|
1269
|
+
Partial blocks are processed first (with mask_mod applied), then full blocks
|
|
1270
|
+
(without mask_mod). This ensures mask_mod is only applied where needed.
|
|
1271
|
+
|
|
1272
|
+
Returns updated (consumer_state_Q, consumer_state_dO).
|
|
1273
|
+
"""
|
|
1274
|
+
q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
|
|
1275
|
+
curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
|
|
1276
|
+
curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
|
|
1277
|
+
|
|
1278
|
+
if const_expr(full_cnt is not None):
|
|
1279
|
+
curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
|
|
1280
|
+
curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
|
|
1281
|
+
else:
|
|
1282
|
+
curr_full_cnt = Int32(0)
|
|
1283
|
+
curr_full_idx = None
|
|
1284
|
+
|
|
1285
|
+
dKV_accumulate = False
|
|
1286
|
+
|
|
1287
|
+
mask_fn_partial = partial(
|
|
1288
|
+
mask.apply_mask,
|
|
1289
|
+
batch_idx=batch_idx,
|
|
1290
|
+
head_idx=head_idx,
|
|
1291
|
+
n_block=n_block,
|
|
1292
|
+
thr_mma=thr_mma_SdP,
|
|
1293
|
+
mask_seqlen=True,
|
|
1294
|
+
mask_causal=is_causal,
|
|
1295
|
+
mask_local=is_local,
|
|
1296
|
+
mask_mod=mask_mod,
|
|
1297
|
+
aux_tensors=aux_tensors,
|
|
1298
|
+
fastdiv_mods=fastdiv_mods,
|
|
1299
|
+
)
|
|
1300
|
+
|
|
1301
|
+
mask_fn_full = partial(
|
|
1302
|
+
mask.apply_mask,
|
|
1303
|
+
batch_idx=batch_idx,
|
|
1304
|
+
head_idx=head_idx,
|
|
1305
|
+
n_block=n_block,
|
|
1306
|
+
thr_mma=thr_mma_SdP,
|
|
1307
|
+
mask_seqlen=True,
|
|
1308
|
+
mask_causal=is_causal,
|
|
1309
|
+
mask_local=is_local,
|
|
1310
|
+
aux_tensors=aux_tensors,
|
|
1311
|
+
fastdiv_mods=fastdiv_mods,
|
|
1312
|
+
)
|
|
1313
|
+
|
|
1314
|
+
for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):
|
|
1315
|
+
sparse_idx = iter_idx // subtile_factor
|
|
1316
|
+
subtile_offset = iter_idx % subtile_factor
|
|
1317
|
+
m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset
|
|
1318
|
+
|
|
1319
|
+
if m_block < m_block_max:
|
|
1320
|
+
consumer_state_Q, consumer_state_dO = mma_one_m_block_fn(
|
|
1321
|
+
m_block,
|
|
1322
|
+
consumer_state_Q,
|
|
1323
|
+
consumer_state_dO,
|
|
1324
|
+
mask_fn=mask_fn_partial,
|
|
1325
|
+
dKV_accumulate=dKV_accumulate,
|
|
1326
|
+
thr_mma_SdP=thr_mma_SdP,
|
|
1327
|
+
batch_idx=batch_idx,
|
|
1328
|
+
head_idx=head_idx,
|
|
1329
|
+
n_block=n_block,
|
|
1330
|
+
softmax_scale=softmax_scale,
|
|
1331
|
+
seqlen=seqlen,
|
|
1332
|
+
aux_tensors=aux_tensors,
|
|
1333
|
+
fastdiv_mods=fastdiv_mods,
|
|
1334
|
+
)
|
|
1335
|
+
dKV_accumulate = True
|
|
1336
|
+
|
|
1337
|
+
if const_expr(full_cnt is not None):
|
|
1338
|
+
for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):
|
|
1339
|
+
sparse_idx = iter_idx // subtile_factor
|
|
1340
|
+
subtile_offset = iter_idx % subtile_factor
|
|
1341
|
+
m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset
|
|
1342
|
+
|
|
1343
|
+
if m_block < m_block_max:
|
|
1344
|
+
consumer_state_Q, consumer_state_dO = mma_one_m_block_fn(
|
|
1345
|
+
m_block,
|
|
1346
|
+
consumer_state_Q,
|
|
1347
|
+
consumer_state_dO,
|
|
1348
|
+
mask_fn=mask_fn_full,
|
|
1349
|
+
dKV_accumulate=dKV_accumulate,
|
|
1350
|
+
thr_mma_SdP=thr_mma_SdP,
|
|
1351
|
+
batch_idx=batch_idx,
|
|
1352
|
+
head_idx=head_idx,
|
|
1353
|
+
n_block=n_block,
|
|
1354
|
+
softmax_scale=softmax_scale,
|
|
1355
|
+
seqlen=seqlen,
|
|
1356
|
+
aux_tensors=aux_tensors,
|
|
1357
|
+
fastdiv_mods=fastdiv_mods,
|
|
1358
|
+
)
|
|
1359
|
+
dKV_accumulate = True
|
|
1360
|
+
|
|
1361
|
+
return consumer_state_Q, consumer_state_dO
|
|
1362
|
+
|
|
1363
|
+
|
|
1364
|
+
@cute.jit
|
|
1365
|
+
def _store_one_dQaccum_sm90(
|
|
1366
|
+
m_block,
|
|
1367
|
+
sdQaccum: cute.Tensor,
|
|
1368
|
+
gdQaccum: cute.Tensor,
|
|
1369
|
+
num_mma_warp_groups: cutlass.Constexpr,
|
|
1370
|
+
num_threads_per_warp_group: cutlass.Constexpr,
|
|
1371
|
+
tma_copy_bytes_dQ,
|
|
1372
|
+
):
|
|
1373
|
+
"""Store dQaccum for a single m_block."""
|
|
1374
|
+
for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):
|
|
1375
|
+
cute.arch.barrier(
|
|
1376
|
+
barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
|
|
1377
|
+
number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
|
|
1378
|
+
)
|
|
1379
|
+
with cute.arch.elect_one():
|
|
1380
|
+
copy_utils.cpasync_reduce_bulk_add_f32(
|
|
1381
|
+
sdQaccum[None, warp_group_idx].iterator,
|
|
1382
|
+
gdQaccum[None, warp_group_idx, m_block].iterator,
|
|
1383
|
+
tma_copy_bytes_dQ,
|
|
1384
|
+
)
|
|
1385
|
+
cute.arch.cp_async_bulk_commit_group()
|
|
1386
|
+
for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):
|
|
1387
|
+
cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True)
|
|
1388
|
+
cute.arch.barrier_arrive(
|
|
1389
|
+
barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
|
|
1390
|
+
number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
|
|
1391
|
+
)
|
|
1392
|
+
|
|
1393
|
+
|
|
1394
|
+
@cute.jit
|
|
1395
|
+
def dQaccum_store_block_sparse_bwd_sm90(
|
|
1396
|
+
blocksparse_tensors: BlockSparseTensors,
|
|
1397
|
+
batch_idx,
|
|
1398
|
+
head_idx,
|
|
1399
|
+
n_block,
|
|
1400
|
+
sdQaccum: cute.Tensor,
|
|
1401
|
+
gdQaccum: cute.Tensor,
|
|
1402
|
+
subtile_factor: cutlass.Constexpr,
|
|
1403
|
+
m_block_max: int,
|
|
1404
|
+
num_mma_warp_groups: cutlass.Constexpr,
|
|
1405
|
+
num_threads_per_warp_group: cutlass.Constexpr,
|
|
1406
|
+
tma_copy_bytes_dQ,
|
|
1407
|
+
):
|
|
1408
|
+
"""SM90 backward block sparse dQaccum store with separate partial/full loops.
|
|
1409
|
+
|
|
1410
|
+
Iterates partial blocks first, then full blocks, matching producer/consumer order.
|
|
1411
|
+
"""
|
|
1412
|
+
q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
|
|
1413
|
+
curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
|
|
1414
|
+
curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
|
|
1415
|
+
|
|
1416
|
+
if const_expr(full_cnt is not None):
|
|
1417
|
+
curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
|
|
1418
|
+
curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
|
|
1419
|
+
else:
|
|
1420
|
+
curr_full_cnt = Int32(0)
|
|
1421
|
+
curr_full_idx = None
|
|
1422
|
+
|
|
1423
|
+
for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):
|
|
1424
|
+
sparse_idx = iter_idx // subtile_factor
|
|
1425
|
+
subtile_offset = iter_idx % subtile_factor
|
|
1426
|
+
m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset
|
|
1427
|
+
|
|
1428
|
+
if m_block < m_block_max:
|
|
1429
|
+
_store_one_dQaccum_sm90(
|
|
1430
|
+
m_block,
|
|
1431
|
+
sdQaccum,
|
|
1432
|
+
gdQaccum,
|
|
1433
|
+
num_mma_warp_groups,
|
|
1434
|
+
num_threads_per_warp_group,
|
|
1435
|
+
tma_copy_bytes_dQ,
|
|
1436
|
+
)
|
|
1437
|
+
|
|
1438
|
+
if const_expr(full_cnt is not None):
|
|
1439
|
+
for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):
|
|
1440
|
+
sparse_idx = iter_idx // subtile_factor
|
|
1441
|
+
subtile_offset = iter_idx % subtile_factor
|
|
1442
|
+
m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset
|
|
1443
|
+
|
|
1444
|
+
if m_block < m_block_max:
|
|
1445
|
+
_store_one_dQaccum_sm90(
|
|
1446
|
+
m_block,
|
|
1447
|
+
sdQaccum,
|
|
1448
|
+
gdQaccum,
|
|
1449
|
+
num_mma_warp_groups,
|
|
1450
|
+
num_threads_per_warp_group,
|
|
1451
|
+
tma_copy_bytes_dQ,
|
|
1452
|
+
)
|