mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,464 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h
4
+ # from Cutlass C++ to Cute-DSL.
5
+ import math
6
+ from typing import Callable, Optional, Type, Literal
7
+
8
+ import cuda.bindings.driver as cuda
9
+
10
+ import cutlass
11
+ import cutlass.cute as cute
12
+ import cutlass.utils.hopper_helpers as sm90_utils_basic
13
+ import cutlass.utils.blackwell_helpers as sm100_utils_basic
14
+ from cutlass.cute.nvgpu import cpasync, warp, warpgroup
15
+ from cutlass import Float32, const_expr
16
+ from cutlass.utils import LayoutEnum
17
+
18
+ from mslk.attention.flash_attn import utils
19
+ from mslk.attention.flash_attn import copy_utils
20
+ from mslk.attention.flash_attn import ampere_helpers as sm80_utils
21
+ from mslk.attention.flash_attn import hopper_helpers as sm90_utils
22
+ from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
23
+ import cutlass.cute.nvgpu.tcgen05 as tcgen05
24
+ from mslk.attention.flash_attn.tile_scheduler import (
25
+ ParamsBase,
26
+ SingleTileScheduler,
27
+ SingleTileVarlenScheduler,
28
+ TileSchedulerArguments,
29
+ )
30
+
31
+
32
+ class FlashAttentionBackwardPostprocess:
33
+ def __init__(
34
+ self,
35
+ dtype: Type[cutlass.Numeric],
36
+ head_dim: int,
37
+ arch: Literal[80, 90, 100],
38
+ tile_m: int = 128,
39
+ num_threads: int = 256,
40
+ AtomLayoutMdQ: int = 1,
41
+ dQ_swapAB: bool = False,
42
+ ):
43
+ """
44
+ :param head_dim: head dimension
45
+ :type head_dim: int
46
+ :param tile_m: m block size
47
+ :type tile_m: int
48
+ """
49
+ self.dtype = dtype
50
+ self.tile_m = tile_m
51
+ assert arch in [80, 90, 100], (
52
+ "Only Ampere (80), Hopper (90), and Blackwell (100) are supported"
53
+ )
54
+ self.arch = arch
55
+ # padding head_dim to a multiple of 32 as k_block_size
56
+ hdim_multiple_of = 32
57
+ self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
58
+ self.check_hdim_oob = head_dim != self.tile_hdim
59
+ self.num_threads = num_threads
60
+ self.AtomLayoutMdQ = AtomLayoutMdQ
61
+ self.dQ_swapAB = dQ_swapAB
62
+
63
+ @staticmethod
64
+ def can_implement(dtype, head_dim, tile_m, num_threads) -> bool:
65
+ """Check if the kernel can be implemented with the given parameters.
66
+
67
+ :param dtype: data type
68
+ :type dtype: cutlass.Numeric
69
+ :param head_dim: head dimension
70
+ :type head_dim: int
71
+ :param tile_m: m block size
72
+ :type tile_m: int
73
+
74
+ :return: True if the kernel can be implemented, False otherwise
75
+ :rtype: bool
76
+ """
77
+ if dtype not in [cutlass.Float16, cutlass.BFloat16]:
78
+ return False
79
+ if head_dim % 8 != 0:
80
+ return False
81
+ if num_threads % 32 != 0:
82
+ return False
83
+ return True
84
+
85
+ def _get_tiled_mma(self):
86
+ if const_expr(self.arch == 80):
87
+ num_mma_warps = self.num_threads // 32
88
+ atom_layout_dQ = (
89
+ (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1)
90
+ if const_expr(not self.dQ_swapAB)
91
+ else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1)
92
+ )
93
+ tiled_mma = cute.make_tiled_mma(
94
+ warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)),
95
+ atom_layout_dQ,
96
+ permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16),
97
+ )
98
+ elif const_expr(self.arch == 90):
99
+ num_mma_warp_groups = self.num_threads // 128
100
+ atom_layout_dQ = (self.AtomLayoutMdQ, num_mma_warp_groups // self.AtomLayoutMdQ)
101
+ tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])
102
+ tiled_mma = sm90_utils_basic.make_trivial_tiled_mma(
103
+ self.dtype,
104
+ self.dtype,
105
+ warpgroup.OperandMajorMode.K, # These don't matter, we only care about the accum
106
+ warpgroup.OperandMajorMode.K,
107
+ Float32,
108
+ atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1])
109
+ + (1,),
110
+ tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1],
111
+ )
112
+ else:
113
+ cta_group = tcgen05.CtaGroup.ONE
114
+ tiled_mma = sm100_utils_basic.make_trivial_tiled_mma(
115
+ self.dtype,
116
+ tcgen05.OperandMajorMode.MN, # dS_major_mode
117
+ tcgen05.OperandMajorMode.MN, # Kt_major_mode
118
+ Float32,
119
+ cta_group,
120
+ (self.tile_m, self.tile_hdim),
121
+ )
122
+ if const_expr(self.arch in [80, 90]):
123
+ assert self.num_threads == tiled_mma.size
124
+ return tiled_mma
125
+
126
+ def _setup_attributes(self):
127
+ # ///////////////////////////////////////////////////////////////////////////////
128
+ # GMEM Tiled copy:
129
+ # ///////////////////////////////////////////////////////////////////////////////
130
+ # Thread layouts for copies
131
+ universal_copy_bits = 128
132
+ async_copy_elems_accum = universal_copy_bits // Float32.width
133
+ atom_async_copy_accum = cute.make_copy_atom(
134
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
135
+ Float32,
136
+ num_bits_per_copy=universal_copy_bits,
137
+ )
138
+ # We don't do bound checking for the gmem -> smem load so we just assert here.
139
+ assert (self.tile_m * self.tile_hdim // async_copy_elems_accum) % self.num_threads == 0
140
+ self.g2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
141
+ atom_async_copy_accum,
142
+ cute.make_layout(self.num_threads),
143
+ cute.make_layout(async_copy_elems_accum),
144
+ )
145
+ num_s2r_copy_elems = 1 if const_expr(self.arch == 80) else 4
146
+ if const_expr(self.arch == 80):
147
+ self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
148
+ Float32, self.num_threads, num_s2r_copy_elems
149
+ )
150
+ self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim)
151
+ elif const_expr(self.arch == 90):
152
+ num_threads_per_warp_group = 128
153
+ num_mma_warp_groups = self.num_threads // 128
154
+ self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
155
+ cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
156
+ cute.make_layout((num_threads_per_warp_group, num_mma_warp_groups)), # thr_layout
157
+ cute.make_layout(128 // Float32.width), # val_layout
158
+ )
159
+ self.sdQaccum_layout = cute.make_layout(
160
+ (self.tile_m * self.tile_hdim // num_mma_warp_groups, num_mma_warp_groups)
161
+ )
162
+ else:
163
+ self.dQ_reduce_ncol = 32
164
+ dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol
165
+ assert self.num_threads == 128 # TODO: currently hard-coded
166
+ self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
167
+ Float32, self.num_threads, num_s2r_copy_elems
168
+ )
169
+ self.sdQaccum_layout = cute.make_layout(
170
+ (self.tile_m * self.tile_hdim // dQaccum_reduce_stage, dQaccum_reduce_stage)
171
+ )
172
+
173
+ self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d(
174
+ self.dtype, self.tile_hdim, self.num_threads
175
+ )
176
+ # ///////////////////////////////////////////////////////////////////////////////
177
+ # Shared memory layout: dQ
178
+ # ///////////////////////////////////////////////////////////////////////////////
179
+ # We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,
180
+ # then setting kBlockKSmem to 32 will cause "Static shape_div failure".
181
+ # We want to treat it as 64 x 48, so kBlockKSmem should be 16.
182
+ mma_shape_n = self.tiled_mma.get_tile_size(1)
183
+ if const_expr(self.arch == 80):
184
+ sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n)
185
+ self.sdQ_layout = cute.tile_to_shape(
186
+ sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1)
187
+ )
188
+ elif const_expr(self.arch == 90):
189
+ self.sdQ_layout = sm90_utils.make_smem_layout(
190
+ self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim)
191
+ )
192
+ else:
193
+ # TODO: this is hard-coded for hdim 128
194
+ self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi(
195
+ self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim), 1
196
+ )
197
+
198
+ @cute.jit
199
+ def __call__(
200
+ self,
201
+ mdQaccum: cute.Tensor,
202
+ mdQ: cute.Tensor,
203
+ scale: cutlass.Float32,
204
+ mCuSeqlensQ: Optional[cute.Tensor],
205
+ mSeqUsedQ: Optional[cute.Tensor],
206
+ stream: cuda.CUstream,
207
+ ):
208
+ # Get the data type and check if it is fp16 or bf16
209
+ if const_expr(mdQ.element_type not in [cutlass.Float16, cutlass.BFloat16]):
210
+ raise TypeError("Only Float16 or BFloat16 is supported")
211
+ if const_expr(mdQaccum is not None):
212
+ if const_expr(mdQaccum.element_type not in [cutlass.Float32]):
213
+ raise TypeError("dQaccum tensor must be Float32")
214
+
215
+ # Assume all strides are divisible by 128 bits except the last stride
216
+ new_stride = lambda t: (
217
+ *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
218
+ t.stride[-1],
219
+ )
220
+ mdQaccum, mdQ = [
221
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
222
+ for t in (mdQaccum, mdQ)
223
+ ]
224
+
225
+ self.tiled_mma = self._get_tiled_mma()
226
+ self._setup_attributes()
227
+
228
+ smem_size = max(
229
+ cute.size_in_bytes(cutlass.Float32, self.sdQaccum_layout),
230
+ cute.size_in_bytes(self.dtype, self.sdQ_layout),
231
+ )
232
+
233
+ if const_expr(mCuSeqlensQ is not None):
234
+ TileScheduler = SingleTileVarlenScheduler
235
+ num_head = mdQ.shape[1]
236
+ num_batch = mCuSeqlensQ.shape[0] - 1
237
+ num_block = cute.ceil_div(mdQ.shape[0], self.tile_m)
238
+ else:
239
+ TileScheduler = SingleTileScheduler
240
+ num_head = mdQ.shape[2]
241
+ num_batch = mdQ.shape[0]
242
+ num_block = cute.ceil_div(mdQ.shape[1], self.tile_m)
243
+
244
+ tile_sched_args = TileSchedulerArguments(
245
+ num_block=num_block,
246
+ num_head=num_head,
247
+ num_batch=num_batch,
248
+ num_splits=1,
249
+ seqlen_k=0,
250
+ headdim=mdQ.shape[2],
251
+ headdim_v=0,
252
+ total_q=mdQ.shape[0],
253
+ tile_shape_mn=(self.tile_m, 1),
254
+ mCuSeqlensQ=mCuSeqlensQ,
255
+ mSeqUsedQ=mSeqUsedQ,
256
+ )
257
+
258
+ tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
259
+ grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
260
+
261
+ # grid_dim: (m_block, num_head, batch_size)
262
+ self.kernel(
263
+ mdQaccum,
264
+ mdQ,
265
+ mCuSeqlensQ,
266
+ mSeqUsedQ,
267
+ scale,
268
+ self.tiled_mma,
269
+ self.dQ_swapAB,
270
+ self.sdQaccum_layout,
271
+ self.sdQ_layout,
272
+ self.g2s_tiled_copy_dQaccum,
273
+ self.s2r_tiled_copy_dQaccum,
274
+ self.gmem_tiled_copy_dQ,
275
+ tile_sched_params,
276
+ TileScheduler,
277
+ ).launch(
278
+ grid=grid_dim,
279
+ block=[self.num_threads, 1, 1],
280
+ smem=smem_size,
281
+ stream=stream,
282
+ )
283
+
284
+ @cute.kernel
285
+ def kernel(
286
+ self,
287
+ mdQaccum: cute.Tensor,
288
+ mdQ: cute.Tensor,
289
+ mCuSeqlensQ: Optional[cute.Tensor],
290
+ mSeqUsedQ: Optional[cute.Tensor],
291
+ scale: cutlass.Float32,
292
+ tiled_mma: cute.TiledMma,
293
+ dQ_swapAB: cutlass.Constexpr,
294
+ sdQaccum_layout: cute.Layout,
295
+ sdQ_layout: cute.ComposedLayout,
296
+ g2s_tiled_copy_dQaccum: cute.TiledCopy,
297
+ s2r_tiled_copy_dQaccum: cute.TiledCopy,
298
+ gmem_tiled_copy_dQ: cute.TiledCopy,
299
+ tile_sched_params: ParamsBase,
300
+ TileScheduler: cutlass.Constexpr[Callable],
301
+ ):
302
+ # ///////////////////////////////////////////////////////////////////////////////
303
+ # Get shared memory buffer
304
+ # ///////////////////////////////////////////////////////////////////////////////
305
+ smem = cutlass.utils.SmemAllocator()
306
+ sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024)
307
+ sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum)))
308
+ if const_expr(self.arch in [80, 90]):
309
+ sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout)
310
+ else:
311
+ # extra stage dimension
312
+ sdQ = cute.make_tensor(
313
+ cute.recast_ptr(sdQaccum.iterator, sdQ_layout.inner, dtype=self.dtype),
314
+ sdQ_layout.outer,
315
+ )[None, None, 0]
316
+ sdQt = utils.transpose_view(sdQ)
317
+
318
+ # Thread index, block index
319
+ tidx, _, _ = cute.arch.thread_idx()
320
+
321
+ tile_scheduler = TileScheduler.create(tile_sched_params)
322
+ work_tile = tile_scheduler.initial_work_tile_info()
323
+
324
+ m_block, head_idx, batch_idx, _ = work_tile.tile_idx
325
+
326
+ if work_tile.is_valid_tile:
327
+ # ///////////////////////////////////////////////////////////////////////////////
328
+ # Get the appropriate tiles for this thread block.
329
+ # ///////////////////////////////////////////////////////////////////////////////
330
+
331
+ seqlen = SeqlenInfoQK.create(
332
+ batch_idx,
333
+ mdQ.shape[1],
334
+ 0,
335
+ mCuSeqlensQ=mCuSeqlensQ,
336
+ mCuSeqlensK=None,
337
+ mSeqUsedQ=mSeqUsedQ,
338
+ mSeqUsedK=None,
339
+ )
340
+ if const_expr(not seqlen.has_cu_seqlens_q):
341
+ mdQ_cur = mdQ[batch_idx, None, head_idx, None]
342
+ mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
343
+ head_dim = mdQ.shape[3]
344
+ else:
345
+ padded_offset_q = seqlen.offset_q + batch_idx * self.tile_m
346
+ if cutlass.const_expr(self.arch >= 90):
347
+ padded_offset_q = padded_offset_q // self.tile_m * self.tile_m
348
+ mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None])
349
+ mdQaccum_cur = cute.domain_offset(
350
+ (padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None]
351
+ )
352
+ head_dim = mdQ.shape[2]
353
+
354
+ # HACK: Compiler doesn't seem to recognize that padding
355
+ # by padded_offset_q * self.tile_hdim keeps alignment
356
+ # since statically divisible by 4
357
+
358
+ mdQaccum_cur_ptr = cute.make_ptr(
359
+ dtype=mdQaccum_cur.element_type,
360
+ value=mdQaccum_cur.iterator.toint(),
361
+ mem_space=mdQaccum_cur.iterator.memspace,
362
+ assumed_align=mdQaccum.iterator.alignment,
363
+ )
364
+ mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout)
365
+
366
+ gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (m_block,))
367
+ gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))
368
+
369
+ seqlen_q = seqlen.seqlen_q
370
+ seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m)
371
+
372
+ # Step 1: load dQaccum from gmem to smem
373
+ g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx)
374
+ tdQgdQaccum = g2s_thr_copy_dQaccum.partition_S(gdQaccum)
375
+ tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum_flat)
376
+ cute.copy(g2s_tiled_copy_dQaccum, tdQgdQaccum, tdQsdQaccumg2s)
377
+ cute.arch.cp_async_commit_group()
378
+ cute.arch.cp_async_wait_group(0)
379
+ cute.arch.barrier()
380
+
381
+ # Step 2: load dQ from smem to rmem
382
+ s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx)
383
+ tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum)
384
+ tile_shape = (self.tile_m, self.tile_hdim)
385
+ acc = None
386
+ tiled_copy_t2r = None
387
+ if const_expr(self.arch in [80, 90]):
388
+ acc_shape = tiled_mma.partition_shape_C(
389
+ tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1]
390
+ )
391
+ acc = cute.make_fragment(acc_shape, cutlass.Float32)
392
+ assert cute.size(acc) == cute.size(tdQsdQaccum)
393
+ else:
394
+ thr_mma = tiled_mma.get_slice(0) # 1-CTA
395
+ dQacc_shape = tiled_mma.partition_shape_C((self.tile_m, self.tile_hdim))
396
+ tdQtdQ = tiled_mma.make_fragment_C(dQacc_shape)
397
+ tdQcdQ = thr_mma.partition_C(
398
+ cute.make_identity_tensor((self.tile_m, self.tile_hdim))
399
+ )
400
+ tmem_load_atom = cute.make_copy_atom(
401
+ tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32
402
+ )
403
+ tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ)
404
+ thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
405
+ tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape
406
+ acc = cute.make_fragment(tdQrdQ_t2r_shape, Float32)
407
+ tdQrdQaccum = cute.make_tensor(acc.iterator, cute.make_layout(tdQsdQaccum.shape))
408
+ cute.autovec_copy(tdQsdQaccum, tdQrdQaccum)
409
+ # Convert tdQrdQaccum from fp32 to fp16/bf16
410
+ rdQ = cute.make_fragment_like(acc, self.dtype)
411
+ rdQ.store((acc.load() * scale).to(self.dtype))
412
+
413
+ # Step 3: Copy dQ from register to smem
414
+ cute.arch.barrier() # make sure all threads have finished loading dQaccum
415
+ if const_expr(self.arch in [80, 90]):
416
+ copy_atom_r2s_dQ = utils.get_smem_store_atom(
417
+ self.arch, self.dtype, transpose=self.dQ_swapAB
418
+ )
419
+ tiled_copy_r2s_dQ = cute.make_tiled_copy_C(copy_atom_r2s_dQ, tiled_mma)
420
+ else:
421
+ # copy_atom_r2s_dQ = sm100_utils_basic.get_smem_store_op(
422
+ # LayoutEnum.ROW_MAJOR, self.dtype, Float32, tiled_copy_t2r,
423
+ # )
424
+ # tiled_copy_r2s_dQ = cute.make_tiled_copy_D(copy_atom_r2s_dQ, tiled_copy_t2r)
425
+ thr_layout_r2s_dQ = cute.make_layout((self.num_threads, 1)) # 128 threads
426
+ val_layout_r2s_dQ = cute.make_layout((1, 128 // self.dtype.width))
427
+ copy_atom_r2s_dQ = cute.make_copy_atom(
428
+ cute.nvgpu.CopyUniversalOp(),
429
+ self.dtype,
430
+ num_bits_per_copy=128,
431
+ )
432
+ tiled_copy_r2s_dQ = cute.make_tiled_copy_tv(
433
+ copy_atom_r2s_dQ, thr_layout_r2s_dQ, val_layout_r2s_dQ
434
+ )
435
+ thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx)
436
+ cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim))
437
+ if const_expr(self.arch in [80, 90]):
438
+ taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ)
439
+ else:
440
+ taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape
441
+ taccdQrdQ = cute.make_tensor(rdQ.iterator, taccdQcdQ_shape)
442
+ taccdQsdQ = thr_copy_r2s_dQ.partition_D(sdQ if const_expr(not self.dQ_swapAB) else sdQt)
443
+ cute.copy(thr_copy_r2s_dQ, taccdQrdQ, taccdQsdQ)
444
+
445
+ # Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem
446
+ cute.arch.barrier() # make sure all smem stores are done
447
+ gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_slice(tidx)
448
+ tdQgdQ = gmem_thr_copy_dQ.partition_S(gdQ)
449
+ tdQsdQ = gmem_thr_copy_dQ.partition_D(sdQ)
450
+ tdQrdQ = cute.make_fragment_like(tdQsdQ, self.dtype)
451
+ # TODO: check OOB when reading from smem if kBlockM isn't evenly tiled
452
+ cute.autovec_copy(tdQsdQ, tdQrdQ)
453
+
454
+ # Step 5: Copy dQ from register to gmem
455
+ tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ)
456
+ tdQpdQ = utils.predicate_k(tdQcdQ, limit=head_dim)
457
+ for rest_m in cutlass.range(cute.size(tdQrdQ.shape[1]), unroll_full=True):
458
+ if tdQcdQ[0, rest_m, 0][0] < seqlen_q - m_block * self.tile_m:
459
+ cute.copy(
460
+ gmem_tiled_copy_dQ,
461
+ tdQrdQ[None, rest_m, None],
462
+ tdQgdQ[None, rest_m, None],
463
+ pred=tdQpdQ[None, rest_m, None],
464
+ )