mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,610 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Tri Dao.
|
|
3
|
+
|
|
4
|
+
from typing import Optional, Callable
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
import cutlass
|
|
8
|
+
import cutlass.cute as cute
|
|
9
|
+
from cutlass import Float32, Int32, const_expr
|
|
10
|
+
|
|
11
|
+
import mslk.attention.flash_attn.utils as utils
|
|
12
|
+
from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@cute.jit
|
|
16
|
+
def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = False) -> None:
|
|
17
|
+
# Bit manipulation, compiles down to the R2P instruction
|
|
18
|
+
# For sm100: we know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using.
|
|
19
|
+
# For sm90: instead of comparing limit to 0, 1, 8, 9, 16, 17, ...,
|
|
20
|
+
# we compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ...
|
|
21
|
+
if const_expr(arch == 90):
|
|
22
|
+
col_limit_transformed = col_limit // 8 * 2 + min(col_limit % 8, 2)
|
|
23
|
+
else:
|
|
24
|
+
col_limit_transformed = col_limit
|
|
25
|
+
ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape))
|
|
26
|
+
# Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31
|
|
27
|
+
for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
|
|
28
|
+
# Don't need to clamp to 32 since the shr.u32 instruction does that already
|
|
29
|
+
col_limit_right_s = max(col_limit_transformed - s * 24, 0)
|
|
30
|
+
# 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11
|
|
31
|
+
mask = (1 << col_limit_right_s) - 1
|
|
32
|
+
# This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
|
|
33
|
+
for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
|
|
34
|
+
in_bound = cutlass.Boolean(mask & (1 << i))
|
|
35
|
+
c = s * 24 + i
|
|
36
|
+
if const_expr(rank1):
|
|
37
|
+
X[c] = X[c] if in_bound else -Float32.inf
|
|
38
|
+
# This is the equivalent of:
|
|
39
|
+
# X[s * 24 + i] = X[s * 24 + i] if col_limit_right_s <= i else -Float32.inf
|
|
40
|
+
else:
|
|
41
|
+
for r in cutlass.range_constexpr(cute.size(X.shape[0])):
|
|
42
|
+
X[r, c] = X[r, c] if in_bound else -Float32.inf
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@cute.jit
|
|
46
|
+
def mask_r2p_transposed(X: cute.Tensor, row_limit_top: Int32, num_rep: int) -> None:
|
|
47
|
+
# Bit manipulation, compiles down to the R2P instruction
|
|
48
|
+
# For sm100: we know that tScS_t2r[i][0] has the form 0, 1, ..., 31, 64, ..., 127
|
|
49
|
+
# or 0, 1, ..., 15, 32, ..., 47, 64, ...
|
|
50
|
+
# We compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ...
|
|
51
|
+
# Here we hardcode for the case of 2 warp groups.
|
|
52
|
+
num_wg = 2
|
|
53
|
+
row_limit_top_transformed = row_limit_top // (num_rep * num_wg) * num_rep + min(
|
|
54
|
+
row_limit_top % (num_rep * num_wg), num_rep
|
|
55
|
+
)
|
|
56
|
+
ncol = cute.size(X.shape)
|
|
57
|
+
# Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31
|
|
58
|
+
for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
|
|
59
|
+
row_limit_top_s = max(row_limit_top_transformed - s * 24, 0)
|
|
60
|
+
# 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11
|
|
61
|
+
mask = (1 << row_limit_top_s) - 1
|
|
62
|
+
# This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
|
|
63
|
+
for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
|
|
64
|
+
out_bound = cutlass.Boolean(mask & (1 << i))
|
|
65
|
+
c = s * 24 + i
|
|
66
|
+
X[c] = -Float32.inf if out_bound else X[c]
|
|
67
|
+
# tidx = cute.arch.thread_idx()[0] % 256
|
|
68
|
+
# if tidx == 128:
|
|
69
|
+
# cute.printf("tidx = {}, s = {}, i = {}, row_limit_top = {}, row_limit_top_s = {}, mask = {}, out_bound = {}", tidx, s, i, row_limit_top, row_limit_top_s, mask, out_bound)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass(frozen=True)
|
|
73
|
+
class AttentionMask:
|
|
74
|
+
tile_m: cutlass.Constexpr[int]
|
|
75
|
+
tile_n: cutlass.Constexpr[int]
|
|
76
|
+
seqlen_info: SeqlenInfoQK
|
|
77
|
+
window_size_left: Optional[Int32] = None
|
|
78
|
+
window_size_right: Optional[Int32] = None
|
|
79
|
+
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 # only pass in if we're doing PackGQA
|
|
80
|
+
swap_AB: cutlass.Constexpr[bool] = False
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def seqlen_q(self) -> Int32:
|
|
84
|
+
return self.seqlen_info.seqlen_q
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def seqlen_k(self) -> Int32:
|
|
88
|
+
return self.seqlen_info.seqlen_k
|
|
89
|
+
|
|
90
|
+
@cute.jit
|
|
91
|
+
def apply_mask(
|
|
92
|
+
self,
|
|
93
|
+
acc_S: cute.Tensor,
|
|
94
|
+
batch_idx: cutlass.Int32,
|
|
95
|
+
head_idx: cutlass.Int32,
|
|
96
|
+
m_block: cutlass.Int32,
|
|
97
|
+
n_block: cutlass.Int32,
|
|
98
|
+
thr_mma: cute.TiledMma,
|
|
99
|
+
mask_seqlen: cutlass.Constexpr[bool],
|
|
100
|
+
mask_causal: cutlass.Constexpr[bool],
|
|
101
|
+
mask_local: cutlass.Constexpr[bool] = False,
|
|
102
|
+
mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
|
|
103
|
+
aux_tensors: Optional[list] = None,
|
|
104
|
+
fastdiv_mods=(None, None),
|
|
105
|
+
) -> None:
|
|
106
|
+
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
|
|
107
|
+
acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.swap_AB)
|
|
108
|
+
acc_shape = (self.tile_m, self.tile_n)
|
|
109
|
+
cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1])
|
|
110
|
+
tScS_mn = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cS), transpose=self.swap_AB)
|
|
111
|
+
# We use t0ScS as these indices are known at compile time. We then must subtract the
|
|
112
|
+
# column limit by the thread column offset.
|
|
113
|
+
t0ScS_mn = utils.make_acc_tensor_mn_view(
|
|
114
|
+
thr_mma.get_slice(0).partition_C(cS), transpose=self.swap_AB
|
|
115
|
+
)
|
|
116
|
+
ROW = 0 if const_expr(not self.swap_AB) else 1
|
|
117
|
+
COL = 1 if const_expr(not self.swap_AB) else 0
|
|
118
|
+
thr_col_offset = tScS_mn[0][COL]
|
|
119
|
+
# To handle edge cases of completely masked out rows where n_block_max = 0,
|
|
120
|
+
# we treat negative n_blocks as 0th n_block
|
|
121
|
+
# TODO: find more transparent solution
|
|
122
|
+
if n_block < 0:
|
|
123
|
+
n_block = 0
|
|
124
|
+
seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
|
|
125
|
+
if const_expr(not mask_causal and not mask_local and mask_mod is None):
|
|
126
|
+
if const_expr(mask_seqlen):
|
|
127
|
+
# The compiler now choses not to use R2P
|
|
128
|
+
r2p = const_expr(False and not self.swap_AB)
|
|
129
|
+
if const_expr(not r2p):
|
|
130
|
+
# traverse column index.
|
|
131
|
+
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
|
|
132
|
+
oob = t0ScS_mn[0, c][COL] >= seqlenk_col_limit
|
|
133
|
+
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
|
134
|
+
acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c]
|
|
135
|
+
else:
|
|
136
|
+
mask_r2p(acc_S_mn, seqlenk_col_limit, arch=90)
|
|
137
|
+
|
|
138
|
+
elif const_expr(
|
|
139
|
+
not mask_causal and not mask_local and mask_mod is not None
|
|
140
|
+
): # FlexAttention mask mod
|
|
141
|
+
nrow = const_expr(cute.size(tScS_mn.shape[0]))
|
|
142
|
+
ncol = const_expr(cute.size(tScS_mn.shape[1]))
|
|
143
|
+
has_fastdiv = const_expr(
|
|
144
|
+
fastdiv_mods is not None
|
|
145
|
+
and fastdiv_mods[0] is not None
|
|
146
|
+
and fastdiv_mods[1] is not None
|
|
147
|
+
)
|
|
148
|
+
wrap_aux_indices = const_expr(
|
|
149
|
+
has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None)
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
for r in cutlass.range_constexpr(nrow):
|
|
153
|
+
# Respect swap_AB: ROW/COL determine which coordinate component corresponds to Q/KV.
|
|
154
|
+
local_row = tScS_mn[r, 0][ROW]
|
|
155
|
+
global_row_idx = local_row + m_block * self.tile_m
|
|
156
|
+
row_for_mod = global_row_idx
|
|
157
|
+
head_idx_for_mod = head_idx
|
|
158
|
+
if const_expr(self.qhead_per_kvhead_packgqa != 1):
|
|
159
|
+
head_offset = global_row_idx % self.qhead_per_kvhead_packgqa
|
|
160
|
+
head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset
|
|
161
|
+
row_for_mod = global_row_idx // self.qhead_per_kvhead_packgqa
|
|
162
|
+
row_for_seqlen = row_for_mod
|
|
163
|
+
if const_expr(wrap_aux_indices):
|
|
164
|
+
_, row_for_mod = divmod(row_for_mod, fastdiv_mods[0])
|
|
165
|
+
|
|
166
|
+
for col in cutlass.range_constexpr(ncol):
|
|
167
|
+
col_idx_local = t0ScS_mn[0, col][COL]
|
|
168
|
+
# Convert to absolute column index
|
|
169
|
+
global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n
|
|
170
|
+
col_for_mod = global_col_idx
|
|
171
|
+
if const_expr(wrap_aux_indices):
|
|
172
|
+
_, col_for_mod = divmod(global_col_idx, fastdiv_mods[1])
|
|
173
|
+
|
|
174
|
+
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
|
|
175
|
+
head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)
|
|
176
|
+
q_idx_ssa = utils.scalar_to_ssa(row_for_mod, cutlass.Int32)
|
|
177
|
+
kv_idx_ssa = utils.scalar_to_ssa(col_for_mod, cutlass.Int32)
|
|
178
|
+
mask_value = mask_mod(
|
|
179
|
+
batch_idx_ssa,
|
|
180
|
+
head_idx_ssa,
|
|
181
|
+
q_idx_ssa,
|
|
182
|
+
kv_idx_ssa,
|
|
183
|
+
self.seqlen_info,
|
|
184
|
+
aux_tensors,
|
|
185
|
+
)
|
|
186
|
+
cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
|
|
187
|
+
if const_expr(mask_seqlen):
|
|
188
|
+
out_of_bounds = (row_for_seqlen >= self.seqlen_q) or (
|
|
189
|
+
global_col_idx >= self.seqlen_k
|
|
190
|
+
)
|
|
191
|
+
if out_of_bounds:
|
|
192
|
+
acc_S_mn[r, col] = -cutlass.Float32.inf
|
|
193
|
+
else:
|
|
194
|
+
acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf
|
|
195
|
+
else:
|
|
196
|
+
acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf
|
|
197
|
+
|
|
198
|
+
else: # Causal or local
|
|
199
|
+
if const_expr(not self.swap_AB):
|
|
200
|
+
# If PackGQA, we split the work of compute divmod among threads in the same row
|
|
201
|
+
threads_per_row = thr_mma.tv_layout_C.shape[0][0]
|
|
202
|
+
mma_m_idx = None
|
|
203
|
+
if const_expr(self.qhead_per_kvhead_packgqa != 1):
|
|
204
|
+
assert not self.swap_AB, "swap_AB with PackGQA not supported yet"
|
|
205
|
+
assert cute.arch.WARP_SIZE % threads_per_row == 0, (
|
|
206
|
+
"threads_per_row must divide WARP_SIZE"
|
|
207
|
+
)
|
|
208
|
+
assert cute.size(acc_S_mn.shape[0]) <= threads_per_row
|
|
209
|
+
tidx = thr_mma.thr_idx
|
|
210
|
+
mma_m_idx = (
|
|
211
|
+
m_block * self.tile_m + tScS_mn[tidx % threads_per_row, 0][0]
|
|
212
|
+
) // self.qhead_per_kvhead_packgqa
|
|
213
|
+
causal_row_offset = (
|
|
214
|
+
1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q - thr_col_offset
|
|
215
|
+
)
|
|
216
|
+
if const_expr(mask_causal):
|
|
217
|
+
r2p = const_expr(not self.swap_AB) # R2P trick, see apply_mask_sm100
|
|
218
|
+
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
|
219
|
+
# get the column index limit based on current row. Only consider the row index, so the column index sets to 0.
|
|
220
|
+
if const_expr(self.qhead_per_kvhead_packgqa == 1):
|
|
221
|
+
row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
|
|
222
|
+
else:
|
|
223
|
+
row_idx = utils.shuffle_sync(
|
|
224
|
+
mma_m_idx, r % threads_per_row, width=threads_per_row
|
|
225
|
+
)
|
|
226
|
+
col_limit_right = row_idx + causal_row_offset
|
|
227
|
+
if const_expr(mask_seqlen):
|
|
228
|
+
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
|
|
229
|
+
if const_expr(not r2p):
|
|
230
|
+
# traverse column index.
|
|
231
|
+
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
|
|
232
|
+
acc_S_mn[r, c] = (
|
|
233
|
+
-Float32.inf
|
|
234
|
+
if t0ScS_mn[0, c][1] >= col_limit_right
|
|
235
|
+
else acc_S_mn[r, c]
|
|
236
|
+
)
|
|
237
|
+
else:
|
|
238
|
+
mask_r2p(acc_S_mn[r, None], col_limit_right, arch=90, rank1=True)
|
|
239
|
+
else: # Local
|
|
240
|
+
local_row_offset_right = (
|
|
241
|
+
causal_row_offset + self.window_size_right
|
|
242
|
+
if const_expr(self.window_size_right is not None)
|
|
243
|
+
else None
|
|
244
|
+
)
|
|
245
|
+
local_row_offset_left = (
|
|
246
|
+
causal_row_offset - 1 - self.window_size_left
|
|
247
|
+
if const_expr(self.window_size_left is not None)
|
|
248
|
+
else None
|
|
249
|
+
)
|
|
250
|
+
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
|
251
|
+
if const_expr(self.qhead_per_kvhead_packgqa == 1):
|
|
252
|
+
row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
|
|
253
|
+
else:
|
|
254
|
+
row_idx = utils.shuffle_sync(
|
|
255
|
+
mma_m_idx, r % threads_per_row, width=threads_per_row
|
|
256
|
+
)
|
|
257
|
+
if const_expr(self.window_size_right is not None):
|
|
258
|
+
col_limit_right = row_idx + local_row_offset_right
|
|
259
|
+
else:
|
|
260
|
+
col_limit_right = self.tile_n
|
|
261
|
+
if const_expr(mask_seqlen):
|
|
262
|
+
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
|
|
263
|
+
col_limit_left = (
|
|
264
|
+
row_idx + local_row_offset_left
|
|
265
|
+
if const_expr(self.window_size_left is not None)
|
|
266
|
+
else 0
|
|
267
|
+
)
|
|
268
|
+
# if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left)
|
|
269
|
+
# traverse column index.
|
|
270
|
+
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
|
|
271
|
+
col_idx = t0ScS_mn[0, c][1]
|
|
272
|
+
# only consider the column index, so the row index sets to 0.
|
|
273
|
+
if col_idx >= col_limit_right or col_idx < col_limit_left:
|
|
274
|
+
acc_S_mn[r, c] = -Float32.inf
|
|
275
|
+
else: # swap_AB
|
|
276
|
+
assert self.qhead_per_kvhead_packgqa == 1
|
|
277
|
+
thr_row_offset = tScS_mn[0][ROW]
|
|
278
|
+
causal_row_offset = (
|
|
279
|
+
seqlenk_col_limit - self.seqlen_q + m_block * self.tile_m + thr_row_offset
|
|
280
|
+
)
|
|
281
|
+
if const_expr(mask_causal):
|
|
282
|
+
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
|
|
283
|
+
col0 = t0ScS_mn[0, c][COL]
|
|
284
|
+
# If col0 is beyond the column limit, we want to mask out the entire
|
|
285
|
+
# column, by setting row limit to be self.tile_m.
|
|
286
|
+
row_limit_top = (
|
|
287
|
+
self.tile_m
|
|
288
|
+
if col0 >= seqlenk_col_limit and mask_seqlen
|
|
289
|
+
else col0 - causal_row_offset
|
|
290
|
+
)
|
|
291
|
+
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
|
292
|
+
acc_S_mn[r, c] = (
|
|
293
|
+
-Float32.inf
|
|
294
|
+
if t0ScS_mn[r, 0][ROW] < row_limit_top
|
|
295
|
+
else acc_S_mn[r, c]
|
|
296
|
+
)
|
|
297
|
+
else:
|
|
298
|
+
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
|
|
299
|
+
col0 = t0ScS_mn[0, c][COL]
|
|
300
|
+
# If col0 is beyond the column limit, we want to mask out the entire
|
|
301
|
+
# column, by setting row limit to be self.tile_m.
|
|
302
|
+
row_limit_top = (
|
|
303
|
+
self.tile_m
|
|
304
|
+
if col0 >= seqlenk_col_limit
|
|
305
|
+
else col0 - causal_row_offset - self.window_size_right
|
|
306
|
+
)
|
|
307
|
+
# TODO: do we need col_limit_sink?
|
|
308
|
+
row_limit_bot = col0 - causal_row_offset + self.window_size_left
|
|
309
|
+
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
|
|
310
|
+
row_idx = t0ScS_mn[r, 0][ROW]
|
|
311
|
+
acc_S_mn[r, c] = (
|
|
312
|
+
-Float32.inf
|
|
313
|
+
if row_idx < row_limit_top or row_idx > row_limit_bot
|
|
314
|
+
else acc_S_mn[r, c]
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
@cute.jit
|
|
318
|
+
def apply_mask_sm100(
|
|
319
|
+
self,
|
|
320
|
+
acc_S: cute.Tensor,
|
|
321
|
+
m_block: Int32,
|
|
322
|
+
n_block: Int32,
|
|
323
|
+
thr_mma: cute.TiledMma,
|
|
324
|
+
thr_tmem_load: cute.TiledCopy,
|
|
325
|
+
mask_seqlen: cutlass.Constexpr[bool],
|
|
326
|
+
mask_causal: cutlass.Constexpr[bool],
|
|
327
|
+
mask_local: cutlass.Constexpr[bool] = False,
|
|
328
|
+
mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
|
|
329
|
+
batch_idx: Int32 = None,
|
|
330
|
+
head_idx: Int32 = None,
|
|
331
|
+
aux_tensors: Optional[list] = None,
|
|
332
|
+
fastdiv_mods=(None, None),
|
|
333
|
+
check_q_boundary: bool = False,
|
|
334
|
+
) -> None:
|
|
335
|
+
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
|
|
336
|
+
acc_shape = (self.tile_m, self.tile_n)
|
|
337
|
+
cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1])
|
|
338
|
+
tScS = thr_mma.partition_C(cS)
|
|
339
|
+
tScS_t2r = thr_tmem_load.partition_D(tScS)
|
|
340
|
+
# To handle edge cases of completely masked out rows where n_block_max = 0,
|
|
341
|
+
# we treat negative n_blocks as 0th n_block
|
|
342
|
+
# TODO: find more transparent solution
|
|
343
|
+
if n_block < 0:
|
|
344
|
+
n_block = 0
|
|
345
|
+
seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n
|
|
346
|
+
r2p = True
|
|
347
|
+
if const_expr(not mask_causal and not mask_local and mask_mod is None):
|
|
348
|
+
if const_expr(mask_seqlen):
|
|
349
|
+
if const_expr(not r2p):
|
|
350
|
+
for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
|
|
351
|
+
# if tScS_t2r[i][1] >= seqlenk_col_limit:
|
|
352
|
+
# acc_S[i] = -Float32.inf
|
|
353
|
+
# For some reason the 2 lines above generate really bad SASS
|
|
354
|
+
acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i]
|
|
355
|
+
else:
|
|
356
|
+
mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True)
|
|
357
|
+
|
|
358
|
+
elif const_expr(not mask_causal and not mask_local and mask_mod is not None):
|
|
359
|
+
# Block sparse case w/ mask_mod
|
|
360
|
+
has_fastdiv = const_expr(
|
|
361
|
+
fastdiv_mods is not None
|
|
362
|
+
and fastdiv_mods[0] is not None
|
|
363
|
+
and fastdiv_mods[1] is not None
|
|
364
|
+
)
|
|
365
|
+
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
|
|
366
|
+
|
|
367
|
+
ncol = const_expr(cute.size(tScS_t2r.shape))
|
|
368
|
+
for i in cutlass.range_constexpr(ncol):
|
|
369
|
+
row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1]
|
|
370
|
+
col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0]
|
|
371
|
+
global_row = row_coord + m_block * self.tile_m
|
|
372
|
+
global_col = col_coord + n_block * self.tile_n
|
|
373
|
+
|
|
374
|
+
if const_expr(self.qhead_per_kvhead_packgqa != 1):
|
|
375
|
+
head_offset = global_row % self.qhead_per_kvhead_packgqa
|
|
376
|
+
head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset
|
|
377
|
+
mask_row = global_row // self.qhead_per_kvhead_packgqa
|
|
378
|
+
else:
|
|
379
|
+
head_idx_for_mod = head_idx
|
|
380
|
+
mask_row = global_row
|
|
381
|
+
|
|
382
|
+
mask_row_for_mod = mask_row
|
|
383
|
+
if const_expr(has_fastdiv and aux_tensors is not None):
|
|
384
|
+
if check_q_boundary:
|
|
385
|
+
_, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0])
|
|
386
|
+
global_col_for_mod = global_col
|
|
387
|
+
if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None):
|
|
388
|
+
_, global_col_for_mod = divmod(global_col, fastdiv_mods[1])
|
|
389
|
+
|
|
390
|
+
head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)
|
|
391
|
+
mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32)
|
|
392
|
+
kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32)
|
|
393
|
+
mask_value = mask_mod(
|
|
394
|
+
batch_idx_ssa,
|
|
395
|
+
head_idx_ssa,
|
|
396
|
+
mask_row_ssa,
|
|
397
|
+
kv_idx_ssa,
|
|
398
|
+
self.seqlen_info,
|
|
399
|
+
aux_tensors,
|
|
400
|
+
)
|
|
401
|
+
cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
|
|
402
|
+
acc_S[i] = acc_S[i] if cond else -Float32.inf
|
|
403
|
+
if const_expr(mask_seqlen):
|
|
404
|
+
acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i]
|
|
405
|
+
if check_q_boundary:
|
|
406
|
+
acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i]
|
|
407
|
+
|
|
408
|
+
else: # Causal or local
|
|
409
|
+
causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q
|
|
410
|
+
row_idx = tScS_t2r[0][0] + m_block * self.tile_m
|
|
411
|
+
if const_expr(self.qhead_per_kvhead_packgqa != 1):
|
|
412
|
+
row_idx = row_idx // self.qhead_per_kvhead_packgqa
|
|
413
|
+
if const_expr(mask_causal):
|
|
414
|
+
col_limit_right = row_idx + causal_row_offset
|
|
415
|
+
if const_expr(mask_seqlen):
|
|
416
|
+
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
|
|
417
|
+
# if cute.arch.thread_idx()[0] % 32 == 0:
|
|
418
|
+
# cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset)
|
|
419
|
+
ncol = const_expr(cute.size(tScS_t2r.shape))
|
|
420
|
+
if const_expr(not r2p):
|
|
421
|
+
for i in cutlass.range(ncol, unroll_full=True):
|
|
422
|
+
acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i]
|
|
423
|
+
else:
|
|
424
|
+
mask_r2p(acc_S, col_limit_right, arch=100, rank1=True)
|
|
425
|
+
else:
|
|
426
|
+
local_row_offset_right = (
|
|
427
|
+
causal_row_offset + self.window_size_right
|
|
428
|
+
if const_expr(self.window_size_right is not None)
|
|
429
|
+
else None
|
|
430
|
+
)
|
|
431
|
+
local_row_offset_left = (
|
|
432
|
+
causal_row_offset - 1 - self.window_size_left
|
|
433
|
+
if const_expr(self.window_size_left is not None)
|
|
434
|
+
else None
|
|
435
|
+
)
|
|
436
|
+
if const_expr(self.window_size_right is not None):
|
|
437
|
+
col_limit_right = row_idx + local_row_offset_right
|
|
438
|
+
else:
|
|
439
|
+
col_limit_right = self.tile_n
|
|
440
|
+
if const_expr(mask_seqlen):
|
|
441
|
+
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
|
|
442
|
+
col_limit_left = (
|
|
443
|
+
row_idx + local_row_offset_left
|
|
444
|
+
if const_expr(self.window_size_left is not None)
|
|
445
|
+
else 0
|
|
446
|
+
)
|
|
447
|
+
# if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left)
|
|
448
|
+
for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
|
|
449
|
+
col_idx = tScS_t2r[i][1]
|
|
450
|
+
acc_S[i] = (
|
|
451
|
+
-Float32.inf
|
|
452
|
+
if col_idx >= col_limit_right or col_idx < col_limit_left
|
|
453
|
+
else acc_S[i]
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
@cute.jit
|
|
457
|
+
def apply_mask_sm100_transposed(
|
|
458
|
+
self,
|
|
459
|
+
acc_S: cute.Tensor,
|
|
460
|
+
tScS_t2r: cute.Tensor,
|
|
461
|
+
t0ScS_t2r: cute.Tensor,
|
|
462
|
+
m_block: cutlass.Int32,
|
|
463
|
+
n_block: cutlass.Int32,
|
|
464
|
+
mask_seqlen: cutlass.Constexpr,
|
|
465
|
+
mask_causal: cutlass.Constexpr,
|
|
466
|
+
mask_local: cutlass.Constexpr,
|
|
467
|
+
mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
|
|
468
|
+
batch_idx: Int32 = None,
|
|
469
|
+
head_idx: Int32 = None,
|
|
470
|
+
aux_tensors: Optional[list] = None,
|
|
471
|
+
fastdiv_mods=(None, None),
|
|
472
|
+
is_full_block: bool = False,
|
|
473
|
+
check_m_boundary: bool = True,
|
|
474
|
+
) -> None:
|
|
475
|
+
"""
|
|
476
|
+
Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q.
|
|
477
|
+
|
|
478
|
+
Coordinate conventio:
|
|
479
|
+
- ROW corresponds to Q (m_block)
|
|
480
|
+
- COL corresponds to KV (n_block)
|
|
481
|
+
|
|
482
|
+
is_full_block: If True, skip mask_mod (all elements valid). Only apply seqlen masking.
|
|
483
|
+
check_m_boundary: If False, skip seqlen_q boundary check (optimization for non-boundary m_blocks).
|
|
484
|
+
When iterating m_blocks in forward order, only the last m_block may be partial.
|
|
485
|
+
"""
|
|
486
|
+
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
|
|
487
|
+
ROW = 0 if const_expr(not self.swap_AB) else 1
|
|
488
|
+
COL = 1 if const_expr(not self.swap_AB) else 0
|
|
489
|
+
assert t0ScS_t2r[0][COL] == 0, "col0 == 0"
|
|
490
|
+
thr_col_offset = tScS_t2r[0][COL]
|
|
491
|
+
seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
|
|
492
|
+
|
|
493
|
+
if const_expr(not mask_causal and not mask_local and mask_mod is not None):
|
|
494
|
+
# Block sparse case with mask_mod (backward)
|
|
495
|
+
#
|
|
496
|
+
# Coordinate convention: ROW → Q (m_block), COL → KV (n_block).
|
|
497
|
+
# These already account for swap_AB.
|
|
498
|
+
#
|
|
499
|
+
# FULL blocks: mask_mod returns True for all elements, so skip it.
|
|
500
|
+
# Still need seqlen bounds check (elements may be OOB on last m_block).
|
|
501
|
+
# PARTIAL blocks: apply mask_mod element-wise, then seqlen bounds.
|
|
502
|
+
if is_full_block:
|
|
503
|
+
if const_expr(mask_seqlen):
|
|
504
|
+
if seqlenk_col_limit <= 0:
|
|
505
|
+
# Entire tile is OOB for K
|
|
506
|
+
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
|
|
507
|
+
acc_S[i] = -cutlass.Float32.inf
|
|
508
|
+
elif check_m_boundary:
|
|
509
|
+
# Last m_block: check Q and K boundaries
|
|
510
|
+
ncol = const_expr(cute.size(tScS_t2r.shape))
|
|
511
|
+
for i in cutlass.range_constexpr(ncol):
|
|
512
|
+
row_coord = tScS_t2r[i][ROW]
|
|
513
|
+
col_coord = tScS_t2r[i][COL]
|
|
514
|
+
global_q = row_coord + m_block * self.tile_m
|
|
515
|
+
global_kv = col_coord + n_block * self.tile_n
|
|
516
|
+
q_out_of_bounds = global_q >= self.seqlen_q
|
|
517
|
+
kv_out_of_bounds = global_kv >= self.seqlen_k
|
|
518
|
+
out_of_bounds = q_out_of_bounds or kv_out_of_bounds
|
|
519
|
+
acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i]
|
|
520
|
+
else:
|
|
521
|
+
# Partial block
|
|
522
|
+
has_fastdiv = const_expr(
|
|
523
|
+
fastdiv_mods is not None
|
|
524
|
+
and fastdiv_mods[0] is not None
|
|
525
|
+
and fastdiv_mods[1] is not None
|
|
526
|
+
)
|
|
527
|
+
wrap_aux_indices = const_expr(
|
|
528
|
+
has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None)
|
|
529
|
+
)
|
|
530
|
+
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
|
|
531
|
+
head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32)
|
|
532
|
+
|
|
533
|
+
ncol = const_expr(cute.size(tScS_t2r.shape))
|
|
534
|
+
for i in cutlass.range_constexpr(ncol):
|
|
535
|
+
row_coord = tScS_t2r[i][ROW]
|
|
536
|
+
col_coord = tScS_t2r[i][COL]
|
|
537
|
+
global_q = row_coord + m_block * self.tile_m
|
|
538
|
+
global_kv = col_coord + n_block * self.tile_n
|
|
539
|
+
|
|
540
|
+
q_idx_for_mod = global_q
|
|
541
|
+
kv_idx_for_mod = global_kv
|
|
542
|
+
if const_expr(wrap_aux_indices):
|
|
543
|
+
_, q_idx_for_mod = divmod(global_q, fastdiv_mods[0])
|
|
544
|
+
_, kv_idx_for_mod = divmod(global_kv, fastdiv_mods[1])
|
|
545
|
+
|
|
546
|
+
q_idx_ssa = utils.scalar_to_ssa(q_idx_for_mod, cutlass.Int32)
|
|
547
|
+
kv_idx_ssa = utils.scalar_to_ssa(kv_idx_for_mod, cutlass.Int32)
|
|
548
|
+
|
|
549
|
+
mask_value = mask_mod(
|
|
550
|
+
batch_idx_ssa,
|
|
551
|
+
head_idx_ssa,
|
|
552
|
+
q_idx_ssa,
|
|
553
|
+
kv_idx_ssa,
|
|
554
|
+
self.seqlen_info,
|
|
555
|
+
aux_tensors,
|
|
556
|
+
)
|
|
557
|
+
cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
|
|
558
|
+
acc_S[i] = acc_S[i] if cond else -cutlass.Float32.inf
|
|
559
|
+
|
|
560
|
+
if const_expr(mask_seqlen):
|
|
561
|
+
# check_m_boundary=False skips q check for non-boundary m_blocks
|
|
562
|
+
q_out_of_bounds = check_m_boundary and (global_q >= self.seqlen_q)
|
|
563
|
+
kv_out_of_bounds = global_kv >= self.seqlen_k
|
|
564
|
+
out_of_bounds = q_out_of_bounds or kv_out_of_bounds
|
|
565
|
+
acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i]
|
|
566
|
+
|
|
567
|
+
elif const_expr(not mask_causal and not mask_local):
|
|
568
|
+
if const_expr(mask_seqlen):
|
|
569
|
+
if seqlenk_col_limit <= 0:
|
|
570
|
+
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
|
|
571
|
+
acc_S[i] = -cutlass.Float32.inf
|
|
572
|
+
else: # Causal or local
|
|
573
|
+
thr_row_offset = tScS_t2r[0][ROW]
|
|
574
|
+
seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset
|
|
575
|
+
causal_offset = seqlenq_row_limit - seqlenk_col_limit
|
|
576
|
+
if const_expr(mask_causal):
|
|
577
|
+
# tidx = cute.arch.thread_idx()[0] % 256
|
|
578
|
+
# if tidx < 32:
|
|
579
|
+
# cute.printf("tidx = {}, {} {}, {} {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1])
|
|
580
|
+
row_limit_top = causal_offset
|
|
581
|
+
if const_expr(mask_seqlen):
|
|
582
|
+
# If col is beyond the column limit, we want to mask out the entire
|
|
583
|
+
# column, by setting row limit to be self.tile_m.
|
|
584
|
+
if seqlenk_col_limit <= 0:
|
|
585
|
+
row_limit_top = self.tile_m
|
|
586
|
+
r2p = True
|
|
587
|
+
if const_expr(not r2p):
|
|
588
|
+
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
|
|
589
|
+
acc_S[i] = (
|
|
590
|
+
-cutlass.Float32.inf if t0ScS_t2r[i][ROW] < row_limit_top else acc_S[i]
|
|
591
|
+
)
|
|
592
|
+
else:
|
|
593
|
+
num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32
|
|
594
|
+
mask_r2p_transposed(acc_S, row_limit_top, num_rep)
|
|
595
|
+
else:
|
|
596
|
+
if const_expr(self.window_size_right is not None):
|
|
597
|
+
row_limit_top = causal_offset - self.window_size_right
|
|
598
|
+
else:
|
|
599
|
+
row_limit_top = 0
|
|
600
|
+
if const_expr(self.window_size_left is not None):
|
|
601
|
+
row_limit_bot = causal_offset + self.window_size_left
|
|
602
|
+
if const_expr(mask_seqlen):
|
|
603
|
+
if seqlenk_col_limit <= 0:
|
|
604
|
+
row_limit_top = self.tile_m
|
|
605
|
+
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
|
|
606
|
+
row_idx = t0ScS_t2r[i][ROW]
|
|
607
|
+
local_mask = row_idx < row_limit_top
|
|
608
|
+
if const_expr(self.window_size_left is not None):
|
|
609
|
+
local_mask |= row_idx > row_limit_bot
|
|
610
|
+
acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i]
|