mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,705 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_combine_kernel.h
4
+ # from Cutlass C++ to Cute-DSL.
5
+ import math
6
+ import operator
7
+ from typing import Type, Optional
8
+ from functools import partial
9
+
10
+ import cuda.bindings.driver as cuda
11
+
12
+ import cutlass
13
+ import cutlass.cute as cute
14
+ from cutlass.cute.nvgpu import cpasync
15
+ from cutlass import Float32, Int32, const_expr
16
+
17
+ from mslk.attention.flash_attn import utils
18
+ from mslk.attention.flash_attn.seqlen_info import SeqlenInfo
19
+ from cutlass.cute import FastDivmodDivisor
20
+
21
+
22
+ class FlashAttentionForwardCombine:
23
+ def __init__(
24
+ self,
25
+ dtype: Type[cutlass.Numeric],
26
+ dtype_partial: Type[cutlass.Numeric],
27
+ head_dim: int,
28
+ m_block_size: int = 8,
29
+ k_block_size: int = 64,
30
+ log_max_splits: int = 4,
31
+ num_threads: int = 256,
32
+ stages: int = 4,
33
+ ):
34
+ """
35
+ Forward combine kernel for split attention computation.
36
+
37
+ :param dtype: output data type
38
+ :param dtype_partial: partial accumulation data type
39
+ :param head_dim: head dimension
40
+ :param m_block_size: m block size
41
+ :param k_block_size: k block size
42
+ :param log_max_splits: log2 of maximum splits
43
+ :param num_threads: number of threads
44
+ :param varlen: whether using variable length sequences
45
+ :param stages: number of pipeline stages
46
+ """
47
+ self.dtype = dtype
48
+ self.dtype_partial = dtype_partial
49
+ self.head_dim = head_dim
50
+ self.m_block_size = m_block_size
51
+ self.k_block_size = k_block_size
52
+ self.max_splits = 1 << log_max_splits
53
+ self.num_threads = num_threads
54
+ self.is_even_k = head_dim % k_block_size == 0
55
+ self.stages = stages
56
+
57
+ @staticmethod
58
+ def can_implement(
59
+ dtype,
60
+ dtype_partial,
61
+ head_dim,
62
+ m_block_size,
63
+ k_block_size,
64
+ log_max_splits,
65
+ num_threads,
66
+ ) -> bool:
67
+ """Check if the kernel can be implemented with the given parameters."""
68
+ if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]:
69
+ return False
70
+ if dtype_partial not in [cutlass.Float16, cutlass.BFloat16, Float32]:
71
+ return False
72
+ if head_dim % 8 != 0:
73
+ return False
74
+ if num_threads % 32 != 0:
75
+ return False
76
+ if m_block_size % 8 != 0:
77
+ return False
78
+ max_splits = 1 << log_max_splits
79
+ if max_splits > 256:
80
+ return False
81
+ if (m_block_size * max_splits) % num_threads != 0:
82
+ return False
83
+ return True
84
+
85
+ def _setup_attributes(self):
86
+ # GMEM copy setup for O partial
87
+ universal_copy_bits = 128
88
+ async_copy_elems = universal_copy_bits // self.dtype_partial.width
89
+ assert self.k_block_size % async_copy_elems == 0
90
+
91
+ k_block_gmem = (
92
+ 128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32)
93
+ )
94
+ gmem_threads_per_row = k_block_gmem // async_copy_elems
95
+ assert self.num_threads % gmem_threads_per_row == 0
96
+
97
+ # Async copy atom for O partial load
98
+ atom_async_copy_partial = cute.make_copy_atom(
99
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
100
+ self.dtype_partial,
101
+ num_bits_per_copy=universal_copy_bits,
102
+ )
103
+ tOpartial_layout = cute.make_ordered_layout(
104
+ (self.num_threads // gmem_threads_per_row, gmem_threads_per_row),
105
+ order=(1, 0),
106
+ )
107
+ vOpartial_layout = cute.make_layout((1, async_copy_elems)) # 4 vals per load
108
+ self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv(
109
+ atom_async_copy_partial, tOpartial_layout, vOpartial_layout
110
+ )
111
+
112
+ # GMEM copy setup for final O (use universal copy for store)
113
+ atom_universal_copy = cute.make_copy_atom(
114
+ cute.nvgpu.CopyUniversalOp(),
115
+ self.dtype,
116
+ num_bits_per_copy=async_copy_elems * self.dtype.width,
117
+ )
118
+ self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(
119
+ atom_universal_copy,
120
+ tOpartial_layout,
121
+ vOpartial_layout, # 4 vals per store
122
+ )
123
+
124
+ # LSE copy setup with async copy (alignment = 1)
125
+ lse_copy_bits = Float32.width # 1 element per copy, width is in bits
126
+ m_block_smem = (
127
+ 128
128
+ if self.m_block_size % 128 == 0
129
+ else (
130
+ 64
131
+ if self.m_block_size % 64 == 0
132
+ else (
133
+ 32
134
+ if self.m_block_size % 32 == 0
135
+ else (16 if self.m_block_size % 16 == 0 else 8)
136
+ )
137
+ )
138
+ )
139
+ gmem_threads_per_row_lse = m_block_smem
140
+ assert self.num_threads % gmem_threads_per_row_lse == 0
141
+
142
+ # Async copy atom for LSE load
143
+ atom_async_copy_lse = cute.make_copy_atom(
144
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS),
145
+ Float32,
146
+ num_bits_per_copy=lse_copy_bits,
147
+ )
148
+ tLSE_layout = cute.make_ordered_layout(
149
+ (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse),
150
+ order=(1, 0),
151
+ )
152
+ vLSE_layout = cute.make_layout(1)
153
+ self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv(
154
+ atom_async_copy_lse, tLSE_layout, vLSE_layout
155
+ )
156
+
157
+ # ///////////////////////////////////////////////////////////////////////////////
158
+ # Shared memory
159
+ # ///////////////////////////////////////////////////////////////////////////////
160
+
161
+ # Shared memory to register copy for LSE
162
+ self.smem_threads_per_col_lse = self.num_threads // m_block_smem
163
+ assert 32 % self.smem_threads_per_col_lse == 0 # Must divide warp size
164
+
165
+ s2r_layout_atom_lse = cute.make_ordered_layout(
166
+ (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse),
167
+ order=(0, 1),
168
+ )
169
+ self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv(
170
+ cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32),
171
+ s2r_layout_atom_lse,
172
+ cute.make_layout(1),
173
+ )
174
+
175
+ # LSE shared memory layout with swizzling to avoid bank conflicts
176
+ # This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts
177
+ if const_expr(m_block_smem == 8):
178
+ smem_lse_swizzle = cute.make_swizzle(5, 0, 5)
179
+ elif const_expr(m_block_smem == 16):
180
+ smem_lse_swizzle = cute.make_swizzle(4, 0, 4)
181
+ else:
182
+ smem_lse_swizzle = cute.make_swizzle(3, 2, 3)
183
+ smem_layout_atom_lse = cute.make_composed_layout(
184
+ smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
185
+ )
186
+ self.smem_layout_lse = cute.tile_to_shape(
187
+ smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1)
188
+ )
189
+
190
+ # O partial shared memory layout (simple layout for pipeline stages)
191
+ self.smem_layout_o = cute.make_ordered_layout(
192
+ (self.m_block_size, self.k_block_size, self.stages), order=(1, 0, 2)
193
+ )
194
+
195
+ @cute.jit
196
+ def __call__(
197
+ self,
198
+ mO_partial: cute.Tensor,
199
+ mLSE_partial: cute.Tensor,
200
+ mO: cute.Tensor,
201
+ mLSE: Optional[cute.Tensor] = None,
202
+ cu_seqlens: Optional[cute.Tensor] = None,
203
+ seqused: Optional[cute.Tensor] = None,
204
+ num_splits_dynamic_ptr: Optional[cute.Tensor] = None,
205
+ semaphore_to_reset: Optional[cute.Tensor] = None,
206
+ stream: cuda.CUstream = None,
207
+ ):
208
+ # Type checking
209
+ if const_expr(not (mO_partial.element_type == self.dtype_partial)):
210
+ raise TypeError("O partial tensor must match dtype_partial")
211
+ if const_expr(not (mO.element_type == self.dtype)):
212
+ raise TypeError("O tensor must match dtype")
213
+ if const_expr(mLSE_partial.element_type not in [Float32]):
214
+ raise TypeError("LSE partial tensor must be Float32")
215
+ if const_expr(mLSE is not None and mLSE.element_type not in [Float32]):
216
+ raise TypeError("LSE tensor must be Float32")
217
+
218
+ # Shape validation - input tensors are in user format, need to be converted to kernel format
219
+ if const_expr(len(mO_partial.shape) not in [4, 5]):
220
+ raise ValueError(
221
+ "O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)"
222
+ )
223
+ if const_expr(len(mLSE_partial.shape) not in [3, 4]):
224
+ raise ValueError(
225
+ "LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)"
226
+ )
227
+ if const_expr(len(mO.shape) not in [3, 4]):
228
+ raise ValueError(
229
+ "O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)"
230
+ )
231
+ if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]):
232
+ raise ValueError(
233
+ "LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)"
234
+ )
235
+
236
+ # Assume all strides are divisible by 128 bits except the last stride
237
+ new_stride = lambda t: (
238
+ *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
239
+ t.stride[-1],
240
+ )
241
+ mO_partial, mO = [
242
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
243
+ for t in (mO_partial, mO)
244
+ ]
245
+ # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b)
246
+ # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h)
247
+ O_partial_layout_transpose = (
248
+ [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2]
249
+ )
250
+ # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h)
251
+ mO_partial = cute.make_tensor(
252
+ mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose)
253
+ )
254
+ O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1]
255
+ mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose))
256
+ # (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b)
257
+ # or (num_splits, total_q, h) -> (total_q, num_splits, h)
258
+ LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2]
259
+ mLSE_partial = cute.make_tensor(
260
+ mLSE_partial.iterator,
261
+ cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose),
262
+ )
263
+ # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h)
264
+ LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1]
265
+ mLSE = (
266
+ cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))
267
+ if mLSE is not None
268
+ else None
269
+ )
270
+
271
+ # Determine if we have variable length sequences
272
+ varlen = const_expr(cu_seqlens is not None or seqused is not None)
273
+
274
+ self._setup_attributes()
275
+
276
+ @cute.struct
277
+ class SharedStorage:
278
+ sLSE: cute.struct.Align[
279
+ cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
280
+ ]
281
+ sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.m_block_size], 128]
282
+ sO: cute.struct.Align[
283
+ cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
284
+ ]
285
+
286
+ smem_size = SharedStorage.size_in_bytes()
287
+
288
+ # Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch)
289
+ seqlen = mO_partial.shape[0]
290
+ num_head = mO_partial.shape[3]
291
+ batch_size = (
292
+ mO_partial.shape[4]
293
+ if const_expr(cu_seqlens is None)
294
+ else Int32(cu_seqlens.shape[0] - 1)
295
+ )
296
+
297
+ # Create FastDivmodDivisor objects for efficient division
298
+ seqlen_divmod = FastDivmodDivisor(seqlen)
299
+ head_divmod = FastDivmodDivisor(num_head)
300
+
301
+ grid_dim = (
302
+ cute.ceil_div(seqlen * num_head, self.m_block_size),
303
+ cute.ceil_div(self.head_dim, self.k_block_size),
304
+ batch_size,
305
+ )
306
+
307
+ self.kernel(
308
+ mO_partial,
309
+ mLSE_partial,
310
+ mO,
311
+ mLSE,
312
+ cu_seqlens,
313
+ seqused,
314
+ num_splits_dynamic_ptr,
315
+ semaphore_to_reset,
316
+ SharedStorage,
317
+ self.smem_layout_lse,
318
+ self.smem_layout_o,
319
+ self.gmem_tiled_copy_O_partial,
320
+ self.gmem_tiled_copy_O,
321
+ self.gmem_tiled_copy_LSE,
322
+ self.s2r_tiled_copy_LSE,
323
+ seqlen_divmod,
324
+ head_divmod,
325
+ varlen,
326
+ ).launch(
327
+ grid=grid_dim,
328
+ block=[self.num_threads, 1, 1],
329
+ smem=smem_size,
330
+ stream=stream,
331
+ )
332
+
333
+ @cute.kernel
334
+ def kernel(
335
+ self,
336
+ mO_partial: cute.Tensor,
337
+ mLSE_partial: cute.Tensor,
338
+ mO: cute.Tensor,
339
+ mLSE: Optional[cute.Tensor],
340
+ cu_seqlens: Optional[cute.Tensor],
341
+ seqused: Optional[cute.Tensor],
342
+ num_splits_dynamic_ptr: Optional[cute.Tensor],
343
+ semaphore_to_reset: Optional[cute.Tensor],
344
+ SharedStorage: cutlass.Constexpr,
345
+ smem_layout_lse: cute.Layout | cute.ComposedLayout,
346
+ smem_layout_o: cute.Layout,
347
+ gmem_tiled_copy_O_partial: cute.TiledCopy,
348
+ gmem_tiled_copy_O: cute.TiledCopy,
349
+ gmem_tiled_copy_LSE: cute.TiledCopy,
350
+ s2r_tiled_copy_LSE: cute.TiledCopy,
351
+ seqlen_divmod: FastDivmodDivisor,
352
+ head_divmod: FastDivmodDivisor,
353
+ varlen: cutlass.Constexpr[bool],
354
+ ):
355
+ # Thread and block indices
356
+ tidx, _, _ = cute.arch.thread_idx()
357
+ m_block, k_block, batch_idx = cute.arch.block_idx()
358
+
359
+ # ///////////////////////////////////////////////////////////////////////////////
360
+ # Get shared memory buffer
361
+ # ///////////////////////////////////////////////////////////////////////////////
362
+ smem = cutlass.utils.SmemAllocator()
363
+ storage = smem.allocate(SharedStorage)
364
+ sLSE = storage.sLSE.get_tensor(smem_layout_lse)
365
+ sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.m_block_size,))
366
+ sO = storage.sO.get_tensor(smem_layout_o)
367
+
368
+ # Handle semaphore reset
369
+ if const_expr(semaphore_to_reset is not None):
370
+ if (
371
+ tidx == 0
372
+ and m_block == cute.arch.grid_dim()[0] - 1
373
+ and k_block == cute.arch.grid_dim()[1] - 1
374
+ and batch_idx == cute.arch.grid_dim()[2] - 1
375
+ ):
376
+ semaphore_to_reset[0] = 0
377
+
378
+ # Get number of splits
379
+ num_splits = (
380
+ num_splits_dynamic_ptr[batch_idx]
381
+ if const_expr(num_splits_dynamic_ptr is not None)
382
+ else mLSE_partial.shape[1]
383
+ )
384
+ # Handle variable length sequences using SeqlenInfo
385
+ seqlen_info = SeqlenInfo.create(
386
+ batch_idx=batch_idx,
387
+ seqlen_static=mO_partial.shape[0],
388
+ cu_seqlens=cu_seqlens,
389
+ seqused=seqused,
390
+ )
391
+ seqlen, offset = seqlen_info.seqlen, seqlen_info.offset
392
+
393
+ # Extract number of heads (head index will be determined dynamically)
394
+ num_head = mO_partial.shape[3]
395
+ max_idx = seqlen * num_head
396
+
397
+ # Early exit for single split if dynamic
398
+ if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (
399
+ const_expr(not varlen) or m_block * self.m_block_size < max_idx
400
+ ):
401
+ # ===============================
402
+ # Step 1: Load LSE_partial from gmem to shared memory
403
+ # ===============================
404
+
405
+ if const_expr(cu_seqlens is None):
406
+ # mLSE_partial_cur = mLSE_partial[None, None, None, batch_idx]
407
+ mLSE_partial_cur = utils.coord_offset_i64(mLSE_partial, batch_idx, dim=3)
408
+ else:
409
+ # mLSE_partial_cur = cute.domain_offset((offset, 0, 0), mLSE_partial)
410
+ mLSE_partial_cur = utils.domain_offset_i64((offset, 0, 0), mLSE_partial)
411
+ mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,))
412
+
413
+ gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx)
414
+ tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE)
415
+
416
+ # Create identity tensor for coordinate tracking
417
+ cLSE = cute.make_identity_tensor((self.max_splits, self.m_block_size))
418
+ tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE)
419
+
420
+ # Load LSE partial values
421
+ for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True):
422
+ mi = tLSEcLSE[0, 0, m][1] # Get m coordinate
423
+ idx = m_block * self.m_block_size + mi
424
+ if idx < max_idx:
425
+ # Calculate actual sequence position and head using FastDivmodDivisor
426
+ if const_expr(not varlen):
427
+ head_idx, m_idx = divmod(idx, seqlen_divmod)
428
+ else:
429
+ head_idx = idx // seqlen
430
+ m_idx = idx - head_idx * seqlen
431
+ mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx]
432
+ for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
433
+ si = tLSEcLSE[0, s, 0][0] # Get split coordinate
434
+ if si < num_splits:
435
+ cute.copy(
436
+ gmem_thr_copy_LSE,
437
+ mLSE_partial_cur_copy[None, si],
438
+ tLSEsLSE[None, s, m],
439
+ )
440
+ else:
441
+ tLSEsLSE[None, s, m].fill(-Float32.inf)
442
+ # Don't need to zero out the rest of the LSEs, as we will not write the output to gmem
443
+ cute.arch.cp_async_commit_group()
444
+
445
+ # ===============================
446
+ # Step 2: Load O_partial for pipeline stages
447
+ # ===============================
448
+
449
+ gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx)
450
+ cO = cute.make_identity_tensor((self.m_block_size, self.k_block_size))
451
+ tOcO = gmem_thr_copy_O_partial.partition_D(cO)
452
+ tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO)
453
+ if const_expr(cu_seqlens is None):
454
+ # mO_partial_cur = mO_partial[None, None, None, None, batch_idx]
455
+ mO_partial_cur = utils.coord_offset_i64(mO_partial, batch_idx, dim=4)
456
+ else:
457
+ # mO_partial_cur = cute.domain_offset((offset, 0, 0, 0), mO_partial)
458
+ mO_partial_cur = utils.domain_offset_i64((offset, 0, 0, 0), mO_partial)
459
+
460
+ # Precompute these values to avoid recomputing them in the loop
461
+ num_rows = const_expr(cute.size(tOcO, mode=[1]))
462
+ tOmidx = cute.make_fragment(num_rows, cutlass.Int32)
463
+ tOhidx = cute.make_fragment(num_rows, cutlass.Int32)
464
+ tOrOptr = cute.make_fragment(num_rows, cutlass.Int64)
465
+ for m in cutlass.range(num_rows, unroll_full=True):
466
+ mi = tOcO[0, m, 0][0] # m coordinate
467
+ idx = m_block * self.m_block_size + mi
468
+ if const_expr(not varlen):
469
+ tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod)
470
+ else:
471
+ tOhidx[m] = idx // seqlen
472
+ tOmidx[m] = idx - tOhidx[m] * seqlen
473
+ tOrOptr[m] = utils.elem_pointer_i64(
474
+ mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])
475
+ ).toint()
476
+ if idx >= max_idx:
477
+ tOhidx[m] = -1
478
+
479
+ tOpO = cute.make_fragment(cute.size(tOcO, [2]), cutlass.Boolean)
480
+ if const_expr(not self.is_even_k):
481
+ for k in cutlass.range(cute.size(tOpO), unroll_full=True):
482
+ tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size
483
+ # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO)
484
+
485
+ load_O_partial = partial(
486
+ self.load_O_partial,
487
+ gmem_tiled_copy_O_partial,
488
+ tOrOptr,
489
+ tOsO_partial,
490
+ tOhidx,
491
+ tOpO,
492
+ tOcO,
493
+ mO_partial_cur.layout,
494
+ )
495
+
496
+ # Load first few stages of O_partial
497
+ for stage in cutlass.range(self.stages - 1, unroll_full=True):
498
+ if stage < num_splits:
499
+ load_O_partial(stage, stage)
500
+ cute.arch.cp_async_commit_group()
501
+
502
+ # ===============================
503
+ # Step 3: Load and transpose LSE from smem to registers
504
+ # ===============================
505
+
506
+ # Wait for LSE and initial O partial stages to complete
507
+ cute.arch.cp_async_wait_group(self.stages - 1)
508
+ cute.arch.sync_threads()
509
+ # if cute.arch.thread_idx()[0] == 0:
510
+ # # cute.print_tensor(sLSE)
511
+ # for i in range(64):
512
+ # cute.printf("sLSE[%d, 0] = %f", i, sLSE[i, 0])
513
+ # cute.arch.sync_threads()
514
+
515
+ s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx)
516
+ ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE)
517
+ ts2rrLSE = cute.make_fragment_like(ts2rsLSE)
518
+ cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE)
519
+
520
+ # ===============================
521
+ # Step 4: Compute final LSE along split dimension
522
+ # ===============================
523
+
524
+ lse_sum = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Float32)
525
+ ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE)
526
+ # We compute the max valid split for each row to short-circuit the computation later
527
+ max_valid_split = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Int32)
528
+ assert cute.size(ts2rrLSE, mode=[0]) == 1
529
+ # Compute max, scales, and final LSE for each row
530
+ for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
531
+ # Find max LSE value across splits
532
+ threads_per_col = const_expr(self.smem_threads_per_col_lse)
533
+ lse_max = utils.warp_reduce(
534
+ ts2rrLSE[None, None, m]
535
+ .load()
536
+ .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
537
+ op=cute.arch.fmax,
538
+ width=threads_per_col,
539
+ )
540
+ # if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max)
541
+ # Find max valid split index
542
+ max_valid_idx = -1
543
+ for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
544
+ if ts2rrLSE[0, s, m] != -Float32.inf:
545
+ max_valid_idx = ts2rcLSE[0, s, 0][0] # Get split coordinate
546
+ # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx)
547
+ max_valid_split[m] = utils.warp_reduce(max_valid_idx, max, width=threads_per_col)
548
+ # Compute exp scales and sum
549
+ lse_max_cur = (
550
+ 0.0 if lse_max == -Float32.inf else lse_max
551
+ ) # In case all local LSEs are -inf
552
+ LOG2_E = math.log2(math.e)
553
+ lse_sum_cur = 0.0
554
+ for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
555
+ scale = utils.exp2f(ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E))
556
+ lse_sum_cur += scale
557
+ ts2rrLSE[0, s, m] = scale # Store scale for later use
558
+ lse_sum_cur = utils.warp_reduce(lse_sum_cur, operator.add, width=threads_per_col)
559
+ lse_sum[m] = utils.logf(lse_sum_cur) + lse_max
560
+ # Normalize scales
561
+ inv_sum = (
562
+ 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur
563
+ )
564
+ ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum)
565
+ # Store the scales exp(lse - lse_logsum) back to smem
566
+ cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE)
567
+
568
+ # Store max valid split to smem
569
+ for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
570
+ if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
571
+ mi = ts2rcLSE[0, 0, m][1]
572
+ if mi < self.m_block_size:
573
+ sMaxValidSplit[mi] = max_valid_split[m]
574
+
575
+ # ===============================
576
+ # Step 5: Store final LSE to gmem
577
+ # ===============================
578
+
579
+ if const_expr(mLSE is not None):
580
+ if const_expr(cu_seqlens is None):
581
+ # mLSE_cur = mLSE[None, None, batch_idx]
582
+ mLSE_cur = utils.coord_offset_i64(mLSE, batch_idx, dim=2)
583
+ else:
584
+ # mLSE_cur = cute.domain_offset((offset, 0), mLSE)
585
+ mLSE_cur = utils.domain_offset_i64((offset, 0), mLSE)
586
+ if k_block == 0: # Only first k_block writes LSE when mLSE is provided
587
+ for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
588
+ if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
589
+ mi = ts2rcLSE[0, 0, m][1]
590
+ idx = m_block * self.m_block_size + mi
591
+ if idx < max_idx:
592
+ if const_expr(not varlen):
593
+ head_idx, m_idx = divmod(idx, seqlen_divmod)
594
+ else:
595
+ head_idx = idx // seqlen
596
+ m_idx = idx - head_idx * seqlen
597
+ mLSE_cur[m_idx, head_idx] = lse_sum[m]
598
+
599
+ # ===============================
600
+ # Step 6: Read O_partial and accumulate final O
601
+ # ===============================
602
+
603
+ cute.arch.sync_threads()
604
+
605
+ # Get max valid split for this thread
606
+ thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]]
607
+ for m in cutlass.range(1, cute.size(tOcO, mode=[1])):
608
+ thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]])
609
+
610
+ tOrO_partial = cute.make_fragment_like(tOsO_partial[None, None, None, 0])
611
+ tOrO = cute.make_fragment_like(tOrO_partial, Float32)
612
+ tOrO.fill(0.0)
613
+
614
+ stage_load = self.stages - 1
615
+ stage_compute = 0
616
+
617
+ # Main accumulation loop
618
+ for s in cutlass.range(thr_max_valid_split + 1, unroll=4):
619
+ # Get scales for this split
620
+ scale = cute.make_fragment(num_rows, Float32)
621
+ for m in cutlass.range(num_rows, unroll_full=True):
622
+ scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem
623
+
624
+ # Load next stage if needed
625
+ split_to_load = s + self.stages - 1
626
+ if split_to_load <= thr_max_valid_split:
627
+ load_O_partial(split_to_load, stage_load)
628
+ cute.arch.cp_async_commit_group()
629
+ stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1
630
+
631
+ # Wait for the current stage to be ready
632
+ cute.arch.cp_async_wait_group(self.stages - 1)
633
+ # We don't need __syncthreads() because each thread is just reading its own data from smem
634
+ # Copy from smem to registers
635
+ cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial)
636
+ stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1
637
+
638
+ # Accumulate scaled partial results
639
+ for m in cutlass.range(num_rows, unroll_full=True):
640
+ if tOhidx[m] >= 0 and scale[m] > 0.0:
641
+ tOrO[None, m, None].store(
642
+ tOrO[None, m, None].load()
643
+ + scale[m] * tOrO_partial[None, m, None].load().to(Float32)
644
+ )
645
+
646
+ # ===============================
647
+ # Step 7: Write final O to gmem
648
+ # ===============================
649
+
650
+ rO = cute.make_fragment_like(tOrO, self.dtype)
651
+ rO.store(tOrO.load().to(self.dtype))
652
+ if const_expr(cu_seqlens is None):
653
+ # mO_cur = mO[None, None, None, batch_idx]
654
+ mO_cur = utils.coord_offset_i64(mO, batch_idx, dim=3)
655
+ else:
656
+ # mO_cur = cute.domain_offset((offset, 0, 0), mO)
657
+ mO_cur = utils.domain_offset_i64((offset, 0, 0), mO)
658
+ mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur)
659
+ elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1]))
660
+ # mO_cur_copy = cute.tiled_divide(mO_cur, (1, elems_per_store,))
661
+ gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
662
+ # Write final results
663
+ for m in cutlass.range(num_rows, unroll_full=True):
664
+ if tOhidx[m] >= 0:
665
+ mO_cur_copy = cute.tiled_divide(
666
+ mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,)
667
+ )
668
+ for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
669
+ k_idx = tOcO[0, 0, k][1] // elems_per_store
670
+ if const_expr(self.is_even_k) or tOpO[k]:
671
+ cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx])
672
+
673
+ @cute.jit
674
+ def load_O_partial(
675
+ self,
676
+ gmem_tiled_copy_O_partial: cute.TiledCopy,
677
+ tOrOptr: cute.Tensor,
678
+ tOsO_partial: cute.Tensor,
679
+ tOhidx: cute.Tensor,
680
+ tOpO: cute.Tensor,
681
+ tOcO: cute.Tensor,
682
+ mO_cur_partial_layout: cute.Layout,
683
+ split: Int32,
684
+ stage: Int32,
685
+ ) -> None:
686
+ elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1]))
687
+ tOsO_partial_cur = tOsO_partial[None, None, None, stage]
688
+ for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True):
689
+ if tOhidx[m] >= 0:
690
+ o_gmem_ptr = cute.make_ptr(
691
+ tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16
692
+ )
693
+ mO_partial_cur = cute.make_tensor(
694
+ o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0))
695
+ )
696
+ mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,))
697
+ for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
698
+ k_idx = tOcO[0, 0, k][1] // elems_per_load
699
+ if const_expr(self.is_even_k) or tOpO[k]:
700
+ cute.copy(
701
+ gmem_tiled_copy_O_partial,
702
+ # mO_partial_cur_copy[None, k_idx, split],
703
+ utils.coord_offset_i64(mO_partial_cur_copy, split, dim=2)[None, k_idx],
704
+ tOsO_partial_cur[None, m, k],
705
+ )