mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,366 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_preprocess_kernel.h
4
+ # from Cutlass C++ to Cute-DSL.
5
+ import math
6
+ import operator
7
+ from typing import Callable, Type, Optional, Literal
8
+
9
+ import cuda.bindings.driver as cuda
10
+
11
+ import cutlass
12
+ import cutlass.cute as cute
13
+ from cutlass import Float32
14
+
15
+ from mslk.attention.flash_attn import utils
16
+ from mslk.attention.flash_attn import copy_utils
17
+ from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
18
+ from mslk.attention.flash_attn.tile_scheduler import (
19
+ ParamsBase,
20
+ SingleTileScheduler,
21
+ SingleTileVarlenScheduler,
22
+ TileSchedulerArguments,
23
+ )
24
+
25
+
26
+ class FlashAttentionBackwardPreprocess:
27
+ def __init__(
28
+ self,
29
+ dtype: Type[cutlass.Numeric],
30
+ head_dim: int,
31
+ arch: Literal[80, 90, 100],
32
+ m_block_size: int = 128,
33
+ num_threads: int = 128,
34
+ ):
35
+ """
36
+ All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension
37
+ should be a multiple of 8.
38
+
39
+ :param head_dim: head dimension
40
+ :type head_dim: int
41
+ :param m_block_size: m block size
42
+ :type m_block_size: int
43
+ :param num_threads: number of threads
44
+ :type num_threads: int
45
+ """
46
+ self.dtype = dtype
47
+ self.m_block_size = m_block_size
48
+ self.arch = arch
49
+ # padding head_dim to a multiple of 32 as k_block_size
50
+ hdim_multiple_of = 32
51
+ self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
52
+ self.check_hdim_oob = head_dim != self.head_dim_padded
53
+ self.num_threads = num_threads
54
+
55
+ @staticmethod
56
+ def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool:
57
+ """Check if the kernel can be implemented with the given parameters.
58
+
59
+ :param dtype: data type
60
+ :type dtype: cutlass.Numeric
61
+ :param head_dim: head dimension
62
+ :type head_dim: int
63
+ :param m_block_size: m block size
64
+ :type m_block_size: int
65
+ :param num_threads: number of threads
66
+ :type num_threads: int
67
+
68
+ :return: True if the kernel can be implemented, False otherwise
69
+ :rtype: bool
70
+ """
71
+ if dtype not in [cutlass.Float16, cutlass.BFloat16]:
72
+ return False
73
+ if head_dim % 8 != 0:
74
+ return False
75
+ if num_threads % 32 != 0:
76
+ return False
77
+ if num_threads < m_block_size: # For multiplying lse with log2
78
+ return False
79
+ return True
80
+
81
+ def _setup_attributes(self):
82
+ # ///////////////////////////////////////////////////////////////////////////////
83
+ # GMEM Tiled copy:
84
+ # ///////////////////////////////////////////////////////////////////////////////
85
+ # Thread layouts for copies
86
+ # We want kBlockKGmem to be a power of 2 so that when we do the summing,
87
+ # it's just between threads in the same warp
88
+ gmem_k_block_size = (
89
+ 128
90
+ if self.head_dim_padded % 128 == 0
91
+ else (
92
+ 64
93
+ if self.head_dim_padded % 64 == 0
94
+ else (32 if self.head_dim_padded % 32 == 0 else 16)
95
+ )
96
+ )
97
+ self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d(
98
+ self.dtype, gmem_k_block_size, self.num_threads
99
+ )
100
+ universal_copy_bits = 128
101
+ num_copy_elems_dQaccum = universal_copy_bits // Float32.width
102
+ assert (
103
+ self.m_block_size * self.head_dim_padded // num_copy_elems_dQaccum
104
+ ) % self.num_threads == 0
105
+ self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
106
+ Float32, self.num_threads, num_copy_elems_dQaccum
107
+ )
108
+
109
+ @cute.jit
110
+ def __call__(
111
+ self,
112
+ mO: cute.Tensor,
113
+ mdO: cute.Tensor,
114
+ mdPsum: cute.Tensor,
115
+ mLSE: Optional[cute.Tensor],
116
+ mLSElog2: Optional[cute.Tensor],
117
+ mdQaccum: Optional[cute.Tensor],
118
+ mCuSeqlensQ: Optional[cute.Tensor],
119
+ mSeqUsedQ: Optional[cute.Tensor],
120
+ stream: cuda.CUstream,
121
+ ):
122
+ # Get the data type and check if it is fp16 or bf16
123
+ if cutlass.const_expr(not (mO.element_type == mdO.element_type)):
124
+ raise TypeError("All tensors must have the same data type")
125
+ if cutlass.const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]):
126
+ raise TypeError("Only Float16 or BFloat16 is supported")
127
+ if cutlass.const_expr(mdPsum.element_type not in [Float32]):
128
+ raise TypeError("dPsum tensor must be Float32")
129
+ if cutlass.const_expr(mdQaccum is not None):
130
+ if cutlass.const_expr(mdQaccum.element_type not in [Float32]):
131
+ raise TypeError("dQaccum tensor must be Float32")
132
+ if cutlass.const_expr(mLSE is not None):
133
+ assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided"
134
+ if cutlass.const_expr(mLSE.element_type not in [Float32]):
135
+ raise TypeError("LSE tensor must be Float32")
136
+ if cutlass.const_expr(mLSElog2.element_type not in [Float32]):
137
+ raise TypeError("LSElog2 tensor must be Float32")
138
+
139
+ # Assume all strides are divisible by 128 bits except the last stride
140
+ new_stride = lambda t: (
141
+ *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
142
+ t.stride[-1],
143
+ )
144
+ mO, mdO, mdQaccum = [
145
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
146
+ if t is not None
147
+ else None
148
+ for t in (mO, mdO, mdQaccum)
149
+ ]
150
+
151
+ self._setup_attributes()
152
+
153
+ if cutlass.const_expr(mCuSeqlensQ is not None):
154
+ TileScheduler = SingleTileVarlenScheduler
155
+ num_head = mO.shape[1]
156
+ num_batch = mCuSeqlensQ.shape[0] - 1
157
+ else:
158
+ TileScheduler = SingleTileScheduler
159
+ num_head = mO.shape[2]
160
+ num_batch = mO.shape[0]
161
+
162
+ tile_sched_args = TileSchedulerArguments(
163
+ num_block=cute.ceil_div(mO.shape[1], self.m_block_size),
164
+ num_head=num_head,
165
+ num_batch=num_batch,
166
+ num_splits=1,
167
+ seqlen_k=0,
168
+ headdim=0,
169
+ headdim_v=mO.shape[2],
170
+ total_q=mO.shape[0],
171
+ tile_shape_mn=(self.m_block_size, 1),
172
+ mCuSeqlensQ=mCuSeqlensQ,
173
+ mSeqUsedQ=mSeqUsedQ,
174
+ )
175
+
176
+ tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
177
+ grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
178
+
179
+ self.kernel(
180
+ mO,
181
+ mdO,
182
+ mdPsum,
183
+ mLSE,
184
+ mLSElog2,
185
+ mdQaccum,
186
+ mCuSeqlensQ,
187
+ mSeqUsedQ,
188
+ self.gmem_tiled_copy_O,
189
+ self.gmem_tiled_copy_dQaccum,
190
+ tile_sched_params,
191
+ TileScheduler,
192
+ ).launch(
193
+ grid=grid_dim,
194
+ block=[self.num_threads, 1, 1],
195
+ stream=stream,
196
+ )
197
+
198
+ @cute.kernel
199
+ def kernel(
200
+ self,
201
+ mO: cute.Tensor,
202
+ mdO: cute.Tensor,
203
+ mdPsum: cute.Tensor,
204
+ mLSE: Optional[cute.Tensor],
205
+ mLSElog2: Optional[cute.Tensor],
206
+ mdQaccum: Optional[cute.Tensor],
207
+ mCuSeqlensQ: Optional[cute.Tensor],
208
+ mSeqUsedQ: Optional[cute.Tensor],
209
+ gmem_tiled_copy_O: cute.TiledCopy,
210
+ gmem_tiled_copy_dQaccum: cute.TiledCopy,
211
+ tile_sched_params: ParamsBase,
212
+ TileScheduler: cutlass.Constexpr[Callable],
213
+ ):
214
+ # Thread index, block index
215
+ tidx, _, _ = cute.arch.thread_idx()
216
+
217
+ tile_scheduler = TileScheduler.create(tile_sched_params)
218
+ work_tile = tile_scheduler.initial_work_tile_info()
219
+ m_block, head_idx, batch_idx, _ = work_tile.tile_idx
220
+
221
+ if work_tile.is_valid_tile:
222
+ # ///////////////////////////////////////////////////////////////////////////////
223
+ # Get the appropriate tiles for this thread block.
224
+ # ///////////////////////////////////////////////////////////////////////////////
225
+ seqlen = SeqlenInfoQK.create(
226
+ batch_idx,
227
+ mO.shape[1],
228
+ 0,
229
+ mCuSeqlensQ=mCuSeqlensQ,
230
+ mCuSeqlensK=None,
231
+ mSeqUsedQ=mSeqUsedQ,
232
+ mSeqUsedK=None,
233
+ )
234
+
235
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
236
+ mO_cur = mO[batch_idx, None, head_idx, None]
237
+ mdO_cur = mdO[batch_idx, None, head_idx, None]
238
+ mdPsum_cur = mdPsum[batch_idx, head_idx, None]
239
+ headdim_v = mO.shape[3]
240
+ else:
241
+ mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, head_idx, None])
242
+ mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
243
+
244
+ padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size
245
+ if cutlass.const_expr(self.arch >= 90):
246
+ padded_offset_q = padded_offset_q // self.m_block_size * self.m_block_size
247
+ mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None])
248
+ headdim_v = mO.shape[2]
249
+
250
+ blkOdO_shape = (self.m_block_size, self.head_dim_padded)
251
+ # (m_block_size, head_dim)
252
+ gO = cute.local_tile(mO_cur, blkOdO_shape, (m_block, 0))
253
+ gdO = cute.local_tile(mdO_cur, blkOdO_shape, (m_block, 0))
254
+
255
+ gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
256
+ # (CPY_Atom, CPY_M, CPY_K)
257
+ tOgO = gmem_thr_copy_O.partition_S(gO)
258
+ tOgdO = gmem_thr_copy_O.partition_S(gdO)
259
+
260
+ # ///////////////////////////////////////////////////////////////////////////////
261
+ # Predicate: Mark indices that need to copy when problem_shape isn't a multiple
262
+ # of tile_shape
263
+ # ///////////////////////////////////////////////////////////////////////////////
264
+ # Construct identity layout for KV
265
+ cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
266
+ tOcO = gmem_thr_copy_O.partition_S(cO)
267
+ t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO)
268
+ tOpO = utils.predicate_k(tOcO, limit=headdim_v)
269
+ tOpdO = utils.predicate_k(tOcO, limit=headdim_v)
270
+
271
+ seqlen_q = seqlen.seqlen_q
272
+ seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size)
273
+
274
+ if cutlass.const_expr(mLSE is not None):
275
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
276
+ mLSE_cur = mLSE[batch_idx, head_idx, None]
277
+ else:
278
+ mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[head_idx, None])
279
+
280
+ gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,))
281
+ lse = Float32.inf
282
+ if tidx < seqlen_q - m_block * self.m_block_size:
283
+ lse = gLSE[tidx]
284
+
285
+ tOrO = cute.make_fragment_like(tOgO)
286
+ tOrdO = cute.make_fragment_like(tOgdO)
287
+ assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0])
288
+ assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1])
289
+ assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2])
290
+ for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True):
291
+ # Instead of using tOcO, we using t0OcO and subtract the offset from the limit
292
+ # (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time.
293
+ if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]:
294
+ cute.copy(
295
+ gmem_thr_copy_O,
296
+ tOgO[None, m, None],
297
+ tOrO[None, m, None],
298
+ pred=tOpO[None, m, None]
299
+ if cutlass.const_expr(self.check_hdim_oob)
300
+ else None,
301
+ )
302
+ cute.copy(
303
+ gmem_thr_copy_O,
304
+ tOgdO[None, m, None],
305
+ tOrdO[None, m, None],
306
+ pred=tOpdO[None, m, None]
307
+ if cutlass.const_expr(self.check_hdim_oob)
308
+ else None,
309
+ )
310
+ # Sum across the "k" dimension
311
+ dpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce(
312
+ cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1)
313
+ )
314
+ threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0]
315
+ assert cute.arch.WARP_SIZE % threads_per_row == 0
316
+ dpsum = utils.warp_reduce(dpsum, operator.add, width=threads_per_row)
317
+ dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), Float32)
318
+ dP_sum.store(dpsum)
319
+
320
+ # Write dPsum from rmem -> gmem
321
+ gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (m_block,))
322
+ # Only the thread corresponding to column 0 writes out the dPsum to gmem
323
+ if tOcO[0, 0, 0][1] == 0:
324
+ for m in cutlass.range(cute.size(dP_sum), unroll_full=True):
325
+ row = tOcO[0, m, 0][0]
326
+ gdPsum[row] = dP_sum[m] if row < seqlen_q - m_block * self.m_block_size else 0.0
327
+
328
+ # Clear dQaccum
329
+ if cutlass.const_expr(mdQaccum is not None):
330
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
331
+ mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
332
+ else:
333
+ mdQaccum_cur = cute.domain_offset(
334
+ (padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None]
335
+ )
336
+
337
+ # HACK: Compiler doesn't seem to recognize that padding
338
+ # by padded_offset_q * self.head_dim_padded keeps alignment
339
+ # since statically divisible by 4
340
+
341
+ mdQaccum_cur_ptr = cute.make_ptr(
342
+ dtype=mdQaccum_cur.element_type,
343
+ value=mdQaccum_cur.iterator.toint(),
344
+ mem_space=mdQaccum_cur.iterator.memspace,
345
+ assumed_align=mdQaccum.iterator.alignment,
346
+ )
347
+ mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout)
348
+
349
+ blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,)
350
+ gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,))
351
+ gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx)
352
+ tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum)
353
+ zero = cute.make_fragment_like(tdQgdQaccum)
354
+ zero.fill(0.0)
355
+ cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum)
356
+
357
+ if cutlass.const_expr(mLSE is not None):
358
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
359
+ mLSElog2_cur = mLSElog2[batch_idx, head_idx, None]
360
+ else:
361
+ mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[head_idx, None])
362
+
363
+ gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,))
364
+ LOG2_E = math.log2(math.e)
365
+ if tidx < seqlen_q_rounded - m_block * self.m_block_size:
366
+ gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0