mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1452 @@
1
+ # @nolint # fbcode
2
+ """
3
+ Block-sparse runtime utilities for CUTE DSL kernels.
4
+
5
+ This module contains runtime execution functions for block-sparse attention kernels.
6
+ These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads.
7
+ """
8
+
9
+ from typing import Callable, Optional
10
+ from functools import partial
11
+ import math
12
+ import cutlass
13
+ import cutlass.cute as cute
14
+ from cutlass import Float32, Int32, const_expr
15
+
16
+ # Import data structures from block_sparsity
17
+ from mslk.attention.flash_attn.block_sparsity import BlockSparseTensors
18
+ from mslk.attention.flash_attn import utils
19
+ from mslk.attention.flash_attn import copy_utils
20
+ from mslk.attention.flash_attn.named_barrier import NamedBarrierBwd
21
+
22
+
23
+ @cute.jit
24
+ def load_block_list(
25
+ block_indices: cute.Tensor,
26
+ block_count,
27
+ load_q_with_first: cutlass.Constexpr,
28
+ first_block_preloaded: cutlass.Constexpr,
29
+ kv_producer_state,
30
+ load_Q,
31
+ load_K,
32
+ load_V,
33
+ pipeline_k,
34
+ pipeline_v,
35
+ use_tma_q: cutlass.Constexpr,
36
+ tma_q_bytes: cutlass.Constexpr,
37
+ intra_wg_overlap: cutlass.Constexpr,
38
+ ):
39
+ """Iterate over the sparse blocks and load K, V (and Q) into the pipeline.
40
+ for the intra_wg_overlap case, we overlap the loads of K and V. And this
41
+ means we need to pipeline the last V load from the partial block case,
42
+ with the loads for the full blocks. Set first_block_preloaded when the
43
+ caller has already issued the first K load for the list.
44
+
45
+ Note:
46
+ we iterate along the block_n indices in reverse.
47
+
48
+ Returns:
49
+ Updated kv_producer_state after processing the block list.
50
+
51
+ """
52
+ if block_count > 0:
53
+ if const_expr(not intra_wg_overlap):
54
+ # Peel first iteration: the first block may need to load Q alongside K,
55
+ # Parameters are already Constexpr, so no need to wrap in const_expr()
56
+ n_block_first = block_indices[block_count - 1]
57
+ extra_tx = tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0
58
+ pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx)
59
+
60
+ if const_expr(load_q_with_first and use_tma_q):
61
+ load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
62
+
63
+ load_K(src_idx=n_block_first, producer_state=kv_producer_state)
64
+ pipeline_v.producer_acquire(kv_producer_state)
65
+ load_V(src_idx=n_block_first, producer_state=kv_producer_state)
66
+ kv_producer_state.advance()
67
+
68
+ for offset in cutlass.range(1, block_count):
69
+ n_block = block_indices[block_count - 1 - offset]
70
+ pipeline_k.producer_acquire(kv_producer_state)
71
+ load_K(src_idx=n_block, producer_state=kv_producer_state)
72
+ pipeline_v.producer_acquire(kv_producer_state)
73
+ load_V(src_idx=n_block, producer_state=kv_producer_state)
74
+ kv_producer_state.advance()
75
+ else:
76
+ n_block_first = block_indices[block_count - 1]
77
+ if const_expr(not first_block_preloaded):
78
+ extra_tx = (
79
+ tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0
80
+ )
81
+ pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx)
82
+
83
+ if const_expr(load_q_with_first and use_tma_q):
84
+ load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
85
+
86
+ load_K(src_idx=n_block_first, producer_state=kv_producer_state)
87
+
88
+ for idx in cutlass.range(block_count - 1, unroll=1):
89
+ n_block_prev = block_indices[block_count - 1 - idx]
90
+ n_block = block_indices[block_count - 2 - idx]
91
+ kv_producer_state_prev = kv_producer_state.clone()
92
+ kv_producer_state.advance()
93
+ pipeline_k.producer_acquire(kv_producer_state)
94
+ load_K(src_idx=n_block, producer_state=kv_producer_state)
95
+ pipeline_v.producer_acquire(kv_producer_state_prev)
96
+ load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev)
97
+
98
+ return kv_producer_state
99
+
100
+
101
+ @cute.jit
102
+ def finish_overlap_v_load(
103
+ block_indices: cute.Tensor,
104
+ block_count,
105
+ load_V,
106
+ pipeline_v,
107
+ kv_producer_state,
108
+ ):
109
+ """Load the final V block after overlapped K/V loads."""
110
+ if block_count > 0:
111
+ n_block_last = block_indices[0]
112
+ pipeline_v.producer_acquire(kv_producer_state)
113
+ load_V(src_idx=n_block_last, producer_state=kv_producer_state)
114
+ kv_producer_state.advance()
115
+
116
+ return kv_producer_state
117
+
118
+
119
+ @cute.jit
120
+ def sparse_tensor_m_block(
121
+ m_block,
122
+ qhead_per_kvhead: cutlass.Constexpr[int],
123
+ ):
124
+ """Map packed m_block indices to block-sparse tensor indices."""
125
+ if const_expr(qhead_per_kvhead != 1):
126
+ return m_block // qhead_per_kvhead
127
+ return m_block
128
+
129
+
130
+ @cute.jit
131
+ def produce_block_sparse_loads(
132
+ blocksparse_tensors: BlockSparseTensors,
133
+ batch_idx,
134
+ head_idx,
135
+ m_block,
136
+ kv_producer_state,
137
+ load_Q,
138
+ load_K,
139
+ load_V,
140
+ pipeline_k,
141
+ pipeline_v,
142
+ use_tma_q: cutlass.Constexpr,
143
+ tma_q_bytes: cutlass.Constexpr,
144
+ intra_wg_overlap: cutlass.Constexpr,
145
+ qhead_per_kvhead: cutlass.Constexpr[int] = 1,
146
+ ):
147
+ """Iterate over the mask and full block lists for a single tile.
148
+
149
+ The masked (partial) list may leave the last V load pending when intra-warp-group
150
+ overlap is enabled. The first full block must consume that pending V while
151
+ issuing its own K load on the next pipeline stage.
152
+
153
+ In the intra-wg-overlap path, the last masked block leaves its V copy in flight
154
+ while we advance the producer state to start the next full K. Either the full list
155
+ overlaps that pending V load, or, if no full blocks exist, we explicitly drain it.
156
+
157
+ Args:
158
+ qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and
159
+ must be converted to unpacked for sparse tensor indexing.
160
+ """
161
+
162
+ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
163
+
164
+ m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead)
165
+
166
+ curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
167
+ curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
168
+
169
+ if const_expr(full_block_cnt is not None):
170
+ curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
171
+ curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
172
+ else:
173
+ curr_full_block_cnt = Int32(0)
174
+ curr_full_block_idx = None
175
+
176
+ mask_empty = curr_mask_block_cnt == 0
177
+ full_empty = curr_full_block_cnt == 0
178
+
179
+ if mask_empty:
180
+ # No masked blocks: the full list owns the initial Q+K load.
181
+ kv_producer_state = load_block_list(
182
+ curr_full_block_idx,
183
+ curr_full_block_cnt,
184
+ load_q_with_first=True,
185
+ first_block_preloaded=False,
186
+ kv_producer_state=kv_producer_state,
187
+ load_Q=load_Q,
188
+ load_K=load_K,
189
+ load_V=load_V,
190
+ pipeline_k=pipeline_k,
191
+ pipeline_v=pipeline_v,
192
+ use_tma_q=use_tma_q,
193
+ tma_q_bytes=tma_q_bytes,
194
+ intra_wg_overlap=intra_wg_overlap,
195
+ )
196
+
197
+ if const_expr(intra_wg_overlap) and curr_full_block_cnt > 0:
198
+ kv_producer_state = finish_overlap_v_load(
199
+ curr_full_block_idx,
200
+ curr_full_block_cnt,
201
+ load_V,
202
+ pipeline_v,
203
+ kv_producer_state,
204
+ )
205
+ else:
206
+ # Masked blocks present: load Q together with the first masked K so consumers can
207
+ # start immediately. When overlap is disabled this fully drains the list.
208
+ kv_producer_state = load_block_list(
209
+ curr_mask_block_idx,
210
+ curr_mask_block_cnt,
211
+ load_q_with_first=True,
212
+ first_block_preloaded=False,
213
+ kv_producer_state=kv_producer_state,
214
+ load_Q=load_Q,
215
+ load_K=load_K,
216
+ load_V=load_V,
217
+ pipeline_k=pipeline_k,
218
+ pipeline_v=pipeline_v,
219
+ use_tma_q=use_tma_q,
220
+ tma_q_bytes=tma_q_bytes,
221
+ intra_wg_overlap=intra_wg_overlap,
222
+ )
223
+
224
+ if full_empty:
225
+ if const_expr(intra_wg_overlap):
226
+ kv_producer_state = finish_overlap_v_load(
227
+ curr_mask_block_idx,
228
+ curr_mask_block_cnt,
229
+ load_V,
230
+ pipeline_v,
231
+ kv_producer_state,
232
+ )
233
+ else:
234
+ if const_expr(intra_wg_overlap):
235
+ # Bridge the masked list to the full list by overlapping the pending masked V
236
+ # with the first full K load.
237
+ n_block_mask_last = curr_mask_block_idx[0]
238
+ n_block_full_first = curr_full_block_idx[curr_full_block_cnt - 1]
239
+ kv_producer_state_prev = kv_producer_state.clone()
240
+ kv_producer_state.advance()
241
+ pipeline_k.producer_acquire(kv_producer_state)
242
+ load_K(src_idx=n_block_full_first, producer_state=kv_producer_state)
243
+ pipeline_v.producer_acquire(kv_producer_state_prev)
244
+ load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev)
245
+
246
+ kv_producer_state = load_block_list(
247
+ curr_full_block_idx,
248
+ curr_full_block_cnt,
249
+ load_q_with_first=False,
250
+ first_block_preloaded=True,
251
+ kv_producer_state=kv_producer_state,
252
+ load_Q=load_Q,
253
+ load_K=load_K,
254
+ load_V=load_V,
255
+ pipeline_k=pipeline_k,
256
+ pipeline_v=pipeline_v,
257
+ use_tma_q=use_tma_q,
258
+ tma_q_bytes=tma_q_bytes,
259
+ intra_wg_overlap=intra_wg_overlap,
260
+ )
261
+
262
+ kv_producer_state = finish_overlap_v_load(
263
+ curr_full_block_idx,
264
+ curr_full_block_cnt,
265
+ load_V,
266
+ pipeline_v,
267
+ kv_producer_state,
268
+ )
269
+ else:
270
+ # Non-overlap path with both lists: run the full list normally (skipping the Q
271
+ # reload because the masked list already issued it).
272
+ kv_producer_state = load_block_list(
273
+ curr_full_block_idx,
274
+ curr_full_block_cnt,
275
+ load_q_with_first=False,
276
+ first_block_preloaded=False,
277
+ kv_producer_state=kv_producer_state,
278
+ load_Q=load_Q,
279
+ load_K=load_K,
280
+ load_V=load_V,
281
+ pipeline_k=pipeline_k,
282
+ pipeline_v=pipeline_v,
283
+ use_tma_q=use_tma_q,
284
+ tma_q_bytes=tma_q_bytes,
285
+ intra_wg_overlap=intra_wg_overlap,
286
+ )
287
+
288
+ return kv_producer_state
289
+
290
+
291
+ @cute.jit
292
+ def consume_block_sparse_loads(
293
+ blocksparse_tensors: BlockSparseTensors,
294
+ batch_idx,
295
+ head_idx,
296
+ m_block,
297
+ seqlen,
298
+ kv_consumer_state,
299
+ mma_pv_fn,
300
+ mma_one_n_block,
301
+ process_first_half_block,
302
+ process_last_half_block,
303
+ mask_fn,
304
+ score_mod_fn,
305
+ O_should_accumulate,
306
+ mask_mod,
307
+ fastdiv_mods,
308
+ intra_wg_overlap: cutlass.Constexpr,
309
+ warp_scheduler_barrier_sync: Callable,
310
+ warp_scheduler_barrier_arrive: Callable,
311
+ qhead_per_kvhead: cutlass.Constexpr[int] = 1,
312
+ ):
313
+ """Consume the mask and full block lists for a single tile on the consumer side.
314
+
315
+ Mirrors `produce_block_sparse_loads` so that the consumer pipeline uses
316
+ the same sparse tensor indexing.
317
+
318
+ Args:
319
+ qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and
320
+ must be converted to unpacked for sparse tensor indexing.
321
+ """
322
+
323
+ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
324
+
325
+ m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead)
326
+
327
+ curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
328
+ curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
329
+ curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
330
+ curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
331
+
332
+ processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0
333
+
334
+ if const_expr(not intra_wg_overlap):
335
+ if curr_mask_block_cnt > 0:
336
+ mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
337
+ warp_scheduler_barrier_sync()
338
+ kv_consumer_state = mma_one_n_block(
339
+ kv_consumer_state,
340
+ n_block=mask_n_block,
341
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
342
+ mask_fn=partial(
343
+ mask_fn,
344
+ mask_mod=mask_mod,
345
+ mask_seqlen=True,
346
+ fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None,
347
+ ),
348
+ is_first_n_block=True,
349
+ )
350
+ O_should_accumulate = True
351
+ for i in cutlass.range(1, curr_mask_block_cnt):
352
+ mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
353
+ kv_consumer_state = mma_one_n_block(
354
+ kv_consumer_state,
355
+ n_block=mask_n_block,
356
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
357
+ mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False),
358
+ is_first_n_block=False,
359
+ )
360
+ O_should_accumulate = True
361
+ if curr_full_block_cnt == 0:
362
+ warp_scheduler_barrier_arrive()
363
+
364
+ if curr_full_block_cnt > 0:
365
+ full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
366
+ if curr_mask_block_cnt == 0:
367
+ warp_scheduler_barrier_sync()
368
+ kv_consumer_state = mma_one_n_block(
369
+ kv_consumer_state,
370
+ n_block=full_n_block,
371
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
372
+ mask_fn=partial(mask_fn, mask_seqlen=True),
373
+ is_first_n_block=True,
374
+ )
375
+ O_should_accumulate = True
376
+ for i in cutlass.range(1, curr_full_block_cnt):
377
+ full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
378
+ kv_consumer_state = mma_one_n_block(
379
+ kv_consumer_state,
380
+ n_block=full_n_block,
381
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
382
+ mask_fn=partial(mask_fn, mask_seqlen=False),
383
+ is_first_n_block=False,
384
+ )
385
+ O_should_accumulate = True
386
+ else:
387
+ kv_consumer_state = mma_one_n_block(
388
+ kv_consumer_state,
389
+ n_block=full_n_block,
390
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
391
+ mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
392
+ is_first_n_block=False,
393
+ )
394
+ O_should_accumulate = True
395
+ for i in cutlass.range(1, curr_full_block_cnt):
396
+ full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
397
+ kv_consumer_state = mma_one_n_block(
398
+ kv_consumer_state,
399
+ n_block=full_n_block,
400
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
401
+ mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False),
402
+ is_first_n_block=False,
403
+ )
404
+ O_should_accumulate = True
405
+ warp_scheduler_barrier_arrive()
406
+ else:
407
+ if curr_mask_block_cnt > 0:
408
+ mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
409
+ kv_consumer_state = process_first_half_block(
410
+ n_block=mask_n_block,
411
+ seqlen=seqlen,
412
+ kv_consumer_state=kv_consumer_state,
413
+ mask_fn=partial(
414
+ mask_fn,
415
+ mask_mod=mask_mod,
416
+ mask_seqlen=True,
417
+ fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None,
418
+ ),
419
+ score_mod_fn=score_mod_fn,
420
+ is_first_block=True,
421
+ )
422
+ for i in cutlass.range(1, curr_mask_block_cnt):
423
+ mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
424
+ kv_consumer_state = mma_one_n_block(
425
+ kv_consumer_state,
426
+ n_block=mask_n_block,
427
+ seqlen=seqlen,
428
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
429
+ mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False),
430
+ )
431
+ O_should_accumulate = True
432
+
433
+ if curr_full_block_cnt > 0:
434
+ full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
435
+ if curr_mask_block_cnt == 0:
436
+ kv_consumer_state = process_first_half_block(
437
+ n_block=full_n_block,
438
+ seqlen=seqlen,
439
+ kv_consumer_state=kv_consumer_state,
440
+ mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
441
+ score_mod_fn=score_mod_fn,
442
+ is_first_block=True,
443
+ )
444
+ else:
445
+ kv_consumer_state = mma_one_n_block(
446
+ kv_consumer_state,
447
+ n_block=full_n_block,
448
+ seqlen=seqlen,
449
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
450
+ mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
451
+ )
452
+ O_should_accumulate = True
453
+ for i in cutlass.range(1, curr_full_block_cnt):
454
+ full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
455
+ kv_consumer_state = mma_one_n_block(
456
+ kv_consumer_state,
457
+ n_block=full_n_block,
458
+ seqlen=seqlen,
459
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
460
+ mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False),
461
+ )
462
+ O_should_accumulate = True
463
+
464
+ if curr_mask_block_cnt + curr_full_block_cnt > 0:
465
+ kv_consumer_state = process_last_half_block(
466
+ kv_consumer_state=kv_consumer_state,
467
+ zero_init=not O_should_accumulate,
468
+ )
469
+ O_should_accumulate = True
470
+
471
+ return kv_consumer_state, O_should_accumulate, processed_any
472
+
473
+
474
+ @cute.jit
475
+ def load_block_list_sm100(
476
+ block_indices: cute.Tensor,
477
+ block_count,
478
+ load_q_with_first: cutlass.Constexpr,
479
+ m_block,
480
+ q_stage: cutlass.Constexpr,
481
+ kv_producer_state,
482
+ load_Q,
483
+ load_K,
484
+ load_V,
485
+ pipeline_kv,
486
+ ):
487
+ """SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count)."""
488
+ if block_count > 0:
489
+ # First iteration: load Q alongside K if requested
490
+ n_block_first = block_indices[block_count - 1]
491
+
492
+ if const_expr(load_q_with_first):
493
+ # SM100 loads Q0 and optionally Q1
494
+ load_Q(block=q_stage * m_block + 0, stage=0)
495
+ if const_expr(q_stage == 2):
496
+ load_Q(block=q_stage * m_block + 1, stage=1)
497
+
498
+ # SM100 doesn't use producer_acquire for pipeline_kv in load path
499
+ # The pipeline barriers are handled inside load_KV
500
+ load_K(block=n_block_first, producer_state=kv_producer_state, page_idx=None)
501
+ kv_producer_state.advance()
502
+ load_V(block=n_block_first, producer_state=kv_producer_state, page_idx=None)
503
+ kv_producer_state.advance()
504
+
505
+ # Remaining blocks
506
+ for offset in cutlass.range(1, block_count):
507
+ n_block = block_indices[block_count - 1 - offset]
508
+ load_K(block=n_block, producer_state=kv_producer_state, page_idx=None)
509
+ kv_producer_state.advance()
510
+ load_V(block=n_block, producer_state=kv_producer_state, page_idx=None)
511
+ kv_producer_state.advance()
512
+
513
+ return kv_producer_state
514
+
515
+
516
+ # SM100-specific tile processor using SM100 helpers
517
+ @cute.jit
518
+ def produce_block_sparse_loads_sm100(
519
+ blocksparse_tensors: BlockSparseTensors,
520
+ batch_idx,
521
+ head_idx,
522
+ m_block,
523
+ kv_producer_state,
524
+ load_Q,
525
+ load_K,
526
+ load_V,
527
+ pipeline_kv,
528
+ q_stage: cutlass.Constexpr,
529
+ q_producer_phase: Int32,
530
+ qhead_per_kvhead: cutlass.Constexpr,
531
+ ):
532
+ """SM100 entry point for sparse block iteration.
533
+
534
+ SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use
535
+ simplified block processing that just calls producer_acquire without extras.
536
+
537
+ Args:
538
+ m_block: which tile of m we are processing
539
+ qhead_per_kvhead: Constexpr pack factor
540
+ """
541
+ # NB: Compute unpacked index for sparse tensor access
542
+ if const_expr(qhead_per_kvhead != 1):
543
+ m_block_sparse = m_block // qhead_per_kvhead
544
+ else:
545
+ m_block_sparse = m_block
546
+
547
+ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
548
+
549
+ curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
550
+ curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
551
+
552
+ if const_expr(full_block_cnt is not None):
553
+ curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
554
+ curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
555
+ else:
556
+ curr_full_block_cnt = Int32(0)
557
+ curr_full_block_idx = None
558
+
559
+ mask_empty = curr_mask_block_cnt == 0
560
+ full_empty = curr_full_block_cnt == 0
561
+
562
+ q_phase_flipped = False
563
+
564
+ if mask_empty:
565
+ # No masked blocks: process full list with Q loading
566
+ kv_producer_state = load_block_list_sm100(
567
+ curr_full_block_idx,
568
+ curr_full_block_cnt,
569
+ load_q_with_first=True,
570
+ m_block=m_block,
571
+ q_stage=q_stage,
572
+ kv_producer_state=kv_producer_state,
573
+ load_Q=load_Q,
574
+ load_K=load_K,
575
+ load_V=load_V,
576
+ pipeline_kv=pipeline_kv,
577
+ )
578
+ q_phase_flipped = not full_empty
579
+ else:
580
+ # Process masked blocks with Q loading
581
+ kv_producer_state = load_block_list_sm100(
582
+ curr_mask_block_idx,
583
+ curr_mask_block_cnt,
584
+ load_q_with_first=True,
585
+ m_block=m_block,
586
+ q_stage=q_stage,
587
+ kv_producer_state=kv_producer_state,
588
+ load_Q=load_Q,
589
+ load_K=load_K,
590
+ load_V=load_V,
591
+ pipeline_kv=pipeline_kv,
592
+ )
593
+ q_phase_flipped = True
594
+
595
+ if not full_empty:
596
+ # Process full blocks without Q loading
597
+ kv_producer_state = load_block_list_sm100(
598
+ curr_full_block_idx,
599
+ curr_full_block_cnt,
600
+ load_q_with_first=False,
601
+ m_block=m_block,
602
+ q_stage=q_stage,
603
+ kv_producer_state=kv_producer_state,
604
+ load_Q=load_Q,
605
+ load_K=load_K,
606
+ load_V=load_V,
607
+ pipeline_kv=pipeline_kv,
608
+ )
609
+
610
+ if q_phase_flipped:
611
+ q_producer_phase ^= 1
612
+
613
+ return kv_producer_state, q_producer_phase
614
+
615
+
616
+ @cute.jit
617
+ def get_total_block_count(
618
+ blocksparse_tensors: BlockSparseTensors,
619
+ batch_idx,
620
+ head_idx,
621
+ m_block,
622
+ qhead_per_kvhead: cutlass.Constexpr,
623
+ ):
624
+ # NB: Convert packed m_block to unpacked for sparse tensor indexing
625
+ if const_expr(qhead_per_kvhead != 1):
626
+ m_block_sparse = m_block // qhead_per_kvhead
627
+ else:
628
+ m_block_sparse = m_block
629
+
630
+ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
631
+ if const_expr(full_block_cnt is not None):
632
+ return (
633
+ mask_block_cnt[batch_idx, head_idx, m_block_sparse]
634
+ + full_block_cnt[batch_idx, head_idx, m_block_sparse]
635
+ )
636
+ else:
637
+ return mask_block_cnt[batch_idx, head_idx, m_block_sparse]
638
+
639
+
640
+ @cute.jit
641
+ def handle_block_sparse_empty_tile_correction_sm100(
642
+ tidx: Int32,
643
+ q_stage: cutlass.Constexpr,
644
+ m_block_size: cutlass.Constexpr,
645
+ qhead_per_kvhead,
646
+ pack_gqa: cutlass.Constexpr,
647
+ is_split_kv: cutlass.Constexpr,
648
+ learnable_sink,
649
+ mLSE,
650
+ seqlen,
651
+ m_block: Int32,
652
+ head_idx: Int32,
653
+ batch_idx: Int32,
654
+ split_idx: Int32,
655
+ sScale: cute.Tensor,
656
+ stats: list,
657
+ correction_epilogue: Callable,
658
+ thr_mma_pv: cute.core.ThrMma,
659
+ tOtOs: tuple[cute.Tensor],
660
+ sO: cute.Tensor,
661
+ mbar_ptr,
662
+ mbar_softmax_corr_full_offset: Int32,
663
+ mbar_softmax_corr_empty_offset: Int32,
664
+ mbar_P_full_O_rescaled_offset: Int32,
665
+ mbar_P_full_2_offset: Int32,
666
+ mbar_corr_epi_full_offset: Int32,
667
+ mbar_corr_epi_empty_offset: Int32,
668
+ softmax_corr_consumer_phase: Int32,
669
+ o_corr_consumer_phase: Int32,
670
+ corr_epi_producer_phase: Int32,
671
+ softmax_scale_log2: Float32,
672
+ mO_cur: Optional[cute.Tensor] = None,
673
+ gO: Optional[cute.Tensor] = None,
674
+ gmem_tiled_copy_O: Optional[cute.TiledCopy] = None,
675
+ ):
676
+ """Handle the block-sparse case where a tile is fully masked:
677
+ * zero staged results
678
+ * seed stats
679
+ * satisfy the usual barrier protocol so downstream warps continue to make progress.
680
+ """
681
+ LOG2_E = Float32(math.log2(math.e))
682
+
683
+ for stage in cutlass.range_constexpr(q_stage):
684
+ row_sum_value = Float32(1.0)
685
+ row_max_value = (
686
+ -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None
687
+ )
688
+ if const_expr(learnable_sink is not None):
689
+ sink_val = -Float32.inf
690
+ if const_expr(not pack_gqa):
691
+ sink_val = Float32(learnable_sink[head_idx])
692
+ elif tidx < m_block_size:
693
+ q_head_idx = (
694
+ (q_stage * m_block + stage) * m_block_size + tidx
695
+ ) % qhead_per_kvhead + head_idx * qhead_per_kvhead
696
+ sink_val = Float32(learnable_sink[q_head_idx])
697
+ if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0):
698
+ if row_max_value == -Float32.inf:
699
+ row_max_value = sink_val * (LOG2_E / softmax_scale_log2)
700
+ row_sum_value = Float32(1.0)
701
+ else:
702
+ row_sum_value = row_sum_value + utils.exp2f(
703
+ sink_val * LOG2_E - row_max_value * softmax_scale_log2
704
+ )
705
+ if tidx < m_block_size:
706
+ scale_row_idx = tidx + stage * m_block_size
707
+ sScale[scale_row_idx] = row_sum_value
708
+ if const_expr(mLSE is not None or learnable_sink is not None):
709
+ sScale[scale_row_idx + m_block_size * 2] = row_max_value
710
+ acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value
711
+ stats[stage] = (row_sum_value, row_max_value, acc_flag)
712
+
713
+ cute.arch.mbarrier_wait(
714
+ mbar_ptr + mbar_softmax_corr_full_offset + stage,
715
+ softmax_corr_consumer_phase,
716
+ )
717
+ cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage)
718
+
719
+ if const_expr(gmem_tiled_copy_O is None):
720
+ cute.arch.mbarrier_wait(
721
+ mbar_ptr + mbar_corr_epi_empty_offset + stage,
722
+ corr_epi_producer_phase,
723
+ )
724
+ correction_epilogue(
725
+ thr_mma_pv,
726
+ tOtOs[stage],
727
+ tidx,
728
+ stage,
729
+ m_block,
730
+ seqlen.seqlen_q,
731
+ Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs
732
+ sO[None, None, stage],
733
+ mO_cur,
734
+ gO,
735
+ gmem_tiled_copy_O,
736
+ )
737
+ if const_expr(gmem_tiled_copy_O is None):
738
+ cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage)
739
+ cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage)
740
+ cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage)
741
+
742
+ softmax_corr_consumer_phase ^= 1
743
+ o_corr_consumer_phase ^= 1
744
+ corr_epi_producer_phase ^= 1
745
+
746
+ return (
747
+ softmax_corr_consumer_phase,
748
+ o_corr_consumer_phase,
749
+ corr_epi_producer_phase,
750
+ )
751
+
752
+
753
+ @cute.jit
754
+ def softmax_block_sparse_sm100(
755
+ blocksparse_tensors: BlockSparseTensors,
756
+ batch_idx,
757
+ head_idx,
758
+ m_block,
759
+ softmax_step: Callable,
760
+ mask_fn: Callable,
761
+ mask_fn_none: Callable,
762
+ mma_si_consumer_phase: Int32,
763
+ si_corr_producer_phase: Int32,
764
+ s0_s1_sequence_phase: Int32,
765
+ mbar_ptr,
766
+ mbar_softmax_corr_full_offset: Int32,
767
+ mbar_softmax_corr_empty_offset: Int32,
768
+ mbar_P_full_O_rescaled_offset: Int32,
769
+ mbar_P_full_2_offset: Int32,
770
+ q_stage: cutlass.Constexpr,
771
+ stage_idx: Int32,
772
+ check_m_boundary: bool,
773
+ qhead_per_kvhead: cutlass.Constexpr,
774
+ ):
775
+ # Convert packed m_block to unpacked for sparse tensor indexing
776
+ if const_expr(qhead_per_kvhead != 1):
777
+ m_block_sparse = m_block // qhead_per_kvhead
778
+ else:
779
+ m_block_sparse = m_block
780
+
781
+ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
782
+
783
+ curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
784
+ curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
785
+
786
+ if const_expr(full_block_cnt is not None):
787
+ curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
788
+ curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
789
+ else:
790
+ curr_full_block_cnt = Int32(0)
791
+ curr_full_block_idx = None
792
+
793
+ total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt
794
+
795
+ if total_block_cnt == 0:
796
+ cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_full_offset + stage_idx)
797
+ cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage_idx)
798
+ cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage_idx)
799
+ cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage_idx)
800
+ else:
801
+ if curr_mask_block_cnt > 0:
802
+ mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
803
+ (
804
+ mma_si_consumer_phase,
805
+ si_corr_producer_phase,
806
+ s0_s1_sequence_phase,
807
+ ) = softmax_step(
808
+ mma_si_consumer_phase,
809
+ si_corr_producer_phase,
810
+ s0_s1_sequence_phase,
811
+ mask_n_block,
812
+ is_first=True,
813
+ mask_fn=partial(mask_fn, mask_seqlen=True, check_q_boundary=check_m_boundary),
814
+ )
815
+ for i in cutlass.range(1, curr_mask_block_cnt):
816
+ mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
817
+ (
818
+ mma_si_consumer_phase,
819
+ si_corr_producer_phase,
820
+ s0_s1_sequence_phase,
821
+ ) = softmax_step(
822
+ mma_si_consumer_phase,
823
+ si_corr_producer_phase,
824
+ s0_s1_sequence_phase,
825
+ mask_n_block,
826
+ mask_fn=partial(mask_fn, mask_seqlen=False, check_q_boundary=check_m_boundary),
827
+ )
828
+
829
+ if curr_full_block_cnt > 0:
830
+ full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
831
+ if curr_mask_block_cnt == 0:
832
+ (
833
+ mma_si_consumer_phase,
834
+ si_corr_producer_phase,
835
+ s0_s1_sequence_phase,
836
+ ) = softmax_step(
837
+ mma_si_consumer_phase,
838
+ si_corr_producer_phase,
839
+ s0_s1_sequence_phase,
840
+ full_n_block,
841
+ is_first=True,
842
+ mask_fn=partial(
843
+ mask_fn_none, mask_seqlen=True, check_q_boundary=check_m_boundary
844
+ ),
845
+ )
846
+ else:
847
+ (
848
+ mma_si_consumer_phase,
849
+ si_corr_producer_phase,
850
+ s0_s1_sequence_phase,
851
+ ) = softmax_step(
852
+ mma_si_consumer_phase,
853
+ si_corr_producer_phase,
854
+ s0_s1_sequence_phase,
855
+ full_n_block,
856
+ is_first=False,
857
+ mask_fn=partial(
858
+ mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary
859
+ ),
860
+ )
861
+ for i in cutlass.range(1, curr_full_block_cnt):
862
+ full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
863
+ (
864
+ mma_si_consumer_phase,
865
+ si_corr_producer_phase,
866
+ s0_s1_sequence_phase,
867
+ ) = softmax_step(
868
+ mma_si_consumer_phase,
869
+ si_corr_producer_phase,
870
+ s0_s1_sequence_phase,
871
+ full_n_block,
872
+ mask_fn=partial(
873
+ mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary
874
+ ),
875
+ )
876
+
877
+ return (
878
+ mma_si_consumer_phase,
879
+ si_corr_producer_phase,
880
+ s0_s1_sequence_phase,
881
+ total_block_cnt == 0,
882
+ )
883
+
884
+
885
+ # =============================================================================
886
+ # Backward-specific block-sparse helpers (SM100)
887
+ # =============================================================================
888
+ #
889
+ # In backward, iteration is transposed compared to forward:
890
+ # - Forward: outer loop over m_blocks (Q tiles), inner loop over n_blocks (KV tiles)
891
+ # - Backward: outer loop over n_blocks (KV tiles), inner loop over m_blocks (Q tiles)
892
+ #
893
+ # The backward block-sparse tensors use "Q direction" indexing:
894
+ # - q_block_cnt[batch, head, n_block] → count of m_blocks to process for this KV tile
895
+ # - q_block_idx[batch, head, n_block, :] → indices of m_blocks to process
896
+ #
897
+
898
+
899
+ @cute.jit
900
+ def get_total_q_block_count_bwd(
901
+ blocksparse_tensors: BlockSparseTensors,
902
+ batch_idx,
903
+ head_idx,
904
+ n_block,
905
+ subtile_factor: cutlass.Constexpr = 1,
906
+ m_block_max: int = 0,
907
+ ):
908
+ """Count total tile iterations for given n_block (KV tile) in backward."""
909
+ q_block_cnt, _, full_block_cnt, _ = blocksparse_tensors
910
+ total = q_block_cnt[batch_idx, head_idx, n_block]
911
+ if const_expr(full_block_cnt is not None):
912
+ total = total + full_block_cnt[batch_idx, head_idx, n_block]
913
+ return total * subtile_factor
914
+
915
+
916
+ @cute.jit
917
+ def produce_block_sparse_q_loads_bwd_sm100(
918
+ blocksparse_tensors: BlockSparseTensors,
919
+ batch_idx,
920
+ head_idx,
921
+ n_block,
922
+ # Pipeline states (will be returned after advancing)
923
+ producer_state_Q_LSE,
924
+ producer_state_dO_dPsum,
925
+ # Pipelines
926
+ pipeline_Q,
927
+ pipeline_LSE,
928
+ pipeline_dO,
929
+ pipeline_dPsum,
930
+ # Load functions
931
+ load_K,
932
+ load_V,
933
+ load_Q,
934
+ load_dO,
935
+ copy_stats,
936
+ # Global tensors for LSE/dPsum
937
+ gLSE,
938
+ sLSE,
939
+ gdPsum,
940
+ sdPsum,
941
+ # TMA copy bytes for extra_tx_count
942
+ tma_copy_bytes_K,
943
+ tma_copy_bytes_V,
944
+ # Flags for which loads to perform
945
+ should_load_Q: cutlass.Constexpr,
946
+ should_load_dO: cutlass.Constexpr,
947
+ # Subtiling factor and bounds
948
+ subtile_factor: cutlass.Constexpr = 1,
949
+ m_block_max: int = 0,
950
+ ):
951
+ """SM100 backward block sparse loading with subtiling.
952
+
953
+ Returns updated (producer_state_Q_LSE, producer_state_dO_dPsum).
954
+ First iteration loads K/V alongside Q/dO; subsequent iterations load only Q/dO.
955
+ """
956
+ (
957
+ curr_q_cnt,
958
+ curr_q_idx,
959
+ curr_full_cnt,
960
+ curr_full_idx,
961
+ loop_count,
962
+ ) = get_block_sparse_iteration_info_bwd(
963
+ blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor, m_block_max
964
+ )
965
+
966
+ for iter_idx in cutlass.range(loop_count, unroll=1):
967
+ m_block, _ = get_m_block_from_iter_bwd(
968
+ iter_idx,
969
+ curr_q_cnt,
970
+ curr_q_idx,
971
+ curr_full_cnt,
972
+ curr_full_idx,
973
+ subtile_factor,
974
+ m_block_max,
975
+ )
976
+ m_block_safe = m_block
977
+ if m_block_max > 0:
978
+ m_block_safe = cutlass.min(m_block, m_block_max - 1)
979
+
980
+ if iter_idx == 0:
981
+ # First block: load K/V alongside Q/dO
982
+ if const_expr(should_load_Q):
983
+ pipeline_Q.producer_acquire(producer_state_Q_LSE, extra_tx_count=tma_copy_bytes_K)
984
+ load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE))
985
+ load_Q(m_block_safe, producer_state=producer_state_Q_LSE)
986
+ pipeline_Q.producer_commit(producer_state_Q_LSE)
987
+ pipeline_LSE.producer_acquire(producer_state_Q_LSE)
988
+ with cute.arch.elect_one():
989
+ copy_stats(
990
+ gLSE[None, m_block_safe],
991
+ sLSE[None, producer_state_Q_LSE.index],
992
+ mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE),
993
+ )
994
+ producer_state_Q_LSE.advance()
995
+ if const_expr(should_load_dO):
996
+ pipeline_dO.producer_acquire(
997
+ producer_state_dO_dPsum, extra_tx_count=tma_copy_bytes_V
998
+ )
999
+ load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum))
1000
+ load_dO(m_block_safe, producer_state=producer_state_dO_dPsum)
1001
+ pipeline_dO.producer_commit(producer_state_dO_dPsum)
1002
+ pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)
1003
+ with cute.arch.elect_one():
1004
+ copy_stats(
1005
+ gdPsum[None, m_block_safe],
1006
+ sdPsum[None, producer_state_dO_dPsum.index],
1007
+ mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum),
1008
+ )
1009
+ producer_state_dO_dPsum.advance()
1010
+ else:
1011
+ # Subsequent blocks: just load Q/dO (K/V already loaded)
1012
+ if const_expr(should_load_Q):
1013
+ pipeline_Q.producer_acquire(producer_state_Q_LSE)
1014
+ load_Q(m_block_safe, producer_state=producer_state_Q_LSE)
1015
+ pipeline_Q.producer_commit(producer_state_Q_LSE)
1016
+ pipeline_LSE.producer_acquire(producer_state_Q_LSE)
1017
+ with cute.arch.elect_one():
1018
+ copy_stats(
1019
+ gLSE[None, m_block_safe],
1020
+ sLSE[None, producer_state_Q_LSE.index],
1021
+ mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE),
1022
+ )
1023
+ producer_state_Q_LSE.advance()
1024
+ if const_expr(should_load_dO):
1025
+ pipeline_dO.producer_acquire(producer_state_dO_dPsum)
1026
+ load_dO(m_block_safe, producer_state=producer_state_dO_dPsum)
1027
+ pipeline_dO.producer_commit(producer_state_dO_dPsum)
1028
+ pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)
1029
+ with cute.arch.elect_one():
1030
+ copy_stats(
1031
+ gdPsum[None, m_block_safe],
1032
+ sdPsum[None, producer_state_dO_dPsum.index],
1033
+ mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum),
1034
+ )
1035
+ producer_state_dO_dPsum.advance()
1036
+
1037
+ return producer_state_Q_LSE, producer_state_dO_dPsum
1038
+
1039
+
1040
+ @cute.jit
1041
+ def get_block_sparse_iteration_info_bwd(
1042
+ blocksparse_tensors: BlockSparseTensors,
1043
+ batch_idx,
1044
+ head_idx,
1045
+ n_block,
1046
+ subtile_factor: cutlass.Constexpr = 1,
1047
+ m_block_max: int = 0,
1048
+ ):
1049
+ """Extract block-sparse iteration info for backward pass.
1050
+
1051
+ Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count).
1052
+ """
1053
+ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
1054
+ curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
1055
+ curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
1056
+
1057
+ if const_expr(full_cnt is not None):
1058
+ curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
1059
+ curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
1060
+ else:
1061
+ curr_full_cnt = Int32(0)
1062
+ curr_full_idx = None
1063
+
1064
+ sparse_block_count = curr_q_cnt
1065
+ if const_expr(full_cnt is not None):
1066
+ sparse_block_count = sparse_block_count + curr_full_cnt
1067
+ total_count = sparse_block_count * subtile_factor
1068
+
1069
+ return curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count
1070
+
1071
+
1072
+ @cute.jit
1073
+ def get_m_block_from_iter_bwd(
1074
+ iter_idx,
1075
+ curr_q_cnt,
1076
+ curr_q_idx: cute.Tensor,
1077
+ curr_full_cnt,
1078
+ curr_full_idx: Optional[cute.Tensor],
1079
+ subtile_factor: cutlass.Constexpr = 1,
1080
+ m_block_max: int = 0,
1081
+ ):
1082
+ """Derive m_block index and is_full_block flag from iteration index.
1083
+
1084
+ Returns (m_block, is_full_block):
1085
+ - m_block: The actual Q-tile block index
1086
+ - is_full_block: True if this is a full block (no mask_mod needed)
1087
+ """
1088
+ sparse_iter_idx = iter_idx // subtile_factor
1089
+ subtile_offset = iter_idx % subtile_factor
1090
+
1091
+ sparse_m_block = Int32(0)
1092
+ is_full_block = False
1093
+ if const_expr(curr_full_idx is not None):
1094
+ if sparse_iter_idx < curr_q_cnt:
1095
+ sparse_m_block = curr_q_idx[sparse_iter_idx]
1096
+ else:
1097
+ sparse_m_block = curr_full_idx[sparse_iter_idx - curr_q_cnt]
1098
+ is_full_block = True
1099
+ else:
1100
+ sparse_m_block = curr_q_idx[sparse_iter_idx]
1101
+
1102
+ return sparse_m_block * subtile_factor + subtile_offset, is_full_block
1103
+
1104
+
1105
+ @cute.jit
1106
+ def _load_q_do_block_sm90(
1107
+ m_block,
1108
+ producer_state_Q,
1109
+ producer_state_dO,
1110
+ pipeline_Q,
1111
+ pipeline_dO,
1112
+ load_K,
1113
+ load_V,
1114
+ load_Q,
1115
+ load_dO,
1116
+ load_LSE,
1117
+ load_dPsum,
1118
+ tma_copy_bytes_K,
1119
+ tma_copy_bytes_V,
1120
+ Q_stage_eq_dO_stage: cutlass.Constexpr,
1121
+ load_kv: bool,
1122
+ ):
1123
+ """Load one Q/dO block, optionally loading K/V on first iteration."""
1124
+ if load_kv:
1125
+ pipeline_Q.producer_acquire(producer_state_Q, extra_tx_count=tma_copy_bytes_K)
1126
+ load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))
1127
+ else:
1128
+ pipeline_Q.producer_acquire(producer_state_Q)
1129
+ load_Q(m_block, producer_state=producer_state_Q)
1130
+ with cute.arch.elect_one():
1131
+ load_LSE(m_block, producer_state=producer_state_Q)
1132
+
1133
+ producer_state_dO_cur = (
1134
+ producer_state_dO if const_expr(not Q_stage_eq_dO_stage) else producer_state_Q
1135
+ )
1136
+ if load_kv:
1137
+ pipeline_dO.producer_acquire(producer_state_dO_cur, extra_tx_count=tma_copy_bytes_V)
1138
+ load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur))
1139
+ else:
1140
+ pipeline_dO.producer_acquire(producer_state_dO_cur)
1141
+ load_dO(m_block, producer_state=producer_state_dO_cur)
1142
+ with cute.arch.elect_one():
1143
+ load_dPsum(m_block, producer_state=producer_state_dO_cur)
1144
+
1145
+ producer_state_Q.advance()
1146
+ producer_state_dO.advance()
1147
+ return producer_state_Q, producer_state_dO
1148
+
1149
+
1150
+ @cute.jit
1151
+ def produce_block_sparse_q_loads_bwd_sm90(
1152
+ blocksparse_tensors: BlockSparseTensors,
1153
+ batch_idx,
1154
+ head_idx,
1155
+ n_block,
1156
+ producer_state_Q,
1157
+ producer_state_dO,
1158
+ pipeline_Q,
1159
+ pipeline_dO,
1160
+ load_K,
1161
+ load_V,
1162
+ load_Q,
1163
+ load_dO,
1164
+ load_LSE,
1165
+ load_dPsum,
1166
+ tma_copy_bytes_K,
1167
+ tma_copy_bytes_V,
1168
+ Q_stage_eq_dO_stage: cutlass.Constexpr,
1169
+ subtile_factor: cutlass.Constexpr,
1170
+ m_block_max: int,
1171
+ ):
1172
+ """SM90 backward block sparse loading with separate partial/full loops.
1173
+
1174
+ K/V are loaded with the first valid block. Iterates partial blocks first,
1175
+ then full blocks, matching consumer order.
1176
+
1177
+ Returns updated (producer_state_Q, producer_state_dO).
1178
+ """
1179
+ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
1180
+ curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
1181
+ curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
1182
+
1183
+ if const_expr(full_cnt is not None):
1184
+ curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
1185
+ curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
1186
+ else:
1187
+ curr_full_cnt = Int32(0)
1188
+ curr_full_idx = None
1189
+
1190
+ kv_loaded = False
1191
+
1192
+ for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):
1193
+ sparse_idx = iter_idx // subtile_factor
1194
+ subtile_offset = iter_idx % subtile_factor
1195
+ m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset
1196
+
1197
+ if m_block < m_block_max:
1198
+ producer_state_Q, producer_state_dO = _load_q_do_block_sm90(
1199
+ m_block,
1200
+ producer_state_Q,
1201
+ producer_state_dO,
1202
+ pipeline_Q,
1203
+ pipeline_dO,
1204
+ load_K,
1205
+ load_V,
1206
+ load_Q,
1207
+ load_dO,
1208
+ load_LSE,
1209
+ load_dPsum,
1210
+ tma_copy_bytes_K,
1211
+ tma_copy_bytes_V,
1212
+ Q_stage_eq_dO_stage,
1213
+ load_kv=not kv_loaded,
1214
+ )
1215
+ kv_loaded = True
1216
+
1217
+ if const_expr(full_cnt is not None):
1218
+ for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):
1219
+ sparse_idx = iter_idx // subtile_factor
1220
+ subtile_offset = iter_idx % subtile_factor
1221
+ m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset
1222
+
1223
+ if m_block < m_block_max:
1224
+ producer_state_Q, producer_state_dO = _load_q_do_block_sm90(
1225
+ m_block,
1226
+ producer_state_Q,
1227
+ producer_state_dO,
1228
+ pipeline_Q,
1229
+ pipeline_dO,
1230
+ load_K,
1231
+ load_V,
1232
+ load_Q,
1233
+ load_dO,
1234
+ load_LSE,
1235
+ load_dPsum,
1236
+ tma_copy_bytes_K,
1237
+ tma_copy_bytes_V,
1238
+ Q_stage_eq_dO_stage,
1239
+ load_kv=not kv_loaded,
1240
+ )
1241
+ kv_loaded = True
1242
+
1243
+ return producer_state_Q, producer_state_dO
1244
+
1245
+
1246
+ @cute.jit
1247
+ def consume_block_sparse_mma_bwd_sm90(
1248
+ blocksparse_tensors: BlockSparseTensors,
1249
+ batch_idx,
1250
+ head_idx,
1251
+ n_block,
1252
+ consumer_state_Q,
1253
+ consumer_state_dO,
1254
+ mma_one_m_block_fn,
1255
+ mask,
1256
+ mask_mod,
1257
+ is_causal: cutlass.Constexpr,
1258
+ is_local: cutlass.Constexpr,
1259
+ thr_mma_SdP,
1260
+ softmax_scale,
1261
+ seqlen,
1262
+ subtile_factor: cutlass.Constexpr,
1263
+ m_block_max: int,
1264
+ aux_tensors=None,
1265
+ fastdiv_mods=(None, None),
1266
+ ):
1267
+ """SM90 backward block sparse MMA consumption with separate partial/full loops.
1268
+
1269
+ Partial blocks are processed first (with mask_mod applied), then full blocks
1270
+ (without mask_mod). This ensures mask_mod is only applied where needed.
1271
+
1272
+ Returns updated (consumer_state_Q, consumer_state_dO).
1273
+ """
1274
+ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
1275
+ curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
1276
+ curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
1277
+
1278
+ if const_expr(full_cnt is not None):
1279
+ curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
1280
+ curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
1281
+ else:
1282
+ curr_full_cnt = Int32(0)
1283
+ curr_full_idx = None
1284
+
1285
+ dKV_accumulate = False
1286
+
1287
+ mask_fn_partial = partial(
1288
+ mask.apply_mask,
1289
+ batch_idx=batch_idx,
1290
+ head_idx=head_idx,
1291
+ n_block=n_block,
1292
+ thr_mma=thr_mma_SdP,
1293
+ mask_seqlen=True,
1294
+ mask_causal=is_causal,
1295
+ mask_local=is_local,
1296
+ mask_mod=mask_mod,
1297
+ aux_tensors=aux_tensors,
1298
+ fastdiv_mods=fastdiv_mods,
1299
+ )
1300
+
1301
+ mask_fn_full = partial(
1302
+ mask.apply_mask,
1303
+ batch_idx=batch_idx,
1304
+ head_idx=head_idx,
1305
+ n_block=n_block,
1306
+ thr_mma=thr_mma_SdP,
1307
+ mask_seqlen=True,
1308
+ mask_causal=is_causal,
1309
+ mask_local=is_local,
1310
+ aux_tensors=aux_tensors,
1311
+ fastdiv_mods=fastdiv_mods,
1312
+ )
1313
+
1314
+ for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):
1315
+ sparse_idx = iter_idx // subtile_factor
1316
+ subtile_offset = iter_idx % subtile_factor
1317
+ m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset
1318
+
1319
+ if m_block < m_block_max:
1320
+ consumer_state_Q, consumer_state_dO = mma_one_m_block_fn(
1321
+ m_block,
1322
+ consumer_state_Q,
1323
+ consumer_state_dO,
1324
+ mask_fn=mask_fn_partial,
1325
+ dKV_accumulate=dKV_accumulate,
1326
+ thr_mma_SdP=thr_mma_SdP,
1327
+ batch_idx=batch_idx,
1328
+ head_idx=head_idx,
1329
+ n_block=n_block,
1330
+ softmax_scale=softmax_scale,
1331
+ seqlen=seqlen,
1332
+ aux_tensors=aux_tensors,
1333
+ fastdiv_mods=fastdiv_mods,
1334
+ )
1335
+ dKV_accumulate = True
1336
+
1337
+ if const_expr(full_cnt is not None):
1338
+ for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):
1339
+ sparse_idx = iter_idx // subtile_factor
1340
+ subtile_offset = iter_idx % subtile_factor
1341
+ m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset
1342
+
1343
+ if m_block < m_block_max:
1344
+ consumer_state_Q, consumer_state_dO = mma_one_m_block_fn(
1345
+ m_block,
1346
+ consumer_state_Q,
1347
+ consumer_state_dO,
1348
+ mask_fn=mask_fn_full,
1349
+ dKV_accumulate=dKV_accumulate,
1350
+ thr_mma_SdP=thr_mma_SdP,
1351
+ batch_idx=batch_idx,
1352
+ head_idx=head_idx,
1353
+ n_block=n_block,
1354
+ softmax_scale=softmax_scale,
1355
+ seqlen=seqlen,
1356
+ aux_tensors=aux_tensors,
1357
+ fastdiv_mods=fastdiv_mods,
1358
+ )
1359
+ dKV_accumulate = True
1360
+
1361
+ return consumer_state_Q, consumer_state_dO
1362
+
1363
+
1364
+ @cute.jit
1365
+ def _store_one_dQaccum_sm90(
1366
+ m_block,
1367
+ sdQaccum: cute.Tensor,
1368
+ gdQaccum: cute.Tensor,
1369
+ num_mma_warp_groups: cutlass.Constexpr,
1370
+ num_threads_per_warp_group: cutlass.Constexpr,
1371
+ tma_copy_bytes_dQ,
1372
+ ):
1373
+ """Store dQaccum for a single m_block."""
1374
+ for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):
1375
+ cute.arch.barrier(
1376
+ barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
1377
+ number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
1378
+ )
1379
+ with cute.arch.elect_one():
1380
+ copy_utils.cpasync_reduce_bulk_add_f32(
1381
+ sdQaccum[None, warp_group_idx].iterator,
1382
+ gdQaccum[None, warp_group_idx, m_block].iterator,
1383
+ tma_copy_bytes_dQ,
1384
+ )
1385
+ cute.arch.cp_async_bulk_commit_group()
1386
+ for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):
1387
+ cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True)
1388
+ cute.arch.barrier_arrive(
1389
+ barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1390
+ number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
1391
+ )
1392
+
1393
+
1394
+ @cute.jit
1395
+ def dQaccum_store_block_sparse_bwd_sm90(
1396
+ blocksparse_tensors: BlockSparseTensors,
1397
+ batch_idx,
1398
+ head_idx,
1399
+ n_block,
1400
+ sdQaccum: cute.Tensor,
1401
+ gdQaccum: cute.Tensor,
1402
+ subtile_factor: cutlass.Constexpr,
1403
+ m_block_max: int,
1404
+ num_mma_warp_groups: cutlass.Constexpr,
1405
+ num_threads_per_warp_group: cutlass.Constexpr,
1406
+ tma_copy_bytes_dQ,
1407
+ ):
1408
+ """SM90 backward block sparse dQaccum store with separate partial/full loops.
1409
+
1410
+ Iterates partial blocks first, then full blocks, matching producer/consumer order.
1411
+ """
1412
+ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
1413
+ curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
1414
+ curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
1415
+
1416
+ if const_expr(full_cnt is not None):
1417
+ curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
1418
+ curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
1419
+ else:
1420
+ curr_full_cnt = Int32(0)
1421
+ curr_full_idx = None
1422
+
1423
+ for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):
1424
+ sparse_idx = iter_idx // subtile_factor
1425
+ subtile_offset = iter_idx % subtile_factor
1426
+ m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset
1427
+
1428
+ if m_block < m_block_max:
1429
+ _store_one_dQaccum_sm90(
1430
+ m_block,
1431
+ sdQaccum,
1432
+ gdQaccum,
1433
+ num_mma_warp_groups,
1434
+ num_threads_per_warp_group,
1435
+ tma_copy_bytes_dQ,
1436
+ )
1437
+
1438
+ if const_expr(full_cnt is not None):
1439
+ for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):
1440
+ sparse_idx = iter_idx // subtile_factor
1441
+ subtile_offset = iter_idx % subtile_factor
1442
+ m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset
1443
+
1444
+ if m_block < m_block_max:
1445
+ _store_one_dQaccum_sm90(
1446
+ m_block,
1447
+ sdQaccum,
1448
+ gdQaccum,
1449
+ num_mma_warp_groups,
1450
+ num_threads_per_warp_group,
1451
+ tma_copy_bytes_dQ,
1452
+ )