mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1262 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_bwd_sm80.hpp
4
+ # from Cutlass C++ to Cute-DSL.
5
+ import math
6
+ from types import SimpleNamespace
7
+ from typing import Type, Callable, Optional
8
+ from functools import partial
9
+
10
+ import cuda.bindings.driver as cuda
11
+
12
+ import cutlass
13
+ import cutlass.cute as cute
14
+ from cutlass.cute.nvgpu import cpasync, warp
15
+ from cutlass import Float32, Int32
16
+ import cutlass.utils as utils_basic
17
+
18
+ from mslk.attention.flash_attn import ampere_helpers as sm80_utils
19
+ from mslk.attention.flash_attn import utils
20
+ from mslk.attention.flash_attn.mask import AttentionMask
21
+ from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
22
+ from mslk.attention.flash_attn.tile_scheduler import ParamsBase, SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments
23
+
24
+
25
+ class FlashAttentionBackwardSm80:
26
+ def __init__(
27
+ self,
28
+ dtype: Type[cutlass.Numeric],
29
+ head_dim: int,
30
+ head_dim_v: Optional[int] = None,
31
+ qhead_per_kvhead: int = 1,
32
+ m_block_size: int = 64,
33
+ n_block_size: int = 128,
34
+ num_stages_Q: int = 2,
35
+ num_stages_dO: int = 2,
36
+ num_threads: int = 256,
37
+ pack_gqa: bool = False,
38
+ is_causal: bool = False,
39
+ SdP_swapAB: bool = False,
40
+ dKV_swapAB: bool = False,
41
+ dQ_swapAB: bool = False,
42
+ AtomLayoutMSdP: int = 1,
43
+ AtomLayoutNdKV: int = 8,
44
+ AtomLayoutMdQ: int = 1,
45
+ V_in_regs: bool = False,
46
+ ):
47
+ """Initializes the configuration for a flash attention v2 kernel.
48
+
49
+ All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension
50
+ should be a multiple of 8.
51
+
52
+ :param head_dim: head dimension
53
+ :type head_dim: int
54
+ :param m_block_size: m block size
55
+ :type m_block_size: int
56
+ :param n_block_size: n block size
57
+ :type n_block_size: int
58
+ :param num_threads: number of threads
59
+ :type num_threads: int
60
+ :param is_causal: is causal
61
+ """
62
+ self.dtype = dtype
63
+ # padding head_dim to a multiple of 16 as k_block_size
64
+ hdim_multiple_of = 32
65
+ self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
66
+ head_dim_v = head_dim_v if head_dim_v is not None else head_dim
67
+ self.same_hdim_kv = head_dim == head_dim_v
68
+ self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)
69
+ # Can save registers (and hence be faster) if we don't have to check hdim predication
70
+ self.check_hdim_oob = head_dim != self.head_dim_padded
71
+ self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded
72
+ self.qhead_per_kvhead = qhead_per_kvhead
73
+ self.m_block_size = m_block_size
74
+ self.n_block_size = n_block_size
75
+ self.num_threads = num_threads
76
+ self.pack_gqa = pack_gqa
77
+ self.is_causal = is_causal
78
+ self.num_stages_Q = num_stages_Q
79
+ self.num_stages_dO = num_stages_dO
80
+ self.SdP_swapAB = SdP_swapAB
81
+ self.dKV_swapAB = dKV_swapAB
82
+ self.dQ_swapAB = dQ_swapAB
83
+ self.AtomLayoutMSdP = AtomLayoutMSdP
84
+ self.AtomLayoutNdKV = AtomLayoutNdKV
85
+ self.AtomLayoutMdQ = AtomLayoutMdQ
86
+ num_mma_warps = self.num_threads // cute.arch.WARP_SIZE
87
+ self.Mma_dKV_is_RS = AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_mma_warps and SdP_swapAB and not dKV_swapAB
88
+ self.V_in_regs = V_in_regs
89
+ self.share_QV_smem = V_in_regs
90
+
91
+ @staticmethod
92
+ def can_implement(
93
+ dtype, head_dim, head_dim_v, m_block_size, n_block_size, num_stages_Q, num_stages_dO,
94
+ num_threads, is_causal,
95
+ V_in_regs=False
96
+ ) -> bool:
97
+ """Check if the kernel can be implemented with the given parameters.
98
+
99
+ :param dtype: data type
100
+ :type dtype: cutlass.Numeric
101
+ :param head_dim: head dimension
102
+ :type head_dim: int
103
+ :param m_block_size: m block size
104
+ :type m_block_size: int
105
+ :param n_block_size: n block size
106
+ :type n_block_size: int
107
+ :param num_threads: number of threads
108
+ :type num_threads: int
109
+ :param is_causal: is causal
110
+ :type is_causal: bool
111
+
112
+ :return: True if the kernel can be implemented, False otherwise
113
+ :rtype: bool
114
+ """
115
+ if dtype not in [cutlass.Float16, cutlass.BFloat16]:
116
+ return False
117
+ if head_dim % 8 != 0:
118
+ return False
119
+ if head_dim_v % 8 != 0:
120
+ return False
121
+ if n_block_size % 16 != 0:
122
+ return False
123
+ if num_threads % 32 != 0:
124
+ return False
125
+ # Check if block size setting is out of shared memory capacity
126
+ # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size
127
+ smem_usage_Q = m_block_size * head_dim * num_stages_Q * 2
128
+ smem_usage_dO = m_block_size * head_dim_v * num_stages_dO * 2
129
+ smem_usage_K = n_block_size * head_dim * 2
130
+ smem_usage_V = n_block_size * head_dim_v * 2
131
+ smem_usage_QV = (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V)
132
+ smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K
133
+ smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80")
134
+ if smem_usage > smem_capacity:
135
+ return False
136
+ return True
137
+
138
+ def _check_type(
139
+ self,
140
+ mQ_type: Type[cutlass.Numeric],
141
+ mK_type: Type[cutlass.Numeric],
142
+ mV_type: Type[cutlass.Numeric],
143
+ mdO_type: Type[cutlass.Numeric],
144
+ mLSE_type: Type[cutlass.Numeric],
145
+ mdPsum_type: Type[cutlass.Numeric],
146
+ mdQaccum_type: Type[cutlass.Numeric],
147
+ mdK_type: Type[cutlass.Numeric],
148
+ mdV_type: Type[cutlass.Numeric],
149
+ mCuSeqlensQ_type: Type[cutlass.Numeric] | None,
150
+ mCuSeqlensK_type: Type[cutlass.Numeric] | None,
151
+ mSeqUsedQ_type: Type[cutlass.Numeric] | None,
152
+ mSeqUsedK_type: Type[cutlass.Numeric] | None,
153
+ ):
154
+ if cutlass.const_expr(not (mQ_type == mK_type == mV_type == mdO_type)):
155
+ raise TypeError("All tensors must have the same data type")
156
+ if cutlass.const_expr(self.qhead_per_kvhead == 1):
157
+ if cutlass.const_expr(not (mdK_type == mdV_type == mQ_type)):
158
+ raise TypeError("mdK and mdV tensors must have the same data type as mQ")
159
+ else:
160
+ if cutlass.const_expr(not (mdK_type == mdV_type == cutlass.Float32)):
161
+ raise TypeError("mdKaccum and mdVaccum tensors must have the data type Float32")
162
+ if cutlass.const_expr(not mQ_type in [cutlass.Float16, cutlass.BFloat16]):
163
+ raise TypeError("Only Float16 or BFloat16 is supported")
164
+ if cutlass.const_expr(not mLSE_type in [cutlass.Float32]):
165
+ raise TypeError("LSE tensor must be Float32")
166
+ if cutlass.const_expr(not mdPsum_type in [cutlass.Float32]):
167
+ raise TypeError("dPsum tensor must be Float32")
168
+ if cutlass.const_expr(not mdQaccum_type in [cutlass.Float32]):
169
+ raise TypeError("dQaccum tensor must be Float32")
170
+ if cutlass.const_expr(mCuSeqlensQ_type not in [None, cutlass.Int32]):
171
+ raise TypeError("cuSeqlensQ tensor must be Int32")
172
+ if cutlass.const_expr(mCuSeqlensK_type not in [None, cutlass.Int32]):
173
+ raise TypeError("cuSeqlensK tensor must be Int32")
174
+ if cutlass.const_expr(mSeqUsedQ_type not in [None, cutlass.Int32]):
175
+ raise TypeError("SeqUsedQ tensor must be Int32")
176
+ if cutlass.const_expr(mSeqUsedK_type not in [None, cutlass.Int32]):
177
+ raise TypeError("SeqUsedK tensor must be Int32")
178
+ assert mQ_type == self.dtype
179
+
180
+ def _setup_attributes(self):
181
+ # ///////////////////////////////////////////////////////////////////////////////
182
+ # Shared memory layout: Q/K/V
183
+ # ///////////////////////////////////////////////////////////////////////////////
184
+ sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_padded)
185
+ self.sQ_layout = cute.tile_to_shape(
186
+ sQ_layout_atom, (self.m_block_size, self.head_dim_padded, self.num_stages_Q), (0, 1, 2),
187
+ )
188
+ sK_layout_atom = sQ_layout_atom
189
+ self.sK_layout = cute.tile_to_shape(
190
+ sK_layout_atom, (self.n_block_size, self.head_dim_padded), (0, 1),
191
+ )
192
+ sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_v_padded)
193
+ self.sV_layout = cute.tile_to_shape(
194
+ sV_layout_atom, (self.n_block_size, self.head_dim_v_padded), (0, 1),
195
+ )
196
+ sdO_layout_atom = sV_layout_atom
197
+ self.sdO_layout = cute.tile_to_shape(
198
+ sdO_layout_atom, (self.m_block_size, self.head_dim_v_padded, self.num_stages_dO), (0, 1, 2),
199
+ )
200
+ # TODO: do we set swizzle to be 3 here explicitly?
201
+ sPdS_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.n_block_size)
202
+ self.sPdS_layout = cute.tile_to_shape(
203
+ sPdS_layout_atom, (self.m_block_size, self.n_block_size), (0, 1),
204
+ )
205
+ # We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds,
206
+ # it's still a valid smem address.
207
+ self.sLSE_layout = cute.make_layout(
208
+ (self.m_block_size, self.num_stages_Q),
209
+ stride=(1, cute.round_up(self.m_block_size, 64)),
210
+ )
211
+ sLSEMma_layout = cute.make_layout(
212
+ (self.m_block_size, self.n_block_size, self.num_stages_Q),
213
+ stride=(1, 0, cute.round_up(self.m_block_size, 64)),
214
+ )
215
+ sLSEMma_layout_transposed = cute.make_layout(
216
+ (self.n_block_size, self.m_block_size, self.num_stages_Q),
217
+ stride=(0, 1, cute.round_up(self.m_block_size, 64)),
218
+ )
219
+ self.sLSEMma_layout = sLSEMma_layout if not self.SdP_swapAB else sLSEMma_layout_transposed
220
+
221
+ # ///////////////////////////////////////////////////////////////////////////////
222
+ # GMEM Tiled copy:
223
+ # ///////////////////////////////////////////////////////////////////////////////
224
+ # Thread layouts for copies
225
+ universal_copy_bits = 128
226
+ async_copy_elems = universal_copy_bits // self.dtype.width
227
+ # atom_async_copy: async copy atom for QKV load
228
+ atom_async_copy = cute.make_copy_atom(
229
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
230
+ self.dtype,
231
+ num_bits_per_copy=universal_copy_bits,
232
+ )
233
+ # atom_universal_copy: universal copy atom for O store
234
+ atom_universal_copy = cute.make_copy_atom(
235
+ cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits,
236
+ )
237
+ # tQK_layout: thread layout for QK load
238
+ tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems
239
+ assert self.num_threads % tQK_shape_dim_1 == 0, "num_threads must be divisible by tQK_shape_dim_1"
240
+ tQK_layout = cute.make_ordered_layout(
241
+ (self.num_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0),
242
+ )
243
+ # Do we need to check if we overshot kBlockM when we load Q?
244
+ self.is_even_m_smem_q = self.m_block_size % tQK_layout.shape[0] == 0
245
+ # Do we need to check if we overshot kBlockN when we load K?
246
+ self.is_even_n_smem_k = self.n_block_size % tQK_layout.shape[0] == 0
247
+ tVdO_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems
248
+ assert self.num_threads % tVdO_shape_dim_1 == 0, "num_threads must be divisible by tVdO_shape_dim_1"
249
+ tVdO_layout = cute.make_ordered_layout(
250
+ (self.num_threads // tVdO_shape_dim_1, tVdO_shape_dim_1), order=(1, 0),
251
+ )
252
+ # Do we need to check if we overshot kBlockN when we load V?
253
+ self.is_even_n_smem_v = self.n_block_size % tVdO_layout.shape[0] == 0
254
+ self.is_even_m_smem_do = self.m_block_size % tVdO_layout.shape[0] == 0
255
+
256
+ # Value layouts for copies
257
+ vQKVdO_layout = cute.make_layout((1, async_copy_elems))
258
+
259
+ # gmem_tiled_copy_QK: tiled copy for QK load
260
+ self.gmem_tiled_copy_QK = cute.make_tiled_copy_tv(atom_async_copy, tQK_layout, vQKVdO_layout)
261
+ self.gmem_tiled_copy_VdO = cute.make_tiled_copy_tv(atom_async_copy, tVdO_layout, vQKVdO_layout)
262
+ self.gmem_tiled_copy_dK = cute.make_tiled_copy_tv(atom_universal_copy, tQK_layout, vQKVdO_layout)
263
+ self.gmem_tiled_copy_dV = cute.make_tiled_copy_tv(atom_universal_copy, tVdO_layout, vQKVdO_layout)
264
+ async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width
265
+
266
+ # I think we wouldn't require this with smarter padding
267
+ if cutlass.const_expr(not self.varlen_q):
268
+ async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width
269
+ atom_async_copy_accum = cute.make_copy_atom(
270
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
271
+ cutlass.Float32,
272
+ num_bits_per_copy=universal_copy_bits,
273
+ )
274
+ else:
275
+ async_copy_elems_accum = 1
276
+ atom_async_copy_accum = cute.make_copy_atom(
277
+ cute.nvgpu.CopyUniversalOp(),
278
+ cutlass.Float32,
279
+ num_bits_per_copy=cutlass.Float32.width,
280
+ )
281
+ self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv(
282
+ atom_async_copy_accum,
283
+ cute.make_layout(self.num_threads),
284
+ cute.make_layout(async_copy_elems_accum),
285
+ )
286
+ self.gmem_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
287
+ cute.make_copy_atom(
288
+ cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=cutlass.Float32.width
289
+ ),
290
+ cute.make_layout(self.num_threads),
291
+ cute.make_layout(1)
292
+ )
293
+ if cutlass.const_expr(self.qhead_per_kvhead > 1):
294
+ self.gmem_tiled_copy_dK = self.gmem_tiled_copy_dQaccum
295
+ self.gmem_tiled_copy_dV = self.gmem_tiled_copy_dQaccum
296
+
297
+ def _get_tiled_mma(self):
298
+ num_mma_warps = self.num_threads // 32
299
+ AtomLayoutSdP = (self.AtomLayoutMSdP, num_mma_warps // self.AtomLayoutMSdP, 1) if cutlass.const_expr(not self.SdP_swapAB) else (num_mma_warps // self.AtomLayoutMSdP, self.AtomLayoutMSdP, 1)
300
+ tiled_mma_sdp = cute.make_tiled_mma(
301
+ warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)),
302
+ AtomLayoutSdP,
303
+ permutation_mnk=(AtomLayoutSdP[0] * 16, AtomLayoutSdP[1] * 16, 16),
304
+ )
305
+ AtomLayoutdKV = (self.AtomLayoutNdKV, num_mma_warps // self.AtomLayoutNdKV, 1) if cutlass.const_expr(not self.dKV_swapAB) else (num_mma_warps // self.AtomLayoutNdKV, self.AtomLayoutNdKV, 1)
306
+ tiled_mma_dkv = cute.make_tiled_mma(
307
+ warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)),
308
+ AtomLayoutdKV,
309
+ permutation_mnk=(AtomLayoutdKV[0] * 16, AtomLayoutdKV[1] * 16, 16),
310
+ )
311
+ AtomLayoutdQ = (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if cutlass.const_expr(not self.dQ_swapAB) else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1)
312
+ tiled_mma_dq = cute.make_tiled_mma(
313
+ warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)),
314
+ AtomLayoutdQ,
315
+ permutation_mnk=(AtomLayoutdQ[0] * 16, AtomLayoutdQ[1] * 16, 16),
316
+ )
317
+ return tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq
318
+
319
+ def _get_shared_storage_cls(self):
320
+ sQ_struct, sK_struct, sV_struct, sdO_struct = [
321
+ cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024]
322
+ for layout in (self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout)
323
+ ]
324
+ cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout))
325
+ sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024]
326
+ sLSE_struct, sdPsum_struct = [
327
+ cute.struct.Align[cute.struct.MemRange[cutlass.Float32, cute.cosize(layout)], 128]
328
+ for layout in (self.sLSE_layout, self.sLSE_layout)
329
+ ]
330
+ sP_struct, sdS_struct = [
331
+ cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 128]
332
+ for layout in (self.sPdS_layout, self.sPdS_layout)
333
+ ]
334
+
335
+ @cute.struct
336
+ class SharedStorageSeparateQV:
337
+ sK: sK_struct
338
+ sV: sV_struct
339
+ sQ: sQ_struct
340
+ sdO: sdO_struct
341
+ sLSE: sLSE_struct
342
+ sdPsum: sdPsum_struct
343
+ sP: sP_struct
344
+ sdS: sdS_struct
345
+ # TODO: the case where there's no sP
346
+
347
+ @cute.struct
348
+ class SharedStorageSharedQV:
349
+ sK: sK_struct
350
+ sV: sV_struct
351
+ sQ: sQV_struct
352
+ sdO: sdO_struct
353
+ sLSE: sLSE_struct
354
+ sdPsum: sdPsum_struct
355
+ sP: sP_struct
356
+ sdS: sdS_struct
357
+
358
+ return SharedStorageSeparateQV if cutlass.const_expr(not self.share_QV_smem) else SharedStorageSharedQV
359
+
360
+ @cute.jit
361
+ def __call__(
362
+ self,
363
+ mQ: cute.Tensor,
364
+ mK: cute.Tensor,
365
+ mV: cute.Tensor,
366
+ mdO: cute.Tensor,
367
+ mLSE: cute.Tensor,
368
+ mdPsum: cute.Tensor,
369
+ mdQaccum: cute.Tensor,
370
+ mdK: cute.Tensor,
371
+ mdV: cute.Tensor,
372
+ softmax_scale: cutlass.Float32,
373
+ stream: cuda.CUstream,
374
+ mCuSeqlensQ: Optional[cute.Tensor] = None,
375
+ mCuSeqlensK: Optional[cute.Tensor] = None,
376
+ mSeqUsedQ: Optional[cute.Tensor] = None,
377
+ mSeqUsedK: Optional[cute.Tensor] = None,
378
+ softcap: Float32 | float | None = None,
379
+ window_size_left: Int32 | int | None = None,
380
+ window_size_right: Int32 | int | None = None,
381
+ mdQ_semaphore: Optional[cute.Tensor] = None,
382
+ ):
383
+ assert mdQ_semaphore is None, "semaphore not supported yet"
384
+ # Get the data type and check if it is fp16 or bf16
385
+ self._check_type(*(t.element_type if t is not None else None
386
+ for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)))
387
+ # Assume all strides are divisible by 128 bits except the last stride
388
+ new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1])
389
+ mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)]
390
+ self.varlen_q = (mCuSeqlensQ is not None)
391
+ self._setup_attributes()
392
+ SharedStorage = self._get_shared_storage_cls()
393
+ tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq = self._get_tiled_mma()
394
+
395
+ num_head = mQ.shape[1] if cutlass.const_expr(mCuSeqlensQ is not None) else mQ.shape[2]
396
+
397
+ if cutlass.const_expr(mCuSeqlensK is not None):
398
+ TileScheduler = SingleTileVarlenScheduler
399
+ num_batch = mCuSeqlensK.shape[0] - 1
400
+ else:
401
+ TileScheduler = SingleTileScheduler
402
+ num_batch = mK.shape[0]
403
+
404
+ # Uses seqlen k, etc. since main bwd kernel's blocks are over n
405
+ tile_sched_args = TileSchedulerArguments(
406
+ num_block=cute.ceil_div(mK.shape[1], self.n_block_size),
407
+ num_head=num_head,
408
+ num_batch=num_batch,
409
+ num_splits=1,
410
+ seqlen_k=0,
411
+ headdim=mK.shape[2],
412
+ headdim_v=mV.shape[2],
413
+ total_q=mK.shape[0],
414
+ tile_shape_mn=(self.n_block_size, self.m_block_size),
415
+ qhead_per_kvhead_packgqa=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1,
416
+ mCuSeqlensQ=mCuSeqlensK,
417
+ mSeqUsedQ=mSeqUsedK,
418
+ )
419
+
420
+ tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
421
+ grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
422
+
423
+ softmax_scale_log2 = softmax_scale * math.log2(math.e)
424
+ self.kernel(
425
+ mQ,
426
+ mK,
427
+ mV,
428
+ mdO,
429
+ mLSE,
430
+ mdPsum,
431
+ mdQaccum,
432
+ mdK,
433
+ mdV,
434
+ mCuSeqlensQ,
435
+ mCuSeqlensK,
436
+ mSeqUsedQ,
437
+ mSeqUsedK,
438
+ softmax_scale,
439
+ softmax_scale_log2,
440
+ self.sQ_layout,
441
+ self.sK_layout,
442
+ self.sV_layout,
443
+ self.sdO_layout,
444
+ self.sPdS_layout,
445
+ self.sLSE_layout,
446
+ self.sLSEMma_layout,
447
+ self.gmem_tiled_copy_QK,
448
+ self.gmem_tiled_copy_VdO,
449
+ self.gmem_tiled_copy_dK,
450
+ self.gmem_tiled_copy_dV,
451
+ self.gmem_tiled_copy_LSE,
452
+ self.gmem_tiled_copy_dQaccum,
453
+ tiled_mma_sdp,
454
+ tiled_mma_dkv,
455
+ tiled_mma_dq,
456
+ SharedStorage,
457
+ tile_sched_params,
458
+ TileScheduler,
459
+ ).launch(
460
+ grid=grid_dim,
461
+ block=[self.num_threads, 1, 1],
462
+ smem=SharedStorage.size_in_bytes(),
463
+ stream=stream,
464
+ )
465
+
466
+ @cute.kernel
467
+ def kernel(
468
+ self,
469
+ mQ: cute.Tensor,
470
+ mK: cute.Tensor,
471
+ mV: cute.Tensor,
472
+ mdO: cute.Tensor,
473
+ mLSE: cute.Tensor,
474
+ mdPsum: cute.Tensor,
475
+ mdQaccum: cute.Tensor,
476
+ mdK: cute.Tensor,
477
+ mdV: cute.Tensor,
478
+ mCuSeqlensQ: Optional[cute.Tensor],
479
+ mCuSeqlensK: Optional[cute.Tensor],
480
+ mSeqUsedQ: Optional[cute.Tensor],
481
+ mSeqUsedK: Optional[cute.Tensor],
482
+ softmax_scale: cutlass.Float32,
483
+ softmax_scale_log2: cutlass.Float32,
484
+ sQ_layout: cute.ComposedLayout,
485
+ sK_layout: cute.ComposedLayout,
486
+ sV_layout: cute.ComposedLayout,
487
+ sdO_layout: cute.ComposedLayout,
488
+ sPdS_layout: cute.ComposedLayout,
489
+ sLSE_layout: cute.Layout,
490
+ sLSEMma_layout: cute.Layout,
491
+ gmem_tiled_copy_QK: cute.TiledCopy,
492
+ gmem_tiled_copy_VdO: cute.TiledCopy,
493
+ gmem_tiled_copy_dK: cute.TiledCopy,
494
+ gmem_tiled_copy_dV: cute.TiledCopy,
495
+ gmem_tiled_copy_LSE: cute.TiledCopy,
496
+ gmem_tiled_copy_dQaccum: cute.TiledCopy,
497
+ tiled_mma_sdp: cute.TiledMma,
498
+ tiled_mma_dkv: cute.TiledMma,
499
+ tiled_mma_dq: cute.TiledMma,
500
+ SharedStorage: cutlass.Constexpr,
501
+ tile_sched_params: ParamsBase,
502
+ TileScheduler: cutlass.Constexpr[Callable],
503
+ ):
504
+ # Thread index, block index
505
+ tidx, _, _ = cute.arch.thread_idx()
506
+
507
+ tile_scheduler = TileScheduler.create(tile_sched_params)
508
+ work_tile = tile_scheduler.initial_work_tile_info()
509
+
510
+ n_block, head_idx, batch_idx, _ = work_tile.tile_idx
511
+
512
+ if work_tile.is_valid_tile:
513
+ seqlen = SeqlenInfoQK.create(batch_idx, mQ.shape[1], mK.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK)
514
+
515
+ m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size)
516
+ m_block_min = 0
517
+ if cutlass.const_expr(self.is_causal):
518
+ m_block_min = max(
519
+ (n_block * self.n_block_size + seqlen.seqlen_q - seqlen.seqlen_k) // self.m_block_size,
520
+ m_block_min,
521
+ )
522
+ # TODO: return early if m_block_max == 0
523
+
524
+ # ///////////////////////////////////////////////////////////////////////////////
525
+ # Get the appropriate tiles for this thread block.
526
+ # ///////////////////////////////////////////////////////////////////////////////
527
+ blkQ_shape = (self.m_block_size, self.head_dim_padded)
528
+ blkK_shape = (self.n_block_size, self.head_dim_padded)
529
+ blkV_shape = (self.n_block_size, self.head_dim_v_padded)
530
+ blkdO_shape = (self.m_block_size, self.head_dim_v_padded)
531
+
532
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
533
+ mQ_cur = mQ[batch_idx, None, head_idx, None]
534
+ mLSE_cur = mLSE[batch_idx, head_idx, None]
535
+ mdO_cur = mdO[batch_idx, None, head_idx, None]
536
+ mdPsum_cur = mdPsum[batch_idx, head_idx, None]
537
+ mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
538
+ else:
539
+ padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size
540
+ mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, head_idx, None])
541
+ mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[head_idx, None])
542
+ mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
543
+ mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None])
544
+ mdQaccum_cur = cute.domain_offset((padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None])
545
+ head_idx_kv = head_idx // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else head_idx
546
+
547
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_k):
548
+ mK_cur, mV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mK, mV)]
549
+ else:
550
+ mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mK, mV)]
551
+
552
+ # (m_block_size, head_dim, m_block)
553
+ gQ = cute.local_tile(mQ_cur, blkQ_shape, (None, 0))
554
+ # (n_block_size, head_dim)
555
+ gK = cute.local_tile(mK_cur, blkK_shape, (n_block, 0))
556
+ # (n_block_size, head_dim_v)
557
+ gV = cute.local_tile(mV_cur, blkV_shape, (n_block, 0))
558
+ # (m_block_size, head_dim_v, m_block)
559
+ gdO = cute.local_tile(mdO_cur, blkdO_shape, (None, 0))
560
+ gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (None,))
561
+ gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (None,))
562
+ gdQaccum = cute.local_tile(mdQaccum_cur, (self.m_block_size * self.head_dim_padded,), (None,))
563
+
564
+ # ///////////////////////////////////////////////////////////////////////////////
565
+ # Get shared memory buffer
566
+ # ///////////////////////////////////////////////////////////////////////////////
567
+ smem = cutlass.utils.SmemAllocator()
568
+ storage = smem.allocate(SharedStorage)
569
+ sQ = storage.sQ.get_tensor(sQ_layout)
570
+ sK = storage.sK.get_tensor(sK_layout)
571
+ if cutlass.const_expr(not self.share_QV_smem):
572
+ sV = storage.sV.get_tensor(sV_layout)
573
+ else:
574
+ sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout)
575
+ sdO = storage.sdO.get_tensor(sdO_layout)
576
+ sP = storage.sP.get_tensor(sPdS_layout)
577
+ sdS = storage.sdS.get_tensor(sPdS_layout)
578
+ sLSE = storage.sLSE.get_tensor(sLSE_layout)
579
+ sdPsum = storage.sdPsum.get_tensor(sLSE_layout)
580
+ sLSEMma = storage.sLSE.get_tensor(sLSEMma_layout)
581
+ sdPsumMma = storage.sdPsum.get_tensor(sLSEMma_layout)
582
+
583
+ # Transpose view of tensors for tiled mma
584
+ sQt, sdOt, sKt, sPt, sdSt = [utils.transpose_view(t) for t in (sQ, sdO, sK, sP, sdS)]
585
+
586
+ gmem_thr_copy_QK = gmem_tiled_copy_QK.get_slice(tidx)
587
+ gmem_thr_copy_VdO = gmem_tiled_copy_VdO.get_slice(tidx)
588
+ gmem_thr_copy_lse = gmem_tiled_copy_LSE.get_slice(tidx)
589
+ gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx)
590
+ # (CPY_Atom, CPY_M, CPY_K, m_block)
591
+ tQgQ = gmem_thr_copy_QK.partition_S(gQ)
592
+ tQsQ = gmem_thr_copy_QK.partition_D(sQ)
593
+ # (CPY_Atom, CPY_N, CPY_K)
594
+ tKgK = gmem_thr_copy_QK.partition_S(gK)
595
+ tKsK = gmem_thr_copy_QK.partition_D(sK)
596
+ # (CPY_Atom, CPY_N, CPY_K)
597
+ tVgV = gmem_thr_copy_VdO.partition_S(gV)
598
+ tVsV = gmem_thr_copy_VdO.partition_D(sV)
599
+ # (CPY_Atom, CPY_M, CPY_K, m_block)
600
+ tdOgdO = gmem_thr_copy_VdO.partition_S(gdO)
601
+ tdOsdO = gmem_thr_copy_VdO.partition_D(sdO)
602
+ tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE)
603
+ tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE)
604
+ tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum)
605
+ tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum)
606
+ tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum)
607
+
608
+ # ///////////////////////////////////////////////////////////////////////////////
609
+ # Tile MMA compute thread partitions and allocate accumulators
610
+ # ///////////////////////////////////////////////////////////////////////////////
611
+ thr_mma_sdp = tiled_mma_sdp.get_slice(tidx)
612
+ thr_mma_dkv = tiled_mma_dkv.get_slice(tidx)
613
+ thr_mma_dq = tiled_mma_dq.get_slice(tidx)
614
+ acc_shape_dK = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_padded))
615
+ acc_shape_dV = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_v_padded))
616
+ acc_dK = cute.make_fragment(acc_shape_dK, cutlass.Float32)
617
+ acc_dV = cute.make_fragment(acc_shape_dV, cutlass.Float32)
618
+ acc_dK.fill(0.0)
619
+ acc_dV.fill(0.0)
620
+
621
+ tSrQ = utils.mma_make_fragment_A(sQ[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB)
622
+ tSrK = utils.mma_make_fragment_B(sK, thr_mma_sdp, swapAB=self.SdP_swapAB)
623
+ tdPrdO = utils.mma_make_fragment_A(sdO[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB)
624
+ tdPrV = utils.mma_make_fragment_B(sV, thr_mma_sdp, swapAB=self.SdP_swapAB)
625
+ tdVrP = utils.mma_make_fragment_A(sPt, thr_mma_dkv, swapAB=self.dKV_swapAB)
626
+ tdVrdO = utils.mma_make_fragment_B(sdOt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB)
627
+ tdKrdS = utils.mma_make_fragment_A(sdSt, thr_mma_dkv, swapAB=self.dKV_swapAB)
628
+ tdKrQ = utils.mma_make_fragment_B(sQt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB)
629
+ tdQrdS = utils.mma_make_fragment_A(sdS, thr_mma_dq, swapAB=self.dQ_swapAB)
630
+ tdQrK = utils.mma_make_fragment_B(sKt, thr_mma_dq, swapAB=self.dQ_swapAB)
631
+
632
+ LSEslice = (None, 0, None) if cutlass.const_expr(not self.SdP_swapAB) else (0, None, None)
633
+ tSsLSEMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sLSEMma))[LSEslice]
634
+ tSsdPsumMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice]
635
+
636
+ # ///////////////////////////////////////////////////////////////////////////////
637
+ # Smem copy atom tiling
638
+ # ///////////////////////////////////////////////////////////////////////////////
639
+ smem_copy_atom = cute.make_copy_atom(
640
+ warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype,
641
+ )
642
+ smem_copy_atom_transposed = cute.make_copy_atom(
643
+ warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype,
644
+ )
645
+ smem_thr_copy_QdO = utils.make_tiled_copy_A(
646
+ smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB
647
+ ).get_slice(tidx)
648
+ smem_thr_copy_KV = utils.make_tiled_copy_B(
649
+ smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB
650
+ ).get_slice(tidx)
651
+ # TODO: should this be smem_copy_atom_transposed?
652
+ smem_thr_copy_PdSt = utils.make_tiled_copy_A(
653
+ smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB
654
+ ).get_slice(tidx)
655
+ smem_thr_copy_QdOt = utils.make_tiled_copy_B(
656
+ smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB
657
+ ).get_slice(tidx)
658
+ smem_thr_copy_dS = utils.make_tiled_copy_A(
659
+ smem_copy_atom, tiled_mma_dq, swapAB=self.dQ_swapAB
660
+ ).get_slice(tidx)
661
+ smem_thr_copy_Kt = utils.make_tiled_copy_B(
662
+ smem_copy_atom_transposed, tiled_mma_dq, swapAB=self.dQ_swapAB
663
+ ).get_slice(tidx)
664
+ # TODO: what's the number of bits? What if SdP_swapAB
665
+ r2s_thr_copy_PdS = cute.make_tiled_copy_C(
666
+ cute.make_copy_atom(
667
+ cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width
668
+ ),
669
+ tiled_mma_sdp,
670
+ ).get_slice(tidx)
671
+
672
+ tSsQ = smem_thr_copy_QdO.partition_S(sQ)
673
+ tdPsdO = smem_thr_copy_QdO.partition_S(sdO)
674
+ tSsK = smem_thr_copy_KV.partition_S(sK)
675
+ tdPsV = smem_thr_copy_KV.partition_S(sV)
676
+ tdVsPt = smem_thr_copy_PdSt.partition_S(sPt)
677
+ tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt)
678
+ tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt)
679
+ tdKsQt = smem_thr_copy_QdOt.partition_S(sQt)
680
+ tdQsdS = smem_thr_copy_dS.partition_S(sdS)
681
+ tdQsKt = smem_thr_copy_Kt.partition_S(sKt)
682
+ tPsP = r2s_thr_copy_PdS.partition_D(sP)
683
+ tdSsdS = r2s_thr_copy_PdS.partition_D(sdS)
684
+
685
+ # ///////////////////////////////////////////////////////////////////////////////
686
+ # Predicate: Mark indices that need to copy when problem_shape isn't a multiple
687
+ # of tile_shape
688
+ # ///////////////////////////////////////////////////////////////////////////////
689
+ # Construct identity layout for KV
690
+ cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
691
+ tQcQ = gmem_thr_copy_QK.partition_S(cQ)
692
+ t0QcQ = gmem_thr_copy_QK.get_slice(0).partition_S(cQ)
693
+ if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded):
694
+ tdOcdO = tQcQ
695
+ t0dOcdO = t0QcQ
696
+ else:
697
+ cdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
698
+ tdOcdO = gmem_thr_copy_VdO.partition_S(cdO)
699
+ t0dOcdO = gmem_thr_copy_VdO.get_slice(0).partition_S(cdO)
700
+ cLSE = cute.make_identity_tensor((self.m_block_size,))
701
+ tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE)
702
+
703
+ # Allocate predicate tensors for m and n, here we only allocate the tile of k, and
704
+ # use "if" on the mn dimension.
705
+ # This is to reduce register pressure and gets 2-3% performance gain.
706
+
707
+ d_head = mQ.shape[cute.rank(mQ) - 1]
708
+ d_head_v = mdO.shape[cute.rank(mdO) - 1]
709
+
710
+ tQpQ = utils.predicate_k(tQcQ, limit=d_head)
711
+ if cutlass.const_expr(self.same_hdim_kv):
712
+ tdOpdO = tQpQ
713
+ else:
714
+ tdOpdO = utils.predicate_k(tdOcdO, limit=d_head_v)
715
+
716
+ # group parameters for compute_one_m_block
717
+ mma_params = SimpleNamespace(
718
+ thr_mma_sdp=thr_mma_sdp, thr_mma_dkv=thr_mma_dkv, thr_mma_dq=thr_mma_dq,
719
+ tSrQ=tSrQ, tSrK=tSrK, tdPrdO=tdPrdO, tdPrV=tdPrV,
720
+ tdVrP=tdVrP, tdVrdO=tdVrdO, tdKrdS=tdKrdS, tdKrQ=tdKrQ,
721
+ tdQrdS=tdQrdS, tdQrK=tdQrK,
722
+ acc_dK=acc_dK, acc_dV=acc_dV,
723
+ )
724
+ smem_copy_params = SimpleNamespace(
725
+ smem_thr_copy_QdO=smem_thr_copy_QdO,
726
+ smem_thr_copy_KV=smem_thr_copy_KV,
727
+ smem_thr_copy_PdSt=smem_thr_copy_PdSt,
728
+ smem_thr_copy_QdOt=smem_thr_copy_QdOt,
729
+ smem_thr_copy_dS=smem_thr_copy_dS,
730
+ smem_thr_copy_Kt=smem_thr_copy_Kt,
731
+ r2s_thr_copy_PdS=r2s_thr_copy_PdS,
732
+ tSsQ=tSsQ, tSsK=tSsK, tdPsdO=tdPsdO, tdPsV=tdPsV,
733
+ tSsLSEMma=tSsLSEMma, tSsdPsumMma=tSsdPsumMma,
734
+ tPsP=tPsP, tdSsdS=tdSsdS,
735
+ tdVsPt=tdVsPt, tdVsdOt=tdVsdOt, tdKsdSt=tdKsdSt, tdKsQt=tdKsQt,
736
+ tdQsdS=tdQsdS, tdQsKt=tdQsKt,
737
+ )
738
+ gmem_copy_params = SimpleNamespace(
739
+ gmem_thr_copy_dQaccum=gmem_thr_copy_dQaccum, tdQgdQaccum=tdQgdQaccum
740
+ )
741
+ load_Q_LSE = partial(
742
+ self.load_Q_LSE, gmem_tiled_copy_QK, gmem_tiled_copy_LSE,
743
+ tQgQ, tQsQ, tQcQ, t0QcQ, tQpQ,
744
+ tLSEgLSE, tLSEsLSE, tLSEcLSE, seqlen=seqlen.seqlen_q
745
+ )
746
+ load_dO_dPsum = partial(
747
+ self.load_dO_dPsum, gmem_tiled_copy_VdO, gmem_tiled_copy_LSE,
748
+ tdOgdO, tdOsdO, tdOcdO, t0dOcdO, tdOpdO,
749
+ tLSEgdPsum, tLSEsdPsum, tLSEcLSE, seqlen=seqlen.seqlen_q
750
+ )
751
+ compute_one_m_block = partial(
752
+ self.compute_one_m_block, mma_params=mma_params,
753
+ smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params,
754
+ load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum,
755
+ m_block_max=m_block_max,
756
+ softmax_scale_log2=softmax_scale_log2,
757
+ )
758
+
759
+ # ///////////////////////////////////////////////////////////////////////////////
760
+ # Prologue
761
+ # ///////////////////////////////////////////////////////////////////////////////
762
+ # Start async loads of the last mn-tile, where we take care of the mn residue
763
+ self.load_V(gmem_thr_copy_VdO, tVgV, tVsV, n_block, seqlen=seqlen.seqlen_k,
764
+ headdim=d_head_v)
765
+ if cutlass.const_expr(self.V_in_regs):
766
+ cute.arch.cp_async_commit_group()
767
+ self.load_K(gmem_thr_copy_QK, tKgK, tKsK, n_block, seqlen=seqlen.seqlen_k,
768
+ headdim=d_head)
769
+ cute.arch.cp_async_commit_group()
770
+
771
+ if cutlass.const_expr(self.V_in_regs):
772
+ cute.arch.cp_async_wait_group(1)
773
+ cute.arch.barrier()
774
+ tdPrV_copy_view = smem_thr_copy_KV.retile(tdPrV)
775
+ cute.copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view)
776
+ # Sync to avoid loading Q to smem_q, which overlaps with smem_v
777
+ cute.arch.barrier()
778
+
779
+ m_block = m_block_min
780
+ assert self.num_stages_Q >= self.num_stages_dO
781
+ for stage in cutlass.range_constexpr(self.num_stages_Q):
782
+ if cutlass.const_expr(self.num_stages_Q == 1 or stage < self.num_stages_Q - 1):
783
+ if stage == 0 or m_block + stage < m_block_max:
784
+ load_Q_LSE(m_block + stage, smem_pipe_write_q=stage)
785
+ cute.arch.cp_async_commit_group()
786
+ if cutlass.const_expr(stage < self.num_stages_dO):
787
+ if stage == 0 or m_block + stage < m_block_max:
788
+ load_dO_dPsum(m_block + stage, smem_pipe_write_q=stage)
789
+ cute.arch.cp_async_commit_group()
790
+
791
+ # ///////////////////////////////////////////////////////////////////////////////
792
+ # Mainloop
793
+ # ///////////////////////////////////////////////////////////////////////////////
794
+ # Start processing of the first n-block.
795
+ mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k)
796
+ mask_fn = partial(
797
+ mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp,
798
+ mask_seqlen=True, mask_causal=self.is_causal
799
+ )
800
+ smem_pipe_read_q = cutlass.Int32(0)
801
+ smem_pipe_read_do = cutlass.Int32(0)
802
+ smem_pipe_write_q = cutlass.Int32(self.num_stages_Q - 1)
803
+ smem_pipe_write_do = cutlass.Int32(0)
804
+ for m_tile in cutlass.range(m_block_min, m_block_max, unroll=1):
805
+ compute_one_m_block(
806
+ m_tile, smem_pipe_read_q, smem_pipe_read_do, smem_pipe_write_q, smem_pipe_write_do,
807
+ mask_fn=mask_fn,
808
+ )
809
+ smem_pipe_read_q = self.advance_pipeline(smem_pipe_read_q, self.num_stages_Q)
810
+ smem_pipe_read_do = self.advance_pipeline(smem_pipe_read_do, self.num_stages_dO)
811
+ smem_pipe_write_q = self.advance_pipeline(smem_pipe_write_q, self.num_stages_Q)
812
+ smem_pipe_write_do = self.advance_pipeline(smem_pipe_write_do, self.num_stages_dO)
813
+
814
+ # ///////////////////////////////////////////////////////////////////////////////
815
+ # Epilogue
816
+ # ///////////////////////////////////////////////////////////////////////////////
817
+ # If GQA, we scale dK in the postprocessing kernel instead
818
+ if cutlass.const_expr(self.qhead_per_kvhead == 1):
819
+ acc_dK.store(acc_dK.load() * softmax_scale)
820
+ # reuse sK and sV data iterator
821
+ sdK = cute.make_tensor(sK.iterator, sK_layout)
822
+ sdV = cute.make_tensor(sV.iterator, sV_layout)
823
+ self.epilogue(
824
+ acc_dK, acc_dV, mdK, mdV, sdK, sdV,
825
+ gmem_tiled_copy_dK, gmem_tiled_copy_dV, tiled_mma_dkv,
826
+ tidx, n_block, head_idx, batch_idx, seqlen, d_head, d_head_v
827
+ )
828
+
829
+ @cute.jit
830
+ def compute_one_m_block(
831
+ self,
832
+ m_block: cutlass.Int32,
833
+ smem_pipe_read_q: cutlass.Int32,
834
+ smem_pipe_read_do: cutlass.Int32,
835
+ smem_pipe_write_q: cutlass.Int32,
836
+ smem_pipe_write_do: cutlass.Int32,
837
+ mma_params: SimpleNamespace,
838
+ smem_copy_params: SimpleNamespace,
839
+ gmem_copy_params: SimpleNamespace,
840
+ load_Q_LSE: Callable,
841
+ load_dO_dPsum: Callable,
842
+ m_block_max: cutlass.Int32,
843
+ softmax_scale_log2: cutlass.Float32,
844
+ mask_fn: Optional[Callable] = None,
845
+ ):
846
+ def load_Q_next():
847
+ m_block_next = m_block + (self.num_stages_Q - 1 if cutlass.const_expr(self.num_stages_Q > 1) else 1)
848
+ if m_block_next < m_block_max:
849
+ load_Q_LSE(m_block_next, smem_pipe_write_q)
850
+ cute.arch.cp_async_commit_group()
851
+
852
+ def load_dO_next():
853
+ if m_block + self.num_stages_dO < m_block_max:
854
+ load_dO_dPsum(m_block + self.num_stages_dO, smem_pipe_write_do)
855
+ cute.arch.cp_async_commit_group()
856
+
857
+ # MMA S
858
+ acc_shape_SdP = mma_params.thr_mma_sdp.partition_shape_C(
859
+ (self.m_block_size, self.n_block_size) if cutlass.const_expr(not self.SdP_swapAB) else (self.n_block_size, self.m_block_size)
860
+ )
861
+ acc_S = cute.make_fragment(acc_shape_SdP, cutlass.Float32)
862
+ acc_S.fill(0.0)
863
+ cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_Q > 1) else 0)
864
+ cute.arch.barrier()
865
+ sm80_utils.gemm(
866
+ mma_params.thr_mma_sdp, acc_S, mma_params.tSrQ, mma_params.tSrK,
867
+ smem_copy_params.tSsQ[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0],
868
+ smem_copy_params.tSsK,
869
+ smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV,
870
+ swap_AB=self.SdP_swapAB,
871
+ )
872
+ tLSErLSE = cute.make_fragment_like(smem_copy_params.tSsLSEMma[None, 0])
873
+ cute.autovec_copy(
874
+ smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], tLSErLSE
875
+ )
876
+ if cutlass.const_expr(mask_fn is not None):
877
+ mask_fn(acc_S, m_block=m_block)
878
+ acc_S_mn = utils.make_acc_tensor_mn_view(acc_S)
879
+ bidx = 0
880
+ # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn)
881
+ # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE)
882
+ assert cute.size(acc_S_mn, mode=[0]) == cute.size(tLSErLSE)
883
+ for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True):
884
+ acc_S_mn[r, None].store(utils.exp2f(acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r]))
885
+ # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn)
886
+
887
+ # MMA dP
888
+ acc_dP = cute.make_fragment(acc_shape_SdP, cutlass.Float32)
889
+ acc_dP.fill(0.0)
890
+ cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_dO > 1) else 0)
891
+ cute.arch.barrier()
892
+ sm80_utils.gemm(
893
+ mma_params.thr_mma_sdp, acc_dP, mma_params.tdPrdO, mma_params.tdPrV,
894
+ smem_copy_params.tdPsdO[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0],
895
+ smem_copy_params.tdPsV,
896
+ smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV,
897
+ hook_fn=load_Q_next if cutlass.const_expr(self.num_stages_Q > 1) else None,
898
+ swap_AB=self.SdP_swapAB,
899
+ )
900
+ tLSErdPsum = cute.make_fragment_like(smem_copy_params.tSsdPsumMma[None, 0])
901
+ cute.autovec_copy(
902
+ smem_copy_params.tSsdPsumMma[None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], tLSErdPsum
903
+ )
904
+ acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP)
905
+ # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn)
906
+ assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum)
907
+ for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True):
908
+ acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r]))
909
+ # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn)
910
+ rP = cute.make_fragment_like(acc_S, self.dtype)
911
+ rP.store(acc_S.load().to(self.dtype))
912
+ if cutlass.const_expr(not self.Mma_dKV_is_RS):
913
+ tPrP = smem_copy_params.r2s_thr_copy_PdS.retile(rP) # ((Atom,AtomNum), MMA_N, MMA_N)
914
+ cute.copy(smem_copy_params.r2s_thr_copy_PdS, tPrP, smem_copy_params.tPsP)
915
+ rdS = cute.make_fragment_like(acc_dP, self.dtype)
916
+ rdS.store(acc_dP.load().to(self.dtype))
917
+ if cutlass.const_expr(not self.Mma_dKV_is_RS):
918
+ cute.arch.barrier() # Make sure P is written
919
+ # For hdim 64, It's faster to write to smem_dS first before the dV gemm
920
+ if cutlass.const_expr(not self.Mma_dKV_is_RS):
921
+ tdSrdS = smem_copy_params.r2s_thr_copy_PdS.retile(rdS)
922
+ cute.copy(smem_copy_params.r2s_thr_copy_PdS, tdSrdS, smem_copy_params.tdSsdS)
923
+ if cutlass.const_expr(self.Mma_dKV_is_RS):
924
+ tdVrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout))
925
+ else:
926
+ tdVrP = mma_params.tdVrP
927
+
928
+ # MMA dK
929
+ sm80_utils.gemm(
930
+ mma_params.thr_mma_dkv, mma_params.acc_dV, tdVrP, mma_params.tdVrdO,
931
+ smem_copy_params.tdVsPt,
932
+ smem_copy_params.tdVsdOt[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0],
933
+ smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt,
934
+ A_in_regs=self.Mma_dKV_is_RS,
935
+ swap_AB=self.dKV_swapAB,
936
+ )
937
+ # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(mma_params.acc_dV)
938
+ cute.arch.barrier() # Make sure dS is written
939
+
940
+ # MMA dQ
941
+ def dQ_mma(hook_fn):
942
+ acc_shape_dQ = mma_params.thr_mma_dq.partition_shape_C(
943
+ (self.m_block_size, self.head_dim_padded) if cutlass.const_expr(not self.dQ_swapAB) else (self.head_dim_padded, self.m_block_size)
944
+ )
945
+ acc_dQ = cute.make_fragment(acc_shape_dQ, cutlass.Float32)
946
+ acc_dQ.fill(0.0)
947
+ sm80_utils.gemm(
948
+ mma_params.thr_mma_dq, acc_dQ, mma_params.tdQrdS, mma_params.tdQrK,
949
+ smem_copy_params.tdQsdS, smem_copy_params.tdQsKt,
950
+ smem_copy_params.smem_thr_copy_dS, smem_copy_params.smem_thr_copy_Kt,
951
+ swap_AB=self.dQ_swapAB,
952
+ hook_fn=hook_fn
953
+ )
954
+ # ((1, 1), num_elements)
955
+ acc_dQ_atomic = gmem_copy_params.gmem_thr_copy_dQaccum.retile(acc_dQ)
956
+ tdQgdQaccum_atomic = gmem_copy_params.tdQgdQaccum[None, None, m_block]
957
+ assert cute.size(acc_dQ_atomic) == cute.size(tdQgdQaccum_atomic)
958
+ for i in cutlass.range(cute.size(acc_dQ_atomic), unroll_full=True):
959
+ utils.atomic_add_fp32(acc_dQ_atomic[i], utils.elem_pointer(tdQgdQaccum_atomic, i))
960
+ # utils.atomic_add_fp32(acc_dQ[i], tdQgdQaccum_atomic.iterator + i * tdQgdQaccum_atomic.stride[1])
961
+ # if cute.arch.thread_idx()[0] == 64 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dQ)
962
+
963
+ # If num_stages_Q == 1, we want to do Mma_dK first so we can start loading Q for the next iteration
964
+ if cutlass.const_expr(self.num_stages_Q > 1):
965
+ dQ_mma(load_dO_next)
966
+
967
+ # MMA dK
968
+ if cutlass.const_expr(self.Mma_dKV_is_RS):
969
+ tdKrdS = cute.make_tensor(rdS.iterator, utils.convert_layout_acc_frgA(rdS.layout))
970
+ else:
971
+ tdKrdS = mma_params.tdKrdS
972
+ sm80_utils.gemm(
973
+ mma_params.thr_mma_dkv, mma_params.acc_dK, tdKrdS, mma_params.tdKrQ,
974
+ smem_copy_params.tdKsdSt,
975
+ smem_copy_params.tdKsQt[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0],
976
+ smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt,
977
+ A_in_regs=self.Mma_dKV_is_RS,
978
+ swap_AB=self.dKV_swapAB,
979
+ hook_fn=load_dO_next if cutlass.const_expr(self.num_stages_Q == 1) else None,
980
+ )
981
+ # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(mma_params.acc_dK)
982
+ if cutlass.const_expr(self.num_stages_Q == 1):
983
+ cute.arch.barrier()
984
+ dQ_mma(load_Q_next)
985
+
986
+ @cute.jit
987
+ def epilogue(
988
+ self,
989
+ acc_dK: cute.Tensor,
990
+ acc_dV: cute.Tensor,
991
+ mdK: cute.Tensor,
992
+ mdV: cute.Tensor,
993
+ sdK: cute.Tensor,
994
+ sdV: cute.Tensor,
995
+ gmem_tiled_copy_dK: cute.TiledCopy,
996
+ gmem_tiled_copy_dV: cute.TiledCopy,
997
+ tiled_mma: cute.TiledMma,
998
+ tidx: cutlass.Int32,
999
+ n_block: cutlass.Int32,
1000
+ num_head: cutlass.Int32,
1001
+ batch_size: cutlass.Int32,
1002
+ seqlen: SeqlenInfoQK,
1003
+ d_head: cutlass.Int32,
1004
+ d_head_v: cutlass.Int32
1005
+ ):
1006
+ rdV = cute.make_fragment_like(acc_dV, self.dtype)
1007
+ rdV.store(acc_dV.load().to(self.dtype))
1008
+ rdK = cute.make_fragment_like(acc_dK, self.dtype)
1009
+ rdK.store(acc_dK.load().to(self.dtype))
1010
+ gmem_thr_copy_dK = gmem_tiled_copy_dK.get_slice(tidx)
1011
+ gmem_thr_copy_dV = gmem_tiled_copy_dV.get_slice(tidx)
1012
+
1013
+ batch_idx = batch_size
1014
+ head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head
1015
+
1016
+ if cutlass.const_expr(self.qhead_per_kvhead == 1):
1017
+ # Make sure all threads have finished reading K and V, otherwise we get racy dQ
1018
+ # because smem_q could be changed.
1019
+ cute.arch.barrier()
1020
+ # smem copy atom for dKV
1021
+ smem_copy_atom_dKV = cute.make_copy_atom(
1022
+ cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width
1023
+ )
1024
+ smem_thr_copy_dKV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma).get_slice(tidx)
1025
+ taccdVrdV = smem_thr_copy_dKV.retile(rdV)
1026
+ taccdKrdK = smem_thr_copy_dKV.retile(rdK)
1027
+ taccdVsdV = smem_thr_copy_dKV.partition_D(sdV)
1028
+ taccdKsdK = smem_thr_copy_dKV.partition_D(sdK)
1029
+ # copy acc O from rmem to smem with the smem copy atom
1030
+ cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV)
1031
+ cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK)
1032
+
1033
+
1034
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_k):
1035
+ mdK_cur, mdV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mdK, mdV)]
1036
+ else:
1037
+ mdK_cur, mdV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mdK, mdV)]
1038
+
1039
+ blkdK_shape = (self.n_block_size, self.head_dim_padded)
1040
+ blkdV_shape = (self.n_block_size, self.head_dim_v_padded)
1041
+ gdK = cute.local_tile(mdK_cur, blkdK_shape, (n_block, 0))
1042
+ gdV = cute.local_tile(mdV_cur, blkdV_shape, (n_block, 0))
1043
+ tdKsdK = gmem_thr_copy_dK.partition_S(sdK)
1044
+ tdKgdK = gmem_thr_copy_dK.partition_D(gdK)
1045
+ tdVsdV = gmem_thr_copy_dV.partition_S(sdV)
1046
+ tdVgdV = gmem_thr_copy_dV.partition_D(gdV)
1047
+ tdKrdK = cute.make_fragment_like(tdKgdK, self.dtype)
1048
+ tdVrdV = cute.make_fragment_like(tdVgdV, self.dtype)
1049
+ # sync before all smem stores are done.
1050
+ cute.arch.barrier()
1051
+ # load acc dK and dV from smem to rmem for wider vectorization
1052
+ # Need to check OOB when reading from smem if kBlockN isn't evenly tiled
1053
+ # TODO
1054
+ cute.autovec_copy(tdKsdK, tdKrdK)
1055
+ cute.autovec_copy(tdVsdV, tdVrdV)
1056
+
1057
+ cdK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded))
1058
+ tdKcdK = gmem_thr_copy_dK.partition_S(cdK)
1059
+ t0dKcdK = gmem_tiled_copy_dK.get_slice(0).partition_S(cdK)
1060
+ if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded):
1061
+ tdVcdV = tdKcdK
1062
+ t0dVcdV = t0dKcdK
1063
+ else:
1064
+ cdV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded))
1065
+ tdVcdV = gmem_thr_copy_dV.partition_S(cdV)
1066
+ t0dVcdV = gmem_tiled_copy_dV.get_slice(0).partition_S(cdV)
1067
+ tdKpdK = utils.predicate_k(tdKcdK, limit=d_head)
1068
+ if cutlass.const_expr(self.same_hdim_kv):
1069
+ tdVpdV = tdKpdK
1070
+ else:
1071
+ tdVpdV = utils.predicate_k(tdVcdV, limit=d_head_v)
1072
+ # copy acc dK and acc_dV from rmem to gmem
1073
+ for rest_m in cutlass.range_constexpr(cute.size(tdKrdK.shape[1])):
1074
+ if t0dKcdK[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdKcdK[0][0]:
1075
+ cute.copy(
1076
+ gmem_tiled_copy_dK,
1077
+ tdKrdK[None, rest_m, None],
1078
+ tdKgdK[None, rest_m, None],
1079
+ pred=tdKpdK[None, rest_m, None] if cutlass.const_expr(self.check_hdim_oob) else None,
1080
+ )
1081
+ for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])):
1082
+ if t0dVcdV[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdVcdV[0][0]:
1083
+ cute.copy(
1084
+ gmem_tiled_copy_dV,
1085
+ tdVrdV[None, rest_m, None],
1086
+ tdVgdV[None, rest_m, None],
1087
+ pred=tdVpdV[None, rest_m, None] if cutlass.const_expr(self.check_hdim_v_oob) else None,
1088
+ )
1089
+
1090
+ else: # qhead_per_kvhead > 1, do atomic add
1091
+ # For Sm90, we need to sync to avoid racy writes to smem_q
1092
+ # For Sm80, we don't need to sync since we're not touching smem
1093
+ head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head
1094
+
1095
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_k):
1096
+ mdK_cur, mdV_cur = [t[batch_idx, head_idx_kv, None] for t in (mdK, mdV)]
1097
+ else:
1098
+ padded_offset_k = seqlen.offset_k + batch_idx * self.n_block_size
1099
+ mdK_cur = cute.domain_offset((padded_offset_k * self.head_dim_padded,), mdK[head_idx_kv, None])
1100
+ mdV_cur = cute.domain_offset((padded_offset_k * self.head_dim_v_padded,), mdV[head_idx_kv, None])
1101
+
1102
+ gdV = cute.local_tile(mdV_cur, (self.n_block_size * self.head_dim_v_padded,), (n_block,))
1103
+ gdK = cute.local_tile(mdK_cur, (self.n_block_size * self.head_dim_padded,), (n_block,))
1104
+ tdVgdVaccum = gmem_thr_copy_dV.partition_S(gdV)
1105
+ tdKgdKaccum = gmem_thr_copy_dK.partition_S(gdK)
1106
+ acc_dV_atomic = gmem_thr_copy_dV.retile(acc_dV)
1107
+ acc_dK_atomic = gmem_thr_copy_dK.retile(acc_dK)
1108
+ assert cute.size(acc_dV_atomic) == cute.size(tdVgdVaccum)
1109
+ assert cute.size(acc_dK_atomic) == cute.size(tdKgdKaccum)
1110
+ for i in cutlass.range(cute.size(acc_dV_atomic), unroll_full=True):
1111
+ utils.atomic_add_fp32(acc_dV_atomic[i], utils.elem_pointer(tdVgdVaccum, i))
1112
+ for i in cutlass.range(cute.size(acc_dK_atomic), unroll_full=True):
1113
+ utils.atomic_add_fp32(acc_dK_atomic[i], utils.elem_pointer(tdKgdKaccum, i))
1114
+
1115
+ @cute.jit
1116
+ def advance_pipeline(self, pipeline_index, num_stages: cutlass.Constexpr):
1117
+ return pipeline_index + 1 if pipeline_index < num_stages - 1 else 0
1118
+
1119
+ @cute.jit
1120
+ def load_K(
1121
+ self,
1122
+ gmem_thr_copy: cute.TiledCopy,
1123
+ tKgK: cute.Tensor,
1124
+ tKsK: cute.Tensor,
1125
+ block: cutlass.Int32,
1126
+ seqlen: cutlass.Int32,
1127
+ headdim: cutlass.Int32,
1128
+ ):
1129
+ cK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded))
1130
+ tKcK = gmem_thr_copy.partition_S(cK)
1131
+ t0KcK = gmem_thr_copy.get_slice(0).partition_S(cK)
1132
+ tKpK = utils.predicate_k(tKcK, limit=headdim)
1133
+ for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])):
1134
+ # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked
1135
+ if self.is_even_n_smem_k or n < cute.size(tKsK.shape[1]) - 1 or tKcK[0, n, 0][0] < self.n_block_size:
1136
+ # Instead of using tKcK, we using t0KcK and subtract the offset from the limit
1137
+ # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time.
1138
+ predicate_n = t0KcK[0, n, 0][0] < seqlen - block * self.n_block_size - tKcK[0][0]
1139
+ predicate = cute.make_fragment_like(tKpK[None, 0, None])
1140
+ for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
1141
+ for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
1142
+ predicate[i, k] = (tKpK[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n
1143
+ cute.copy(
1144
+ gmem_thr_copy, tKgK[None, n, None], tKsK[None, n, None], pred=predicate,
1145
+ )
1146
+ # We need to clear the sK smem tiles since we'll use sKt for mma_dq
1147
+
1148
+ @cute.jit
1149
+ def load_V(
1150
+ self,
1151
+ gmem_thr_copy: cute.TiledCopy,
1152
+ tVgV: cute.Tensor,
1153
+ tVsV: cute.Tensor,
1154
+ block: cutlass.Int32,
1155
+ seqlen: cutlass.Int32,
1156
+ headdim: cutlass.Int32,
1157
+ ):
1158
+ cV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded))
1159
+ tVcV = gmem_thr_copy.partition_S(cV)
1160
+ t0VcV = gmem_thr_copy.get_slice(0).partition_S(cV)
1161
+ tVpV = utils.predicate_k(tVcV, limit=headdim)
1162
+ for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])):
1163
+ # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked
1164
+ if self.is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size:
1165
+ # Instead of using tVcV, we using t0VcV and subtract the offset from the limit
1166
+ # (seqlen - block * kBlockN). This is because the entries of t0VcV are known at compile time.
1167
+ predicate_n = t0VcV[0, n, 0][0] < seqlen - block * self.n_block_size - tVcV[0][0]
1168
+ predicate = cute.make_fragment_like(tVpV[None, 0, None])
1169
+ for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
1170
+ for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
1171
+ predicate[i, k] = (tVpV[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n
1172
+ cute.copy(
1173
+ gmem_thr_copy, tVgV[None, n, None], tVsV[None, n, None], pred=predicate,
1174
+ )
1175
+
1176
+ @cute.jit
1177
+ def load_Q_LSE(
1178
+ self,
1179
+ gmem_tiled_copy_Q: cute.TiledCopy,
1180
+ gmem_tiled_copy_LSE: cute.TiledCopy,
1181
+ tQgQ: cute.Tensor,
1182
+ tQsQ: cute.Tensor,
1183
+ tQcQ: cute.Tensor,
1184
+ t0QcQ: cute.Tensor,
1185
+ tQpQ: cute.Tensor,
1186
+ tLSEgLSE: cute.Tensor,
1187
+ tLSEsLSE: cute.Tensor,
1188
+ tLSEcLSE: cute.Tensor,
1189
+ block: cutlass.Int32,
1190
+ smem_pipe_write_q: cutlass.Int32,
1191
+ seqlen: cutlass.Int32,
1192
+ ):
1193
+ for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
1194
+ # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked
1195
+ if self.is_even_m_smem_q or m < cute.size(tQsQ.shape[1]) - 1 or tQcQ[0, m, 0][0] < self.m_block_size:
1196
+ # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit
1197
+ # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time.
1198
+ predicate_m = t0QcQ[0, m, 0][0] < seqlen - block * self.m_block_size - tQcQ[0][0]
1199
+ predicate = cute.make_fragment_like(tQpQ[None, 0, None])
1200
+ for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
1201
+ for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
1202
+ predicate[i, k] = (tQpQ[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m
1203
+ cute.copy(
1204
+ gmem_tiled_copy_Q,
1205
+ tQgQ[None, m, None, block],
1206
+ tQsQ[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q) > 1 else 0],
1207
+ pred=predicate,
1208
+ )
1209
+ # We need to clear the sQ smem tiles since we'll use sQt for mma_dK
1210
+ # We made sure LSE length is padded so we read `kBlockM` elements so that all
1211
+ # elements in sLSE are filled. Without this we might have uninitialized sLSE values.
1212
+ for m in cutlass.range_constexpr(cute.size(tLSEsLSE.shape[1])):
1213
+ if tLSEcLSE[0, m][0] < self.m_block_size:
1214
+ cute.copy(
1215
+ gmem_tiled_copy_LSE,
1216
+ tLSEgLSE[None, m, block],
1217
+ tLSEsLSE[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q > 1) else 0],
1218
+ )
1219
+
1220
+ @cute.jit
1221
+ def load_dO_dPsum(
1222
+ self,
1223
+ gmem_tiled_copy_dO: cute.TiledCopy,
1224
+ gmem_tiled_copy_dPsum: cute.TiledCopy,
1225
+ tdOgdO: cute.Tensor,
1226
+ tdOsdO: cute.Tensor,
1227
+ tdOcdO: cute.Tensor,
1228
+ t0dOcdO: cute.Tensor,
1229
+ tdOpdO: cute.Tensor,
1230
+ tdPsumgdPsum: cute.Tensor,
1231
+ tdPsumsdPsum: cute.Tensor,
1232
+ tdPsumcdPsum: cute.Tensor,
1233
+ block: cutlass.Int32,
1234
+ smem_pipe_write_q: cutlass.Int32,
1235
+ seqlen: cutlass.Int32,
1236
+ ):
1237
+ for m in cutlass.range_constexpr(cute.size(tdOsdO.shape[1])):
1238
+ # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked
1239
+ if self.is_even_m_smem_do or m < cute.size(tdOsdO.shape[1]) - 1 or tdOcdO[0, m, 0][0] < self.m_block_size:
1240
+ # Instead of using tdOcdO, we using t0dOcdO and subtract the offset from the limit
1241
+ # (seqlen - block * kBlockM). This is because the entries of t0dOcdO are known at compile time.
1242
+ predicate_m = t0dOcdO[0, m, 0][0] < seqlen - block * self.m_block_size - tdOcdO[0][0]
1243
+ predicate = cute.make_fragment_like(tdOpdO[None, 0, None])
1244
+ for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
1245
+ for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
1246
+ predicate[i, k] = (tdOpdO[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m
1247
+ cute.copy(
1248
+ gmem_tiled_copy_dO,
1249
+ tdOgdO[None, m, None, block],
1250
+ tdOsdO[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0],
1251
+ pred=predicate,
1252
+ )
1253
+ # We need to clear the sQ smem tiles since we'll use sQt for mma_dK
1254
+ # We made sure LSE length is padded so we read `kBlockM` elements so that all
1255
+ # elements in sLSE are filled. Without this we might have uninitialized sLSE values.
1256
+ for m in cutlass.range_constexpr(cute.size(tdPsumgdPsum.shape[1])):
1257
+ if tdPsumcdPsum[0, m][0] < self.m_block_size:
1258
+ cute.copy(
1259
+ gmem_tiled_copy_dPsum,
1260
+ tdPsumgdPsum[None, m, block],
1261
+ tdPsumsdPsum[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0],
1262
+ )