mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,583 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Tri Dao.
|
|
3
|
+
|
|
4
|
+
import math
|
|
5
|
+
import operator
|
|
6
|
+
from typing import Tuple
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
import cutlass
|
|
10
|
+
import cutlass.cute as cute
|
|
11
|
+
from cutlass import Float32
|
|
12
|
+
|
|
13
|
+
import mslk.attention.flash_attn.utils as utils
|
|
14
|
+
from mslk.attention.flash_attn.cute_dsl_utils import ParamsBase
|
|
15
|
+
from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class Softmax(ParamsBase):
|
|
20
|
+
scale_log2: Float32
|
|
21
|
+
num_rows: cutlass.Constexpr[int]
|
|
22
|
+
row_max: cute.Tensor
|
|
23
|
+
row_sum: cute.Tensor
|
|
24
|
+
arch: cutlass.Constexpr[int] = 80
|
|
25
|
+
softmax_scale: Float32 | None = None
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def create(
|
|
29
|
+
scale_log2: Float32,
|
|
30
|
+
num_rows: cutlass.Constexpr[int],
|
|
31
|
+
arch: cutlass.Constexpr[int] = 80,
|
|
32
|
+
softmax_scale: Float32 | None = None,
|
|
33
|
+
):
|
|
34
|
+
row_max = cute.make_rmem_tensor(num_rows, Float32)
|
|
35
|
+
row_sum = cute.make_rmem_tensor(num_rows, Float32)
|
|
36
|
+
return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale)
|
|
37
|
+
|
|
38
|
+
def reset(self) -> None:
|
|
39
|
+
self.row_max.fill(-Float32.inf)
|
|
40
|
+
self.row_sum.fill(0.0)
|
|
41
|
+
|
|
42
|
+
def _compute_row_max(
|
|
43
|
+
self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None
|
|
44
|
+
) -> Float32:
|
|
45
|
+
return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch)
|
|
46
|
+
|
|
47
|
+
def _compute_row_sum(
|
|
48
|
+
self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None
|
|
49
|
+
) -> Float32:
|
|
50
|
+
return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch)
|
|
51
|
+
|
|
52
|
+
@cute.jit
|
|
53
|
+
def online_softmax(
|
|
54
|
+
self,
|
|
55
|
+
acc_S: cute.Tensor,
|
|
56
|
+
is_first: cutlass.Constexpr[bool] = False,
|
|
57
|
+
check_inf: cutlass.Constexpr[bool] = True,
|
|
58
|
+
) -> cute.Tensor:
|
|
59
|
+
"""Apply online softmax and return the row_scale to rescale O.
|
|
60
|
+
|
|
61
|
+
:param acc_S: acc_S tensor
|
|
62
|
+
:type acc_S: cute.Tensor
|
|
63
|
+
:param is_first: is first n_block
|
|
64
|
+
:type is_first: cutlass.Constexpr
|
|
65
|
+
"""
|
|
66
|
+
# Change acc_S to M,N layout view.
|
|
67
|
+
acc_S_mn = utils.make_acc_tensor_mn_view(acc_S)
|
|
68
|
+
row_scale = cute.make_fragment_like(self.row_max, Float32)
|
|
69
|
+
|
|
70
|
+
row_max = self.row_max
|
|
71
|
+
row_sum = self.row_sum
|
|
72
|
+
scale_log2 = self.scale_log2
|
|
73
|
+
arch = self.arch
|
|
74
|
+
|
|
75
|
+
# Each iteration processes one row of acc_S
|
|
76
|
+
for r in cutlass.range(cute.size(row_max), unroll_full=True):
|
|
77
|
+
acc_S_row = acc_S_mn[r, None].load() # (n_block_size)
|
|
78
|
+
|
|
79
|
+
row_max_cur = utils.fmax_reduce(
|
|
80
|
+
acc_S_row,
|
|
81
|
+
init_val=row_max[r] if cutlass.const_expr(not is_first) else None,
|
|
82
|
+
arch=arch,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
row_max_cur = utils.warp_reduce(row_max_cur, cute.arch.fmax, width=4)
|
|
86
|
+
# Update row_max before changing row_max_cur to safe value for -inf
|
|
87
|
+
row_max_prev = row_max[r]
|
|
88
|
+
row_max[r] = row_max_cur
|
|
89
|
+
|
|
90
|
+
if cutlass.const_expr(check_inf):
|
|
91
|
+
row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur
|
|
92
|
+
|
|
93
|
+
if cutlass.const_expr(is_first):
|
|
94
|
+
row_max_cur_scaled = row_max_cur * scale_log2
|
|
95
|
+
acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled)
|
|
96
|
+
|
|
97
|
+
acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch)
|
|
98
|
+
row_scale[r] = 1.0
|
|
99
|
+
else:
|
|
100
|
+
row_max_cur_scaled = row_max_cur * scale_log2
|
|
101
|
+
acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled)
|
|
102
|
+
# row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled)
|
|
103
|
+
row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * scale_log2)
|
|
104
|
+
|
|
105
|
+
acc_S_row_sum = utils.fadd_reduce(
|
|
106
|
+
acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
row_sum[r] = acc_S_row_sum
|
|
110
|
+
acc_S_mn[r, None].store(acc_S_row_exp)
|
|
111
|
+
|
|
112
|
+
return row_scale
|
|
113
|
+
|
|
114
|
+
@cute.jit
|
|
115
|
+
def finalize(
|
|
116
|
+
self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None
|
|
117
|
+
) -> cute.Tensor:
|
|
118
|
+
"""Finalize the online softmax by computing the scale and logsumexp."""
|
|
119
|
+
if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)):
|
|
120
|
+
assert cute.size(sink_val) == cute.size(self.row_sum)
|
|
121
|
+
row_sum = self.row_sum
|
|
122
|
+
row_max = self.row_max
|
|
123
|
+
scale_log2 = self.scale_log2
|
|
124
|
+
|
|
125
|
+
# quad reduction for row_sum as we didn't do it during each iteration of online softmax
|
|
126
|
+
row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4))
|
|
127
|
+
row_scale = cute.make_fragment_like(row_max, Float32)
|
|
128
|
+
|
|
129
|
+
for r in cutlass.range(cute.size(row_sum), unroll_full=True):
|
|
130
|
+
if cutlass.const_expr(sink_val is not None):
|
|
131
|
+
sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r]
|
|
132
|
+
LOG2_E = math.log2(math.e)
|
|
133
|
+
row_sum[r] += utils.exp2f(sink_val_cur * LOG2_E - row_max[r] * scale_log2)
|
|
134
|
+
|
|
135
|
+
# if row_sum is zero or nan, set acc_O_mn_row to 1.0
|
|
136
|
+
acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r]
|
|
137
|
+
row_scale[r] = (
|
|
138
|
+
cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0)
|
|
139
|
+
) * final_scale
|
|
140
|
+
row_sum_cur = row_sum[r]
|
|
141
|
+
LN2 = math.log(2.0)
|
|
142
|
+
row_sum[r] = (
|
|
143
|
+
(row_max[r] * scale_log2 + utils.log2f(row_sum_cur)) * LN2
|
|
144
|
+
if not acc_O_mn_row_is_zero_or_nan
|
|
145
|
+
else -Float32.inf
|
|
146
|
+
)
|
|
147
|
+
return row_scale
|
|
148
|
+
|
|
149
|
+
@cute.jit
|
|
150
|
+
def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None:
|
|
151
|
+
"""Scale each row of acc_O by the given scale tensor.
|
|
152
|
+
:param acc_O: input tensor
|
|
153
|
+
:type acc_O: cute.Tensor
|
|
154
|
+
:param row_scale: row_scale tensor
|
|
155
|
+
:type row_scale: cute.Tensor
|
|
156
|
+
"""
|
|
157
|
+
acc_O_mn = utils.make_acc_tensor_mn_view(acc_O)
|
|
158
|
+
assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0])
|
|
159
|
+
for r in cutlass.range(cute.size(row_scale), unroll_full=True):
|
|
160
|
+
acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r])
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@dataclass
|
|
164
|
+
class SoftmaxSm100(Softmax):
|
|
165
|
+
rescale_threshold: cutlass.Constexpr[float] = 0.0
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def create(
|
|
169
|
+
scale_log2: Float32,
|
|
170
|
+
rescale_threshold: cutlass.Constexpr[float] = 0.0,
|
|
171
|
+
softmax_scale: Float32 | None = None,
|
|
172
|
+
):
|
|
173
|
+
num_rows = 1
|
|
174
|
+
arch = 100
|
|
175
|
+
row_max = cute.make_rmem_tensor(num_rows, Float32)
|
|
176
|
+
row_sum = cute.make_rmem_tensor(num_rows, Float32)
|
|
177
|
+
return SoftmaxSm100(
|
|
178
|
+
scale_log2,
|
|
179
|
+
num_rows,
|
|
180
|
+
row_max,
|
|
181
|
+
row_sum,
|
|
182
|
+
arch,
|
|
183
|
+
softmax_scale,
|
|
184
|
+
rescale_threshold=rescale_threshold,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
@cute.jit
|
|
188
|
+
def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]:
|
|
189
|
+
if cutlass.const_expr(is_first):
|
|
190
|
+
row_max_new = self._compute_row_max(acc_S_row)
|
|
191
|
+
row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
|
|
192
|
+
acc_scale = 0.0
|
|
193
|
+
else:
|
|
194
|
+
row_max_old = self.row_max[0]
|
|
195
|
+
row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old)
|
|
196
|
+
row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
|
|
197
|
+
acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2
|
|
198
|
+
acc_scale = utils.exp2f(acc_scale_)
|
|
199
|
+
if cutlass.const_expr(self.rescale_threshold > 0.0):
|
|
200
|
+
if acc_scale_ >= -self.rescale_threshold:
|
|
201
|
+
row_max_new = row_max_old
|
|
202
|
+
row_max_safe = row_max_old
|
|
203
|
+
acc_scale = 1.0
|
|
204
|
+
self.row_max[0] = row_max_new
|
|
205
|
+
return row_max_safe, acc_scale
|
|
206
|
+
|
|
207
|
+
def update_row_sum(
|
|
208
|
+
self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False
|
|
209
|
+
) -> None:
|
|
210
|
+
init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None
|
|
211
|
+
# self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale)
|
|
212
|
+
self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val)
|
|
213
|
+
# tmp = self._compute_row_sum(acc_S_row_exp)
|
|
214
|
+
# self.row_sum[0] = self.row_sum[0] * row_scale + tmp
|
|
215
|
+
|
|
216
|
+
@cute.jit
|
|
217
|
+
def scale_subtract_rowmax(
|
|
218
|
+
self,
|
|
219
|
+
acc_S_row: cute.Tensor,
|
|
220
|
+
row_max: Float32,
|
|
221
|
+
):
|
|
222
|
+
assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
|
|
223
|
+
row_max_scaled = row_max * self.scale_log2
|
|
224
|
+
for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True):
|
|
225
|
+
acc_S_row[i], acc_S_row[i + 1] = utils.fma_packed_f32x2(
|
|
226
|
+
(acc_S_row[i], acc_S_row[i + 1]),
|
|
227
|
+
(self.scale_log2, self.scale_log2),
|
|
228
|
+
(-row_max_scaled, -row_max_scaled),
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
@cute.jit
|
|
232
|
+
def apply_exp2_convert(
|
|
233
|
+
self,
|
|
234
|
+
acc_S_row: cute.Tensor,
|
|
235
|
+
acc_S_row_converted: cute.Tensor,
|
|
236
|
+
e2e: cutlass.Constexpr[bool] = False,
|
|
237
|
+
e2e_freq: cutlass.Constexpr[int] = 16,
|
|
238
|
+
e2e_res: cutlass.Constexpr[int] = 4,
|
|
239
|
+
e2e_frg_limit: cutlass.Constexpr[int] = 1,
|
|
240
|
+
):
|
|
241
|
+
assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
|
|
242
|
+
frg_tile = 32
|
|
243
|
+
assert frg_tile % 2 == 0
|
|
244
|
+
frg_cnt = cute.size(acc_S_row) // frg_tile
|
|
245
|
+
assert cute.size(acc_S_row) % frg_tile == 0
|
|
246
|
+
acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))
|
|
247
|
+
acc_S_row_converted_frg = cute.logical_divide(
|
|
248
|
+
acc_S_row_converted, cute.make_layout(frg_tile)
|
|
249
|
+
)
|
|
250
|
+
for j in cutlass.range_constexpr(frg_cnt):
|
|
251
|
+
for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
|
|
252
|
+
# acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j])
|
|
253
|
+
# acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j])
|
|
254
|
+
if cutlass.const_expr(not e2e):
|
|
255
|
+
acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j])
|
|
256
|
+
acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j])
|
|
257
|
+
else:
|
|
258
|
+
if cutlass.const_expr(
|
|
259
|
+
k % e2e_freq < e2e_freq - e2e_res or j >= frg_cnt - e2e_frg_limit
|
|
260
|
+
):
|
|
261
|
+
acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j])
|
|
262
|
+
acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j])
|
|
263
|
+
else:
|
|
264
|
+
# acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j])
|
|
265
|
+
acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2(
|
|
266
|
+
acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]
|
|
267
|
+
)
|
|
268
|
+
acc_S_row_converted_frg[None, j].store(
|
|
269
|
+
acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
@cute.jit
|
|
273
|
+
def scale_apply_exp2_convert(
|
|
274
|
+
self,
|
|
275
|
+
acc_S_row: cute.Tensor,
|
|
276
|
+
row_max: Float32,
|
|
277
|
+
acc_S_row_converted: cute.Tensor,
|
|
278
|
+
):
|
|
279
|
+
assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
|
|
280
|
+
minus_row_max_scaled = -row_max * self.scale_log2
|
|
281
|
+
for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2):
|
|
282
|
+
acc_S_row[i], acc_S_row[i + 1] = utils.fma_packed_f32x2(
|
|
283
|
+
(acc_S_row[i], acc_S_row[i + 1]),
|
|
284
|
+
(self.scale_log2, self.scale_log2),
|
|
285
|
+
(minus_row_max_scaled, minus_row_max_scaled),
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2):
|
|
289
|
+
# acc_S_row[i], acc_S_row[i + 1] = utils.fma_packed_f32x2(
|
|
290
|
+
# (acc_S_row[i], acc_S_row[i + 1]),
|
|
291
|
+
# (self.scale_log2, self.scale_log2),
|
|
292
|
+
# (minus_row_max_scaled, minus_row_max_scaled),
|
|
293
|
+
# )
|
|
294
|
+
# acc_S_row[i] = cute.arch.exp2(acc_S_row[i])
|
|
295
|
+
# acc_S_row[i + 1] = cute.arch.exp2(acc_S_row[i + 1])
|
|
296
|
+
|
|
297
|
+
frg_tile = 32
|
|
298
|
+
assert frg_tile % 2 == 0
|
|
299
|
+
frg_cnt = cute.size(acc_S_row) // frg_tile
|
|
300
|
+
assert cute.size(acc_S_row) % frg_tile == 0
|
|
301
|
+
acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))
|
|
302
|
+
acc_S_row_converted_frg = cute.logical_divide(
|
|
303
|
+
acc_S_row_converted, cute.make_layout(frg_tile)
|
|
304
|
+
)
|
|
305
|
+
for j in cutlass.range_constexpr(frg_cnt):
|
|
306
|
+
for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
|
|
307
|
+
# acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = (
|
|
308
|
+
# utils.fma_packed_f32x2(
|
|
309
|
+
# (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]),
|
|
310
|
+
# (self.scale_log2, self.scale_log2),
|
|
311
|
+
# (minus_row_max_scaled, minus_row_max_scaled),
|
|
312
|
+
# )
|
|
313
|
+
# )
|
|
314
|
+
# acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j])
|
|
315
|
+
# acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j])
|
|
316
|
+
acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j])
|
|
317
|
+
acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j])
|
|
318
|
+
acc_S_row_converted_frg[None, j].store(
|
|
319
|
+
acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
@cute.jit
|
|
324
|
+
def floor_if_packed(
|
|
325
|
+
q_idx,
|
|
326
|
+
qhead_per_kvhead: cutlass.Constexpr[int],
|
|
327
|
+
) -> cute.Tensor:
|
|
328
|
+
"""Convert q_idx to packed format for Pack-GQA."""
|
|
329
|
+
if cutlass.const_expr(qhead_per_kvhead == 1):
|
|
330
|
+
return q_idx
|
|
331
|
+
return q_idx // qhead_per_kvhead
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
@cute.jit
|
|
335
|
+
def apply_score_mod_inner(
|
|
336
|
+
score_tensor,
|
|
337
|
+
index_tensor,
|
|
338
|
+
score_mod: cutlass.Constexpr,
|
|
339
|
+
batch_idx,
|
|
340
|
+
head_idx,
|
|
341
|
+
softmax_scale,
|
|
342
|
+
vec_size: cutlass.Constexpr,
|
|
343
|
+
qk_acc_dtype: cutlass.Constexpr,
|
|
344
|
+
aux_tensors,
|
|
345
|
+
fastdiv_mods,
|
|
346
|
+
seqlen_info: SeqlenInfoQK,
|
|
347
|
+
constant_q_idx: cutlass.Constexpr,
|
|
348
|
+
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
|
|
349
|
+
transpose_indices: cutlass.Constexpr[bool] = False,
|
|
350
|
+
):
|
|
351
|
+
"""Shared implementation for applying score modification.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
score_tensor: The scores to modify (acc_S for flash_fwd, tSrS_t2r for sm100)
|
|
355
|
+
index_tensor: Index positions (tScS for flash_fwd, tScS_t2r for sm100)
|
|
356
|
+
score_mod: The score modification function to apply
|
|
357
|
+
batch_idx: Batch index
|
|
358
|
+
head_idx: Head index
|
|
359
|
+
softmax_scale: Scale to apply
|
|
360
|
+
vec_size: Vector size for processing elements
|
|
361
|
+
qk_acc_dtype: Data type for accumulator
|
|
362
|
+
aux_tensors: Optional aux_tensors for FlexAttention
|
|
363
|
+
fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping
|
|
364
|
+
seqlen_info: Sequence length info
|
|
365
|
+
constant_q_idx: If provided, use this constant for all q_idx values
|
|
366
|
+
If None, compute q_idx per-element
|
|
367
|
+
qhead_per_kvhead_packgqa: Pack-GQA replication factor. Divide q_idx by this
|
|
368
|
+
when greater than 1 so score mods see logical heads.
|
|
369
|
+
transpose_indices: If True, swap q_idx/kv_idx in index_tensor (for bwd kernel where S is transposed)
|
|
370
|
+
"""
|
|
371
|
+
# Index positions in the index_tensor tuple
|
|
372
|
+
# Forward: index_tensor[...][0] = q_idx, index_tensor[...][1] = kv_idx
|
|
373
|
+
# Backward (transposed): index_tensor[...][0] = kv_idx, index_tensor[...][1] = q_idx
|
|
374
|
+
if cutlass.const_expr(transpose_indices):
|
|
375
|
+
q_idx_pos = cutlass.const_expr(1)
|
|
376
|
+
kv_idx_pos = cutlass.const_expr(0)
|
|
377
|
+
else:
|
|
378
|
+
q_idx_pos = cutlass.const_expr(0)
|
|
379
|
+
kv_idx_pos = cutlass.const_expr(1)
|
|
380
|
+
|
|
381
|
+
n_vals = cutlass.const_expr(cute.size(score_tensor.shape))
|
|
382
|
+
score_vec = cute.make_rmem_tensor(vec_size, qk_acc_dtype)
|
|
383
|
+
kv_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)
|
|
384
|
+
|
|
385
|
+
# SSA values for batch (constant across all elements)
|
|
386
|
+
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,))
|
|
387
|
+
|
|
388
|
+
# Handle q_idx based on whether it's constant
|
|
389
|
+
q_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)
|
|
390
|
+
|
|
391
|
+
# For Pack-GQA with non-constant q_idx, we need per-element head indices
|
|
392
|
+
# since a thread my process multiple query head indices
|
|
393
|
+
if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
|
|
394
|
+
head_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)
|
|
395
|
+
|
|
396
|
+
for i in cutlass.range(0, n_vals, vec_size, unroll_full=True):
|
|
397
|
+
for j in cutlass.range(vec_size, unroll_full=True):
|
|
398
|
+
score_vec[j] = score_tensor[i + j] * softmax_scale
|
|
399
|
+
|
|
400
|
+
# Extract head offset from packed q_idx for Pack-GQA
|
|
401
|
+
if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
|
|
402
|
+
q_idx_packed = index_tensor[i + j][q_idx_pos]
|
|
403
|
+
# Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead)
|
|
404
|
+
q_idx_logical = q_idx_packed // qhead_per_kvhead
|
|
405
|
+
head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead
|
|
406
|
+
head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset
|
|
407
|
+
|
|
408
|
+
# If we will do loads we mod, in order to not read OOB
|
|
409
|
+
if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None):
|
|
410
|
+
if cutlass.const_expr(constant_q_idx is None):
|
|
411
|
+
seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
|
|
412
|
+
q_idx_floored = floor_if_packed(
|
|
413
|
+
index_tensor[i + j][q_idx_pos], qhead_per_kvhead
|
|
414
|
+
)
|
|
415
|
+
_, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod)
|
|
416
|
+
q_idx_vec[j] = q_idx_wrapped
|
|
417
|
+
else:
|
|
418
|
+
_, seqlen_k_divmod = fastdiv_mods
|
|
419
|
+
|
|
420
|
+
_, kv_idx_wrapped = divmod(index_tensor[i + j][kv_idx_pos], seqlen_k_divmod)
|
|
421
|
+
kv_idx_vec[j] = kv_idx_wrapped
|
|
422
|
+
else:
|
|
423
|
+
# No bounds checking - direct indexing
|
|
424
|
+
if constant_q_idx is None:
|
|
425
|
+
q_idx_vec[j] = floor_if_packed(index_tensor[i + j][q_idx_pos], qhead_per_kvhead)
|
|
426
|
+
kv_idx_vec[j] = index_tensor[i + j][kv_idx_pos]
|
|
427
|
+
|
|
428
|
+
# Convert to SSA for score_mod call
|
|
429
|
+
score_ssa = score_vec.load()
|
|
430
|
+
kv_idx_ssa = kv_idx_vec.load()
|
|
431
|
+
if cutlass.const_expr(constant_q_idx is None):
|
|
432
|
+
q_idx_ssa = q_idx_vec.load()
|
|
433
|
+
else:
|
|
434
|
+
# NB we do not apply Pack-GQA division here, as constant_q_idx is assumed to already be logical
|
|
435
|
+
q_idx_const = constant_q_idx
|
|
436
|
+
q_idx_ssa = utils.scalar_to_ssa(q_idx_const, cutlass.Int32).broadcast_to((vec_size,))
|
|
437
|
+
|
|
438
|
+
# Compute head_idx_ssa: per-element for Pack-GQA with non-constant q_idx, constant otherwise
|
|
439
|
+
if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
|
|
440
|
+
head_idx_ssa = head_idx_vec.load()
|
|
441
|
+
else:
|
|
442
|
+
head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,))
|
|
443
|
+
|
|
444
|
+
aux_args = []
|
|
445
|
+
if cutlass.const_expr(aux_tensors is not None):
|
|
446
|
+
aux_args = aux_tensors
|
|
447
|
+
|
|
448
|
+
post_mod_scores = score_mod(
|
|
449
|
+
score_ssa,
|
|
450
|
+
batch_idx_ssa,
|
|
451
|
+
head_idx_ssa,
|
|
452
|
+
q_idx=q_idx_ssa,
|
|
453
|
+
kv_idx=kv_idx_ssa,
|
|
454
|
+
seqlen_info=seqlen_info,
|
|
455
|
+
aux_tensors=aux_args,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# Write back modified scores
|
|
459
|
+
score_vec.store(post_mod_scores)
|
|
460
|
+
for j in cutlass.range(vec_size, unroll_full=True):
|
|
461
|
+
score_tensor[i + j] = score_vec[j]
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
@cute.jit
|
|
465
|
+
def apply_score_mod_bwd_inner(
|
|
466
|
+
grad_tensor,
|
|
467
|
+
score_tensor,
|
|
468
|
+
index_tensor,
|
|
469
|
+
score_mod_bwd: cutlass.Constexpr,
|
|
470
|
+
batch_idx,
|
|
471
|
+
head_idx,
|
|
472
|
+
softmax_scale,
|
|
473
|
+
vec_size: cutlass.Constexpr,
|
|
474
|
+
qk_acc_dtype: cutlass.Constexpr,
|
|
475
|
+
aux_tensors,
|
|
476
|
+
fastdiv_mods,
|
|
477
|
+
seqlen_info,
|
|
478
|
+
constant_q_idx: cutlass.Constexpr,
|
|
479
|
+
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
|
|
480
|
+
transpose_indices: cutlass.Constexpr[bool] = False,
|
|
481
|
+
):
|
|
482
|
+
"""Apply backward score modification (joint graph).
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
grad_tensor: in/out: dlogits rewritten in-place with d(scaled_scores)
|
|
486
|
+
score_tensor: pre-mod scores (unscaled QK tile), scaled by softmax_scale internally
|
|
487
|
+
index_tensor: Index positions (same as forward)
|
|
488
|
+
score_mod_bwd: The backward score modification function (joint graph)
|
|
489
|
+
batch_idx: Batch index
|
|
490
|
+
head_idx: Head index
|
|
491
|
+
softmax_scale: Scale to apply to score_tensor
|
|
492
|
+
vec_size: Vector size for processing elements
|
|
493
|
+
qk_acc_dtype: Data type for accumulator
|
|
494
|
+
aux_tensors: Optional aux_tensors for FlexAttention
|
|
495
|
+
fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping
|
|
496
|
+
seqlen_info: Sequence length info
|
|
497
|
+
constant_q_idx: If provided, use this constant for all q_idx values
|
|
498
|
+
qhead_per_kvhead: Pack-GQA replication factor
|
|
499
|
+
transpose_indices: If True, swap q_idx/kv_idx in index_tensor
|
|
500
|
+
"""
|
|
501
|
+
# Index positions in the index_tensor tuple
|
|
502
|
+
# Forward: index_tensor[...][0] = q_idx, index_tensor[...][1] = kv_idx
|
|
503
|
+
# Backward (transposed): index_tensor[...][0] = kv_idx, index_tensor[...][1] = q_idx
|
|
504
|
+
if cutlass.const_expr(transpose_indices):
|
|
505
|
+
q_idx_pos = cutlass.const_expr(1)
|
|
506
|
+
kv_idx_pos = cutlass.const_expr(0)
|
|
507
|
+
else:
|
|
508
|
+
q_idx_pos = cutlass.const_expr(0)
|
|
509
|
+
kv_idx_pos = cutlass.const_expr(1)
|
|
510
|
+
n_vals = cutlass.const_expr(cute.size(grad_tensor.shape))
|
|
511
|
+
grad_vec = cute.make_fragment(vec_size, qk_acc_dtype)
|
|
512
|
+
score_vec = cute.make_fragment(vec_size, qk_acc_dtype)
|
|
513
|
+
kv_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)
|
|
514
|
+
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,))
|
|
515
|
+
q_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)
|
|
516
|
+
|
|
517
|
+
# For Pack-GQA with non-constant q_idx, we need per-element head indices
|
|
518
|
+
if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
|
|
519
|
+
head_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)
|
|
520
|
+
|
|
521
|
+
for i in cutlass.range(0, n_vals, vec_size, unroll_full=True):
|
|
522
|
+
for j in cutlass.range(vec_size, unroll_full=True):
|
|
523
|
+
grad_vec[j] = grad_tensor[i + j]
|
|
524
|
+
# Scale score so joint graph sees same value as forward score_mod
|
|
525
|
+
score_vec[j] = score_tensor[i + j] * softmax_scale
|
|
526
|
+
|
|
527
|
+
if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
|
|
528
|
+
q_idx_packed = index_tensor[i + j][q_idx_pos]
|
|
529
|
+
q_idx_logical = q_idx_packed // qhead_per_kvhead
|
|
530
|
+
head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead
|
|
531
|
+
head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset
|
|
532
|
+
|
|
533
|
+
if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None):
|
|
534
|
+
if cutlass.const_expr(constant_q_idx is None):
|
|
535
|
+
seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
|
|
536
|
+
q_idx_floored = floor_if_packed(
|
|
537
|
+
index_tensor[i + j][q_idx_pos], qhead_per_kvhead
|
|
538
|
+
)
|
|
539
|
+
_, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod)
|
|
540
|
+
q_idx_vec[j] = q_idx_wrapped
|
|
541
|
+
else:
|
|
542
|
+
_, seqlen_k_divmod = fastdiv_mods
|
|
543
|
+
|
|
544
|
+
_, kv_idx_wrapped = divmod(index_tensor[i + j][kv_idx_pos], seqlen_k_divmod)
|
|
545
|
+
kv_idx_vec[j] = kv_idx_wrapped
|
|
546
|
+
else:
|
|
547
|
+
# No bounds checking - direct indexing
|
|
548
|
+
if constant_q_idx is None:
|
|
549
|
+
q_idx_vec[j] = floor_if_packed(index_tensor[i + j][q_idx_pos], qhead_per_kvhead)
|
|
550
|
+
kv_idx_vec[j] = index_tensor[i + j][kv_idx_pos]
|
|
551
|
+
|
|
552
|
+
grad_ssa = grad_vec.load()
|
|
553
|
+
score_ssa = score_vec.load()
|
|
554
|
+
kv_idx_ssa = kv_idx_vec.load()
|
|
555
|
+
|
|
556
|
+
if cutlass.const_expr(constant_q_idx is None):
|
|
557
|
+
q_idx_ssa = q_idx_vec.load()
|
|
558
|
+
else:
|
|
559
|
+
q_idx_ssa = utils.scalar_to_ssa(constant_q_idx, cutlass.Int32).broadcast_to((vec_size,))
|
|
560
|
+
|
|
561
|
+
if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
|
|
562
|
+
head_idx_ssa = head_idx_vec.load()
|
|
563
|
+
else:
|
|
564
|
+
head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,))
|
|
565
|
+
|
|
566
|
+
aux_args = []
|
|
567
|
+
if cutlass.const_expr(aux_tensors is not None):
|
|
568
|
+
aux_args = aux_tensors
|
|
569
|
+
|
|
570
|
+
grad_out_ssa = score_mod_bwd(
|
|
571
|
+
grad_ssa,
|
|
572
|
+
score_ssa,
|
|
573
|
+
batch_idx_ssa,
|
|
574
|
+
head_idx_ssa,
|
|
575
|
+
q_idx=q_idx_ssa,
|
|
576
|
+
kv_idx=kv_idx_ssa,
|
|
577
|
+
seqlen_info=seqlen_info,
|
|
578
|
+
aux_tensors=aux_args,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
grad_vec.store(grad_out_ssa)
|
|
582
|
+
for j in cutlass.range(vec_size, unroll_full=True):
|
|
583
|
+
grad_tensor[i + j] = grad_vec[j]
|