mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,610 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Tri Dao.
3
+
4
+ from typing import Optional, Callable
5
+ from dataclasses import dataclass
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass import Float32, Int32, const_expr
10
+
11
+ import mslk.attention.flash_attn.utils as utils
12
+ from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
13
+
14
+
15
+ @cute.jit
16
+ def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = False) -> None:
17
+ # Bit manipulation, compiles down to the R2P instruction
18
+ # For sm100: we know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using.
19
+ # For sm90: instead of comparing limit to 0, 1, 8, 9, 16, 17, ...,
20
+ # we compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ...
21
+ if const_expr(arch == 90):
22
+ col_limit_transformed = col_limit // 8 * 2 + min(col_limit % 8, 2)
23
+ else:
24
+ col_limit_transformed = col_limit
25
+ ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape))
26
+ # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31
27
+ for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
28
+ # Don't need to clamp to 32 since the shr.u32 instruction does that already
29
+ col_limit_right_s = max(col_limit_transformed - s * 24, 0)
30
+ # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11
31
+ mask = (1 << col_limit_right_s) - 1
32
+ # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
33
+ for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
34
+ in_bound = cutlass.Boolean(mask & (1 << i))
35
+ c = s * 24 + i
36
+ if const_expr(rank1):
37
+ X[c] = X[c] if in_bound else -Float32.inf
38
+ # This is the equivalent of:
39
+ # X[s * 24 + i] = X[s * 24 + i] if col_limit_right_s <= i else -Float32.inf
40
+ else:
41
+ for r in cutlass.range_constexpr(cute.size(X.shape[0])):
42
+ X[r, c] = X[r, c] if in_bound else -Float32.inf
43
+
44
+
45
+ @cute.jit
46
+ def mask_r2p_transposed(X: cute.Tensor, row_limit_top: Int32, num_rep: int) -> None:
47
+ # Bit manipulation, compiles down to the R2P instruction
48
+ # For sm100: we know that tScS_t2r[i][0] has the form 0, 1, ..., 31, 64, ..., 127
49
+ # or 0, 1, ..., 15, 32, ..., 47, 64, ...
50
+ # We compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ...
51
+ # Here we hardcode for the case of 2 warp groups.
52
+ num_wg = 2
53
+ row_limit_top_transformed = row_limit_top // (num_rep * num_wg) * num_rep + min(
54
+ row_limit_top % (num_rep * num_wg), num_rep
55
+ )
56
+ ncol = cute.size(X.shape)
57
+ # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31
58
+ for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
59
+ row_limit_top_s = max(row_limit_top_transformed - s * 24, 0)
60
+ # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11
61
+ mask = (1 << row_limit_top_s) - 1
62
+ # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
63
+ for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
64
+ out_bound = cutlass.Boolean(mask & (1 << i))
65
+ c = s * 24 + i
66
+ X[c] = -Float32.inf if out_bound else X[c]
67
+ # tidx = cute.arch.thread_idx()[0] % 256
68
+ # if tidx == 128:
69
+ # cute.printf("tidx = {}, s = {}, i = {}, row_limit_top = {}, row_limit_top_s = {}, mask = {}, out_bound = {}", tidx, s, i, row_limit_top, row_limit_top_s, mask, out_bound)
70
+
71
+
72
+ @dataclass(frozen=True)
73
+ class AttentionMask:
74
+ tile_m: cutlass.Constexpr[int]
75
+ tile_n: cutlass.Constexpr[int]
76
+ seqlen_info: SeqlenInfoQK
77
+ window_size_left: Optional[Int32] = None
78
+ window_size_right: Optional[Int32] = None
79
+ qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 # only pass in if we're doing PackGQA
80
+ swap_AB: cutlass.Constexpr[bool] = False
81
+
82
+ @property
83
+ def seqlen_q(self) -> Int32:
84
+ return self.seqlen_info.seqlen_q
85
+
86
+ @property
87
+ def seqlen_k(self) -> Int32:
88
+ return self.seqlen_info.seqlen_k
89
+
90
+ @cute.jit
91
+ def apply_mask(
92
+ self,
93
+ acc_S: cute.Tensor,
94
+ batch_idx: cutlass.Int32,
95
+ head_idx: cutlass.Int32,
96
+ m_block: cutlass.Int32,
97
+ n_block: cutlass.Int32,
98
+ thr_mma: cute.TiledMma,
99
+ mask_seqlen: cutlass.Constexpr[bool],
100
+ mask_causal: cutlass.Constexpr[bool],
101
+ mask_local: cutlass.Constexpr[bool] = False,
102
+ mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
103
+ aux_tensors: Optional[list] = None,
104
+ fastdiv_mods=(None, None),
105
+ ) -> None:
106
+ assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
107
+ acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.swap_AB)
108
+ acc_shape = (self.tile_m, self.tile_n)
109
+ cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1])
110
+ tScS_mn = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cS), transpose=self.swap_AB)
111
+ # We use t0ScS as these indices are known at compile time. We then must subtract the
112
+ # column limit by the thread column offset.
113
+ t0ScS_mn = utils.make_acc_tensor_mn_view(
114
+ thr_mma.get_slice(0).partition_C(cS), transpose=self.swap_AB
115
+ )
116
+ ROW = 0 if const_expr(not self.swap_AB) else 1
117
+ COL = 1 if const_expr(not self.swap_AB) else 0
118
+ thr_col_offset = tScS_mn[0][COL]
119
+ # To handle edge cases of completely masked out rows where n_block_max = 0,
120
+ # we treat negative n_blocks as 0th n_block
121
+ # TODO: find more transparent solution
122
+ if n_block < 0:
123
+ n_block = 0
124
+ seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
125
+ if const_expr(not mask_causal and not mask_local and mask_mod is None):
126
+ if const_expr(mask_seqlen):
127
+ # The compiler now choses not to use R2P
128
+ r2p = const_expr(False and not self.swap_AB)
129
+ if const_expr(not r2p):
130
+ # traverse column index.
131
+ for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
132
+ oob = t0ScS_mn[0, c][COL] >= seqlenk_col_limit
133
+ for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
134
+ acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c]
135
+ else:
136
+ mask_r2p(acc_S_mn, seqlenk_col_limit, arch=90)
137
+
138
+ elif const_expr(
139
+ not mask_causal and not mask_local and mask_mod is not None
140
+ ): # FlexAttention mask mod
141
+ nrow = const_expr(cute.size(tScS_mn.shape[0]))
142
+ ncol = const_expr(cute.size(tScS_mn.shape[1]))
143
+ has_fastdiv = const_expr(
144
+ fastdiv_mods is not None
145
+ and fastdiv_mods[0] is not None
146
+ and fastdiv_mods[1] is not None
147
+ )
148
+ wrap_aux_indices = const_expr(
149
+ has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None)
150
+ )
151
+
152
+ for r in cutlass.range_constexpr(nrow):
153
+ # Respect swap_AB: ROW/COL determine which coordinate component corresponds to Q/KV.
154
+ local_row = tScS_mn[r, 0][ROW]
155
+ global_row_idx = local_row + m_block * self.tile_m
156
+ row_for_mod = global_row_idx
157
+ head_idx_for_mod = head_idx
158
+ if const_expr(self.qhead_per_kvhead_packgqa != 1):
159
+ head_offset = global_row_idx % self.qhead_per_kvhead_packgqa
160
+ head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset
161
+ row_for_mod = global_row_idx // self.qhead_per_kvhead_packgqa
162
+ row_for_seqlen = row_for_mod
163
+ if const_expr(wrap_aux_indices):
164
+ _, row_for_mod = divmod(row_for_mod, fastdiv_mods[0])
165
+
166
+ for col in cutlass.range_constexpr(ncol):
167
+ col_idx_local = t0ScS_mn[0, col][COL]
168
+ # Convert to absolute column index
169
+ global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n
170
+ col_for_mod = global_col_idx
171
+ if const_expr(wrap_aux_indices):
172
+ _, col_for_mod = divmod(global_col_idx, fastdiv_mods[1])
173
+
174
+ batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
175
+ head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)
176
+ q_idx_ssa = utils.scalar_to_ssa(row_for_mod, cutlass.Int32)
177
+ kv_idx_ssa = utils.scalar_to_ssa(col_for_mod, cutlass.Int32)
178
+ mask_value = mask_mod(
179
+ batch_idx_ssa,
180
+ head_idx_ssa,
181
+ q_idx_ssa,
182
+ kv_idx_ssa,
183
+ self.seqlen_info,
184
+ aux_tensors,
185
+ )
186
+ cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
187
+ if const_expr(mask_seqlen):
188
+ out_of_bounds = (row_for_seqlen >= self.seqlen_q) or (
189
+ global_col_idx >= self.seqlen_k
190
+ )
191
+ if out_of_bounds:
192
+ acc_S_mn[r, col] = -cutlass.Float32.inf
193
+ else:
194
+ acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf
195
+ else:
196
+ acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf
197
+
198
+ else: # Causal or local
199
+ if const_expr(not self.swap_AB):
200
+ # If PackGQA, we split the work of compute divmod among threads in the same row
201
+ threads_per_row = thr_mma.tv_layout_C.shape[0][0]
202
+ mma_m_idx = None
203
+ if const_expr(self.qhead_per_kvhead_packgqa != 1):
204
+ assert not self.swap_AB, "swap_AB with PackGQA not supported yet"
205
+ assert cute.arch.WARP_SIZE % threads_per_row == 0, (
206
+ "threads_per_row must divide WARP_SIZE"
207
+ )
208
+ assert cute.size(acc_S_mn.shape[0]) <= threads_per_row
209
+ tidx = thr_mma.thr_idx
210
+ mma_m_idx = (
211
+ m_block * self.tile_m + tScS_mn[tidx % threads_per_row, 0][0]
212
+ ) // self.qhead_per_kvhead_packgqa
213
+ causal_row_offset = (
214
+ 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q - thr_col_offset
215
+ )
216
+ if const_expr(mask_causal):
217
+ r2p = const_expr(not self.swap_AB) # R2P trick, see apply_mask_sm100
218
+ for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
219
+ # get the column index limit based on current row. Only consider the row index, so the column index sets to 0.
220
+ if const_expr(self.qhead_per_kvhead_packgqa == 1):
221
+ row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
222
+ else:
223
+ row_idx = utils.shuffle_sync(
224
+ mma_m_idx, r % threads_per_row, width=threads_per_row
225
+ )
226
+ col_limit_right = row_idx + causal_row_offset
227
+ if const_expr(mask_seqlen):
228
+ col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
229
+ if const_expr(not r2p):
230
+ # traverse column index.
231
+ for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
232
+ acc_S_mn[r, c] = (
233
+ -Float32.inf
234
+ if t0ScS_mn[0, c][1] >= col_limit_right
235
+ else acc_S_mn[r, c]
236
+ )
237
+ else:
238
+ mask_r2p(acc_S_mn[r, None], col_limit_right, arch=90, rank1=True)
239
+ else: # Local
240
+ local_row_offset_right = (
241
+ causal_row_offset + self.window_size_right
242
+ if const_expr(self.window_size_right is not None)
243
+ else None
244
+ )
245
+ local_row_offset_left = (
246
+ causal_row_offset - 1 - self.window_size_left
247
+ if const_expr(self.window_size_left is not None)
248
+ else None
249
+ )
250
+ for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
251
+ if const_expr(self.qhead_per_kvhead_packgqa == 1):
252
+ row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
253
+ else:
254
+ row_idx = utils.shuffle_sync(
255
+ mma_m_idx, r % threads_per_row, width=threads_per_row
256
+ )
257
+ if const_expr(self.window_size_right is not None):
258
+ col_limit_right = row_idx + local_row_offset_right
259
+ else:
260
+ col_limit_right = self.tile_n
261
+ if const_expr(mask_seqlen):
262
+ col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
263
+ col_limit_left = (
264
+ row_idx + local_row_offset_left
265
+ if const_expr(self.window_size_left is not None)
266
+ else 0
267
+ )
268
+ # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left)
269
+ # traverse column index.
270
+ for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
271
+ col_idx = t0ScS_mn[0, c][1]
272
+ # only consider the column index, so the row index sets to 0.
273
+ if col_idx >= col_limit_right or col_idx < col_limit_left:
274
+ acc_S_mn[r, c] = -Float32.inf
275
+ else: # swap_AB
276
+ assert self.qhead_per_kvhead_packgqa == 1
277
+ thr_row_offset = tScS_mn[0][ROW]
278
+ causal_row_offset = (
279
+ seqlenk_col_limit - self.seqlen_q + m_block * self.tile_m + thr_row_offset
280
+ )
281
+ if const_expr(mask_causal):
282
+ for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
283
+ col0 = t0ScS_mn[0, c][COL]
284
+ # If col0 is beyond the column limit, we want to mask out the entire
285
+ # column, by setting row limit to be self.tile_m.
286
+ row_limit_top = (
287
+ self.tile_m
288
+ if col0 >= seqlenk_col_limit and mask_seqlen
289
+ else col0 - causal_row_offset
290
+ )
291
+ for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
292
+ acc_S_mn[r, c] = (
293
+ -Float32.inf
294
+ if t0ScS_mn[r, 0][ROW] < row_limit_top
295
+ else acc_S_mn[r, c]
296
+ )
297
+ else:
298
+ for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
299
+ col0 = t0ScS_mn[0, c][COL]
300
+ # If col0 is beyond the column limit, we want to mask out the entire
301
+ # column, by setting row limit to be self.tile_m.
302
+ row_limit_top = (
303
+ self.tile_m
304
+ if col0 >= seqlenk_col_limit
305
+ else col0 - causal_row_offset - self.window_size_right
306
+ )
307
+ # TODO: do we need col_limit_sink?
308
+ row_limit_bot = col0 - causal_row_offset + self.window_size_left
309
+ for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
310
+ row_idx = t0ScS_mn[r, 0][ROW]
311
+ acc_S_mn[r, c] = (
312
+ -Float32.inf
313
+ if row_idx < row_limit_top or row_idx > row_limit_bot
314
+ else acc_S_mn[r, c]
315
+ )
316
+
317
+ @cute.jit
318
+ def apply_mask_sm100(
319
+ self,
320
+ acc_S: cute.Tensor,
321
+ m_block: Int32,
322
+ n_block: Int32,
323
+ thr_mma: cute.TiledMma,
324
+ thr_tmem_load: cute.TiledCopy,
325
+ mask_seqlen: cutlass.Constexpr[bool],
326
+ mask_causal: cutlass.Constexpr[bool],
327
+ mask_local: cutlass.Constexpr[bool] = False,
328
+ mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
329
+ batch_idx: Int32 = None,
330
+ head_idx: Int32 = None,
331
+ aux_tensors: Optional[list] = None,
332
+ fastdiv_mods=(None, None),
333
+ check_q_boundary: bool = False,
334
+ ) -> None:
335
+ assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
336
+ acc_shape = (self.tile_m, self.tile_n)
337
+ cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1])
338
+ tScS = thr_mma.partition_C(cS)
339
+ tScS_t2r = thr_tmem_load.partition_D(tScS)
340
+ # To handle edge cases of completely masked out rows where n_block_max = 0,
341
+ # we treat negative n_blocks as 0th n_block
342
+ # TODO: find more transparent solution
343
+ if n_block < 0:
344
+ n_block = 0
345
+ seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n
346
+ r2p = True
347
+ if const_expr(not mask_causal and not mask_local and mask_mod is None):
348
+ if const_expr(mask_seqlen):
349
+ if const_expr(not r2p):
350
+ for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
351
+ # if tScS_t2r[i][1] >= seqlenk_col_limit:
352
+ # acc_S[i] = -Float32.inf
353
+ # For some reason the 2 lines above generate really bad SASS
354
+ acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i]
355
+ else:
356
+ mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True)
357
+
358
+ elif const_expr(not mask_causal and not mask_local and mask_mod is not None):
359
+ # Block sparse case w/ mask_mod
360
+ has_fastdiv = const_expr(
361
+ fastdiv_mods is not None
362
+ and fastdiv_mods[0] is not None
363
+ and fastdiv_mods[1] is not None
364
+ )
365
+ batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
366
+
367
+ ncol = const_expr(cute.size(tScS_t2r.shape))
368
+ for i in cutlass.range_constexpr(ncol):
369
+ row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1]
370
+ col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0]
371
+ global_row = row_coord + m_block * self.tile_m
372
+ global_col = col_coord + n_block * self.tile_n
373
+
374
+ if const_expr(self.qhead_per_kvhead_packgqa != 1):
375
+ head_offset = global_row % self.qhead_per_kvhead_packgqa
376
+ head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset
377
+ mask_row = global_row // self.qhead_per_kvhead_packgqa
378
+ else:
379
+ head_idx_for_mod = head_idx
380
+ mask_row = global_row
381
+
382
+ mask_row_for_mod = mask_row
383
+ if const_expr(has_fastdiv and aux_tensors is not None):
384
+ if check_q_boundary:
385
+ _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0])
386
+ global_col_for_mod = global_col
387
+ if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None):
388
+ _, global_col_for_mod = divmod(global_col, fastdiv_mods[1])
389
+
390
+ head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)
391
+ mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32)
392
+ kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32)
393
+ mask_value = mask_mod(
394
+ batch_idx_ssa,
395
+ head_idx_ssa,
396
+ mask_row_ssa,
397
+ kv_idx_ssa,
398
+ self.seqlen_info,
399
+ aux_tensors,
400
+ )
401
+ cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
402
+ acc_S[i] = acc_S[i] if cond else -Float32.inf
403
+ if const_expr(mask_seqlen):
404
+ acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i]
405
+ if check_q_boundary:
406
+ acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i]
407
+
408
+ else: # Causal or local
409
+ causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q
410
+ row_idx = tScS_t2r[0][0] + m_block * self.tile_m
411
+ if const_expr(self.qhead_per_kvhead_packgqa != 1):
412
+ row_idx = row_idx // self.qhead_per_kvhead_packgqa
413
+ if const_expr(mask_causal):
414
+ col_limit_right = row_idx + causal_row_offset
415
+ if const_expr(mask_seqlen):
416
+ col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
417
+ # if cute.arch.thread_idx()[0] % 32 == 0:
418
+ # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset)
419
+ ncol = const_expr(cute.size(tScS_t2r.shape))
420
+ if const_expr(not r2p):
421
+ for i in cutlass.range(ncol, unroll_full=True):
422
+ acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i]
423
+ else:
424
+ mask_r2p(acc_S, col_limit_right, arch=100, rank1=True)
425
+ else:
426
+ local_row_offset_right = (
427
+ causal_row_offset + self.window_size_right
428
+ if const_expr(self.window_size_right is not None)
429
+ else None
430
+ )
431
+ local_row_offset_left = (
432
+ causal_row_offset - 1 - self.window_size_left
433
+ if const_expr(self.window_size_left is not None)
434
+ else None
435
+ )
436
+ if const_expr(self.window_size_right is not None):
437
+ col_limit_right = row_idx + local_row_offset_right
438
+ else:
439
+ col_limit_right = self.tile_n
440
+ if const_expr(mask_seqlen):
441
+ col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
442
+ col_limit_left = (
443
+ row_idx + local_row_offset_left
444
+ if const_expr(self.window_size_left is not None)
445
+ else 0
446
+ )
447
+ # if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left)
448
+ for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
449
+ col_idx = tScS_t2r[i][1]
450
+ acc_S[i] = (
451
+ -Float32.inf
452
+ if col_idx >= col_limit_right or col_idx < col_limit_left
453
+ else acc_S[i]
454
+ )
455
+
456
+ @cute.jit
457
+ def apply_mask_sm100_transposed(
458
+ self,
459
+ acc_S: cute.Tensor,
460
+ tScS_t2r: cute.Tensor,
461
+ t0ScS_t2r: cute.Tensor,
462
+ m_block: cutlass.Int32,
463
+ n_block: cutlass.Int32,
464
+ mask_seqlen: cutlass.Constexpr,
465
+ mask_causal: cutlass.Constexpr,
466
+ mask_local: cutlass.Constexpr,
467
+ mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
468
+ batch_idx: Int32 = None,
469
+ head_idx: Int32 = None,
470
+ aux_tensors: Optional[list] = None,
471
+ fastdiv_mods=(None, None),
472
+ is_full_block: bool = False,
473
+ check_m_boundary: bool = True,
474
+ ) -> None:
475
+ """
476
+ Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q.
477
+
478
+ Coordinate conventio:
479
+ - ROW corresponds to Q (m_block)
480
+ - COL corresponds to KV (n_block)
481
+
482
+ is_full_block: If True, skip mask_mod (all elements valid). Only apply seqlen masking.
483
+ check_m_boundary: If False, skip seqlen_q boundary check (optimization for non-boundary m_blocks).
484
+ When iterating m_blocks in forward order, only the last m_block may be partial.
485
+ """
486
+ assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
487
+ ROW = 0 if const_expr(not self.swap_AB) else 1
488
+ COL = 1 if const_expr(not self.swap_AB) else 0
489
+ assert t0ScS_t2r[0][COL] == 0, "col0 == 0"
490
+ thr_col_offset = tScS_t2r[0][COL]
491
+ seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
492
+
493
+ if const_expr(not mask_causal and not mask_local and mask_mod is not None):
494
+ # Block sparse case with mask_mod (backward)
495
+ #
496
+ # Coordinate convention: ROW → Q (m_block), COL → KV (n_block).
497
+ # These already account for swap_AB.
498
+ #
499
+ # FULL blocks: mask_mod returns True for all elements, so skip it.
500
+ # Still need seqlen bounds check (elements may be OOB on last m_block).
501
+ # PARTIAL blocks: apply mask_mod element-wise, then seqlen bounds.
502
+ if is_full_block:
503
+ if const_expr(mask_seqlen):
504
+ if seqlenk_col_limit <= 0:
505
+ # Entire tile is OOB for K
506
+ for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
507
+ acc_S[i] = -cutlass.Float32.inf
508
+ elif check_m_boundary:
509
+ # Last m_block: check Q and K boundaries
510
+ ncol = const_expr(cute.size(tScS_t2r.shape))
511
+ for i in cutlass.range_constexpr(ncol):
512
+ row_coord = tScS_t2r[i][ROW]
513
+ col_coord = tScS_t2r[i][COL]
514
+ global_q = row_coord + m_block * self.tile_m
515
+ global_kv = col_coord + n_block * self.tile_n
516
+ q_out_of_bounds = global_q >= self.seqlen_q
517
+ kv_out_of_bounds = global_kv >= self.seqlen_k
518
+ out_of_bounds = q_out_of_bounds or kv_out_of_bounds
519
+ acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i]
520
+ else:
521
+ # Partial block
522
+ has_fastdiv = const_expr(
523
+ fastdiv_mods is not None
524
+ and fastdiv_mods[0] is not None
525
+ and fastdiv_mods[1] is not None
526
+ )
527
+ wrap_aux_indices = const_expr(
528
+ has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None)
529
+ )
530
+ batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
531
+ head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32)
532
+
533
+ ncol = const_expr(cute.size(tScS_t2r.shape))
534
+ for i in cutlass.range_constexpr(ncol):
535
+ row_coord = tScS_t2r[i][ROW]
536
+ col_coord = tScS_t2r[i][COL]
537
+ global_q = row_coord + m_block * self.tile_m
538
+ global_kv = col_coord + n_block * self.tile_n
539
+
540
+ q_idx_for_mod = global_q
541
+ kv_idx_for_mod = global_kv
542
+ if const_expr(wrap_aux_indices):
543
+ _, q_idx_for_mod = divmod(global_q, fastdiv_mods[0])
544
+ _, kv_idx_for_mod = divmod(global_kv, fastdiv_mods[1])
545
+
546
+ q_idx_ssa = utils.scalar_to_ssa(q_idx_for_mod, cutlass.Int32)
547
+ kv_idx_ssa = utils.scalar_to_ssa(kv_idx_for_mod, cutlass.Int32)
548
+
549
+ mask_value = mask_mod(
550
+ batch_idx_ssa,
551
+ head_idx_ssa,
552
+ q_idx_ssa,
553
+ kv_idx_ssa,
554
+ self.seqlen_info,
555
+ aux_tensors,
556
+ )
557
+ cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
558
+ acc_S[i] = acc_S[i] if cond else -cutlass.Float32.inf
559
+
560
+ if const_expr(mask_seqlen):
561
+ # check_m_boundary=False skips q check for non-boundary m_blocks
562
+ q_out_of_bounds = check_m_boundary and (global_q >= self.seqlen_q)
563
+ kv_out_of_bounds = global_kv >= self.seqlen_k
564
+ out_of_bounds = q_out_of_bounds or kv_out_of_bounds
565
+ acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i]
566
+
567
+ elif const_expr(not mask_causal and not mask_local):
568
+ if const_expr(mask_seqlen):
569
+ if seqlenk_col_limit <= 0:
570
+ for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
571
+ acc_S[i] = -cutlass.Float32.inf
572
+ else: # Causal or local
573
+ thr_row_offset = tScS_t2r[0][ROW]
574
+ seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset
575
+ causal_offset = seqlenq_row_limit - seqlenk_col_limit
576
+ if const_expr(mask_causal):
577
+ # tidx = cute.arch.thread_idx()[0] % 256
578
+ # if tidx < 32:
579
+ # cute.printf("tidx = {}, {} {}, {} {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1])
580
+ row_limit_top = causal_offset
581
+ if const_expr(mask_seqlen):
582
+ # If col is beyond the column limit, we want to mask out the entire
583
+ # column, by setting row limit to be self.tile_m.
584
+ if seqlenk_col_limit <= 0:
585
+ row_limit_top = self.tile_m
586
+ r2p = True
587
+ if const_expr(not r2p):
588
+ for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
589
+ acc_S[i] = (
590
+ -cutlass.Float32.inf if t0ScS_t2r[i][ROW] < row_limit_top else acc_S[i]
591
+ )
592
+ else:
593
+ num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32
594
+ mask_r2p_transposed(acc_S, row_limit_top, num_rep)
595
+ else:
596
+ if const_expr(self.window_size_right is not None):
597
+ row_limit_top = causal_offset - self.window_size_right
598
+ else:
599
+ row_limit_top = 0
600
+ if const_expr(self.window_size_left is not None):
601
+ row_limit_bot = causal_offset + self.window_size_left
602
+ if const_expr(mask_seqlen):
603
+ if seqlenk_col_limit <= 0:
604
+ row_limit_top = self.tile_m
605
+ for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
606
+ row_idx = t0ScS_t2r[i][ROW]
607
+ local_mask = row_idx < row_limit_top
608
+ if const_expr(self.window_size_left is not None):
609
+ local_mask |= row_idx > row_limit_bot
610
+ acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i]