mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2727 @@
1
+ # @nolint # fbcode
2
+ # Supported features:
3
+ # - BF16 & FP16 dtype
4
+ # - noncausal & causal attention
5
+ # - MHA, GQA, MQA
6
+ # - hdim 64, 96, 128, (192, 128).
7
+ # - varlen
8
+ # - sliding window
9
+ # - split-kv
10
+ # Unsupported features that will be added later:
11
+ # - page size != 128
12
+ # - more hdim (192, 256)
13
+ # Based on the cutlass example and cute-dsl example:
14
+ # https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha
15
+ # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py
16
+
17
+ import enum
18
+ import math
19
+ from typing import Type, Tuple, Callable, Optional, Literal
20
+ from functools import partial
21
+
22
+ import cuda.bindings.driver as cuda
23
+
24
+ import cutlass
25
+ import cutlass.cute as cute
26
+ from cutlass import Float32, Int32, const_expr
27
+ from cutlass.cute.nvgpu import cpasync
28
+ import cutlass.cute.nvgpu.tcgen05 as tcgen05
29
+ import cutlass.utils.blackwell_helpers as sm100_utils_basic
30
+
31
+ from mslk.attention.flash_attn.paged_kv import PagedKVManager
32
+ import mslk.attention.flash_attn.utils as utils
33
+ from mslk.attention.flash_attn import copy_utils
34
+ import mslk.attention.flash_attn.pipeline as pipeline
35
+ from mslk.attention.flash_attn.mask import AttentionMask
36
+ from mslk.attention.flash_attn.softmax import SoftmaxSm100, apply_score_mod_inner
37
+ from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
38
+ from mslk.attention.flash_attn.block_info import BlockInfo
39
+ from mslk.attention.flash_attn.block_sparsity import BlockSparseTensors
40
+ from mslk.attention.flash_attn.block_sparse_utils import (
41
+ get_total_block_count,
42
+ produce_block_sparse_loads_sm100,
43
+ softmax_block_sparse_sm100,
44
+ handle_block_sparse_empty_tile_correction_sm100,
45
+ )
46
+ from mslk.attention.flash_attn.pack_gqa import PackGQA
47
+ from mslk.attention.flash_attn import mma_sm100_desc as sm100_desc
48
+ from mslk.attention.flash_attn import blackwell_helpers as sm100_utils
49
+ from cutlass.cute import FastDivmodDivisor
50
+ from mslk.attention.flash_attn.tile_scheduler import (
51
+ TileSchedulerArguments,
52
+ SingleTileScheduler,
53
+ StaticPersistentTileScheduler,
54
+ SingleTileLPTScheduler,
55
+ SingleTileVarlenScheduler,
56
+ ParamsBase,
57
+ )
58
+
59
+
60
+ class NamedBarrierFwd(enum.IntEnum):
61
+ Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
62
+ # WarpSchedulerWG1 = enum.auto()
63
+ # WarpSchedulerWG2 = enum.auto()
64
+ # WarpSchedulerWG3 = enum.auto()
65
+ # PFull = enum.auto()
66
+ # PEmpty = enum.auto()
67
+
68
+
69
+ class FlashAttentionForwardSm100:
70
+ arch = 100
71
+
72
+ def __init__(
73
+ self,
74
+ # dtype: Type[cutlass.Numeric],
75
+ head_dim: int,
76
+ head_dim_v: Optional[int] = None,
77
+ qhead_per_kvhead: cutlass.Constexpr[int] = 1,
78
+ is_causal: bool = False,
79
+ is_local: bool = False,
80
+ is_split_kv: bool = False,
81
+ pack_gqa: bool = False,
82
+ m_block_size: int = 128,
83
+ n_block_size: int = 128,
84
+ q_stage: cutlass.Constexpr[int] = 2,
85
+ is_persistent: bool = True,
86
+ score_mod: cutlass.Constexpr | None = None,
87
+ mask_mod: cutlass.Constexpr | None = None,
88
+ has_aux_tensors: cutlass.Constexpr = False,
89
+ paged_kv_non_tma: bool = False,
90
+ is_varlen_q: bool = False,
91
+ ):
92
+ self.use_tma_KV = not paged_kv_non_tma
93
+ # self.dtype = dtype
94
+ # padding head_dim to a multiple of 16 as k_block_size
95
+ hdim_multiple_of = 16
96
+ self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
97
+ head_dim_v = head_dim_v if head_dim_v is not None else head_dim
98
+ self.same_hdim_kv = head_dim == head_dim_v
99
+ self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)
100
+ self.same_hdim_kv_padded = self.head_dim_padded == self.head_dim_v_padded
101
+ self.check_hdim_oob = head_dim != self.head_dim_padded
102
+ self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded
103
+ self.m_block_size = m_block_size
104
+ self.n_block_size = n_block_size
105
+ self.q_stage = q_stage
106
+ assert self.q_stage in [1, 2]
107
+
108
+ # 2 Q tile per CTA
109
+ self.cta_tiler = (self.q_stage * m_block_size, n_block_size, self.head_dim_padded)
110
+ self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim_padded)
111
+ self.mma_tiler_pv = (m_block_size, self.head_dim_v_padded, n_block_size)
112
+ self.qk_acc_dtype = Float32
113
+ self.pv_acc_dtype = Float32
114
+ self.cluster_shape_mn = (1, 1)
115
+ self.is_persistent = is_persistent
116
+ self.is_causal = is_causal
117
+ self.is_local = is_local
118
+ self.is_varlen_q = is_varlen_q
119
+ self.use_correction_warps_for_epi = is_varlen_q
120
+ self.qhead_per_kvhead = qhead_per_kvhead
121
+ self.is_split_kv = is_split_kv
122
+ self.pack_gqa = pack_gqa
123
+ if pack_gqa:
124
+ assert m_block_size % self.qhead_per_kvhead == 0, (
125
+ "For PackGQA, m_block_size must be divisible by qhead_per_kvhead"
126
+ )
127
+ assert not (self.is_split_kv and self.head_dim_v_padded >= 192), (
128
+ "SplitKV is not supported for hdim >= 192"
129
+ )
130
+ self.score_mod = score_mod
131
+ self.mask_mod = mask_mod
132
+ if cutlass.const_expr(has_aux_tensors):
133
+ self.vec_size: cutlass.Constexpr = 1
134
+ else:
135
+ self.vec_size: cutlass.Constexpr = 2
136
+ # Does S1 need to wait for S0 to finish
137
+ # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local)
138
+ self.s0_s1_barrier = False
139
+ self.overlap_sO_sQ = (
140
+ (self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or
141
+ (self.head_dim_v_padded >= 128 and self.is_split_kv)
142
+ )
143
+ if self.overlap_sO_sQ:
144
+ self.is_persistent = False
145
+
146
+ assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), (
147
+ "Paged KV does not support irregular head dim"
148
+ )
149
+
150
+ self.softmax0_warp_ids = (0, 1, 2, 3)
151
+ self.softmax1_warp_ids = (4, 5, 6, 7)
152
+ self.correction_warp_ids = (8, 9, 10, 11)
153
+ self.mma_warp_id = 12
154
+ self.epilogue_warp_ids = (13,)
155
+ self.load_warp_ids = (14,)
156
+ self.empty_warp_ids = (15,)
157
+ SM100_TMEM_CAPACITY_COLUMNS = 512
158
+ self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
159
+
160
+ self.threads_per_cta = cute.arch.WARP_SIZE * len(
161
+ (
162
+ *self.softmax0_warp_ids,
163
+ *self.softmax1_warp_ids,
164
+ *self.correction_warp_ids,
165
+ self.mma_warp_id,
166
+ *self.load_warp_ids,
167
+ *self.epilogue_warp_ids,
168
+ *self.empty_warp_ids,
169
+ )
170
+ )
171
+
172
+ if self.q_stage == 1:
173
+ if not self.use_tma_KV:
174
+ self.empty_warp_ids = self.empty_warp_ids + self.load_warp_ids
175
+ self.load_warp_ids = self.softmax1_warp_ids
176
+ else:
177
+ self.empty_warp_ids = self.empty_warp_ids + self.softmax1_warp_ids
178
+ self.softmax1_warp_ids = ()
179
+ elif not self.use_tma_KV:
180
+ self.load_warp_ids = (14, 15)
181
+ self.empty_warp_ids = ()
182
+
183
+ if self.use_correction_warps_for_epi:
184
+ self.empty_warp_ids = self.empty_warp_ids + self.epilogue_warp_ids
185
+ self.epilogue_warp_ids = self.correction_warp_ids
186
+ elif self.is_varlen_q: # fallback
187
+ self.epilogue_warp_ids = (13, 14)
188
+
189
+ self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128
190
+ self.tmem_o_offset = [
191
+ self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded
192
+ for i in range(self.q_stage)
193
+ ] # e.g., 256, 384
194
+ self.tmem_total = self.tmem_o_offset[-1] + self.head_dim_v_padded
195
+ assert self.tmem_total <= SM100_TMEM_CAPACITY_COLUMNS
196
+ self.tmem_s_to_p_offset = self.n_block_size // 2
197
+ self.tmem_p_offset = [
198
+ self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2)
199
+ ] # 0, 128
200
+
201
+ # vec buffer for row_max & row_sum
202
+ self.tmem_vec_offset = self.tmem_s_offset
203
+
204
+ if self.head_dim_padded < 96:
205
+ self.num_regs_softmax = 200
206
+ self.num_regs_correction = 64
207
+ self.num_regs_other = 48
208
+ else:
209
+ # self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184
210
+ self.num_regs_softmax = 200
211
+ # self.num_regs_softmax = 176
212
+ # self.num_regs_correction = 96
213
+ # self.num_regs_correction = 80
214
+ # self.num_regs_correction = 64 if self.is_causal or self.is_local else 80
215
+ self.num_regs_correction = 64
216
+ # self.num_regs_other = 32
217
+ # self.num_regs_other = 64
218
+ # self.num_regs_other = 80
219
+ self.num_regs_other = 48
220
+ # self.num_regs_other = 96 if self.is_causal or self.is_local else 80
221
+ # self.num_regs_other = 64 if self.is_causal or self.is_local else 80
222
+ self.num_regs_empty = 24
223
+
224
+ self.buffer_align_bytes = 1024
225
+
226
+ def _setup_attributes(self):
227
+ """Set up configurations and parameters for the FMHA kernel operation.
228
+
229
+ This method initializes and configures various attributes required for the
230
+ execution of the fused multi-head attention kernel, mainly about the pipeline stages:
231
+
232
+ - Sets up staging parameters for Q, K, V inputs and accumulator data
233
+ - Configures pipeline stages for softmax, correction, and epilogue operations
234
+ """
235
+
236
+ self.kv_stage = 4 if self.q_dtype.width == 8 or self.q_stage == 1 else 3
237
+ self.acc_stage = 1
238
+ # For hdim 192,128, we don't have enough smem to store all 3 stages of KV:
239
+ # 128 x 192 x 2 bytes x 3 stages = 144KB, and we need 96KB for Q.
240
+ # Instead we store smem as [smem_large, smem_small, smem_large], where smem_large is
241
+ # 128 x 192 and smem_small is 128 x 128. We set the stride between the stages to be
242
+ # 128 * 160, so that indexing the 0th and 2nd stages will get the right address,
243
+ # but for the 1st stage we need to add or subtract (depending on phase) 128 x 64.
244
+ self.uneven_kv_smem = (
245
+ self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3
246
+ )
247
+ self.uneven_kv_smem_offset = (
248
+ self.m_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2
249
+ if self.uneven_kv_smem
250
+ else 0
251
+ )
252
+ assert self.uneven_kv_smem_offset % 1024 == 0
253
+
254
+ @cute.jit
255
+ def __call__(
256
+ self,
257
+ mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
258
+ mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table
259
+ mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table
260
+ mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
261
+ mLSE: Optional[cute.Tensor],
262
+ softmax_scale: Float32,
263
+ stream: cuda.CUstream,
264
+ mCuSeqlensQ: Optional[cute.Tensor] = None,
265
+ mCuSeqlensK: Optional[cute.Tensor] = None,
266
+ mSeqUsedQ: Optional[cute.Tensor] = None,
267
+ mSeqUsedK: Optional[cute.Tensor] = None,
268
+ mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq)
269
+ window_size_left: Int32 | int | None = None,
270
+ window_size_right: Int32 | int | None = None,
271
+ learnable_sink: Optional[cute.Tensor] = None,
272
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
273
+ aux_tensors: Optional[list] = None,
274
+ ):
275
+ """Execute the Fused Multi-Head Attention operation on the provided tensors.
276
+
277
+ This method prepares the input tensors for processing, validates their shapes and types,
278
+ configures the computation parameters, and launches the CUDA kernel.
279
+
280
+ The method handles:
281
+ 1. Tensor layout transformations for specific memory access patterns
282
+ 2. Validation of tensor shapes and data types
283
+ 3. Initialization of hardware-specific parameters and memory layouts
284
+ 4. Configuration of TMA (Tensor Memory Access) operations
285
+ 5. Grid and work scheduling computation
286
+ 6. Kernel launch with appropriate parameters
287
+ """
288
+ # setup static attributes before smem/grid/tma computation
289
+ self.q_dtype = mQ.element_type
290
+ self.k_dtype = mK.element_type
291
+ self.v_dtype = mV.element_type
292
+ self.o_dtype = mO.element_type
293
+ # Assume all strides are divisible by 128 bits except the last stride
294
+ new_stride = lambda t: (
295
+ *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
296
+ t.stride[-1],
297
+ )
298
+ mQ, mK, mV, mO = [
299
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
300
+ for t in (mQ, mK, mV, mO)
301
+ ]
302
+ Q_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
303
+ mQ = cute.make_tensor(mQ.iterator, cute.select(mQ.layout, mode=Q_layout_transpose))
304
+ # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there's cu_seqlens_k or (page_size, d, h_k, num_pages) if there's page_table
305
+ KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]
306
+ mK, mV = [
307
+ cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose))
308
+ for t in (mK, mV)
309
+ ]
310
+ if const_expr(self.is_split_kv):
311
+ O_layout_transpose = [2, 4, 3, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 3, 2, 0]
312
+ LSE_layout_transpose = [3, 2, 1, 0] if const_expr(mCuSeqlensQ is None) else [2, 1, 0]
313
+ num_splits = mO.shape[0]
314
+ else:
315
+ O_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
316
+ LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
317
+ num_splits = Int32(1)
318
+ mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose))
319
+ mLSE = (
320
+ cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))
321
+ if const_expr(mLSE is not None)
322
+ else None
323
+ )
324
+ # (s, d, h, b) -> (d, s, h, b)
325
+ V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2]
326
+ mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose))
327
+
328
+ self.q_major_mode = cutlass.utils.LayoutEnum.from_tensor(mQ).mma_major_mode()
329
+ self.k_major_mode = cutlass.utils.LayoutEnum.from_tensor(mK).mma_major_mode()
330
+ self.v_major_mode = cutlass.utils.LayoutEnum.from_tensor(mV).mma_major_mode()
331
+ self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO)
332
+
333
+ if const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K):
334
+ raise RuntimeError("The layout of mQ is not supported")
335
+ if const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K):
336
+ raise RuntimeError("The layout of mK is not supported")
337
+ if const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN):
338
+ raise RuntimeError("The layout of mV is not supported")
339
+
340
+ # check type consistency
341
+ if const_expr(self.q_dtype != self.k_dtype):
342
+ raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}")
343
+ if const_expr(self.q_dtype != self.v_dtype):
344
+ raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}")
345
+ self._setup_attributes()
346
+ self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None
347
+ # This can be tuned
348
+ self.e2e_freq = 16
349
+ if const_expr(
350
+ self.head_dim_padded > 64 and not self.is_causal and not self.is_local and self.pack_gqa
351
+ ):
352
+ self.e2e_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10
353
+
354
+ cta_group = tcgen05.CtaGroup.ONE
355
+ # the intermediate tensor p is from tmem & mK-major
356
+ p_source = tcgen05.OperandSource.TMEM
357
+ p_major_mode = tcgen05.OperandMajorMode.K
358
+ tiled_mma_qk = sm100_utils_basic.make_trivial_tiled_mma(
359
+ self.q_dtype,
360
+ self.q_major_mode,
361
+ self.k_major_mode,
362
+ self.qk_acc_dtype,
363
+ cta_group,
364
+ self.mma_tiler_qk[:2],
365
+ )
366
+ tiled_mma_pv = sm100_utils_basic.make_trivial_tiled_mma(
367
+ self.v_dtype,
368
+ p_major_mode,
369
+ self.v_major_mode,
370
+ self.pv_acc_dtype,
371
+ cta_group,
372
+ self.mma_tiler_pv[:2],
373
+ p_source,
374
+ )
375
+
376
+ self.cluster_shape_mnk = (*self.cluster_shape_mn, 1)
377
+ self.cluster_layout_vmnk = cute.tiled_divide(
378
+ cute.make_layout(self.cluster_shape_mnk),
379
+ (tiled_mma_qk.thr_id.shape,),
380
+ )
381
+
382
+ self.epi_tile = self.mma_tiler_pv[:2]
383
+
384
+ sQ_layout = sm100_utils_basic.make_smem_layout_a(
385
+ tiled_mma_qk,
386
+ self.mma_tiler_qk,
387
+ self.q_dtype,
388
+ self.q_stage,
389
+ )
390
+ sK_layout = sm100_utils_basic.make_smem_layout_b(
391
+ tiled_mma_qk,
392
+ self.mma_tiler_qk,
393
+ self.k_dtype,
394
+ self.kv_stage,
395
+ )
396
+ tP_layout = sm100_utils_basic.make_smem_layout_a(
397
+ tiled_mma_pv,
398
+ self.mma_tiler_pv,
399
+ self.q_dtype,
400
+ self.acc_stage,
401
+ )
402
+ sV_layout = sm100_utils_basic.make_smem_layout_b(
403
+ tiled_mma_pv,
404
+ self.mma_tiler_pv,
405
+ self.v_dtype,
406
+ self.kv_stage,
407
+ )
408
+ sO_layout = sm100_utils_basic.make_smem_layout_epi(
409
+ self.o_dtype,
410
+ self.o_layout,
411
+ self.epi_tile,
412
+ self.q_stage,
413
+ )
414
+ if const_expr(not self.same_hdim_kv_padded):
415
+ # sK and sV are using the same physical smem so we need to adjust the stride so that they line up
416
+ stride_sK = const_expr(
417
+ max(sK_layout.outer.stride[-1], 0)
418
+ ) # take max to turn tuple to Int32
419
+ stride_sV = const_expr(max(sV_layout.outer.stride[-1], 0))
420
+ stage_stride = const_expr(
421
+ max(stride_sK, stride_sV)
422
+ if not self.uneven_kv_smem
423
+ else (stride_sK + stride_sV) // 2
424
+ )
425
+ sK_layout = cute.make_composed_layout(
426
+ sK_layout.inner,
427
+ 0,
428
+ cute.make_layout(
429
+ (*sK_layout.outer.shape[:-1], self.kv_stage),
430
+ stride=(*sK_layout.outer.stride[:-1], stage_stride),
431
+ ),
432
+ )
433
+ sV_layout = cute.make_composed_layout(
434
+ sV_layout.inner,
435
+ 0,
436
+ cute.make_layout(
437
+ (*sV_layout.outer.shape[:-1], self.kv_stage),
438
+ stride=(*sV_layout.outer.stride[:-1], stage_stride),
439
+ ),
440
+ )
441
+
442
+ if const_expr(self.pack_gqa):
443
+ shape_Q_packed = (
444
+ (self.qhead_per_kvhead, mQ.shape[0]),
445
+ mQ.shape[1],
446
+ mK.shape[2],
447
+ *mQ.shape[3:],
448
+ )
449
+ stride_Q_packed = (
450
+ (mQ.stride[2], mQ.stride[0]),
451
+ mQ.stride[1],
452
+ mQ.stride[2] * self.qhead_per_kvhead,
453
+ *mQ.stride[3:],
454
+ )
455
+ mQ = cute.make_tensor(
456
+ mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)
457
+ )
458
+ shape_O_packed = (
459
+ (self.qhead_per_kvhead, mO.shape[0]),
460
+ mO.shape[1],
461
+ mK.shape[2],
462
+ *mO.shape[3:],
463
+ )
464
+ stride_O_packed = (
465
+ (mO.stride[2], mO.stride[0]),
466
+ mO.stride[1],
467
+ mO.stride[2] * self.qhead_per_kvhead,
468
+ *mO.stride[3:],
469
+ )
470
+ mO = cute.make_tensor(
471
+ mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)
472
+ )
473
+ if const_expr(mLSE is not None):
474
+ shape_LSE_packed = (
475
+ (self.qhead_per_kvhead, mLSE.shape[0]),
476
+ mK.shape[2],
477
+ *mLSE.shape[2:],
478
+ )
479
+ stride_LSE_packed = (
480
+ (mLSE.stride[1], mLSE.stride[0]),
481
+ mLSE.stride[1] * self.qhead_per_kvhead,
482
+ *mLSE.stride[2:],
483
+ )
484
+ mLSE = cute.make_tensor(
485
+ mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)
486
+ )
487
+
488
+ self.tma_copy_bytes = {
489
+ name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2]))
490
+ for name, mX, layout in [
491
+ ("Q", mQ, sQ_layout),
492
+ ("K", mK, sK_layout),
493
+ ("V", mV, sV_layout),
494
+ ]
495
+ }
496
+
497
+ # TMA load for Q
498
+ tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group)
499
+ tma_store_op = cpasync.CopyBulkTensorTileS2GOp()
500
+
501
+ tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A(
502
+ tma_load_op,
503
+ mQ,
504
+ cute.select(sQ_layout, mode=[0, 1, 2]),
505
+ self.mma_tiler_qk,
506
+ tiled_mma_qk,
507
+ self.cluster_layout_vmnk.shape,
508
+ )
509
+
510
+ if const_expr(self.use_tma_KV):
511
+ # TMA load for K
512
+ tma_atom_K, mK = cute.nvgpu.make_tiled_tma_atom_B(
513
+ tma_load_op,
514
+ mK,
515
+ cute.select(sK_layout, mode=[0, 1, 2]),
516
+ self.mma_tiler_qk,
517
+ tiled_mma_qk,
518
+ self.cluster_layout_vmnk.shape,
519
+ )
520
+ # TMA load for V
521
+ tma_atom_V, mV = cute.nvgpu.make_tiled_tma_atom_B(
522
+ tma_load_op,
523
+ mV,
524
+ cute.select(sV_layout, mode=[0, 1, 2]),
525
+ self.mma_tiler_pv,
526
+ tiled_mma_pv,
527
+ self.cluster_layout_vmnk.shape,
528
+ )
529
+ else:
530
+ tma_atom_K = None
531
+ tma_atom_V = None
532
+
533
+ o_cta_v_layout = cute.composition(cute.make_identity_layout(mO.shape), self.epi_tile)
534
+
535
+ self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)
536
+ if const_expr(self.use_tma_O):
537
+ tma_atom_O, mO = cpasync.make_tiled_tma_atom(
538
+ tma_store_op,
539
+ mO,
540
+ cute.select(sO_layout, mode=[0, 1]),
541
+ o_cta_v_layout,
542
+ )
543
+ gmem_tiled_copy_O = None
544
+ else:
545
+ tma_atom_O = None
546
+ universal_copy_bits = 128
547
+ async_copy_elems = universal_copy_bits // self.o_dtype.width
548
+ atom_universal_copy = cute.make_copy_atom(
549
+ cute.nvgpu.CopyUniversalOp(),
550
+ self.o_dtype,
551
+ num_bits_per_copy=universal_copy_bits,
552
+ )
553
+ tO_shape_dim_1 = sO_layout.outer.shape[1][0] // async_copy_elems
554
+ tO_layout = cute.make_ordered_layout(
555
+ (self.num_epilogue_threads // tO_shape_dim_1, tO_shape_dim_1),
556
+ order=(1, 0),
557
+ )
558
+ # So that we don't have to check if we overshoot kBlockM when we store O
559
+ assert self.m_block_size % tO_layout.shape[0] == 0
560
+ vO_layout = cute.make_layout((1, async_copy_elems))
561
+ gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout)
562
+
563
+ if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):
564
+ TileScheduler = SingleTileVarlenScheduler
565
+ else:
566
+ if const_expr(self.is_causal or self.is_local):
567
+ TileScheduler = SingleTileLPTScheduler
568
+ else:
569
+ TileScheduler = (
570
+ SingleTileScheduler
571
+ if const_expr(not self.is_persistent)
572
+ else StaticPersistentTileScheduler
573
+ )
574
+ tile_sched_args = TileSchedulerArguments(
575
+ cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]),
576
+ cute.size(mQ.shape[2]),
577
+ cute.size(mQ.shape[3])
578
+ if const_expr(mCuSeqlensQ is None)
579
+ else cute.size(mCuSeqlensQ.shape[0] - 1),
580
+ num_splits,
581
+ cute.size(mK.shape[0])
582
+ if const_expr(mPageTable is None)
583
+ else mK.shape[0] * mPageTable.shape[1],
584
+ mQ.shape[1],
585
+ mV.shape[0], # Note that this is different from Sm90 since we transpose mV in Sm100
586
+ total_q=cute.size(mQ.shape[0])
587
+ if const_expr(mCuSeqlensQ is not None)
588
+ else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),
589
+ tile_shape_mn=self.cta_tiler[:2],
590
+ mCuSeqlensQ=mCuSeqlensQ,
591
+ mSeqUsedQ=mSeqUsedQ,
592
+ qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
593
+ element_size=self.k_dtype.width // 8,
594
+ is_persistent=self.is_persistent,
595
+ lpt=self.is_causal or self.is_local,
596
+ is_split_kv=self.is_split_kv,
597
+ )
598
+ tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
599
+ self.tile_scheduler_cls = TileScheduler
600
+ grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
601
+
602
+ self.mbar_load_q_full_offset = 0
603
+ self.mbar_load_q_empty_offset = self.mbar_load_q_full_offset + self.q_stage
604
+ self.mbar_load_kv_full_offset = self.mbar_load_q_empty_offset + self.q_stage
605
+ self.mbar_load_kv_empty_offset = self.mbar_load_kv_full_offset + self.kv_stage
606
+ self.mbar_P_full_O_rescaled_offset = self.mbar_load_kv_empty_offset + self.kv_stage
607
+ self.mbar_S_full_offset = self.mbar_P_full_O_rescaled_offset + self.q_stage
608
+ self.mbar_O_full_offset = self.mbar_S_full_offset + self.q_stage
609
+ self.mbar_softmax_corr_full_offset = self.mbar_O_full_offset + self.q_stage
610
+ self.mbar_softmax_corr_empty_offset = self.mbar_softmax_corr_full_offset + self.q_stage
611
+ self.mbar_corr_epi_full_offset = self.mbar_softmax_corr_empty_offset + self.q_stage
612
+ self.mbar_corr_epi_empty_offset = self.mbar_corr_epi_full_offset + self.q_stage
613
+ self.mbar_s0_s1_sequence_offset = self.mbar_corr_epi_empty_offset + self.q_stage
614
+ self.mbar_tmem_dealloc_offset = self.mbar_s0_s1_sequence_offset + 8
615
+ self.mbar_P_full_2_offset = self.mbar_tmem_dealloc_offset + 1
616
+ self.mbar_total = self.mbar_P_full_2_offset + self.q_stage
617
+
618
+ sO_size = cute.cosize(sO_layout) if const_expr(not self.overlap_sO_sQ) else 0
619
+ sQ_size = (
620
+ cute.cosize(sQ_layout) if const_expr(not self.overlap_sO_sQ) else
621
+ cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width)
622
+ )
623
+
624
+ @cute.struct
625
+ class SharedStorage:
626
+ # m_barriers for pipelines
627
+ mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mbar_total]
628
+ # Tmem holding buffer
629
+ tmem_holding_buf: Int32
630
+ # Smem tensors
631
+ # store row max and row sum
632
+ sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2]
633
+ sO: cute.struct.Align[
634
+ cute.struct.MemRange[self.o_dtype, sO_size],
635
+ self.buffer_align_bytes,
636
+ ]
637
+ sQ: cute.struct.Align[
638
+ cute.struct.MemRange[self.q_dtype, sQ_size],
639
+ self.buffer_align_bytes,
640
+ ]
641
+ sK: cute.struct.Align[
642
+ # cute.cosize(sK_layout) is correct even in the case of self.uneven_kv_smem
643
+ cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)],
644
+ self.buffer_align_bytes,
645
+ ]
646
+
647
+ self.shared_storage = SharedStorage
648
+
649
+ LOG2_E = math.log2(math.e)
650
+ if const_expr(self.score_mod is None):
651
+ softmax_scale_log2 = softmax_scale * LOG2_E
652
+ softmax_scale = None
653
+ else:
654
+ # NB: If a users passes in a score mod, we want to apply the score-mod in the sm_scaled qk
655
+ # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base
656
+ # and correctly apply the softmax_scale prior to score_mod in the softmax step
657
+ softmax_scale_log2 = LOG2_E
658
+ softmax_scale = softmax_scale
659
+
660
+ if const_expr(window_size_left is not None):
661
+ window_size_left = Int32(window_size_left)
662
+ if const_expr(window_size_right is not None):
663
+ window_size_right = Int32(window_size_right)
664
+
665
+ fastdiv_mods = None
666
+ if cutlass.const_expr(aux_tensors is not None):
667
+ seqlen_q = cute.size(mQ.shape[0]) // (
668
+ self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
669
+ )
670
+ seqlen_k = (
671
+ cute.size(mK.shape[0])
672
+ if const_expr(mPageTable is None)
673
+ else mK.shape[0] * mPageTable.shape[1]
674
+ )
675
+ seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
676
+ seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
677
+ fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
678
+
679
+ self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
680
+ if cutlass.const_expr(self.use_block_sparsity and mPageTable is not None):
681
+ raise NotImplementedError("Block sparsity + paged KV not supported on SM100")
682
+
683
+ # Launch the kernel synchronously
684
+ self.kernel(
685
+ mQ,
686
+ mK,
687
+ mV,
688
+ mO,
689
+ mLSE,
690
+ mCuSeqlensQ,
691
+ mCuSeqlensK,
692
+ mSeqUsedQ,
693
+ mSeqUsedK,
694
+ mPageTable,
695
+ tma_atom_Q,
696
+ tma_atom_K,
697
+ tma_atom_V,
698
+ tma_atom_O,
699
+ softmax_scale_log2,
700
+ softmax_scale,
701
+ window_size_left,
702
+ window_size_right,
703
+ learnable_sink,
704
+ blocksparse_tensors,
705
+ sQ_layout,
706
+ sK_layout,
707
+ tP_layout,
708
+ sV_layout,
709
+ sO_layout,
710
+ gmem_tiled_copy_O,
711
+ tiled_mma_qk,
712
+ tiled_mma_pv,
713
+ tile_sched_params,
714
+ num_splits,
715
+ aux_tensors,
716
+ fastdiv_mods,
717
+ ).launch(
718
+ grid=grid_dim,
719
+ block=[self.threads_per_cta, 1, 1],
720
+ cluster=self.cluster_shape_mnk,
721
+ smem=self.shared_storage.size_in_bytes(),
722
+ stream=stream,
723
+ min_blocks_per_mp=1,
724
+ )
725
+
726
+ # GPU device kernel
727
+ @cute.kernel
728
+ def kernel(
729
+ self,
730
+ mQ: cute.Tensor, # (s_q, d, h, b) or (total_q, d, h) if there is cu_seqlens_q
731
+ mK: cute.Tensor, # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there is cu_seqlens_k or (page_size, d, h_k, num_pages) if there is page_table
732
+ mV: cute.Tensor, # (d, s_k, h_k, b_k) or (d, total_k, h_k) if there is cu_seqlens_k or (d, page_size, h_k, num_pages) if there is page_table
733
+ mO: cute.Tensor,
734
+ mLSE: Optional[cute.Tensor],
735
+ mCuSeqlensQ: Optional[cute.Tensor],
736
+ mCuSeqlensK: Optional[cute.Tensor],
737
+ mSeqUsedQ: Optional[cute.Tensor],
738
+ mSeqUsedK: Optional[cute.Tensor],
739
+ mPageTable: Optional[cute.Tensor],
740
+ tma_atom_Q: cute.CopyAtom,
741
+ tma_atom_K: Optional[cute.CopyAtom],
742
+ tma_atom_V: Optional[cute.CopyAtom],
743
+ tma_atom_O: Optional[cute.CopyAtom],
744
+ softmax_scale_log2: Float32,
745
+ softmax_scale: Float32 | None,
746
+ window_size_left: Optional[Int32],
747
+ window_size_right: Optional[Int32],
748
+ learnable_sink: Optional[cute.Tensor],
749
+ blocksparse_tensors: Optional[BlockSparseTensors],
750
+ sQ_layout: cute.ComposedLayout,
751
+ sK_layout: cute.ComposedLayout,
752
+ tP_layout: cute.ComposedLayout,
753
+ sV_layout: cute.ComposedLayout,
754
+ sO_layout: cute.ComposedLayout,
755
+ gmem_tiled_copy_O: Optional[cute.TiledCopy],
756
+ tiled_mma_qk: cute.TiledMma,
757
+ tiled_mma_pv: cute.TiledMma,
758
+ tile_sched_params: ParamsBase,
759
+ num_splits: Int32,
760
+ aux_tensors: Optional[list] = None,
761
+ fastdiv_mods=(None, None),
762
+ ):
763
+ """The device kernel implementation of the Fused Multi-Head Attention.
764
+
765
+ This kernel coordinates multiple specialized warps to perform different phases of the FMHA computation:
766
+ 1. Load warp: Loads Q, K, V data from global memory to shared memory using TMA
767
+ 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V)
768
+ 3. Softmax warps: Compute softmax normalization on attention scores
769
+ 4. Correction warps: Apply adjustments to intermediate results
770
+ 5. Epilogue warp: Handles final output transformation and storage
771
+
772
+ The kernel implements a complex pipeline with overlapping computation and memory operations,
773
+ using tensor memory access (TMA) for efficient data loading, warp specialization for different
774
+ computation phases, and optional attention masking.
775
+ """
776
+
777
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
778
+
779
+ # Prefetch tma descriptor
780
+ if warp_idx == 0:
781
+ cpasync.prefetch_descriptor(tma_atom_Q)
782
+ if const_expr(tma_atom_K is not None):
783
+ cpasync.prefetch_descriptor(tma_atom_K)
784
+ if const_expr(tma_atom_V is not None):
785
+ cpasync.prefetch_descriptor(tma_atom_V)
786
+ if const_expr(tma_atom_O is not None):
787
+ cpasync.prefetch_descriptor(tma_atom_O)
788
+
789
+ # Alloc
790
+ smem = cutlass.utils.SmemAllocator()
791
+ storage = smem.allocate(self.shared_storage)
792
+
793
+ mbar_ptr = storage.mbar_ptr.data_ptr()
794
+ # Use the first N warps to initialize barriers
795
+ if warp_idx == 1:
796
+ # Init "full" barrier with number of producers, "empty" barrier with number of consumers
797
+ for i in cutlass.range_constexpr(self.q_stage):
798
+ cute.arch.mbarrier_init(
799
+ mbar_ptr + self.mbar_load_q_full_offset + i, 1
800
+ )
801
+ cute.arch.mbarrier_init(
802
+ mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id])
803
+ )
804
+ if warp_idx == 2:
805
+ for i in cutlass.range_constexpr(self.q_stage):
806
+ cute.arch.mbarrier_init(
807
+ mbar_ptr + self.mbar_softmax_corr_empty_offset + i, cute.arch.WARP_SIZE * 4
808
+ )
809
+ cute.arch.mbarrier_init(
810
+ mbar_ptr + self.mbar_softmax_corr_full_offset + i, cute.arch.WARP_SIZE * 4
811
+ )
812
+ if warp_idx == 3:
813
+ if const_expr(self.s0_s1_barrier):
814
+ for i in cutlass.range_constexpr(8):
815
+ cute.arch.mbarrier_init(
816
+ mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE
817
+ )
818
+ if const_expr(not self.use_correction_warps_for_epi) and warp_idx == 4:
819
+ for i in cutlass.range_constexpr(self.q_stage):
820
+ cute.arch.mbarrier_init(
821
+ mbar_ptr + self.mbar_corr_epi_full_offset + i,
822
+ cute.arch.WARP_SIZE * len(self.correction_warp_ids),
823
+ )
824
+ cute.arch.mbarrier_init(
825
+ mbar_ptr + self.mbar_corr_epi_empty_offset + i,
826
+ cute.arch.WARP_SIZE * len(self.epilogue_warp_ids),
827
+ )
828
+ if warp_idx == 5:
829
+ for i in cutlass.range_constexpr(self.q_stage):
830
+ cute.arch.mbarrier_init(
831
+ mbar_ptr + self.mbar_P_full_O_rescaled_offset + i,
832
+ cute.arch.WARP_SIZE
833
+ * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids)),
834
+ )
835
+ cute.arch.mbarrier_init(
836
+ mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id])
837
+ )
838
+ cute.arch.mbarrier_init(
839
+ mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id])
840
+ )
841
+ if warp_idx == 6:
842
+ for i in cutlass.range_constexpr(self.q_stage):
843
+ cute.arch.mbarrier_init(
844
+ mbar_ptr + self.mbar_P_full_2_offset + i,
845
+ cute.arch.WARP_SIZE * len(self.softmax0_warp_ids),
846
+ )
847
+ if warp_idx == 7:
848
+ cute.arch.mbarrier_init(
849
+ mbar_ptr + self.mbar_tmem_dealloc_offset,
850
+ cute.arch.WARP_SIZE
851
+ * len(
852
+ (
853
+ *self.softmax0_warp_ids,
854
+ *self.softmax1_warp_ids,
855
+ *self.correction_warp_ids,
856
+ )
857
+ ),
858
+ )
859
+ # Relying on pipeline_kv constructor to call mbarrier_init_fence and sync
860
+ pipeline_kv = self.make_and_init_load_kv_pipeline(mbar_ptr + self.mbar_load_kv_full_offset)
861
+
862
+ # Generate smem tensor Q/K/V/O
863
+ # (MMA, MMA_Q, MMA_D, PIPE)
864
+ sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner)
865
+ # (MMA, MMA_K, MMA_D, PIPE)
866
+ sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)
867
+ # (MMA, MMA_K, MMA_D, PIPE)
868
+ # Strip swizzle info to reuse smem
869
+ sV = cute.make_tensor(cute.recast_ptr(sK.iterator, sV_layout.inner), sV_layout.outer)
870
+ if const_expr(not self.overlap_sO_sQ):
871
+ sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner)
872
+ else:
873
+ sO = cute.make_tensor(cute.recast_ptr(sQ.iterator, sO_layout.inner, self.o_dtype), sO_layout.outer)
874
+
875
+ sScale = storage.sScale.get_tensor(cute.make_layout(self.q_stage * self.m_block_size * 2))
876
+
877
+ thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM
878
+ thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM
879
+
880
+ qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2])
881
+ tStS_fake = thr_mma_qk.make_fragment_C(qk_acc_shape)
882
+ # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always
883
+ # request 512 columns of tmem, so we know that it starts at 0.
884
+ tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16)
885
+ tStS = cute.make_tensor(tmem_ptr, tStS_fake.layout)
886
+
887
+ pv_acc_shape = thr_mma_pv.partition_shape_C(self.mma_tiler_pv[:2])
888
+ tOtO = thr_mma_pv.make_fragment_C(pv_acc_shape)
889
+
890
+ tStSs = tuple(
891
+ cute.make_tensor(tStS.iterator + self.tmem_s_offset[stage], tStS.layout)
892
+ for stage in range(self.q_stage)
893
+ )
894
+ tOtOs = tuple(
895
+ cute.make_tensor(tOtO.iterator + self.tmem_o_offset[stage], tOtO.layout)
896
+ for stage in range(self.q_stage)
897
+ )
898
+
899
+ tP = cute.make_tensor(tStS.iterator, tP_layout.outer)
900
+ tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0]
901
+
902
+ tOrPs = [
903
+ cute.make_tensor(
904
+ tOrP.iterator
905
+ + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p_offset[stage],
906
+ tOrP.layout,
907
+ )
908
+ for stage in range(self.q_stage)
909
+ ]
910
+
911
+ block_info = BlockInfo(
912
+ # This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1])
913
+ self.cta_tiler[0],
914
+ self.cta_tiler[1],
915
+ self.is_causal,
916
+ self.is_local,
917
+ self.is_split_kv,
918
+ window_size_left,
919
+ window_size_right,
920
+ qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
921
+ )
922
+ SeqlenInfoCls = partial(
923
+ SeqlenInfoQK.create,
924
+ seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1],
925
+ seqlen_k_static=mK.shape[0]
926
+ if const_expr(mPageTable is None)
927
+ else mK.shape[0] * mPageTable.shape[1],
928
+ mCuSeqlensQ=mCuSeqlensQ,
929
+ mCuSeqlensK=mCuSeqlensK,
930
+ mSeqUsedQ=mSeqUsedQ,
931
+ mSeqUsedK=mSeqUsedK,
932
+ )
933
+ AttentionMaskCls = partial(
934
+ AttentionMask,
935
+ self.m_block_size,
936
+ self.n_block_size,
937
+ window_size_left=window_size_left,
938
+ window_size_right=window_size_right,
939
+ qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
940
+ )
941
+ TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params)
942
+
943
+ # ///////////////////////////////////////////////////////////////////////////////
944
+ # EMPTY
945
+ # ///////////////////////////////////////////////////////////////////////////////
946
+ for i in cutlass.range_constexpr(len(self.empty_warp_ids)):
947
+ if warp_idx == self.empty_warp_ids[i]:
948
+ cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
949
+
950
+ # ///////////////////////////////////////////////////////////////////////////////
951
+ # LOAD
952
+ # ///////////////////////////////////////////////////////////////////////////////
953
+ if warp_idx >= self.load_warp_ids[0] and warp_idx <= self.load_warp_ids[-1]:
954
+ cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
955
+ self.load(
956
+ thr_mma_qk,
957
+ thr_mma_pv,
958
+ mQ,
959
+ mK,
960
+ mV,
961
+ sQ,
962
+ sK,
963
+ sV,
964
+ mPageTable,
965
+ tma_atom_Q,
966
+ tma_atom_K,
967
+ tma_atom_V,
968
+ pipeline_kv,
969
+ mbar_ptr,
970
+ block_info,
971
+ num_splits,
972
+ SeqlenInfoCls,
973
+ TileSchedulerCls,
974
+ blocksparse_tensors,
975
+ )
976
+
977
+ # ///////////////////////////////////////////////////////////////////////////////
978
+ # MMA
979
+ # ///////////////////////////////////////////////////////////////////////////////
980
+ if warp_idx == self.mma_warp_id:
981
+ # if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids:
982
+ cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
983
+ # Alloc tmem buffer
984
+ tmem_alloc_cols = Int32(self.tmem_alloc_cols)
985
+ if warp_idx == self.mma_warp_id:
986
+ cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf)
987
+ cute.arch.sync_warp()
988
+
989
+ self.mma(
990
+ tiled_mma_qk,
991
+ tiled_mma_pv,
992
+ sQ,
993
+ sK,
994
+ sV,
995
+ tStSs,
996
+ tOtOs,
997
+ tOrPs,
998
+ pipeline_kv,
999
+ mbar_ptr,
1000
+ block_info,
1001
+ num_splits,
1002
+ SeqlenInfoCls,
1003
+ TileSchedulerCls,
1004
+ blocksparse_tensors,
1005
+ )
1006
+
1007
+ # if warp_idx == self.mma_warp_id:
1008
+ # dealloc tmem buffer
1009
+ cute.arch.relinquish_tmem_alloc_permit()
1010
+ cute.arch.mbarrier_wait(mbar_ptr + self.mbar_tmem_dealloc_offset, 0)
1011
+ tmem_alloc_cols = Int32(self.tmem_alloc_cols)
1012
+ # Retrieving tmem ptr and make acc
1013
+ tmem_ptr = cute.arch.retrieve_tmem_ptr(
1014
+ Float32,
1015
+ alignment=16,
1016
+ ptr_to_buffer_holding_addr=storage.tmem_holding_buf,
1017
+ )
1018
+ cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols)
1019
+
1020
+ # ///////////////////////////////////////////////////////////////////////////////
1021
+ # Epilogue
1022
+ # ///////////////////////////////////////////////////////////////////////////////
1023
+ if const_expr(not self.use_correction_warps_for_epi):
1024
+ if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]:
1025
+ cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
1026
+ self.epilogue_s2g(
1027
+ mO,
1028
+ sO,
1029
+ gmem_tiled_copy_O,
1030
+ tma_atom_O,
1031
+ mbar_ptr,
1032
+ block_info,
1033
+ num_splits,
1034
+ SeqlenInfoCls,
1035
+ TileSchedulerCls,
1036
+ )
1037
+
1038
+ # ///////////////////////////////////////////////////////////////////////////////
1039
+ # Softmax
1040
+ # ///////////////////////////////////////////////////////////////////////////////
1041
+ if (
1042
+ (const_expr(self.q_stage == 2) and warp_idx <= self.softmax1_warp_ids[-1]) or
1043
+ (const_expr(self.q_stage == 1) and warp_idx <= self.softmax0_warp_ids[-1])
1044
+ ):
1045
+ # increase register after decreasing
1046
+ cute.arch.warpgroup_reg_alloc(self.num_regs_softmax)
1047
+ softmax_loop = partial(
1048
+ self.softmax_loop,
1049
+ softmax_scale_log2=softmax_scale_log2,
1050
+ softmax_scale=softmax_scale,
1051
+ thr_mma_qk=thr_mma_qk,
1052
+ sScale=sScale,
1053
+ mLSE=mLSE,
1054
+ learnable_sink=learnable_sink,
1055
+ mbar_ptr=mbar_ptr,
1056
+ block_info=block_info,
1057
+ num_splits=num_splits,
1058
+ SeqlenInfoCls=SeqlenInfoCls,
1059
+ AttentionMaskCls=AttentionMaskCls,
1060
+ TileSchedulerCls=TileSchedulerCls,
1061
+ aux_tensors=aux_tensors,
1062
+ fastdiv_mods=fastdiv_mods,
1063
+ blocksparse_tensors=blocksparse_tensors,
1064
+ )
1065
+
1066
+ if const_expr(not self.s0_s1_barrier):
1067
+ stage = Int32(0 if const_expr(self.q_stage == 1) or warp_idx < self.softmax1_warp_ids[0] else 1)
1068
+ softmax_loop(
1069
+ stage=stage,
1070
+ tStSi=cute.make_tensor(
1071
+ tStS.iterator
1072
+ + (self.tmem_s_offset[0] if stage == 0 else self.tmem_s_offset[1]),
1073
+ tStS.layout,
1074
+ ),
1075
+ )
1076
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset)
1077
+ else:
1078
+ # If there's s0_s1_barrier, it's faster to have 2 WGs having different code
1079
+ if warp_idx < self.softmax1_warp_ids[0]:
1080
+ tStSi = cute.make_tensor(tStS.iterator + self.tmem_s_offset[0], tStS.layout)
1081
+ softmax_loop(stage=0, tStSi=tStSi)
1082
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset)
1083
+ if warp_idx < self.correction_warp_ids[0] and warp_idx >= self.softmax1_warp_ids[0]:
1084
+ tStSi = cute.make_tensor(tStS.iterator + self.tmem_s_offset[1], tStS.layout)
1085
+ softmax_loop(stage=1, tStSi=tStSi)
1086
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset)
1087
+
1088
+ # ///////////////////////////////////////////////////////////////////////////////
1089
+ # Correction
1090
+ # ///////////////////////////////////////////////////////////////////////////////
1091
+ if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id:
1092
+ cute.arch.warpgroup_reg_dealloc(self.num_regs_correction)
1093
+ self.correction_loop(
1094
+ thr_mma_qk,
1095
+ thr_mma_pv,
1096
+ tStS,
1097
+ tOtOs,
1098
+ sScale,
1099
+ mO,
1100
+ mLSE,
1101
+ sO,
1102
+ learnable_sink,
1103
+ gmem_tiled_copy_O,
1104
+ tma_atom_O,
1105
+ mbar_ptr,
1106
+ softmax_scale_log2,
1107
+ block_info,
1108
+ num_splits,
1109
+ SeqlenInfoCls,
1110
+ TileSchedulerCls,
1111
+ blocksparse_tensors,
1112
+ )
1113
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset)
1114
+
1115
+ return
1116
+
1117
+ @cute.jit
1118
+ def load(
1119
+ self,
1120
+ thr_mma_qk: cute.core.ThrMma,
1121
+ thr_mma_pv: cute.core.ThrMma,
1122
+ mQ: cute.Tensor,
1123
+ mK: cute.Tensor,
1124
+ mV: cute.Tensor,
1125
+ sQ: cute.Tensor,
1126
+ sK: cute.Tensor,
1127
+ sV: cute.Tensor,
1128
+ mPageTable: Optional[cute.Tensor],
1129
+ tma_atom_Q: cute.CopyAtom,
1130
+ tma_atom_K: Optional[cute.CopyAtom],
1131
+ tma_atom_V: Optional[cute.CopyAtom],
1132
+ pipeline_kv: cutlass.pipeline.PipelineAsync,
1133
+ mbar_ptr: cute.Pointer,
1134
+ block_info: BlockInfo,
1135
+ num_splits: Int32,
1136
+ SeqlenInfoCls: Callable,
1137
+ TileSchedulerCls: Callable,
1138
+ blocksparse_tensors: Optional[BlockSparseTensors],
1139
+ ):
1140
+ num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE
1141
+ tidx = cute.arch.thread_idx()[0] % num_load_threads
1142
+ q_producer_phase = Int32(1)
1143
+ kv_producer_state = cutlass.pipeline.make_pipeline_state(
1144
+ cutlass.pipeline.PipelineUserType.Producer, self.kv_stage
1145
+ )
1146
+ tile_scheduler = TileSchedulerCls()
1147
+ work_tile = tile_scheduler.initial_work_tile_info()
1148
+ while work_tile.is_valid_tile:
1149
+ m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
1150
+ seqlen = SeqlenInfoCls(batch_idx)
1151
+ mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
1152
+ gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0))
1153
+
1154
+ head_idx_kv = (
1155
+ head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx
1156
+ )
1157
+ if const_expr(mPageTable is None):
1158
+ if const_expr(not seqlen.has_cu_seqlens_k):
1159
+ mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)]
1160
+ else:
1161
+ mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, head_idx_kv])
1162
+ mV_cur = cute.domain_offset((0, seqlen.offset_k), mV[None, None, head_idx_kv])
1163
+ gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0))
1164
+ gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None))
1165
+ else:
1166
+ # Need to keep batch coord None since we'll index into it with page idx
1167
+ mK_cur, mV_cur = [t[None, None, head_idx_kv, None] for t in (mK, mV)]
1168
+ gK = cute.local_tile(
1169
+ mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None)
1170
+ )
1171
+ gV = cute.local_tile(
1172
+ mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None)
1173
+ )
1174
+ tSgQ = thr_mma_qk.partition_A(gQ)
1175
+ tSgK = thr_mma_qk.partition_B(gK)
1176
+ tOgV = thr_mma_pv.partition_B(gV)
1177
+ load_Q_fn, _, _ = copy_utils.tma_get_copy_fn(
1178
+ tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ
1179
+ )
1180
+
1181
+ if const_expr(self.use_tma_KV):
1182
+ tKsK, tKgK = cpasync.tma_partition(
1183
+ tma_atom_K,
1184
+ 0, # no multicast
1185
+ cute.make_layout(1),
1186
+ cute.group_modes(sK, 0, 3),
1187
+ cute.group_modes(tSgK, 0, 3),
1188
+ )
1189
+ tVsV, tVgV = cpasync.tma_partition(
1190
+ tma_atom_V,
1191
+ 0, # no multicast
1192
+ cute.make_layout(1),
1193
+ cute.group_modes(sV, 0, 3),
1194
+ cute.group_modes(tOgV, 0, 3),
1195
+ )
1196
+ paged_kv_manager = None
1197
+ else:
1198
+ page_size = mK.shape[0]
1199
+ paged_kv_manager = PagedKVManager.create(
1200
+ mPageTable,
1201
+ mK,
1202
+ mV,
1203
+ FastDivmodDivisor(page_size),
1204
+ batch_idx,
1205
+ head_idx_kv,
1206
+ tidx,
1207
+ seqlen.seqlen_k,
1208
+ 0, # leftpad_k
1209
+ self.n_block_size,
1210
+ self.head_dim_padded,
1211
+ self.head_dim_v_padded,
1212
+ num_load_threads,
1213
+ mK.element_type,
1214
+ )
1215
+ tKsK, tKgK = None, None
1216
+ tVsV, tVgV = None, None
1217
+
1218
+ load_Q = partial(
1219
+ self.load_Q,
1220
+ load_Q_fn,
1221
+ mbar_ptr + self.mbar_load_q_full_offset,
1222
+ mbar_ptr + self.mbar_load_q_empty_offset,
1223
+ phase=q_producer_phase,
1224
+ )
1225
+ # We have to use mbarrier directly in the load for KV instead of replying on
1226
+ # pipeline_kv, because we could have different number of TMA bytes for K and V
1227
+ load_K = partial(
1228
+ self.load_KV,
1229
+ tma_atom_K,
1230
+ tKgK,
1231
+ tKsK,
1232
+ paged_kv_manager,
1233
+ sK,
1234
+ mbar_ptr + self.mbar_load_kv_full_offset,
1235
+ mbar_ptr + self.mbar_load_kv_empty_offset,
1236
+ K_or_V="K",
1237
+ )
1238
+ load_V = partial(
1239
+ self.load_KV,
1240
+ tma_atom_V,
1241
+ tVgV,
1242
+ tVsV,
1243
+ paged_kv_manager,
1244
+ sV,
1245
+ mbar_ptr + self.mbar_load_kv_full_offset,
1246
+ mbar_ptr + self.mbar_load_kv_empty_offset,
1247
+ K_or_V="V",
1248
+ )
1249
+
1250
+ if const_expr(not self.use_block_sparsity):
1251
+ n_block_min, n_block_max = block_info.get_n_block_min_max(
1252
+ seqlen, m_block, split_idx, num_splits
1253
+ )
1254
+ if const_expr(not self.is_split_kv) or n_block_min < n_block_max:
1255
+ if const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE:
1256
+ load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0
1257
+ n_block_first = n_block_max - 1 if n_block_max > 0 else 0
1258
+ page_idx = (
1259
+ mPageTable[batch_idx, n_block_first]
1260
+ if const_expr(mPageTable is not None and self.use_tma_KV)
1261
+ else None
1262
+ )
1263
+ if const_expr(not self.use_tma_KV):
1264
+ paged_kv_manager.load_page_table(n_block_first)
1265
+ load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0
1266
+ kv_producer_state.advance()
1267
+ if const_expr(self.q_stage == 2) and (const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE):
1268
+ load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1
1269
+ q_producer_phase ^= 1
1270
+ load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0
1271
+ kv_producer_state.advance()
1272
+ for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
1273
+ n_block = n_block_max - 2 - i
1274
+ page_idx = (
1275
+ mPageTable[batch_idx, n_block]
1276
+ if const_expr(mPageTable is not None and self.use_tma_KV)
1277
+ else None
1278
+ )
1279
+ if const_expr(not self.use_tma_KV):
1280
+ paged_kv_manager.load_page_table(n_block)
1281
+ # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx)
1282
+ load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki
1283
+ kv_producer_state.advance()
1284
+ load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi
1285
+ kv_producer_state.advance()
1286
+
1287
+ else:
1288
+ kv_producer_state, q_producer_phase = produce_block_sparse_loads_sm100(
1289
+ blocksparse_tensors,
1290
+ batch_idx,
1291
+ head_idx,
1292
+ m_block,
1293
+ kv_producer_state,
1294
+ load_Q,
1295
+ load_K,
1296
+ load_V,
1297
+ pipeline_kv,
1298
+ self.q_stage,
1299
+ q_producer_phase,
1300
+ self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1301
+ )
1302
+
1303
+
1304
+ tile_scheduler.prefetch_next_work()
1305
+ tile_scheduler.advance_to_next_work()
1306
+ work_tile = tile_scheduler.get_current_work()
1307
+ # End of persistent scheduler loop
1308
+
1309
+ @cute.jit
1310
+ def mma(
1311
+ self,
1312
+ tiled_mma_qk: cute.core.ThrMma,
1313
+ tiled_mma_pv: cute.core.ThrMma,
1314
+ sQ: cute.Tensor,
1315
+ sK: cute.Tensor,
1316
+ sV: cute.Tensor,
1317
+ tStSs: Tuple[cute.Tensor, cute.Tensor],
1318
+ tOtOs: tuple[cute.Tensor],
1319
+ tOrPs: Tuple[cute.Tensor, cute.Tensor],
1320
+ pipeline_kv: cutlass.pipeline.PipelineAsync,
1321
+ mbar_ptr: cute.Pointer,
1322
+ block_info: BlockInfo,
1323
+ num_splits: Int32,
1324
+ SeqlenInfoCls: Callable,
1325
+ TileSchedulerCls: Callable,
1326
+ blocksparse_tensors: Optional[BlockSparseTensors],
1327
+ ):
1328
+ tSrQ = tiled_mma_qk.make_fragment_A(sQ)
1329
+ tSrK = tiled_mma_qk.make_fragment_B(sK)
1330
+ tOrV = tiled_mma_pv.make_fragment_B(sV)
1331
+ if const_expr(self.q_stage == 2):
1332
+ tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 1])
1333
+ else:
1334
+ tSrQs = (tSrQ[None, None, None, 0],)
1335
+
1336
+ qk_mma_op, pv_mma_op = tiled_mma_qk.op, tiled_mma_pv.op
1337
+
1338
+ gemm_Si = [
1339
+ partial(
1340
+ sm100_utils.gemm_ptx_partial,
1341
+ qk_mma_op,
1342
+ self.tmem_s_offset[stage],
1343
+ tSrQs[stage],
1344
+ sA=sQ[None, None, None, stage],
1345
+ zero_init=True,
1346
+ )
1347
+ for stage in range(self.q_stage)
1348
+ ]
1349
+ gemm_Pi = [
1350
+ partial(
1351
+ sm100_utils.gemm_ptx_partial,
1352
+ pv_mma_op,
1353
+ self.tmem_o_offset[stage],
1354
+ tOrPs[stage],
1355
+ sA=None,
1356
+ )
1357
+ for stage in range(self.q_stage)
1358
+ ]
1359
+
1360
+ mma_q_consumer_phase = Int32(0)
1361
+ mma_kv_consumer_state = cutlass.pipeline.make_pipeline_state(
1362
+ cutlass.pipeline.PipelineUserType.Consumer, self.kv_stage
1363
+ )
1364
+ P_full_O_rescaled_phase = Int32(0)
1365
+
1366
+ tile_scheduler = TileSchedulerCls()
1367
+ work_tile = tile_scheduler.initial_work_tile_info()
1368
+ while work_tile.is_valid_tile:
1369
+ m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
1370
+ seqlen = SeqlenInfoCls(batch_idx)
1371
+
1372
+ block_iter_count = Int32(0)
1373
+ process_tile = False
1374
+
1375
+ if const_expr(self.use_block_sparsity):
1376
+ block_iter_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1)
1377
+ process_tile = block_iter_count > Int32(0)
1378
+ else:
1379
+ n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)
1380
+ block_iter_count = n_block_max - n_block_min
1381
+ if const_expr(not self.is_split_kv):
1382
+ process_tile = True
1383
+ else:
1384
+ process_tile = n_block_min < n_block_max
1385
+
1386
+ if process_tile:
1387
+ for stage in cutlass.range_constexpr(self.q_stage):
1388
+ # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1)
1389
+ # 1. wait for Q0 / Q1
1390
+ cute.arch.mbarrier_wait(
1391
+ mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase
1392
+ )
1393
+ # 2. wait for K0
1394
+ if const_expr(stage == 0):
1395
+ pipeline_kv.consumer_wait(mma_kv_consumer_state)
1396
+ tSrKi = tSrK[None, None, None, mma_kv_consumer_state.index]
1397
+ # We don't need to acquire empty S0 / S1.
1398
+ # For the first iteration, we don't need to wait as we're guaranteed S0 / S1
1399
+ # are empty. For subsequent iterations, the wait happened at the end
1400
+ # of the while loop.
1401
+ # 3. gemm
1402
+ # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True)
1403
+ sK_cur = sK[None, None, None, mma_kv_consumer_state.index]
1404
+ if const_expr(self.uneven_kv_smem):
1405
+ sK_cur = self.offset_kv_smem(
1406
+ sK_cur, mma_kv_consumer_state.index, mma_kv_consumer_state.phase
1407
+ )
1408
+ gemm_Si[stage](tCrB=tSrKi, sB=sK_cur)
1409
+ # 4. release S0 / S1
1410
+ with cute.arch.elect_one():
1411
+ tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage)
1412
+ mma_q_consumer_phase ^= 1
1413
+ # 5. release K0
1414
+ pipeline_kv.consumer_release(mma_kv_consumer_state)
1415
+ mma_kv_consumer_state.advance()
1416
+ # End of GEMM (Q1 * K0 -> S1)
1417
+ # Note: Q0 & Q1 are still needed in the seqlen_kv loop
1418
+ # so we need to release them after the seqlen_kv loop
1419
+
1420
+ # O hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate
1421
+ block_loop_count = block_iter_count - 1
1422
+ O_should_accumulate = False
1423
+ for i in cutlass.range(block_loop_count, unroll=1):
1424
+ # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop
1425
+ # 1. wait for V0
1426
+ pipeline_kv.consumer_wait(mma_kv_consumer_state)
1427
+ mma_kv_release_state = mma_kv_consumer_state.clone()
1428
+ Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase
1429
+ tOrVi = tOrV[None, None, None, Vi_index]
1430
+ for stage in cutlass.range_constexpr(self.q_stage):
1431
+ # 2. acquire corrected O0/O1_partial and P0 / P1
1432
+ # For the first iteration in this work tile, waiting for O0/O1_partial
1433
+ # means that the correction warps has finished reading tO during
1434
+ # the last iteration of the previous work tile has finished.
1435
+ cute.arch.mbarrier_wait(
1436
+ mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage,
1437
+ P_full_O_rescaled_phase,
1438
+ )
1439
+ # 3. gemm
1440
+ # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True)
1441
+ # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate)
1442
+ sV_cur = sV[None, None, None, Vi_index]
1443
+ if const_expr(self.uneven_kv_smem):
1444
+ sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase)
1445
+ gemm_Pi[stage](
1446
+ tCrB=tOrVi,
1447
+ sB=sV_cur,
1448
+ zero_init=not O_should_accumulate,
1449
+ mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage,
1450
+ mbar_phase=P_full_O_rescaled_phase,
1451
+ )
1452
+ # 4. release accumulated O0_partial / O1_partial
1453
+ # Don't need to signal O_full to the correction warps anymore since the
1454
+ # correction warps wait for the softmax warps anyway. By the time the softmax
1455
+ # warps finished, S_i for the next iteration must have been done, so O_i-1
1456
+ # must have been done as well.
1457
+ # with cute.arch.elect_one():
1458
+ # tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage)
1459
+ # 5. release V(i-1)
1460
+ if const_expr(stage == self.q_stage - 1):
1461
+ pipeline_kv.consumer_release(mma_kv_release_state)
1462
+ mma_kv_release_state.advance()
1463
+ # End of GEMM_PV00 (P0 * V0 -> O0_partial)
1464
+
1465
+ # GEMM_QK0i (Q0 * Ki -> S0)
1466
+ # 1. wait for Ki
1467
+ if const_expr(stage == 0):
1468
+ mma_kv_consumer_state.advance()
1469
+ pipeline_kv.consumer_wait(mma_kv_consumer_state)
1470
+ Ki_index, Ki_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase
1471
+ # 2. gemm
1472
+ # Don't need to wait for the softmax warp to have finished reading the previous
1473
+ # Si, since this gemm is scheduled after the PV gemm, which guaranteed that Si
1474
+ # has been read and Pi has been written.
1475
+ # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrK[None, None, None, Ki_index], zero_init=True)
1476
+ sK_cur = sK[None, None, None, Ki_index]
1477
+ if const_expr(self.uneven_kv_smem):
1478
+ sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase)
1479
+ gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index], sB=sK_cur)
1480
+ # 3. release S0
1481
+ with cute.arch.elect_one():
1482
+ tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage)
1483
+ # End of GEMM_QK0i (Q0 * Ki -> S0)
1484
+ # 4. release Ki
1485
+ pipeline_kv.consumer_release(mma_kv_consumer_state)
1486
+ mma_kv_consumer_state.advance()
1487
+ P_full_O_rescaled_phase ^= 1
1488
+ O_should_accumulate = True
1489
+ # End of seqlen_kv loop
1490
+
1491
+ # release Q0 & Q1
1492
+ with cute.arch.elect_one():
1493
+ for stage in cutlass.range_constexpr(self.q_stage):
1494
+ tcgen05.commit(mbar_ptr + self.mbar_load_q_empty_offset + stage)
1495
+
1496
+ # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop
1497
+ # 1. wait for V0
1498
+ pipeline_kv.consumer_wait(mma_kv_consumer_state)
1499
+ Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase
1500
+ tOrVi = tOrV[None, None, None, Vi_index]
1501
+ for stage in cutlass.range_constexpr(self.q_stage):
1502
+ # 2. acquire corrected Oi_partial and Pi
1503
+ cute.arch.mbarrier_wait(
1504
+ mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase
1505
+ )
1506
+ # 3. gemm
1507
+ # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True)
1508
+ # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate)
1509
+ sV_cur = sV[None, None, None, Vi_index]
1510
+ if const_expr(self.uneven_kv_smem):
1511
+ sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase)
1512
+ gemm_Pi[stage](
1513
+ tCrB=tOrVi,
1514
+ sB=sV_cur,
1515
+ zero_init=not O_should_accumulate,
1516
+ mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage,
1517
+ mbar_phase=P_full_O_rescaled_phase,
1518
+ )
1519
+ # 4. release accumulated O0_partial
1520
+ # We do need O_full here since for the last tile, by the time the softmax warp
1521
+ # has signaled to the correction warps, the softmax warp has just finished compute
1522
+ # the row sum of the current tile. It does not guarantee that the 1st tile
1523
+ # of the next work tile has been computed yet.
1524
+ with cute.arch.elect_one():
1525
+ tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage)
1526
+ # End of GEMM_PV00 (P0 * V0 -> O0_partial)
1527
+ P_full_O_rescaled_phase ^= 1
1528
+ # 5. release Vi_end
1529
+ pipeline_kv.consumer_release(mma_kv_consumer_state)
1530
+ mma_kv_consumer_state.advance()
1531
+ # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1)
1532
+
1533
+ # Advance to next tile
1534
+ tile_scheduler.advance_to_next_work()
1535
+ work_tile = tile_scheduler.get_current_work()
1536
+ # End of persistent scheduler loop
1537
+
1538
+
1539
+ # for both softmax0 and softmax1 warp group
1540
+ @cute.jit
1541
+ def softmax_loop(
1542
+ self,
1543
+ stage: int | Int32,
1544
+ softmax_scale_log2: Float32,
1545
+ softmax_scale: Float32,
1546
+ thr_mma_qk: cute.core.ThrMma,
1547
+ tStSi: cute.Tensor,
1548
+ sScale: cute.Tensor,
1549
+ mLSE: Optional[cute.Tensor],
1550
+ learnable_sink: Optional[cute.Tensor],
1551
+ mbar_ptr: cute.Pointer,
1552
+ block_info: BlockInfo,
1553
+ num_splits: Int32,
1554
+ SeqlenInfoCls: Callable,
1555
+ AttentionMaskCls: Callable,
1556
+ TileSchedulerCls: Callable,
1557
+ aux_tensors: Optional[list] = None,
1558
+ fastdiv_mods=(None, None),
1559
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
1560
+ ):
1561
+ """Compute softmax on attention scores from QK matrix multiplication.
1562
+
1563
+ This method handles the softmax computation for either the first or second half of the
1564
+ attention matrix, depending on the 'stage' parameter. It calculates row-wise maximum
1565
+ and sum values needed for stable softmax computation, applies optional masking, and
1566
+ transforms raw attention scores into probability distributions.
1567
+
1568
+ The implementation uses specialized memory access patterns and efficient math operations
1569
+ for computing exp(x) using exp2 functions. It also coordinates pipeline
1570
+ synchronization between MMA, correction, and sequence processing stages.
1571
+ """
1572
+ tidx = cute.arch.thread_idx()[0] % (
1573
+ cute.arch.WARP_SIZE
1574
+ # * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids)
1575
+ * (len(self.softmax0_warp_ids))
1576
+ )
1577
+
1578
+ tStScale = cute.composition(tStSi, cute.make_layout((self.m_block_size, 1)))
1579
+ tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2]))
1580
+ tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1)))
1581
+
1582
+ tilePlikeFP32 = self.mma_tiler_qk[1] // 32 * self.v_dtype.width
1583
+ tStP_layout = cute.composition(
1584
+ tStSi.layout, cute.make_layout((self.m_block_size, tilePlikeFP32))
1585
+ )
1586
+ tStP = cute.make_tensor(tStSi.iterator + self.tmem_s_to_p_offset, tStP_layout)
1587
+
1588
+ tmem_load_atom = cute.make_copy_atom(
1589
+ tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)),
1590
+ Float32,
1591
+ )
1592
+ thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi).get_slice(tidx)
1593
+ tStS_t2r = thr_tmem_load.partition_S(tStSi)
1594
+
1595
+ tmem_store_scale_atom = cute.make_copy_atom(
1596
+ tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)),
1597
+ Float32,
1598
+ )
1599
+ thr_tmem_store_scale = tcgen05.make_tmem_copy(tmem_store_scale_atom, tStScale).get_slice(
1600
+ tidx
1601
+ )
1602
+
1603
+ tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale)
1604
+ tmem_store_atom = cute.make_copy_atom(
1605
+ tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)),
1606
+ Float32,
1607
+ )
1608
+ thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx)
1609
+ tStP_r2t = thr_tmem_store.partition_D(tStP)
1610
+
1611
+ mma_si_consumer_phase = Int32(0)
1612
+ si_corr_producer_phase = Int32(1)
1613
+ s0_s1_sequence_phase = Int32(1 if stage == 0 else 0)
1614
+
1615
+ # self.warp_scheduler_barrier_init()
1616
+
1617
+ warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
1618
+ mbar_s0_s1_sequence_offset = self.mbar_s0_s1_sequence_offset + warp_idx_in_wg
1619
+
1620
+ tile_scheduler = TileSchedulerCls()
1621
+ work_tile = tile_scheduler.initial_work_tile_info()
1622
+ while work_tile.is_valid_tile:
1623
+ m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
1624
+ seqlen = SeqlenInfoCls(batch_idx)
1625
+ n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)
1626
+
1627
+ mask = AttentionMaskCls(seqlen)
1628
+ shared_mask_kwargs = dict(
1629
+ m_block=self.q_stage * m_block + stage,
1630
+ thr_mma=thr_mma_qk,
1631
+ thr_tmem_load=thr_tmem_load,
1632
+ mask_causal=self.is_causal,
1633
+ mask_local=self.is_local,
1634
+ batch_idx=batch_idx,
1635
+ head_idx=head_idx,
1636
+ aux_tensors=aux_tensors,
1637
+ )
1638
+
1639
+ # Recompute fastdiv_mods if necessary
1640
+ recompute_fastdiv_mods_q = cutlass.const_expr(
1641
+ aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
1642
+ )
1643
+ recompute_fastdiv_mods_k = cutlass.const_expr(
1644
+ aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k)
1645
+ )
1646
+
1647
+ if cutlass.const_expr(fastdiv_mods is not None):
1648
+ seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
1649
+ fastdiv_mods = (
1650
+ seqlen_q_divmod
1651
+ if not recompute_fastdiv_mods_q
1652
+ else FastDivmodDivisor(seqlen.seqlen_q),
1653
+ seqlen_k_divmod
1654
+ if not recompute_fastdiv_mods_k
1655
+ else FastDivmodDivisor(seqlen.seqlen_k),
1656
+ )
1657
+
1658
+ mask_mod = self.mask_mod if const_expr(self.mask_mod is not None) else None
1659
+ mask_fn = partial(
1660
+ mask.apply_mask_sm100,
1661
+ mask_mod=mask_mod,
1662
+ fastdiv_mods=fastdiv_mods,
1663
+ **shared_mask_kwargs,
1664
+ )
1665
+ if const_expr(self.use_block_sparsity):
1666
+ # Full blocks dont need mask_mod
1667
+ mask_fn_none = partial(
1668
+ mask.apply_mask_sm100,
1669
+ mask_mod=None,
1670
+ fastdiv_mods=fastdiv_mods,
1671
+ **shared_mask_kwargs,
1672
+ )
1673
+ else:
1674
+ mask_fn_none = None
1675
+
1676
+ softmax = SoftmaxSm100.create(
1677
+ softmax_scale_log2,
1678
+ rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0,
1679
+ softmax_scale=softmax_scale,
1680
+ )
1681
+ softmax.reset()
1682
+
1683
+ if const_expr(self.use_block_sparsity):
1684
+ tile_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1)
1685
+ has_work = tile_block_count > Int32(0)
1686
+ else:
1687
+ tile_block_count = n_block_max - n_block_min
1688
+ has_work = const_expr(not self.is_split_kv) or tile_block_count > Int32(0)
1689
+
1690
+ softmax_step = partial(
1691
+ self.softmax_step,
1692
+ softmax=softmax,
1693
+ mbar_ptr=mbar_ptr,
1694
+ mbar_s0_s1_sequence_offset=mbar_s0_s1_sequence_offset,
1695
+ thr_mma_qk=thr_mma_qk,
1696
+ thr_tmem_load=thr_tmem_load,
1697
+ thr_tmem_store=thr_tmem_store,
1698
+ thr_tmem_store_scale=thr_tmem_store_scale,
1699
+ tStS_t2r=tStS_t2r,
1700
+ tStScale_r2t=tStScale_r2t,
1701
+ tStP_r2t=tStP_r2t,
1702
+ sScale=sScale,
1703
+ stage=stage,
1704
+ batch_idx=batch_idx,
1705
+ head_idx=head_idx,
1706
+ m_block=self.q_stage * m_block + stage,
1707
+ seqlen=seqlen,
1708
+ aux_tensors=aux_tensors,
1709
+ fastdiv_mods=fastdiv_mods,
1710
+ )
1711
+
1712
+ if has_work:
1713
+ # Softmax acts as the producer: wait until correction signals the stage is empty
1714
+ cute.arch.mbarrier_wait(
1715
+ mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase
1716
+ )
1717
+ si_corr_producer_phase ^= 1
1718
+
1719
+ # Block sparse or dense iteration
1720
+ if const_expr(self.use_block_sparsity):
1721
+ # When aux_tensors exist, Q indices beyond seqlen_q must be wrapped to avoid
1722
+ # OOB aux_tensor access. Only edge tiles (where m_tile_end > seqlen_q) need this.
1723
+ if const_expr(aux_tensors is not None):
1724
+ m_tile_end = (self.q_stage * m_block + stage + 1) * self.m_block_size
1725
+ check_m_boundary = m_tile_end > seqlen.seqlen_q
1726
+ else:
1727
+ check_m_boundary = False
1728
+ (
1729
+ mma_si_consumer_phase,
1730
+ si_corr_producer_phase,
1731
+ s0_s1_sequence_phase,
1732
+ empty_tile,
1733
+ ) = softmax_block_sparse_sm100(
1734
+ blocksparse_tensors,
1735
+ batch_idx,
1736
+ head_idx,
1737
+ m_block,
1738
+ softmax_step,
1739
+ mask_fn,
1740
+ mask_fn_none,
1741
+ mma_si_consumer_phase,
1742
+ si_corr_producer_phase,
1743
+ s0_s1_sequence_phase,
1744
+ mbar_ptr,
1745
+ self.mbar_softmax_corr_full_offset,
1746
+ self.mbar_softmax_corr_empty_offset,
1747
+ self.mbar_P_full_O_rescaled_offset,
1748
+ self.mbar_P_full_2_offset,
1749
+ self.q_stage,
1750
+ Int32(stage),
1751
+ check_m_boundary,
1752
+ self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1753
+ )
1754
+ if not empty_tile:
1755
+ sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0]
1756
+ if const_expr(mLSE is not None or learnable_sink is not None):
1757
+ sScale[
1758
+ tidx + stage * self.m_block_size + self.m_block_size * 2
1759
+ ] = softmax.row_max[0]
1760
+ # if tidx == 0:
1761
+ # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0])
1762
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage)
1763
+ # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0])
1764
+ else:
1765
+ if const_expr(not self.is_split_kv) or tile_block_count > Int32(0):
1766
+ mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(
1767
+ mma_si_consumer_phase,
1768
+ si_corr_producer_phase,
1769
+ s0_s1_sequence_phase,
1770
+ n_block_max - 1,
1771
+ is_first=True,
1772
+ mask_fn=partial(mask_fn, mask_seqlen=True),
1773
+ )
1774
+ n_block_max -= 1
1775
+ # Next couple of iterations with causal masking
1776
+ if const_expr(self.is_causal or self.is_local):
1777
+ n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask(
1778
+ seqlen, m_block, n_block_min
1779
+ )
1780
+ for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1):
1781
+ n_block = n_block_max - 1 - n_tile
1782
+ mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = (
1783
+ softmax_step(
1784
+ mma_si_consumer_phase,
1785
+ si_corr_producer_phase,
1786
+ s0_s1_sequence_phase,
1787
+ n_block,
1788
+ mask_fn=partial(mask_fn, mask_seqlen=False),
1789
+ )
1790
+ )
1791
+ n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask)
1792
+ # The remaining iterations have no masking (but may still need mask_mod)
1793
+ n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask(
1794
+ seqlen, m_block, n_block_min
1795
+ )
1796
+ for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1):
1797
+ n_block = n_block_max - n_tile - 1
1798
+ if const_expr(self.mask_mod is not None):
1799
+ mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(
1800
+ mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block,
1801
+ mask_fn=partial(mask_fn, mask_seqlen=False),
1802
+ )
1803
+ else:
1804
+ mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(
1805
+ mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block,
1806
+ )
1807
+ # Separate iterations with local masking on the left
1808
+ if const_expr(self.is_local and block_info.window_size_left is not None):
1809
+ n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask)
1810
+ for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1):
1811
+ n_block = n_block_max - 1 - n_tile
1812
+ mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = (
1813
+ softmax_step(
1814
+ mma_si_consumer_phase,
1815
+ si_corr_producer_phase,
1816
+ s0_s1_sequence_phase,
1817
+ n_block,
1818
+ mask_fn=partial(mask_fn, mask_seqlen=False),
1819
+ )
1820
+ )
1821
+ # Now that we no longer already have the 1st iteration, need mask_seqlen=True here
1822
+
1823
+ # Dense path always writes scale / signals
1824
+ sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0]
1825
+ if const_expr(mLSE is not None or learnable_sink is not None):
1826
+ sScale[
1827
+ tidx + stage * self.m_block_size + self.m_block_size * 2
1828
+ ] = softmax.row_max[0]
1829
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage)
1830
+
1831
+ # # Write LSE to gmem
1832
+ # if const_expr(mLSE is not None):
1833
+ # acc_O_mn_row_is_zero_or_nan = softmax.row_sum[0] == 0.0 or softmax.row_sum[0] != softmax.row_sum[0]
1834
+ # scale = (
1835
+ # cute.arch.rcp_approx(softmax.row_sum[0] if not acc_O_mn_row_is_zero_or_nan else 1.0)
1836
+ # )
1837
+ # LN2 = math.log(2.0)
1838
+ # lse = (
1839
+ # (softmax.row_max[0] * softmax.scale_log2 + utils.log2f(softmax.row_sum[0])) * LN2
1840
+ # if not acc_O_mn_row_is_zero_or_nan else -Float32.inf
1841
+ # )
1842
+ # if const_expr(not seqlen.has_cu_seqlens_q):
1843
+ # mLSE_cur = mLSE[None, head_idx, batch_idx]
1844
+ # else:
1845
+ # mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx])
1846
+ # gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block * 2 + stage,))
1847
+ # if tidx < seqlen.seqlen_q - (m_block * 2 + stage) * self.m_block_size:
1848
+ # gLSE[tidx] = lse
1849
+
1850
+ # Advance to next tile
1851
+ tile_scheduler.advance_to_next_work()
1852
+ work_tile = tile_scheduler.get_current_work()
1853
+ # End of persistent scheduler loop
1854
+
1855
+ @cute.jit
1856
+ def softmax_step(
1857
+ self,
1858
+ mma_si_consumer_phase: Int32,
1859
+ si_corr_producer_phase: Int32,
1860
+ s0_s1_sequence_phase: Int32,
1861
+ n_block: Int32,
1862
+ softmax: SoftmaxSm100,
1863
+ mbar_ptr: cute.Pointer,
1864
+ mbar_s0_s1_sequence_offset: Int32,
1865
+ thr_mma_qk: cute.core.ThrMma,
1866
+ thr_tmem_load: cute.CopyAtom,
1867
+ thr_tmem_store: cute.CopyAtom,
1868
+ thr_tmem_store_scale: cute.CopyAtom,
1869
+ tStS_t2r: cute.Tensor,
1870
+ tStScale_r2t: cute.Tensor,
1871
+ tStP_r2t: cute.Tensor,
1872
+ sScale: cute.Tensor,
1873
+ stage: int | Int32,
1874
+ batch_idx: Int32,
1875
+ head_idx: Int32,
1876
+ m_block: Int32,
1877
+ seqlen,
1878
+ aux_tensors: Optional[list] = None,
1879
+ fastdiv_mods=(None, None),
1880
+ mask_fn: Optional[Callable] = None,
1881
+ is_first: bool = False,
1882
+ ) -> Tuple[cute.Int32, cute.Int32, cute.Int32]:
1883
+ """Perform a single step of the softmax computation on a block of attention scores.
1884
+
1885
+ This method processes one block of the attention matrix, computing numerically stable
1886
+ softmax by first finding the row maximum, subtracting it from all elements, applying
1887
+ exponential function, and then normalizing by the sum of exponentials. It also handles
1888
+ optional masking of attention scores.
1889
+
1890
+ The method involves several key operations:
1891
+ 1. Loading attention scores from tensor memory
1892
+ 2. Applying optional masking based on position
1893
+ 3. Computing row-wise maximum values for numerical stability
1894
+ 4. Transforming scores using exp2(x*scale - max*scale)
1895
+ 5. Computing row sums for normalization
1896
+ 6. Coordinating pipeline synchronization between different processing stages
1897
+ """
1898
+ tilePlikeFP32 = self.mma_tiler_qk[1] // Float32.width * self.v_dtype.width
1899
+ tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2]))
1900
+ tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1)))
1901
+ tScP = cute.composition(tScS, cute.make_layout((self.m_block_size, tilePlikeFP32)))
1902
+
1903
+ # Wait for Si
1904
+ cute.arch.mbarrier_wait(mbar_ptr + self.mbar_S_full_offset + stage, mma_si_consumer_phase)
1905
+ tSrS_t2r = cute.make_fragment(thr_tmem_load.partition_D(tScS).shape, self.qk_acc_dtype)
1906
+ cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r)
1907
+ if cutlass.const_expr(self.score_mod is not None):
1908
+ self.apply_score_mod(
1909
+ tSrS_t2r,
1910
+ thr_tmem_load,
1911
+ thr_mma_qk,
1912
+ batch_idx,
1913
+ head_idx,
1914
+ m_block,
1915
+ n_block,
1916
+ softmax,
1917
+ seqlen,
1918
+ aux_tensors,
1919
+ fastdiv_mods,
1920
+ )
1921
+
1922
+ if const_expr(mask_fn is not None):
1923
+ mask_fn(tSrS_t2r, n_block=n_block)
1924
+ row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load(), is_first)
1925
+
1926
+ if const_expr(not is_first):
1927
+ # tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScScale).shape, Float32)
1928
+ # tSrScale_r2t[0] = acc_scale
1929
+ # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t)
1930
+ # cute.arch.fence_view_async_tmem_store()
1931
+ thread_idx = thr_tmem_load.thr_idx
1932
+ sScale[thread_idx + stage * self.m_block_size] = acc_scale
1933
+ # if thread_idx == 0: cute.printf("softmax acc_scale stage %d: %f, row_max = %f\n", stage, acc_scale, row_max)
1934
+ # Notify correction wg that row_max is ready
1935
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage)
1936
+
1937
+ # if thread_idx == 0 and stage == 0: cute.print_tensor(tSrS_t2r)
1938
+ # print(tSrS_t2r)
1939
+ softmax.scale_subtract_rowmax(tSrS_t2r, row_max)
1940
+ # Sequence barrier wait
1941
+ if const_expr(self.s0_s1_barrier):
1942
+ cute.arch.mbarrier_wait(
1943
+ mbar_ptr + mbar_s0_s1_sequence_offset + stage * 4, s0_s1_sequence_phase
1944
+ )
1945
+ tSrP_r2t_f32 = cute.make_fragment(thr_tmem_store.partition_S(tScP).shape, Float32)
1946
+ tSrP_r2t = cute.make_tensor(
1947
+ cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype),
1948
+ tSrS_t2r.layout,
1949
+ )
1950
+ # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t)
1951
+ softmax.apply_exp2_convert(
1952
+ tSrS_t2r,
1953
+ tSrP_r2t,
1954
+ e2e=mask_fn is None and self.head_dim_padded <= 128,
1955
+ e2e_freq=self.e2e_freq,
1956
+ )
1957
+ # Sequence barrier arrive
1958
+ if const_expr(self.s0_s1_barrier):
1959
+ cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4)
1960
+ # print(tSrP_r2t_f32, tStP_r2t)
1961
+ # cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t)
1962
+ for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 4 * 3):
1963
+ cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i])
1964
+ cute.arch.fence_view_async_tmem_store()
1965
+ # Notify mma warp that P is ready
1966
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage)
1967
+ for i in cutlass.range_constexpr(
1968
+ cute.size(tStP_r2t.shape[2]) // 4 * 3, cute.size(tStP_r2t.shape[2])
1969
+ ):
1970
+ cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i])
1971
+ cute.arch.fence_view_async_tmem_store()
1972
+ # Notify mma warp that the 2nd half of P is ready
1973
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_2_offset + stage)
1974
+ cute.arch.mbarrier_wait(
1975
+ mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase
1976
+ )
1977
+ softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first)
1978
+ # acc_scale = cute.arch.exp2(acc_scale_)
1979
+ return mma_si_consumer_phase ^ 1, si_corr_producer_phase ^ 1, s0_s1_sequence_phase ^ 1
1980
+
1981
+ @cute.jit
1982
+ def correction_loop(
1983
+ self,
1984
+ thr_mma_qk: cute.core.ThrMma,
1985
+ thr_mma_pv: cute.core.ThrMma,
1986
+ tStS: cute.Tensor,
1987
+ tOtOs: tuple[cute.Tensor],
1988
+ sScale: cute.Tensor,
1989
+ mO: cute.Tensor,
1990
+ mLSE: cute.Tensor,
1991
+ sO: cute.Tensor,
1992
+ learnable_sink: Optional[cute.Tensor],
1993
+ gmem_tiled_copy_O: cute.TiledCopy,
1994
+ tma_atom_O: cute.CopyAtom,
1995
+ mbar_ptr: cute.Pointer,
1996
+ softmax_scale_log2: Float32,
1997
+ block_info: BlockInfo,
1998
+ num_splits: Int32,
1999
+ SeqlenInfoCls: Callable,
2000
+ TileSchedulerCls: Callable,
2001
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
2002
+ ):
2003
+ tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids))
2004
+ tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2]))
2005
+ tStScale_layout = cute.composition(tStS.layout, cute.make_layout((self.m_block_size, 1)))
2006
+ tStScales = tuple(
2007
+ cute.make_tensor(tStS.iterator + self.tmem_vec_offset[stage], tStScale_layout)
2008
+ for stage in range(self.q_stage)
2009
+ )
2010
+ tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1)))
2011
+ tmem_load_v_atom = cute.make_copy_atom(
2012
+ tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)),
2013
+ self.qk_acc_dtype,
2014
+ )
2015
+ thr_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScales[0]).get_slice(tidx)
2016
+
2017
+ tStScales_t2r = [thr_tmem_load_vec.partition_S(tStScales[stage]) for stage in range(self.q_stage)]
2018
+ tSrScale_t2r_shape = thr_tmem_load_vec.partition_D(tScScale).shape
2019
+
2020
+ # First iter: no correction is required
2021
+ for stage in cutlass.range_constexpr(self.q_stage):
2022
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage)
2023
+
2024
+ softmax_corr_consumer_phase = Int32(0)
2025
+ o_corr_consumer_phase = Int32(0)
2026
+ corr_epi_producer_phase = Int32(1)
2027
+
2028
+ tile_scheduler = TileSchedulerCls()
2029
+ work_tile = tile_scheduler.initial_work_tile_info()
2030
+ while work_tile.is_valid_tile:
2031
+ m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
2032
+ seqlen = SeqlenInfoCls(batch_idx)
2033
+ n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)
2034
+
2035
+ if const_expr(self.is_split_kv):
2036
+ mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
2037
+ else:
2038
+ mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]
2039
+ gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0))
2040
+
2041
+ # Default LSE to -inf for invalid split_idx tiles
2042
+ stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage
2043
+
2044
+ if const_expr(self.use_block_sparsity):
2045
+ total_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1)
2046
+ has_work = total_block_count > Int32(0)
2047
+ else:
2048
+ total_block_count = n_block_max - n_block_min
2049
+ has_work = const_expr(not self.is_split_kv) or total_block_count > Int32(0)
2050
+
2051
+ if has_work:
2052
+ # Ignore first signal from softmax as no correction is required
2053
+ cute.arch.mbarrier_wait(
2054
+ mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase
2055
+ )
2056
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 0)
2057
+ if const_expr(self.q_stage == 2):
2058
+ cute.arch.mbarrier_wait(
2059
+ mbar_ptr + self.mbar_softmax_corr_full_offset + 1, softmax_corr_consumer_phase
2060
+ )
2061
+ softmax_corr_consumer_phase ^= 1
2062
+
2063
+ tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32)
2064
+ for i in cutlass.range(total_block_count - 1, unroll=1):
2065
+ for stage in cutlass.range_constexpr(self.q_stage):
2066
+ # wait for S0 / S1
2067
+ cute.arch.mbarrier_wait(
2068
+ mbar_ptr + self.mbar_softmax_corr_full_offset + stage,
2069
+ softmax_corr_consumer_phase,
2070
+ )
2071
+ # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r)
2072
+ # cute.arch.fence_view_async_tmem_load()
2073
+ # scale = tSrScale_t2r[0]
2074
+ scale = sScale[tidx + stage * self.m_block_size]
2075
+ should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0
2076
+ # should_rescale = True
2077
+ # if tidx == 0: cute.printf("Correction scale i = %d, for stage %d: %f, should_rescale = %d\n", i, stage, scale, should_rescale)
2078
+ # Don't need O_full anymore, since by the time softmax has signaled the correction
2079
+ # warps, S_i must have been done, so O_i-1 must have been done as well.
2080
+ # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase)
2081
+ if should_rescale:
2082
+ self.correction_rescale(
2083
+ thr_mma_pv, tOtOs[stage], tidx, scale
2084
+ )
2085
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage)
2086
+ if const_expr(self.q_stage == 2):
2087
+ cute.arch.mbarrier_arrive(
2088
+ mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage)
2089
+ )
2090
+ else:
2091
+ cute.arch.mbarrier_arrive(
2092
+ mbar_ptr + self.mbar_softmax_corr_empty_offset + stage
2093
+ )
2094
+ softmax_corr_consumer_phase ^= 1
2095
+ # o_corr_consumer_phase ^= 1
2096
+ if const_expr(self.q_stage == 2):
2097
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 1)
2098
+ # End of seqlen_corr_loop_steps
2099
+
2100
+ # Even in the case of self.overlap_sO_sQ, we can write to stage 0 of sO without
2101
+ # additional sync because the MMA in the top half must have been done.
2102
+ # Similarly we can write to stage 1 of sO without additional sync.
2103
+ learnable_sink_val = [None] * self.q_stage
2104
+ if const_expr(learnable_sink is not None):
2105
+ if const_expr(not self.pack_gqa):
2106
+ sink_val = Float32(learnable_sink[head_idx])
2107
+ learnable_sink_val = [sink_val] * self.q_stage
2108
+ else: # Each thread might have a different sink value due to different q_head
2109
+ for stage in cutlass.range_constexpr(self.q_stage):
2110
+ q_head_idx = (
2111
+ (self.q_stage * m_block + stage) * self.m_block_size + tidx
2112
+ ) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead
2113
+ learnable_sink_val[stage] = Float32(learnable_sink[q_head_idx])
2114
+ for stage in cutlass.range_constexpr(self.q_stage):
2115
+ cute.arch.mbarrier_wait(
2116
+ mbar_ptr + self.mbar_softmax_corr_full_offset + stage,
2117
+ softmax_corr_consumer_phase,
2118
+ )
2119
+ # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r)
2120
+ # cute.arch.fence_view_async_tmem_load()
2121
+ # scale = tSrScale_t2r[0]
2122
+ row_sum = sScale[tidx + stage * self.m_block_size]
2123
+ if const_expr(mLSE is not None or learnable_sink is not None):
2124
+ row_max = sScale[tidx + stage * self.m_block_size + self.m_block_size * 2]
2125
+ else:
2126
+ row_max = None
2127
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage)
2128
+ if const_expr(learnable_sink is not None):
2129
+ LOG2_E = math.log2(math.e)
2130
+ sink_val = learnable_sink_val[stage]
2131
+ if const_expr(not self.is_split_kv) or split_idx == 0:
2132
+ if row_max == -Float32.inf:
2133
+ # It's possible to have an empty row with splitKV.
2134
+ row_max = sink_val * (LOG2_E / softmax_scale_log2)
2135
+ row_sum = Float32(1.0)
2136
+ else:
2137
+ row_sum += utils.exp2f(
2138
+ sink_val * LOG2_E - row_max * softmax_scale_log2
2139
+ )
2140
+ acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum
2141
+ stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan)
2142
+ scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0)
2143
+ cute.arch.mbarrier_wait(
2144
+ mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase
2145
+ )
2146
+ if const_expr(not self.use_correction_warps_for_epi):
2147
+ cute.arch.mbarrier_wait(
2148
+ mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase
2149
+ )
2150
+ self.correction_epilogue(
2151
+ thr_mma_pv,
2152
+ tOtOs[stage],
2153
+ tidx,
2154
+ stage,
2155
+ m_block,
2156
+ seqlen.seqlen_q,
2157
+ scale,
2158
+ sO[None, None, stage],
2159
+ mO_cur,
2160
+ gO,
2161
+ gmem_tiled_copy_O,
2162
+ )
2163
+ if const_expr(not self.use_correction_warps_for_epi):
2164
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage)
2165
+ # Signal for the next work tile that O buffers in tmem are already read, so
2166
+ # mma warp can write to them
2167
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage)
2168
+ # if tidx == 0: cute.printf("Correction final scale for stage %d: %f\n", stage, scale)
2169
+
2170
+ o_corr_consumer_phase ^= 1
2171
+ softmax_corr_consumer_phase ^= 1
2172
+ corr_epi_producer_phase ^= 1
2173
+ else:
2174
+ # WARNING: we need some code before the const_expr, see https://github.com/NVIDIA/cutlass/issues/2781
2175
+ if const_expr(self.use_correction_warps_for_epi):
2176
+ gmem_tiled_copy_O_for_empty_tile = gmem_tiled_copy_O
2177
+ else:
2178
+ gmem_tiled_copy_O_for_empty_tile = None
2179
+ if const_expr(self.use_block_sparsity):
2180
+ (
2181
+ softmax_corr_consumer_phase,
2182
+ o_corr_consumer_phase,
2183
+ corr_epi_producer_phase,
2184
+ ) = handle_block_sparse_empty_tile_correction_sm100(
2185
+ tidx,
2186
+ self.q_stage,
2187
+ self.m_block_size,
2188
+ self.qhead_per_kvhead,
2189
+ self.pack_gqa,
2190
+ self.is_split_kv,
2191
+ learnable_sink,
2192
+ mLSE,
2193
+ seqlen,
2194
+ m_block,
2195
+ head_idx,
2196
+ batch_idx,
2197
+ split_idx,
2198
+ sScale,
2199
+ stats,
2200
+ self.correction_epilogue,
2201
+ thr_mma_pv,
2202
+ tOtOs,
2203
+ sO,
2204
+ mbar_ptr,
2205
+ self.mbar_softmax_corr_full_offset,
2206
+ self.mbar_softmax_corr_empty_offset,
2207
+ self.mbar_P_full_O_rescaled_offset,
2208
+ self.mbar_P_full_2_offset,
2209
+ self.mbar_corr_epi_full_offset,
2210
+ self.mbar_corr_epi_empty_offset,
2211
+ softmax_corr_consumer_phase,
2212
+ o_corr_consumer_phase,
2213
+ corr_epi_producer_phase,
2214
+ softmax_scale_log2,
2215
+ mO_cur,
2216
+ gO,
2217
+ gmem_tiled_copy_O_for_empty_tile,
2218
+ )
2219
+
2220
+ if const_expr(mLSE is not None):
2221
+ if const_expr(not seqlen.has_cu_seqlens_q):
2222
+ if const_expr(self.is_split_kv):
2223
+ mLSE_cur = mLSE[None, head_idx, batch_idx, split_idx]
2224
+ else:
2225
+ mLSE_cur = mLSE[None, head_idx, batch_idx]
2226
+ else:
2227
+ offset = (
2228
+ seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q)
2229
+ )
2230
+ if const_expr(self.is_split_kv):
2231
+ mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx, split_idx])
2232
+ else:
2233
+ mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx])
2234
+ for stage in cutlass.range_constexpr(self.q_stage):
2235
+ gLSE = cute.local_tile(
2236
+ mLSE_cur, (self.m_block_size,), (self.q_stage * m_block + stage,)
2237
+ )
2238
+ row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage]
2239
+ # if tidx == 0 and stage <= 1:
2240
+ # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan)
2241
+ LN2 = math.log(2.0)
2242
+ lse = (
2243
+ (row_max * softmax_scale_log2 + utils.log2f(row_sum)) * LN2
2244
+ if not acc_O_mn_row_is_zero_or_nan
2245
+ else -Float32.inf
2246
+ )
2247
+ seqlen_q = (
2248
+ seqlen.seqlen_q
2249
+ if const_expr(not self.pack_gqa)
2250
+ else seqlen.seqlen_q * self.qhead_per_kvhead
2251
+ )
2252
+ if tidx < seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size:
2253
+ # This actually just works with PackGQA too
2254
+ gLSE[tidx] = lse
2255
+
2256
+ # Advance to next tile
2257
+ tile_scheduler.advance_to_next_work()
2258
+ work_tile = tile_scheduler.get_current_work()
2259
+ # End of persistent scheduler loop
2260
+
2261
+ @cute.jit
2262
+ def correction_rescale(
2263
+ self,
2264
+ thr_mma: cute.core.ThrMma,
2265
+ tOtO: cute.Tensor,
2266
+ tidx: Int32,
2267
+ scale: Float32,
2268
+ ):
2269
+ """Rescale intermediate attention results based on softmax normalization factor.
2270
+
2271
+ This method performs a crucial correction step in the attention computation pipeline.
2272
+ When processing attention in blocks, the softmax normalization factors may change
2273
+ as new blocks are processed. This method rescales previously computed partial
2274
+ output values to account for updated normalization factors.
2275
+
2276
+ The implementation uses efficient tensor memory operations to:
2277
+ 1. Load existing partial attention output from tensor memory
2278
+ 2. Apply the scaling factor to all elements
2279
+ 3. Store the rescaled results back to tensor memory
2280
+ """
2281
+ tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2]))
2282
+ corr_tile_size = 16 # tuneable parameter
2283
+ tmem_load_atom = cute.make_copy_atom(
2284
+ tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
2285
+ self.pv_acc_dtype,
2286
+ )
2287
+ tmem_store_atom = cute.make_copy_atom(
2288
+ tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
2289
+ self.pv_acc_dtype,
2290
+ )
2291
+ tOtO_i = cute.composition(tOtO, cute.make_layout((self.m_block_size, corr_tile_size)))
2292
+ tOcO_i = cute.composition(tOcO, cute.make_layout((self.m_block_size, corr_tile_size)))
2293
+ thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i).get_slice(tidx)
2294
+ thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i).get_slice(tidx)
2295
+ tOtO_t2r = thr_tmem_load.partition_S(tOtO_i)
2296
+ tOrO_t2r_shape = thr_tmem_load.partition_D(tOcO_i).shape
2297
+ tOtO_r2t = thr_tmem_store.partition_D(tOtO_i)
2298
+
2299
+ frg_count = self.head_dim_v_padded // corr_tile_size
2300
+ tOrO_frg = cute.make_fragment((tOrO_t2r_shape, frg_count), self.pv_acc_dtype)
2301
+ for i in cutlass.range_constexpr(frg_count):
2302
+ tOrO_frg = cute.make_fragment(tOrO_t2r_shape, self.pv_acc_dtype)
2303
+ tOtO_t2r_i = cute.make_tensor(tOtO_t2r.iterator + i * corr_tile_size, tOtO_t2r.layout)
2304
+ cute.copy(thr_tmem_load, tOtO_t2r_i, tOrO_frg)
2305
+ for j in cutlass.range(0, cute.size(tOrO_frg), 2, unroll_full=True):
2306
+ tOrO_frg[j], tOrO_frg[j + 1] = utils.mul_packed_f32x2(
2307
+ (tOrO_frg[j], tOrO_frg[j + 1]),
2308
+ (scale, scale),
2309
+ )
2310
+ tOtO_r2t_i = cute.make_tensor(tOtO_r2t.iterator + i * corr_tile_size, tOtO_r2t.layout)
2311
+ cute.copy(thr_tmem_store, tOrO_frg, tOtO_r2t_i)
2312
+ cute.arch.fence_view_async_tmem_store()
2313
+
2314
+ @cute.jit
2315
+ def correction_epilogue(
2316
+ self,
2317
+ thr_mma: cute.core.ThrMma,
2318
+ tOtO: cute.Tensor,
2319
+ tidx: Int32,
2320
+ stage: Int32,
2321
+ m_block: Int32,
2322
+ seqlen_q: Int32,
2323
+ scale: Float32,
2324
+ sO: cute.Tensor,
2325
+ mO_cur: Optional[cute.Tensor] = None,
2326
+ gO: Optional[cute.Tensor] = None,
2327
+ gmem_tiled_copy_O: Optional[cute.TiledCopy] = None,
2328
+ ):
2329
+ """Apply final scaling and transformation to attention output before writing to global memory.
2330
+
2331
+ This correction_epilogue function handles the final processing step for attention output values.
2332
+ It applies a scaling factor to the accumulated attention results and prepares the
2333
+ data for efficient transfer back to global memory.
2334
+
2335
+ The method performs:
2336
+ 1. Loading of accumulated attention results from tensor memory
2337
+ 2. Application of the final output scaling factor
2338
+ 3. Type conversion if necessary (typically from higher precision accumulator to output precision)
2339
+ 4. Reorganization of data for optimal memory access patterns
2340
+ 5. Preparation for efficient TMA store operations
2341
+
2342
+ :param thr_mma: Thread MMA operation for the computation
2343
+ :type thr_mma: cute.core.ThrMma
2344
+ :param tOtO: Tensor containing accumulated attention output
2345
+ :type tOtO: cute.Tensor
2346
+ :param scale: Final scaling factor to apply to the output
2347
+ :type scale: Float32
2348
+ :param sO: Shared memory tensor for the final output
2349
+ :type sO: cute.Tensor
2350
+ """
2351
+
2352
+ corr_tile_size = 32 * 8 // self.o_dtype.width
2353
+ tOsO = thr_mma.partition_C(sO)
2354
+ tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2]))
2355
+
2356
+ tOtO_i = cute.logical_divide(tOtO, cute.make_layout((self.m_block_size, corr_tile_size)))
2357
+ tOcO_i = cute.logical_divide(tOcO, cute.make_layout((self.m_block_size, corr_tile_size)))
2358
+ tOsO_i = cute.logical_divide(tOsO, cute.make_layout((self.m_block_size, corr_tile_size)))
2359
+
2360
+ epi_subtile = (self.epi_tile[0], corr_tile_size)
2361
+ tmem_copy_atom = sm100_utils_basic.get_tmem_load_op(
2362
+ self.mma_tiler_pv,
2363
+ self.o_layout,
2364
+ self.o_dtype,
2365
+ self.pv_acc_dtype,
2366
+ epi_subtile,
2367
+ use_2cta_instrs=False,
2368
+ )
2369
+ tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]).get_slice(
2370
+ tidx
2371
+ )
2372
+ thr_tmem_load = tiled_tmem_load.get_slice(tidx)
2373
+ smem_copy_atom = sm100_utils_basic.get_smem_store_op(
2374
+ self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load
2375
+ )
2376
+ tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load)
2377
+
2378
+ tOtO_t2r = thr_tmem_load.partition_S(tOtO_i[(None, None), None])
2379
+ tOsO_s2r = thr_tmem_load.partition_D(tOsO_i[(None, None), None])
2380
+ tOcO_t2r = thr_tmem_load.partition_D(tOcO_i[(None, None), None])
2381
+ for i in cutlass.range_constexpr(self.head_dim_v_padded // corr_tile_size):
2382
+ tOtO_t2r_i = tOtO_t2r[None, 0, 0, i]
2383
+ tOsO_r2s_i = tOsO_s2r[None, 0, 0, i]
2384
+ tOrO_frg = cute.make_fragment(tOcO_t2r[None, 0, 0, i].shape, self.pv_acc_dtype)
2385
+ cute.copy(tiled_tmem_load, tOtO_t2r_i, tOrO_frg)
2386
+ for j in cutlass.range_constexpr(0, cute.size(tOrO_frg), 2):
2387
+ tOrO_frg[j], tOrO_frg[j + 1] = utils.mul_packed_f32x2(
2388
+ (tOrO_frg[j], tOrO_frg[j + 1]),
2389
+ (scale, scale),
2390
+ )
2391
+ tOrO_frg_cvt = cute.make_fragment(tOrO_frg.shape, self.o_dtype)
2392
+ tOrO_frg_cvt.store(tOrO_frg.load().to(self.o_dtype))
2393
+ cute.copy(tiled_smem_store, tOrO_frg_cvt, tOsO_r2s_i)
2394
+ # fence view async shared
2395
+ cute.arch.fence_proxy(
2396
+ cute.arch.ProxyKind.async_shared,
2397
+ space=cute.arch.SharedSpace.shared_cta,
2398
+ )
2399
+
2400
+ if const_expr(self.use_correction_warps_for_epi):
2401
+ assert(not self.use_tma_O)
2402
+ assert(gmem_tiled_copy_O is not None)
2403
+ cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue),
2404
+ number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE)
2405
+ gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
2406
+ tOsO = gmem_thr_copy_O.partition_S(sO)
2407
+ cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
2408
+ tOgO = gmem_thr_copy_O.partition_D(gO)
2409
+ tOcO = gmem_thr_copy_O.partition_S(cO)
2410
+ t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
2411
+ tOpO = utils.predicate_k(tOcO, limit=mO_cur.shape[1])
2412
+ pack_gqa = PackGQA(
2413
+ self.m_block_size,
2414
+ self.head_dim_v_padded,
2415
+ self.check_hdim_v_oob,
2416
+ self.qhead_per_kvhead,
2417
+ )
2418
+
2419
+ # load acc O from smem to rmem for wider vectorization
2420
+ tOrO = cute.make_fragment_like(tOsO, self.o_dtype)
2421
+ cute.autovec_copy(tOsO, tOrO)
2422
+ # copy acc O from rmem to gmem
2423
+ if const_expr(not self.pack_gqa):
2424
+ for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
2425
+ if (
2426
+ t0OcO[0, rest_m, 0][0]
2427
+ < seqlen_q
2428
+ - (self.q_stage * m_block + stage) * self.m_block_size
2429
+ - tOcO[0][0]
2430
+ ):
2431
+ cute.copy(
2432
+ gmem_tiled_copy_O,
2433
+ tOrO[None, rest_m, None],
2434
+ tOgO[None, rest_m, None, self.q_stage * m_block + stage],
2435
+ pred=tOpO[None, rest_m, None]
2436
+ if const_expr(self.check_hdim_v_oob)
2437
+ else None,
2438
+ )
2439
+ else:
2440
+ pack_gqa.store_O(
2441
+ mO_cur,
2442
+ tOrO,
2443
+ gmem_tiled_copy_O,
2444
+ tidx,
2445
+ self.q_stage * m_block + stage,
2446
+ seqlen_q,
2447
+ )
2448
+
2449
+ @cute.jit
2450
+ def epilogue_s2g(
2451
+ self,
2452
+ mO: cute.Tensor,
2453
+ sO: cute.Tensor,
2454
+ gmem_tiled_copy_O: cute.TiledCopy,
2455
+ tma_atom_O: Optional[cute.CopyAtom],
2456
+ mbar_ptr: cute.Pointer,
2457
+ block_info: BlockInfo,
2458
+ num_splits: int,
2459
+ SeqlenInfoCls: Callable,
2460
+ TileSchedulerCls: Callable,
2461
+ ):
2462
+ epi_consumer_phase = Int32(0)
2463
+ tile_scheduler = TileSchedulerCls()
2464
+ work_tile = tile_scheduler.initial_work_tile_info()
2465
+ while work_tile.is_valid_tile:
2466
+ m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
2467
+ seqlen = SeqlenInfoCls(batch_idx)
2468
+ n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)
2469
+
2470
+ if const_expr(not self.is_split_kv) or n_block_min < n_block_max:
2471
+ if const_expr(self.is_split_kv):
2472
+ mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
2473
+ else:
2474
+ mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]
2475
+ gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0))
2476
+ if const_expr(self.use_tma_O):
2477
+ store_O, _, _ = copy_utils.tma_get_copy_fn(
2478
+ tma_atom_O, 0, cute.make_layout(1), sO, gO
2479
+ )
2480
+ for stage in cutlass.range_constexpr(self.q_stage):
2481
+ # wait from corr, issue tma store on smem
2482
+ # 1. wait for O0 / O1 final
2483
+ cute.arch.mbarrier_wait(
2484
+ mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase
2485
+ )
2486
+ # 2. copy O0 / O1 to gmem
2487
+ store_O(src_idx=stage, dst_idx=self.q_stage * m_block + stage)
2488
+ cute.arch.cp_async_bulk_commit_group()
2489
+ for stage in cutlass.range_constexpr(self.q_stage):
2490
+ # Ensure O0 / O1 buffer is ready to be released
2491
+ if const_expr(self.q_stage == 2):
2492
+ cute.arch.cp_async_bulk_wait_group(1 - stage, read=True)
2493
+ else:
2494
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
2495
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage)
2496
+ else:
2497
+ tidx = cute.arch.thread_idx()[0] % (
2498
+ cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)
2499
+ )
2500
+ gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
2501
+ tOsO = gmem_thr_copy_O.partition_S(sO)
2502
+ cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
2503
+ tOgO = gmem_thr_copy_O.partition_D(gO)
2504
+ tOcO = gmem_thr_copy_O.partition_S(cO)
2505
+ t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
2506
+ tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])
2507
+ pack_gqa = PackGQA(
2508
+ self.m_block_size,
2509
+ self.head_dim_v_padded,
2510
+ self.check_hdim_v_oob,
2511
+ self.qhead_per_kvhead,
2512
+ )
2513
+ for stage in cutlass.range_constexpr(self.q_stage):
2514
+ # wait from corr, issue tma store on smem
2515
+ # 1. wait for O0 / O1 final
2516
+ cute.arch.mbarrier_wait(
2517
+ mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase
2518
+ )
2519
+ # 2. copy O0 / O1 to gmem
2520
+ # load acc O from smem to rmem for wider vectorization
2521
+ tOrO = cute.make_fragment_like(tOsO[None, None, None, 0], self.o_dtype)
2522
+ cute.autovec_copy(tOsO[None, None, None, stage], tOrO)
2523
+ # copy acc O from rmem to gmem
2524
+ if const_expr(not self.pack_gqa):
2525
+ for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
2526
+ if (
2527
+ t0OcO[0, rest_m, 0][0]
2528
+ < seqlen.seqlen_q
2529
+ - (self.q_stage * m_block + stage) * self.m_block_size
2530
+ - tOcO[0][0]
2531
+ ):
2532
+ cute.copy(
2533
+ gmem_tiled_copy_O,
2534
+ tOrO[None, rest_m, None],
2535
+ tOgO[None, rest_m, None, self.q_stage * m_block + stage],
2536
+ pred=tOpO[None, rest_m, None]
2537
+ if const_expr(self.check_hdim_v_oob)
2538
+ else None,
2539
+ )
2540
+ else:
2541
+ pack_gqa.store_O(
2542
+ mO_cur,
2543
+ tOrO,
2544
+ gmem_tiled_copy_O,
2545
+ tidx,
2546
+ self.q_stage * m_block + stage,
2547
+ seqlen.seqlen_q,
2548
+ )
2549
+ cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage)
2550
+
2551
+ epi_consumer_phase ^= 1
2552
+
2553
+ # Advance to next tile
2554
+ tile_scheduler.advance_to_next_work()
2555
+ work_tile = tile_scheduler.get_current_work()
2556
+
2557
+ def load_Q(
2558
+ self,
2559
+ load_Q_fn: Callable,
2560
+ mbar_full_ptr: cute.Pointer,
2561
+ mbar_empty_ptr: cute.Pointer,
2562
+ block: Int32,
2563
+ stage: int,
2564
+ phase: Int32,
2565
+ ):
2566
+ cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase)
2567
+ with cute.arch.elect_one():
2568
+ cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, self.tma_copy_bytes["Q"])
2569
+ load_Q_fn(src_idx=block, dst_idx=stage, tma_bar_ptr=mbar_full_ptr + stage)
2570
+
2571
+ @cute.jit
2572
+ def load_KV(
2573
+ self,
2574
+ tma_atom: Optional[cute.CopyAtom],
2575
+ tXgX: Optional[cute.Tensor],
2576
+ tXsX: Optional[cute.Tensor],
2577
+ paged_kv_manager: Optional[PagedKVManager],
2578
+ sX: cute.Tensor,
2579
+ mbar_full_ptr: cute.Pointer,
2580
+ mbar_empty_ptr: cute.Pointer,
2581
+ block: Int32,
2582
+ producer_state: cutlass.pipeline.PipelineState,
2583
+ K_or_V: Literal["K", "V"],
2584
+ page_idx: Optional[Int32] = None,
2585
+ ):
2586
+ assert K_or_V in ("K", "V")
2587
+ stage, phase = producer_state.index, producer_state.phase
2588
+ cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase)
2589
+ if const_expr(K_or_V == "K" and self.uneven_kv_smem):
2590
+ # Before this round, the smem location was occupied by V, which is smaller than
2591
+ # K. So we need to wait for the stage after that (stage 1) to be empty as well.
2592
+ if stage == 0:
2593
+ cute.arch.mbarrier_wait(mbar_empty_ptr + 1, phase)
2594
+
2595
+ if const_expr(self.use_tma_KV):
2596
+ assert (
2597
+ tXgX is not None and
2598
+ tXsX is not None and
2599
+ tma_atom is not None
2600
+ )
2601
+ with cute.arch.elect_one():
2602
+ cute.arch.mbarrier_arrive_and_expect_tx(
2603
+ mbar_full_ptr + stage, self.tma_copy_bytes[K_or_V],
2604
+ )
2605
+ tXsX_cur = tXsX[None, stage]
2606
+ if const_expr(self.uneven_kv_smem):
2607
+ # Since this is the producer_state, the phase starts at 1, so we have to invert it
2608
+ tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase ^ 1)
2609
+ # Currently we assume that page_size == n_block_size so we index into tXgX with block = 0
2610
+ tXgX_cur = tXgX[None, block] if const_expr(page_idx is None) else tXgX[None, 0, page_idx]
2611
+ cute.copy(tma_atom, tXgX_cur, tXsX_cur, tma_bar_ptr=mbar_full_ptr + stage)
2612
+ else:
2613
+ assert paged_kv_manager is not None
2614
+ paged_kv_manager.load_KV(block, sX[None, None, None, stage], K_or_V)
2615
+ cute.arch.cp_async_commit_group()
2616
+ cute.arch.cp_async_mbarrier_arrive_noinc(mbar_full_ptr + stage)
2617
+
2618
+ @cute.jit
2619
+ def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32):
2620
+ if const_expr(self.uneven_kv_smem):
2621
+ # smem layout is [smem_large, smem_small, smem_large], and the current stride is
2622
+ # (smem_large + smem_small) // 2. So for stage == 1, move right by offset if
2623
+ # phase == 0, or left by offset if phase == 1.
2624
+ offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase)
2625
+ return cute.make_tensor(sX.iterator + offset, sX.layout)
2626
+ else:
2627
+ return sX
2628
+
2629
+ def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr):
2630
+ load_kv_consumer_group = cutlass.pipeline.CooperativeGroup(
2631
+ cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])
2632
+ )
2633
+ if self.use_tma_KV:
2634
+ load_kv_producer_group = cutlass.pipeline.CooperativeGroup(
2635
+ cutlass.pipeline.Agent.Thread, len(self.load_warp_ids)
2636
+ )
2637
+ return cutlass.pipeline.PipelineTmaUmma.create(
2638
+ barrier_storage=load_kv_mbar_ptr,
2639
+ num_stages=self.kv_stage,
2640
+ producer_group=load_kv_producer_group,
2641
+ consumer_group=load_kv_consumer_group,
2642
+ tx_count=self.tma_copy_bytes["K"],
2643
+ )
2644
+ else:
2645
+ load_kv_producer_group = cutlass.pipeline.CooperativeGroup(
2646
+ cutlass.pipeline.Agent.Thread, len(self.load_warp_ids) * cute.arch.WARP_SIZE
2647
+ )
2648
+ return cutlass.pipeline.PipelineAsyncUmma.create(
2649
+ num_stages=self.kv_stage,
2650
+ producer_group=load_kv_producer_group,
2651
+ consumer_group=load_kv_consumer_group,
2652
+ barrier_storage=load_kv_mbar_ptr,
2653
+ )
2654
+
2655
+ # @cute.jit
2656
+ # def warp_scheduler_barrier_init(self):
2657
+ # warp_group_idx = utils.canonical_warp_group_idx(sync=False)
2658
+ # if warp_group_idx == 0:
2659
+ # cute.arch.barrier_arrive(
2660
+ # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), number_of_threads=2 * 128,
2661
+ # )
2662
+
2663
+ # def warp_scheduler_barrier_sync(self):
2664
+ # cute.arch.barrier(
2665
+ # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False),
2666
+ # number_of_threads=2 * 128
2667
+ # )
2668
+
2669
+ # def warp_scheduler_barrier_arrive(self):
2670
+ # cur_wg = utils.canonical_warp_group_idx(sync=False)
2671
+ # next_wg = 1 - cur_wg
2672
+ # cute.arch.barrier_arrive(
2673
+ # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128,
2674
+ # )
2675
+
2676
+ @cute.jit
2677
+ def apply_score_mod(
2678
+ self,
2679
+ tSrS_t2r,
2680
+ thr_tmem_load,
2681
+ thr_mma_qk,
2682
+ batch_idx,
2683
+ head_idx,
2684
+ m_block,
2685
+ n_block,
2686
+ softmax,
2687
+ seqlen: SeqlenInfoQK,
2688
+ aux_tensors=None,
2689
+ fastdiv_mods=(None, None),
2690
+ ):
2691
+ """Apply score modification for SM100 (constant q_idx)."""
2692
+ # Prepare index tensor with extra partition
2693
+ cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size))
2694
+ cS = cute.domain_offset((m_block * self.m_block_size, n_block * self.n_block_size), cS)
2695
+ tScS = thr_mma_qk.partition_C(cS)
2696
+ tScS_t2r = thr_tmem_load.partition_D(tScS)
2697
+
2698
+ # Shared q_idx for all scores
2699
+ q_idx_logical = tScS_t2r[0][0]
2700
+
2701
+ # For Pack-GQA, compute the logical head index for this tile
2702
+ if cutlass.const_expr(self.pack_gqa):
2703
+ # Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead)
2704
+ q_physical = q_idx_logical
2705
+ q_idx_logical = q_physical // self.qhead_per_kvhead
2706
+ head_offset = q_physical - q_idx_logical * self.qhead_per_kvhead
2707
+ head_idx = head_idx * self.qhead_per_kvhead + head_offset
2708
+
2709
+ if cutlass.const_expr(aux_tensors is not None):
2710
+ seqlen_q_divmod, _ = fastdiv_mods
2711
+ _, q_idx_logical = divmod(q_idx_logical, seqlen_q_divmod)
2712
+
2713
+ apply_score_mod_inner(
2714
+ tSrS_t2r,
2715
+ tScS_t2r,
2716
+ self.score_mod,
2717
+ batch_idx,
2718
+ head_idx,
2719
+ softmax.softmax_scale,
2720
+ self.vec_size,
2721
+ self.qk_acc_dtype,
2722
+ aux_tensors,
2723
+ fastdiv_mods,
2724
+ seqlen_info=seqlen,
2725
+ constant_q_idx=q_idx_logical,
2726
+ qhead_per_kvhead=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1,
2727
+ )