mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2471 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ # A reimplementation of
4
+ # https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_kernel_sm80.h
5
+ # and https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_kernel_sm90.h
6
+ # from Cutlass C++ to Cute-DSL.
7
+ # Built on Cute-DSL example: https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py
8
+
9
+ import math
10
+ from types import SimpleNamespace
11
+ from typing import Type, Callable, Optional, List
12
+ from functools import partial
13
+
14
+ import cuda.bindings.driver as cuda
15
+
16
+ import cutlass
17
+ import cutlass.cute as cute
18
+ from cutlass import Constexpr, Float32, Int32, const_expr, Boolean
19
+ from cutlass.cute.nvgpu import cpasync, warp, warpgroup
20
+ from cutlass.cute.arch import ProxyKind, SharedSpace
21
+ import cutlass.utils as utils_basic
22
+ from cutlass.utils import LayoutEnum
23
+ import cutlass.utils.hopper_helpers as sm90_utils_basic
24
+
25
+ from mslk.attention.flash_attn import ampere_helpers as sm80_utils
26
+ from mslk.attention.flash_attn import hopper_helpers as sm90_utils
27
+ from mslk.attention.flash_attn import utils
28
+ from mslk.attention.flash_attn import copy_utils
29
+ from mslk.attention.flash_attn.mask import AttentionMask
30
+ from mslk.attention.flash_attn.softmax import Softmax, apply_score_mod_inner
31
+ from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
32
+ from mslk.attention.flash_attn.block_info import BlockInfo
33
+ from mslk.attention.flash_attn.block_sparsity import BlockSparseTensors
34
+ from mslk.attention.flash_attn.block_sparse_utils import (
35
+ produce_block_sparse_loads,
36
+ consume_block_sparse_loads,
37
+ )
38
+ from mslk.attention.flash_attn import pipeline
39
+ from mslk.attention.flash_attn.pack_gqa import PackGQA
40
+ from mslk.attention.flash_attn.named_barrier import NamedBarrierFwd
41
+ from mslk.attention.flash_attn.tile_scheduler import (
42
+ TileSchedulerArguments,
43
+ SingleTileScheduler,
44
+ SingleTileLPTScheduler,
45
+ SingleTileVarlenScheduler,
46
+ ParamsBase,
47
+ )
48
+ from cutlass.cute import FastDivmodDivisor
49
+
50
+
51
+ class FlashAttentionForwardBase:
52
+ arch: int = 80
53
+
54
+ def __init__(
55
+ self,
56
+ dtype: Type[cutlass.Numeric],
57
+ head_dim: int,
58
+ head_dim_v: Optional[int] = None,
59
+ qhead_per_kvhead: int = 1,
60
+ is_causal: bool = False,
61
+ is_local: bool = False,
62
+ pack_gqa: bool = True,
63
+ tile_m: int = 128,
64
+ tile_n: int = 128,
65
+ num_stages: int = 1,
66
+ num_threads: int = 128,
67
+ Q_in_regs: bool = False,
68
+ score_mod: Optional[cutlass.Constexpr] = None,
69
+ mask_mod: Optional[cutlass.Constexpr] = None,
70
+ has_aux_tensors: bool = False,
71
+ ):
72
+ """Initializes the configuration for a flash attention kernel.
73
+
74
+ All contiguous dimensions must be at least 16 bytes aligned, which means that the head dimension
75
+ should be a multiple of 8.
76
+
77
+ :param head_dim: head dimension
78
+ :type head_dim: int
79
+ :param tile_m: m block size
80
+ :type tile_m: int
81
+ :param tile_n: n block size
82
+ :type tile_n: int
83
+ :param num_threads: number of threads
84
+ :type num_threads: int
85
+ :param is_causal: is causal
86
+ :param score_mod: A callable that takes the attention scores and applies a modification.
87
+ Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Any``
88
+ :param mask_mod: A callable that takes the attention scores and returns a boolean representing whether that score should be masked.
89
+ Callable signature: ``mask_mod(batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Boolean``
90
+ """
91
+ self.dtype = dtype
92
+ # padding head_dim to a multiple of 16 as k_block_size
93
+ hdim_multiple_of = 16
94
+ self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
95
+ head_dim_v = head_dim_v if head_dim_v is not None else head_dim
96
+ self.same_hdim_kv = head_dim == head_dim_v
97
+ self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)
98
+ # Can save registers (and hence be faster) if we don't have to check hdim predication
99
+ self.check_hdim_oob = head_dim != self.tile_hdim
100
+ self.check_hdim_v_oob = head_dim_v != self.tile_hdimv
101
+ self.qhead_per_kvhead = qhead_per_kvhead
102
+ self.is_causal = is_causal
103
+ self.is_local = is_local
104
+ self.pack_gqa = pack_gqa
105
+ self.tile_m = tile_m
106
+ self.tile_n = tile_n
107
+ self.num_threads = num_threads
108
+ self.num_stages = num_stages
109
+ self.Q_in_regs = Q_in_regs
110
+ self.score_mod = score_mod
111
+ self.mask_mod = mask_mod
112
+ self.qk_acc_dtype = Float32
113
+ if const_expr(has_aux_tensors):
114
+ self.vec_size: cutlass.Constexpr = 1
115
+ else:
116
+ self.vec_size: cutlass.Constexpr = 2
117
+
118
+ @staticmethod
119
+ def can_implement(
120
+ dtype,
121
+ head_dim,
122
+ head_dim_v,
123
+ tile_m,
124
+ tile_n,
125
+ num_stages,
126
+ num_threads,
127
+ is_causal,
128
+ Q_in_regs=False,
129
+ ) -> bool:
130
+ """Check if the kernel can be implemented with the given parameters.
131
+
132
+ :param dtype: data type
133
+ :type dtype: cutlass.Numeric
134
+ :param head_dim: head dimension
135
+ :type head_dim: int
136
+ :param tile_m: m block size
137
+ :type tile_m: int
138
+ :param tile_n: n block size
139
+ :type tile_n: int
140
+ :param num_threads: number of threads
141
+ :type num_threads: int
142
+ :param is_causal: is causal
143
+ :type is_causal: bool
144
+
145
+ :return: True if the kernel can be implemented, False otherwise
146
+ :rtype: bool
147
+ """
148
+ if dtype not in [cutlass.Float16, cutlass.BFloat16]:
149
+ return False
150
+ if head_dim % 8 != 0:
151
+ return False
152
+ if head_dim_v % 8 != 0:
153
+ return False
154
+ if tile_n % 16 != 0:
155
+ return False
156
+ if num_threads % 32 != 0:
157
+ return False
158
+ # Check if block size setting is out of shared memory capacity
159
+ # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size
160
+ smem_usage_Q = tile_m * head_dim * 2
161
+ smem_usage_K = tile_n * head_dim * num_stages * 2
162
+ smem_usage_V = tile_n * head_dim_v * num_stages * 2
163
+ smem_usage_QV = (
164
+ (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V)
165
+ )
166
+ smem_usage = smem_usage_QV + smem_usage_K
167
+ # TODO: sm86 and sm89
168
+ smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80")
169
+ if smem_usage > smem_capacity:
170
+ return False
171
+ # Check if twice the block size is divisible by the number of threads
172
+ if (tile_m * 2) % num_threads != 0:
173
+ return False
174
+ return True
175
+
176
+ def _check_type(
177
+ self,
178
+ mQ_type: Type[cutlass.Numeric],
179
+ mK_type: Type[cutlass.Numeric],
180
+ mV_type: Type[cutlass.Numeric],
181
+ mO_type: Type[cutlass.Numeric],
182
+ mLSE_type: Type[cutlass.Numeric] | None,
183
+ mCuSeqlensQ_type: Type[cutlass.Numeric] | None,
184
+ mCuSeqlensK_type: Type[cutlass.Numeric] | None,
185
+ mSeqUsedQ_type: Type[cutlass.Numeric] | None,
186
+ mSeqUsedK_type: Type[cutlass.Numeric] | None,
187
+ ):
188
+ # Get the data type and check if it is fp16 or bf16
189
+ if const_expr(not (mQ_type == mK_type == mV_type == mO_type)):
190
+ raise TypeError("All tensors must have the same data type")
191
+ if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]):
192
+ raise TypeError("Only Float16 or BFloat16 is supported")
193
+ if const_expr(mLSE_type not in [None, Float32]):
194
+ raise TypeError("LSE tensor must be Float32")
195
+ if const_expr(mCuSeqlensQ_type not in [None, Int32]):
196
+ raise TypeError("cu_seqlens_q tensor must be Int32")
197
+ if const_expr(mCuSeqlensK_type not in [None, Int32]):
198
+ raise TypeError("cu_seqlens_k tensor must be Int32")
199
+ if const_expr(mSeqUsedQ_type not in [None, Int32]):
200
+ raise TypeError("seqused_q tensor must be Int32")
201
+ if const_expr(mSeqUsedK_type not in [None, Int32]):
202
+ raise TypeError("seqused_k tensor must be Int32")
203
+ assert mQ_type == self.dtype
204
+
205
+ def _setup_attributes(self):
206
+ # ///////////////////////////////////////////////////////////////////////////////
207
+ # Shared memory layout: Q/K/V
208
+ # ///////////////////////////////////////////////////////////////////////////////
209
+ sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom = (
210
+ self._get_smem_layout_atom()
211
+ )
212
+ self.sQ_layout = cute.tile_to_shape(
213
+ sQ_layout_atom,
214
+ (self.tile_m, self.tile_hdim),
215
+ (0, 1),
216
+ )
217
+ self.sK_layout = cute.tile_to_shape(
218
+ sK_layout_atom,
219
+ (self.tile_n, self.tile_hdim, self.num_stages),
220
+ (0, 1, 2),
221
+ )
222
+ self.sV_layout = cute.tile_to_shape(
223
+ sV_layout_atom,
224
+ (self.tile_n, self.tile_hdimv, self.num_stages),
225
+ (0, 1, 2),
226
+ )
227
+ self.sO_layout = cute.tile_to_shape(
228
+ sO_layout_atom,
229
+ (self.tile_m, self.tile_hdimv),
230
+ (0, 1),
231
+ )
232
+ if const_expr(sP_layout_atom is not None):
233
+ self.sP_layout = cute.tile_to_shape(
234
+ sP_layout_atom,
235
+ (self.tile_m, self.tile_n),
236
+ (0, 1),
237
+ )
238
+ else:
239
+ self.sP_layout = None
240
+
241
+ # ///////////////////////////////////////////////////////////////////////////////
242
+ # GMEM Tiled copy:
243
+ # ///////////////////////////////////////////////////////////////////////////////
244
+ # Thread layouts for copies
245
+ universal_copy_bits = 128
246
+ async_copy_elems = universal_copy_bits // self.dtype.width
247
+ # atom_async_copy: async copy atom for QKV load
248
+ atom_async_copy = cute.make_copy_atom(
249
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
250
+ self.dtype,
251
+ num_bits_per_copy=universal_copy_bits,
252
+ )
253
+ # atom_universal_copy: universal copy atom for O store
254
+ atom_universal_copy = cute.make_copy_atom(
255
+ cute.nvgpu.CopyUniversalOp(),
256
+ self.dtype,
257
+ num_bits_per_copy=universal_copy_bits,
258
+ )
259
+ # tQ_layout and tK_layout: thread layout for QK load
260
+ tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems
261
+ assert self.num_Q_load_threads % tQK_shape_dim_1 == 0, (
262
+ "num_threads must be divisible by tQK_shape_dim_1"
263
+ )
264
+ assert self.num_producer_threads % tQK_shape_dim_1 == 0, (
265
+ "num_threads must be divisible by tQK_shape_dim_1"
266
+ )
267
+ tQ_layout = cute.make_ordered_layout(
268
+ (self.num_Q_load_threads // tQK_shape_dim_1, tQK_shape_dim_1),
269
+ order=(1, 0),
270
+ )
271
+ tK_layout = cute.make_ordered_layout(
272
+ (self.num_producer_threads // tQK_shape_dim_1, tQK_shape_dim_1),
273
+ order=(1, 0),
274
+ )
275
+ # So that we don't have to check if we overshoot kBlockM when we load Q
276
+ assert self.tile_m % tQ_layout.shape[0] == 0
277
+ tV_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems
278
+ tV_layout = cute.make_ordered_layout(
279
+ (self.num_producer_threads // tV_shape_dim_1, tV_shape_dim_1),
280
+ order=(1, 0),
281
+ )
282
+ # TODO: need a different layout for O if O dtype is not the same as V dtype
283
+ # tO_layout: thread layout for O store
284
+ tO_layout = cute.make_ordered_layout(
285
+ (self.num_epilogue_threads // tV_shape_dim_1, tV_shape_dim_1),
286
+ order=(1, 0),
287
+ )
288
+ # So that we don't have to check if we overshoot kBlockM when we store O
289
+ assert self.tile_m % tO_layout.shape[0] == 0
290
+
291
+ # Value layouts for copies
292
+ vQKV_layout = cute.make_layout((1, async_copy_elems))
293
+ vO_layout = vQKV_layout
294
+
295
+ self.gmem_tiled_copy_Q = cute.make_tiled_copy_tv(atom_async_copy, tQ_layout, vQKV_layout)
296
+ self.gmem_tiled_copy_K = cute.make_tiled_copy_tv(atom_async_copy, tK_layout, vQKV_layout)
297
+ self.gmem_tiled_copy_V = cute.make_tiled_copy_tv(atom_async_copy, tV_layout, vQKV_layout)
298
+ # gmem_tiled_copy_O: tiled copy for O store
299
+ self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout)
300
+
301
+ def _get_smem_layout_atom(self):
302
+ raise NotImplementedError()
303
+
304
+ def _get_tiled_mma(self):
305
+ raise NotImplementedError()
306
+
307
+ def _get_shared_storage_cls(self):
308
+ raise NotImplementedError()
309
+
310
+ @cute.jit
311
+ def __call__(
312
+ self,
313
+ mQ: cute.Tensor,
314
+ mK: cute.Tensor,
315
+ mV: cute.Tensor,
316
+ mO: cute.Tensor,
317
+ mLSE: Optional[cute.Tensor],
318
+ softmax_scale: Float32,
319
+ stream: cuda.CUstream,
320
+ ):
321
+ """Configures and launches the flash attention kernel.
322
+
323
+ mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:
324
+ (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1)
325
+ """
326
+ raise NotImplementedError()
327
+
328
+ @cute.jit
329
+ def epilogue(
330
+ self,
331
+ acc_O: cute.Tensor,
332
+ lse: cute.Tensor,
333
+ mO: cute.Tensor,
334
+ mLSE: Optional[cute.Tensor],
335
+ sO: cute.Tensor,
336
+ seqlen: SeqlenInfoQK,
337
+ gmem_tiled_copy_O: cute.TiledCopy,
338
+ tma_atom_O: Optional[cute.CopyAtom],
339
+ tiled_mma: cute.TiledMma,
340
+ tidx: Int32,
341
+ m_block: Int32,
342
+ head_idx: Int32,
343
+ batch_idx: Int32,
344
+ ):
345
+ # store acc_O
346
+ rO = cute.make_fragment_like(acc_O, self.dtype)
347
+ rO.store(acc_O.load().to(self.dtype))
348
+ # Make sure all threads have finished reading V
349
+ cute.arch.barrier(
350
+ barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads
351
+ )
352
+ smem_copy_atom_O = utils.get_smem_store_atom(self.arch, self.dtype)
353
+ smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx)
354
+ taccOrO = smem_thr_copy_O.retile(rO)
355
+ taccOsO = smem_thr_copy_O.partition_D(sO)
356
+ # taccOsO = quack_copy_utils.partition_D_position_independent(smem_thr_copy_O, sO)
357
+ # copy acc O from rmem to smem with the smem copy atom
358
+ cute.copy(smem_copy_atom_O, taccOrO, taccOsO)
359
+
360
+ cO = cute.make_identity_tensor((self.tile_m, self.tile_hdimv))
361
+ pack_gqa = PackGQA(
362
+ self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead
363
+ )
364
+
365
+ # Write LSE from rmem -> gmem
366
+ if const_expr(mLSE is not None):
367
+ if const_expr(not seqlen.has_cu_seqlens_q):
368
+ mLSE_cur = mLSE[None, head_idx, batch_idx]
369
+ else:
370
+ offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q)
371
+ mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx])
372
+ if const_expr(not self.pack_gqa):
373
+ gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,))
374
+ gLSE_expanded_layout = cute.append(
375
+ gLSE.layout, cute.make_layout((self.tile_hdimv,), stride=(0,))
376
+ )
377
+ gLSE_expanded = cute.make_tensor(gLSE.iterator, gLSE_expanded_layout)
378
+ thr_mma = tiled_mma.get_slice(tidx)
379
+ taccOgLSE = utils.make_acc_tensor_mn_view(thr_mma.partition_C(gLSE_expanded))
380
+ assert cute.size(taccOgLSE, mode=[0]) == cute.size(lse)
381
+ taccOcO = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cO))
382
+ t0accOcO = utils.make_acc_tensor_mn_view(thr_mma.get_slice(0).partition_C(cO))
383
+ # Only the thread corresponding to column 0 writes out the lse to gmem
384
+ if taccOcO[0][1] == 0:
385
+ for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])):
386
+ if (
387
+ t0accOcO[m, 0][0]
388
+ < seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0]
389
+ ):
390
+ taccOgLSE[m, 0] = lse[m]
391
+ else:
392
+ pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q)
393
+
394
+ if const_expr(not seqlen.has_cu_seqlens_q):
395
+ mO_cur = mO[None, None, head_idx, batch_idx]
396
+ else:
397
+ offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q)
398
+ mO_cur = cute.domain_offset((offset, 0), mO[None, None, head_idx])
399
+ # thr_mma = tiled_mma.get_slice(tidx)
400
+ # taccOgO = thr_mma.partition_C(gO)
401
+ # cute.autovec_copy(rO, taccOgO)
402
+ # sync to make sure all smem stores are done
403
+ if const_expr(self.use_tma_O):
404
+ # ensure smem writes are visible to TMA
405
+ cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
406
+ cute.arch.barrier_arrive(
407
+ barrier_id=int(NamedBarrierFwd.Epilogue),
408
+ number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,
409
+ )
410
+ gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0))
411
+ store_O, _, _ = copy_utils.tma_get_copy_fn(
412
+ tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True
413
+ )
414
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
415
+ if warp_idx == 4:
416
+ cute.arch.barrier(
417
+ barrier_id=int(NamedBarrierFwd.Epilogue),
418
+ number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,
419
+ )
420
+ store_O()
421
+ cute.arch.cp_async_bulk_commit_group()
422
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
423
+ else:
424
+ cute.arch.barrier(
425
+ barrier_id=int(NamedBarrierFwd.Epilogue),
426
+ number_of_threads=self.num_epilogue_threads,
427
+ )
428
+ gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
429
+ tOsO = gmem_thr_copy_O.partition_S(sO)
430
+ tOrO = cute.make_fragment_like(tOsO, self.dtype)
431
+ # load acc O from smem to rmem for wider vectorization
432
+ cute.autovec_copy(tOsO, tOrO)
433
+ if const_expr(not self.pack_gqa):
434
+ gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0))
435
+ tOgO = gmem_thr_copy_O.partition_D(gO)
436
+ tOcO = gmem_thr_copy_O.partition_S(cO)
437
+ t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
438
+ tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])
439
+ # copy acc O from rmem to gmem
440
+ for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
441
+ if (
442
+ t0OcO[0, rest_m, 0][0]
443
+ < seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0]
444
+ ):
445
+ cute.copy(
446
+ gmem_tiled_copy_O,
447
+ tOrO[None, rest_m, None],
448
+ tOgO[None, rest_m, None],
449
+ pred=tOpO[None, rest_m, None]
450
+ if const_expr(self.check_hdim_v_oob)
451
+ else None,
452
+ )
453
+ else:
454
+ pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q)
455
+
456
+ @cute.jit
457
+ def advance_pipeline(self, pipeline_index):
458
+ return pipeline_index + 1 if pipeline_index < self.num_stages - 1 else 0
459
+
460
+ @cute.jit
461
+ def load_Q(
462
+ self,
463
+ gmem_thr_copy: cute.TiledCopy,
464
+ gQ: cute.Tensor,
465
+ sQ: cute.Tensor,
466
+ block: Int32,
467
+ seqlen: Int32,
468
+ headdim: Int32,
469
+ ):
470
+ tQsQ, tQgQ = gmem_thr_copy.partition_D(sQ), gmem_thr_copy.partition_S(gQ)
471
+ cQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim))
472
+ tQcQ = gmem_thr_copy.partition_S(cQ)
473
+ t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ)
474
+ tQpQ = utils.predicate_k(tQcQ, limit=headdim)
475
+ for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
476
+ # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit
477
+ # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time.
478
+ if t0QcQ[0, m, 0][0] < seqlen - block * self.tile_m - tQcQ[0][0]:
479
+ cute.copy(
480
+ gmem_thr_copy,
481
+ tQgQ[None, m, None],
482
+ tQsQ[None, m, None],
483
+ pred=tQpQ[None, m, None] if const_expr(self.check_hdim_oob) else None,
484
+ )
485
+ # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
486
+
487
+ @cute.jit
488
+ def load_K(
489
+ self,
490
+ gmem_tiled_copy: cute.TiledCopy,
491
+ tKgK: cute.Tensor,
492
+ tKsK: cute.Tensor,
493
+ tKcK: cute.Tensor,
494
+ t0KcK: cute.Tensor,
495
+ tKpK: cute.Tensor,
496
+ block: Int32,
497
+ smem_pipe_write: Int32,
498
+ seqlen: Int32,
499
+ need_predicates: cutlass.Constexpr,
500
+ ):
501
+ # Do we need to check if we overshoot kBlockN when we load K?
502
+ is_even_n_smem_k = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0
503
+ if const_expr(need_predicates or not is_even_n_smem_k):
504
+ # Instead of using tKcK, we using t0KcK and subtract the offset from the limit
505
+ # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time.
506
+ if const_expr(is_even_n_smem_k):
507
+ seqlen_limit = seqlen - block * self.tile_n
508
+ else:
509
+ if const_expr(not need_predicates):
510
+ seqlen_limit = self.tile_n
511
+ else:
512
+ seqlen_limit = cutlass.min(seqlen - block * self.tile_n, self.tile_n)
513
+ seqlen_limit -= tKcK[0][0]
514
+ for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])):
515
+ if t0KcK[0, n, 0][0] < seqlen_limit:
516
+ cute.copy(
517
+ gmem_tiled_copy,
518
+ tKgK[None, n, None, block],
519
+ tKsK[
520
+ None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0
521
+ ],
522
+ pred=tKpK[None, n, None] if const_expr(self.check_hdim_oob) else None,
523
+ )
524
+ # We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
525
+ else:
526
+ cute.copy(
527
+ gmem_tiled_copy,
528
+ tKgK[None, None, None, block],
529
+ tKsK[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0],
530
+ pred=tKpK if const_expr(self.check_hdim_oob) else None,
531
+ )
532
+
533
+ @cute.jit
534
+ def load_V(
535
+ self,
536
+ gmem_tiled_copy: cute.TiledCopy,
537
+ tVgV: cute.Tensor,
538
+ tVsV: cute.Tensor,
539
+ tVcV: cute.Tensor,
540
+ t0VcV: cute.Tensor,
541
+ tVpV: cute.Tensor,
542
+ block: Int32,
543
+ smem_pipe_write: Int32,
544
+ seqlen: Int32,
545
+ need_predicates: cutlass.Constexpr,
546
+ ):
547
+ # Do we need to check if we overshoot kBlockN when we load V?
548
+ is_even_n_smem_v = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0
549
+ if const_expr(need_predicates or not is_even_n_smem_v):
550
+ for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])):
551
+ # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked
552
+ if (
553
+ is_even_n_smem_v
554
+ or n < cute.size(tVsV.shape[1]) - 1
555
+ or tVcV[0, n, 0][0] < self.tile_n
556
+ ):
557
+ predicate = tVpV[None, n, None] if const_expr(self.check_hdim_v_oob) else None
558
+ if const_expr(need_predicates):
559
+ seqlen_limit = seqlen - block * self.tile_n - tVcV[0][0]
560
+ predicate_n = t0VcV[0, n, 0][0] < seqlen_limit
561
+ predicate = cute.make_fragment_like(tVpV[None, 0, None])
562
+ for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
563
+ for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
564
+ predicate[i, k] = (
565
+ tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True
566
+ ) and predicate_n
567
+ cute.copy(
568
+ gmem_tiled_copy,
569
+ tVgV[None, n, None, block],
570
+ tVsV[
571
+ None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0
572
+ ],
573
+ pred=predicate,
574
+ )
575
+ else:
576
+ cute.copy(
577
+ gmem_tiled_copy,
578
+ tVgV[None, None, None, block],
579
+ tVsV[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0],
580
+ pred=tVpV if const_expr(self.check_hdim_v_oob) else None,
581
+ )
582
+
583
+
584
+ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
585
+ def _get_smem_layout_atom(self):
586
+ sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdim)
587
+ sK_layout_atom = sQ_layout_atom
588
+ sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdimv)
589
+ sO_layout_atom = sV_layout_atom
590
+ sP_layout_atom = None
591
+ return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom
592
+
593
+ def _get_tiled_mma(self):
594
+ tiled_mma_qk = cute.make_tiled_mma(
595
+ warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)),
596
+ (self.num_threads // 32, 1, 1),
597
+ permutation_mnk=(self.num_threads // 32 * 16, 16, 16),
598
+ )
599
+ tiled_mma_pv = cute.make_tiled_mma(
600
+ warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)),
601
+ (self.num_threads // 32, 1, 1),
602
+ permutation_mnk=(self.num_threads // 32 * 16, 16, 16),
603
+ )
604
+ return tiled_mma_qk, tiled_mma_pv
605
+
606
+ def _get_shared_storage_cls(self):
607
+ sQ_struct, sK_struct, sV_struct = [
608
+ cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024]
609
+ for layout in (self.sQ_layout, self.sK_layout, self.sV_layout)
610
+ ]
611
+ cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout))
612
+ sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024]
613
+
614
+ @cute.struct
615
+ class SharedStorageQKV:
616
+ sV: sV_struct
617
+ sQ: sQ_struct
618
+ sK: sK_struct
619
+
620
+ @cute.struct
621
+ class SharedStorageSharedQV:
622
+ sQ: sQV_struct
623
+ sK: sK_struct
624
+
625
+ return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV
626
+
627
+ @cute.jit
628
+ def __call__(
629
+ self,
630
+ mQ: cute.Tensor,
631
+ mK: cute.Tensor,
632
+ mV: cute.Tensor,
633
+ mO: cute.Tensor,
634
+ mLSE: Optional[cute.Tensor],
635
+ stream: cuda.CUstream,
636
+ softmax_scale: Optional[Float32] = None,
637
+ window_size_left: Optional[Int32] = None,
638
+ window_size_right: Optional[Int32] = None,
639
+ learnable_sink: Optional[cute.Tensor] = None,
640
+ aux_tensors=None,
641
+ ):
642
+ """Configures and launches the flash attention kernel.
643
+
644
+ mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:
645
+ (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1)
646
+ """
647
+ assert learnable_sink is None, "Learnable sink is not supported in this kernel"
648
+ self._check_type(
649
+ *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE))
650
+ )
651
+ tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma()
652
+ self.num_mma_threads = tiled_mma_pv.size
653
+ self.num_producer_threads = self.num_threads
654
+ self.num_Q_load_threads = self.num_threads
655
+ self.num_epilogue_threads = self.num_threads
656
+ # self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None
657
+ self.use_tma_O = self.arch >= 90
658
+ self._setup_attributes()
659
+ SharedStorage = self._get_shared_storage_cls()
660
+ # Assume all strides are divisible by 128 bits except the last stride
661
+ new_stride = lambda t: (
662
+ *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
663
+ t.stride[-1],
664
+ )
665
+ mQ, mK, mV, mO = [
666
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
667
+ for t in (mQ, mK, mV, mO)
668
+ ]
669
+ mQ, mK, mV, mO = [
670
+ cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0]))
671
+ for t in (mQ, mK, mV, mO)
672
+ ]
673
+ mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=[2, 1, 0]))
674
+ # grid_dim: (m_block, num_head, batch_size)
675
+ grid_dim = (
676
+ cute.ceil_div(mQ.shape[0], self.tile_m),
677
+ cute.size(mQ.shape[2]),
678
+ cute.size(mQ.shape[3]),
679
+ )
680
+ LOG2_E = math.log2(math.e)
681
+ if const_expr(self.score_mod is None):
682
+ softmax_scale_log2 = Float32(softmax_scale * LOG2_E)
683
+ softmax_scale = None
684
+ else:
685
+ # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk
686
+ # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base
687
+ # and correctly apply the softmax_scale prior to score_mod in the softmax step
688
+ softmax_scale_log2 = Float32(LOG2_E)
689
+ softmax_scale = Float32(softmax_scale)
690
+
691
+ fastdiv_mods = None
692
+ if const_expr(aux_tensors is not None):
693
+ seqlen_q = cute.size(mQ.shape[0]) // (
694
+ self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
695
+ )
696
+ seqlen_k = cute.size(mK.shape[0])
697
+ seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
698
+ seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
699
+ fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
700
+
701
+ self.kernel(
702
+ mQ,
703
+ mK,
704
+ mV,
705
+ mO,
706
+ mLSE,
707
+ softmax_scale_log2,
708
+ softmax_scale,
709
+ window_size_left,
710
+ window_size_right,
711
+ self.sQ_layout,
712
+ self.sK_layout,
713
+ self.sV_layout,
714
+ self.sO_layout,
715
+ self.sP_layout,
716
+ self.gmem_tiled_copy_Q,
717
+ self.gmem_tiled_copy_K,
718
+ self.gmem_tiled_copy_V,
719
+ self.gmem_tiled_copy_O,
720
+ tiled_mma_qk,
721
+ tiled_mma_pv,
722
+ SharedStorage,
723
+ aux_tensors,
724
+ fastdiv_mods,
725
+ ).launch(
726
+ grid=grid_dim,
727
+ block=[self.num_threads, 1, 1],
728
+ smem=SharedStorage.size_in_bytes(),
729
+ stream=stream,
730
+ )
731
+
732
+ @cute.kernel
733
+ def kernel(
734
+ self,
735
+ mQ: cute.Tensor,
736
+ mK: cute.Tensor,
737
+ mV: cute.Tensor,
738
+ mO: cute.Tensor,
739
+ mLSE: Optional[cute.Tensor],
740
+ softmax_scale_log2: Float32,
741
+ softmax_scale: Optional[Float32],
742
+ window_size_left: Optional[Int32],
743
+ window_size_right: Optional[Int32],
744
+ sQ_layout: cute.ComposedLayout,
745
+ sK_layout: cute.ComposedLayout,
746
+ sV_layout: cute.ComposedLayout,
747
+ sO_layout: cute.ComposedLayout,
748
+ sP_layout: cute.ComposedLayout | None,
749
+ gmem_tiled_copy_Q: cute.TiledCopy,
750
+ gmem_tiled_copy_K: cute.TiledCopy,
751
+ gmem_tiled_copy_V: cute.TiledCopy,
752
+ gmem_tiled_copy_O: cute.TiledCopy,
753
+ tiled_mma_qk: cute.TiledMma,
754
+ tiled_mma_pv: cute.TiledMma,
755
+ SharedStorage: cutlass.Constexpr,
756
+ aux_tensors=None,
757
+ fastdiv_mods=None,
758
+ ):
759
+ # Thread index, block index
760
+ tidx, _, _ = cute.arch.thread_idx()
761
+ m_block, num_head, batch_size = cute.arch.block_idx()
762
+
763
+ block_info = BlockInfo(
764
+ self.tile_m,
765
+ self.tile_n,
766
+ self.is_causal,
767
+ self.is_local,
768
+ False, # is_split_kv
769
+ window_size_left,
770
+ window_size_right,
771
+ qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
772
+ )
773
+ seqlen = SeqlenInfoQK.create(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0])
774
+ n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
775
+ # TODO: return early if n_block_max == 0
776
+ # if self.is_causal:
777
+ # if n_block_max <= 0:
778
+ # return
779
+ n_block = n_block_max - 1
780
+
781
+ # ///////////////////////////////////////////////////////////////////////////////
782
+ # Get the appropriate tiles for this thread block.
783
+ # ///////////////////////////////////////////////////////////////////////////////
784
+ blkQ_shape = (self.tile_m, self.tile_hdim)
785
+ blkK_shape = (self.tile_n, self.tile_hdim)
786
+ blkV_shape = (self.tile_n, self.tile_hdimv)
787
+ gQ = cute.local_tile(mQ[None, None, num_head, batch_size], blkQ_shape, (m_block, 0))
788
+ num_head_kv = num_head // self.qhead_per_kvhead
789
+ gK = cute.local_tile(mK[None, None, num_head_kv, batch_size], blkK_shape, (None, 0))
790
+ gV = cute.local_tile(mV[None, None, num_head_kv, batch_size], blkV_shape, (None, 0))
791
+
792
+ # ///////////////////////////////////////////////////////////////////////////////
793
+ # Get shared memory buffer
794
+ # ///////////////////////////////////////////////////////////////////////////////
795
+ smem = cutlass.utils.SmemAllocator()
796
+ storage = smem.allocate(SharedStorage)
797
+ sQ = storage.sQ.get_tensor(sQ_layout)
798
+ sK = storage.sK.get_tensor(sK_layout)
799
+ if const_expr(not self.Q_in_regs):
800
+ sV = storage.sV.get_tensor(sV_layout)
801
+ else:
802
+ sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout)
803
+ # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma
804
+ sVt = utils.transpose_view(sV)
805
+
806
+ gmem_thr_copy_K = gmem_tiled_copy_K.get_slice(tidx)
807
+ gmem_thr_copy_V = gmem_tiled_copy_V.get_slice(tidx)
808
+ # (CPY_Atom, CPY_N, CPY_K, n_block)
809
+ tKsK, tKgK = gmem_thr_copy_K.partition_D(sK), gmem_thr_copy_K.partition_S(gK)
810
+ # (CPY_Atom, CPY_N, CPY_K, n_block)
811
+ tVsV, tVgV = gmem_thr_copy_V.partition_D(sV), gmem_thr_copy_V.partition_S(gV)
812
+
813
+ # ///////////////////////////////////////////////////////////////////////////////
814
+ # Tile MMA compute thread partitions and allocate accumulators
815
+ # ///////////////////////////////////////////////////////////////////////////////
816
+ thr_mma_qk = tiled_mma_qk.get_slice(tidx)
817
+ thr_mma_pv = tiled_mma_pv.get_slice(tidx)
818
+ tSrQ = thr_mma_qk.make_fragment_A(thr_mma_qk.partition_A(sQ))
819
+ tSrK = thr_mma_qk.make_fragment_B(thr_mma_qk.partition_B(sK[None, None, 0]))
820
+ tOrVt = thr_mma_pv.make_fragment_B(thr_mma_pv.partition_B(sVt[None, None, 0]))
821
+ acc_shape_O = thr_mma_pv.partition_shape_C((self.tile_m, self.tile_hdimv))
822
+ acc_O = cute.make_fragment(acc_shape_O, Float32)
823
+ acc_O.fill(0.0)
824
+
825
+ # ///////////////////////////////////////////////////////////////////////////////
826
+ # Smem copy atom tiling
827
+ # ///////////////////////////////////////////////////////////////////////////////
828
+ smem_copy_atom_QK = cute.make_copy_atom(
829
+ warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4),
830
+ self.dtype,
831
+ )
832
+ smem_copy_atom_V = cute.make_copy_atom(
833
+ warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4),
834
+ self.dtype,
835
+ )
836
+ smem_thr_copy_Q = utils.make_tiled_copy_A(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx)
837
+ smem_thr_copy_K = utils.make_tiled_copy_B(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx)
838
+ smem_thr_copy_V = utils.make_tiled_copy_B(smem_copy_atom_V, tiled_mma_pv).get_slice(tidx)
839
+
840
+ tSsQ = smem_thr_copy_Q.partition_S(sQ)
841
+ tSsK = smem_thr_copy_K.partition_S(sK)
842
+ tOsVt = smem_thr_copy_V.partition_S(sVt)
843
+
844
+ # ///////////////////////////////////////////////////////////////////////////////
845
+ # Predicate: Mark indices that need to copy when problem_shape isn't a multiple
846
+ # of tile_shape
847
+ # ///////////////////////////////////////////////////////////////////////////////
848
+ # Construct identity layout for KV
849
+ cK = cute.make_identity_tensor((self.tile_n, self.tile_hdim))
850
+ tKcK = gmem_thr_copy_K.partition_S(cK)
851
+ t0KcK = gmem_thr_copy_K.get_slice(0).partition_S(cK)
852
+ if const_expr(self.tile_hdim == self.tile_hdimv):
853
+ tVcV = tKcK
854
+ t0VcV = t0KcK
855
+ else:
856
+ cV = cute.make_identity_tensor((self.tile_n, self.tile_hdimv))
857
+ tVcV = gmem_thr_copy_V.partition_S(cV)
858
+ t0VcV = gmem_thr_copy_V.get_slice(0).partition_S(cV)
859
+ # Allocate predicate tensors for m and n, here we only allocate the tile of k, and
860
+ # use "if" on the mn dimension.
861
+ # This is to reduce register pressure and gets 2-3% performance gain.
862
+ tKpK = utils.predicate_k(tKcK, limit=mK.shape[1])
863
+ if const_expr(self.same_hdim_kv):
864
+ tVpV = tKpK
865
+ else:
866
+ tVpV = utils.predicate_k(tVcV, limit=mV.shape[1])
867
+
868
+ # shape: (atom_v_m * rest_m)
869
+ softmax = Softmax.create(
870
+ softmax_scale_log2,
871
+ num_rows=acc_O.shape[0][0] * acc_O.shape[1],
872
+ softmax_scale=softmax_scale,
873
+ )
874
+ softmax.reset()
875
+
876
+ # group parameters for compute_one_n_block
877
+ mma_params = SimpleNamespace(
878
+ thr_mma_qk=thr_mma_qk,
879
+ thr_mma_pv=thr_mma_pv,
880
+ tSrQ=tSrQ,
881
+ tSrK=tSrK,
882
+ tOrVt=tOrVt,
883
+ acc_O=acc_O,
884
+ )
885
+ smem_copy_params = SimpleNamespace(
886
+ smem_thr_copy_Q=smem_thr_copy_Q,
887
+ smem_thr_copy_K=smem_thr_copy_K,
888
+ smem_thr_copy_V=smem_thr_copy_V,
889
+ tSsQ=tSsQ,
890
+ tSsK=tSsK,
891
+ tOsVt=tOsVt,
892
+ )
893
+ load_K = partial(
894
+ self.load_K, gmem_tiled_copy_K, tKgK, tKsK, tKcK, t0KcK, tKpK, seqlen=seqlen.seqlen_k
895
+ )
896
+ load_V = partial(
897
+ self.load_V, gmem_tiled_copy_V, tVgV, tVsV, tVcV, t0VcV, tVpV, seqlen=seqlen.seqlen_k
898
+ )
899
+
900
+ compute_one_n_block = partial(
901
+ self.compute_one_n_block,
902
+ mma_params=mma_params,
903
+ smem_copy_params=smem_copy_params,
904
+ softmax=softmax,
905
+ load_K=load_K,
906
+ load_V=load_V,
907
+ score_mod=self.score_mod,
908
+ batch_idx=batch_size,
909
+ head_idx=num_head,
910
+ m_block=m_block,
911
+ aux_tensors=aux_tensors,
912
+ fastdiv_mods=fastdiv_mods,
913
+ )
914
+
915
+ # ///////////////////////////////////////////////////////////////////////////////
916
+ # Prologue
917
+ # ///////////////////////////////////////////////////////////////////////////////
918
+ # Start async loads of the last mn-tile, where we take care of the mn residue
919
+ gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx)
920
+ self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, headdim=mQ.shape[1])
921
+ cute.arch.cp_async_commit_group()
922
+
923
+ def preprocess_Q():
924
+ cute.arch.cp_async_wait_group(self.num_stages * 2 - 1)
925
+ if const_expr(self.Q_in_regs):
926
+ cute.arch.barrier()
927
+ tSrQ_copy_view = smem_thr_copy_Q.retile(tSrQ)
928
+ cute.copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view)
929
+
930
+ # If Q_in_regs, we load Q, then load 1 stage of K, then (optionally) rotate Q and
931
+ # read from smem_q to registers, then load V.
932
+ # If !Q_in_regs, we load Q, load all stages of K & V, then (optionally) rotate Q.
933
+ if const_expr(self.Q_in_regs):
934
+ load_K(n_block, smem_pipe_write=0, need_predicates=True)
935
+ cute.arch.cp_async_commit_group()
936
+ preprocess_Q()
937
+ cute.arch.barrier() # Make sure all threads have read smem_q before loading V
938
+
939
+ for stage in cutlass.range_constexpr(self.num_stages):
940
+ if const_expr(not self.Q_in_regs or stage > 0):
941
+ if stage == 0 or n_block - stage >= 0:
942
+ load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0)
943
+ cute.arch.cp_async_commit_group()
944
+ if const_expr(stage < self.num_stages - 1):
945
+ if stage == 0 or n_block - stage >= 0:
946
+ load_V(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0)
947
+ cute.arch.cp_async_commit_group()
948
+ if const_expr(not self.Q_in_regs):
949
+ preprocess_Q()
950
+
951
+ # ///////////////////////////////////////////////////////////////////////////////
952
+ # Mainloop
953
+ # ///////////////////////////////////////////////////////////////////////////////
954
+ # Start processing of the first n-block.
955
+ # For performance reason, we separate out two kinds of iterations:
956
+ # those that need masking on S, and those that don't.
957
+ # We need masking on S for the very last block when K and V has length not multiple of tile_n.
958
+ # We also need masking on S if it's causal, for the last several blocks.
959
+ mask = AttentionMask(
960
+ self.tile_m,
961
+ self.tile_n,
962
+ seqlen.seqlen_q,
963
+ seqlen.seqlen_k,
964
+ window_size_left,
965
+ window_size_right,
966
+ self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
967
+ )
968
+ mask_fn = partial(
969
+ mask.apply_mask,
970
+ m_block=m_block,
971
+ thr_mma=thr_mma_qk,
972
+ mask_causal=self.is_causal,
973
+ mask_local=self.is_local,
974
+ fastdiv_mods=fastdiv_mods if const_expr(self.mask_mod is not None) else None,
975
+ )
976
+
977
+ # First iteration with seqlen masking
978
+ smem_pipe_read = Int32(0)
979
+ smem_pipe_write = Int32(self.num_stages - 1)
980
+ compute_one_n_block(
981
+ n_block,
982
+ smem_pipe_read,
983
+ smem_pipe_write,
984
+ is_first_n_block=True,
985
+ check_inf=True,
986
+ mask_fn=partial(mask_fn, mask_seqlen=True),
987
+ )
988
+ smem_pipe_read = self.advance_pipeline(smem_pipe_read)
989
+ smem_pipe_write = self.advance_pipeline(smem_pipe_write)
990
+ # Next couple of iterations with causal masking
991
+ if const_expr(self.is_causal or self.is_local):
992
+ n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask(
993
+ seqlen, m_block, n_block_min
994
+ )
995
+ for n_tile in cutlass.range(n_block_max - 1 - n_block_min_causal_local_mask, unroll=1):
996
+ n_block = n_block_max - 2 - n_tile
997
+ compute_one_n_block(
998
+ n_block,
999
+ smem_pipe_read,
1000
+ smem_pipe_write,
1001
+ check_inf=True,
1002
+ mask_fn=partial(mask_fn, mask_seqlen=False),
1003
+ )
1004
+ smem_pipe_read = self.advance_pipeline(smem_pipe_read)
1005
+ smem_pipe_write = self.advance_pipeline(smem_pipe_write)
1006
+ # The remaining iterations have no masking
1007
+ for n_tile in cutlass.range(n_block, unroll=1):
1008
+ compute_one_n_block(
1009
+ n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True
1010
+ )
1011
+ smem_pipe_read = self.advance_pipeline(smem_pipe_read)
1012
+ smem_pipe_write = self.advance_pipeline(smem_pipe_write)
1013
+ # TODO: local
1014
+
1015
+ # normalize acc_O by row_sum and calculate the lse
1016
+ row_scale = softmax.finalize()
1017
+ softmax.rescale_O(acc_O, row_scale)
1018
+
1019
+ # ///////////////////////////////////////////////////////////////////////////////
1020
+ # Epilogue
1021
+ # ///////////////////////////////////////////////////////////////////////////////
1022
+ # reuse sQ's data iterator
1023
+ sO = cute.make_tensor(sQ.iterator, sO_layout)
1024
+ self.epilogue(
1025
+ acc_O,
1026
+ softmax.row_sum,
1027
+ mO,
1028
+ mLSE,
1029
+ sO,
1030
+ seqlen,
1031
+ gmem_tiled_copy_O,
1032
+ None,
1033
+ tiled_mma_pv,
1034
+ tidx,
1035
+ m_block,
1036
+ num_head,
1037
+ batch_size,
1038
+ )
1039
+
1040
+ @cute.jit
1041
+ def compute_one_n_block(
1042
+ self,
1043
+ n_block: Int32,
1044
+ smem_pipe_read: Int32,
1045
+ smem_pipe_write: Int32,
1046
+ mma_params: SimpleNamespace,
1047
+ smem_copy_params: SimpleNamespace,
1048
+ softmax: Softmax,
1049
+ load_K: Callable,
1050
+ load_V: Callable,
1051
+ score_mod: Callable | None,
1052
+ batch_idx: cutlass.Int32,
1053
+ head_idx: cutlass.Int32,
1054
+ m_block: cutlass.Int32,
1055
+ seqlen: SeqlenInfoQK,
1056
+ aux_tensors=None,
1057
+ fastdiv_mods=None,
1058
+ mask_fn: Optional[Callable] = None,
1059
+ is_first_n_block: cutlass.Constexpr = False,
1060
+ check_inf: cutlass.Constexpr = True,
1061
+ ):
1062
+ """Compute one n_block of S/O.
1063
+
1064
+ This function provides different variants for processing the first n block versus
1065
+ subsequent blocks.
1066
+ """
1067
+
1068
+ def sync():
1069
+ cute.arch.cp_async_wait_group(self.num_stages * 2 - 2)
1070
+ cute.arch.barrier()
1071
+
1072
+ acc_shape_S = mma_params.thr_mma_qk.partition_shape_C((self.tile_m, self.tile_n))
1073
+ acc_S = cute.make_fragment(acc_shape_S, Float32)
1074
+ acc_S.fill(0.0)
1075
+ # wait for smem tile QK before mma calculation for S
1076
+ sync()
1077
+
1078
+ # need predicates for the first tile
1079
+ def load_V_next():
1080
+ if self.num_stages == 1 or n_block - self.num_stages + 1 >= 0:
1081
+ load_V(
1082
+ n_block - self.num_stages + 1,
1083
+ smem_pipe_write,
1084
+ need_predicates=is_first_n_block and self.num_stages == 1,
1085
+ )
1086
+ cute.arch.cp_async_commit_group()
1087
+
1088
+ load_V_next()
1089
+ sm80_utils.gemm(
1090
+ mma_params.thr_mma_qk,
1091
+ acc_S,
1092
+ mma_params.tSrQ,
1093
+ mma_params.tSrK,
1094
+ smem_copy_params.tSsQ,
1095
+ smem_copy_params.tSsK[
1096
+ None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0
1097
+ ],
1098
+ smem_copy_params.smem_thr_copy_Q,
1099
+ smem_copy_params.smem_thr_copy_K,
1100
+ # hook_fn=load_V_next,
1101
+ A_in_regs=self.Q_in_regs,
1102
+ )
1103
+ if const_expr(score_mod is not None):
1104
+ self.apply_score_mod(
1105
+ mma_params.thr_mma_qk,
1106
+ batch_idx,
1107
+ head_idx,
1108
+ m_block,
1109
+ acc_S,
1110
+ n_block,
1111
+ seqlen,
1112
+ softmax_scale=softmax.softmax_scale,
1113
+ aux_tensors=aux_tensors,
1114
+ fastdiv_mods=fastdiv_mods,
1115
+ )
1116
+
1117
+ smem_pipe_write = self.advance_pipeline(smem_pipe_write)
1118
+
1119
+ def load_K_next():
1120
+ if n_block - self.num_stages >= 0:
1121
+ load_K(n_block - self.num_stages, smem_pipe_write, need_predicates=False)
1122
+ cute.arch.cp_async_commit_group()
1123
+
1124
+ # wait for smem tile V for O
1125
+ if const_expr(self.num_stages == 1):
1126
+ sync()
1127
+ load_K_next()
1128
+ if const_expr(mask_fn is not None):
1129
+ mask_fn(acc_S, n_block=n_block)
1130
+ row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf)
1131
+ softmax.rescale_O(mma_params.acc_O, row_scale)
1132
+ rP = cute.make_fragment_like(acc_S, self.dtype)
1133
+ rP.store(acc_S.load().to(self.dtype))
1134
+ tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout))
1135
+ if const_expr(self.num_stages > 1):
1136
+ sync()
1137
+ load_K_next()
1138
+ sm80_utils.gemm_rs(
1139
+ mma_params.thr_mma_pv,
1140
+ mma_params.acc_O,
1141
+ tOrP,
1142
+ mma_params.tOrVt,
1143
+ smem_copy_params.tOsVt[
1144
+ None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0
1145
+ ],
1146
+ smem_copy_params.smem_thr_copy_V,
1147
+ # hook_fn=load_K_next,
1148
+ )
1149
+ # if const_expr(self.num_stages > 1):
1150
+ # load_K_next()
1151
+
1152
+
1153
+ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
1154
+ arch = 90
1155
+
1156
+ def __init__(
1157
+ self,
1158
+ *args,
1159
+ intra_wg_overlap: bool = True,
1160
+ mma_pv_is_rs: bool = True,
1161
+ **kwargs,
1162
+ ):
1163
+ super().__init__(*args, **kwargs)
1164
+ self.intra_wg_overlap = intra_wg_overlap
1165
+ self.mma_pv_is_rs = mma_pv_is_rs
1166
+ self.buffer_align_bytes = 1024
1167
+
1168
+ def _get_smem_layout_atom(self):
1169
+ sQ_layout_atom = warpgroup.make_smem_layout_atom(
1170
+ sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim),
1171
+ self.dtype,
1172
+ )
1173
+ sK_layout_atom = sQ_layout_atom
1174
+ sV_layout_atom = warpgroup.make_smem_layout_atom(
1175
+ sm90_utils_basic.get_smem_layout_atom(
1176
+ LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv
1177
+ ),
1178
+ self.dtype,
1179
+ )
1180
+ sO_layout_atom = sV_layout_atom
1181
+ if not self.mma_pv_is_rs:
1182
+ sP_layout_atom = warpgroup.make_smem_layout_atom(
1183
+ sm90_utils_basic.get_smem_layout_atom(
1184
+ LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n
1185
+ ),
1186
+ self.dtype,
1187
+ )
1188
+ else:
1189
+ sP_layout_atom = None
1190
+ return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom
1191
+
1192
+ def _get_tiled_mma(self):
1193
+ tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma(
1194
+ self.dtype,
1195
+ self.dtype,
1196
+ warpgroup.OperandMajorMode.K,
1197
+ warpgroup.OperandMajorMode.K,
1198
+ Float32,
1199
+ atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512
1200
+ tiler_mn=(64, self.tile_n),
1201
+ )
1202
+ tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma(
1203
+ self.dtype,
1204
+ self.dtype,
1205
+ warpgroup.OperandMajorMode.K,
1206
+ warpgroup.OperandMajorMode.MN,
1207
+ Float32,
1208
+ atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512
1209
+ tiler_mn=(64, self.tile_hdimv),
1210
+ a_source=warpgroup.OperandSource.RMEM
1211
+ if self.mma_pv_is_rs
1212
+ else warpgroup.OperandSource.SMEM,
1213
+ )
1214
+ tiled_mma_pv_rs = sm90_utils_basic.make_trivial_tiled_mma(
1215
+ self.dtype,
1216
+ self.dtype,
1217
+ warpgroup.OperandMajorMode.K,
1218
+ warpgroup.OperandMajorMode.MN,
1219
+ Float32,
1220
+ atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512
1221
+ tiler_mn=(64, self.tile_hdimv),
1222
+ a_source=warpgroup.OperandSource.RMEM,
1223
+ )
1224
+ return tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs
1225
+
1226
+ def _get_shared_storage_cls(self):
1227
+ # If we use cp.async to load Q, we want sQ to align to 1024 bytes
1228
+ sQ_struct, sK_struct, sV_struct = [
1229
+ cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes]
1230
+ for layout in (self.sQ_layout, self.sK_layout, self.sV_layout)
1231
+
1232
+ ]
1233
+ cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout))
1234
+ sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024]
1235
+ cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0
1236
+ sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024]
1237
+ # 1 for Q, 1 for O, self.num_stages*2 for K, self.num_stages*2 for V,
1238
+ mbar_ptr_QO_struct = cute.struct.MemRange[cutlass.Int64, 2]
1239
+ mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2]
1240
+ mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2]
1241
+
1242
+ @cute.struct
1243
+ class SharedStorageQKV:
1244
+ mbar_ptr: mbar_ptr_QO_struct
1245
+ mbar_ptr_K: mbar_ptr_K_struct
1246
+ mbar_ptr_V: mbar_ptr_V_struct
1247
+ sV: sV_struct
1248
+ sQ: sQ_struct
1249
+ sK: sK_struct
1250
+ sP: sP_struct
1251
+
1252
+ @cute.struct
1253
+ class SharedStorageSharedQV:
1254
+ mbar_ptr: mbar_ptr_QO_struct
1255
+ mbar_ptr_K: mbar_ptr_K_struct
1256
+ mbar_ptr_V: mbar_ptr_V_struct
1257
+ sQ: sQV_struct
1258
+ sK: sK_struct
1259
+ sP: sP_struct
1260
+
1261
+ return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV
1262
+
1263
+ @cute.jit
1264
+ def __call__(
1265
+ self,
1266
+ mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
1267
+ mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table
1268
+ mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table
1269
+ mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
1270
+ mLSE: Optional[cute.Tensor],
1271
+ softmax_scale: Float32,
1272
+ stream: cuda.CUstream,
1273
+ mCuSeqlensQ: Optional[cute.Tensor] = None,
1274
+ mCuSeqlensK: Optional[cute.Tensor] = None,
1275
+ mSeqUsedQ: Optional[cute.Tensor] = None,
1276
+ mSeqUsedK: Optional[cute.Tensor] = None,
1277
+ mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq)
1278
+ window_size_left: Int32 | int | None = None,
1279
+ window_size_right: Int32 | int | None = None,
1280
+ learnable_sink: Optional[cute.Tensor] = None,
1281
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
1282
+ aux_tensors: Optional[list] = None,
1283
+ ):
1284
+ """Configures and launches the flash attention kernel.
1285
+
1286
+ mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:
1287
+ (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1)
1288
+ """
1289
+
1290
+ self._check_type(
1291
+ *(
1292
+ t.element_type if t is not None else None
1293
+ for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)
1294
+ )
1295
+ )
1296
+
1297
+ # Assume all strides are divisible by 128 bits except the last stride
1298
+ new_stride = lambda t: (
1299
+ *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
1300
+ t.stride[-1],
1301
+ )
1302
+
1303
+ mQ, mK, mV, mO = [
1304
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
1305
+ for t in (mQ, mK, mV, mO)
1306
+ ]
1307
+ QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
1308
+ mQ, mO = [utils.select(t, QO_layout_transpose) for t in (mQ, mO)]
1309
+ KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]
1310
+ mK, mV = [utils.select(t, KV_layout_transpose) for t in (mK, mV)]
1311
+ LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
1312
+ mLSE = utils.select(mLSE, LSE_layout_transpose) if const_expr(mLSE is not None) else None
1313
+
1314
+ tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs = self._get_tiled_mma()
1315
+ self.num_mma_threads = tiled_mma_qk.size
1316
+ self.num_threads_per_warp_group = 128
1317
+ self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group
1318
+ self.num_threads = self.num_threads_per_warp_group * (self.num_mma_warp_groups + 1)
1319
+ self.num_producer_threads = 32
1320
+ self.num_Q_load_threads = self.num_mma_threads # If not TMA_Q, MMA threads load Q
1321
+ self.num_epilogue_threads = self.num_mma_threads
1322
+ self.num_mma_regs = (
1323
+ 256
1324
+ if self.num_mma_warp_groups == 1
1325
+ else (240 if self.num_mma_warp_groups == 2 else 160)
1326
+ )
1327
+ self.num_producer_regs = (
1328
+ 56 if self.num_mma_warp_groups == 1 else (24 if self.num_mma_warp_groups == 2 else 32)
1329
+ )
1330
+ # self.num_mma_regs = 232
1331
+ # self.num_producer_regs = 40
1332
+ self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
1333
+
1334
+ self.use_scheduler_barrier = (
1335
+ (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128)
1336
+ if const_expr(self.intra_wg_overlap)
1337
+ else (self.num_mma_warp_groups == 2)
1338
+ )
1339
+ self.use_tma_Q = self.arch >= 90 and not (
1340
+ self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0
1341
+ )
1342
+ self.use_tma_O = (
1343
+ self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa
1344
+ )
1345
+ # TODO: rescale_O_before_gemm
1346
+ self._setup_attributes()
1347
+ # TODO: we prob don't need most of what's in _setup_attributes
1348
+ self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [
1349
+ sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage)
1350
+ for mX, shape, stage in [
1351
+ (mQ, (self.tile_m, self.tile_hdim), None),
1352
+ (mK, (self.tile_n, self.tile_hdim), self.num_stages),
1353
+ (mV, (self.tile_n, self.tile_hdimv), self.num_stages),
1354
+ (mO, (self.tile_m, self.tile_hdimv), None),
1355
+ ]
1356
+ ]
1357
+ self.sP_layout = None
1358
+ if const_expr(not self.mma_pv_is_rs):
1359
+ self.sP_layout = sm90_utils.make_smem_layout(
1360
+ mV.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n)
1361
+ )
1362
+
1363
+ SharedStorage = self._get_shared_storage_cls()
1364
+
1365
+ if const_expr(self.pack_gqa):
1366
+ shape_Q_packed = (
1367
+ (self.qhead_per_kvhead, mQ.shape[0]),
1368
+ mQ.shape[1],
1369
+ mK.shape[2],
1370
+ *mQ.shape[3:],
1371
+ )
1372
+ stride_Q_packed = (
1373
+ (mQ.stride[2], mQ.stride[0]),
1374
+ mQ.stride[1],
1375
+ mQ.stride[2] * self.qhead_per_kvhead,
1376
+ *mQ.stride[3:],
1377
+ )
1378
+ mQ = cute.make_tensor(
1379
+ mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)
1380
+ )
1381
+ shape_O_packed = (
1382
+ (self.qhead_per_kvhead, mO.shape[0]),
1383
+ mK.shape[1],
1384
+ mK.shape[2],
1385
+ *mO.shape[3:],
1386
+ )
1387
+ stride_O_packed = (
1388
+ (mO.stride[2], mO.stride[0]),
1389
+ mO.stride[1],
1390
+ mO.stride[2] * self.qhead_per_kvhead,
1391
+ *mO.stride[3:],
1392
+ )
1393
+ mO = cute.make_tensor(
1394
+ mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)
1395
+ )
1396
+ if const_expr(mLSE is not None):
1397
+ shape_LSE_packed = (
1398
+ (self.qhead_per_kvhead, mLSE.shape[0]),
1399
+ mK.shape[2],
1400
+ *mLSE.shape[2:],
1401
+ )
1402
+ stride_LSE_packed = (
1403
+ (mLSE.stride[1], mLSE.stride[0]),
1404
+ mLSE.stride[1] * self.qhead_per_kvhead,
1405
+ *mLSE.stride[2:],
1406
+ )
1407
+ mLSE = cute.make_tensor(
1408
+ mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)
1409
+ )
1410
+
1411
+ # TMA
1412
+ gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp()
1413
+ gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast
1414
+ gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp()
1415
+ self.tma_copy_bytes = {
1416
+ name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1]))
1417
+ for name, mX, layout in [
1418
+ ("Q", mQ, self.sQ_layout),
1419
+ ("K", mK, self.sK_layout),
1420
+ ("V", mV, self.sV_layout),
1421
+ ]
1422
+ }
1423
+ tma_atom_Q, tma_tensor_Q = None, None
1424
+ if const_expr(self.use_tma_Q):
1425
+ tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom(
1426
+ gmem_tiled_copy_Q,
1427
+ mQ,
1428
+ self.sQ_layout,
1429
+ (self.tile_m, self.tile_hdim), # No mcast
1430
+ )
1431
+ tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom(
1432
+ gmem_tiled_copy_KV,
1433
+ mK,
1434
+ cute.select(self.sK_layout, mode=[0, 1]),
1435
+ (self.tile_n, self.tile_hdim),
1436
+ 1, # No mcast for now
1437
+ )
1438
+ tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom(
1439
+ gmem_tiled_copy_KV,
1440
+ mV,
1441
+ cute.select(self.sV_layout, mode=[0, 1]),
1442
+ (self.tile_n, self.tile_hdimv),
1443
+ 1, # No mcast for now
1444
+ )
1445
+ tma_atom_O, tma_tensor_O = None, None
1446
+ if const_expr(self.use_tma_O):
1447
+ tma_atom_O, tma_tensor_O = cpasync.make_tiled_tma_atom(
1448
+ gmem_tiled_copy_O,
1449
+ mO,
1450
+ self.sO_layout,
1451
+ (self.tile_m, self.tile_hdimv), # No mcast
1452
+ )
1453
+ if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):
1454
+ TileScheduler = SingleTileVarlenScheduler
1455
+ else:
1456
+ TileScheduler = (
1457
+ SingleTileScheduler
1458
+ if const_expr(not self.is_causal or self.is_local)
1459
+ else SingleTileLPTScheduler
1460
+ )
1461
+ tile_sched_args = TileSchedulerArguments(
1462
+ cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m),
1463
+ cute.size(mQ.shape[2]),
1464
+ cute.size(mQ.shape[3])
1465
+ if const_expr(mCuSeqlensQ is None)
1466
+ else cute.size(mCuSeqlensQ.shape[0] - 1),
1467
+ 1, # num_splits
1468
+ cute.size(mK.shape[0]),
1469
+ mQ.shape[1],
1470
+ mV.shape[1],
1471
+ total_q=cute.size(mQ.shape[0])
1472
+ if const_expr(mCuSeqlensQ is not None)
1473
+ else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),
1474
+ tile_shape_mn=(self.tile_m, self.tile_n),
1475
+ mCuSeqlensQ=mCuSeqlensQ,
1476
+ mSeqUsedQ=mSeqUsedQ,
1477
+ qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1478
+ element_size=self.dtype.width // 8,
1479
+ is_persistent=False,
1480
+ lpt=self.is_causal or self.is_local,
1481
+ )
1482
+ tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
1483
+ grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
1484
+ LOG2_E = math.log2(math.e)
1485
+ if const_expr(self.score_mod is None):
1486
+ softmax_scale_log2 = softmax_scale * LOG2_E
1487
+ softmax_scale = None
1488
+ else:
1489
+ # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk
1490
+ # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base
1491
+ # and correctly apply the softmax_scale prior to score_mod in the softmax step
1492
+ softmax_scale_log2 = LOG2_E
1493
+ softmax_scale = softmax_scale
1494
+ if const_expr(window_size_left is not None):
1495
+ window_size_left = Int32(window_size_left)
1496
+ if const_expr(window_size_right is not None):
1497
+ window_size_right = Int32(window_size_right)
1498
+
1499
+ fastdiv_mods = None
1500
+ if const_expr(aux_tensors is not None):
1501
+ seqlen_q = cute.size(mQ.shape[0]) // (
1502
+ self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
1503
+ )
1504
+ seqlen_k = (
1505
+ cute.size(mK.shape[0])
1506
+ if const_expr(mPageTable is None)
1507
+ else mK.shape[0] * mPageTable.shape[1]
1508
+ )
1509
+ seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
1510
+ seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
1511
+ fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
1512
+
1513
+ self.kernel(
1514
+ tma_tensor_Q if const_expr(self.use_tma_Q) else mQ,
1515
+ tma_tensor_K,
1516
+ tma_tensor_V,
1517
+ tma_tensor_O if const_expr(self.use_tma_O) else mO,
1518
+ mLSE,
1519
+ mCuSeqlensQ,
1520
+ mCuSeqlensK,
1521
+ mSeqUsedQ,
1522
+ mSeqUsedK,
1523
+ tma_atom_Q,
1524
+ tma_atom_K,
1525
+ tma_atom_V,
1526
+ tma_atom_O,
1527
+ softmax_scale_log2,
1528
+ softmax_scale,
1529
+ window_size_left,
1530
+ window_size_right,
1531
+ learnable_sink,
1532
+ blocksparse_tensors,
1533
+ self.sQ_layout,
1534
+ self.sK_layout,
1535
+ self.sV_layout,
1536
+ self.sO_layout,
1537
+ self.sP_layout,
1538
+ self.gmem_tiled_copy_Q,
1539
+ self.gmem_tiled_copy_K,
1540
+ self.gmem_tiled_copy_V,
1541
+ self.gmem_tiled_copy_O,
1542
+ tiled_mma_qk,
1543
+ tiled_mma_pv,
1544
+ tiled_mma_pv_rs,
1545
+ tile_sched_params,
1546
+ TileScheduler,
1547
+ SharedStorage,
1548
+ aux_tensors,
1549
+ fastdiv_mods,
1550
+ ).launch(
1551
+ grid=grid_dim,
1552
+ block=[self.num_threads, 1, 1],
1553
+ stream=stream,
1554
+ min_blocks_per_mp=1,
1555
+ )
1556
+
1557
+ @cute.kernel
1558
+ def kernel(
1559
+ self,
1560
+ mQ: cute.Tensor,
1561
+ mK: cute.Tensor,
1562
+ mV: cute.Tensor,
1563
+ mO: cute.Tensor,
1564
+ mLSE: Optional[cute.Tensor],
1565
+ mCuSeqlensQ: Optional[cute.Tensor],
1566
+ mCuSeqlensK: Optional[cute.Tensor],
1567
+ mSeqUsedQ: Optional[cute.Tensor],
1568
+ mSeqUsedK: Optional[cute.Tensor],
1569
+ tma_atom_Q: Optional[cute.CopyAtom],
1570
+ tma_atom_K: Optional[cute.CopyAtom],
1571
+ tma_atom_V: Optional[cute.CopyAtom],
1572
+ tma_atom_O: Optional[cute.CopyAtom],
1573
+ softmax_scale_log2: Float32,
1574
+ softmax_scale: Optional[Float32],
1575
+ window_size_left: Optional[Int32],
1576
+ window_size_right: Optional[Int32],
1577
+ learnable_sink: Optional[cute.Tensor],
1578
+ blocksparse_tensors: Optional[BlockSparseTensors],
1579
+ sQ_layout: cute.ComposedLayout,
1580
+ sK_layout: cute.ComposedLayout,
1581
+ sV_layout: cute.ComposedLayout,
1582
+ sO_layout: cute.ComposedLayout,
1583
+ sP_layout: cute.ComposedLayout | None,
1584
+ gmem_tiled_copy_Q: cute.TiledCopy,
1585
+ gmem_tiled_copy_K: cute.TiledCopy,
1586
+ gmem_tiled_copy_V: cute.TiledCopy,
1587
+ gmem_tiled_copy_O: cute.TiledCopy,
1588
+ tiled_mma_qk: cute.TiledMma,
1589
+ tiled_mma_pv: cute.TiledMma,
1590
+ tiled_mma_pv_rs: cute.TiledMma,
1591
+ tile_sched_params: ParamsBase,
1592
+ TileScheduler: cutlass.Constexpr[Callable],
1593
+ SharedStorage: cutlass.Constexpr[Callable],
1594
+ aux_tensors=Optional[list[cute.Tensor]],
1595
+ fastdiv_mods=None,
1596
+ ):
1597
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1598
+ # Prefetch tma descriptor
1599
+ if warp_idx == 0:
1600
+ for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O):
1601
+ if const_expr(tma_atom is not None):
1602
+ cpasync.prefetch_descriptor(tma_atom)
1603
+
1604
+ smem = cutlass.utils.SmemAllocator()
1605
+ storage = smem.allocate(SharedStorage)
1606
+
1607
+ # Mbarrier init
1608
+ mbar_ptr_Q = storage.mbar_ptr.data_ptr()
1609
+ if warp_idx == 1:
1610
+ # if tidx < 2:
1611
+ # # barrierO num threads should be self.num_mma_threads
1612
+ # cute.arch.mbarrier_init(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads)
1613
+ if const_expr(not self.use_tma_Q):
1614
+ cute.arch.mbarrier_init(mbar_ptr_Q, self.num_Q_load_threads)
1615
+ # cute.arch.mbarrier_init(mbar_ptr_Q + 1, self.num_mma_threads)
1616
+ # We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync
1617
+ pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup(
1618
+ cutlass.pipeline.Agent.Thread
1619
+ )
1620
+ pipeline_kv_consumer_group = cutlass.pipeline.CooperativeGroup(
1621
+ cutlass.pipeline.Agent.Thread, self.num_mma_threads // cute.arch.WARP_SIZE
1622
+ )
1623
+ pipeline_k = pipeline.PipelineTmaAsync.create(
1624
+ barrier_storage=storage.mbar_ptr_K.data_ptr(),
1625
+ num_stages=self.num_stages,
1626
+ producer_group=pipeline_kv_producer_group,
1627
+ consumer_group=pipeline_kv_consumer_group,
1628
+ tx_count=self.tma_copy_bytes["K"],
1629
+ defer_sync=True,
1630
+ )
1631
+ pipeline_v = pipeline.PipelineTmaAsync.create(
1632
+ barrier_storage=storage.mbar_ptr_V.data_ptr(),
1633
+ num_stages=self.num_stages,
1634
+ producer_group=pipeline_kv_producer_group,
1635
+ consumer_group=pipeline_kv_consumer_group,
1636
+ tx_count=self.tma_copy_bytes["V"],
1637
+ defer_sync=False
1638
+ )
1639
+
1640
+ # ///////////////////////////////////////////////////////////////////////////////
1641
+ # Get shared memory buffer
1642
+ # ///////////////////////////////////////////////////////////////////////////////
1643
+ sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner)
1644
+ sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)
1645
+ if const_expr(not self.Q_in_regs):
1646
+ sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner)
1647
+ else:
1648
+ sV = storage.sQ.get_tensor(
1649
+ sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type
1650
+ )
1651
+ # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma
1652
+ sVt = utils.transpose_view(sV)
1653
+ sP = None
1654
+ if const_expr(sP_layout is not None):
1655
+ sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner)
1656
+ # reuse sQ's data iterator
1657
+ sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype)
1658
+
1659
+ block_info = BlockInfo(
1660
+ self.tile_m,
1661
+ self.tile_n,
1662
+ self.is_causal,
1663
+ self.is_local,
1664
+ False, # is_split_kv
1665
+ window_size_left,
1666
+ window_size_right,
1667
+ qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1668
+ )
1669
+ SeqlenInfoCls = partial(
1670
+ SeqlenInfoQK.create,
1671
+ seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1],
1672
+ seqlen_k_static=mK.shape[0],
1673
+ mCuSeqlensQ=mCuSeqlensQ,
1674
+ mCuSeqlensK=mCuSeqlensK,
1675
+ mSeqUsedQ=mSeqUsedQ,
1676
+ mSeqUsedK=mSeqUsedK,
1677
+ )
1678
+ AttentionMaskCls = partial(
1679
+ AttentionMask,
1680
+ self.tile_m,
1681
+ self.tile_n,
1682
+ window_size_left=window_size_left,
1683
+ window_size_right=window_size_right,
1684
+ qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1685
+ )
1686
+ TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
1687
+
1688
+ if warp_idx < 4: # Producer
1689
+ cute.arch.warpgroup_reg_dealloc(self.num_producer_regs)
1690
+ self.load(
1691
+ mQ,
1692
+ mK,
1693
+ mV,
1694
+ sQ,
1695
+ sK,
1696
+ sV,
1697
+ tma_atom_Q,
1698
+ tma_atom_K,
1699
+ tma_atom_V,
1700
+ pipeline_k,
1701
+ pipeline_v,
1702
+ mbar_ptr_Q,
1703
+ blocksparse_tensors,
1704
+ block_info,
1705
+ SeqlenInfoCls,
1706
+ TileSchedulerCls,
1707
+ )
1708
+
1709
+ else: # Consumer
1710
+ cute.arch.warpgroup_reg_alloc(self.num_mma_regs)
1711
+ # ///////////////////////////////////////////////////////////////////////////////
1712
+ # Tile MMA compute thread partitions and allocate accumulators
1713
+ # ///////////////////////////////////////////////////////////////////////////////
1714
+ tidx, _, _ = cute.arch.thread_idx()
1715
+ tidx = tidx - 128
1716
+ self.mma(
1717
+ tiled_mma_qk,
1718
+ tiled_mma_pv,
1719
+ tiled_mma_pv_rs,
1720
+ mQ,
1721
+ mO,
1722
+ mLSE,
1723
+ sQ,
1724
+ sK,
1725
+ sVt,
1726
+ sP,
1727
+ sO,
1728
+ learnable_sink,
1729
+ pipeline_k,
1730
+ pipeline_v,
1731
+ mbar_ptr_Q,
1732
+ gmem_tiled_copy_Q,
1733
+ gmem_tiled_copy_O,
1734
+ tma_atom_O,
1735
+ tidx,
1736
+ softmax_scale_log2,
1737
+ softmax_scale,
1738
+ block_info,
1739
+ SeqlenInfoCls,
1740
+ AttentionMaskCls,
1741
+ TileSchedulerCls,
1742
+ blocksparse_tensors,
1743
+ aux_tensors,
1744
+ fastdiv_mods,
1745
+ )
1746
+
1747
+ @cute.jit
1748
+ def load(
1749
+ self,
1750
+ mQ: cute.Tensor,
1751
+ mK: cute.Tensor,
1752
+ mV: cute.Tensor,
1753
+ sQ: cute.Tensor,
1754
+ sK: cute.Tensor,
1755
+ sV: cute.Tensor,
1756
+ tma_atom_Q: cute.CopyAtom,
1757
+ tma_atom_K: cute.CopyAtom,
1758
+ tma_atom_V: cute.CopyAtom,
1759
+ pipeline_k: cutlass.pipeline.PipelineAsync,
1760
+ pipeline_v: cutlass.pipeline.PipelineAsync,
1761
+ mbar_ptr_Q: cutlass.Pointer,
1762
+ blocksparse_tensors: Optional[BlockSparseTensors],
1763
+ block_info: BlockInfo,
1764
+ SeqlenInfoCls: Callable,
1765
+ TileSchedulerCls: Callable,
1766
+ ):
1767
+ warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
1768
+ if warp_idx_in_wg == 0:
1769
+ q_producer_phase = Int32(1)
1770
+ kv_producer_state = pipeline.make_pipeline_state(
1771
+ cutlass.pipeline.PipelineUserType.Producer, self.num_stages
1772
+ )
1773
+ tile_scheduler = TileSchedulerCls()
1774
+ work_tile = tile_scheduler.initial_work_tile_info()
1775
+ while work_tile.is_valid_tile:
1776
+ # if work_tile.is_valid_tile:
1777
+ m_block, head_idx, batch_idx, _ = work_tile.tile_idx
1778
+ seqlen = SeqlenInfoCls(batch_idx)
1779
+ mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
1780
+ head_idx_kv = (
1781
+ head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx
1782
+ )
1783
+ mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv]
1784
+ mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv]
1785
+ gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0))
1786
+ gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0))
1787
+ if const_expr(self.use_tma_Q):
1788
+ gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))
1789
+ load_Q, _, _ = copy_utils.tma_get_copy_fn(
1790
+ tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True
1791
+ )
1792
+ # TODO: mcast
1793
+ # TODO check warp_idx if we have 128 producer threads
1794
+ load_K, _, _ = copy_utils.tma_get_copy_fn(
1795
+ tma_atom_K, 0, cute.make_layout(1), gK, sK
1796
+ )
1797
+ load_K = copy_utils.tma_producer_copy_fn(load_K, pipeline_k)
1798
+ load_V, _, _ = copy_utils.tma_get_copy_fn(
1799
+ tma_atom_V, 0, cute.make_layout(1), gV, sV
1800
+ )
1801
+ load_V = copy_utils.tma_producer_copy_fn(load_V, pipeline_v)
1802
+
1803
+ if const_expr(not self.use_block_sparsity):
1804
+ n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
1805
+ # if cute.arch.thread_idx()[0] == 0:
1806
+ # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max)
1807
+ # First iteration: load both Q & K with the same mbarrier
1808
+ n_block = n_block_max - 1
1809
+ pipeline_k.producer_acquire(
1810
+ kv_producer_state,
1811
+ extra_tx_count=self.tma_copy_bytes["Q"]
1812
+ if const_expr(self.use_tma_Q)
1813
+ else 0,
1814
+ )
1815
+ if const_expr(self.use_tma_Q):
1816
+ load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
1817
+ load_K(src_idx=n_block, producer_state=kv_producer_state)
1818
+
1819
+ if const_expr(not self.intra_wg_overlap):
1820
+ pipeline_v.producer_acquire(kv_producer_state)
1821
+ load_V(src_idx=n_block, producer_state=kv_producer_state)
1822
+ kv_producer_state.advance()
1823
+ for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
1824
+ n_block = n_block_max - 1 - i - 1
1825
+ pipeline_k.producer_acquire(kv_producer_state)
1826
+ load_K(src_idx=n_block, producer_state=kv_producer_state)
1827
+ pipeline_v.producer_acquire(kv_producer_state)
1828
+ load_V(src_idx=n_block, producer_state=kv_producer_state)
1829
+ kv_producer_state.advance()
1830
+ else:
1831
+ for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
1832
+ n_block_prev = n_block_max - i - 1
1833
+ n_block = n_block_prev - 1
1834
+ kv_producer_state_prev = kv_producer_state.clone()
1835
+ kv_producer_state.advance()
1836
+ pipeline_k.producer_acquire(kv_producer_state)
1837
+ load_K(src_idx=n_block, producer_state=kv_producer_state)
1838
+ pipeline_v.producer_acquire(kv_producer_state_prev)
1839
+ load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev)
1840
+ n_block = n_block_min
1841
+ pipeline_v.producer_acquire(kv_producer_state)
1842
+ load_V(src_idx=n_block, producer_state=kv_producer_state)
1843
+ kv_producer_state.advance()
1844
+ else:
1845
+ kv_producer_state = produce_block_sparse_loads(
1846
+ blocksparse_tensors,
1847
+ batch_idx,
1848
+ head_idx,
1849
+ m_block,
1850
+ kv_producer_state,
1851
+ load_Q,
1852
+ load_K,
1853
+ load_V,
1854
+ pipeline_k,
1855
+ pipeline_v,
1856
+ self.use_tma_Q,
1857
+ self.tma_copy_bytes["Q"],
1858
+ self.intra_wg_overlap,
1859
+ self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1860
+ )
1861
+
1862
+ tile_scheduler.prefetch_next_work()
1863
+ tile_scheduler.advance_to_next_work()
1864
+ work_tile = tile_scheduler.get_current_work()
1865
+ # End of persistent scheduler loop
1866
+
1867
+ @cute.jit
1868
+ def mma(
1869
+ self,
1870
+ tiled_mma_qk: cute.TiledMma,
1871
+ tiled_mma_pv: cute.TiledMma,
1872
+ tiled_mma_pv_rs: cute.TiledMma,
1873
+ # softmax: Softmax,
1874
+ # acc_O: cute.Tensor,
1875
+ mQ: cute.Tensor,
1876
+ mO: cute.Tensor,
1877
+ mLSE: Optional[cute.Tensor],
1878
+ sQ: cute.Tensor,
1879
+ sK: cute.Tensor,
1880
+ sVt: cute.Tensor,
1881
+ sP: Optional[cute.Tensor],
1882
+ sO: cute.Tensor,
1883
+ learnable_sink: Optional[cute.Tensor],
1884
+ pipeline_k: cutlass.pipeline.PipelineAsync,
1885
+ pipeline_v: cutlass.pipeline.PipelineAsync,
1886
+ mbar_ptr_Q: cutlass.Pointer,
1887
+ gmem_tiled_copy_Q: cute.TiledCopy,
1888
+ gmem_tiled_copy_O: cute.TiledCopy,
1889
+ tma_atom_O: Optional[cute.CopyAtom],
1890
+ tidx: Int32,
1891
+ softmax_scale_log2: Float32,
1892
+ softmax_scale: Optional[Float32],
1893
+ block_info: BlockInfo,
1894
+ SeqlenInfoCls: Callable,
1895
+ AttentionMaskCls: Callable,
1896
+ TileSchedulerCls: Callable,
1897
+ blocksparse_tensors: Optional[BlockSparseTensors],
1898
+ aux_tensors: Optional[list],
1899
+ fastdiv_mods=None,
1900
+ ):
1901
+ warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
1902
+ warp_group_thread_layout = cute.make_layout(
1903
+ self.num_mma_warp_groups, stride=self.num_threads_per_warp_group
1904
+ )
1905
+ thr_mma_qk = tiled_mma_qk.get_slice(tidx)
1906
+ wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx))
1907
+ wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx))
1908
+ tSrQ = tiled_mma_qk.make_fragment_A(wg_mma_qk.partition_A(sQ))
1909
+ tSrK = tiled_mma_qk.make_fragment_B(wg_mma_qk.partition_B(sK))
1910
+ if const_expr(self.mma_pv_is_rs):
1911
+ acc_S_shape = tiled_mma_qk.partition_shape_C((self.tile_m, self.tile_n))
1912
+ tOrP = cute.make_fragment(
1913
+ utils.convert_layout_acc_frgA(cute.make_layout(acc_S_shape)), self.dtype
1914
+ )
1915
+ else:
1916
+ tOrP = tiled_mma_pv.make_fragment_A(wg_mma_pv.partition_A(sP))
1917
+ tOrVt = tiled_mma_pv.make_fragment_B(wg_mma_pv.partition_B(sVt))
1918
+
1919
+ # ///////////////////////////////////////////////////////////////////////////////
1920
+ # Smem copy atom tiling
1921
+ # ///////////////////////////////////////////////////////////////////////////////
1922
+ smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype)
1923
+ smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx)
1924
+ # tPsP = smem_thr_copy_P.partition_D(sP_pi) if const_expr(sP_pi is not None) else None
1925
+ tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None
1926
+ # if cute.arch.thread_idx()[0] == 0:
1927
+ # cute.printf(sP_pi.layout, sP_pi.iterator)
1928
+ # cute.printf(sP.layout, sP.iterator)
1929
+ # cute.printf(tPsP.layout, tPsP.iterator)
1930
+
1931
+ self.mma_init()
1932
+
1933
+ acc_shape_O = tiled_mma_pv.partition_shape_C((self.tile_m, self.tile_hdimv))
1934
+ acc_O = cute.make_fragment(acc_shape_O, Float32)
1935
+ smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP)
1936
+
1937
+ mma_qk_fn = partial(
1938
+ sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK
1939
+ )
1940
+ mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt)
1941
+
1942
+ mma_one_n_block_all = partial(
1943
+ self.mma_one_n_block_intrawg_overlap
1944
+ if const_expr(self.intra_wg_overlap)
1945
+ else self.mma_one_n_block,
1946
+ mma_qk_fn=mma_qk_fn,
1947
+ tiled_mma_pv_rs=tiled_mma_pv_rs,
1948
+ pipeline_k=pipeline_k,
1949
+ pipeline_v=pipeline_v,
1950
+ acc_O=acc_O,
1951
+ tOrP=tOrP,
1952
+ smem_copy_params=smem_copy_params,
1953
+ check_inf=True,
1954
+ )
1955
+
1956
+ q_consumer_phase = Int32(0)
1957
+ kv_consumer_state = pipeline.make_pipeline_state(
1958
+ cutlass.pipeline.PipelineUserType.Consumer, self.num_stages
1959
+ )
1960
+
1961
+ tile_scheduler = TileSchedulerCls()
1962
+ work_tile = tile_scheduler.initial_work_tile_info()
1963
+ softmax = Softmax.create(
1964
+ softmax_scale_log2,
1965
+ num_rows=acc_O.shape[0][0] * acc_O.shape[1],
1966
+ softmax_scale=softmax_scale,
1967
+ )
1968
+
1969
+ process_first_half_block = partial(
1970
+ self.first_half_block_overlap,
1971
+ mma_qk_fn=mma_qk_fn,
1972
+ pipeline_k=pipeline_k,
1973
+ tOrP=tOrP,
1974
+ smem_copy_params=smem_copy_params,
1975
+ softmax=softmax,
1976
+ )
1977
+ process_last_half_block = partial(
1978
+ self.last_half_block_overlap,
1979
+ pipeline_v=pipeline_v,
1980
+ mma_pv_fn=mma_pv_fn,
1981
+ )
1982
+ while work_tile.is_valid_tile:
1983
+ # if work_tile.is_valid_tile:
1984
+
1985
+ # shape: (atom_v_m * rest_m)
1986
+ m_block, head_idx, batch_idx, _ = work_tile.tile_idx
1987
+ seqlen = SeqlenInfoCls(batch_idx)
1988
+
1989
+ # Recompute fastdiv_mods if necessary for varlen with aux_tensors
1990
+ recompute_fastdiv_mods_q = cutlass.const_expr(
1991
+ aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
1992
+ )
1993
+ recompute_fastdiv_mods_k = cutlass.const_expr(
1994
+ aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k)
1995
+ )
1996
+ if cutlass.const_expr(fastdiv_mods is not None):
1997
+ seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
1998
+ fastdiv_mods = (
1999
+ seqlen_q_divmod
2000
+ if not recompute_fastdiv_mods_q
2001
+ else FastDivmodDivisor(seqlen.seqlen_q),
2002
+ seqlen_k_divmod
2003
+ if not recompute_fastdiv_mods_k
2004
+ else FastDivmodDivisor(seqlen.seqlen_k),
2005
+ )
2006
+
2007
+ mask = AttentionMaskCls(seqlen)
2008
+ mask_fn = partial(
2009
+ mask.apply_mask,
2010
+ batch_idx=batch_idx,
2011
+ head_idx=head_idx,
2012
+ m_block=m_block,
2013
+ thr_mma=thr_mma_qk,
2014
+ mask_causal=self.is_causal,
2015
+ mask_local=self.is_local,
2016
+ aux_tensors=aux_tensors,
2017
+ fastdiv_mods=fastdiv_mods,
2018
+ )
2019
+ score_mod_fn = None
2020
+ if const_expr(self.score_mod is not None):
2021
+ score_mod_fn = partial(
2022
+ self.apply_score_mod,
2023
+ thr_mma_qk,
2024
+ batch_idx,
2025
+ head_idx,
2026
+ m_block,
2027
+ softmax_scale=softmax_scale,
2028
+ aux_tensors=aux_tensors,
2029
+ fastdiv_mods=fastdiv_mods,
2030
+ )
2031
+ mma_one_n_block = partial(
2032
+ mma_one_n_block_all,
2033
+ seqlen=seqlen,
2034
+ softmax=softmax,
2035
+ score_mod_fn=score_mod_fn,
2036
+ )
2037
+ # Load Q if not TMA_Q
2038
+ if const_expr(not self.use_tma_Q):
2039
+ pack_gqa = PackGQA(
2040
+ self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead
2041
+ )
2042
+ mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
2043
+ # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx)
2044
+ # gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))
2045
+ # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q,
2046
+ # headdim=mQ.shape[1])
2047
+ pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q)
2048
+ cute.arch.cp_async_mbarrier_arrive_noinc(mbar_ptr_Q)
2049
+
2050
+ n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
2051
+ if const_expr(not self.use_tma_Q):
2052
+ cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase)
2053
+ q_consumer_phase ^= 1
2054
+ # For performance reason, we separate out two kinds of iterations:
2055
+ # those that need masking on S, and those that don't.
2056
+ # We need masking on S for the very last block when K and V has length not multiple of tile_n.
2057
+ # We also need masking on S if it's causal, for the last several blocks.
2058
+ # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True
2059
+ O_should_accumulate = False
2060
+
2061
+ # ==========================================
2062
+ # MAINLOOP
2063
+ # ==========================================
2064
+ if const_expr(not self.use_block_sparsity):
2065
+ # ==========================================
2066
+ # No block-sparsity (original path)
2067
+ # ==========================================
2068
+ # First iteration with seqlen masking
2069
+ if const_expr(self.intra_wg_overlap):
2070
+ kv_consumer_state = process_first_half_block(
2071
+ n_block=n_block_max - 1,
2072
+ seqlen=seqlen,
2073
+ kv_consumer_state=kv_consumer_state,
2074
+ mask_fn=partial(mask_fn, mask_mod=self.mask_mod),
2075
+ score_mod_fn=score_mod_fn,
2076
+ is_first_block=True,
2077
+ )
2078
+ # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter
2079
+ # acc_O.fill(0.0)
2080
+ else:
2081
+ self.warp_scheduler_barrier_sync()
2082
+ kv_consumer_state = mma_one_n_block(
2083
+ kv_consumer_state,
2084
+ n_block=n_block_max - 1,
2085
+ seqlen=seqlen,
2086
+ mma_pv_fn=partial(mma_pv_fn, zero_init=True),
2087
+ is_first_n_block=True,
2088
+ mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True),
2089
+ )
2090
+ O_should_accumulate = True
2091
+ # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min)
2092
+ n_block_max -= 1
2093
+ # Next couple of iterations with causal masking
2094
+ if const_expr(self.is_causal or self.is_local):
2095
+ n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask(
2096
+ seqlen, m_block, n_block_min
2097
+ )
2098
+ # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask)
2099
+ for n_tile in cutlass.range(
2100
+ n_block_max - n_block_min_causal_local_mask, unroll=1
2101
+ ):
2102
+ kv_consumer_state = mma_one_n_block(
2103
+ kv_consumer_state,
2104
+ n_block=n_block_max - 1 - n_tile,
2105
+ seqlen=seqlen,
2106
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
2107
+ mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
2108
+ )
2109
+ O_should_accumulate = True
2110
+ n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask)
2111
+ # The remaining iterations have no masking
2112
+ n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask(
2113
+ seqlen, m_block, n_block_min
2114
+ )
2115
+ # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min)
2116
+ for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1):
2117
+ kv_consumer_state = mma_one_n_block(
2118
+ kv_consumer_state,
2119
+ n_block=n_block_max - 1 - n_tile,
2120
+ seqlen=seqlen,
2121
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
2122
+ mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
2123
+ )
2124
+ O_should_accumulate = True
2125
+ # Separate iterations with local masking on the left
2126
+ if const_expr(self.is_local and block_info.window_size_left is not None):
2127
+ n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask)
2128
+ for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1):
2129
+ kv_consumer_state = mma_one_n_block(
2130
+ kv_consumer_state,
2131
+ n_block=n_block_max - 1 - n_tile,
2132
+ seqlen=seqlen,
2133
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
2134
+ mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
2135
+ )
2136
+ O_should_accumulate = True
2137
+ # Last "half" iteration
2138
+ if const_expr(self.intra_wg_overlap):
2139
+ kv_consumer_state = process_last_half_block(
2140
+ kv_consumer_state=kv_consumer_state,
2141
+ zero_init=not O_should_accumulate,
2142
+ )
2143
+ O_should_accumulate = True
2144
+ else:
2145
+ self.warp_scheduler_barrier_arrive()
2146
+
2147
+ else:
2148
+ # ==========================================
2149
+ # Block sparsity
2150
+ # ==========================================
2151
+ kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads(
2152
+ blocksparse_tensors,
2153
+ batch_idx,
2154
+ head_idx,
2155
+ m_block,
2156
+ seqlen,
2157
+ kv_consumer_state,
2158
+ mma_pv_fn,
2159
+ mma_one_n_block,
2160
+ process_first_half_block,
2161
+ process_last_half_block,
2162
+ mask_fn,
2163
+ score_mod_fn,
2164
+ O_should_accumulate,
2165
+ self.mask_mod,
2166
+ fastdiv_mods,
2167
+ self.intra_wg_overlap,
2168
+ self.warp_scheduler_barrier_sync,
2169
+ self.warp_scheduler_barrier_arrive,
2170
+ self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
2171
+ )
2172
+
2173
+ # Handle empty case (when no blocks to process)
2174
+ if not processed_any:
2175
+ softmax.reset()
2176
+ acc_O.fill(0.0)
2177
+
2178
+ sink_val = None
2179
+ if const_expr(learnable_sink is not None):
2180
+ if const_expr(not self.pack_gqa):
2181
+ sink_val = Float32(learnable_sink[head_idx])
2182
+ else: # Each thread might have a different sink value due to different q_head
2183
+ sink_val = cute.make_fragment_like(softmax.row_max, Float32)
2184
+ cS = cute.make_identity_tensor((self.tile_m, self.tile_n))
2185
+ tScS_mn = utils.make_acc_tensor_mn_view(thr_mma_qk.partition_C(cS))
2186
+ for r in cutlass.range(cute.size(sink_val), unroll_full=True):
2187
+ row = m_block * self.tile_m + tScS_mn[r][0]
2188
+ q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead
2189
+ sink_val[r] = Float32(learnable_sink[q_head_idx])
2190
+
2191
+ # normalize acc_O by row_sum and calculate the lse
2192
+ row_scale = softmax.finalize(sink_val=sink_val)
2193
+ softmax.rescale_O(acc_O, row_scale)
2194
+
2195
+ # ///////////////////////////////////////////////////////////////////////////////
2196
+ # Epilogue
2197
+ # ///////////////////////////////////////////////////////////////////////////////
2198
+ self.epilogue(
2199
+ acc_O,
2200
+ softmax.row_sum,
2201
+ mO,
2202
+ mLSE,
2203
+ sO,
2204
+ seqlen,
2205
+ gmem_tiled_copy_O,
2206
+ tma_atom_O,
2207
+ tiled_mma_pv,
2208
+ tidx,
2209
+ m_block,
2210
+ head_idx,
2211
+ batch_idx,
2212
+ )
2213
+
2214
+ tile_scheduler.advance_to_next_work()
2215
+ work_tile = tile_scheduler.get_current_work()
2216
+
2217
+
2218
+ @cute.jit
2219
+ def first_half_block_overlap(
2220
+ self,
2221
+ n_block: Int32,
2222
+ mma_qk_fn: Callable,
2223
+ kv_consumer_state,
2224
+ pipeline_k,
2225
+ tOrP: cute.Tensor,
2226
+ smem_copy_params: SimpleNamespace,
2227
+ softmax: Softmax,
2228
+ seqlen: SeqlenInfoQK,
2229
+ mask_fn: Callable = None,
2230
+ score_mod_fn: Optional[Callable] = None,
2231
+ is_first_block: bool = False,
2232
+ ):
2233
+ """Processes the first half block when using intra-warpgroup-overlap"""
2234
+
2235
+ pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state))
2236
+ acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0)
2237
+ pipeline_k.consumer_release(kv_consumer_state)
2238
+
2239
+ # Apply score modification if present
2240
+ if const_expr(score_mod_fn is not None):
2241
+ score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
2242
+
2243
+ # Apply mask; mask_seqlen always True for first block
2244
+ # Caveat: if full block further right than mask block, seqlen masking is redundant;
2245
+ # however, masking is being applied anyway, so essentially no perf hit
2246
+ mask_fn(acc_S, n_block=n_block, mask_seqlen=True)
2247
+
2248
+ softmax.online_softmax(acc_S, is_first=is_first_block)
2249
+
2250
+ tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout))
2251
+ tOrP_cur = (
2252
+ tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
2253
+ )
2254
+ tOrP_cur.store(tOrP_acc.load().to(self.dtype))
2255
+
2256
+ # if pv gemm not rs
2257
+ if const_expr(not self.mma_pv_is_rs):
2258
+ tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
2259
+ cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
2260
+ # Fence and barrier to make smem store visible to WGMMA
2261
+ cute.arch.fence_proxy(
2262
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
2263
+ )
2264
+ cute.arch.sync_warp()
2265
+
2266
+ return kv_consumer_state
2267
+
2268
+ @cute.jit
2269
+ def last_half_block_overlap(
2270
+ self,
2271
+ kv_consumer_state,
2272
+ pipeline_v,
2273
+ mma_pv_fn: Callable,
2274
+ zero_init: bool,
2275
+ ):
2276
+ """Processes the final PV GEMM when using intra-warpgroup-overlap"""
2277
+
2278
+ pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state))
2279
+ mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0)
2280
+ pipeline_v.consumer_release(kv_consumer_state)
2281
+ kv_consumer_state.advance()
2282
+ return kv_consumer_state
2283
+
2284
+ @cute.jit
2285
+ def mma_one_n_block(
2286
+ self,
2287
+ smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,
2288
+ n_block: Int32,
2289
+ mma_qk_fn: Callable,
2290
+ mma_pv_fn: Callable,
2291
+ tiled_mma_pv_rs: cute.TiledMma,
2292
+ pipeline_k: cutlass.pipeline.PipelineAsync,
2293
+ pipeline_v: cutlass.pipeline.PipelineAsync,
2294
+ acc_O: cute.Tensor,
2295
+ tOrP: cute.Tensor,
2296
+ smem_copy_params: SimpleNamespace,
2297
+ softmax: Softmax,
2298
+ seqlen: SeqlenInfoQK,
2299
+ score_mod_fn: Optional[Callable] = None,
2300
+ mask_fn: Optional[Callable] = None,
2301
+ is_first_n_block: cutlass.Constexpr = False,
2302
+ check_inf: cutlass.Constexpr = True,
2303
+ ):
2304
+ pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))
2305
+ # S = Q @ K.T
2306
+ acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)
2307
+ self.warp_scheduler_barrier_arrive()
2308
+ warpgroup.wait_group(0)
2309
+ pipeline_k.consumer_release(smem_pipe_read)
2310
+
2311
+ # handle score mods and masking
2312
+ if const_expr(score_mod_fn is not None):
2313
+ score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
2314
+ if const_expr(mask_fn is not None):
2315
+ mask_fn(acc_S=acc_S, n_block=n_block)
2316
+
2317
+ row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf)
2318
+ # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S))
2319
+ tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout))
2320
+ tOrP_cur = (
2321
+ tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
2322
+ )
2323
+ # tOrP.store(tOrP_acc.load().to(self.dtype))
2324
+ # the "to(self.dtype)" conversion fails to vectorize for block sizes other
2325
+ # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of
2326
+ # 2 elements. So we just call ptx directly.
2327
+ utils.cvt_f16(tOrP_acc, tOrP_cur)
2328
+ if const_expr(not self.mma_pv_is_rs):
2329
+ tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
2330
+ cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
2331
+ softmax.rescale_O(acc_O, row_scale)
2332
+ if const_expr(not self.mma_pv_is_rs):
2333
+ # Fence and barrier to make sure smem store is visible to WGMMA
2334
+ cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
2335
+ cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
2336
+ pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read))
2337
+ self.warp_scheduler_barrier_sync()
2338
+ # O += P @ V
2339
+ mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0)
2340
+ pipeline_v.consumer_release(smem_pipe_read)
2341
+ smem_pipe_read.advance()
2342
+ return smem_pipe_read
2343
+
2344
+ @cute.jit
2345
+ def mma_one_n_block_intrawg_overlap(
2346
+ self,
2347
+ smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,
2348
+ n_block: Int32,
2349
+ mma_qk_fn: Callable,
2350
+ mma_pv_fn: Callable,
2351
+ tiled_mma_pv_rs: cute.TiledMma,
2352
+ pipeline_k: cutlass.pipeline.PipelineAsync,
2353
+ pipeline_v: cutlass.pipeline.PipelineAsync,
2354
+ acc_O: cute.Tensor,
2355
+ tOrP: cute.Tensor,
2356
+ smem_copy_params: SimpleNamespace,
2357
+ softmax: Softmax,
2358
+ seqlen: SeqlenInfoQK,
2359
+ score_mod_fn: Optional[Callable] = None,
2360
+ mask_fn: Optional[Callable] = None,
2361
+ check_inf: cutlass.Constexpr = True,
2362
+ ):
2363
+ smem_pipe_read_v = smem_pipe_read.clone()
2364
+ smem_pipe_read.advance()
2365
+ pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))
2366
+ self.warp_scheduler_barrier_sync()
2367
+ # S = Q @ K.T
2368
+ acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)
2369
+ pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v))
2370
+ # O += P @ V
2371
+ mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1)
2372
+ self.warp_scheduler_barrier_arrive()
2373
+ warpgroup.wait_group(1)
2374
+ pipeline_k.consumer_release(smem_pipe_read)
2375
+
2376
+ # handle score mods and masking
2377
+ if const_expr(score_mod_fn is not None):
2378
+ score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
2379
+ if const_expr(mask_fn is not None):
2380
+ mask_fn(acc_S=acc_S, n_block=n_block)
2381
+ # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S))
2382
+
2383
+ row_scale = softmax.online_softmax(acc_S, check_inf=check_inf)
2384
+ warpgroup.wait_group(0)
2385
+ pipeline_v.consumer_release(smem_pipe_read_v)
2386
+ tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout))
2387
+ tOrP_cur = (
2388
+ tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
2389
+ )
2390
+ # tOrP_cur.store(tOrP_acc.load().to(self.dtype))
2391
+ # the "to(self.dtype)" conversion fails to vectorize for block sizes other
2392
+ # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of
2393
+ # 2 elements. So we just call ptx directly.
2394
+ utils.cvt_f16(tOrP_acc, tOrP_cur)
2395
+ if const_expr(not self.mma_pv_is_rs):
2396
+ tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
2397
+ cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
2398
+ softmax.rescale_O(acc_O, row_scale)
2399
+ if const_expr(not self.mma_pv_is_rs):
2400
+ # Fence and barrier to make sure smem store is visible to WGMMA
2401
+ cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
2402
+ cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
2403
+ return smem_pipe_read
2404
+
2405
+ @cute.jit
2406
+ def mma_init(self):
2407
+ warp_group_idx = utils.canonical_warp_group_idx(sync=False)
2408
+ if const_expr(self.use_scheduler_barrier):
2409
+ if warp_group_idx == 1:
2410
+ cute.arch.barrier_arrive(
2411
+ barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1),
2412
+ number_of_threads=2 * self.num_threads_per_warp_group,
2413
+ )
2414
+
2415
+ @cute.jit
2416
+ def apply_score_mod(
2417
+ self,
2418
+ thr_mma_qk,
2419
+ batch_idx,
2420
+ head_idx,
2421
+ m_block,
2422
+ acc_S,
2423
+ n_block,
2424
+ softmax_scale,
2425
+ seqlen,
2426
+ aux_tensors: Optional[list] = None,
2427
+ fastdiv_mods=None,
2428
+ ):
2429
+ # Prepare index tensor
2430
+ cS = cute.make_identity_tensor((self.tile_m, self.tile_n))
2431
+ cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS)
2432
+ tScS = thr_mma_qk.partition_C(cS)
2433
+
2434
+ apply_score_mod_inner(
2435
+ acc_S,
2436
+ tScS,
2437
+ self.score_mod,
2438
+ batch_idx,
2439
+ head_idx,
2440
+ softmax_scale,
2441
+ self.vec_size,
2442
+ self.qk_acc_dtype,
2443
+ aux_tensors,
2444
+ fastdiv_mods,
2445
+ seqlen_info=seqlen,
2446
+ constant_q_idx=None,
2447
+ qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
2448
+ )
2449
+
2450
+ def warp_scheduler_barrier_sync(self):
2451
+ if const_expr(self.use_scheduler_barrier):
2452
+ cute.arch.barrier(
2453
+ barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1)
2454
+ - 1
2455
+ + utils.canonical_warp_group_idx(sync=False),
2456
+ number_of_threads=2 * self.num_threads_per_warp_group,
2457
+ )
2458
+
2459
+ def warp_scheduler_barrier_arrive(self):
2460
+ if const_expr(self.use_scheduler_barrier):
2461
+ assert self.num_mma_warp_groups in [2, 3]
2462
+ cur_wg = utils.canonical_warp_group_idx(sync=False) - 1
2463
+ if const_expr(self.num_mma_warp_groups == 2):
2464
+ next_wg = 1 - cur_wg
2465
+ else:
2466
+ t = cur_wg + 1
2467
+ next_wg = t % self.num_mma_warp_groups
2468
+ cute.arch.barrier_arrive(
2469
+ barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg,
2470
+ number_of_threads=2 * self.num_threads_per_warp_group,
2471
+ )