mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1703 @@
1
+ # @nolint # fbcode
2
+ import math
3
+ from typing import Callable, Optional, Type
4
+ from functools import partial
5
+
6
+ import cuda.bindings.driver as cuda
7
+
8
+ import cutlass
9
+ import cutlass.cute as cute
10
+ import cutlass.utils.hopper_helpers as sm90_utils_basic
11
+ from cutlass.cute.nvgpu import cpasync, warpgroup
12
+ from cutlass.cute.arch import ProxyKind, SharedSpace
13
+ from cutlass.cute import FastDivmodDivisor
14
+ from cutlass import Float32, Int32, Boolean, const_expr
15
+ from cutlass.utils import LayoutEnum
16
+
17
+ from mslk.attention.flash_attn import hopper_helpers as sm90_utils
18
+ from mslk.attention.flash_attn import utils
19
+ from mslk.attention.flash_attn import copy_utils
20
+ from mslk.attention.flash_attn.hopper_helpers import gemm_zero_init, gemm_w_idx
21
+ from mslk.attention.flash_attn.mask import AttentionMask
22
+ from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
23
+ from mslk.attention.flash_attn.block_info import BlockInfo
24
+ from mslk.attention.flash_attn import pipeline
25
+ from mslk.attention.flash_attn.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, ParamsBase
26
+ from mslk.attention.flash_attn.named_barrier import NamedBarrierFwd, NamedBarrierBwd
27
+ from mslk.attention.flash_attn.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
28
+ from mslk.attention.flash_attn.block_sparsity import BlockSparseTensors
29
+ from mslk.attention.flash_attn.block_sparse_utils import (
30
+ get_total_q_block_count_bwd,
31
+ produce_block_sparse_q_loads_bwd_sm90,
32
+ consume_block_sparse_mma_bwd_sm90,
33
+ dQaccum_store_block_sparse_bwd_sm90,
34
+ )
35
+
36
+
37
+ def mma_partition_fragment_AB(
38
+ thr_mma: cute.core.ThrMma, sA: Optional[cute.Tensor], sB: Optional[cute.Tensor], swap_AB: bool
39
+ ):
40
+ if const_expr(not swap_AB):
41
+ return (
42
+ thr_mma.make_fragment_A(thr_mma.partition_A(sA)) if sA is not None else None,
43
+ thr_mma.make_fragment_B(thr_mma.partition_B(sB)) if sB is not None else None,
44
+ )
45
+ else:
46
+ return (
47
+ thr_mma.make_fragment_B(thr_mma.partition_B(sA)) if sA is not None else None,
48
+ thr_mma.make_fragment_A(thr_mma.partition_A(sB)) if sB is not None else None,
49
+ )
50
+
51
+
52
+ class FlashAttentionBackwardSm90:
53
+ arch = 90
54
+
55
+ def __init__(
56
+ self,
57
+ dtype: Type[cutlass.Numeric],
58
+ head_dim: int,
59
+ head_dim_v: Optional[int] = None,
60
+ qhead_per_kvhead: int = 1,
61
+ is_causal: bool = False,
62
+ tile_m: int = 64,
63
+ tile_n: int = 128,
64
+ Q_stage: int = 2,
65
+ dO_stage: int = 2,
66
+ PdS_stage: int = 2,
67
+ SdP_swapAB: bool = False,
68
+ dKV_swapAB: bool = False,
69
+ dQ_swapAB: bool = False,
70
+ AtomLayoutMSdP: int = 1,
71
+ AtomLayoutNdKV: int = 2,
72
+ AtomLayoutMdQ: int = 1,
73
+ num_threads: int = 384,
74
+ V_in_regs: bool = False,
75
+ score_mod: cutlass.Constexpr | None = None,
76
+ score_mod_bwd: cutlass.Constexpr | None = None,
77
+ mask_mod: cutlass.Constexpr | None = None,
78
+ has_aux_tensors: cutlass.Constexpr = False,
79
+ subtile_factor: cutlass.Constexpr[int] = 1,
80
+ ):
81
+ self.dtype = dtype
82
+ # padding head_dim to a multiple of 16 as k_block_size
83
+ hdim_multiple_of = 16
84
+ self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
85
+ head_dim_v = head_dim_v if head_dim_v is not None else head_dim
86
+ self.same_hdim_kv = head_dim == head_dim_v
87
+ self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)
88
+ # Can save registers (and hence be faster) if we don't have to check hdim predication
89
+ self.check_hdim_oob = head_dim != self.tile_hdim
90
+ self.check_hdim_v_oob = head_dim_v != self.tile_hdimv
91
+ self.qhead_per_kvhead = qhead_per_kvhead
92
+ self.is_causal = is_causal
93
+ self.is_local = False
94
+ self.tile_m = tile_m
95
+ self.tile_n = tile_n
96
+ self.num_threads = num_threads
97
+ self.Q_stage = Q_stage
98
+ self.dO_stage = dO_stage
99
+ self.PdS_stage = PdS_stage
100
+ assert self.dO_stage in [1, self.Q_stage]
101
+ assert self.PdS_stage in [1, self.Q_stage]
102
+ self.SdP_swapAB = SdP_swapAB
103
+ self.dKV_swapAB = dKV_swapAB
104
+ self.dQ_swapAB = dQ_swapAB
105
+ self.AtomLayoutMSdP = AtomLayoutMSdP
106
+ self.AtomLayoutNdKV = AtomLayoutNdKV
107
+ self.AtomLayoutMdQ = AtomLayoutMdQ
108
+ self.num_mma_warp_groups = (self.num_threads // 128) - 1
109
+ self.mma_dkv_is_rs = (
110
+ AtomLayoutMSdP == 1
111
+ and AtomLayoutNdKV == self.num_mma_warp_groups
112
+ and SdP_swapAB
113
+ and not dKV_swapAB
114
+ )
115
+ self.V_in_regs = V_in_regs
116
+ if qhead_per_kvhead > 1:
117
+ assert self.same_hdim_kv, "GQA backward requires head_dim == head_dim_v"
118
+ assert self.num_mma_warp_groups == 2, "GQA backward assumes 2 warp groups"
119
+ # These are tuned for speed
120
+ # Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share
121
+ # them and then shuffle to get the value whenever we need? This can reduce register
122
+ # pressure when SdP_swapAB, where each thread needs to keep statistics for (kBlockM / 4)
123
+ # rows. If !SdP_swapAB, each thread only needs to keep statistics for 2 rows.
124
+ # TODO: impl these for hdim 64
125
+ self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64
126
+ self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64
127
+
128
+ self.score_mod = score_mod
129
+ self.score_mod_bwd = score_mod_bwd
130
+ self.mask_mod = mask_mod
131
+ self.has_aux_tensors = has_aux_tensors
132
+ self.subtile_factor = subtile_factor
133
+ if cutlass.const_expr(has_aux_tensors):
134
+ self.vec_size: cutlass.Constexpr = 1
135
+ else:
136
+ self.vec_size: cutlass.Constexpr = 4
137
+ self.qk_acc_dtype = Float32
138
+
139
+ @staticmethod
140
+ def can_implement(
141
+ dtype,
142
+ head_dim,
143
+ head_dim_v,
144
+ tile_m,
145
+ tile_n,
146
+ Q_stage,
147
+ num_threads,
148
+ V_in_regs=False,
149
+ ) -> bool:
150
+ if dtype not in [cutlass.Float16, cutlass.BFloat16]:
151
+ return False
152
+ if head_dim % 8 != 0:
153
+ return False
154
+ if head_dim_v % 8 != 0:
155
+ return False
156
+ if tile_n % 16 != 0:
157
+ return False
158
+ if num_threads % 32 != 0:
159
+ return False
160
+ if (tile_m * 2) % num_threads != 0:
161
+ return False
162
+ return True
163
+
164
+ def _check_type(
165
+ self,
166
+ mQ_type: Type[cutlass.Numeric],
167
+ mK_type: Type[cutlass.Numeric],
168
+ mV_type: Type[cutlass.Numeric],
169
+ mdO_type: Type[cutlass.Numeric],
170
+ mLSE_type: Type[cutlass.Numeric],
171
+ mdPsum_type: Type[cutlass.Numeric],
172
+ mdQaccum_type: Type[cutlass.Numeric],
173
+ mdK_type: Type[cutlass.Numeric],
174
+ mdV_type: Type[cutlass.Numeric],
175
+ ):
176
+ # Get the data type and check if it is fp16 or bf16
177
+ if const_expr(not (mQ_type == mK_type == mV_type == mdO_type)):
178
+ raise TypeError("All tensors must have the same data type")
179
+ if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]):
180
+ raise TypeError("Only Float16 or BFloat16 is supported")
181
+ if const_expr(mLSE_type not in [Float32]):
182
+ raise TypeError("LSE tensor must be Float32")
183
+ if const_expr(mdPsum_type not in [Float32]):
184
+ raise TypeError("dPsum tensor must be Float32")
185
+ if const_expr(mdQaccum_type not in [Float32]):
186
+ raise TypeError("dQaccum tensor must be Float32")
187
+ if const_expr(self.qhead_per_kvhead == 1):
188
+ if const_expr(not (mdK_type == mdV_type == mQ_type)):
189
+ raise TypeError("mdK and mdV tensors must have the same data type as mQ")
190
+ else:
191
+ if const_expr(not (mdK_type == mdV_type == Float32)):
192
+ raise TypeError("mdKaccum and mdVaccum tensors must have the data type Float32")
193
+ assert mQ_type == self.dtype
194
+
195
+ def _setup_attributes(self):
196
+ self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout, self.sPdS_layout = [
197
+ sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage)
198
+ for shape, stage in [
199
+ ((self.tile_m, self.tile_hdim), self.Q_stage),
200
+ ((self.tile_n, self.tile_hdim), None),
201
+ ((self.tile_n, self.tile_hdimv), None),
202
+ ((self.tile_m, self.tile_hdimv), self.dO_stage),
203
+ ((self.tile_m, self.tile_n), self.PdS_stage),
204
+ ]
205
+ ]
206
+ self.sdQaccum_layout = cute.make_layout(
207
+ (self.tile_m * self.tile_hdim // self.num_mma_warp_groups, self.num_mma_warp_groups)
208
+ )
209
+ # dQaccum R->S
210
+ self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
211
+ cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
212
+ # thr_layout
213
+ cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)),
214
+ cute.make_layout(128 // Float32.width), # val_layout
215
+ )
216
+ # dKVaccum for GQA epilogue - reuses sV+sK memory recast as f32
217
+ self.sdKVaccum_layout = cute.make_layout(
218
+ (self.tile_n * self.tile_hdim // self.num_mma_warp_groups, self.num_mma_warp_groups)
219
+ )
220
+ # dKVaccum R->S (same pattern as dQaccum but sized for tile_n)
221
+ self.r2s_tiled_copy_dKVaccum = cute.make_tiled_copy_tv(
222
+ cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
223
+ cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)),
224
+ cute.make_layout(128 // Float32.width),
225
+ )
226
+
227
+ def _get_tiled_mma(self):
228
+ # S = Q @ K.T, dP = dO @ V.T
229
+ atom_layout_SdP = (self.AtomLayoutMSdP, self.num_mma_warp_groups // self.AtomLayoutMSdP)
230
+ tiler_mn_SdP = (self.tile_m // atom_layout_SdP[0], self.tile_n // atom_layout_SdP[1])
231
+ tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma(
232
+ self.dtype,
233
+ self.dtype,
234
+ warpgroup.OperandMajorMode.K,
235
+ warpgroup.OperandMajorMode.K,
236
+ Float32,
237
+ atom_layout_mnk=(atom_layout_SdP if not self.SdP_swapAB else atom_layout_SdP[::-1])
238
+ + (1,),
239
+ tiler_mn=tiler_mn_SdP if not self.SdP_swapAB else tiler_mn_SdP[::-1],
240
+ )
241
+ # dV = P.T @ dO, dK = dS.T @ Q
242
+ atom_layout_dKV = (self.AtomLayoutNdKV, self.num_mma_warp_groups // self.AtomLayoutNdKV)
243
+ tiler_mn_dK = (self.tile_n // atom_layout_dKV[0], self.tile_hdim // atom_layout_dKV[1])
244
+ tiler_mn_dV = (self.tile_n // atom_layout_dKV[0], self.tile_hdimv // atom_layout_dKV[1])
245
+ tiled_mma_dK, tiled_mma_dV = [
246
+ sm90_utils_basic.make_trivial_tiled_mma(
247
+ self.dtype,
248
+ self.dtype,
249
+ warpgroup.OperandMajorMode.MN
250
+ if not self.mma_dkv_is_rs
251
+ else warpgroup.OperandMajorMode.K,
252
+ warpgroup.OperandMajorMode.MN,
253
+ Float32,
254
+ atom_layout_mnk=(atom_layout_dKV if not self.dKV_swapAB else atom_layout_dKV[::-1])
255
+ + (1,),
256
+ tiler_mn=tiler_mn_d if not self.dKV_swapAB else tiler_mn_d[::-1],
257
+ a_source=warpgroup.OperandSource.RMEM
258
+ if self.mma_dkv_is_rs
259
+ else warpgroup.OperandSource.SMEM,
260
+ )
261
+ for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV)
262
+ ]
263
+ # dQ = dS @ K
264
+ atom_layout_dQ = (self.AtomLayoutMdQ, self.num_mma_warp_groups // self.AtomLayoutMdQ)
265
+ tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])
266
+ tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma(
267
+ self.dtype,
268
+ self.dtype,
269
+ warpgroup.OperandMajorMode.K if not self.dQ_swapAB else warpgroup.OperandMajorMode.MN,
270
+ warpgroup.OperandMajorMode.MN if not self.dQ_swapAB else warpgroup.OperandMajorMode.K,
271
+ Float32,
272
+ atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1]) + (1,),
273
+ tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1],
274
+ )
275
+ return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ
276
+
277
+ def _get_shared_storage_cls(self):
278
+ sQ_alignment = sK_alignment = sV_alighment = sdQaccum_alignment = sdO_alignment = 1024
279
+
280
+ sQ_struct, sK_struct, sV_struct, sdO_struct, sdQaccum_struct = [
281
+ cute.struct.Align[cute.struct.MemRange[type, cute.cosize(layout)], alignment]
282
+ for (layout, type, alignment) in [
283
+ (self.sQ_layout, self.dtype, sQ_alignment),
284
+ (self.sK_layout, self.dtype, sK_alignment),
285
+ (self.sV_layout, self.dtype, sV_alighment),
286
+ (self.sdO_layout, self.dtype, sdO_alignment),
287
+ (self.sdQaccum_layout, Float32, sdQaccum_alignment),
288
+ ]
289
+ ]
290
+
291
+ cosize_sdS = cute.cosize(self.sPdS_layout)
292
+ cosize_sP = cute.cosize(self.sPdS_layout) if const_expr(not self.mma_dkv_is_rs) else 0
293
+ sLSE_struct = cute.struct.Align[
294
+ cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.Q_stage], 128
295
+ ]
296
+ sdPsum_struct = cute.struct.Align[
297
+ cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.dO_stage], 128
298
+ ]
299
+
300
+ @cute.struct
301
+ class SharedStorageQKV:
302
+ mbar_ptr_Q: cute.struct.MemRange[cutlass.Int64, self.Q_stage * 2]
303
+ mbar_ptr_dO: cute.struct.MemRange[cutlass.Int64, self.dO_stage * 2]
304
+ sLSE: sLSE_struct
305
+ sdPsum: sdPsum_struct
306
+ sQ: sQ_struct
307
+ sV: sV_struct
308
+ sK: sK_struct
309
+ sdO: sdO_struct
310
+ sP: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024]
311
+ sdS: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sdS], 1024]
312
+ sdQaccum: sdQaccum_struct
313
+
314
+ return SharedStorageQKV
315
+
316
+ @cute.jit
317
+ def __call__(
318
+ self,
319
+ mQ: cute.Tensor,
320
+ mK: cute.Tensor,
321
+ mV: cute.Tensor,
322
+ mdO: cute.Tensor,
323
+ mLSE: cute.Tensor,
324
+ mdPsum: cute.Tensor,
325
+ mdQaccum: cute.Tensor,
326
+ mdK: cute.Tensor,
327
+ mdV: cute.Tensor,
328
+ softmax_scale: Float32,
329
+ stream: cuda.CUstream,
330
+ mCuSeqlensQ: Optional[cute.Tensor] = None,
331
+ mCuSeqlensK: Optional[cute.Tensor] = None,
332
+ mSeqUsedQ: Optional[cute.Tensor] = None,
333
+ mSeqUsedK: Optional[cute.Tensor] = None,
334
+ softcap: Float32 | float | None = None,
335
+ window_size_left: Int32 | int | None = None,
336
+ window_size_right: Int32 | int | None = None,
337
+ mdQ_semaphore: Optional[cute.Tensor] = None,
338
+ mdK_semaphore: Optional[cute.Tensor] = None,
339
+ mdV_semaphore: Optional[cute.Tensor] = None,
340
+ aux_tensors: Optional[list] = None,
341
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
342
+ ):
343
+ assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, (
344
+ "determinism not supported yet for Sm90"
345
+ )
346
+
347
+ self._check_type(
348
+ *(
349
+ t.element_type if t is not None else None
350
+ for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
351
+ )
352
+ )
353
+
354
+ # Assume all strides are divisible by 128 bits except the last stride
355
+ new_stride = lambda t: (
356
+ *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
357
+ t.stride[-1],
358
+ )
359
+ mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [
360
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
361
+ if t is not None
362
+ else None
363
+ for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
364
+ ]
365
+
366
+ layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b)
367
+ mQ, mK, mV, mdO = [utils.select(t, layout_transpose) for t in (mQ, mK, mV, mdO)]
368
+ if const_expr(self.qhead_per_kvhead == 1):
369
+ mdK, mdV = [utils.select(t, layout_transpose) for t in (mdK, mdV)]
370
+ else:
371
+ accum_transpose = [2, 1, 0] # (b, n, s*h) -> (s*h, n, b)
372
+ mdK, mdV = [utils.select(t, accum_transpose) for t in (mdK, mdV)]
373
+ LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) -> (s, n, b)
374
+ mLSE, mdPsum, mdQaccum = [
375
+ utils.select(t, LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum)
376
+ ]
377
+
378
+ tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ = self._get_tiled_mma()
379
+
380
+ self.num_mma_threads = tiled_mma_SdP.size
381
+ assert self.num_mma_threads + 128 == self.num_threads
382
+
383
+ self.num_threads_per_warp_group = 128
384
+ self.num_producer_threads = 32
385
+
386
+ self.num_mma_regs = 240
387
+ self.num_producer_regs = 24
388
+ # self.num_mma_regs = 232
389
+ # self.num_producer_regs = 40
390
+
391
+ self._setup_attributes()
392
+ SharedStorage = self._get_shared_storage_cls()
393
+
394
+ self.tma_copy_bytes = {
395
+ name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1]))
396
+ for name, mX, layout in [
397
+ ("Q", mQ, self.sQ_layout),
398
+ ("K", mK, self.sK_layout),
399
+ ("V", mV, self.sV_layout),
400
+ ("dO", mdO, self.sdO_layout),
401
+ ]
402
+ }
403
+ self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8
404
+ self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8
405
+ self.tma_copy_bytes["dQ"] = (
406
+ self.tile_m * self.tile_hdim * Float32.width // 8 // self.num_mma_warp_groups
407
+ )
408
+ self.tma_copy_bytes["dKacc"] = self.tile_n * self.tile_hdim * Float32.width // 8
409
+ self.tma_copy_bytes["dVacc"] = self.tile_n * self.tile_hdimv * Float32.width // 8
410
+
411
+ tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom(
412
+ cpasync.CopyBulkTensorTileG2SOp(),
413
+ mQ,
414
+ cute.select(self.sQ_layout, mode=[0, 1]),
415
+ (self.tile_m, self.tile_hdim),
416
+ )
417
+ tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom(
418
+ cpasync.CopyBulkTensorTileG2SOp(),
419
+ mK,
420
+ cute.select(self.sK_layout, mode=[0, 1]),
421
+ (self.tile_n, self.tile_hdim),
422
+ )
423
+ tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom(
424
+ cpasync.CopyBulkTensorTileG2SOp(),
425
+ mV,
426
+ cute.select(self.sV_layout, mode=[0, 1]),
427
+ (self.tile_n, self.tile_hdimv),
428
+ )
429
+ tma_atom_dO, tma_tensor_dO = cpasync.make_tiled_tma_atom(
430
+ cpasync.CopyBulkTensorTileG2SOp(),
431
+ mdO,
432
+ cute.select(self.sdO_layout, mode=[0, 1]),
433
+ (self.tile_m, self.tile_hdimv),
434
+ )
435
+ if const_expr(self.qhead_per_kvhead == 1):
436
+ tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom(
437
+ cpasync.CopyBulkTensorTileS2GOp(),
438
+ mdK,
439
+ cute.select(self.sK_layout, mode=[0, 1]),
440
+ (self.tile_n, self.tile_hdim),
441
+ )
442
+ tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom(
443
+ cpasync.CopyBulkTensorTileS2GOp(),
444
+ mdV,
445
+ cute.select(self.sV_layout, mode=[0, 1]),
446
+ (self.tile_n, self.tile_hdimv),
447
+ )
448
+ else:
449
+ tma_atom_dK = tma_atom_dV = tma_tensor_dK = tma_tensor_dV = None
450
+
451
+ TileScheduler = SingleTileScheduler
452
+ tile_sched_args = TileSchedulerArguments(
453
+ cute.ceil_div(cute.size(mK.shape[0]), self.tile_n),
454
+ cute.size(mQ.shape[2]),
455
+ cute.size(mQ.shape[3]),
456
+ 1, # num_splits
457
+ cute.size(mK.shape[0]),
458
+ mQ.shape[1],
459
+ mV.shape[1],
460
+ total_q=cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),
461
+ tile_shape_mn=(self.tile_m, self.tile_n),
462
+ mCuSeqlensQ=None,
463
+ mSeqUsedQ=None,
464
+ qhead_per_kvhead_packgqa=1,
465
+ element_size=self.dtype.width // 8,
466
+ is_persistent=False,
467
+ lpt=False,
468
+ )
469
+
470
+ tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
471
+ grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
472
+
473
+ LOG2_E = math.log2(math.e)
474
+ if const_expr(self.score_mod is None):
475
+ softmax_scale_log2 = softmax_scale * LOG2_E
476
+ else:
477
+ softmax_scale_log2 = LOG2_E
478
+
479
+ fastdiv_mods = None
480
+ if const_expr(aux_tensors is not None):
481
+ seqlen_q = cute.size(mQ.shape[0])
482
+ seqlen_k = cute.size(mK.shape[0])
483
+ seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
484
+ seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
485
+ fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
486
+
487
+ qhead_per_kvhead_divmod = None
488
+ if const_expr(self.qhead_per_kvhead > 1):
489
+ qhead_per_kvhead_divmod = FastDivmodDivisor(self.qhead_per_kvhead)
490
+
491
+ self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
492
+
493
+ self.kernel(
494
+ tma_tensor_Q,
495
+ tma_tensor_K,
496
+ tma_tensor_V,
497
+ tma_tensor_dO,
498
+ tma_tensor_dK if const_expr(self.qhead_per_kvhead == 1) else mdK,
499
+ tma_tensor_dV if const_expr(self.qhead_per_kvhead == 1) else mdV,
500
+ tma_atom_Q,
501
+ tma_atom_K,
502
+ tma_atom_V,
503
+ tma_atom_dO,
504
+ tma_atom_dK,
505
+ tma_atom_dV,
506
+ mLSE,
507
+ mdPsum,
508
+ mdQaccum,
509
+ self.sQ_layout,
510
+ self.sK_layout,
511
+ self.sV_layout,
512
+ self.sPdS_layout,
513
+ self.sdO_layout,
514
+ self.sdQaccum_layout,
515
+ self.sdKVaccum_layout,
516
+ self.r2s_tiled_copy_dQaccum,
517
+ self.r2s_tiled_copy_dKVaccum,
518
+ tiled_mma_SdP,
519
+ tiled_mma_dK,
520
+ tiled_mma_dV,
521
+ tiled_mma_dQ,
522
+ softmax_scale_log2,
523
+ softmax_scale,
524
+ tile_sched_params,
525
+ TileScheduler,
526
+ SharedStorage,
527
+ aux_tensors,
528
+ fastdiv_mods,
529
+ blocksparse_tensors,
530
+ qhead_per_kvhead_divmod,
531
+ ).launch(
532
+ grid=grid_dim,
533
+ block=[self.num_threads, 1, 1],
534
+ smem=SharedStorage.size_in_bytes(),
535
+ stream=stream,
536
+ min_blocks_per_mp=1,
537
+ )
538
+
539
+ @cute.kernel
540
+ def kernel(
541
+ self,
542
+ mQ: cute.Tensor,
543
+ mK: cute.Tensor,
544
+ mV: cute.Tensor,
545
+ mdO: cute.Tensor,
546
+ mdK: cute.Tensor,
547
+ mdV: cute.Tensor,
548
+ tma_atom_Q: cute.CopyAtom,
549
+ tma_atom_K: cute.CopyAtom,
550
+ tma_atom_V: cute.CopyAtom,
551
+ tma_atom_dO: cute.CopyAtom,
552
+ tma_atom_dK: cute.CopyAtom,
553
+ tma_atom_dV: cute.CopyAtom,
554
+ mLSE: cute.Tensor,
555
+ mdPsum: cute.Tensor,
556
+ mdQaccum: cute.Tensor,
557
+ sQ_layout: cute.ComposedLayout,
558
+ sK_layout: cute.ComposedLayout,
559
+ sV_layout: cute.ComposedLayout,
560
+ sPdS_layout: cute.ComposedLayout,
561
+ sdO_layout: cute.ComposedLayout,
562
+ sdQaccum_layout: cute.Layout,
563
+ sdKVaccum_layout: cute.Layout,
564
+ r2s_tiled_copy_dQaccum: cute.TiledCopy,
565
+ r2s_tiled_copy_dKVaccum: cute.TiledCopy,
566
+ tiled_mma_SdP: cute.TiledMma,
567
+ tiled_mma_dK: cute.TiledMma,
568
+ tiled_mma_dV: cute.TiledMma,
569
+ tiled_mma_dQ: cute.TiledMma,
570
+ softmax_scale_log2,
571
+ softmax_scale,
572
+ tile_sched_params: ParamsBase,
573
+ TileScheduler: cutlass.Constexpr[Callable],
574
+ SharedStorage: cutlass.Constexpr[Callable],
575
+ aux_tensors: Optional[list] = None,
576
+ fastdiv_mods=(None, None),
577
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
578
+ qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
579
+ ):
580
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
581
+
582
+ # prefetch TMA descriptors
583
+ if warp_idx == 0:
584
+ cpasync.prefetch_descriptor(tma_atom_Q)
585
+ cpasync.prefetch_descriptor(tma_atom_K)
586
+ cpasync.prefetch_descriptor(tma_atom_V)
587
+ cpasync.prefetch_descriptor(tma_atom_dO)
588
+
589
+ smem = cutlass.utils.SmemAllocator()
590
+ storage = smem.allocate(SharedStorage)
591
+
592
+ pipeline_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread)
593
+ pipeline_consumer_group = cutlass.pipeline.CooperativeGroup(
594
+ cutlass.pipeline.Agent.Thread, self.num_mma_threads // cute.arch.WARP_SIZE
595
+ )
596
+ pipeline_Q = pipeline.PipelineTmaAsync.create(
597
+ barrier_storage=storage.mbar_ptr_Q.data_ptr(),
598
+ num_stages=self.Q_stage,
599
+ producer_group=pipeline_producer_group,
600
+ consumer_group=pipeline_consumer_group,
601
+ tx_count=self.tma_copy_bytes["Q"] + self.tma_copy_bytes["LSE"],
602
+ defer_sync=True,
603
+ )
604
+ pipeline_dO = pipeline.PipelineTmaAsync.create(
605
+ barrier_storage=storage.mbar_ptr_dO.data_ptr(),
606
+ num_stages=self.dO_stage,
607
+ producer_group=pipeline_producer_group,
608
+ consumer_group=pipeline_consumer_group,
609
+ tx_count=self.tma_copy_bytes["dO"] + self.tma_copy_bytes["dPsum"],
610
+ defer_sync=False,
611
+ )
612
+
613
+ sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner)
614
+ sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner)
615
+ sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)
616
+ sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner)
617
+ sP = None
618
+ if const_expr(not self.mma_dkv_is_rs):
619
+ sP = storage.sP.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner)
620
+ sdS = storage.sdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner)
621
+ sLSE = storage.sLSE.get_tensor(
622
+ cute.make_layout(
623
+ (self.tile_m, self.Q_stage),
624
+ stride=(1, cute.round_up(self.tile_m, 64)),
625
+ )
626
+ )
627
+ sdPsum = storage.sdPsum.get_tensor(
628
+ cute.make_layout(
629
+ (self.tile_m, self.dO_stage),
630
+ stride=(1, cute.round_up(self.tile_m, 64)),
631
+ )
632
+ )
633
+ sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout)
634
+
635
+ block_info = BlockInfo(
636
+ self.tile_m,
637
+ self.tile_n,
638
+ self.is_causal,
639
+ self.is_local,
640
+ False, # is_split_kv
641
+ None,
642
+ None,
643
+ qhead_per_kvhead_packgqa=1,
644
+ )
645
+ SeqlenInfoCls = partial(
646
+ SeqlenInfoQK.create,
647
+ seqlen_q_static=mQ.shape[0],
648
+ seqlen_k_static=mK.shape[0],
649
+ mCuSeqlensQ=None,
650
+ mCuSeqlensK=None,
651
+ mSeqUsedQ=None,
652
+ mSeqUsedK=None,
653
+ )
654
+ AttentionMaskCls = partial(
655
+ AttentionMask,
656
+ self.tile_m,
657
+ self.tile_n,
658
+ window_size_left=None,
659
+ window_size_right=None,
660
+ swap_AB=self.SdP_swapAB,
661
+ )
662
+ TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
663
+
664
+ if warp_idx < 4:
665
+ cute.arch.warpgroup_reg_dealloc(self.num_producer_regs)
666
+ if warp_idx == 0:
667
+ self.load(
668
+ mQ,
669
+ mK,
670
+ mV,
671
+ mdO,
672
+ mLSE,
673
+ mdPsum,
674
+ sQ,
675
+ sK,
676
+ sV,
677
+ sdO,
678
+ sLSE,
679
+ sdPsum,
680
+ tma_atom_Q,
681
+ tma_atom_K,
682
+ tma_atom_V,
683
+ tma_atom_dO,
684
+ pipeline_Q,
685
+ pipeline_dO,
686
+ block_info,
687
+ SeqlenInfoCls,
688
+ TileSchedulerCls,
689
+ blocksparse_tensors,
690
+ qhead_per_kvhead_divmod,
691
+ )
692
+ if warp_idx == 1:
693
+ for warp_group_idx in cutlass.range(self.num_mma_warp_groups):
694
+ cute.arch.barrier_arrive(
695
+ barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
696
+ number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
697
+ )
698
+ self.dQaccum_store(
699
+ mdQaccum,
700
+ sdQaccum,
701
+ block_info,
702
+ TileSchedulerCls,
703
+ SeqlenInfoCls,
704
+ blocksparse_tensors,
705
+ )
706
+ else:
707
+ cute.arch.warpgroup_reg_alloc(self.num_mma_regs)
708
+ tidx, _, _ = cute.arch.thread_idx()
709
+ tidx = tidx - 128
710
+ self.mma(
711
+ tiled_mma_SdP,
712
+ tiled_mma_dK,
713
+ tiled_mma_dV,
714
+ tiled_mma_dQ,
715
+ mdK,
716
+ mdV,
717
+ mdQaccum,
718
+ sQ,
719
+ sK,
720
+ sV,
721
+ sdO,
722
+ sP,
723
+ sdS,
724
+ sLSE,
725
+ sdPsum,
726
+ sdQaccum,
727
+ pipeline_Q,
728
+ pipeline_dO,
729
+ tidx,
730
+ tma_atom_dK,
731
+ tma_atom_dV,
732
+ r2s_tiled_copy_dQaccum,
733
+ r2s_tiled_copy_dKVaccum,
734
+ sdKVaccum_layout,
735
+ softmax_scale_log2,
736
+ softmax_scale,
737
+ block_info,
738
+ SeqlenInfoCls,
739
+ AttentionMaskCls,
740
+ TileSchedulerCls,
741
+ aux_tensors,
742
+ fastdiv_mods,
743
+ blocksparse_tensors,
744
+ qhead_per_kvhead_divmod,
745
+ )
746
+
747
+ @cute.jit
748
+ def load(
749
+ self,
750
+ mQ: cute.Tensor,
751
+ mK: cute.Tensor,
752
+ mV: cute.Tensor,
753
+ mdO: cute.Tensor,
754
+ mLSE: cute.Tensor,
755
+ mdPsum: cute.Tensor,
756
+ sQ: cute.Tensor,
757
+ sK: cute.Tensor,
758
+ sV: cute.Tensor,
759
+ sdO: cute.Tensor,
760
+ sLSE: cute.Tensor,
761
+ sdPsum: cute.Tensor,
762
+ tma_atom_Q: cute.CopyAtom,
763
+ tma_atom_K: cute.CopyAtom,
764
+ tma_atom_V: cute.CopyAtom,
765
+ tma_atom_dO: cute.CopyAtom,
766
+ pipeline_Q: cutlass.pipeline.PipelineAsync,
767
+ pipeline_dO: cutlass.pipeline.PipelineAsync,
768
+ block_info: BlockInfo,
769
+ SeqlenInfoCls: Callable,
770
+ TileSchedulerCls: Callable,
771
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
772
+ qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
773
+ ):
774
+ warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
775
+
776
+ if warp_idx_in_wg == 0:
777
+ producer_state_Q = cutlass.pipeline.make_pipeline_state(
778
+ cutlass.pipeline.PipelineUserType.Producer, self.Q_stage
779
+ )
780
+ producer_state_dO = cutlass.pipeline.make_pipeline_state(
781
+ cutlass.pipeline.PipelineUserType.Producer, self.dO_stage
782
+ )
783
+ tile_scheduler = TileSchedulerCls()
784
+ work_tile = tile_scheduler.initial_work_tile_info()
785
+ while work_tile.is_valid_tile:
786
+ n_block, head_idx, batch_idx, _ = work_tile.tile_idx
787
+ seqlen = SeqlenInfoCls(batch_idx)
788
+ head_idx_kv = (
789
+ head_idx
790
+ if const_expr(self.qhead_per_kvhead == 1)
791
+ else head_idx // qhead_per_kvhead_divmod
792
+ )
793
+ mK_cur = mK[None, None, head_idx_kv, batch_idx]
794
+ gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
795
+ mV_cur = mV[None, None, head_idx_kv, batch_idx]
796
+ gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
797
+
798
+ mQ_cur = mQ[None, None, head_idx, batch_idx]
799
+ gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (None, 0))
800
+ mdO_cur = mdO[None, None, head_idx, batch_idx]
801
+ gdO = cute.local_tile(mdO_cur, (self.tile_m, self.tile_hdimv), (None, 0))
802
+ mLSE_cur = mLSE[None, head_idx, batch_idx]
803
+ gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,))
804
+ mdPsum_cur = mdPsum[None, head_idx, batch_idx]
805
+ gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,))
806
+
807
+ load_K, _, _ = copy_utils.tma_get_copy_fn(
808
+ tma_atom_K, 0, cute.make_layout(1), gK, sK, single_stage=True
809
+ )
810
+ load_V, _, _ = copy_utils.tma_get_copy_fn(
811
+ tma_atom_V, 0, cute.make_layout(1), gV, sV, single_stage=True
812
+ )
813
+ load_Q, _, _ = copy_utils.tma_get_copy_fn(
814
+ tma_atom_Q, 0, cute.make_layout(1), gQ, sQ
815
+ )
816
+ load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q)
817
+ load_dO, _, _ = copy_utils.tma_get_copy_fn(
818
+ tma_atom_dO, 0, cute.make_layout(1), gdO, sdO
819
+ )
820
+ load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO)
821
+ load_LSE = copy_utils.cpasync_bulk_get_copy_fn(gLSE, sLSE)
822
+ load_LSE = copy_utils.tma_producer_copy_fn(load_LSE, pipeline_Q)
823
+ load_dPsum = copy_utils.cpasync_bulk_get_copy_fn(gdPsum, sdPsum)
824
+ load_dPsum = copy_utils.tma_producer_copy_fn(load_dPsum, pipeline_dO)
825
+
826
+ m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
827
+
828
+ if const_expr(not self.use_block_sparsity):
829
+ total_m_block_cnt = m_block_max - m_block_min
830
+ process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
831
+ else:
832
+ total_m_block_cnt = get_total_q_block_count_bwd(
833
+ blocksparse_tensors,
834
+ batch_idx,
835
+ head_idx,
836
+ n_block,
837
+ subtile_factor=self.subtile_factor,
838
+ m_block_max=m_block_max,
839
+ )
840
+ process_tile = total_m_block_cnt > Int32(0)
841
+
842
+ if process_tile:
843
+ if const_expr(not self.use_block_sparsity):
844
+ first_m_block = m_block_min
845
+ pipeline_Q.producer_acquire(
846
+ producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"]
847
+ )
848
+ load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))
849
+ load_Q(first_m_block, producer_state=producer_state_Q)
850
+ with cute.arch.elect_one():
851
+ load_LSE(first_m_block, producer_state=producer_state_Q)
852
+ producer_state_dO_cur = (
853
+ producer_state_dO
854
+ if const_expr(self.Q_stage != self.dO_stage)
855
+ else producer_state_Q
856
+ )
857
+ pipeline_dO.producer_acquire(
858
+ producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes["V"]
859
+ )
860
+ load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur))
861
+ load_dO(first_m_block, producer_state=producer_state_dO_cur)
862
+ with cute.arch.elect_one():
863
+ load_dPsum(first_m_block, producer_state=producer_state_dO_cur)
864
+ producer_state_Q.advance()
865
+ producer_state_dO.advance()
866
+
867
+ for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1):
868
+ pipeline_Q.producer_acquire(producer_state_Q)
869
+ load_Q(m_block, producer_state=producer_state_Q)
870
+ with cute.arch.elect_one():
871
+ load_LSE(m_block, producer_state=producer_state_Q)
872
+ producer_state_dO_cur = (
873
+ producer_state_dO
874
+ if const_expr(self.Q_stage != self.dO_stage)
875
+ else producer_state_Q
876
+ )
877
+ pipeline_dO.producer_acquire(producer_state_dO_cur)
878
+ load_dO(m_block, producer_state=producer_state_dO_cur)
879
+ with cute.arch.elect_one():
880
+ load_dPsum(m_block, producer_state=producer_state_dO_cur)
881
+ producer_state_Q.advance()
882
+ producer_state_dO.advance()
883
+ else:
884
+ producer_state_Q, producer_state_dO = produce_block_sparse_q_loads_bwd_sm90(
885
+ blocksparse_tensors,
886
+ batch_idx,
887
+ head_idx,
888
+ n_block,
889
+ producer_state_Q,
890
+ producer_state_dO,
891
+ pipeline_Q,
892
+ pipeline_dO,
893
+ load_K,
894
+ load_V,
895
+ load_Q,
896
+ load_dO,
897
+ load_LSE,
898
+ load_dPsum,
899
+ self.tma_copy_bytes["K"],
900
+ self.tma_copy_bytes["V"],
901
+ Q_stage_eq_dO_stage=(self.Q_stage == self.dO_stage),
902
+ subtile_factor=self.subtile_factor,
903
+ m_block_max=m_block_max,
904
+ )
905
+
906
+ tile_scheduler.prefetch_next_work()
907
+ tile_scheduler.advance_to_next_work()
908
+ work_tile = tile_scheduler.get_current_work()
909
+
910
+ @cute.jit
911
+ def apply_score_mod(
912
+ self,
913
+ acc_S: cute.Tensor,
914
+ thr_mma_SdP: cute.core.ThrMma,
915
+ batch_idx,
916
+ head_idx,
917
+ m_block,
918
+ n_block,
919
+ softmax_scale,
920
+ seqlen_info: SeqlenInfoQK,
921
+ aux_tensors=None,
922
+ fastdiv_mods=(None, None),
923
+ ):
924
+ # [NOTE] SdP_swapAB: swapAB transposes the tile, so use (n, m) indexing
925
+ cS = cute.make_identity_tensor(
926
+ (self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n)
927
+ )
928
+ cS = cute.domain_offset(
929
+ (n_block * self.tile_n, m_block * self.tile_m)
930
+ if self.SdP_swapAB
931
+ else (m_block * self.tile_m, n_block * self.tile_n),
932
+ cS,
933
+ )
934
+ tScS = thr_mma_SdP.partition_C(cS)
935
+
936
+ apply_score_mod_inner(
937
+ acc_S,
938
+ tScS,
939
+ self.score_mod,
940
+ batch_idx,
941
+ head_idx,
942
+ softmax_scale,
943
+ self.vec_size,
944
+ self.qk_acc_dtype,
945
+ aux_tensors,
946
+ fastdiv_mods,
947
+ seqlen_info,
948
+ constant_q_idx=None,
949
+ qhead_per_kvhead=self.qhead_per_kvhead,
950
+ transpose_indices=self.SdP_swapAB,
951
+ )
952
+
953
+ @cute.jit
954
+ def apply_score_mod_bwd(
955
+ self,
956
+ grad_tensor: cute.Tensor,
957
+ score_tensor: cute.Tensor,
958
+ thr_mma_SdP: cute.core.ThrMma,
959
+ batch_idx,
960
+ head_idx,
961
+ m_block,
962
+ n_block,
963
+ softmax_scale,
964
+ seqlen_info: SeqlenInfoQK,
965
+ aux_tensors=None,
966
+ fastdiv_mods=(None, None),
967
+ ):
968
+ cS = cute.make_identity_tensor(
969
+ (self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n)
970
+ )
971
+ cS = cute.domain_offset(
972
+ (n_block * self.tile_n, m_block * self.tile_m)
973
+ if self.SdP_swapAB
974
+ else (m_block * self.tile_m, n_block * self.tile_n),
975
+ cS,
976
+ )
977
+ tScS = thr_mma_SdP.partition_C(cS)
978
+
979
+ apply_score_mod_bwd_inner(
980
+ grad_tensor,
981
+ score_tensor,
982
+ tScS,
983
+ self.score_mod_bwd,
984
+ batch_idx,
985
+ head_idx,
986
+ softmax_scale,
987
+ self.vec_size,
988
+ self.qk_acc_dtype,
989
+ aux_tensors,
990
+ fastdiv_mods,
991
+ seqlen_info,
992
+ constant_q_idx=None,
993
+ qhead_per_kvhead=self.qhead_per_kvhead,
994
+ transpose_indices=self.SdP_swapAB,
995
+ )
996
+
997
+ @cute.jit
998
+ def mma(
999
+ self,
1000
+ tiled_mma_SdP: cute.TiledMma,
1001
+ tiled_mma_dK: cute.TiledMma,
1002
+ tiled_mma_dV: cute.TiledMma,
1003
+ tiled_mma_dQ: cute.TiledMma,
1004
+ mdK: cute.Tensor,
1005
+ mdV: cute.Tensor,
1006
+ mdQaccum: cute.Tensor,
1007
+ sQ: cute.Tensor,
1008
+ sK: cute.Tensor,
1009
+ sV: cute.Tensor,
1010
+ sdO: cute.Tensor,
1011
+ sP: Optional[cute.Tensor],
1012
+ sdS: cute.Tensor,
1013
+ sLSE: cute.Tensor,
1014
+ sdPsum: cute.Tensor,
1015
+ sdQaccum: cute.Tensor,
1016
+ pipeline_Q: cutlass.pipeline.PipelineAsync,
1017
+ pipeline_dO: cutlass.pipeline.PipelineAsync,
1018
+ tidx: Int32,
1019
+ tma_atom_dK: cute.CopyAtom,
1020
+ tma_atom_dV: cute.CopyAtom,
1021
+ r2s_tiled_copy_dQaccum: cute.TiledCopy,
1022
+ r2s_tiled_copy_dKVaccum: cute.TiledCopy,
1023
+ sdKVaccum_layout: cute.Layout,
1024
+ softmax_scale_log2: Float32,
1025
+ softmax_scale: Float32,
1026
+ block_info: BlockInfo,
1027
+ SeqlenInfoCls: Callable,
1028
+ AttentionMaskCls: Callable,
1029
+ TileSchedulerCls: Callable,
1030
+ aux_tensors: Optional[list] = None,
1031
+ fastdiv_mods=(None, None),
1032
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
1033
+ qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
1034
+ ):
1035
+ warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
1036
+ warp_group_thread_layout = cute.make_layout(
1037
+ self.num_mma_warp_groups, stride=self.num_threads_per_warp_group
1038
+ )
1039
+ thr_mma_SdP = tiled_mma_SdP.get_slice(tidx)
1040
+ wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx))
1041
+ wg_mma_dK = tiled_mma_dK.get_slice(warp_group_thread_layout(warp_group_idx))
1042
+ wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx))
1043
+ wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(warp_group_idx))
1044
+ # S = Q @ K.T
1045
+ tSrQ, tSrK = mma_partition_fragment_AB(wg_mma_SdP, sQ, sK, self.SdP_swapAB)
1046
+ # dP = dO @ V.T
1047
+ tdPrdO, tdPrV = mma_partition_fragment_AB(wg_mma_SdP, sdO, sV, self.SdP_swapAB)
1048
+ # dV += P.T @ dO
1049
+ sPt = utils.transpose_view(sP) if sP is not None else None
1050
+ sdOt = utils.transpose_view(sdO)
1051
+ tdVrPt, tdVrdOt = mma_partition_fragment_AB(wg_mma_dV, sPt, sdOt, self.dKV_swapAB)
1052
+ # dK += dS.T @ Q
1053
+ sdSt = utils.transpose_view(sdS)
1054
+ sQt = utils.transpose_view(sQ)
1055
+ tdKrdSt, tdKrQt = mma_partition_fragment_AB(wg_mma_dK, sdSt, sQt, self.dKV_swapAB)
1056
+ # dQ = dS @ K
1057
+ sKt = utils.transpose_view(sK)
1058
+ tdQrdS, tdQrKt = mma_partition_fragment_AB(wg_mma_dQ, sdS, sKt, self.dQ_swapAB)
1059
+
1060
+ # Smem copy atom tiling
1061
+ smem_copy_atom_PdS = utils.get_smem_store_atom(
1062
+ self.arch, self.dtype, transpose=self.SdP_swapAB
1063
+ )
1064
+ smem_thr_copy_PdS = cute.make_tiled_copy_C(smem_copy_atom_PdS, tiled_mma_SdP).get_slice(
1065
+ tidx
1066
+ )
1067
+ tPsP = None
1068
+ if const_expr(sP is not None):
1069
+ tPsP = smem_thr_copy_PdS.partition_D(sP if const_expr(not self.SdP_swapAB) else sPt)
1070
+ tdSsdS = smem_thr_copy_PdS.partition_D(sdS if const_expr(not self.SdP_swapAB) else sdSt)
1071
+
1072
+ sLSE_mma = cute.make_tensor(
1073
+ sLSE.iterator,
1074
+ cute.make_layout(
1075
+ (self.tile_m, self.tile_n, self.Q_stage),
1076
+ stride=(1, 0, cute.round_up(self.tile_m, 64)),
1077
+ ),
1078
+ )
1079
+ sdPsum_mma = cute.make_tensor(
1080
+ sdPsum.iterator,
1081
+ cute.make_layout(
1082
+ (self.tile_m, self.tile_n, self.dO_stage),
1083
+ stride=(1, 0, cute.round_up(self.tile_m, 64)),
1084
+ ),
1085
+ )
1086
+ if const_expr(self.SdP_swapAB):
1087
+ sLSE_mma = utils.transpose_view(sLSE_mma)
1088
+ sdPsum_mma = utils.transpose_view(sdPsum_mma)
1089
+ LSEslice = (None, 0, None) if const_expr(not self.SdP_swapAB) else (0, None, None)
1090
+ tLSEsLSE = utils.make_acc_tensor_mn_view(thr_mma_SdP.partition_C(sLSE_mma))[LSEslice]
1091
+ tLSEsdPsum = utils.make_acc_tensor_mn_view(thr_mma_SdP.partition_C(sdPsum_mma))[LSEslice]
1092
+
1093
+ smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx)
1094
+ tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum)
1095
+
1096
+ dV_shape = (self.tile_n, self.tile_hdimv)
1097
+ acc_dV = cute.make_fragment(
1098
+ tiled_mma_dV.partition_shape_C(dV_shape if not self.dKV_swapAB else dV_shape[::-1]),
1099
+ Float32,
1100
+ )
1101
+ dK_shape = (self.tile_n, self.tile_hdim)
1102
+ acc_dK = cute.make_fragment(
1103
+ tiled_mma_dK.partition_shape_C(dK_shape if not self.dKV_swapAB else dK_shape[::-1]),
1104
+ Float32,
1105
+ )
1106
+
1107
+ mma_qk_fn = partial(
1108
+ gemm_zero_init,
1109
+ tiled_mma_SdP,
1110
+ (self.tile_m, self.tile_n),
1111
+ tSrQ,
1112
+ tSrK,
1113
+ swap_AB=self.SdP_swapAB,
1114
+ )
1115
+ mma_dov_fn = partial(
1116
+ gemm_zero_init,
1117
+ tiled_mma_SdP,
1118
+ (self.tile_m, self.tile_n),
1119
+ tdPrdO,
1120
+ tdPrV,
1121
+ swap_AB=self.SdP_swapAB,
1122
+ )
1123
+ if const_expr(not self.mma_dkv_is_rs):
1124
+ mma_pdo_fn = partial(
1125
+ gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt, swap_AB=self.dKV_swapAB
1126
+ )
1127
+ mma_dsq_fn = partial(
1128
+ gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt, swap_AB=self.dKV_swapAB
1129
+ )
1130
+ else:
1131
+ assert not self.dKV_swapAB
1132
+ mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, acc_dV, tCrB=tdVrdOt)
1133
+ mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, acc_dK, tCrB=tdKrQt)
1134
+ mma_dsk_fn = partial(
1135
+ gemm_zero_init,
1136
+ tiled_mma_dQ,
1137
+ (self.tile_m, self.tile_hdim),
1138
+ tdQrdS,
1139
+ tdQrKt,
1140
+ swap_AB=self.dQ_swapAB,
1141
+ )
1142
+
1143
+ mma_one_m_block_all = partial(
1144
+ self.mma_one_m_block,
1145
+ warp_group_idx=warp_group_idx,
1146
+ mma_qk_fn=mma_qk_fn,
1147
+ mma_dov_fn=mma_dov_fn,
1148
+ mma_pdo_fn=mma_pdo_fn,
1149
+ mma_dsq_fn=mma_dsq_fn,
1150
+ mma_dsk_fn=mma_dsk_fn,
1151
+ pipeline_Q=pipeline_Q,
1152
+ pipeline_dO=pipeline_dO,
1153
+ tLSEsLSE=tLSEsLSE,
1154
+ tLSEsdPsum=tLSEsdPsum,
1155
+ tPsP=tPsP,
1156
+ tdSsdS=tdSsdS,
1157
+ tdQsdQaccum=tdQsdQaccum,
1158
+ smem_thr_copy_PdS=smem_thr_copy_PdS,
1159
+ smem_thr_copy_dQaccum=smem_thr_copy_dQaccum,
1160
+ softmax_scale_log2=softmax_scale_log2,
1161
+ # acc_dV=acc_dV,
1162
+ # acc_dK=acc_dK,
1163
+ )
1164
+
1165
+ consumer_state_Q = cutlass.pipeline.make_pipeline_state(
1166
+ cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage
1167
+ )
1168
+ consumer_state_dO = cutlass.pipeline.make_pipeline_state(
1169
+ cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage
1170
+ )
1171
+ tile_scheduler = TileSchedulerCls()
1172
+ work_tile = tile_scheduler.initial_work_tile_info()
1173
+ while work_tile.is_valid_tile:
1174
+ n_block, head_idx, batch_idx, _ = work_tile.tile_idx
1175
+ seqlen = SeqlenInfoCls(batch_idx)
1176
+ mask = AttentionMaskCls(seqlen)
1177
+ m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
1178
+
1179
+ if const_expr(not self.use_block_sparsity):
1180
+ process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
1181
+ else:
1182
+ total_m_block_cnt = get_total_q_block_count_bwd(
1183
+ blocksparse_tensors,
1184
+ batch_idx,
1185
+ head_idx,
1186
+ n_block,
1187
+ subtile_factor=self.subtile_factor,
1188
+ m_block_max=m_block_max,
1189
+ )
1190
+ process_tile = total_m_block_cnt > Int32(0)
1191
+
1192
+ if process_tile:
1193
+ if const_expr(not self.use_block_sparsity):
1194
+ mask_fn = partial(
1195
+ mask.apply_mask,
1196
+ batch_idx=batch_idx,
1197
+ head_idx=head_idx,
1198
+ n_block=n_block,
1199
+ thr_mma=thr_mma_SdP,
1200
+ mask_seqlen=True,
1201
+ mask_causal=self.is_causal,
1202
+ mask_local=self.is_local,
1203
+ mask_mod=self.mask_mod,
1204
+ aux_tensors=aux_tensors,
1205
+ fastdiv_mods=fastdiv_mods,
1206
+ )
1207
+ dKV_accumulate = False
1208
+ for m_block in cutlass.range(m_block_min, m_block_max, unroll=1):
1209
+ consumer_state_Q, consumer_state_dO = mma_one_m_block_all(
1210
+ m_block,
1211
+ consumer_state_Q,
1212
+ consumer_state_dO,
1213
+ mask_fn=mask_fn,
1214
+ dKV_accumulate=dKV_accumulate,
1215
+ thr_mma_SdP=thr_mma_SdP,
1216
+ batch_idx=batch_idx,
1217
+ head_idx=head_idx,
1218
+ n_block=n_block,
1219
+ softmax_scale=softmax_scale,
1220
+ seqlen=seqlen,
1221
+ aux_tensors=aux_tensors,
1222
+ fastdiv_mods=fastdiv_mods,
1223
+ )
1224
+ dKV_accumulate = True
1225
+ else:
1226
+ consumer_state_Q, consumer_state_dO = consume_block_sparse_mma_bwd_sm90(
1227
+ blocksparse_tensors,
1228
+ batch_idx,
1229
+ head_idx,
1230
+ n_block,
1231
+ consumer_state_Q,
1232
+ consumer_state_dO,
1233
+ mma_one_m_block_all,
1234
+ mask,
1235
+ self.mask_mod,
1236
+ is_causal=self.is_causal,
1237
+ is_local=self.is_local,
1238
+ thr_mma_SdP=thr_mma_SdP,
1239
+ softmax_scale=softmax_scale,
1240
+ seqlen=seqlen,
1241
+ subtile_factor=self.subtile_factor,
1242
+ m_block_max=m_block_max,
1243
+ aux_tensors=aux_tensors,
1244
+ fastdiv_mods=fastdiv_mods,
1245
+ )
1246
+
1247
+ if const_expr(self.qhead_per_kvhead == 1):
1248
+ acc_dK.store(acc_dK.load() * softmax_scale)
1249
+ self.epilogue_dKV(
1250
+ acc_dV,
1251
+ mdV,
1252
+ sV,
1253
+ acc_dK,
1254
+ mdK,
1255
+ sK,
1256
+ seqlen,
1257
+ tma_atom_dK,
1258
+ tma_atom_dV,
1259
+ tiled_mma_dK,
1260
+ tiled_mma_dV,
1261
+ r2s_tiled_copy_dKVaccum,
1262
+ sdKVaccum_layout,
1263
+ tidx,
1264
+ n_block,
1265
+ head_idx,
1266
+ batch_idx,
1267
+ qhead_per_kvhead_divmod,
1268
+ )
1269
+ else:
1270
+ # Block sparsity: KV tile with zero Q blocks produces no dK/dV; write zeros.
1271
+ if const_expr(self.use_block_sparsity):
1272
+ acc_dK.fill(0.0)
1273
+ acc_dV.fill(0.0)
1274
+ self.epilogue_dKV(
1275
+ acc_dV,
1276
+ mdV,
1277
+ sV,
1278
+ acc_dK,
1279
+ mdK,
1280
+ sK,
1281
+ seqlen,
1282
+ tma_atom_dK,
1283
+ tma_atom_dV,
1284
+ tiled_mma_dK,
1285
+ tiled_mma_dV,
1286
+ r2s_tiled_copy_dKVaccum,
1287
+ sdKVaccum_layout,
1288
+ tidx,
1289
+ n_block,
1290
+ head_idx,
1291
+ batch_idx,
1292
+ qhead_per_kvhead_divmod,
1293
+ )
1294
+
1295
+ tile_scheduler.advance_to_next_work()
1296
+ work_tile = tile_scheduler.get_current_work()
1297
+
1298
+ @cute.jit
1299
+ def mma_one_m_block(
1300
+ self,
1301
+ m_block: Int32,
1302
+ consumer_state_Q: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,
1303
+ consumer_state_dO: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,
1304
+ warp_group_idx: Int32,
1305
+ mma_qk_fn: Callable,
1306
+ mma_dov_fn: Callable,
1307
+ mma_pdo_fn: Callable,
1308
+ mma_dsq_fn: Callable,
1309
+ mma_dsk_fn: Callable,
1310
+ pipeline_Q: cutlass.pipeline.PipelineAsync,
1311
+ pipeline_dO: cutlass.pipeline.PipelineAsync,
1312
+ tLSEsLSE: cute.Tensor,
1313
+ tLSEsdPsum: cute.Tensor,
1314
+ tPsP: Optional[cute.Tensor],
1315
+ tdSsdS: Optional[cute.Tensor],
1316
+ tdQsdQaccum: cute.Tensor,
1317
+ smem_thr_copy_PdS: cute.TiledCopy,
1318
+ smem_thr_copy_dQaccum: cute.TiledCopy,
1319
+ softmax_scale_log2: Float32,
1320
+ mask_fn: Optional[Callable] = None,
1321
+ dKV_accumulate: Boolean = True,
1322
+ thr_mma_SdP: Optional[cute.core.ThrMma] = None,
1323
+ batch_idx: Int32 = 0,
1324
+ head_idx: Int32 = 0,
1325
+ n_block: Int32 = 0,
1326
+ softmax_scale: Float32 = 1.0,
1327
+ seqlen: Optional[SeqlenInfoQK] = None,
1328
+ aux_tensors: Optional[list] = None,
1329
+ fastdiv_mods=(None, None),
1330
+ ):
1331
+ consumer_state_dO_cur = (
1332
+ consumer_state_dO if const_expr(self.Q_stage == self.dO_stage) else consumer_state_Q
1333
+ )
1334
+ smem_idx_Q = consumer_state_Q.index
1335
+ smem_idx_dO = consumer_state_dO_cur.index if const_expr(self.dO_stage > 1) else 0
1336
+ smem_idx_PdS = smem_idx_Q if const_expr(self.PdS_stage > 1) else 0
1337
+ # (1) [GEMM 1] S = Q @ K^T
1338
+ pipeline_Q.consumer_wait(consumer_state_Q, pipeline_Q.consumer_try_wait(consumer_state_Q))
1339
+ acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1)
1340
+ tLSErLSE = copy_utils.load_s2r(tLSEsLSE[None, smem_idx_Q])
1341
+ # (2) [GEMM 2] dP = dO @ V.T
1342
+ pipeline_dO.consumer_wait(
1343
+ consumer_state_dO_cur, pipeline_dO.consumer_try_wait(consumer_state_dO_cur)
1344
+ )
1345
+ acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1)
1346
+
1347
+ if const_expr(self.score_mod_bwd is not None):
1348
+ acc_S_pre = cute.make_fragment_like(acc_S)
1349
+ cute.autovec_copy(acc_S, acc_S_pre)
1350
+
1351
+ if const_expr(self.score_mod is not None):
1352
+ self.apply_score_mod(
1353
+ acc_S,
1354
+ thr_mma_SdP,
1355
+ batch_idx,
1356
+ head_idx,
1357
+ m_block,
1358
+ n_block,
1359
+ softmax_scale,
1360
+ seqlen,
1361
+ aux_tensors,
1362
+ fastdiv_mods,
1363
+ )
1364
+
1365
+ # (3) [Pointwise 1] P = exp(S - LSE)
1366
+ if cutlass.const_expr(mask_fn is not None):
1367
+ mask_fn(acc_S, m_block=m_block)
1368
+ acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.SdP_swapAB)
1369
+ for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])):
1370
+ for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True):
1371
+ acc_S_mn[r, c] = cute.math.exp2(
1372
+ acc_S_mn[r, c] * softmax_scale_log2 - tLSErLSE[r], fastmath=True
1373
+ )
1374
+ tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO])
1375
+
1376
+ # Convert P from f32 -> f16
1377
+ tdVrP = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_S), self.dtype)
1378
+ # R2S for P
1379
+ if const_expr(not self.mma_dkv_is_rs):
1380
+ # sync to ensure P has already been used in the previous iteration before overwriting
1381
+ if const_expr(self.PdS_stage == 1):
1382
+ cute.arch.barrier(
1383
+ barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads
1384
+ )
1385
+ tPrP = smem_thr_copy_PdS.retile(tdVrP)
1386
+ cute.copy(smem_thr_copy_PdS, tPrP, tPsP[None, None, None, smem_idx_PdS])
1387
+
1388
+ # (4) [Pointwise 2] dS = P*(dP-dPsum)
1389
+ warpgroup.wait_group(0)
1390
+ acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP, transpose=self.SdP_swapAB)
1391
+ for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])):
1392
+ for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True):
1393
+ acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r])
1394
+
1395
+ if const_expr(self.score_mod_bwd is not None):
1396
+ self.apply_score_mod_bwd(
1397
+ acc_dP,
1398
+ acc_S_pre,
1399
+ thr_mma_SdP,
1400
+ batch_idx,
1401
+ head_idx,
1402
+ m_block,
1403
+ n_block,
1404
+ softmax_scale,
1405
+ seqlen,
1406
+ aux_tensors,
1407
+ fastdiv_mods,
1408
+ )
1409
+
1410
+ # Convert dS from f32 -> f16
1411
+ tdKrdS = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_dP), self.dtype)
1412
+
1413
+ # If there's double buffering on dS, we don't need to sync here.
1414
+ # Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ.
1415
+ # But because both WGs have to sync at the end of the loop and double buffering,
1416
+ # this race condition is not possible.
1417
+ # This sync is to ensure (1) P is written in case of !mma_dkv_is_rs and
1418
+ # (2) dS is already read by the Mma in the previous iteration in case of mma_dkv_is_rs.
1419
+ if const_expr(not self.mma_dkv_is_rs or (self.PdS_stage == 1 and self.mma_dkv_is_rs)):
1420
+ cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
1421
+ cute.arch.barrier(
1422
+ barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads
1423
+ )
1424
+
1425
+ # R2S for dS
1426
+ tdSrdS = smem_thr_copy_PdS.retile(tdKrdS)
1427
+ cute.copy(smem_thr_copy_PdS, tdSrdS, tdSsdS[None, None, None, smem_idx_PdS])
1428
+
1429
+ # (5) [GEMM 3] dV += P.T @ dO
1430
+ if const_expr(not self.mma_dkv_is_rs):
1431
+ mma_pdo_fn(
1432
+ A_idx=smem_idx_PdS, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1
1433
+ )
1434
+ else:
1435
+ mma_pdo_fn(tCrA=tdVrP, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1)
1436
+
1437
+ # smem fence to make sure sdS is written before it's read by WGMMA
1438
+ cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
1439
+ cute.arch.barrier(
1440
+ barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads
1441
+ )
1442
+ # (6) [GEMM 4] dQ = dS @ K
1443
+ acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1)
1444
+ # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV)
1445
+ pipeline_dO.consumer_release(consumer_state_dO_cur) # release dO as dV mma is done
1446
+
1447
+ # (7) [GEMM 5] dK += dS.T @ Q
1448
+ if const_expr(not self.mma_dkv_is_rs):
1449
+ mma_dsq_fn(
1450
+ A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1
1451
+ )
1452
+ else:
1453
+ mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1)
1454
+ # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dQ)
1455
+
1456
+ cute.arch.barrier(
1457
+ barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1458
+ number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
1459
+ )
1460
+ tdQrdQaccum_flat = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape))
1461
+ cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum)
1462
+ cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
1463
+ cute.arch.barrier_arrive(
1464
+ barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
1465
+ number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
1466
+ )
1467
+
1468
+ warpgroup.wait_group(0)
1469
+ # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dK)
1470
+ pipeline_Q.consumer_release(consumer_state_Q)
1471
+ # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block = {}, after pipeline_Q consumer release", cute.arch.thread_idx()[0], m_block)
1472
+
1473
+ consumer_state_Q.advance()
1474
+ consumer_state_dO.advance()
1475
+ return consumer_state_Q, consumer_state_dO
1476
+
1477
+ @cute.jit
1478
+ def epilogue_dKV(
1479
+ self,
1480
+ acc_dV: cute.Tensor,
1481
+ mdV: cute.Tensor,
1482
+ sV: cute.Tensor,
1483
+ acc_dK: cute.Tensor,
1484
+ mdK: cute.Tensor,
1485
+ sK: cute.Tensor,
1486
+ seqlen: SeqlenInfoQK,
1487
+ tma_atom_dK: cute.CopyAtom,
1488
+ tma_atom_dV: cute.CopyAtom,
1489
+ tiled_mma_dK: cute.TiledMma,
1490
+ tiled_mma_dV: cute.TiledMma,
1491
+ r2s_tiled_copy_dKVaccum: cute.TiledCopy,
1492
+ sdKVaccum_layout: cute.Layout,
1493
+ tidx: Int32,
1494
+ n_block: Int32,
1495
+ head_idx: Int32,
1496
+ batch_idx: Int32,
1497
+ qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
1498
+ ):
1499
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1500
+
1501
+ if const_expr(self.qhead_per_kvhead == 1):
1502
+ rdV = cute.make_fragment_like(acc_dV, self.dtype)
1503
+ rdV.store(acc_dV.load().to(self.dtype))
1504
+ rdK = utils.cvt_f16(acc_dK, self.dtype)
1505
+
1506
+ cute.arch.barrier(
1507
+ barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads
1508
+ )
1509
+
1510
+ smem_copy_atom_dKV = cute.make_copy_atom(
1511
+ cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=self.dKV_swapAB, num_matrices=4),
1512
+ self.dtype,
1513
+ )
1514
+ smem_thr_copy_dK = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dK).get_slice(
1515
+ tidx
1516
+ )
1517
+ smem_thr_copy_dV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dV).get_slice(
1518
+ tidx
1519
+ )
1520
+ mdV_cur = mdV[None, None, head_idx, batch_idx]
1521
+ mdK_cur = mdK[None, None, head_idx, batch_idx]
1522
+ gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
1523
+ gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
1524
+ store_dK, _, _ = copy_utils.tma_get_copy_fn(
1525
+ tma_atom_dK, 0, cute.make_layout(1), sK, gdK, single_stage=True
1526
+ )
1527
+ store_dV, _, _ = copy_utils.tma_get_copy_fn(
1528
+ tma_atom_dV, 0, cute.make_layout(1), sV, gdV, single_stage=True
1529
+ )
1530
+
1531
+ taccdVrdV = smem_thr_copy_dV.retile(rdV)
1532
+ sdV = sV if const_expr(not self.dKV_swapAB) else utils.transpose_view(sV)
1533
+ taccdVsdV = smem_thr_copy_dV.partition_D(sdV)
1534
+ cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV)
1535
+ cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
1536
+ cute.arch.barrier(
1537
+ barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads
1538
+ )
1539
+ if warp_idx == 4:
1540
+ store_dV()
1541
+ taccdKrdK = smem_thr_copy_dK.retile(rdK)
1542
+ sdK = sK if const_expr(not self.dKV_swapAB) else utils.transpose_view(sK)
1543
+ taccdKsdK = smem_thr_copy_dK.partition_D(sdK)
1544
+ cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK)
1545
+ cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
1546
+ cute.arch.barrier(
1547
+ barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads
1548
+ )
1549
+ if warp_idx == 4:
1550
+ store_dK()
1551
+ cute.arch.cp_async_bulk_commit_group()
1552
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
1553
+ else:
1554
+ head_idx_kv = head_idx // qhead_per_kvhead_divmod
1555
+
1556
+ mdKaccum_cur = mdK[None, head_idx_kv, batch_idx]
1557
+ gdKaccum_ = cute.local_tile(mdKaccum_cur, (self.tile_n * self.tile_hdim,), (n_block,))
1558
+ gdKaccum = cute.flat_divide(
1559
+ gdKaccum_, (self.tile_n * self.tile_hdim // self.num_mma_warp_groups,)
1560
+ )
1561
+
1562
+ mdVaccum_cur = mdV[None, head_idx_kv, batch_idx]
1563
+ gdVaccum_ = cute.local_tile(mdVaccum_cur, (self.tile_n * self.tile_hdimv,), (n_block,))
1564
+ gdVaccum = cute.flat_divide(
1565
+ gdVaccum_, (self.tile_n * self.tile_hdimv // self.num_mma_warp_groups,)
1566
+ )
1567
+
1568
+ sdKVaccum = cute.make_tensor(
1569
+ cute.recast_ptr(sV.iterator, dtype=Float32),
1570
+ sdKVaccum_layout,
1571
+ )
1572
+
1573
+ smem_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_slice(tidx)
1574
+ tdKsdKVaccum = smem_thr_copy_dKVaccum.partition_D(sdKVaccum)
1575
+
1576
+ cute.arch.barrier(
1577
+ barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads
1578
+ )
1579
+
1580
+ tdKrdKaccum_flat = cute.make_tensor(
1581
+ acc_dK.iterator, cute.make_layout(tdKsdKVaccum.shape)
1582
+ )
1583
+ cute.autovec_copy(tdKrdKaccum_flat, tdKsdKVaccum)
1584
+ cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
1585
+ cute.arch.barrier(
1586
+ barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads
1587
+ )
1588
+
1589
+ if warp_idx == 4:
1590
+ with cute.arch.elect_one():
1591
+ for wg_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
1592
+ copy_utils.cpasync_reduce_bulk_add_f32(
1593
+ sdKVaccum[None, wg_idx].iterator,
1594
+ gdKaccum[None, wg_idx].iterator,
1595
+ self.tma_copy_bytes["dKacc"] // self.num_mma_warp_groups,
1596
+ )
1597
+ cute.arch.cp_async_bulk_commit_group()
1598
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
1599
+
1600
+ cute.arch.barrier(
1601
+ barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads
1602
+ )
1603
+
1604
+ tdVrdVaccum_flat = cute.make_tensor(
1605
+ acc_dV.iterator, cute.make_layout(tdKsdKVaccum.shape)
1606
+ )
1607
+ cute.autovec_copy(tdVrdVaccum_flat, tdKsdKVaccum)
1608
+ cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
1609
+ cute.arch.barrier(
1610
+ barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads
1611
+ )
1612
+
1613
+ if warp_idx == 4:
1614
+ with cute.arch.elect_one():
1615
+ for wg_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
1616
+ copy_utils.cpasync_reduce_bulk_add_f32(
1617
+ sdKVaccum[None, wg_idx].iterator,
1618
+ gdVaccum[None, wg_idx].iterator,
1619
+ self.tma_copy_bytes["dVacc"] // self.num_mma_warp_groups,
1620
+ )
1621
+ cute.arch.cp_async_bulk_commit_group()
1622
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
1623
+
1624
+ @cute.jit
1625
+ def dQaccum_store(
1626
+ self,
1627
+ mdQaccum: cute.Tensor,
1628
+ sdQaccum: cute.Tensor,
1629
+ block_info: BlockInfo,
1630
+ TileSchedulerCls: cutlass.Constexpr[Callable],
1631
+ SeqlenInfoCls: cutlass.Constexpr[Callable],
1632
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
1633
+ ):
1634
+ tile_scheduler = TileSchedulerCls()
1635
+ work_tile = tile_scheduler.initial_work_tile_info()
1636
+ while work_tile.is_valid_tile:
1637
+ n_block, head_idx, batch_idx, _ = work_tile.tile_idx
1638
+ seqlen = SeqlenInfoCls(batch_idx)
1639
+ mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
1640
+ gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,))
1641
+ # (M * K / WG, WG, _)
1642
+ gdQaccum = cute.flat_divide(
1643
+ gdQaccum_, (self.tile_m * self.tile_hdim // self.num_mma_warp_groups,)
1644
+ )
1645
+ m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
1646
+ if const_expr(not self.use_block_sparsity):
1647
+ process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
1648
+ loop_count = m_block_max - m_block_min
1649
+ else:
1650
+ total_block_cnt = get_total_q_block_count_bwd(
1651
+ blocksparse_tensors,
1652
+ batch_idx,
1653
+ head_idx,
1654
+ n_block,
1655
+ subtile_factor=self.subtile_factor,
1656
+ m_block_max=m_block_max,
1657
+ )
1658
+ process_tile = total_block_cnt > Int32(0)
1659
+
1660
+ if process_tile:
1661
+ if const_expr(not self.use_block_sparsity):
1662
+ for iter_idx in cutlass.range(loop_count, unroll=1):
1663
+ m_block = m_block_min + iter_idx
1664
+ m_block_safe = m_block
1665
+
1666
+ for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
1667
+ cute.arch.barrier(
1668
+ barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
1669
+ number_of_threads=self.num_threads_per_warp_group
1670
+ + cute.arch.WARP_SIZE,
1671
+ )
1672
+ with cute.arch.elect_one():
1673
+ copy_utils.cpasync_reduce_bulk_add_f32(
1674
+ sdQaccum[None, warp_group_idx].iterator,
1675
+ gdQaccum[None, warp_group_idx, m_block_safe].iterator,
1676
+ self.tma_copy_bytes["dQ"],
1677
+ )
1678
+ cute.arch.cp_async_bulk_commit_group()
1679
+ for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
1680
+ cute.arch.cp_async_bulk_wait_group(
1681
+ self.num_mma_warp_groups - 1 - warp_group_idx, read=True
1682
+ )
1683
+ cute.arch.barrier_arrive(
1684
+ barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1685
+ number_of_threads=self.num_threads_per_warp_group
1686
+ + cute.arch.WARP_SIZE,
1687
+ )
1688
+ else:
1689
+ dQaccum_store_block_sparse_bwd_sm90(
1690
+ blocksparse_tensors,
1691
+ batch_idx,
1692
+ head_idx,
1693
+ n_block,
1694
+ sdQaccum,
1695
+ gdQaccum,
1696
+ subtile_factor=self.subtile_factor,
1697
+ m_block_max=m_block_max,
1698
+ num_mma_warp_groups=self.num_mma_warp_groups,
1699
+ num_threads_per_warp_group=self.num_threads_per_warp_group,
1700
+ tma_copy_bytes_dQ=self.tma_copy_bytes["dQ"],
1701
+ )
1702
+ tile_scheduler.advance_to_next_work()
1703
+ work_tile = tile_scheduler.get_current_work()