mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2951 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Ted Zadouri, Markus Hoehnerbach, Jay Shah, Tri Dao.
3
+ import math
4
+ from typing import Callable, Optional
5
+ from functools import partial
6
+
7
+ import cuda.bindings.driver as cuda
8
+
9
+ import cutlass
10
+ import cutlass.cute as cute
11
+ from cutlass.cute import FastDivmodDivisor
12
+ from cutlass import Float32, Int32, const_expr
13
+ from cutlass.utils import LayoutEnum
14
+ from cutlass.cute.nvgpu import cpasync, tcgen05
15
+ import cutlass.utils.blackwell_helpers as sm100_utils_basic
16
+ from cutlass.pipeline import PipelineAsync, PipelineConsumer
17
+
18
+ from mslk.attention.flash_attn import utils
19
+ from mslk.attention.flash_attn import copy_utils
20
+ from mslk.attention.flash_attn import pipeline
21
+ from mslk.attention.flash_attn.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa
22
+ from mslk.attention.flash_attn.mask import AttentionMask
23
+ from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
24
+ from mslk.attention.flash_attn.block_info import BlockInfo
25
+ from mslk.attention.flash_attn.tile_scheduler import (
26
+ TileSchedulerArguments,
27
+ SingleTileScheduler,
28
+ SingleTileLPTBwdScheduler, # noqa
29
+ SingleTileVarlenScheduler,
30
+ ParamsBase,
31
+ )
32
+
33
+ from mslk.attention.flash_attn import barrier
34
+ from mslk.attention.flash_attn.named_barrier import NamedBarrierBwdSm100
35
+ from mslk.attention.flash_attn.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
36
+ from mslk.attention.flash_attn.block_sparsity import BlockSparseTensors
37
+ from mslk.attention.flash_attn.block_sparse_utils import (
38
+ get_total_q_block_count_bwd,
39
+ get_block_sparse_iteration_info_bwd,
40
+ get_m_block_from_iter_bwd,
41
+ produce_block_sparse_q_loads_bwd_sm100,
42
+ )
43
+
44
+
45
+ class FlashAttentionBackwardSm100:
46
+ arch = 100
47
+
48
+ def __init__(
49
+ self,
50
+ head_dim: int,
51
+ head_dim_v: Optional[int] = None,
52
+ is_causal: bool = False,
53
+ is_local: bool = False,
54
+ qhead_per_kvhead: cutlass.Constexpr[int] = 1,
55
+ tile_m: int = 128,
56
+ tile_n: int = 128,
57
+ is_persistent: bool = False,
58
+ deterministic: bool = False,
59
+ cluster_size: int = 1,
60
+ score_mod: cutlass.Constexpr | None = None,
61
+ score_mod_bwd: cutlass.Constexpr | None = None,
62
+ mask_mod: cutlass.Constexpr | None = None,
63
+ has_aux_tensors: cutlass.Constexpr = False,
64
+ subtile_factor: cutlass.Constexpr[int] = 1,
65
+ ):
66
+ # padding head_dim to a multiple of 16 as k_block_size
67
+ hdim_multiple_of = 16
68
+ self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
69
+ head_dim_v = head_dim_v if head_dim_v is not None else head_dim
70
+ self.same_hdim_kv = head_dim == head_dim_v
71
+ assert head_dim == head_dim_v, "head_dim and head_dim_v must be the same for now"
72
+ self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)
73
+ assert self.tile_hdim == self.tile_hdimv, (
74
+ "tile_hdim and tile_hdimv must be the same for now"
75
+ )
76
+ self.check_hdim_oob = head_dim != self.tile_hdim
77
+ self.check_hdim_v_oob = head_dim_v != self.tile_hdimv
78
+
79
+ self.tile_m = tile_m
80
+ self.tile_n = tile_n
81
+
82
+ # CTA tiler
83
+ self.cta_tiler = (tile_n, tile_m, self.tile_hdim)
84
+ # S = K @ Q.T
85
+ self.mma_tiler_kq = (tile_n, tile_m, self.tile_hdim)
86
+ # dP = V @ dO.T
87
+ self.mma_tiler_vdo = (tile_n, tile_m, self.tile_hdimv)
88
+ # dV = P.T @ dO
89
+ self.mma_tiler_pdo = (tile_n, self.tile_hdimv, tile_m)
90
+ # dK = dS.T @ Q (N, M) (M, D)
91
+ self.mma_tiler_dsq = (tile_n, self.tile_hdimv, tile_m)
92
+ # dQ = dS @ K
93
+ self.mma_tiler_dsk = (tile_m, self.tile_hdimv, tile_n)
94
+
95
+ self.acc_dtype = Float32
96
+
97
+ assert cluster_size in (1, 2), "Only cluster_size=1 or 2 is supported"
98
+ self.cluster_shape_mn = (cluster_size, 1)
99
+ self.is_persistent = is_persistent
100
+ self.is_causal = is_causal
101
+ self.is_local = is_local
102
+ self.qhead_per_kvhead = qhead_per_kvhead
103
+ self.pack_gqa = False
104
+ self.deterministic = deterministic
105
+
106
+ # Score mod and mask mod support
107
+ self.score_mod = score_mod
108
+ self.score_mod_bwd = score_mod_bwd
109
+ self.mask_mod = mask_mod
110
+ self.has_aux_tensors = has_aux_tensors
111
+ self.subtile_factor = subtile_factor
112
+ # For score_mod, use vec_size=1 (like forward) to handle per-element indices
113
+ if cutlass.const_expr(has_aux_tensors):
114
+ self.vec_size: cutlass.Constexpr = 1
115
+ else:
116
+ self.vec_size: cutlass.Constexpr = 4
117
+ self.qk_acc_dtype = Float32
118
+
119
+ # Speed optimizations, does not affect correctness
120
+ self.shuffle_LSE = False
121
+ self.shuffle_dPsum = False
122
+ self.use_smem_dS_for_mma_dK = self.deterministic and self.is_causal
123
+
124
+ self.reduce_warp_ids = (0, 1, 2, 3)
125
+ self.compute_warp_ids = (4, 5, 6, 7, 8, 9, 10, 11)
126
+ self.mma_warp_id = 12
127
+ self.load_warp_id = 13
128
+ self.epi_warp_id = 14
129
+ self.empty_warp_id = 15
130
+
131
+ # 16 warps -> 512 threads
132
+ self.threads_per_cta = cute.arch.WARP_SIZE * len(
133
+ (
134
+ *self.reduce_warp_ids,
135
+ *self.compute_warp_ids,
136
+ self.mma_warp_id,
137
+ self.load_warp_id,
138
+ self.epi_warp_id,
139
+ self.empty_warp_id,
140
+ )
141
+ )
142
+
143
+ # NamedBarrier
144
+ self.compute_sync_barrier = cutlass.pipeline.NamedBarrier(
145
+ barrier_id=int(NamedBarrierBwdSm100.Compute),
146
+ num_threads=len(self.compute_warp_ids) * cute.arch.WARP_SIZE,
147
+ )
148
+ # self.epilogue_sync_barrier = pipeline.NamedBarrier(
149
+ # barrier_id=2,
150
+ # num_threads=self.num_compute_warps * self.threads_per_warp,
151
+ # )
152
+ self.reduce_sync_barrier = cutlass.pipeline.NamedBarrier(
153
+ barrier_id=int(NamedBarrierBwdSm100.dQaccReduce),
154
+ num_threads=len(self.reduce_warp_ids) * cute.arch.WARP_SIZE,
155
+ )
156
+
157
+ # TMEM setup
158
+ SM100_TMEM_CAPACITY_COLUMNS = 512
159
+ self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
160
+
161
+ # self.tmem_dK_offset = 0
162
+ # self.tmem_dV_offset = self.tmem_dK_offset + self.tile_hdim
163
+ # self.tmem_dQ_offset = self.tmem_dV_offset + self.tile_hdimv
164
+ # self.tmem_dP_offset = self.tmem_dQ_offset # overlap with dQ
165
+ # self.tmem_S_offset = self.tmem_dQ_offset + max(self.tile_m, self.tile_hdim)
166
+ # self.tmem_P_offset = self.tmem_S_offset # overlap with S
167
+ # self.tmem_total = self.tmem_S_offset + self.tile_n
168
+ # assert self.tmem_total <= self.tmem_alloc_cols
169
+
170
+ self.tmem_S_offset = 0
171
+ self.tmem_P_offset = 0 # overlap with S
172
+ self.tmem_dV_offset = self.tmem_S_offset + self.tile_n
173
+ self.tmem_dP_offset = self.tmem_dV_offset + self.tile_hdimv
174
+ self.tmem_dQ_offset = self.tmem_dP_offset # overlap with dP
175
+ self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m
176
+ self.tmem_dS_offset = self.tmem_dP_offset # overlap with dP
177
+
178
+ if (not is_causal and not is_local) or deterministic:
179
+ self.num_regs_reduce = 152
180
+ self.num_regs_compute = 136
181
+ else:
182
+ self.num_regs_reduce = 136
183
+ self.num_regs_compute = 144
184
+ self.num_regs_other = 96 - 8
185
+ self.num_regs_empty = 24
186
+ assert self.num_regs_reduce + self.num_regs_compute * 2 + self.num_regs_other <= 512
187
+
188
+ self.buffer_align_bytes = 1024
189
+
190
+ def _setup_attributes(self):
191
+ self.Q_stage = 2
192
+ self.dO_stage = 1
193
+ # LSE_stage = Q_stage and dPsum_stage = dO_stage
194
+ # self.sdKVaccum_stage = 2
195
+ # number of tma reduce adds per dQacc mma
196
+ self.dQ_reduce_ncol = 32
197
+ self.sdQaccum_stage = 64 // self.dQ_reduce_ncol
198
+ assert self.tile_hdim % self.dQ_reduce_ncol == 0
199
+ self.dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol
200
+ self.cluster_reduce_dQ = False and cute.size(self.cluster_shape_mn) > 1
201
+ # number of tma reduce adds for dKacc and dVacc epilogue
202
+ self.dK_reduce_ncol = 32
203
+
204
+ def _get_tiled_mma(self):
205
+ cta_group = tcgen05.CtaGroup.ONE
206
+ # S = K @ Q.T
207
+ tiled_mma_S = sm100_utils_basic.make_trivial_tiled_mma(
208
+ self.q_dtype,
209
+ tcgen05.OperandMajorMode.K,
210
+ tcgen05.OperandMajorMode.K,
211
+ self.acc_dtype,
212
+ cta_group,
213
+ self.mma_tiler_kq[:2],
214
+ )
215
+ # dP = V @ dO.T
216
+ tiled_mma_dP = sm100_utils_basic.make_trivial_tiled_mma(
217
+ self.do_dtype,
218
+ tcgen05.OperandMajorMode.K,
219
+ tcgen05.OperandMajorMode.K,
220
+ self.acc_dtype,
221
+ cta_group,
222
+ self.mma_tiler_vdo[:2],
223
+ )
224
+ # dV += P @ dO --> (K, MN) major
225
+ tiled_mma_dV = sm100_utils_basic.make_trivial_tiled_mma(
226
+ self.do_dtype,
227
+ tcgen05.OperandMajorMode.K, # P_major_mode
228
+ tcgen05.OperandMajorMode.MN, # dO_major_mode
229
+ self.acc_dtype,
230
+ cta_group,
231
+ self.mma_tiler_pdo[:2],
232
+ a_source=tcgen05.OperandSource.TMEM,
233
+ )
234
+ # dK += dS.T @ Q
235
+ if const_expr(self.use_smem_dS_for_mma_dK):
236
+ mma_dK_a_src = tcgen05.OperandSource.SMEM
237
+ else:
238
+ mma_dK_a_src = tcgen05.OperandSource.TMEM
239
+ tiled_mma_dK = sm100_utils_basic.make_trivial_tiled_mma(
240
+ self.do_dtype,
241
+ tcgen05.OperandMajorMode.K, # dS_major_mode
242
+ tcgen05.OperandMajorMode.MN, # Q_major_mode
243
+ self.acc_dtype,
244
+ cta_group,
245
+ self.mma_tiler_dsq[:2],
246
+ a_source=mma_dK_a_src,
247
+ )
248
+ # dQ = dS @ K
249
+ tiled_mma_dQ = sm100_utils_basic.make_trivial_tiled_mma(
250
+ self.k_dtype,
251
+ tcgen05.OperandMajorMode.MN, # dS_major_mode
252
+ tcgen05.OperandMajorMode.MN, # Kt_major_mode
253
+ self.acc_dtype,
254
+ cta_group,
255
+ self.mma_tiler_dsk[:2],
256
+ )
257
+ return tiled_mma_S, tiled_mma_dP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ
258
+
259
+ def _setup_smem_layout(self):
260
+ # S = K @ Q.T
261
+ sK_layout = sm100_utils_basic.make_smem_layout_a(
262
+ self.tiled_mma_S,
263
+ self.mma_tiler_kq,
264
+ self.k_dtype,
265
+ 1,
266
+ )
267
+ self.sK_layout = cute.slice_(sK_layout, (None, None, None, 0))
268
+ self.sQ_layout = sm100_utils_basic.make_smem_layout_b(
269
+ self.tiled_mma_S,
270
+ self.mma_tiler_kq,
271
+ self.q_dtype,
272
+ self.Q_stage,
273
+ )
274
+ # dP = V @ dO.T
275
+ sV_layout = sm100_utils_basic.make_smem_layout_a(
276
+ self.tiled_mma_dP,
277
+ self.mma_tiler_vdo,
278
+ self.v_dtype,
279
+ 1,
280
+ )
281
+ self.sV_layout = cute.slice_(sV_layout, (None, None, None, 0))
282
+ self.sdOt_layout = sm100_utils_basic.make_smem_layout_b(
283
+ self.tiled_mma_dP,
284
+ self.mma_tiler_vdo,
285
+ self.do_dtype,
286
+ self.dO_stage,
287
+ )
288
+ # dV += P @ dO
289
+ tP_layout = sm100_utils_basic.make_smem_layout_a(
290
+ self.tiled_mma_dV,
291
+ self.mma_tiler_pdo,
292
+ self.do_dtype,
293
+ 1,
294
+ )
295
+ self.tP_layout = cute.slice_(tP_layout, (None, None, None, 0))
296
+ self.sdO_layout = sm100_utils_basic.make_smem_layout_b(
297
+ self.tiled_mma_dV,
298
+ self.mma_tiler_pdo,
299
+ self.do_dtype,
300
+ self.dO_stage,
301
+ )
302
+ # dK += dS.T @ Q
303
+ sdSt_layout = sm100_utils_basic.make_smem_layout_a(
304
+ self.tiled_mma_dK,
305
+ self.mma_tiler_dsq,
306
+ self.ds_dtype,
307
+ 1,
308
+ )
309
+ self.sdSt_layout = cute.slice_(sdSt_layout, (None, None, None, 0))
310
+ tdS_layout = sm100_utils_basic.make_smem_layout_a(
311
+ self.tiled_mma_dK,
312
+ self.mma_tiler_dsq,
313
+ self.ds_dtype,
314
+ 1,
315
+ )
316
+ self.tdS_layout = cute.slice_(tdS_layout, (None, None, None, 0))
317
+ self.sQt_layout = sm100_utils_basic.make_smem_layout_b(
318
+ self.tiled_mma_dK,
319
+ self.mma_tiler_dsq,
320
+ self.q_dtype,
321
+ self.Q_stage,
322
+ )
323
+ # dQ = dS @ K
324
+ sdS_layout = sm100_utils_basic.make_smem_layout_a(
325
+ self.tiled_mma_dQ,
326
+ self.mma_tiler_dsk,
327
+ self.ds_dtype,
328
+ 1,
329
+ )
330
+ self.sdS_layout = cute.slice_(sdS_layout, (None, None, None, 0))
331
+ sKt_layout = sm100_utils_basic.make_smem_layout_b(
332
+ self.tiled_mma_dQ,
333
+ self.mma_tiler_dsk,
334
+ self.k_dtype,
335
+ 1,
336
+ )
337
+ self.sKt_layout = cute.slice_(sKt_layout, (None, None, None, 0))
338
+ self.sdQaccum_layout = cute.make_layout(
339
+ (self.tile_m * self.dQ_reduce_ncol, self.sdQaccum_stage)
340
+ )
341
+ self.sLSE_layout = cute.make_layout(
342
+ shape=(self.tile_m, self.Q_stage),
343
+ stride=(1, cute.round_up(self.tile_m, 64)),
344
+ )
345
+ self.sdPsum_layout = cute.make_layout(
346
+ shape=(self.tile_m, self.dO_stage),
347
+ stride=(1, cute.round_up(self.tile_m, 64)),
348
+ )
349
+ self.sdKV_epi_tile = (
350
+ self.tile_n,
351
+ min(128 // (self.dk_dtype.width // 8), self.tile_hdim // 2), # 64 or 32
352
+ ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2]
353
+ # headdim_64 gets 1 stage
354
+ self.num_epi_stages = max(1, (self.tile_hdim // 2) // self.sdKV_epi_tile[1])
355
+ self.sdKV_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages
356
+ # TODO: dK and dV could have different shapes
357
+ if const_expr(not self.dKV_postprocess):
358
+ self.sdKV_layout = sm100_utils_basic.make_smem_layout_epi(
359
+ self.dk_dtype,
360
+ LayoutEnum.ROW_MAJOR,
361
+ self.sdKV_epi_tile,
362
+ 2, # num compute wgs
363
+ )
364
+ else:
365
+ self.sdKV_layout = cute.make_layout((self.tile_n * self.dK_reduce_ncol, 2))
366
+
367
+ @cute.jit
368
+ def __call__(
369
+ self,
370
+ mQ: cute.Tensor,
371
+ mK: cute.Tensor,
372
+ mV: cute.Tensor,
373
+ mdO: cute.Tensor,
374
+ mLSE: cute.Tensor,
375
+ mdPsum: cute.Tensor,
376
+ mdQaccum: cute.Tensor,
377
+ mdK: cute.Tensor,
378
+ mdV: cute.Tensor,
379
+ softmax_scale: Float32,
380
+ stream: cuda.CUstream,
381
+ mCuSeqlensQ: Optional[cute.Tensor] = None,
382
+ mCuSeqlensK: Optional[cute.Tensor] = None,
383
+ mSeqUsedQ: Optional[cute.Tensor] = None,
384
+ mSeqUsedK: Optional[cute.Tensor] = None,
385
+ softcap: Float32 | float | None = None,
386
+ window_size_left: Int32 | int | None = None,
387
+ window_size_right: Int32 | int | None = None,
388
+ mdQ_semaphore: Optional[cute.Tensor] = None,
389
+ mdK_semaphore: Optional[cute.Tensor] = None,
390
+ mdV_semaphore: Optional[cute.Tensor] = None,
391
+ aux_tensors: Optional[list] = None,
392
+ # Block-sparse tensors (Q direction - for iterating m_blocks per n_block):
393
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
394
+ ):
395
+ self.q_dtype = mQ.element_type
396
+ self.k_dtype = mK.element_type
397
+ self.v_dtype = mV.element_type
398
+ self.do_dtype = mdO.element_type
399
+ self.lse_dtype = mLSE.element_type
400
+ self.dpsum_dtype = mdPsum.element_type
401
+ self.dqaccum_dtype = mdQaccum.element_type
402
+ self.dk_dtype = mdK.element_type
403
+ self.dv_dtype = mdV.element_type
404
+ self.ds_dtype = self.q_dtype
405
+
406
+ self.is_varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None
407
+ self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None
408
+ self.use_tma_store = not (self.qhead_per_kvhead == 1 and mCuSeqlensK is not None)
409
+ self.dKV_postprocess = self.qhead_per_kvhead > 1
410
+
411
+ if const_expr(self.dKV_postprocess):
412
+ assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA"
413
+ assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA"
414
+
415
+ # Assume all strides are divisible by 128 bits except the last stride
416
+ new_stride = lambda t: (
417
+ *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
418
+ t.stride[-1],
419
+ )
420
+ (
421
+ mdQaccum,
422
+ mdK,
423
+ mdV,
424
+ ) = [
425
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
426
+ if t is not None
427
+ else None
428
+ for t in (
429
+ mdQaccum,
430
+ mdK,
431
+ mdV,
432
+ )
433
+ ]
434
+
435
+ # (b, s, n, h) --> (s, h, n, b) or (t, n, h) -> (t, h, n)
436
+ QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
437
+ mQ, mdO = [utils.select(t, mode=QO_layout_transpose) for t in (mQ, mdO)]
438
+
439
+ KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]
440
+ mK, mV = [utils.select(t, mode=KV_layout_transpose) for t in (mK, mV)]
441
+
442
+ # (b, n, s) --> (s, n, b) or (n, t) --> (t, n)
443
+ LSE_dPsum_dQaccum_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
444
+ mLSE, mdPsum, mdQaccum = [
445
+ utils.select(t, mode=LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum)
446
+ ]
447
+
448
+ if const_expr(not self.dKV_postprocess):
449
+ layout_dKV_transpose = KV_layout_transpose
450
+ else:
451
+ layout_dKV_transpose = [2, 1, 0] if const_expr(mCuSeqlensK is None) else [1, 0]
452
+ mdK, mdV = [utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV)]
453
+ # (s, h, n, b) --> (h, s, n, b) or (t, h, n) -> (h, t, b)
454
+ dO_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensQ is None) else [1, 0, 2]
455
+ mdO = utils.select(mdO, mode=dO_transpose)
456
+
457
+ # (b, n, block, stage) -> (block, stage, n, b)
458
+ semaphore_transpose = [2, 3, 1, 0]
459
+ if const_expr(self.deterministic):
460
+ assert mdQ_semaphore is not None
461
+ mdQ_semaphore = utils.select(mdQ_semaphore, mode=semaphore_transpose)
462
+
463
+ if const_expr(self.deterministic and self.qhead_per_kvhead > 1):
464
+ assert mdK_semaphore is not None
465
+ assert mdV_semaphore is not None
466
+ mdK_semaphore, mdV_semaphore = [
467
+ utils.select(t, mode=semaphore_transpose) for t in (mdK_semaphore, mdV_semaphore)
468
+ ]
469
+ else:
470
+ mdK_semaphore = None
471
+ mdV_semaphore = None
472
+
473
+ self._setup_attributes()
474
+ (
475
+ self.tiled_mma_S,
476
+ self.tiled_mma_dP,
477
+ self.tiled_mma_dK,
478
+ self.tiled_mma_dV,
479
+ self.tiled_mma_dQ,
480
+ ) = self._get_tiled_mma()
481
+ self._setup_smem_layout()
482
+
483
+ cta_group = tcgen05.CtaGroup.ONE
484
+
485
+ self.cluster_shape_mnk = (*self.cluster_shape_mn, 1)
486
+ self.cluster_layout_vmnk = cute.tiled_divide(
487
+ cute.make_layout(self.cluster_shape_mnk),
488
+ (self.tiled_mma_S.thr_id.shape,),
489
+ )
490
+ self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
491
+ self.is_q_do_mcast = self.num_mcast_ctas_b > 1
492
+
493
+ if const_expr(not self.dKV_postprocess):
494
+ self.mdK_layout_enum = LayoutEnum.from_tensor(mdK)
495
+ self.mdV_layout_enum = LayoutEnum.from_tensor(mdV)
496
+ dK_major_mode = self.mdK_layout_enum.mma_major_mode()
497
+ dV_major_mode = self.mdV_layout_enum.mma_major_mode()
498
+ if const_expr(dK_major_mode != tcgen05.OperandMajorMode.K):
499
+ raise RuntimeError("The layout of mdK is wrong")
500
+ if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K):
501
+ raise RuntimeError("The layout of mdV is wrong")
502
+
503
+ if const_expr(self.use_tma_store and not self.dKV_postprocess):
504
+ tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp()
505
+ tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom(
506
+ tma_copy_op_dKV,
507
+ mdK,
508
+ cute.select(self.sdKV_layout, mode=[0, 1]),
509
+ self.sdKV_epi_tile,
510
+ 1, # no mcast
511
+ )
512
+ tma_atom_dV, mdV_tma_tensor = cpasync.make_tiled_tma_atom(
513
+ tma_copy_op_dKV,
514
+ mdV,
515
+ cute.select(self.sdKV_layout, mode=[0, 1]),
516
+ self.sdKV_epi_tile,
517
+ 1, # no mcast
518
+ )
519
+ else:
520
+ mdV_tma_tensor = mdV
521
+ mdK_tma_tensor = mdK
522
+ tma_atom_dV = None
523
+ tma_atom_dK = None
524
+
525
+ if const_expr(not self.dKV_postprocess):
526
+ thr_layout_r2s_dKV = cute.make_ordered_layout((128, 1), order=(1, 0)) # 128 threads
527
+ val_layout_r2s_dKV = cute.make_ordered_layout(
528
+ (1, 128 // self.dk_dtype.width), order=(1, 0)
529
+ ) # 4 or 8 vals for 16 byte store
530
+ copy_atom_r2s_dKV = cute.make_copy_atom(
531
+ cute.nvgpu.CopyUniversalOp(),
532
+ self.dk_dtype,
533
+ num_bits_per_copy=128,
534
+ )
535
+ tiled_copy_r2s_dKV = cute.make_tiled_copy_tv(
536
+ copy_atom_r2s_dKV, thr_layout_r2s_dKV, val_layout_r2s_dKV
537
+ )
538
+ else:
539
+ tiled_copy_r2s_dKV = copy_utils.tiled_copy_1d(
540
+ Float32, 128, num_copy_elems=128 // Float32.width
541
+ )
542
+
543
+ tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group)
544
+ tma_load_op_multicast = cpasync.CopyBulkTensorTileG2SMulticastOp(cta_group)
545
+
546
+ # S.T = K @ Q.T
547
+ tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_A(
548
+ tma_load_op,
549
+ mK,
550
+ cute.select(self.sK_layout, mode=[0, 1, 2]),
551
+ self.mma_tiler_kq,
552
+ self.tiled_mma_S,
553
+ self.cluster_layout_vmnk.shape,
554
+ )
555
+ Q_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B(
556
+ self.cluster_shape_mnk, self.tiled_mma_S.thr_id
557
+ )
558
+ tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B(
559
+ # tma_load_op if const_expr(self.cluster_shape_mnk[0] == 1) else tma_load_op_multicast,
560
+ Q_tma_op,
561
+ mQ,
562
+ cute.select(self.sQ_layout, mode=[0, 1, 2]),
563
+ self.mma_tiler_kq,
564
+ self.tiled_mma_S,
565
+ self.cluster_layout_vmnk.shape,
566
+ )
567
+ # dP.T = V @ dO.T
568
+ tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_A(
569
+ tma_load_op,
570
+ mV,
571
+ cute.select(self.sV_layout, mode=[0, 1, 2]),
572
+ self.mma_tiler_vdo,
573
+ self.tiled_mma_dP,
574
+ self.cluster_layout_vmnk.shape,
575
+ )
576
+ dO_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B(
577
+ self.cluster_shape_mnk, self.tiled_mma_dV.thr_id
578
+ )
579
+ tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B(
580
+ # tma_load_op if const_expr(self.cluster_shape_mnk[0] == 1) else tma_load_op_multicast,
581
+ dO_tma_op,
582
+ mdO,
583
+ cute.select(self.sdO_layout, mode=[0, 1, 2]),
584
+ self.mma_tiler_pdo,
585
+ self.tiled_mma_dV,
586
+ self.cluster_layout_vmnk.shape,
587
+ )
588
+
589
+ self.tma_copy_bytes = {
590
+ name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2]))
591
+ for name, mX, layout in [
592
+ ("Q", mQ, self.sQ_layout),
593
+ ("K", mK, self.sK_layout),
594
+ ("V", mV, self.sV_layout),
595
+ ("dO", mdO, self.sdO_layout),
596
+ ]
597
+ }
598
+ self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8
599
+ self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8
600
+ self.tma_copy_bytes["dQ"] = self.tile_m * self.dQ_reduce_ncol * Float32.width // 8
601
+ self.tma_copy_bytes["dKacc"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8
602
+
603
+ # TileScheduler = SingleTileScheduler
604
+ if const_expr(self.is_varlen_k):
605
+ TileScheduler = SingleTileVarlenScheduler
606
+ elif const_expr(self.deterministic):
607
+ TileScheduler = SingleTileLPTBwdScheduler
608
+ else:
609
+ TileScheduler = SingleTileScheduler
610
+ # reads n_blocks right-to-left
611
+ self.spt = (self.is_causal or self.is_local) and self.deterministic
612
+ tile_sched_args = TileSchedulerArguments(
613
+ cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), # num_blocks
614
+ cute.size(mQ.shape[2]), # num_heads = num_query_heads
615
+ cute.size(mK.shape[3])
616
+ if const_expr(mCuSeqlensK is None)
617
+ else cute.size(mCuSeqlensK.shape[0] - 1), # num_batches
618
+ 1, # num_splits
619
+ cute.size(mQ.shape[0]), # pass seqlen_q or total_q for seqlen_k
620
+ mQ.shape[1], # headdim
621
+ mV.shape[1], # headdim_v
622
+ total_q=cute.size(mK.shape[0]) # pass total_k for total_q
623
+ if const_expr(mCuSeqlensK is not None)
624
+ else cute.size(mK.shape[0]) * cute.size(mK.shape[3]),
625
+ tile_shape_mn=self.cta_tiler[:2], # (tile_n, tile_m)
626
+ cluster_shape_mn=self.cluster_shape_mnk[:2],
627
+ mCuSeqlensQ=mCuSeqlensK,
628
+ mSeqUsedQ=mSeqUsedK,
629
+ qhead_per_kvhead_packgqa=1, # pack_gqa disabled for bwd
630
+ element_size=self.k_dtype.width // 8,
631
+ is_persistent=self.is_persistent, # persistent mode not tested
632
+ lpt=self.spt,
633
+ head_swizzle=self.deterministic,
634
+ )
635
+
636
+ tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
637
+ self.tile_scheduler_cls = TileScheduler
638
+ grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
639
+ # cute.printf("grid_dim = {}", grid_dim)
640
+
641
+ # Compute allocation sizes for shared buffers that are reused
642
+ # sQ is reused for sdK, sdO is reused for sdV
643
+ sQ_alloc_bytes = max(
644
+ cute.size_in_bytes(self.q_dtype, self.sQ_layout),
645
+ cute.size_in_bytes(self.dk_dtype, self.sdKV_layout),
646
+ )
647
+ sdO_alloc_bytes = max(
648
+ cute.size_in_bytes(self.dv_dtype, self.sdKV_layout),
649
+ cute.size_in_bytes(self.do_dtype, self.sdO_layout),
650
+ )
651
+ # Sanity check that layouts fit in allocation
652
+ sdV_bytes = cute.size_in_bytes(self.dv_dtype, self.sdKV_layout)
653
+ sdK_bytes = cute.size_in_bytes(self.dk_dtype, self.sdKV_layout)
654
+ assert sdV_bytes <= sdO_alloc_bytes, "sdV doesn't fit in sdO storage allocation"
655
+ assert sdK_bytes <= sQ_alloc_bytes, "sdK doesn't fit in sQ storage allocation"
656
+
657
+ @cute.struct
658
+ class SharedStorage:
659
+ Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage]
660
+ dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage]
661
+ LSE_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage]
662
+ dPsum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage]
663
+ S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1]
664
+ dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1]
665
+ dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1]
666
+ dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 2]
667
+ dQ_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
668
+ dQ_cluster_full_mbar_ptr: cute.struct.MemRange[
669
+ cutlass.Int64, self.dQaccum_reduce_stage // 2
670
+ ]
671
+ dQ_cluster_empty_mbar_ptr: cute.struct.MemRange[
672
+ cutlass.Int64, self.dQaccum_reduce_stage // 2
673
+ ]
674
+ tmem_holding_buf: Int32
675
+ tmem_dealloc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1]
676
+
677
+ # Smem tensors
678
+
679
+ # sQ is reused for sdK which in the non-MHA case needs float32
680
+ sQ: cute.struct.Align[
681
+ cute.struct.MemRange[cute.Uint8, sQ_alloc_bytes],
682
+ self.buffer_align_bytes,
683
+ ]
684
+ sK: cute.struct.Align[
685
+ cute.struct.MemRange[self.k_dtype, cute.cosize(self.sK_layout)],
686
+ self.buffer_align_bytes,
687
+ ]
688
+ sV: cute.struct.Align[
689
+ cute.struct.MemRange[self.v_dtype, cute.cosize(self.sV_layout)],
690
+ self.buffer_align_bytes,
691
+ ]
692
+ # sdO is reused for sdV which in the non-MHA case needs float32
693
+ sdO: cute.struct.Align[
694
+ cute.struct.MemRange[cute.Uint8, sdO_alloc_bytes],
695
+ self.buffer_align_bytes,
696
+ ]
697
+ sdS: cute.struct.Align[
698
+ cute.struct.MemRange[self.ds_dtype, cute.cosize(self.sdSt_layout)],
699
+ 128,
700
+ ]
701
+ sLSE: cute.struct.Align[
702
+ cute.struct.MemRange[self.lse_dtype, cute.cosize(self.sLSE_layout)],
703
+ 128,
704
+ ]
705
+ sdPsum: cute.struct.Align[
706
+ cute.struct.MemRange[self.dpsum_dtype, cute.cosize(self.sdPsum_layout)],
707
+ 128,
708
+ ]
709
+ sdQaccum: cute.struct.Align[
710
+ cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(self.sdQaccum_layout)],
711
+ self.buffer_align_bytes,
712
+ ]
713
+
714
+ self.shared_storage = SharedStorage
715
+
716
+ LOG2_E = math.log2(math.e)
717
+ if const_expr(self.score_mod is None):
718
+ # Without score_mod: bake scale into log2
719
+ softmax_scale_log2 = softmax_scale * LOG2_E
720
+ else:
721
+ # With score_mod: score_mod applied to S * softmax_scale, then use LOG2_E only
722
+ softmax_scale_log2 = LOG2_E
723
+
724
+ if const_expr(window_size_left is not None):
725
+ window_size_left = Int32(window_size_left)
726
+ if const_expr(window_size_right is not None):
727
+ window_size_right = Int32(window_size_right)
728
+
729
+ fastdiv_mods = None
730
+ if const_expr(aux_tensors is not None):
731
+ seqlen_q = cute.size(mQ.shape[0]) // (
732
+ self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
733
+ )
734
+ seqlen_k = cute.size(mK.shape[0])
735
+ seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
736
+ seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
737
+ fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
738
+ self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
739
+
740
+ if const_expr(self.use_block_sparsity or aux_tensors is not None):
741
+ assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), (
742
+ "Variable sequence length is not supported yet for blocksparse or aux tensors in bwd"
743
+ )
744
+
745
+ self.kernel(
746
+ tma_tensor_Q,
747
+ tma_tensor_K,
748
+ tma_tensor_V,
749
+ mLSE,
750
+ mdPsum,
751
+ tma_tensor_dO,
752
+ mdV,
753
+ mdK,
754
+ mdQaccum,
755
+ mdV_tma_tensor,
756
+ mdK_tma_tensor,
757
+ mdQ_semaphore,
758
+ mdK_semaphore,
759
+ mdV_semaphore,
760
+ mCuSeqlensQ,
761
+ mCuSeqlensK,
762
+ mSeqUsedQ,
763
+ mSeqUsedK,
764
+ tma_atom_Q,
765
+ tma_atom_K,
766
+ tma_atom_V,
767
+ tma_atom_dO,
768
+ tma_atom_dV,
769
+ tma_atom_dK,
770
+ self.sQ_layout,
771
+ self.sQt_layout,
772
+ self.sK_layout,
773
+ self.sV_layout,
774
+ self.sLSE_layout,
775
+ self.sdPsum_layout,
776
+ self.sdO_layout,
777
+ self.sdOt_layout,
778
+ self.sdSt_layout,
779
+ self.sdS_layout,
780
+ self.sKt_layout,
781
+ self.sdQaccum_layout,
782
+ self.sdKV_layout,
783
+ self.tP_layout,
784
+ self.tdS_layout,
785
+ self.tiled_mma_S,
786
+ self.tiled_mma_dP,
787
+ self.tiled_mma_dV,
788
+ self.tiled_mma_dK,
789
+ self.tiled_mma_dQ,
790
+ tiled_copy_r2s_dKV,
791
+ softmax_scale,
792
+ softmax_scale_log2,
793
+ window_size_left,
794
+ window_size_right,
795
+ tile_sched_params,
796
+ aux_tensors,
797
+ fastdiv_mods,
798
+ blocksparse_tensors,
799
+ ).launch(
800
+ grid=grid_dim,
801
+ block=[self.threads_per_cta, 1, 1],
802
+ cluster=self.cluster_shape_mnk if cute.size(self.cluster_shape_mnk) > 1 else None,
803
+ smem=self.shared_storage.size_in_bytes(),
804
+ stream=stream,
805
+ min_blocks_per_mp=1,
806
+ )
807
+
808
+ @cute.kernel
809
+ def kernel(
810
+ self,
811
+ mQ: cute.Tensor,
812
+ mK: cute.Tensor,
813
+ mV: cute.Tensor,
814
+ mLSE: cute.Tensor,
815
+ mdPsum: cute.Tensor,
816
+ mdO: cute.Tensor,
817
+ mdV: cute.Tensor,
818
+ mdK: cute.Tensor,
819
+ mdQaccum: cute.Tensor,
820
+ mdV_tma_tensor: Optional[cute.Tensor],
821
+ mdK_tma_tensor: Optional[cute.Tensor],
822
+ mdQ_semaphore: Optional[cute.Tensor],
823
+ mdK_semaphore: Optional[cute.Tensor],
824
+ mdV_semaphore: Optional[cute.Tensor],
825
+ mCuSeqlensQ: Optional[cute.Tensor],
826
+ mCuSeqlensK: Optional[cute.Tensor],
827
+ mSeqUsedQ: Optional[cute.Tensor],
828
+ mSeqUsedK: Optional[cute.Tensor],
829
+ tma_atom_Q: cute.CopyAtom,
830
+ tma_atom_K: cute.CopyAtom,
831
+ tma_atom_V: cute.CopyAtom,
832
+ tma_atom_dO: cute.CopyAtom,
833
+ tma_atom_dV: Optional[cute.CopyAtom],
834
+ tma_atom_dK: Optional[cute.CopyAtom],
835
+ sQ_layout: cute.ComposedLayout,
836
+ sQt_layout: cute.ComposedLayout,
837
+ sK_layout: cute.ComposedLayout,
838
+ sV_layout: cute.ComposedLayout,
839
+ sLSE_layout: cute.Layout,
840
+ sdPsum_layout: cute.Layout,
841
+ sdO_layout: cute.ComposedLayout,
842
+ sdOt_layout: cute.ComposedLayout,
843
+ sdSt_layout: cute.ComposedLayout,
844
+ sdS_layout: cute.ComposedLayout,
845
+ sKt_layout: cute.ComposedLayout,
846
+ sdQaccum_layout: cute.Layout,
847
+ sdKV_layout: cute.ComposedLayout | cute.Layout,
848
+ tP_layout: cute.ComposedLayout,
849
+ tdS_layout: cute.ComposedLayout,
850
+ tiled_mma_S: cute.TiledMma,
851
+ tiled_mma_dP: cute.TiledMma,
852
+ tiled_mma_dV: cute.TiledMma,
853
+ tiled_mma_dK: cute.TiledMma,
854
+ tiled_mma_dQ: cute.TiledMma,
855
+ tiled_copy_r2s_dKV: cute.TiledCopy,
856
+ softmax_scale: cutlass.Float32,
857
+ softmax_scale_log2: cutlass.Float32,
858
+ window_size_left: Optional[Int32],
859
+ window_size_right: Optional[Int32],
860
+ tile_sched_params: ParamsBase,
861
+ aux_tensors: Optional[list] = None,
862
+ fastdiv_mods=(None, None),
863
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
864
+ ):
865
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
866
+
867
+ # Prefetch tma descriptor
868
+ if warp_idx == self.load_warp_id:
869
+ with cute.arch.elect_one():
870
+ cpasync.prefetch_descriptor(tma_atom_Q)
871
+ cpasync.prefetch_descriptor(tma_atom_K)
872
+ cpasync.prefetch_descriptor(tma_atom_V)
873
+ cpasync.prefetch_descriptor(tma_atom_dO)
874
+ if const_expr(tma_atom_dV is not None):
875
+ cpasync.prefetch_descriptor(tma_atom_dV)
876
+ if const_expr(tma_atom_dK is not None):
877
+ cpasync.prefetch_descriptor(tma_atom_dK)
878
+
879
+ cluster_layout_vmnk = cute.tiled_divide(
880
+ cute.make_layout(self.cluster_shape_mnk),
881
+ (tiled_mma_S.thr_id.shape,),
882
+ )
883
+
884
+ # Alloc
885
+ smem = cutlass.utils.SmemAllocator()
886
+ storage = smem.allocate(self.shared_storage)
887
+
888
+ tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr()
889
+ dQ_cluster_full_mbar_ptr = storage.dQ_cluster_full_mbar_ptr.data_ptr()
890
+ dQ_cluster_empty_mbar_ptr = storage.dQ_cluster_empty_mbar_ptr.data_ptr()
891
+
892
+ if warp_idx == 1:
893
+ cute.arch.mbarrier_init(
894
+ tmem_dealloc_mbar_ptr, cute.arch.WARP_SIZE * len(self.compute_warp_ids)
895
+ )
896
+ if const_expr(self.cluster_reduce_dQ):
897
+ if warp_idx == 4:
898
+ for i in range(self.dQaccum_reduce_stage // 2):
899
+ cute.arch.mbarrier_init(dQ_cluster_full_mbar_ptr + i, 1)
900
+ cute.arch.mbarrier_init(dQ_cluster_empty_mbar_ptr + i, 1)
901
+
902
+ # UMMA producers and AsyncThread consumers
903
+ pipeline_producer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup(
904
+ cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])
905
+ )
906
+ # Only 1 thread per warp will signal
907
+ pipeline_consumer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup(
908
+ cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids)
909
+ )
910
+ pipeline_S_P = cutlass.pipeline.PipelineUmmaAsync.create(
911
+ num_stages=1,
912
+ producer_group=pipeline_producer_group_MMA_AsyncThread,
913
+ consumer_group=pipeline_consumer_group_MMA_AsyncThread,
914
+ barrier_storage=storage.S_mbar_ptr.data_ptr(),
915
+ )
916
+ pipeline_dP = cutlass.pipeline.PipelineUmmaAsync.create(
917
+ num_stages=1,
918
+ producer_group=pipeline_producer_group_MMA_AsyncThread,
919
+ consumer_group=pipeline_consumer_group_MMA_AsyncThread,
920
+ barrier_storage=storage.dP_mbar_ptr.data_ptr(),
921
+ )
922
+ pipeline_dKV = cutlass.pipeline.PipelineUmmaAsync.create(
923
+ num_stages=2,
924
+ producer_group=pipeline_producer_group_MMA_AsyncThread,
925
+ consumer_group=pipeline_consumer_group_MMA_AsyncThread,
926
+ barrier_storage=storage.dKV_mbar_ptr.data_ptr(),
927
+ )
928
+ pipeline_consumer_group_MMA_AsyncThread_dQ = cutlass.pipeline.CooperativeGroup(
929
+ cutlass.pipeline.Agent.Thread,
930
+ len(self.reduce_warp_ids),
931
+ ) # Compute
932
+ pipeline_dQ = cutlass.pipeline.PipelineUmmaAsync.create(
933
+ num_stages=1,
934
+ producer_group=pipeline_producer_group_MMA_AsyncThread,
935
+ consumer_group=pipeline_consumer_group_MMA_AsyncThread_dQ,
936
+ barrier_storage=storage.dQ_mbar_ptr.data_ptr(),
937
+ )
938
+
939
+ # AsyncThread producers and UMMA consumers
940
+ # Only 1 thread per warp will signal
941
+ pipeline_PdS_producer_group = cutlass.pipeline.CooperativeGroup(
942
+ cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids)
943
+ ) # Compute
944
+ pipeline_PdS_consumer_group = cutlass.pipeline.CooperativeGroup(
945
+ cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])
946
+ ) # MMA
947
+ pipeline_dS = cutlass.pipeline.PipelineAsyncUmma.create(
948
+ num_stages=1,
949
+ producer_group=pipeline_PdS_producer_group,
950
+ consumer_group=pipeline_PdS_consumer_group,
951
+ barrier_storage=storage.dS_mbar_ptr.data_ptr(),
952
+ )
953
+
954
+ # TMA producer and UMMA consumers
955
+ pipeline_producer_group = cutlass.pipeline.CooperativeGroup(
956
+ cutlass.pipeline.Agent.Thread, len([self.load_warp_id])
957
+ )
958
+ # The arrive count is the number of mcast size
959
+ pipeline_consumer_group = cutlass.pipeline.CooperativeGroup(
960
+ cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) * self.num_mcast_ctas_b
961
+ )
962
+ pipeline_consumer_group_compute = cutlass.pipeline.CooperativeGroup(
963
+ # cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) * self.num_mcast_ctas_b
964
+ cutlass.pipeline.Agent.Thread,
965
+ len(self.compute_warp_ids) * 1,
966
+ )
967
+ pipeline_LSE = cutlass.pipeline.PipelineTmaAsync.create(
968
+ barrier_storage=storage.LSE_mbar_ptr.data_ptr(),
969
+ num_stages=self.Q_stage,
970
+ producer_group=pipeline_producer_group,
971
+ consumer_group=pipeline_consumer_group_compute,
972
+ tx_count=self.tma_copy_bytes["LSE"],
973
+ # cta_layout_vmnk=cluster_layout_vmnk,
974
+ # init_wait=False,
975
+ )
976
+ pipeline_dPsum = cutlass.pipeline.PipelineTmaAsync.create(
977
+ barrier_storage=storage.dPsum_mbar_ptr.data_ptr(),
978
+ num_stages=self.dO_stage,
979
+ producer_group=pipeline_producer_group,
980
+ consumer_group=pipeline_consumer_group_compute,
981
+ tx_count=self.tma_copy_bytes["dPsum"],
982
+ # cta_layout_vmnk=cluster_layout_vmnk,
983
+ # init_wait=False,
984
+ )
985
+ pipeline_Q = pipeline.PipelineTmaUmma.create(
986
+ barrier_storage=storage.Q_mbar_ptr.data_ptr(),
987
+ num_stages=self.Q_stage,
988
+ producer_group=pipeline_producer_group,
989
+ consumer_group=pipeline_consumer_group,
990
+ tx_count=self.tma_copy_bytes["Q"],
991
+ cta_layout_vmnk=cluster_layout_vmnk,
992
+ init_wait=False,
993
+ )
994
+ pipeline_dO = pipeline.PipelineTmaUmma.create(
995
+ barrier_storage=storage.dO_mbar_ptr.data_ptr(),
996
+ num_stages=self.dO_stage,
997
+ producer_group=pipeline_producer_group,
998
+ consumer_group=pipeline_consumer_group,
999
+ tx_count=self.tma_copy_bytes["dO"],
1000
+ cta_layout_vmnk=cluster_layout_vmnk,
1001
+ init_wait=True,
1002
+ )
1003
+
1004
+ sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner, dtype=self.q_dtype)
1005
+ sQt = cute.make_tensor(
1006
+ cute.recast_ptr(sQ.iterator, sQt_layout.inner, dtype=self.q_dtype), sQt_layout.outer
1007
+ )
1008
+ sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)
1009
+ sKt = cute.make_tensor(cute.recast_ptr(sK.iterator, sKt_layout.inner), sKt_layout.outer)
1010
+ sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner)
1011
+ sdSt = storage.sdS.get_tensor(sdSt_layout.outer, swizzle=sdSt_layout.inner)
1012
+ sdS = cute.make_tensor(cute.recast_ptr(sdSt.iterator, sdS_layout.inner), sdS_layout.outer)
1013
+ sdO = storage.sdO.get_tensor(
1014
+ sdO_layout.outer, swizzle=sdO_layout.inner, dtype=self.do_dtype
1015
+ )
1016
+ sdOt = cute.make_tensor(
1017
+ cute.recast_ptr(sdO.iterator, sdOt_layout.inner, dtype=self.do_dtype), sdOt_layout.outer
1018
+ )
1019
+ sLSE = storage.sLSE.get_tensor(sLSE_layout)
1020
+ sdPsum = storage.sdPsum.get_tensor(sdPsum_layout)
1021
+ if const_expr(not self.dKV_postprocess):
1022
+ sdV = storage.sdO.get_tensor(
1023
+ sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dv_dtype
1024
+ )
1025
+ sdK = storage.sQ.get_tensor(
1026
+ sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dk_dtype
1027
+ )
1028
+ else:
1029
+ sdV = storage.sdO.get_tensor(sdKV_layout, dtype=self.dv_dtype)
1030
+ sdK = storage.sQ.get_tensor(sdKV_layout, dtype=self.dk_dtype)
1031
+
1032
+ # Buffer sizing is guaranteed by max(...) in SharedStorage declarations
1033
+ # for both sQ (reused as sdK) and sdO (reused as sdV)
1034
+
1035
+ sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout)
1036
+
1037
+ # TMEM
1038
+ # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always
1039
+ # request 512 columns of tmem, so we know that it starts at 0.
1040
+ tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16)
1041
+ # S
1042
+ thr_mma_S = tiled_mma_S.get_slice(0)
1043
+ Sacc_shape = thr_mma_S.partition_shape_C(self.mma_tiler_kq[:2]) # (M, N)
1044
+ tStS = thr_mma_S.make_fragment_C(Sacc_shape)
1045
+ # (MMA, MMA_M, MMA_N)
1046
+ tStS = cute.make_tensor(tmem_ptr + self.tmem_S_offset, tStS.layout)
1047
+ # dP
1048
+ thr_mma_dP = tiled_mma_dP.get_slice(0)
1049
+ dPacc_shape = thr_mma_dP.partition_shape_C(self.mma_tiler_vdo[:2])
1050
+ tdPtdP = thr_mma_dP.make_fragment_C(dPacc_shape)
1051
+ tdPtdP = cute.make_tensor(tmem_ptr + self.tmem_dP_offset, tdPtdP.layout)
1052
+ # dV
1053
+ thr_mma_dV = tiled_mma_dV.get_slice(0)
1054
+ dvacc_shape = thr_mma_dV.partition_shape_C(self.mma_tiler_pdo[:2])
1055
+ tdVtdV = thr_mma_dV.make_fragment_C(dvacc_shape)
1056
+ tdVtdV = cute.make_tensor(tmem_ptr + self.tmem_dV_offset, tdVtdV.layout)
1057
+ tP = cute.make_tensor(
1058
+ cute.recast_ptr(tmem_ptr + self.tmem_P_offset, dtype=self.do_dtype), tP_layout.outer
1059
+ )
1060
+ # dK
1061
+ thr_mma_dK = tiled_mma_dK.get_slice(0)
1062
+ dkacc_shape = thr_mma_dK.partition_shape_C(self.mma_tiler_dsq[:2])
1063
+ tdKtdK = thr_mma_dK.make_fragment_C(dkacc_shape)
1064
+ tdKtdK = cute.make_tensor(tmem_ptr + self.tmem_dK_offset, tdKtdK.layout)
1065
+ tdS = cute.make_tensor(
1066
+ cute.recast_ptr(tmem_ptr + self.tmem_dS_offset, dtype=self.ds_dtype), tdS_layout.outer
1067
+ )
1068
+ # dQ
1069
+ thr_mma_dQ = tiled_mma_dQ.get_slice(0)
1070
+ dQacc_shape = thr_mma_dQ.partition_shape_C(self.mma_tiler_dsk[:2])
1071
+ tdQtdQ = thr_mma_dQ.make_fragment_C(dQacc_shape)
1072
+ tdQtdQ = cute.make_tensor(tmem_ptr + self.tmem_dQ_offset, tdQtdQ.layout)
1073
+
1074
+ block_info = BlockInfo(
1075
+ self.tile_m,
1076
+ # self.tile_n,
1077
+ self.tile_n * self.cluster_shape_mnk[0], # careful, this case is not very well-tested
1078
+ self.is_causal,
1079
+ self.is_local,
1080
+ False, # is_split_kv
1081
+ window_size_left,
1082
+ window_size_right,
1083
+ qhead_per_kvhead_packgqa=1,
1084
+ )
1085
+ SeqlenInfoCls = partial(
1086
+ SeqlenInfoQK.create,
1087
+ seqlen_q_static=mQ.shape[0],
1088
+ seqlen_k_static=mK.shape[0],
1089
+ mCuSeqlensQ=mCuSeqlensQ,
1090
+ mCuSeqlensK=mCuSeqlensK,
1091
+ mSeqUsedQ=mSeqUsedQ,
1092
+ mSeqUsedK=mSeqUsedK,
1093
+ tile_m=self.tile_m,
1094
+ tile_n=self.tile_n,
1095
+ )
1096
+ TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params)
1097
+
1098
+ AttentionMaskCls = partial(
1099
+ AttentionMask,
1100
+ self.tile_m,
1101
+ self.tile_n,
1102
+ swap_AB=True,
1103
+ window_size_left=window_size_left,
1104
+ window_size_right=window_size_right,
1105
+ )
1106
+
1107
+ # EMPTY
1108
+ # (15)
1109
+ if warp_idx == self.empty_warp_id:
1110
+ cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
1111
+
1112
+ # EPI
1113
+ # (14)
1114
+ if warp_idx == self.epi_warp_id:
1115
+ # currently no-op, could use for tma store/reduce
1116
+ cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
1117
+
1118
+ # LOAD
1119
+ # (13)
1120
+ if warp_idx == self.load_warp_id:
1121
+ cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
1122
+ self.load(
1123
+ thr_mma_S,
1124
+ thr_mma_dP,
1125
+ thr_mma_dV,
1126
+ mQ,
1127
+ mK,
1128
+ mV,
1129
+ mLSE,
1130
+ mdPsum,
1131
+ mdO,
1132
+ sQ,
1133
+ sK,
1134
+ sV,
1135
+ sLSE,
1136
+ sdPsum,
1137
+ sdO,
1138
+ tma_atom_Q,
1139
+ tma_atom_K,
1140
+ tma_atom_V,
1141
+ tma_atom_dO,
1142
+ pipeline_Q,
1143
+ pipeline_dO,
1144
+ pipeline_LSE,
1145
+ pipeline_dPsum,
1146
+ cluster_layout_vmnk,
1147
+ block_info,
1148
+ SeqlenInfoCls,
1149
+ TileSchedulerCls,
1150
+ blocksparse_tensors,
1151
+ should_load_Q=True,
1152
+ should_load_dO=True,
1153
+ )
1154
+
1155
+ # MMA
1156
+ # (12)
1157
+ if warp_idx == self.mma_warp_id:
1158
+ cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
1159
+
1160
+ # Alloc tmem buffer
1161
+ tmem_alloc_cols = Int32(self.tmem_alloc_cols)
1162
+ cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf)
1163
+ cute.arch.sync_warp()
1164
+
1165
+ self.mma(
1166
+ tiled_mma_S,
1167
+ tiled_mma_dP,
1168
+ tiled_mma_dV,
1169
+ tiled_mma_dK,
1170
+ tiled_mma_dQ,
1171
+ sQ,
1172
+ sQt,
1173
+ sK,
1174
+ sV,
1175
+ sdO,
1176
+ sdOt,
1177
+ sdSt,
1178
+ sdS,
1179
+ sKt,
1180
+ tP,
1181
+ tdS,
1182
+ tStS,
1183
+ tdPtdP,
1184
+ tdVtdV,
1185
+ tdKtdK,
1186
+ tdQtdQ,
1187
+ pipeline_Q.make_consumer(),
1188
+ pipeline_dO,
1189
+ pipeline_S_P,
1190
+ pipeline_dS,
1191
+ pipeline_dKV,
1192
+ pipeline_dP,
1193
+ pipeline_dQ,
1194
+ block_info,
1195
+ SeqlenInfoCls,
1196
+ TileSchedulerCls,
1197
+ blocksparse_tensors,
1198
+ )
1199
+ cute.arch.relinquish_tmem_alloc_permit()
1200
+ tmem_ptr = cute.arch.retrieve_tmem_ptr(
1201
+ Float32, alignment=16, ptr_to_buffer_holding_addr=storage.tmem_holding_buf
1202
+ )
1203
+
1204
+ cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0)
1205
+ tmem_alloc_cols = Int32(self.tmem_alloc_cols)
1206
+ cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols, is_two_cta=False)
1207
+
1208
+ # Compute
1209
+ # (4, 5, 6, 7, 8, 9, 10, 11) --> 8 warps
1210
+ if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]:
1211
+ cute.arch.warpgroup_reg_alloc(self.num_regs_compute) # 8 warps
1212
+ self.compute_loop(
1213
+ thr_mma_S,
1214
+ thr_mma_dP,
1215
+ thr_mma_dV,
1216
+ thr_mma_dK,
1217
+ tStS,
1218
+ sLSE,
1219
+ sdPsum,
1220
+ tdVtdV,
1221
+ tdKtdK,
1222
+ mdV,
1223
+ mdK,
1224
+ sdS,
1225
+ tdPtdP,
1226
+ pipeline_LSE,
1227
+ pipeline_dPsum,
1228
+ pipeline_S_P,
1229
+ pipeline_dS,
1230
+ pipeline_dKV,
1231
+ pipeline_dP,
1232
+ softmax_scale,
1233
+ softmax_scale_log2,
1234
+ block_info,
1235
+ SeqlenInfoCls,
1236
+ AttentionMaskCls,
1237
+ TileSchedulerCls,
1238
+ sdV,
1239
+ sdK,
1240
+ mdV_tma_tensor,
1241
+ mdK_tma_tensor,
1242
+ tma_atom_dV,
1243
+ tma_atom_dK,
1244
+ tiled_copy_r2s_dKV,
1245
+ mdK_semaphore,
1246
+ mdV_semaphore,
1247
+ aux_tensors,
1248
+ fastdiv_mods,
1249
+ blocksparse_tensors,
1250
+ )
1251
+ cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr)
1252
+
1253
+ # Reduce
1254
+ # (0, 1, 2, 3) - dQ
1255
+ if warp_idx >= self.reduce_warp_ids[0] and warp_idx <= self.reduce_warp_ids[-1]:
1256
+ cute.arch.warpgroup_reg_alloc(self.num_regs_reduce)
1257
+ self.dQacc_reduce(
1258
+ mdQaccum,
1259
+ sdQaccum,
1260
+ thr_mma_dQ,
1261
+ tdQtdQ,
1262
+ pipeline_dQ,
1263
+ block_info,
1264
+ SeqlenInfoCls,
1265
+ TileSchedulerCls,
1266
+ mdQ_semaphore,
1267
+ blocksparse_tensors,
1268
+ )
1269
+
1270
+ return
1271
+
1272
+ @cute.jit
1273
+ def load(
1274
+ self,
1275
+ thr_mma_S: cute.core.ThrMma,
1276
+ thr_mma_dP: cute.core.ThrMma,
1277
+ thr_mma_dV: cute.core.ThrMma,
1278
+ mQ: cute.Tensor,
1279
+ mK: cute.Tensor,
1280
+ mV: cute.Tensor,
1281
+ mLSE: cute.Tensor,
1282
+ mdPsum: cute.Tensor,
1283
+ mdO: cute.Tensor,
1284
+ sQ: cute.Tensor,
1285
+ sK: cute.Tensor,
1286
+ sV: cute.Tensor,
1287
+ sLSE: cute.Tensor,
1288
+ sdPsum: cute.Tensor,
1289
+ sdO: cute.Tensor,
1290
+ tma_atom_Q: cute.CopyAtom,
1291
+ tma_atom_K: cute.CopyAtom,
1292
+ tma_atom_V: cute.CopyAtom,
1293
+ tma_atom_dO: cute.CopyAtom,
1294
+ pipeline_Q: PipelineAsync,
1295
+ pipeline_dO: PipelineAsync,
1296
+ pipeline_LSE: PipelineAsync,
1297
+ pipeline_dPsum: PipelineAsync,
1298
+ cluster_layout_vmnk: cute.Layout,
1299
+ block_info: BlockInfo,
1300
+ SeqlenInfoCls: Callable,
1301
+ TileSchedulerCls: Callable,
1302
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
1303
+ should_load_Q: bool = True,
1304
+ should_load_dO: bool = True,
1305
+ ):
1306
+ producer_state_Q_LSE = cutlass.pipeline.make_pipeline_state(
1307
+ cutlass.pipeline.PipelineUserType.Producer, self.Q_stage
1308
+ )
1309
+ producer_state_dO_dPsum = cutlass.pipeline.make_pipeline_state(
1310
+ cutlass.pipeline.PipelineUserType.Producer, self.dO_stage
1311
+ )
1312
+
1313
+ # Compute multicast mask for Q & dO buffer full
1314
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
1315
+ block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
1316
+ q_do_mcast_mask = None
1317
+ if const_expr(self.is_q_do_mcast):
1318
+ q_do_mcast_mask = cpasync.create_tma_multicast_mask(
1319
+ cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1
1320
+ )
1321
+
1322
+ tile_scheduler = TileSchedulerCls()
1323
+ work_tile = tile_scheduler.initial_work_tile_info()
1324
+ while work_tile.is_valid_tile:
1325
+ n_block, head_idx, batch_idx, _ = work_tile.tile_idx
1326
+ seqlen = SeqlenInfoCls(batch_idx)
1327
+ m_block_min, m_block_max = block_info.get_m_block_min_max(
1328
+ seqlen, n_block // self.cluster_shape_mnk[0]
1329
+ )
1330
+ head_idx_kv = head_idx // self.qhead_per_kvhead
1331
+ mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
1332
+ mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv]
1333
+ mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv]
1334
+ if const_expr(not seqlen.has_cu_seqlens_q):
1335
+ mdO_cur = mdO[None, None, head_idx, batch_idx]
1336
+ else:
1337
+ mdO_cur = cute.domain_offset((0, seqlen.offset_q), mdO[None, None, head_idx])
1338
+ mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2, padded=True)[None, head_idx]
1339
+ mdPsum_cur = seqlen.offset_batch_Q(mdPsum, batch_idx, dim=2, padded=True)[
1340
+ None, head_idx
1341
+ ]
1342
+
1343
+ gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block, 0))
1344
+ tSgK = thr_mma_S.partition_A(gK)
1345
+ gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_vdo, mode=[0, 2]), (n_block, 0))
1346
+ tdPgV = thr_mma_dP.partition_A(gV)
1347
+ gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0))
1348
+ tSgQ = thr_mma_S.partition_B(gQ)
1349
+ gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,))
1350
+ gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,))
1351
+ gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None))
1352
+ tdPgdO = thr_mma_dV.partition_B(gdO)
1353
+
1354
+ load_K, _, _ = copy_utils.tma_get_copy_fn(
1355
+ tma_atom_K, 0, cute.make_layout(1), tSgK, sK, single_stage=True
1356
+ )
1357
+ load_V, _, _ = copy_utils.tma_get_copy_fn(
1358
+ tma_atom_V,
1359
+ 0,
1360
+ cute.make_layout(1),
1361
+ tdPgV,
1362
+ sV,
1363
+ single_stage=True,
1364
+ )
1365
+ b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)
1366
+ load_Q, _, _ = copy_utils.tma_get_copy_fn(
1367
+ tma_atom_Q,
1368
+ cta_coord=block_in_cluster_coord_vmnk[1],
1369
+ cta_layout=b_cta_layout,
1370
+ src_tensor=tSgQ,
1371
+ dst_tensor=sQ,
1372
+ mcast_mask=q_do_mcast_mask,
1373
+ )
1374
+ load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q)
1375
+ load_dO, _, _ = copy_utils.tma_get_copy_fn(
1376
+ tma_atom_dO,
1377
+ cta_coord=block_in_cluster_coord_vmnk[1],
1378
+ cta_layout=b_cta_layout,
1379
+ src_tensor=tdPgdO,
1380
+ dst_tensor=sdO,
1381
+ mcast_mask=q_do_mcast_mask,
1382
+ )
1383
+ load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO)
1384
+ copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32)
1385
+ copy_stats = partial(cute.copy, copy_atom_stats)
1386
+ # copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SMulticastOp(), Float32)
1387
+ # sLSE = cute.logical_divide(sLSE, (64,))[(None, block_in_cluster_coord_vmnk[1]), None]
1388
+ # gLSE = cute.logical_divide(gLSE, (64,))[(None, block_in_cluster_coord_vmnk[1]), None]
1389
+ # sdPsum = cute.logical_divide(sdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None]
1390
+ # gdPsum = cute.logical_divide(gdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None]
1391
+ # copy_stats = partial(cute.copy, copy_atom_stats, mcast_mask=q_do_mcast_mask)
1392
+
1393
+ # some tiles might be empty due to block sparsity
1394
+ if const_expr(self.use_block_sparsity):
1395
+ total_m_block_cnt = get_total_q_block_count_bwd(
1396
+ blocksparse_tensors,
1397
+ batch_idx,
1398
+ head_idx,
1399
+ n_block,
1400
+ subtile_factor=self.subtile_factor,
1401
+ m_block_max=m_block_max,
1402
+ )
1403
+ process_tile = total_m_block_cnt > Int32(0)
1404
+ else:
1405
+ process_tile = (
1406
+ const_expr(not self.is_local and not self.is_varlen_q)
1407
+ or m_block_min < m_block_max
1408
+ )
1409
+
1410
+ if process_tile:
1411
+ if const_expr(self.use_block_sparsity):
1412
+ producer_state_Q_LSE, producer_state_dO_dPsum = (
1413
+ produce_block_sparse_q_loads_bwd_sm100(
1414
+ blocksparse_tensors,
1415
+ batch_idx,
1416
+ head_idx,
1417
+ n_block,
1418
+ producer_state_Q_LSE,
1419
+ producer_state_dO_dPsum,
1420
+ pipeline_Q,
1421
+ pipeline_LSE,
1422
+ pipeline_dO,
1423
+ pipeline_dPsum,
1424
+ load_K,
1425
+ load_V,
1426
+ load_Q,
1427
+ load_dO,
1428
+ copy_stats,
1429
+ gLSE,
1430
+ sLSE,
1431
+ gdPsum,
1432
+ sdPsum,
1433
+ self.tma_copy_bytes["K"],
1434
+ self.tma_copy_bytes["V"],
1435
+ should_load_Q=should_load_Q,
1436
+ should_load_dO=should_load_dO,
1437
+ subtile_factor=self.subtile_factor,
1438
+ m_block_max=m_block_max,
1439
+ )
1440
+ )
1441
+ else:
1442
+ first_m_block = m_block_min
1443
+
1444
+ # First iteration: load K together w Q & LSE, then V together w dO & dPsum
1445
+ if const_expr(should_load_Q):
1446
+ pipeline_Q.producer_acquire(
1447
+ producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"]
1448
+ )
1449
+ load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE))
1450
+ load_Q(first_m_block, producer_state=producer_state_Q_LSE)
1451
+ pipeline_Q.producer_commit(producer_state_Q_LSE)
1452
+ pipeline_LSE.producer_acquire(producer_state_Q_LSE)
1453
+ with cute.arch.elect_one():
1454
+ copy_stats(
1455
+ gLSE[None, first_m_block],
1456
+ sLSE[None, producer_state_Q_LSE.index],
1457
+ mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE),
1458
+ )
1459
+ producer_state_Q_LSE.advance()
1460
+ if const_expr(should_load_dO):
1461
+ pipeline_dO.producer_acquire(
1462
+ producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["V"]
1463
+ )
1464
+ load_V(
1465
+ tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)
1466
+ )
1467
+ load_dO(first_m_block, producer_state=producer_state_dO_dPsum)
1468
+ pipeline_dO.producer_commit(producer_state_dO_dPsum)
1469
+ pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)
1470
+ with cute.arch.elect_one():
1471
+ copy_stats(
1472
+ gdPsum[None, first_m_block],
1473
+ sdPsum[None, producer_state_dO_dPsum.index],
1474
+ mbar_ptr=pipeline_dPsum.producer_get_barrier(
1475
+ producer_state_dO_dPsum
1476
+ ),
1477
+ )
1478
+ producer_state_dO_dPsum.advance()
1479
+
1480
+ # Dense path: iterate from m_block_min+1 to m_block_max
1481
+ for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1):
1482
+ if const_expr(should_load_Q):
1483
+ pipeline_Q.producer_acquire(producer_state_Q_LSE)
1484
+ load_Q(m_block, producer_state=producer_state_Q_LSE)
1485
+ pipeline_Q.producer_commit(producer_state_Q_LSE)
1486
+ pipeline_LSE.producer_acquire(producer_state_Q_LSE)
1487
+ with cute.arch.elect_one():
1488
+ copy_stats(
1489
+ gLSE[None, m_block],
1490
+ sLSE[None, producer_state_Q_LSE.index],
1491
+ mbar_ptr=pipeline_LSE.producer_get_barrier(
1492
+ producer_state_Q_LSE
1493
+ ),
1494
+ )
1495
+ producer_state_Q_LSE.advance()
1496
+ if const_expr(should_load_dO):
1497
+ pipeline_dO.producer_acquire(producer_state_dO_dPsum)
1498
+ load_dO(m_block, producer_state=producer_state_dO_dPsum)
1499
+ pipeline_dO.producer_commit(producer_state_dO_dPsum)
1500
+ pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)
1501
+ with cute.arch.elect_one():
1502
+ copy_stats(
1503
+ gdPsum[None, m_block],
1504
+ sdPsum[None, producer_state_dO_dPsum.index],
1505
+ mbar_ptr=pipeline_dPsum.producer_get_barrier(
1506
+ producer_state_dO_dPsum
1507
+ ),
1508
+ )
1509
+ producer_state_dO_dPsum.advance()
1510
+
1511
+ if const_expr(should_load_Q):
1512
+ pipeline_Q.producer_tail(
1513
+ producer_state_Q_LSE.clone()
1514
+ ) # will hang if we don't clone
1515
+ pipeline_LSE.producer_tail(producer_state_Q_LSE)
1516
+ if const_expr(should_load_dO):
1517
+ pipeline_dO.producer_tail(producer_state_dO_dPsum.clone())
1518
+ pipeline_dPsum.producer_tail(producer_state_dO_dPsum)
1519
+
1520
+ tile_scheduler.prefetch_next_work()
1521
+ tile_scheduler.advance_to_next_work()
1522
+ work_tile = tile_scheduler.get_current_work()
1523
+
1524
+ @cute.jit
1525
+ def mma(
1526
+ self,
1527
+ tiled_mma_S: cute.TiledMma,
1528
+ tiled_mma_dP: cute.TiledMma,
1529
+ tiled_mma_dV: cute.TiledMma,
1530
+ tiled_mma_dK: cute.TiledMma,
1531
+ tiled_mma_dQ: cute.TiledMma,
1532
+ sQ: cute.Tensor,
1533
+ sQt: cute.Tensor,
1534
+ sK: cute.Tensor,
1535
+ sV: cute.Tensor,
1536
+ sdO: cute.Tensor,
1537
+ sdOt: cute.Tensor,
1538
+ sdSt: cute.Tensor,
1539
+ sdS: cute.Tensor,
1540
+ sKt: cute.Tensor,
1541
+ tP: cute.Tensor,
1542
+ tdS: cute.Tensor,
1543
+ tStS: cute.Tensor,
1544
+ tdPtdP: cute.Tensor,
1545
+ tdVtdV: cute.Tensor,
1546
+ tdKtdK: cute.Tensor,
1547
+ tdQtdQ: cute.Tensor,
1548
+ pipeline_Q_consumer: PipelineConsumer,
1549
+ pipeline_dO: PipelineAsync,
1550
+ pipeline_S_P: PipelineAsync,
1551
+ pipeline_dS: PipelineAsync,
1552
+ pipeline_dKV: PipelineAsync,
1553
+ pipeline_dP: PipelineAsync,
1554
+ pipeline_dQ: PipelineAsync,
1555
+ block_info: BlockInfo,
1556
+ SeqlenInfoCls: Callable,
1557
+ TileSchedulerCls: Callable,
1558
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
1559
+ ):
1560
+ # [2025-10-21] For reasons I don't understand, putting these partitioning in the main
1561
+ # kernel (before warp specialization) is a lot slower tha putting them here.
1562
+ # Partition smem / tmem tensors
1563
+ # S = K @ Q.T
1564
+ tSrK = tiled_mma_S.make_fragment_A(sK)
1565
+ tSrQ = tiled_mma_S.make_fragment_B(sQ)
1566
+ # dP = V @ dO.T
1567
+ tdPrV = tiled_mma_dP.make_fragment_A(sV)
1568
+ tdPrdOt = tiled_mma_dP.make_fragment_B(sdOt)
1569
+ # dK = dS.T @ Q
1570
+ if const_expr(self.use_smem_dS_for_mma_dK):
1571
+ tdKrdS = tiled_mma_dK.make_fragment_A(sdSt)
1572
+ else:
1573
+ tdKrdS = tiled_mma_dK.make_fragment_A(tdS)
1574
+ tdKrQ = tiled_mma_dK.make_fragment_B(sQt)
1575
+ # dQ = dS @ K
1576
+ tdQrdS = tiled_mma_dQ.make_fragment_A(sdS)
1577
+ tdQrK = tiled_mma_dQ.make_fragment_B(sKt)
1578
+ # dV = P @ dO.T
1579
+ tdVrdO = tiled_mma_dV.make_fragment_B(sdO)
1580
+ tdVrP = tiled_mma_dV.make_fragment_A(tP)
1581
+
1582
+ # mma_qk_fn = partial(gemm_w_idx, tiled_mma_S, tStS, tSrK, tSrQ, zero_init=True)
1583
+ mma_qk_fn = partial(
1584
+ gemm_ptx_w_idx, tiled_mma_S, tStS, tSrK, tSrQ, sA=sK, sB=sQ, zero_init=True
1585
+ )
1586
+ # mma_dov_fn = partial(gemm_w_idx, tiled_mma_dP, tdPtdP, tdPrV, tdPrdOt, zero_init=True)
1587
+ mma_dov_fn = partial(
1588
+ gemm_ptx_w_idx,
1589
+ tiled_mma_dP,
1590
+ tdPtdP,
1591
+ tdPrV,
1592
+ tdPrdOt,
1593
+ sA=sV,
1594
+ sB=sdOt,
1595
+ zero_init=True,
1596
+ )
1597
+ # mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO)
1598
+ mma_pdo_fn = partial(
1599
+ gemm_ptx_w_idx,
1600
+ tiled_mma_dV,
1601
+ tdVtdV,
1602
+ tdVrP,
1603
+ tdVrdO,
1604
+ sA=None,
1605
+ sB=sdO,
1606
+ tA_addr=self.tmem_P_offset,
1607
+ )
1608
+ mma_dsk_fn = partial(gemm_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, zero_init=True)
1609
+ # mma_dsk_fn = partial(
1610
+ # gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, zero_init=True
1611
+ # )
1612
+ if const_expr(self.use_smem_dS_for_mma_dK):
1613
+ mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ)
1614
+ else:
1615
+ # Need to explicitly pass in tA_addr for correctness
1616
+ mma_dsq_fn = partial(
1617
+ gemm_ptx_w_idx,
1618
+ tiled_mma_dK,
1619
+ tdKtdK,
1620
+ tdKrdS,
1621
+ tdKrQ,
1622
+ sA=None,
1623
+ sB=sQt,
1624
+ tA_addr=self.tmem_dS_offset,
1625
+ )
1626
+
1627
+ consumer_state_dO = cutlass.pipeline.make_pipeline_state(
1628
+ cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage
1629
+ )
1630
+ producer_phase_acc = Int32(1) # For S & P, dP, dQ
1631
+ consumer_state_dS = cutlass.pipeline.make_pipeline_state(
1632
+ cutlass.pipeline.PipelineUserType.Consumer, 1
1633
+ )
1634
+ # producer_state_dKV = cutlass.pipeline.make_pipeline_state(
1635
+ # cutlass.pipeline.PipelineUserType.Producer, 2
1636
+ # )
1637
+ producer_phase_dKV = Int32(1)
1638
+ cta_group = pipeline_S_P.cta_group
1639
+
1640
+ tile_scheduler = TileSchedulerCls()
1641
+ work_tile = tile_scheduler.initial_work_tile_info()
1642
+ while work_tile.is_valid_tile:
1643
+ n_block, head_idx, batch_idx, _ = work_tile.tile_idx
1644
+ seqlen = SeqlenInfoCls(batch_idx) # must be seqlen_k
1645
+ m_block_min, m_block_max = block_info.get_m_block_min_max(
1646
+ seqlen, n_block // self.cluster_shape_mnk[0]
1647
+ )
1648
+
1649
+ if const_expr(self.use_block_sparsity):
1650
+ block_iter_count = get_total_q_block_count_bwd(
1651
+ blocksparse_tensors,
1652
+ batch_idx,
1653
+ head_idx,
1654
+ n_block,
1655
+ subtile_factor=self.subtile_factor,
1656
+ m_block_max=m_block_max,
1657
+ )
1658
+ process_tile = block_iter_count > Int32(0)
1659
+ else:
1660
+ block_iter_count = m_block_max - m_block_min
1661
+ process_tile = (
1662
+ const_expr(not self.is_local and not self.is_varlen_q)
1663
+ or m_block_min < m_block_max
1664
+ )
1665
+
1666
+ if process_tile:
1667
+ accumulate_dK = False
1668
+ # -----------------------------------------------------------
1669
+ ###### Prologue
1670
+ # -----------------------------------------------------------
1671
+ # 1. S = Q0 @ K.T
1672
+ # 2. dP = V @ dO.T
1673
+ # 3. dV = P @ dO
1674
+ # 1) S = Q0 @ K.T
1675
+ handle_Q = pipeline_Q_consumer.wait_and_advance()
1676
+ pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc)
1677
+ mma_qk_fn(B_idx=handle_Q.index)
1678
+ # Don't release Q yet
1679
+ pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group)
1680
+
1681
+ # 2) dP = V @ dO.T
1682
+ pipeline_dO.consumer_wait(consumer_state_dO)
1683
+ pipeline_dP.sync_object_empty.wait(0, producer_phase_acc)
1684
+ # dQ uses the same tmem as dP
1685
+ pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc)
1686
+ mma_dov_fn(B_idx=consumer_state_dO.index)
1687
+ # Don't release dO yet
1688
+ pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group)
1689
+
1690
+ producer_phase_acc ^= 1
1691
+ # 3) dV = P.T @ dO
1692
+ # wait for P to be ready, which uses the same tmem as S
1693
+ pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc)
1694
+ mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True)
1695
+ pipeline_dO.consumer_release(consumer_state_dO)
1696
+ consumer_state_dO.advance()
1697
+ # -----------------------------------------------------------
1698
+ ###### MAIN LOOP
1699
+ # -----------------------------------------------------------
1700
+ # 1. S = K @ Q.T
1701
+ # 2. dQ = dS @ K
1702
+ # 3. dK = dS.T @ Q
1703
+ # 4. dP = V @ dO.T
1704
+ # 5. dV = P.T @ dO
1705
+
1706
+ # For block sparsity, we use block_iter_count; for dense, use m_block range
1707
+ # MMA doesn't need actual m_block indices, just the iteration count
1708
+ main_loop_iters = (
1709
+ block_iter_count - 1
1710
+ if const_expr(self.use_block_sparsity)
1711
+ else m_block_max - m_block_min - 1
1712
+ )
1713
+ for _ in cutlass.range(main_loop_iters, unroll=1):
1714
+ # 1) S = K @ Q_i
1715
+ handle_Q_next = pipeline_Q_consumer.wait_and_advance()
1716
+ # Don't need to wait for S, as P must have been ready ealier, i.e., S is ready
1717
+ mma_qk_fn(B_idx=handle_Q_next.index)
1718
+ pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group)
1719
+
1720
+ # 2-3)
1721
+ # Do dK = dS.T @ Q, then dQ = dS @ K if dS in tmem for first mma
1722
+ # Otherwise, reverse order
1723
+ pipeline_dS.consumer_wait(consumer_state_dS)
1724
+
1725
+ if const_expr(self.use_smem_dS_for_mma_dK):
1726
+ mma_dsk_fn()
1727
+ pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group)
1728
+ mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK)
1729
+ accumulate_dK = True
1730
+ handle_Q.release()
1731
+ else:
1732
+ mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK)
1733
+ accumulate_dK = True
1734
+ handle_Q.release()
1735
+ mma_dsk_fn()
1736
+ pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group)
1737
+
1738
+ # dP uses the same tmem as dQ
1739
+ # However, if dS is ready, then dP must have been ready,
1740
+ # so we don't need this wait before mma_dsk_fn()
1741
+ # pipeline_dP.sync_object_empty.wait(0, producer_phase_acc)
1742
+
1743
+ pipeline_dS.consumer_release(consumer_state_dS)
1744
+ consumer_state_dS.advance()
1745
+
1746
+ # 4) dP = V @ dO.T
1747
+ pipeline_dO.consumer_wait(consumer_state_dO)
1748
+ # dQ uses the same tmem as dP
1749
+ pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc)
1750
+ mma_dov_fn(B_idx=consumer_state_dO.index)
1751
+ pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group)
1752
+
1753
+ producer_phase_acc ^= 1
1754
+ # 5) dV += P @ dO
1755
+ # wait for P to be ready, which uses the same tmem as S
1756
+ pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc)
1757
+ mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False)
1758
+ pipeline_dO.consumer_release(consumer_state_dO)
1759
+ consumer_state_dO.advance()
1760
+
1761
+ handle_Q = handle_Q_next
1762
+
1763
+ pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group)
1764
+
1765
+ # signal to the epilogue that dV is ready
1766
+ # pipeline_dKV.producer_acquire(producer_state_dKV)
1767
+ pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV)
1768
+ # pipeline_dKV.producer_commit(producer_state_dKV)
1769
+ pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group)
1770
+ # producer_state_dKV.advance()
1771
+ # pipeline_dKV.producer_acquire(producer_state_dKV)
1772
+ pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV)
1773
+
1774
+ # -----------------------------------------------------------
1775
+ ###### Remaining 2
1776
+ # -----------------------------------------------------------
1777
+ # 1) dK += dS.T @ Q
1778
+ pipeline_dS.consumer_wait(consumer_state_dS)
1779
+ mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK)
1780
+ # signal to the epilogue that dK is ready
1781
+ # pipeline_dKV.producer_commit(producer_state_dKV)
1782
+ pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group)
1783
+ # producer_state_dKV.advance()
1784
+ producer_phase_dKV ^= 1
1785
+
1786
+ # 2) dQ = dS @ K
1787
+ # dS is done, so dP must have been ready, we don't need to wait
1788
+ mma_dsk_fn()
1789
+ pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group)
1790
+ # Wait until dQ is done before releasing Q, since K and Q0 uses the same mbarrier
1791
+ handle_Q.release()
1792
+ pipeline_dS.consumer_release(consumer_state_dS)
1793
+ consumer_state_dS.advance()
1794
+
1795
+ producer_phase_acc ^= 1
1796
+
1797
+ tile_scheduler.advance_to_next_work()
1798
+ work_tile = tile_scheduler.get_current_work()
1799
+
1800
+ # Currently it hangs if we have this S_P.producer_tail, will need to understand why
1801
+ # pipeline_S_P.producer_tail(producer_state_S_P)
1802
+ # pipeline_dP.producer_tail(producer_state_dP)
1803
+ # pipeline_dKV.producer_tail(producer_state_dKV)
1804
+ # pipeline_dQ.producer_tail(producer_state_dQ)
1805
+
1806
+ @cute.jit
1807
+ def split_wg(
1808
+ self,
1809
+ t: cute.Tensor,
1810
+ wg_idx: cutlass.Int32,
1811
+ num_wg: cutlass.Constexpr[int],
1812
+ ):
1813
+ reduced_shape = cute.product_each(t.shape)
1814
+ rank = len(reduced_shape)
1815
+ if const_expr(reduced_shape[1] > 1):
1816
+ assert rank >= 2, "Need rank >= 2 for t in split_wg"
1817
+ t = cute.logical_divide(t, (reduced_shape[0], reduced_shape[1] // num_wg))
1818
+ coord = (None, (None, wg_idx)) + (None,) * (rank - 2)
1819
+ else:
1820
+ assert rank >= 3, "Need rank >= 3 for t in split_wg"
1821
+ if const_expr(rank == 3):
1822
+ t = cute.logical_divide(
1823
+ t, (reduced_shape[0], reduced_shape[1], reduced_shape[2] // num_wg)
1824
+ )
1825
+ coord = (
1826
+ None,
1827
+ None,
1828
+ (None, wg_idx),
1829
+ ) + (None,) * (rank - 3)
1830
+ else:
1831
+ t = cute.logical_divide(
1832
+ t,
1833
+ (
1834
+ reduced_shape[0],
1835
+ reduced_shape[1],
1836
+ reduced_shape[2],
1837
+ reduced_shape[3] // num_wg,
1838
+ ),
1839
+ )
1840
+ coord = (
1841
+ None,
1842
+ None,
1843
+ None,
1844
+ (None, wg_idx),
1845
+ ) + (None,) * (rank - 4)
1846
+ return t[coord]
1847
+
1848
+ @cute.jit
1849
+ def apply_score_mod(
1850
+ self,
1851
+ tSrS_t2r,
1852
+ thr_copy_t2r,
1853
+ thr_mma_S,
1854
+ batch_idx,
1855
+ head_idx,
1856
+ m_block,
1857
+ n_block,
1858
+ softmax_scale,
1859
+ seqlen_info,
1860
+ aux_tensors=None,
1861
+ fastdiv_mods=(None, None),
1862
+ ):
1863
+ """Apply forward score modification for SM100 backward pass."""
1864
+ # In bwd, S is computed as K @ Q.T so dimensions are (tile_n, tile_m)
1865
+ cS = cute.make_identity_tensor((self.tile_n, self.tile_m))
1866
+ cS = cute.domain_offset((n_block * self.tile_n, m_block * self.tile_m), cS)
1867
+ tScS = thr_mma_S.partition_C(cS)
1868
+ tScS_idx = thr_copy_t2r.partition_D(tScS)
1869
+
1870
+ apply_score_mod_inner(
1871
+ tSrS_t2r,
1872
+ tScS_idx,
1873
+ self.score_mod,
1874
+ batch_idx,
1875
+ head_idx,
1876
+ softmax_scale,
1877
+ self.vec_size,
1878
+ self.qk_acc_dtype,
1879
+ aux_tensors,
1880
+ fastdiv_mods,
1881
+ seqlen_info,
1882
+ constant_q_idx=None,
1883
+ qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1884
+ transpose_indices=True,
1885
+ )
1886
+
1887
+ @cute.jit
1888
+ def apply_score_mod_bwd(
1889
+ self,
1890
+ grad_tensor,
1891
+ score_tensor,
1892
+ index_tensor,
1893
+ batch_idx,
1894
+ head_idx,
1895
+ softmax_scale,
1896
+ seqlen_info,
1897
+ aux_tensors=None,
1898
+ fastdiv_mods=(None, None),
1899
+ ):
1900
+ """Apply backward score modification (joint graph) for SM100."""
1901
+ apply_score_mod_bwd_inner(
1902
+ grad_tensor,
1903
+ score_tensor,
1904
+ index_tensor,
1905
+ self.score_mod_bwd,
1906
+ batch_idx,
1907
+ head_idx,
1908
+ softmax_scale,
1909
+ self.vec_size,
1910
+ self.qk_acc_dtype,
1911
+ aux_tensors,
1912
+ fastdiv_mods,
1913
+ seqlen_info,
1914
+ constant_q_idx=None,
1915
+ qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1916
+ transpose_indices=True,
1917
+ )
1918
+
1919
+ @cute.jit
1920
+ def compute_loop(
1921
+ self,
1922
+ thr_mma_S: cute.core.ThrMma,
1923
+ thr_mma_dP: cute.core.ThrMma,
1924
+ thr_mma_dV: cute.core.ThrMma,
1925
+ thr_mma_dK: cute.core.ThrMma,
1926
+ tStS: cute.Tensor,
1927
+ sLSE: cute.Tensor,
1928
+ sdPsum: cute.Tensor,
1929
+ tdVtdV: cute.Tensor,
1930
+ tdKtdK: cute.Tensor,
1931
+ mdV: cute.Tensor,
1932
+ mdK: cute.Tensor,
1933
+ sdS: cute.Tensor,
1934
+ tdPtdP: cute.Tensor,
1935
+ pipeline_LSE: PipelineAsync,
1936
+ pipeline_dPsum: PipelineAsync,
1937
+ pipeline_S_P: PipelineAsync,
1938
+ pipeline_dS: PipelineAsync,
1939
+ pipeline_dKV: PipelineAsync,
1940
+ pipeline_dP: PipelineAsync,
1941
+ softmax_scale: cutlass.Float32,
1942
+ softmax_scale_log2: cutlass.Float32,
1943
+ block_info: BlockInfo,
1944
+ SeqlenInfoCls: Callable,
1945
+ AttentionMaskCls: Callable,
1946
+ TileSchedulerCls: Callable,
1947
+ sdV: Optional[cute.Tensor],
1948
+ sdK: Optional[cute.Tensor],
1949
+ mdV_tma_tensor: Optional[cute.Tensor],
1950
+ mdK_tma_tensor: Optional[cute.Tensor],
1951
+ tma_atom_dV: Optional[cute.CopyAtom],
1952
+ tma_atom_dK: Optional[cute.CopyAtom],
1953
+ tiled_copy_r2s_dKV: Optional[cute.TiledCopy],
1954
+ mdK_semaphore: Optional[cute.Tensor],
1955
+ mdV_semaphore: Optional[cute.Tensor],
1956
+ aux_tensors: Optional[list] = None,
1957
+ fastdiv_mods=(None, None),
1958
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
1959
+ ):
1960
+ sLSE_2D = cute.make_tensor(
1961
+ sLSE.iterator,
1962
+ cute.make_layout(
1963
+ (self.tile_m, self.tile_n, self.Q_stage),
1964
+ stride=(1, 0, cute.round_up(self.tile_m, 64)),
1965
+ ),
1966
+ )
1967
+ sdPsum_2D = cute.make_tensor(
1968
+ sdPsum.iterator,
1969
+ cute.make_layout(
1970
+ (self.tile_m, self.tile_n, self.dO_stage),
1971
+ stride=(1, 0, cute.round_up(self.tile_m, 64)),
1972
+ ),
1973
+ )
1974
+ # if const_expr(self.SdP_swapAB):
1975
+ if const_expr(True):
1976
+ sLSE_2D = utils.transpose_view(sLSE_2D)
1977
+ sdPsum_2D = utils.transpose_view(sdPsum_2D)
1978
+
1979
+ # tix: [128...384] 8 warps
1980
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # 4-11
1981
+ tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids))
1982
+ # tidx = cute.arch.thread_idx()[0] - (cute.arch.WARP_SIZE * self.compute_warp_ids[0])
1983
+ dp_idx = tidx % 128
1984
+ num_wg = len(self.compute_warp_ids) // 4 # 2
1985
+ # wg_idx:
1986
+ # 0: [256...384]
1987
+ # 1: [128...256]
1988
+
1989
+ tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # 64 for tile_n = 128
1990
+ # tStS has shape ((128, 128), 1, 1), tStP has shape ((128, 64), 1, 1)
1991
+ # tP overlap with tS
1992
+ tStP = cute.composition(tStS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1))
1993
+ tStP = cute.make_tensor(tStS.iterator, tStP.layout) # Otherwise the tmem address is wrong
1994
+ tScS = thr_mma_S.partition_C(cute.make_identity_tensor(self.mma_tiler_kq[:2]))
1995
+ tScP = cute.composition(tScS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1))
1996
+ # tdS overlap with tdP
1997
+ tdPtdS = cute.composition(tdPtdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1))
1998
+ tdPcdP = thr_mma_dP.partition_C(cute.make_identity_tensor(self.mma_tiler_vdo[:2]))
1999
+ tdPcdS = cute.composition(tdPcdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1))
2000
+
2001
+ tmem_load_atom = cute.make_copy_atom(
2002
+ tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32
2003
+ )
2004
+ tmem_store_atom = cute.make_copy_atom(
2005
+ tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32
2006
+ )
2007
+
2008
+ # tmem -> rmem
2009
+ thr_copy_t2r = copy_utils.make_tmem_copy(tmem_load_atom, num_wg).get_slice(tidx)
2010
+ tStS_t2r = thr_copy_t2r.partition_S(tStS) # (((32, 32), 1), 2, 1, 1)
2011
+ tdPtdP_t2r = thr_copy_t2r.partition_S(tdPtdP)
2012
+ tScS_t2r = thr_copy_t2r.partition_D(tScS) # ((32, 1), 2, 1, 1)
2013
+ t0ScS_t2r = thr_copy_t2r.get_slice(0).partition_D(tScS) # ((32, 1), 2, 1, 1)
2014
+ # ((32, 1), 2, 1, 1, STAGE)
2015
+ tSsLSE = thr_copy_t2r.partition_D(thr_mma_S.partition_C(sLSE_2D))
2016
+ tSsdPsum = thr_copy_t2r.partition_D(thr_mma_dP.partition_C(sdPsum_2D))
2017
+ # rmem -> tmem
2018
+ thr_copy_r2t = copy_utils.make_tmem_copy(tmem_store_atom, num_wg).get_slice(tidx)
2019
+ tScP_r2t = thr_copy_r2t.partition_S(tScP)
2020
+ tStP_r2t = thr_copy_r2t.partition_D(tStP)
2021
+ tdPcdS_r2t = thr_copy_r2t.partition_S(tdPcdS)
2022
+ tdPtdS_r2t = thr_copy_r2t.partition_D(tdPtdS)
2023
+ # rmem -> smem
2024
+ # This part is a bit iffy, we might be making a lot of assumptions here
2025
+ copy_atom_r2s = sm100_utils_basic.get_smem_store_op(
2026
+ LayoutEnum.ROW_MAJOR, self.ds_dtype, Float32, thr_copy_t2r
2027
+ )
2028
+ thr_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, thr_copy_t2r).get_slice(tidx)
2029
+ # We assume the swizzle (i.e. layout.inner) stays the same
2030
+ sdS_layout = sm100_utils_basic.make_smem_layout_epi(
2031
+ self.ds_dtype, LayoutEnum.ROW_MAJOR, (self.tile_n, self.tile_m), 1
2032
+ ).outer # ((8,16), (64,2), (1, 1))
2033
+ sdS_layout = cute.slice_(sdS_layout, (None, None, 0)) # ((8,16), (64,2))
2034
+ # Need to group into 1 mode to be compatible w thr_copy_r2s
2035
+ sdS_layout = cute.make_layout((sdS_layout.shape,), stride=(sdS_layout.stride,))
2036
+ sdS_epi = cute.make_tensor(sdS.iterator, sdS_layout)
2037
+ tRS_sdS = thr_copy_r2s.partition_D(sdS_epi)
2038
+
2039
+ consumer_state_S_P_dP = pipeline.make_pipeline_state( # Our impl has shortcut for stage==1
2040
+ cutlass.pipeline.PipelineUserType.Consumer, 1
2041
+ )
2042
+ # consumer_phase_S_P_dP = Int32(0)
2043
+ producer_state_dS = pipeline.make_pipeline_state( # Our impl has shortcut for stage==1
2044
+ cutlass.pipeline.PipelineUserType.Producer, 1
2045
+ )
2046
+ consumer_state_dKV = cutlass.pipeline.make_pipeline_state(
2047
+ cutlass.pipeline.PipelineUserType.Consumer, 2
2048
+ )
2049
+ consumer_state_LSE = cutlass.pipeline.make_pipeline_state(
2050
+ cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage
2051
+ )
2052
+ # consumer_state_dPsum = cutlass.pipeline.make_pipeline_state(
2053
+ consumer_state_dPsum = pipeline.make_pipeline_state(
2054
+ cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage
2055
+ )
2056
+
2057
+ tile_scheduler = TileSchedulerCls()
2058
+ work_tile = tile_scheduler.initial_work_tile_info()
2059
+ while work_tile.is_valid_tile:
2060
+ n_block, head_idx, batch_idx, _ = work_tile.tile_idx
2061
+ seqlen = SeqlenInfoCls(batch_idx)
2062
+ m_block_min, m_block_max = block_info.get_m_block_min_max(
2063
+ seqlen, n_block // self.cluster_shape_mnk[0]
2064
+ )
2065
+ mask = AttentionMaskCls(seqlen)
2066
+ # TODO: condition mask_seqlen
2067
+ mask_fn = partial(
2068
+ mask.apply_mask_sm100_transposed,
2069
+ tScS_t2r=tScS_t2r,
2070
+ t0ScS_t2r=t0ScS_t2r,
2071
+ n_block=n_block,
2072
+ mask_seqlen=True,
2073
+ mask_causal=self.is_causal,
2074
+ mask_local=self.is_local,
2075
+ mask_mod=self.mask_mod,
2076
+ batch_idx=batch_idx,
2077
+ head_idx=head_idx,
2078
+ aux_tensors=aux_tensors,
2079
+ fastdiv_mods=fastdiv_mods,
2080
+ )
2081
+
2082
+ # prefetch_LSE = not self.is_causal
2083
+ prefetch_LSE = False
2084
+
2085
+ # some tiles might be empty due to block sparsity
2086
+ if const_expr(self.use_block_sparsity):
2087
+ (
2088
+ curr_q_cnt,
2089
+ curr_q_idx,
2090
+ curr_full_cnt,
2091
+ curr_full_idx,
2092
+ loop_count,
2093
+ ) = get_block_sparse_iteration_info_bwd(
2094
+ blocksparse_tensors,
2095
+ batch_idx,
2096
+ head_idx,
2097
+ n_block,
2098
+ subtile_factor=self.subtile_factor,
2099
+ m_block_max=m_block_max,
2100
+ )
2101
+ process_tile = loop_count > Int32(0)
2102
+ else:
2103
+ process_tile = (
2104
+ const_expr(not self.is_local and not self.is_varlen_q)
2105
+ or m_block_min < m_block_max
2106
+ )
2107
+ loop_count = m_block_max - m_block_min
2108
+
2109
+ # Mainloop
2110
+ # Block sparsity: iterate over sparse m_block count and derive actual m_block
2111
+ # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly.
2112
+ for iter_idx in cutlass.range(loop_count, unroll=1):
2113
+ if const_expr(self.use_block_sparsity):
2114
+ m_block, is_full_block = get_m_block_from_iter_bwd(
2115
+ iter_idx,
2116
+ curr_q_cnt,
2117
+ curr_q_idx,
2118
+ curr_full_cnt,
2119
+ curr_full_idx,
2120
+ subtile_factor=self.subtile_factor,
2121
+ m_block_max=m_block_max,
2122
+ )
2123
+ m_block_oob = m_block >= m_block_max
2124
+ else:
2125
+ m_block = m_block_min + iter_idx
2126
+ m_block_oob = False
2127
+ is_full_block = False
2128
+ # Prefetch 1 stage of LSE
2129
+ pipeline_LSE.consumer_wait(consumer_state_LSE)
2130
+ tSrLSE_s2r = cute.make_fragment(tScS_t2r[None, 0, 0, 0].shape, Float32)
2131
+ if const_expr(prefetch_LSE and not self.shuffle_LSE):
2132
+ cute.autovec_copy(tSsLSE[None, 0, 0, 0, consumer_state_LSE.index], tSrLSE_s2r)
2133
+
2134
+ pipeline_S_P.consumer_wait(consumer_state_S_P_dP)
2135
+ # pipeline_S_P.sync_object_full.wait(0, consumer_phase_S_P_dP)
2136
+ #### TMEM->RMEM (Load S from TMEM)
2137
+ tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32)
2138
+ cute.copy(thr_copy_t2r, tStS_t2r, tSrS_t2r)
2139
+ if const_expr(self.score_mod_bwd is not None):
2140
+ tSrS_pre = cute.make_fragment_like(tSrS_t2r)
2141
+ cute.autovec_copy(tSrS_t2r, tSrS_pre)
2142
+
2143
+ if const_expr(self.score_mod is not None):
2144
+ # Apply score_mod FIRST -> matches forward
2145
+ self.apply_score_mod(
2146
+ tSrS_t2r,
2147
+ thr_copy_t2r,
2148
+ thr_mma_S,
2149
+ batch_idx,
2150
+ head_idx,
2151
+ m_block,
2152
+ n_block,
2153
+ softmax_scale,
2154
+ seqlen,
2155
+ aux_tensors,
2156
+ fastdiv_mods,
2157
+ )
2158
+
2159
+ #### APPLY MASK (after score_mod, matching forward pass order)
2160
+ check_m_boundary = (m_block + 1) * self.tile_m > seqlen.seqlen_q
2161
+ mask_fn(
2162
+ tSrS_t2r,
2163
+ m_block=m_block,
2164
+ is_full_block=is_full_block,
2165
+ check_m_boundary=check_m_boundary,
2166
+ )
2167
+
2168
+ num_stages = cute.size(tScS_t2r, mode=[1])
2169
+
2170
+ # ---------------------------------------------
2171
+ #### P = exp(S - LSE)
2172
+ # ---------------------------------------------
2173
+ lane_idx = cute.arch.lane_idx()
2174
+ tSrP_r2t_f32 = cute.make_fragment(tScP_r2t.shape, Float32) # 64
2175
+ tSrP_r2t = cute.recast_tensor(tSrP_r2t_f32, self.q_dtype)
2176
+ for stage in cutlass.range_constexpr(num_stages):
2177
+ tSrS_cur = tSrS_t2r[None, stage, 0, 0]
2178
+ tSsLSE_cur = tSsLSE[None, stage, 0, 0, consumer_state_LSE.index]
2179
+ if const_expr(not self.shuffle_LSE):
2180
+ if const_expr(stage > 0 or not prefetch_LSE):
2181
+ cute.autovec_copy(tSsLSE_cur, tSrLSE_s2r)
2182
+ tSrLSE = tSrLSE_s2r
2183
+ else:
2184
+ tSrLSE = tSsLSE_cur[lane_idx]
2185
+ for v in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[0]) // 2):
2186
+ if const_expr(not self.shuffle_LSE):
2187
+ lse_pair = (tSrLSE[2 * v], tSrLSE[2 * v + 1])
2188
+ else:
2189
+ lse_pair = (
2190
+ utils.shuffle_sync(tSrLSE, offset=2 * v),
2191
+ utils.shuffle_sync(tSrLSE, offset=2 * v + 1),
2192
+ )
2193
+ tSrS_cur[2 * v], tSrS_cur[2 * v + 1] = utils.fma_packed_f32x2(
2194
+ ((tSrS_cur[2 * v], tSrS_cur[2 * v + 1])),
2195
+ (softmax_scale_log2, softmax_scale_log2),
2196
+ (-lse_pair[0], -lse_pair[1]),
2197
+ )
2198
+ tSrS_cur[2 * v] = cute.math.exp2(tSrS_cur[2 * v], fastmath=True)
2199
+ tSrS_cur[2 * v + 1] = cute.math.exp2(tSrS_cur[2 * v + 1], fastmath=True)
2200
+ utils.cvt_f16(tSrS_cur, tSrP_r2t[None, stage, 0, 0])
2201
+ if const_expr(stage == 0):
2202
+ cute.arch.fence_view_async_tmem_load()
2203
+ # Without this barrier, we could have 1 warp writing to P in tmem while
2204
+ # another warp is still reading S from tmem.
2205
+ self.compute_sync_barrier.arrive_and_wait()
2206
+ cute.copy(
2207
+ thr_copy_r2t,
2208
+ tSrP_r2t_f32[None, stage, None, None],
2209
+ tStP_r2t[None, stage, None, None],
2210
+ )
2211
+
2212
+ cute.arch.fence_view_async_tmem_store()
2213
+ self.compute_sync_barrier.arrive_and_wait()
2214
+
2215
+ with cute.arch.elect_one():
2216
+ pipeline_S_P.consumer_release(consumer_state_S_P_dP)
2217
+ # pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask)
2218
+ pipeline_LSE.consumer_release(consumer_state_LSE)
2219
+ # consumer_state_S_P_dP.advance()
2220
+ consumer_state_LSE.advance()
2221
+
2222
+ # ---------------------------------------------
2223
+ # dS.T = P.T * (dP.T - D)
2224
+ # ---------------------------------------------
2225
+ pipeline_dPsum.consumer_wait(consumer_state_dPsum)
2226
+
2227
+ pipeline_dP.consumer_wait(consumer_state_S_P_dP)
2228
+ # pipeline_dP.sync_object_full.wait(0, consumer_phase_S_P_dP)
2229
+ consumer_state_S_P_dP.advance()
2230
+ # consumer_phase_S_P_dP ^= 1
2231
+
2232
+ ##### dS.T = P.T * (dP.T - Psum)
2233
+ for stage in cutlass.range_constexpr(num_stages):
2234
+ tdPrdP_t2r = cute.make_fragment(tScS_t2r[None, 0, None, None].shape, Float32)
2235
+ cute.copy(thr_copy_t2r, tdPtdP_t2r[None, stage, None, None], tdPrdP_t2r)
2236
+ cute.arch.fence_view_async_tmem_load()
2237
+ self.compute_sync_barrier.arrive_and_wait()
2238
+ tdPrdP_cur = tdPrdP_t2r[None, 0, 0]
2239
+ tSrS_cur = tSrS_t2r[None, stage, 0, 0]
2240
+ tSsdPsum_cur = tSsdPsum[None, stage, 0, 0, consumer_state_dPsum.index]
2241
+ if const_expr(not self.shuffle_dPsum):
2242
+ tSrdPsum = cute.make_fragment_like(tSsdPsum_cur, Float32)
2243
+ cute.autovec_copy(tSsdPsum_cur, tSrdPsum)
2244
+ else:
2245
+ tSrdPsum = tSsdPsum_cur[lane_idx]
2246
+ for v in cutlass.range_constexpr(cute.size(tdPrdP_t2r, mode=[0]) // 2):
2247
+ if const_expr(not self.shuffle_dPsum):
2248
+ dPsum_pair = (tSrdPsum[2 * v], tSrdPsum[2 * v + 1])
2249
+ else:
2250
+ dPsum_pair = (
2251
+ utils.shuffle_sync(tSrdPsum, offset=2 * v),
2252
+ utils.shuffle_sync(tSrdPsum, offset=2 * v + 1),
2253
+ )
2254
+ tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = utils.sub_packed_f32x2(
2255
+ (tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]), dPsum_pair
2256
+ )
2257
+ tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = utils.mul_packed_f32x2(
2258
+ (tSrS_cur[2 * v], tSrS_cur[2 * v + 1]),
2259
+ (tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]),
2260
+ )
2261
+
2262
+ if const_expr(self.score_mod_bwd is not None):
2263
+ tSrS_pre_cur = tSrS_pre[None, stage, 0, 0]
2264
+ cS_bwd = cute.make_identity_tensor((self.tile_n, self.tile_m))
2265
+ cS_bwd = cute.domain_offset(
2266
+ (n_block * self.tile_n, m_block * self.tile_m), cS_bwd
2267
+ )
2268
+ tScS_bwd = thr_mma_S.partition_C(cS_bwd)
2269
+ tScS_idx_bwd = thr_copy_t2r.partition_D(tScS_bwd)
2270
+ tScS_idx_cur = tScS_idx_bwd[None, stage, 0, 0]
2271
+ self.apply_score_mod_bwd(
2272
+ tdPrdP_cur,
2273
+ tSrS_pre_cur,
2274
+ tScS_idx_cur,
2275
+ batch_idx,
2276
+ head_idx,
2277
+ softmax_scale,
2278
+ seqlen,
2279
+ aux_tensors,
2280
+ fastdiv_mods,
2281
+ )
2282
+ # Zero out OOB positions (kv_idx >= seqlen_k) after score_mod_bwd
2283
+ for i in cutlass.range(cute.size(tdPrdP_cur), unroll_full=True):
2284
+ kv_idx = tScS_idx_cur[i][0]
2285
+ tdPrdP_cur[i] = 0.0 if kv_idx >= seqlen.seqlen_k else tdPrdP_cur[i]
2286
+
2287
+ tdPrdS_cvt = cute.make_fragment_like(tdPrdP_cur, self.ds_dtype)
2288
+ utils.cvt_f16(tdPrdP_cur, tdPrdS_cvt)
2289
+ if const_expr(stage == 0):
2290
+ pipeline_dS.producer_acquire(producer_state_dS)
2291
+ cute.autovec_copy(tdPrdS_cvt, tRS_sdS[None, stage])
2292
+ if const_expr(not self.use_smem_dS_for_mma_dK):
2293
+ tdPrdS_r2t_f32 = cute.recast_tensor(tdPrdS_cvt, Float32)
2294
+ cute.copy(thr_copy_r2t, tdPrdS_r2t_f32, tdPtdS_r2t[None, stage, 0, 0])
2295
+
2296
+ if const_expr(not self.use_smem_dS_for_mma_dK):
2297
+ cute.arch.fence_view_async_tmem_store()
2298
+ cute.arch.fence_proxy(
2299
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
2300
+ )
2301
+ self.compute_sync_barrier.arrive_and_wait()
2302
+
2303
+ # with cute.arch.elect_one():
2304
+ # The mma warp no longer waits for dP (it waits for dS), so we don't have to arrive
2305
+ # pipeline_dP.sync_object_empty.arrive(0, pipeline_dP.consumer_mask)
2306
+ pipeline_dPsum.consumer_release(consumer_state_dPsum)
2307
+ consumer_state_dPsum.advance()
2308
+ with cute.arch.elect_one():
2309
+ pipeline_dS.producer_commit(producer_state_dS)
2310
+ producer_state_dS.advance()
2311
+
2312
+ # Epilogue
2313
+ # Run epilogue if we processed any m_blocks for this n_block
2314
+ if process_tile:
2315
+ if const_expr(not self.use_tma_store):
2316
+ consumer_state_dKV = self.epilogue_dKV(
2317
+ dp_idx,
2318
+ warp_idx,
2319
+ batch_idx,
2320
+ head_idx,
2321
+ n_block,
2322
+ seqlen,
2323
+ thr_mma_dV,
2324
+ thr_mma_dK,
2325
+ tdVtdV,
2326
+ tdKtdK,
2327
+ mdV,
2328
+ mdK,
2329
+ pipeline_dKV,
2330
+ consumer_state_dKV,
2331
+ softmax_scale,
2332
+ )
2333
+ else:
2334
+ thr_copy_r2s_dKV = tiled_copy_r2s_dKV.get_slice(dp_idx)
2335
+ #### STORE dV
2336
+ consumer_state_dKV = self.epilogue_dK_or_dV_tma(
2337
+ dp_idx,
2338
+ batch_idx,
2339
+ head_idx,
2340
+ n_block,
2341
+ seqlen,
2342
+ thr_mma_dV,
2343
+ tdVtdV,
2344
+ mdV_tma_tensor,
2345
+ sdV,
2346
+ tma_atom_dV,
2347
+ thr_copy_r2s_dKV,
2348
+ pipeline_dKV,
2349
+ consumer_state_dKV,
2350
+ None, # Don't scale
2351
+ int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id
2352
+ mdV_semaphore,
2353
+ )
2354
+ #### STORE dK
2355
+ consumer_state_dKV = self.epilogue_dK_or_dV_tma(
2356
+ dp_idx,
2357
+ batch_idx,
2358
+ head_idx,
2359
+ n_block,
2360
+ seqlen,
2361
+ thr_mma_dK,
2362
+ tdKtdK,
2363
+ mdK_tma_tensor,
2364
+ sdK,
2365
+ tma_atom_dK,
2366
+ thr_copy_r2s_dKV,
2367
+ pipeline_dKV,
2368
+ consumer_state_dKV,
2369
+ softmax_scale if const_expr(not self.dKV_postprocess) else None,
2370
+ int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id
2371
+ mdK_semaphore,
2372
+ )
2373
+ # Zero dK/dV for empty tiles (local attention or block sparsity)
2374
+ # When total_m_block_cnt == 0 for block sparsity, no Q tiles contribute to this KV tile
2375
+ if const_expr(not self.dKV_postprocess):
2376
+ should_zero_dKV = False
2377
+ if const_expr(self.is_local or self.is_varlen_q):
2378
+ should_zero_dKV = m_block_min >= m_block_max
2379
+ if const_expr(self.use_block_sparsity):
2380
+ # For block sparsity, zero when no m_blocks contribute to this n_block
2381
+ if not process_tile:
2382
+ should_zero_dKV = True
2383
+
2384
+ if should_zero_dKV:
2385
+ # like other epis, currently assumes hdim == hdimv
2386
+ gmem_tiled_copy_zero_dKV = copy_utils.tiled_copy_2d(
2387
+ self.dk_dtype,
2388
+ self.tile_hdim,
2389
+ 128, # num_threads
2390
+ )
2391
+ gmem_thr_copy_zero_dKV = gmem_tiled_copy_zero_dKV.get_slice(dp_idx)
2392
+ mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx]
2393
+ mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx]
2394
+ gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
2395
+ gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
2396
+ tdKgdK = gmem_thr_copy_zero_dKV.partition_D(gdK)
2397
+ tdVgdV = gmem_thr_copy_zero_dKV.partition_D(gdV)
2398
+ assert tdKgdK.shape[2] == 1
2399
+ assert tdVgdV.shape[2] == 1
2400
+ cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim))
2401
+ tdKVcdKV = gmem_thr_copy_zero_dKV.partition_D(cdKV)
2402
+ zero = cute.make_fragment_like(tdKgdK[None, 0, 0])
2403
+ zero.fill(0.0)
2404
+ if tidx < 128:
2405
+ for i in cutlass.range_constexpr(tdKgdK.shape[1]):
2406
+ row_idx = tdKVcdKV[0, i, 0][0]
2407
+ if row_idx < seqlen.seqlen_k - self.tile_n * n_block:
2408
+ cute.copy(gmem_tiled_copy_zero_dKV, zero, tdKgdK[None, i, 0])
2409
+ else:
2410
+ for i in cutlass.range_constexpr(tdVgdV.shape[1]):
2411
+ row_idx = tdKVcdKV[0, i, 0][0]
2412
+ if row_idx < seqlen.seqlen_k - self.tile_n * n_block:
2413
+ cute.copy(gmem_tiled_copy_zero_dKV, zero, tdVgdV[None, i, 0])
2414
+
2415
+ tile_scheduler.advance_to_next_work()
2416
+ work_tile = tile_scheduler.get_current_work()
2417
+
2418
+ @cute.jit
2419
+ def dQacc_reduce(
2420
+ self,
2421
+ mdQaccum: cute.Tensor,
2422
+ sdQaccum: cute.Tensor,
2423
+ thr_mma_dQ: cute.core.ThrMma,
2424
+ tdQtdQ: cute.Tensor,
2425
+ pipeline_dQ: PipelineAsync,
2426
+ block_info: BlockInfo,
2427
+ SeqlenInfoCls: Callable,
2428
+ TileSchedulerCls: Callable,
2429
+ mdQ_semaphore: Optional[cute.Tensor],
2430
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
2431
+ ):
2432
+ num_reduce_threads = cute.arch.WARP_SIZE * len(self.reduce_warp_ids)
2433
+ tidx = cute.arch.thread_idx()[0] % num_reduce_threads
2434
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx() % len(self.reduce_warp_ids))
2435
+ is_tma_warp = warp_idx == 0
2436
+ # TMEM -> RMEM
2437
+ tmem_load_atom = cute.make_copy_atom(
2438
+ tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32
2439
+ )
2440
+ thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ).get_slice(tidx)
2441
+ tdQtdQ_t2r = thr_copy_t2r.partition_S(tdQtdQ)
2442
+ tdQcdQ = thr_mma_dQ.partition_C(cute.make_identity_tensor(self.mma_tiler_dsk[:2]))
2443
+ tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape
2444
+ assert cute.size(tdQrdQ_t2r_shape, mode=[1]) == self.dQaccum_reduce_stage, (
2445
+ "dQaccum reduce stage mismatch"
2446
+ )
2447
+
2448
+ thr_copy_dQaccum_r2s = copy_utils.tiled_copy_1d(
2449
+ self.dqaccum_dtype, num_reduce_threads, num_copy_elems=128 // self.dqaccum_dtype.width
2450
+ ).get_slice(tidx)
2451
+ tdQsdQ = thr_copy_dQaccum_r2s.partition_D(sdQaccum)
2452
+
2453
+ read_flag = const_expr(not self.deterministic)
2454
+
2455
+ tile_scheduler = TileSchedulerCls()
2456
+ work_tile = tile_scheduler.initial_work_tile_info()
2457
+ dQ_consumer_state = pipeline.make_pipeline_state(
2458
+ cutlass.pipeline.PipelineUserType.Consumer, 1
2459
+ )
2460
+ dQ_tma_store_producer_state = pipeline.make_pipeline_state(
2461
+ pipeline.PipelineUserType.Producer, self.sdQaccum_stage
2462
+ )
2463
+ while work_tile.is_valid_tile:
2464
+ n_block, head_idx, batch_idx, _ = work_tile.tile_idx
2465
+ seqlen = SeqlenInfoCls(batch_idx)
2466
+ m_block_min, m_block_max = block_info.get_m_block_min_max(
2467
+ seqlen, n_block // self.cluster_shape_mnk[0]
2468
+ )
2469
+ if const_expr(not seqlen.has_cu_seqlens_q):
2470
+ mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
2471
+ else:
2472
+ mdQaccum_cur = cute.domain_offset(
2473
+ (seqlen.padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx]
2474
+ )
2475
+ gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,))
2476
+ # (M * K / STAGE, STAGE, _)
2477
+ gdQaccum = cute.flat_divide(
2478
+ gdQaccum_, (self.tile_m * self.tile_hdim // self.dQaccum_reduce_stage,)
2479
+ )
2480
+
2481
+ if const_expr(self.deterministic):
2482
+ mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx]
2483
+
2484
+ delay_semaphore_release = self.is_causal
2485
+ n_block_global_max = cute.ceil_div(seqlen.seqlen_k, self.tile_n)
2486
+
2487
+ # some tiles might be empty due to block sparsity
2488
+ if const_expr(self.use_block_sparsity):
2489
+ (
2490
+ curr_q_cnt,
2491
+ curr_q_idx,
2492
+ curr_full_cnt,
2493
+ curr_full_idx,
2494
+ loop_count,
2495
+ ) = get_block_sparse_iteration_info_bwd(
2496
+ blocksparse_tensors,
2497
+ batch_idx,
2498
+ head_idx,
2499
+ n_block,
2500
+ subtile_factor=self.subtile_factor,
2501
+ m_block_max=m_block_max,
2502
+ )
2503
+ process_tile = loop_count > Int32(0)
2504
+ else:
2505
+ process_tile = (
2506
+ const_expr(not self.is_local and not self.is_varlen_q)
2507
+ or m_block_min < m_block_max
2508
+ )
2509
+ loop_count = m_block_max - m_block_min
2510
+
2511
+ # dQacc_reduce mainloop
2512
+ # Block sparsity: iterate over sparse m_block count and derive actual m_block
2513
+ # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly.
2514
+ for iter_idx in cutlass.range(loop_count, unroll=1):
2515
+ if const_expr(self.use_block_sparsity):
2516
+ m_block, _ = get_m_block_from_iter_bwd(
2517
+ iter_idx,
2518
+ curr_q_cnt,
2519
+ curr_q_idx,
2520
+ curr_full_cnt,
2521
+ curr_full_idx,
2522
+ subtile_factor=self.subtile_factor,
2523
+ m_block_max=m_block_max,
2524
+ )
2525
+ if m_block_max > 0:
2526
+ m_block = cutlass.min(m_block, m_block_max - 1)
2527
+ else:
2528
+ m_block = m_block_min + iter_idx
2529
+ pipeline_dQ.consumer_wait(dQ_consumer_state)
2530
+ # TMEM -> RMEM
2531
+ tdQrdQ_t2r = cute.make_fragment(tdQrdQ_t2r_shape, Float32)
2532
+ cute.copy(thr_copy_t2r, tdQtdQ_t2r, tdQrdQ_t2r)
2533
+ cute.arch.fence_view_async_tmem_load()
2534
+ cute.arch.sync_warp()
2535
+ with cute.arch.elect_one():
2536
+ pipeline_dQ.consumer_release(dQ_consumer_state)
2537
+ dQ_consumer_state.advance()
2538
+
2539
+ gdQaccum_cur = gdQaccum[None, None, m_block]
2540
+
2541
+ for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4
2542
+ smem_idx = dQ_tma_store_producer_state.index
2543
+ tdQsdQ_r2s = tdQsdQ[None, None, smem_idx]
2544
+ tdQrdQ_r2s = cute.make_tensor(
2545
+ tdQrdQ_t2r[None, stage, None, None].iterator, tdQsdQ_r2s.shape
2546
+ )
2547
+ cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s)
2548
+ # Fence and barrier to make sure shared memory store is visible to TMA store
2549
+ cute.arch.fence_proxy(
2550
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
2551
+ )
2552
+ # semaphore acquire
2553
+ if const_expr(self.deterministic and stage == 0):
2554
+ if const_expr(self.spt):
2555
+ if const_expr(
2556
+ self.is_causal or block_info.window_size_right is not None
2557
+ ):
2558
+ n_idx_right = (
2559
+ (m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q
2560
+ )
2561
+ if const_expr(block_info.window_size_right is not None):
2562
+ n_idx_right += block_info.window_size_right
2563
+ n_block_max_for_m_block = min(
2564
+ n_block_global_max,
2565
+ cute.ceil_div(n_idx_right, self.tile_n),
2566
+ )
2567
+ else:
2568
+ n_block_max_for_m_block = n_block_global_max
2569
+ lock_value = n_block_max_for_m_block - 1 - n_block
2570
+ else:
2571
+ lock_value = n_block
2572
+ barrier.wait_eq(
2573
+ mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, lock_value
2574
+ )
2575
+ self.reduce_sync_barrier.arrive_and_wait()
2576
+ # Copy from shared memory to global memory
2577
+ if is_tma_warp:
2578
+ with cute.arch.elect_one():
2579
+ copy_utils.cpasync_reduce_bulk_add_f32(
2580
+ sdQaccum[None, smem_idx].iterator,
2581
+ gdQaccum_cur[None, stage].iterator,
2582
+ self.tma_copy_bytes["dQ"] // 1,
2583
+ )
2584
+ cute.arch.cp_async_bulk_commit_group()
2585
+ cute.arch.cp_async_bulk_wait_group(self.sdQaccum_stage - 1, read=read_flag)
2586
+ self.reduce_sync_barrier.arrive_and_wait()
2587
+ dQ_tma_store_producer_state.advance()
2588
+ # Directly add to gmem, much slower
2589
+ # tdQgdQ = thr_copy_dQaccum_r2s.partition_D(gdQaccum[None, stage, m_block])
2590
+ # assert cute.size(tdQrdQ_r2s) == cute.size(tdQgdQ)
2591
+ # for i in cutlass.range(cute.size(tdQrdQ_r2s) // 4, unroll_full=True):
2592
+ # copy_utils.atomic_add_fp32x4(
2593
+ # tdQrdQ_r2s[4 * i],
2594
+ # tdQrdQ_r2s[4 * i + 1],
2595
+ # tdQrdQ_r2s[4 * i + 2],
2596
+ # tdQrdQ_r2s[4 * i + 3],
2597
+ # utils.elem_pointer(tdQgdQ, 4 * i),
2598
+ # )
2599
+ # semaphore release for prior m_block
2600
+ if const_expr(self.deterministic and stage == 0 and delay_semaphore_release):
2601
+ if m_block > m_block_min:
2602
+ barrier.arrive_inc(
2603
+ mdQ_semaphore_cur[(m_block - 1, None)].iterator, tidx, 0, 1
2604
+ )
2605
+
2606
+ # semaphore release
2607
+ # NOTE: arrive_inc calls red_release which issues membar
2608
+ if const_expr(self.deterministic and not delay_semaphore_release):
2609
+ if is_tma_warp:
2610
+ cute.arch.cp_async_bulk_wait_group(0, read=read_flag)
2611
+ self.reduce_sync_barrier.arrive_and_wait()
2612
+ barrier.arrive_inc(mdQ_semaphore_cur[m_block, None].iterator, tidx, 0, 1)
2613
+
2614
+ if const_expr(not self.is_local) or m_block_min < m_block_max:
2615
+ if is_tma_warp:
2616
+ cute.arch.cp_async_bulk_wait_group(0, read=read_flag)
2617
+ self.reduce_sync_barrier.arrive_and_wait()
2618
+ # final semaphore release
2619
+ if const_expr(self.deterministic and delay_semaphore_release):
2620
+ barrier.arrive_inc(
2621
+ mdQ_semaphore_cur[(m_block_max - 1, None)].iterator, tidx, 0, 1
2622
+ )
2623
+
2624
+ if const_expr(
2625
+ self.deterministic and not self.spt and block_info.window_size_left is not None
2626
+ ):
2627
+ m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m)
2628
+ for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1):
2629
+ barrier.arrive_inc(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, 1)
2630
+
2631
+ tile_scheduler.advance_to_next_work()
2632
+ work_tile = tile_scheduler.get_current_work()
2633
+
2634
+ @cute.jit
2635
+ def epilogue_dKV(
2636
+ self,
2637
+ tidx: Int32,
2638
+ warp_idx: Int32,
2639
+ batch_idx: Int32,
2640
+ head_idx: Int32,
2641
+ n_block: Int32,
2642
+ seqlen,
2643
+ thr_mma_dV: cute.core.ThrMma,
2644
+ thr_mma_dK: cute.core.ThrMma,
2645
+ tdVtdV: cute.Tensor,
2646
+ tdKtdK: cute.Tensor,
2647
+ mdV: cute.Tensor,
2648
+ mdK: cute.Tensor,
2649
+ pipeline_dKV: PipelineAsync,
2650
+ consumer_state_dKV: cutlass.pipeline.PipelineState,
2651
+ softmax_scale: Float32,
2652
+ ):
2653
+ wg_idx = (
2654
+ cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids))
2655
+ ) // 128
2656
+ num_wg = cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128
2657
+
2658
+ assert self.qhead_per_kvhead == 1, "This epilogue path is only for MHA"
2659
+ mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx]
2660
+ mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx]
2661
+
2662
+ tmem_load_atom = cute.make_copy_atom(
2663
+ tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)), Float32
2664
+ )
2665
+
2666
+ # dV
2667
+ pipeline_dKV.consumer_wait(consumer_state_dKV)
2668
+
2669
+ tiled_tmem_ld_dV = tcgen05.make_tmem_copy(tmem_load_atom, tdVtdV)
2670
+ thr_tmem_ld_dV = tiled_tmem_ld_dV.get_slice(tidx)
2671
+
2672
+ tdVtdV_t2r_p = thr_tmem_ld_dV.partition_S(tdVtdV)
2673
+ tdVtdV_t2r = self.split_wg(tdVtdV_t2r_p, wg_idx, num_wg)
2674
+
2675
+ cdV = cute.make_identity_tensor((self.mma_tiler_pdo[0], self.mma_tiler_pdo[1]))
2676
+ tdVcdV = thr_mma_dV.partition_C(cdV)
2677
+ tdVcdV_tensor = cute.make_tensor(tdVcdV.iterator, tdVcdV.layout)
2678
+
2679
+ tdVcdV_t2r_p = thr_tmem_ld_dV.partition_D(tdVcdV_tensor)
2680
+ tdVcdV_t2r = self.split_wg(tdVcdV_t2r_p, wg_idx, num_wg)
2681
+ tdVrdV_t2r = cute.make_fragment(tdVcdV_t2r.shape, Float32)
2682
+
2683
+ cute.copy(thr_tmem_ld_dV, tdVtdV_t2r, tdVrdV_t2r)
2684
+ cute.arch.fence_view_async_tmem_load()
2685
+
2686
+ universal_copy_bits = 128
2687
+ atom_universal_copy = cute.make_copy_atom(
2688
+ cute.nvgpu.CopyUniversalOp(),
2689
+ self.dv_dtype,
2690
+ num_bits_per_copy=universal_copy_bits,
2691
+ )
2692
+ tiled_gmem_store_dV = cute.make_tiled_copy(
2693
+ atom_universal_copy,
2694
+ layout_tv=tiled_tmem_ld_dV.layout_dst_tv_tiled,
2695
+ tiler_mn=tiled_tmem_ld_dV.tiler_mn,
2696
+ )
2697
+
2698
+ tdVrdV_r2s = cute.make_fragment(tdVrdV_t2r.shape, self.dv_dtype)
2699
+ for i in cutlass.range_constexpr(cute.size(tdVrdV_t2r, mode=[1])):
2700
+ dV_vec = tdVrdV_t2r[(None, i, 0, 0)].load()
2701
+ tdVrdV_r2s[(None, i, 0, 0)].store(dV_vec.to(self.dv_dtype))
2702
+
2703
+ gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (None, 0))
2704
+ gdV_tile = gdV[None, None, n_block]
2705
+
2706
+ tdVgdV = thr_mma_dV.partition_C(gdV_tile)
2707
+ tdVgdV_r2g_p = thr_tmem_ld_dV.partition_D(tdVgdV)
2708
+ tdVgdV_r2g = self.split_wg(tdVgdV_r2g_p, wg_idx, num_wg)
2709
+
2710
+ if tidx < seqlen.seqlen_k - self.tile_n * n_block:
2711
+ cute.copy(tiled_gmem_store_dV, tdVrdV_r2s, tdVgdV_r2g)
2712
+
2713
+ cute.arch.sync_warp()
2714
+ with cute.arch.elect_one():
2715
+ pipeline_dKV.consumer_release(consumer_state_dKV)
2716
+ consumer_state_dKV.advance()
2717
+
2718
+ # dK
2719
+ pipeline_dKV.consumer_wait(consumer_state_dKV)
2720
+
2721
+ tiled_tmem_ld_dK = tcgen05.make_tmem_copy(tmem_load_atom, tdKtdK)
2722
+ thr_tmem_ld_dK = tiled_tmem_ld_dK.get_slice(tidx)
2723
+
2724
+ tdKtdK_t2r_p = thr_tmem_ld_dK.partition_S(tdKtdK)
2725
+ tdKtdK_t2r = self.split_wg(tdKtdK_t2r_p, wg_idx, num_wg)
2726
+
2727
+ cdK = cute.make_identity_tensor((self.mma_tiler_dsq[0], self.mma_tiler_dsq[1]))
2728
+ tdKcdK = thr_mma_dK.partition_C(cdK)
2729
+ tdKcdK_tensor = cute.make_tensor(tdKcdK.iterator, tdKcdK.layout)
2730
+
2731
+ tdKcdK_t2r_p = thr_tmem_ld_dK.partition_D(tdKcdK_tensor)
2732
+ tdKcdK_t2r = self.split_wg(tdKcdK_t2r_p, wg_idx, num_wg)
2733
+ tdKrdK_t2r = cute.make_fragment(tdKcdK_t2r.shape, Float32)
2734
+
2735
+ cute.copy(tiled_tmem_ld_dK, tdKtdK_t2r, tdKrdK_t2r)
2736
+ cute.arch.fence_view_async_tmem_load()
2737
+
2738
+ universal_copy_bits = 128
2739
+ atom_universal_copy = cute.make_copy_atom(
2740
+ cute.nvgpu.CopyUniversalOp(),
2741
+ self.dk_dtype,
2742
+ num_bits_per_copy=universal_copy_bits,
2743
+ )
2744
+
2745
+ tiled_gmem_store_dK = cute.make_tiled_copy(
2746
+ atom_universal_copy,
2747
+ layout_tv=tiled_tmem_ld_dK.layout_dst_tv_tiled,
2748
+ tiler_mn=tiled_tmem_ld_dK.tiler_mn,
2749
+ )
2750
+
2751
+ tdKrdK_r2s = cute.make_fragment(tdKrdK_t2r.shape, self.dk_dtype)
2752
+
2753
+ for i in cutlass.range_constexpr(cute.size(tdKrdK_t2r, mode=[1])):
2754
+ dK_vec = tdKrdK_t2r[(None, i, 0, 0)].load() * softmax_scale
2755
+ tdKrdK_r2s[(None, i, 0, 0)].store(dK_vec.to(self.dk_dtype))
2756
+
2757
+ gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdimv), (None, 0))
2758
+ gdK_tile = gdK[None, None, n_block]
2759
+
2760
+ tdKgdK = thr_mma_dK.partition_C(gdK_tile)
2761
+ tdKgdK_r2g_p = thr_tmem_ld_dK.partition_D(tdKgdK)
2762
+ tdKgdK_r2g = self.split_wg(tdKgdK_r2g_p, wg_idx, num_wg)
2763
+
2764
+ if tidx < seqlen.seqlen_k - self.tile_n * n_block:
2765
+ cute.copy(tiled_gmem_store_dK, tdKrdK_r2s, tdKgdK_r2g)
2766
+
2767
+ cute.arch.sync_warp()
2768
+ with cute.arch.elect_one():
2769
+ pipeline_dKV.consumer_release(consumer_state_dKV)
2770
+ consumer_state_dKV.advance()
2771
+ return consumer_state_dKV
2772
+
2773
+ @cute.jit
2774
+ def epilogue_dK_or_dV_tma(
2775
+ self,
2776
+ tidx: Int32,
2777
+ batch_idx: Int32,
2778
+ head_idx: Int32,
2779
+ n_block: Int32,
2780
+ seqlen,
2781
+ thr_mma: cute.core.ThrMma,
2782
+ tdKVtdKV: cute.Tensor,
2783
+ mdKV: cute.Tensor,
2784
+ sdKV: cute.Tensor,
2785
+ tma_atom_dKV: cute.CopyAtom,
2786
+ thr_copy_r2s_dKV: cute.TiledCopy,
2787
+ pipeline_dKV: PipelineAsync,
2788
+ consumer_state_dKV: cutlass.pipeline.PipelineState,
2789
+ scale: Optional[Float32],
2790
+ barrier_id: Int32,
2791
+ mdKV_semaphore: Optional[cute.Tensor],
2792
+ ) -> cutlass.pipeline.PipelineState:
2793
+ # assumes mma_tiler_pdo = mma_tiler_dsq = (tile_n, head_dim)
2794
+ # head_dim = head_dim_v, dk_dtype = dv_dtype
2795
+ num_compute_threads = cute.arch.WARP_SIZE * len(self.compute_warp_ids)
2796
+ wg_idx = (cute.arch.thread_idx()[0] % num_compute_threads) // 128
2797
+ num_wg = num_compute_threads // 128
2798
+ leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0
2799
+
2800
+ if const_expr(not self.dKV_postprocess):
2801
+ sdKV = sdKV[None, None, wg_idx] # (tile_n, 64) for bf16
2802
+ else:
2803
+ sdKV = sdKV[None, wg_idx] # (tile_n * 32) for fp32
2804
+
2805
+ # (8, tile_n / 128, 64 / 8) = (8, 1, 8) or (4, tile_n * 32 / (128 * 4)) = (4, 8)
2806
+ tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV)
2807
+
2808
+ head_idx_kv = head_idx // self.qhead_per_kvhead
2809
+ if const_expr(not self.dKV_postprocess):
2810
+ assert not seqlen.has_cu_seqlens_k, "varlen uses non tma store path"
2811
+ mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim)
2812
+ gdKV_p = cute.local_tile(
2813
+ mdKV_cur, (self.tile_n, self.tile_hdim), (n_block, 0)
2814
+ ) # (tile_n, hdim)
2815
+ gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) # (tile_n, hdim / 2)
2816
+ gdKV_epi = cute.local_tile(
2817
+ gdKV, self.sdKV_epi_tile, (0, None)
2818
+ ) # (tile_n, 64, epi_stage = (hdim / 2) / 64)
2819
+ else:
2820
+ if const_expr(not seqlen.has_cu_seqlens_k):
2821
+ mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim)
2822
+ else:
2823
+ mdKV_cur = cute.domain_offset(
2824
+ (seqlen.padded_offset_k * self.tile_hdim,), mdKV[None, head_idx_kv]
2825
+ )
2826
+ gdKV_p = cute.local_tile(
2827
+ mdKV_cur, (self.tile_n * self.tile_hdim,), (n_block,)
2828
+ ) # (tile_n * hdim)
2829
+ gdKV = cute.logical_divide(gdKV_p, (self.tile_n * self.tile_hdim // num_wg,))[
2830
+ ((None, wg_idx),)
2831
+ ] # (tile_n * hdim / 2)
2832
+ gdKV_epi = cute.flat_divide(
2833
+ gdKV, (self.sdKV_flat_epi_tile,)
2834
+ ) # (tile_n * hdim / 2 / epi_stage, epi_stage)
2835
+
2836
+ deterministic_KV = self.deterministic and self.qhead_per_kvhead > 1
2837
+ if const_expr(deterministic_KV):
2838
+ mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx]
2839
+
2840
+ if const_expr(not self.dKV_postprocess):
2841
+ tdKVsdKV, tdKVgdKV = cpasync.tma_partition(
2842
+ tma_atom_dKV,
2843
+ 0, # no multicast
2844
+ cute.make_layout(1),
2845
+ cute.group_modes(sdKV, 0, 2),
2846
+ cute.group_modes(gdKV_epi, 0, 2),
2847
+ ) # (TMA) and (TMA, EPI_STAGE)
2848
+ assert len(tdKVsdKV.shape) == 1, "Wrong rank for SMEM fragment tdKVsdKV"
2849
+ assert len(tdKVgdKV.shape) == 2, "Wrong rank for GMEM fragment tdKVgdKV"
2850
+ num_epi_stages = cute.size(tdKVgdKV.shape[1])
2851
+ assert num_epi_stages == self.num_epi_stages, "Epi stage calculation is wrong"
2852
+ else:
2853
+ num_epi_stages = self.num_epi_stages
2854
+
2855
+ tmem_load_atom = cute.make_copy_atom(
2856
+ tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32
2857
+ )
2858
+
2859
+ read_flag = const_expr(not deterministic_KV)
2860
+
2861
+ pipeline_dKV.consumer_wait(consumer_state_dKV)
2862
+
2863
+ # semaphore acquire
2864
+ if const_expr(deterministic_KV):
2865
+ barrier.wait_eq(
2866
+ mdKV_semaphore_cur.iterator, tidx, wg_idx, head_idx % self.qhead_per_kvhead
2867
+ )
2868
+ cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128)
2869
+
2870
+ for epi_stage in cutlass.range_constexpr(num_epi_stages):
2871
+ # TMEM -> RMEM -- setup
2872
+ thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdKVtdKV).get_slice(tidx)
2873
+ tdKVtdKV_t2r_p = thr_copy_t2r.partition_S(tdKVtdKV)
2874
+ tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0]
2875
+ if const_expr(num_epi_stages > 1):
2876
+ tdKVtdKV_t2r = tdKVtdKV_t2r[None, epi_stage]
2877
+
2878
+ cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim))
2879
+ tdKVcdKV = thr_mma.partition_C(cdKV)
2880
+ tdKVcdKV_t2r_p = thr_copy_t2r.partition_D(tdKVcdKV)
2881
+ tdKVcdKV_t2r = self.split_wg(tdKVcdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0]
2882
+ if const_expr(num_epi_stages > 1):
2883
+ tdKVcdKV_t2r = tdKVcdKV_t2r[None, epi_stage]
2884
+
2885
+ tdKVrdKV_t2r = cute.make_fragment(tdKVcdKV_t2r.shape, Float32)
2886
+
2887
+ assert cute.size(tdKVrdKV_t2r) == cute.size(tdKVtdKV_t2r) // cute.arch.WARP_SIZE, (
2888
+ "RMEM<->TMEM fragment size mismatch"
2889
+ )
2890
+
2891
+ # TMEM -> RMEM -- copy and fence
2892
+ cute.copy(thr_copy_t2r, tdKVtdKV_t2r, tdKVrdKV_t2r)
2893
+ cute.arch.fence_view_async_tmem_load()
2894
+
2895
+ # RMEM -- scale and convert
2896
+ if const_expr(scale is not None):
2897
+ for i in cutlass.range(cute.size(tdKVrdKV_t2r.shape) // 2, unroll_full=True):
2898
+ tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1] = utils.mul_packed_f32x2(
2899
+ (tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1]), (scale, scale)
2900
+ )
2901
+ tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) # (32 columns)
2902
+ tdKVrdKV.store(tdKVrdKV_t2r.load().to(self.dv_dtype))
2903
+
2904
+ # RMEM -> SMEM -- copy, fence and barrier
2905
+ tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVsdKV_r2s.shape)
2906
+ cute.copy(thr_copy_r2s_dKV, tdKVrdKV_r2s, tdKVsdKV_r2s)
2907
+ cute.arch.fence_proxy(
2908
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
2909
+ )
2910
+ cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128)
2911
+
2912
+ # SMEM -> GMEM
2913
+ if leader_warp:
2914
+ if const_expr(not self.dKV_postprocess):
2915
+ cute.copy(tma_atom_dKV, tdKVsdKV, tdKVgdKV[None, epi_stage])
2916
+ else:
2917
+ with cute.arch.elect_one():
2918
+ copy_utils.cpasync_reduce_bulk_add_f32(
2919
+ sdKV.iterator,
2920
+ gdKV_epi[None, epi_stage].iterator,
2921
+ self.tma_copy_bytes["dKacc"],
2922
+ )
2923
+ if const_expr(epi_stage < num_epi_stages - 1):
2924
+ cute.arch.cp_async_bulk_commit_group()
2925
+ cute.arch.cp_async_bulk_wait_group(0, read=read_flag)
2926
+ cute.arch.barrier_arrive(
2927
+ barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE
2928
+ )
2929
+
2930
+ # Barrier since all warps need to wait for SMEM to be freed
2931
+ cute.arch.fence_proxy(
2932
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
2933
+ )
2934
+ cute.arch.barrier(
2935
+ barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE
2936
+ )
2937
+
2938
+ # semaphore release
2939
+ # NOTE: arrive_inc calls red_release which issues membar
2940
+ if const_expr(deterministic_KV):
2941
+ if leader_warp:
2942
+ cute.arch.cp_async_bulk_commit_group()
2943
+ cute.arch.cp_async_bulk_wait_group(0, read=read_flag)
2944
+ cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128)
2945
+ barrier.arrive_inc(mdKV_semaphore_cur.iterator, tidx, wg_idx, 1)
2946
+
2947
+ cute.arch.sync_warp()
2948
+ with cute.arch.elect_one():
2949
+ pipeline_dKV.consumer_release(consumer_state_dKV)
2950
+ consumer_state_dKV.advance()
2951
+ return consumer_state_dKV