mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,583 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Tri Dao.
3
+
4
+ import math
5
+ import operator
6
+ from typing import Tuple
7
+ from dataclasses import dataclass
8
+
9
+ import cutlass
10
+ import cutlass.cute as cute
11
+ from cutlass import Float32
12
+
13
+ import mslk.attention.flash_attn.utils as utils
14
+ from mslk.attention.flash_attn.cute_dsl_utils import ParamsBase
15
+ from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
16
+
17
+
18
+ @dataclass
19
+ class Softmax(ParamsBase):
20
+ scale_log2: Float32
21
+ num_rows: cutlass.Constexpr[int]
22
+ row_max: cute.Tensor
23
+ row_sum: cute.Tensor
24
+ arch: cutlass.Constexpr[int] = 80
25
+ softmax_scale: Float32 | None = None
26
+
27
+ @staticmethod
28
+ def create(
29
+ scale_log2: Float32,
30
+ num_rows: cutlass.Constexpr[int],
31
+ arch: cutlass.Constexpr[int] = 80,
32
+ softmax_scale: Float32 | None = None,
33
+ ):
34
+ row_max = cute.make_rmem_tensor(num_rows, Float32)
35
+ row_sum = cute.make_rmem_tensor(num_rows, Float32)
36
+ return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale)
37
+
38
+ def reset(self) -> None:
39
+ self.row_max.fill(-Float32.inf)
40
+ self.row_sum.fill(0.0)
41
+
42
+ def _compute_row_max(
43
+ self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None
44
+ ) -> Float32:
45
+ return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch)
46
+
47
+ def _compute_row_sum(
48
+ self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None
49
+ ) -> Float32:
50
+ return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch)
51
+
52
+ @cute.jit
53
+ def online_softmax(
54
+ self,
55
+ acc_S: cute.Tensor,
56
+ is_first: cutlass.Constexpr[bool] = False,
57
+ check_inf: cutlass.Constexpr[bool] = True,
58
+ ) -> cute.Tensor:
59
+ """Apply online softmax and return the row_scale to rescale O.
60
+
61
+ :param acc_S: acc_S tensor
62
+ :type acc_S: cute.Tensor
63
+ :param is_first: is first n_block
64
+ :type is_first: cutlass.Constexpr
65
+ """
66
+ # Change acc_S to M,N layout view.
67
+ acc_S_mn = utils.make_acc_tensor_mn_view(acc_S)
68
+ row_scale = cute.make_fragment_like(self.row_max, Float32)
69
+
70
+ row_max = self.row_max
71
+ row_sum = self.row_sum
72
+ scale_log2 = self.scale_log2
73
+ arch = self.arch
74
+
75
+ # Each iteration processes one row of acc_S
76
+ for r in cutlass.range(cute.size(row_max), unroll_full=True):
77
+ acc_S_row = acc_S_mn[r, None].load() # (n_block_size)
78
+
79
+ row_max_cur = utils.fmax_reduce(
80
+ acc_S_row,
81
+ init_val=row_max[r] if cutlass.const_expr(not is_first) else None,
82
+ arch=arch,
83
+ )
84
+
85
+ row_max_cur = utils.warp_reduce(row_max_cur, cute.arch.fmax, width=4)
86
+ # Update row_max before changing row_max_cur to safe value for -inf
87
+ row_max_prev = row_max[r]
88
+ row_max[r] = row_max_cur
89
+
90
+ if cutlass.const_expr(check_inf):
91
+ row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur
92
+
93
+ if cutlass.const_expr(is_first):
94
+ row_max_cur_scaled = row_max_cur * scale_log2
95
+ acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled)
96
+
97
+ acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch)
98
+ row_scale[r] = 1.0
99
+ else:
100
+ row_max_cur_scaled = row_max_cur * scale_log2
101
+ acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled)
102
+ # row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled)
103
+ row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * scale_log2)
104
+
105
+ acc_S_row_sum = utils.fadd_reduce(
106
+ acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch
107
+ )
108
+
109
+ row_sum[r] = acc_S_row_sum
110
+ acc_S_mn[r, None].store(acc_S_row_exp)
111
+
112
+ return row_scale
113
+
114
+ @cute.jit
115
+ def finalize(
116
+ self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None
117
+ ) -> cute.Tensor:
118
+ """Finalize the online softmax by computing the scale and logsumexp."""
119
+ if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)):
120
+ assert cute.size(sink_val) == cute.size(self.row_sum)
121
+ row_sum = self.row_sum
122
+ row_max = self.row_max
123
+ scale_log2 = self.scale_log2
124
+
125
+ # quad reduction for row_sum as we didn't do it during each iteration of online softmax
126
+ row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4))
127
+ row_scale = cute.make_fragment_like(row_max, Float32)
128
+
129
+ for r in cutlass.range(cute.size(row_sum), unroll_full=True):
130
+ if cutlass.const_expr(sink_val is not None):
131
+ sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r]
132
+ LOG2_E = math.log2(math.e)
133
+ row_sum[r] += utils.exp2f(sink_val_cur * LOG2_E - row_max[r] * scale_log2)
134
+
135
+ # if row_sum is zero or nan, set acc_O_mn_row to 1.0
136
+ acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r]
137
+ row_scale[r] = (
138
+ cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0)
139
+ ) * final_scale
140
+ row_sum_cur = row_sum[r]
141
+ LN2 = math.log(2.0)
142
+ row_sum[r] = (
143
+ (row_max[r] * scale_log2 + utils.log2f(row_sum_cur)) * LN2
144
+ if not acc_O_mn_row_is_zero_or_nan
145
+ else -Float32.inf
146
+ )
147
+ return row_scale
148
+
149
+ @cute.jit
150
+ def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None:
151
+ """Scale each row of acc_O by the given scale tensor.
152
+ :param acc_O: input tensor
153
+ :type acc_O: cute.Tensor
154
+ :param row_scale: row_scale tensor
155
+ :type row_scale: cute.Tensor
156
+ """
157
+ acc_O_mn = utils.make_acc_tensor_mn_view(acc_O)
158
+ assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0])
159
+ for r in cutlass.range(cute.size(row_scale), unroll_full=True):
160
+ acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r])
161
+
162
+
163
+ @dataclass
164
+ class SoftmaxSm100(Softmax):
165
+ rescale_threshold: cutlass.Constexpr[float] = 0.0
166
+
167
+ @staticmethod
168
+ def create(
169
+ scale_log2: Float32,
170
+ rescale_threshold: cutlass.Constexpr[float] = 0.0,
171
+ softmax_scale: Float32 | None = None,
172
+ ):
173
+ num_rows = 1
174
+ arch = 100
175
+ row_max = cute.make_rmem_tensor(num_rows, Float32)
176
+ row_sum = cute.make_rmem_tensor(num_rows, Float32)
177
+ return SoftmaxSm100(
178
+ scale_log2,
179
+ num_rows,
180
+ row_max,
181
+ row_sum,
182
+ arch,
183
+ softmax_scale,
184
+ rescale_threshold=rescale_threshold,
185
+ )
186
+
187
+ @cute.jit
188
+ def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]:
189
+ if cutlass.const_expr(is_first):
190
+ row_max_new = self._compute_row_max(acc_S_row)
191
+ row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
192
+ acc_scale = 0.0
193
+ else:
194
+ row_max_old = self.row_max[0]
195
+ row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old)
196
+ row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
197
+ acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2
198
+ acc_scale = utils.exp2f(acc_scale_)
199
+ if cutlass.const_expr(self.rescale_threshold > 0.0):
200
+ if acc_scale_ >= -self.rescale_threshold:
201
+ row_max_new = row_max_old
202
+ row_max_safe = row_max_old
203
+ acc_scale = 1.0
204
+ self.row_max[0] = row_max_new
205
+ return row_max_safe, acc_scale
206
+
207
+ def update_row_sum(
208
+ self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False
209
+ ) -> None:
210
+ init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None
211
+ # self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale)
212
+ self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val)
213
+ # tmp = self._compute_row_sum(acc_S_row_exp)
214
+ # self.row_sum[0] = self.row_sum[0] * row_scale + tmp
215
+
216
+ @cute.jit
217
+ def scale_subtract_rowmax(
218
+ self,
219
+ acc_S_row: cute.Tensor,
220
+ row_max: Float32,
221
+ ):
222
+ assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
223
+ row_max_scaled = row_max * self.scale_log2
224
+ for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True):
225
+ acc_S_row[i], acc_S_row[i + 1] = utils.fma_packed_f32x2(
226
+ (acc_S_row[i], acc_S_row[i + 1]),
227
+ (self.scale_log2, self.scale_log2),
228
+ (-row_max_scaled, -row_max_scaled),
229
+ )
230
+
231
+ @cute.jit
232
+ def apply_exp2_convert(
233
+ self,
234
+ acc_S_row: cute.Tensor,
235
+ acc_S_row_converted: cute.Tensor,
236
+ e2e: cutlass.Constexpr[bool] = False,
237
+ e2e_freq: cutlass.Constexpr[int] = 16,
238
+ e2e_res: cutlass.Constexpr[int] = 4,
239
+ e2e_frg_limit: cutlass.Constexpr[int] = 1,
240
+ ):
241
+ assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
242
+ frg_tile = 32
243
+ assert frg_tile % 2 == 0
244
+ frg_cnt = cute.size(acc_S_row) // frg_tile
245
+ assert cute.size(acc_S_row) % frg_tile == 0
246
+ acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))
247
+ acc_S_row_converted_frg = cute.logical_divide(
248
+ acc_S_row_converted, cute.make_layout(frg_tile)
249
+ )
250
+ for j in cutlass.range_constexpr(frg_cnt):
251
+ for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
252
+ # acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j])
253
+ # acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j])
254
+ if cutlass.const_expr(not e2e):
255
+ acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j])
256
+ acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j])
257
+ else:
258
+ if cutlass.const_expr(
259
+ k % e2e_freq < e2e_freq - e2e_res or j >= frg_cnt - e2e_frg_limit
260
+ ):
261
+ acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j])
262
+ acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j])
263
+ else:
264
+ # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j])
265
+ acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2(
266
+ acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]
267
+ )
268
+ acc_S_row_converted_frg[None, j].store(
269
+ acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
270
+ )
271
+
272
+ @cute.jit
273
+ def scale_apply_exp2_convert(
274
+ self,
275
+ acc_S_row: cute.Tensor,
276
+ row_max: Float32,
277
+ acc_S_row_converted: cute.Tensor,
278
+ ):
279
+ assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
280
+ minus_row_max_scaled = -row_max * self.scale_log2
281
+ for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2):
282
+ acc_S_row[i], acc_S_row[i + 1] = utils.fma_packed_f32x2(
283
+ (acc_S_row[i], acc_S_row[i + 1]),
284
+ (self.scale_log2, self.scale_log2),
285
+ (minus_row_max_scaled, minus_row_max_scaled),
286
+ )
287
+
288
+ # for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2):
289
+ # acc_S_row[i], acc_S_row[i + 1] = utils.fma_packed_f32x2(
290
+ # (acc_S_row[i], acc_S_row[i + 1]),
291
+ # (self.scale_log2, self.scale_log2),
292
+ # (minus_row_max_scaled, minus_row_max_scaled),
293
+ # )
294
+ # acc_S_row[i] = cute.arch.exp2(acc_S_row[i])
295
+ # acc_S_row[i + 1] = cute.arch.exp2(acc_S_row[i + 1])
296
+
297
+ frg_tile = 32
298
+ assert frg_tile % 2 == 0
299
+ frg_cnt = cute.size(acc_S_row) // frg_tile
300
+ assert cute.size(acc_S_row) % frg_tile == 0
301
+ acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))
302
+ acc_S_row_converted_frg = cute.logical_divide(
303
+ acc_S_row_converted, cute.make_layout(frg_tile)
304
+ )
305
+ for j in cutlass.range_constexpr(frg_cnt):
306
+ for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
307
+ # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = (
308
+ # utils.fma_packed_f32x2(
309
+ # (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]),
310
+ # (self.scale_log2, self.scale_log2),
311
+ # (minus_row_max_scaled, minus_row_max_scaled),
312
+ # )
313
+ # )
314
+ # acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j])
315
+ # acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j])
316
+ acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j])
317
+ acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j])
318
+ acc_S_row_converted_frg[None, j].store(
319
+ acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
320
+ )
321
+
322
+
323
+ @cute.jit
324
+ def floor_if_packed(
325
+ q_idx,
326
+ qhead_per_kvhead: cutlass.Constexpr[int],
327
+ ) -> cute.Tensor:
328
+ """Convert q_idx to packed format for Pack-GQA."""
329
+ if cutlass.const_expr(qhead_per_kvhead == 1):
330
+ return q_idx
331
+ return q_idx // qhead_per_kvhead
332
+
333
+
334
+ @cute.jit
335
+ def apply_score_mod_inner(
336
+ score_tensor,
337
+ index_tensor,
338
+ score_mod: cutlass.Constexpr,
339
+ batch_idx,
340
+ head_idx,
341
+ softmax_scale,
342
+ vec_size: cutlass.Constexpr,
343
+ qk_acc_dtype: cutlass.Constexpr,
344
+ aux_tensors,
345
+ fastdiv_mods,
346
+ seqlen_info: SeqlenInfoQK,
347
+ constant_q_idx: cutlass.Constexpr,
348
+ qhead_per_kvhead: cutlass.Constexpr[int] = 1,
349
+ transpose_indices: cutlass.Constexpr[bool] = False,
350
+ ):
351
+ """Shared implementation for applying score modification.
352
+
353
+ Args:
354
+ score_tensor: The scores to modify (acc_S for flash_fwd, tSrS_t2r for sm100)
355
+ index_tensor: Index positions (tScS for flash_fwd, tScS_t2r for sm100)
356
+ score_mod: The score modification function to apply
357
+ batch_idx: Batch index
358
+ head_idx: Head index
359
+ softmax_scale: Scale to apply
360
+ vec_size: Vector size for processing elements
361
+ qk_acc_dtype: Data type for accumulator
362
+ aux_tensors: Optional aux_tensors for FlexAttention
363
+ fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping
364
+ seqlen_info: Sequence length info
365
+ constant_q_idx: If provided, use this constant for all q_idx values
366
+ If None, compute q_idx per-element
367
+ qhead_per_kvhead_packgqa: Pack-GQA replication factor. Divide q_idx by this
368
+ when greater than 1 so score mods see logical heads.
369
+ transpose_indices: If True, swap q_idx/kv_idx in index_tensor (for bwd kernel where S is transposed)
370
+ """
371
+ # Index positions in the index_tensor tuple
372
+ # Forward: index_tensor[...][0] = q_idx, index_tensor[...][1] = kv_idx
373
+ # Backward (transposed): index_tensor[...][0] = kv_idx, index_tensor[...][1] = q_idx
374
+ if cutlass.const_expr(transpose_indices):
375
+ q_idx_pos = cutlass.const_expr(1)
376
+ kv_idx_pos = cutlass.const_expr(0)
377
+ else:
378
+ q_idx_pos = cutlass.const_expr(0)
379
+ kv_idx_pos = cutlass.const_expr(1)
380
+
381
+ n_vals = cutlass.const_expr(cute.size(score_tensor.shape))
382
+ score_vec = cute.make_rmem_tensor(vec_size, qk_acc_dtype)
383
+ kv_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)
384
+
385
+ # SSA values for batch (constant across all elements)
386
+ batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,))
387
+
388
+ # Handle q_idx based on whether it's constant
389
+ q_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)
390
+
391
+ # For Pack-GQA with non-constant q_idx, we need per-element head indices
392
+ # since a thread my process multiple query head indices
393
+ if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
394
+ head_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)
395
+
396
+ for i in cutlass.range(0, n_vals, vec_size, unroll_full=True):
397
+ for j in cutlass.range(vec_size, unroll_full=True):
398
+ score_vec[j] = score_tensor[i + j] * softmax_scale
399
+
400
+ # Extract head offset from packed q_idx for Pack-GQA
401
+ if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
402
+ q_idx_packed = index_tensor[i + j][q_idx_pos]
403
+ # Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead)
404
+ q_idx_logical = q_idx_packed // qhead_per_kvhead
405
+ head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead
406
+ head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset
407
+
408
+ # If we will do loads we mod, in order to not read OOB
409
+ if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None):
410
+ if cutlass.const_expr(constant_q_idx is None):
411
+ seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
412
+ q_idx_floored = floor_if_packed(
413
+ index_tensor[i + j][q_idx_pos], qhead_per_kvhead
414
+ )
415
+ _, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod)
416
+ q_idx_vec[j] = q_idx_wrapped
417
+ else:
418
+ _, seqlen_k_divmod = fastdiv_mods
419
+
420
+ _, kv_idx_wrapped = divmod(index_tensor[i + j][kv_idx_pos], seqlen_k_divmod)
421
+ kv_idx_vec[j] = kv_idx_wrapped
422
+ else:
423
+ # No bounds checking - direct indexing
424
+ if constant_q_idx is None:
425
+ q_idx_vec[j] = floor_if_packed(index_tensor[i + j][q_idx_pos], qhead_per_kvhead)
426
+ kv_idx_vec[j] = index_tensor[i + j][kv_idx_pos]
427
+
428
+ # Convert to SSA for score_mod call
429
+ score_ssa = score_vec.load()
430
+ kv_idx_ssa = kv_idx_vec.load()
431
+ if cutlass.const_expr(constant_q_idx is None):
432
+ q_idx_ssa = q_idx_vec.load()
433
+ else:
434
+ # NB we do not apply Pack-GQA division here, as constant_q_idx is assumed to already be logical
435
+ q_idx_const = constant_q_idx
436
+ q_idx_ssa = utils.scalar_to_ssa(q_idx_const, cutlass.Int32).broadcast_to((vec_size,))
437
+
438
+ # Compute head_idx_ssa: per-element for Pack-GQA with non-constant q_idx, constant otherwise
439
+ if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
440
+ head_idx_ssa = head_idx_vec.load()
441
+ else:
442
+ head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,))
443
+
444
+ aux_args = []
445
+ if cutlass.const_expr(aux_tensors is not None):
446
+ aux_args = aux_tensors
447
+
448
+ post_mod_scores = score_mod(
449
+ score_ssa,
450
+ batch_idx_ssa,
451
+ head_idx_ssa,
452
+ q_idx=q_idx_ssa,
453
+ kv_idx=kv_idx_ssa,
454
+ seqlen_info=seqlen_info,
455
+ aux_tensors=aux_args,
456
+ )
457
+
458
+ # Write back modified scores
459
+ score_vec.store(post_mod_scores)
460
+ for j in cutlass.range(vec_size, unroll_full=True):
461
+ score_tensor[i + j] = score_vec[j]
462
+
463
+
464
+ @cute.jit
465
+ def apply_score_mod_bwd_inner(
466
+ grad_tensor,
467
+ score_tensor,
468
+ index_tensor,
469
+ score_mod_bwd: cutlass.Constexpr,
470
+ batch_idx,
471
+ head_idx,
472
+ softmax_scale,
473
+ vec_size: cutlass.Constexpr,
474
+ qk_acc_dtype: cutlass.Constexpr,
475
+ aux_tensors,
476
+ fastdiv_mods,
477
+ seqlen_info,
478
+ constant_q_idx: cutlass.Constexpr,
479
+ qhead_per_kvhead: cutlass.Constexpr[int] = 1,
480
+ transpose_indices: cutlass.Constexpr[bool] = False,
481
+ ):
482
+ """Apply backward score modification (joint graph).
483
+
484
+ Args:
485
+ grad_tensor: in/out: dlogits rewritten in-place with d(scaled_scores)
486
+ score_tensor: pre-mod scores (unscaled QK tile), scaled by softmax_scale internally
487
+ index_tensor: Index positions (same as forward)
488
+ score_mod_bwd: The backward score modification function (joint graph)
489
+ batch_idx: Batch index
490
+ head_idx: Head index
491
+ softmax_scale: Scale to apply to score_tensor
492
+ vec_size: Vector size for processing elements
493
+ qk_acc_dtype: Data type for accumulator
494
+ aux_tensors: Optional aux_tensors for FlexAttention
495
+ fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping
496
+ seqlen_info: Sequence length info
497
+ constant_q_idx: If provided, use this constant for all q_idx values
498
+ qhead_per_kvhead: Pack-GQA replication factor
499
+ transpose_indices: If True, swap q_idx/kv_idx in index_tensor
500
+ """
501
+ # Index positions in the index_tensor tuple
502
+ # Forward: index_tensor[...][0] = q_idx, index_tensor[...][1] = kv_idx
503
+ # Backward (transposed): index_tensor[...][0] = kv_idx, index_tensor[...][1] = q_idx
504
+ if cutlass.const_expr(transpose_indices):
505
+ q_idx_pos = cutlass.const_expr(1)
506
+ kv_idx_pos = cutlass.const_expr(0)
507
+ else:
508
+ q_idx_pos = cutlass.const_expr(0)
509
+ kv_idx_pos = cutlass.const_expr(1)
510
+ n_vals = cutlass.const_expr(cute.size(grad_tensor.shape))
511
+ grad_vec = cute.make_fragment(vec_size, qk_acc_dtype)
512
+ score_vec = cute.make_fragment(vec_size, qk_acc_dtype)
513
+ kv_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)
514
+ batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,))
515
+ q_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)
516
+
517
+ # For Pack-GQA with non-constant q_idx, we need per-element head indices
518
+ if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
519
+ head_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)
520
+
521
+ for i in cutlass.range(0, n_vals, vec_size, unroll_full=True):
522
+ for j in cutlass.range(vec_size, unroll_full=True):
523
+ grad_vec[j] = grad_tensor[i + j]
524
+ # Scale score so joint graph sees same value as forward score_mod
525
+ score_vec[j] = score_tensor[i + j] * softmax_scale
526
+
527
+ if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
528
+ q_idx_packed = index_tensor[i + j][q_idx_pos]
529
+ q_idx_logical = q_idx_packed // qhead_per_kvhead
530
+ head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead
531
+ head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset
532
+
533
+ if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None):
534
+ if cutlass.const_expr(constant_q_idx is None):
535
+ seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
536
+ q_idx_floored = floor_if_packed(
537
+ index_tensor[i + j][q_idx_pos], qhead_per_kvhead
538
+ )
539
+ _, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod)
540
+ q_idx_vec[j] = q_idx_wrapped
541
+ else:
542
+ _, seqlen_k_divmod = fastdiv_mods
543
+
544
+ _, kv_idx_wrapped = divmod(index_tensor[i + j][kv_idx_pos], seqlen_k_divmod)
545
+ kv_idx_vec[j] = kv_idx_wrapped
546
+ else:
547
+ # No bounds checking - direct indexing
548
+ if constant_q_idx is None:
549
+ q_idx_vec[j] = floor_if_packed(index_tensor[i + j][q_idx_pos], qhead_per_kvhead)
550
+ kv_idx_vec[j] = index_tensor[i + j][kv_idx_pos]
551
+
552
+ grad_ssa = grad_vec.load()
553
+ score_ssa = score_vec.load()
554
+ kv_idx_ssa = kv_idx_vec.load()
555
+
556
+ if cutlass.const_expr(constant_q_idx is None):
557
+ q_idx_ssa = q_idx_vec.load()
558
+ else:
559
+ q_idx_ssa = utils.scalar_to_ssa(constant_q_idx, cutlass.Int32).broadcast_to((vec_size,))
560
+
561
+ if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
562
+ head_idx_ssa = head_idx_vec.load()
563
+ else:
564
+ head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,))
565
+
566
+ aux_args = []
567
+ if cutlass.const_expr(aux_tensors is not None):
568
+ aux_args = aux_tensors
569
+
570
+ grad_out_ssa = score_mod_bwd(
571
+ grad_ssa,
572
+ score_ssa,
573
+ batch_idx_ssa,
574
+ head_idx_ssa,
575
+ q_idx=q_idx_ssa,
576
+ kv_idx=kv_idx_ssa,
577
+ seqlen_info=seqlen_info,
578
+ aux_tensors=aux_args,
579
+ )
580
+
581
+ grad_vec.store(grad_out_ssa)
582
+ for j in cutlass.range(vec_size, unroll_full=True):
583
+ grad_tensor[i + j] = grad_vec[j]