mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
"""
|
|
3
|
+
Block-sparsity utilities for FlexAttention
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Callable, NamedTuple, Tuple
|
|
7
|
+
|
|
8
|
+
import cutlass.cute as cute
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from mslk.attention.flash_attn.cute_dsl_utils import to_cute_tensor
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def ceildiv(a: int, b: int) -> int:
|
|
15
|
+
return (a + b - 1) // b
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BlockSparseTensors(NamedTuple):
|
|
19
|
+
mask_block_cnt: cute.Tensor
|
|
20
|
+
mask_block_idx: cute.Tensor
|
|
21
|
+
full_block_cnt: cute.Tensor | None
|
|
22
|
+
full_block_idx: cute.Tensor | None
|
|
23
|
+
|
|
24
|
+
def __new_from_mlir_values__(self, values):
|
|
25
|
+
if len(values) == 2:
|
|
26
|
+
values = (*values, None, None)
|
|
27
|
+
return BlockSparseTensors(*values)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BlockSparseTensorsTorch(NamedTuple):
|
|
31
|
+
mask_block_cnt: torch.Tensor
|
|
32
|
+
mask_block_idx: torch.Tensor
|
|
33
|
+
full_block_cnt: torch.Tensor | None = None
|
|
34
|
+
full_block_idx: torch.Tensor | None = None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _expand_sparsity_tensor(
|
|
38
|
+
tensor: torch.Tensor,
|
|
39
|
+
expected_shape: Tuple[int, ...],
|
|
40
|
+
tensor_name: str,
|
|
41
|
+
context: str | None,
|
|
42
|
+
hint: str | Callable[[], str] | None,
|
|
43
|
+
) -> torch.Tensor:
|
|
44
|
+
"""Check if we need to expand the tensor to expected shape, and do so if possible."""
|
|
45
|
+
needs_expand = tensor.shape != expected_shape
|
|
46
|
+
if not needs_expand:
|
|
47
|
+
return tensor
|
|
48
|
+
can_expand = all(map(lambda cur, tgt: cur == tgt or cur == 1, tensor.shape, expected_shape))
|
|
49
|
+
if not can_expand:
|
|
50
|
+
context_clause = f" ({context})" if context else ""
|
|
51
|
+
resolved_hint = hint() if callable(hint) else hint
|
|
52
|
+
hint_clause = f" Hint: {resolved_hint}" if resolved_hint else ""
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"{tensor_name}{context_clause} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}."
|
|
55
|
+
f"{hint_clause}"
|
|
56
|
+
)
|
|
57
|
+
return tensor.expand(*expected_shape)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _check_and_expand_block(
|
|
61
|
+
name: str,
|
|
62
|
+
cnt: torch.Tensor | None,
|
|
63
|
+
idx: torch.Tensor | None,
|
|
64
|
+
expected_count_shape: Tuple[int, int, int],
|
|
65
|
+
expected_index_shape: Tuple[int, int, int, int],
|
|
66
|
+
context: str | None,
|
|
67
|
+
hint: str | Callable[[], str] | None,
|
|
68
|
+
) -> Tuple[torch.Tensor | None, torch.Tensor | None]:
|
|
69
|
+
if (cnt is None) != (idx is None):
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"{name}_block_cnt and {name}_block_idx must both be provided or both be None"
|
|
72
|
+
)
|
|
73
|
+
if cnt is None or idx is None:
|
|
74
|
+
return None, None
|
|
75
|
+
if cnt.dtype != torch.int32 or idx.dtype != torch.int32:
|
|
76
|
+
raise ValueError(f"{name}_block tensors must have dtype torch.int32")
|
|
77
|
+
if cnt.device != idx.device:
|
|
78
|
+
raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device")
|
|
79
|
+
if not cnt.is_cuda or not idx.is_cuda:
|
|
80
|
+
raise ValueError(f"{name}_block tensors must live on CUDA")
|
|
81
|
+
expanded_cnt = _expand_sparsity_tensor(
|
|
82
|
+
cnt, expected_count_shape, f"{name}_block_cnt", context, hint
|
|
83
|
+
)
|
|
84
|
+
expanded_idx = _expand_sparsity_tensor(
|
|
85
|
+
idx, expected_index_shape, f"{name}_block_idx", context, hint
|
|
86
|
+
)
|
|
87
|
+
return expanded_cnt, expanded_idx
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_block_sparse_expected_shapes(
|
|
91
|
+
batch_size: int,
|
|
92
|
+
num_head: int,
|
|
93
|
+
seqlen_q: int,
|
|
94
|
+
seqlen_k: int,
|
|
95
|
+
m_block_size: int,
|
|
96
|
+
n_block_size: int,
|
|
97
|
+
q_stage: int,
|
|
98
|
+
) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]:
|
|
99
|
+
"""Return (expected_count_shape, expected_index_shape) for block sparse normalization."""
|
|
100
|
+
m_block_size_effective = q_stage * m_block_size
|
|
101
|
+
expected_m_blocks = ceildiv(seqlen_q, m_block_size_effective)
|
|
102
|
+
expected_n_blocks = ceildiv(seqlen_k, n_block_size)
|
|
103
|
+
expected_count_shape = (batch_size, num_head, expected_m_blocks)
|
|
104
|
+
expected_index_shape = (batch_size, num_head, expected_m_blocks, expected_n_blocks)
|
|
105
|
+
return expected_count_shape, expected_index_shape
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def get_block_sparse_expected_shapes_bwd(
|
|
109
|
+
batch_size: int,
|
|
110
|
+
num_head: int,
|
|
111
|
+
seqlen_q: int,
|
|
112
|
+
seqlen_k: int,
|
|
113
|
+
m_block_size: int,
|
|
114
|
+
n_block_size: int,
|
|
115
|
+
subtile_factor: int,
|
|
116
|
+
) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]:
|
|
117
|
+
"""Return (expected_count_shape, expected_index_shape) for backward block sparse normalization.
|
|
118
|
+
|
|
119
|
+
Backward uses Q-direction indexing (transposed from forward), where shapes are
|
|
120
|
+
indexed by N-blocks first, then M-blocks. The sparse_block_size_q is determined
|
|
121
|
+
by subtile_factor * m_block_size.
|
|
122
|
+
"""
|
|
123
|
+
sparse_block_size_q = subtile_factor * m_block_size
|
|
124
|
+
expected_m_blocks = ceildiv(seqlen_q, sparse_block_size_q)
|
|
125
|
+
expected_n_blocks = ceildiv(seqlen_k, n_block_size)
|
|
126
|
+
expected_count_shape = (batch_size, num_head, expected_n_blocks)
|
|
127
|
+
expected_index_shape = (batch_size, num_head, expected_n_blocks, expected_m_blocks)
|
|
128
|
+
return expected_count_shape, expected_index_shape
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def normalize_block_sparse_tensors(
|
|
132
|
+
tensors: BlockSparseTensorsTorch,
|
|
133
|
+
*,
|
|
134
|
+
expected_count_shape: Tuple[int, int, int],
|
|
135
|
+
expected_index_shape: Tuple[int, int, int, int],
|
|
136
|
+
context: str | None = None,
|
|
137
|
+
hint: str | Callable[[], str] | None = None,
|
|
138
|
+
) -> BlockSparseTensorsTorch:
|
|
139
|
+
if tensors.mask_block_cnt is None or tensors.mask_block_idx is None:
|
|
140
|
+
raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
|
|
141
|
+
|
|
142
|
+
mask_cnt, mask_idx = _check_and_expand_block(
|
|
143
|
+
"mask",
|
|
144
|
+
tensors.mask_block_cnt,
|
|
145
|
+
tensors.mask_block_idx,
|
|
146
|
+
expected_count_shape,
|
|
147
|
+
expected_index_shape,
|
|
148
|
+
context,
|
|
149
|
+
hint,
|
|
150
|
+
)
|
|
151
|
+
if mask_cnt is None or mask_idx is None:
|
|
152
|
+
raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
|
|
153
|
+
|
|
154
|
+
full_cnt, full_idx = _check_and_expand_block(
|
|
155
|
+
"full",
|
|
156
|
+
tensors.full_block_cnt,
|
|
157
|
+
tensors.full_block_idx,
|
|
158
|
+
expected_count_shape,
|
|
159
|
+
expected_index_shape,
|
|
160
|
+
context,
|
|
161
|
+
hint,
|
|
162
|
+
)
|
|
163
|
+
if full_cnt is not None and mask_cnt.device != full_cnt.device:
|
|
164
|
+
raise ValueError("All block sparse tensors must be on the same device")
|
|
165
|
+
|
|
166
|
+
return BlockSparseTensorsTorch(
|
|
167
|
+
mask_block_cnt=mask_cnt,
|
|
168
|
+
mask_block_idx=mask_idx,
|
|
169
|
+
full_block_cnt=full_cnt,
|
|
170
|
+
full_block_idx=full_idx,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool:
|
|
175
|
+
return any(t is not None for t in (tensors.full_block_cnt, tensors.mask_block_cnt))
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def to_cute_block_sparse_tensors(
|
|
179
|
+
tensors: BlockSparseTensorsTorch, enable_tvm_ffi: bool = True
|
|
180
|
+
) -> BlockSparseTensors | None:
|
|
181
|
+
"""Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi"""
|
|
182
|
+
if not is_block_sparsity_enabled(tensors):
|
|
183
|
+
return None
|
|
184
|
+
(
|
|
185
|
+
mask_block_cnt,
|
|
186
|
+
mask_block_idx,
|
|
187
|
+
full_block_cnt,
|
|
188
|
+
full_block_idx,
|
|
189
|
+
) = tensors
|
|
190
|
+
|
|
191
|
+
(
|
|
192
|
+
mask_block_cnt_tensor,
|
|
193
|
+
mask_block_idx_tensor,
|
|
194
|
+
) = [
|
|
195
|
+
to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi)
|
|
196
|
+
for t in (mask_block_cnt, mask_block_idx)
|
|
197
|
+
]
|
|
198
|
+
(
|
|
199
|
+
full_block_cnt_tensor,
|
|
200
|
+
full_block_idx_tensor,
|
|
201
|
+
) = [
|
|
202
|
+
to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi)
|
|
203
|
+
if t is not None
|
|
204
|
+
else None
|
|
205
|
+
for t in (full_block_cnt, full_block_idx)
|
|
206
|
+
]
|
|
207
|
+
|
|
208
|
+
return BlockSparseTensors(
|
|
209
|
+
mask_block_cnt_tensor,
|
|
210
|
+
mask_block_idx_tensor,
|
|
211
|
+
full_block_cnt_tensor,
|
|
212
|
+
full_block_idx_tensor,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def fast_sampling(mask_mod):
|
|
217
|
+
"""Convenience decorator to mark mask_mod as safe for 5-point fast sampling"""
|
|
218
|
+
mask_mod.use_fast_sampling = True
|
|
219
|
+
return mask_mod
|
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Callable, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import cutlass
|
|
6
|
+
import cutlass.cute as cute
|
|
7
|
+
import torch
|
|
8
|
+
from cutlass import Boolean, Int8, Int32, const_expr
|
|
9
|
+
|
|
10
|
+
from mslk.attention.flash_attn.block_sparsity import (
|
|
11
|
+
BlockSparseTensors,
|
|
12
|
+
BlockSparseTensorsTorch,
|
|
13
|
+
to_cute_block_sparse_tensors,
|
|
14
|
+
)
|
|
15
|
+
from mslk.attention.flash_attn.utils import hash_callable, scalar_to_ssa, ssa_to_scalar
|
|
16
|
+
from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BlockSparsityKernel:
|
|
20
|
+
"""Block sparsity kernel for FlexAttention.
|
|
21
|
+
|
|
22
|
+
This kernel computes `mask_mod` for every token of each block
|
|
23
|
+
to determine if an n block is full, masked, or neither.
|
|
24
|
+
|
|
25
|
+
Writes block counts and indices to a BlockSparseTensors object.
|
|
26
|
+
|
|
27
|
+
When use_fast_sampling=True, uses 5-point sampling (4 corners + center)
|
|
28
|
+
which is much faster but only suitable for masks where this is sufficient.
|
|
29
|
+
|
|
30
|
+
TODO:
|
|
31
|
+
- optimize mask_mod evaluation
|
|
32
|
+
- varlen support
|
|
33
|
+
- transposed tensors for bwd pass
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
mask_mod: Callable,
|
|
39
|
+
tile_mn: Tuple[int, int],
|
|
40
|
+
compute_full_blocks: bool = True,
|
|
41
|
+
use_aux_tensors: bool = False,
|
|
42
|
+
use_fast_sampling: bool = False,
|
|
43
|
+
):
|
|
44
|
+
self.mask_mod = mask_mod
|
|
45
|
+
self.tile_mn = tile_mn
|
|
46
|
+
self.compute_full_blocks = compute_full_blocks
|
|
47
|
+
self.use_aux_tensors = use_aux_tensors
|
|
48
|
+
self.use_fast_sampling = use_fast_sampling
|
|
49
|
+
|
|
50
|
+
@cute.jit
|
|
51
|
+
def __call__(
|
|
52
|
+
self,
|
|
53
|
+
blocksparse_tensors: BlockSparseTensors,
|
|
54
|
+
seqlen_q: Int32,
|
|
55
|
+
seqlen_k: Int32,
|
|
56
|
+
aux_tensors: Optional[list] = None,
|
|
57
|
+
):
|
|
58
|
+
self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx = blocksparse_tensors
|
|
59
|
+
|
|
60
|
+
if const_expr(self.compute_full_blocks):
|
|
61
|
+
assert self.full_cnt is not None and self.full_idx is not None, (
|
|
62
|
+
"full block tensors must be provided when computing full blocks"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
batch_size, num_heads, num_m_blocks, num_n_blocks = self.mask_idx.shape
|
|
66
|
+
# launch 1 CTA per m block
|
|
67
|
+
grid = [num_m_blocks, num_heads, batch_size]
|
|
68
|
+
|
|
69
|
+
if const_expr(self.use_fast_sampling):
|
|
70
|
+
num_threads = 5
|
|
71
|
+
self.num_warps = 1
|
|
72
|
+
else:
|
|
73
|
+
num_threads = self.tile_mn[0]
|
|
74
|
+
self.num_warps = (num_threads + 32 - 1) // 32
|
|
75
|
+
|
|
76
|
+
self.kernel(
|
|
77
|
+
self.mask_cnt,
|
|
78
|
+
self.mask_idx,
|
|
79
|
+
self.full_cnt,
|
|
80
|
+
self.full_idx,
|
|
81
|
+
num_n_blocks,
|
|
82
|
+
seqlen_q,
|
|
83
|
+
seqlen_k,
|
|
84
|
+
aux_tensors,
|
|
85
|
+
).launch(grid=grid, block=[num_threads, 1, 1])
|
|
86
|
+
|
|
87
|
+
@cute.kernel
|
|
88
|
+
def kernel(
|
|
89
|
+
self,
|
|
90
|
+
mask_cnt: cute.Tensor,
|
|
91
|
+
mask_idx: cute.Tensor,
|
|
92
|
+
full_cnt: cute.Tensor,
|
|
93
|
+
full_idx: cute.Tensor,
|
|
94
|
+
num_n_blocks: Int32,
|
|
95
|
+
seqlen_q: Int32,
|
|
96
|
+
seqlen_k: Int32,
|
|
97
|
+
aux_tensors: Optional[list] = None,
|
|
98
|
+
):
|
|
99
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
100
|
+
warp_idx = cute.arch.warp_idx()
|
|
101
|
+
lane_id = cute.arch.lane_idx()
|
|
102
|
+
m_block, head_idx, batch_idx = cute.arch.block_idx()
|
|
103
|
+
|
|
104
|
+
ssa = partial(scalar_to_ssa, dtype=Int32)
|
|
105
|
+
|
|
106
|
+
seqlen = SeqlenInfoQK.create(
|
|
107
|
+
batch_idx,
|
|
108
|
+
seqlen_q,
|
|
109
|
+
seqlen_k,
|
|
110
|
+
mCuSeqlensQ=None,
|
|
111
|
+
mCuSeqlensK=None,
|
|
112
|
+
mSeqUsedQ=None,
|
|
113
|
+
mSeqUsedK=None,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
@cute.struct
|
|
117
|
+
class SharedStorage:
|
|
118
|
+
reduction_buffer_smem: cute.struct.Align[
|
|
119
|
+
cute.struct.MemRange[cutlass.Int8, 2 * self.num_warps], 1024
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
smem = cutlass.utils.SmemAllocator()
|
|
123
|
+
storage = smem.allocate(SharedStorage, 16)
|
|
124
|
+
|
|
125
|
+
reduction_buffer = storage.reduction_buffer_smem.get_tensor(
|
|
126
|
+
cute.make_layout((self.num_warps, 2))
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
num_mask_blocks = Int32(0)
|
|
130
|
+
num_full_blocks = Int32(0)
|
|
131
|
+
|
|
132
|
+
for n_block in cutlass.range(num_n_blocks, unroll_full=True):
|
|
133
|
+
m_base = m_block * self.tile_mn[0]
|
|
134
|
+
n_base = n_block * self.tile_mn[1]
|
|
135
|
+
|
|
136
|
+
if const_expr(self.use_fast_sampling):
|
|
137
|
+
# Fast path: 5-point sampling (4 corners + center)
|
|
138
|
+
# Clamps OOB indices to nearest in bounds.
|
|
139
|
+
thread_result = Boolean(False)
|
|
140
|
+
thread_is_valid = Boolean(False)
|
|
141
|
+
q_idx = Int32(0)
|
|
142
|
+
kv_idx = Int32(0)
|
|
143
|
+
|
|
144
|
+
if tidx == 0:
|
|
145
|
+
# Top-left corner (0, 0); always in bounds
|
|
146
|
+
q_idx = m_base
|
|
147
|
+
kv_idx = n_base
|
|
148
|
+
elif tidx == 1:
|
|
149
|
+
# Top-right corner
|
|
150
|
+
q_idx = m_base
|
|
151
|
+
kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1)
|
|
152
|
+
elif tidx == 2:
|
|
153
|
+
# Bottom-left corner
|
|
154
|
+
q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1)
|
|
155
|
+
kv_idx = n_base
|
|
156
|
+
elif tidx == 3:
|
|
157
|
+
# Bottom-right corner
|
|
158
|
+
q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1)
|
|
159
|
+
kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1)
|
|
160
|
+
elif tidx == 4:
|
|
161
|
+
# Center point
|
|
162
|
+
q_idx = m_base + (cutlass.min(seqlen_q - m_base, self.tile_mn[0])) // 2
|
|
163
|
+
kv_idx = n_base + (cutlass.min(seqlen_k - n_base, self.tile_mn[1])) // 2
|
|
164
|
+
else:
|
|
165
|
+
thread_is_valid = Boolean(False)
|
|
166
|
+
|
|
167
|
+
# Check bounds and determine if this thread has a valid index pair
|
|
168
|
+
if tidx < 5 and q_idx < seqlen_q and kv_idx < seqlen_k:
|
|
169
|
+
thread_is_valid = Boolean(True)
|
|
170
|
+
q_idx_ssa = ssa(q_idx)
|
|
171
|
+
kv_idx_ssa = ssa(kv_idx)
|
|
172
|
+
thread_result = ssa_to_scalar(
|
|
173
|
+
self.mask_mod(
|
|
174
|
+
ssa(batch_idx),
|
|
175
|
+
ssa(head_idx),
|
|
176
|
+
q_idx_ssa,
|
|
177
|
+
kv_idx_ssa,
|
|
178
|
+
seqlen,
|
|
179
|
+
aux_tensors,
|
|
180
|
+
)
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
thread_is_valid = Boolean(False)
|
|
184
|
+
|
|
185
|
+
# Use vote_any_sync to see if any valid thread found unmasked or masked
|
|
186
|
+
# Only count results from threads that checked valid indices
|
|
187
|
+
has_unmasked = cute.arch.vote_any_sync(thread_result & thread_is_valid)
|
|
188
|
+
has_masked = cute.arch.vote_any_sync((Boolean(not thread_result)) & thread_is_valid)
|
|
189
|
+
|
|
190
|
+
else:
|
|
191
|
+
# Full path: check all elements in the block
|
|
192
|
+
# Track if this thread's row has any masked or unmasked elements
|
|
193
|
+
thread_has_unmasked = Boolean(False)
|
|
194
|
+
thread_has_masked = Boolean(False)
|
|
195
|
+
thread_is_valid = Boolean(False)
|
|
196
|
+
|
|
197
|
+
# Each thread handles 1 row
|
|
198
|
+
q_idx = m_base + tidx
|
|
199
|
+
kv_idx = Int32(0)
|
|
200
|
+
if tidx < self.tile_mn[0] and q_idx < seqlen_q:
|
|
201
|
+
thread_is_valid = Boolean(True)
|
|
202
|
+
q_idx_ssa = ssa(q_idx)
|
|
203
|
+
|
|
204
|
+
# Loop over all columns in this row
|
|
205
|
+
for c in cutlass.range(self.tile_mn[1], unroll_full=True):
|
|
206
|
+
kv_idx = n_base + c
|
|
207
|
+
kv_idx_ssa = ssa(kv_idx)
|
|
208
|
+
|
|
209
|
+
# Only check elements within valid sequence bounds
|
|
210
|
+
if kv_idx < seqlen_k:
|
|
211
|
+
# Direct scalar call
|
|
212
|
+
mask_val = ssa_to_scalar(
|
|
213
|
+
self.mask_mod(
|
|
214
|
+
ssa(batch_idx),
|
|
215
|
+
ssa(head_idx),
|
|
216
|
+
q_idx_ssa,
|
|
217
|
+
kv_idx_ssa,
|
|
218
|
+
seqlen,
|
|
219
|
+
aux_tensors,
|
|
220
|
+
)
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Update tracking flags
|
|
224
|
+
if mask_val:
|
|
225
|
+
thread_has_unmasked = Boolean(True)
|
|
226
|
+
else:
|
|
227
|
+
thread_has_masked = Boolean(True)
|
|
228
|
+
|
|
229
|
+
# Block-level reduction to combine results across all threads
|
|
230
|
+
# Only count votes from threads that checked valid indices
|
|
231
|
+
warp_has_unmasked_mask = cute.arch.vote_any_sync(
|
|
232
|
+
thread_has_unmasked & thread_is_valid
|
|
233
|
+
)
|
|
234
|
+
warp_has_masked_mask = cute.arch.vote_any_sync(thread_has_masked & thread_is_valid)
|
|
235
|
+
|
|
236
|
+
# lane 0 writes the ballot mask to shared memory
|
|
237
|
+
lane_id = tidx % 32
|
|
238
|
+
if lane_id == 0:
|
|
239
|
+
# Store as Int8
|
|
240
|
+
reduction_buffer[warp_idx, 0] = Int8(1) if warp_has_unmasked_mask else Int8(0)
|
|
241
|
+
reduction_buffer[warp_idx, 1] = Int8(1) if warp_has_masked_mask else Int8(0)
|
|
242
|
+
|
|
243
|
+
cute.arch.sync_threads()
|
|
244
|
+
|
|
245
|
+
# Thread 0 ORs all warp results together
|
|
246
|
+
has_unmasked = Boolean(False)
|
|
247
|
+
has_masked = Boolean(False)
|
|
248
|
+
if tidx == 0:
|
|
249
|
+
for w in cutlass.range(self.num_warps):
|
|
250
|
+
if reduction_buffer[w, 0]:
|
|
251
|
+
has_unmasked = Boolean(True)
|
|
252
|
+
if reduction_buffer[w, 1]:
|
|
253
|
+
has_masked = Boolean(True)
|
|
254
|
+
|
|
255
|
+
# Only thread 0 updates the output arrays (common to both paths)
|
|
256
|
+
if tidx == 0:
|
|
257
|
+
# Block classification based on what we found:
|
|
258
|
+
# - If has_masked and has_unmasked: partial block (needs masking)
|
|
259
|
+
# - If only has_unmasked: full block (no masking needed)
|
|
260
|
+
# - If only has_masked: skip this block entirely
|
|
261
|
+
is_partial = Boolean(has_masked and has_unmasked)
|
|
262
|
+
is_full = Boolean(has_unmasked and (not has_masked))
|
|
263
|
+
|
|
264
|
+
if is_partial:
|
|
265
|
+
mask_idx[batch_idx, head_idx, m_block, num_mask_blocks] = n_block
|
|
266
|
+
num_mask_blocks += 1
|
|
267
|
+
elif is_full and const_expr(self.compute_full_blocks):
|
|
268
|
+
full_idx[batch_idx, head_idx, m_block, num_full_blocks] = n_block
|
|
269
|
+
num_full_blocks += 1
|
|
270
|
+
|
|
271
|
+
# Only thread 0 writes back the counts
|
|
272
|
+
if tidx == 0:
|
|
273
|
+
mask_cnt[batch_idx, head_idx, m_block] = num_mask_blocks
|
|
274
|
+
if const_expr(self.compute_full_blocks):
|
|
275
|
+
full_cnt[batch_idx, head_idx, m_block] = num_full_blocks
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def compute_block_sparsity(
|
|
279
|
+
tile_m,
|
|
280
|
+
tile_n,
|
|
281
|
+
batch_size,
|
|
282
|
+
num_heads,
|
|
283
|
+
seqlen_q,
|
|
284
|
+
seqlen_k,
|
|
285
|
+
mask_mod: Callable,
|
|
286
|
+
aux_tensors: Optional[list], # list[cute.Tensor]
|
|
287
|
+
device,
|
|
288
|
+
compute_full_blocks: bool = True,
|
|
289
|
+
use_fast_sampling: bool = False,
|
|
290
|
+
) -> Tuple[BlockSparseTensors, BlockSparseTensorsTorch]:
|
|
291
|
+
"""
|
|
292
|
+
Computes block sparsity for a given `mask_mod`.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
tile_m: The tile size for the m dimension.
|
|
296
|
+
tile_n: The tile size for the n dimension.
|
|
297
|
+
batch_size: The batch size.
|
|
298
|
+
num_heads: The number of heads.
|
|
299
|
+
seqlen_q: The sequence length for the query.
|
|
300
|
+
seqlen_k: The sequence length for the key.
|
|
301
|
+
mask_mod: The `mask_mod` callable to use.
|
|
302
|
+
aux_tensors: A list of auxiliary tensors.
|
|
303
|
+
device: The device to use.
|
|
304
|
+
compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed.
|
|
305
|
+
use_fast_sampling: Whether to use 5-point sampling (4 corners + center). This is much faster, but only suitable for masks where this check is sufficient.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
A tuple of `BlockSparseTensors` and `BlockSparseTensorsTorch`.
|
|
309
|
+
"""
|
|
310
|
+
# Check if mask_mod is marked as suitable for 5-point fast sampling
|
|
311
|
+
use_fast_sampling = getattr(mask_mod, "use_fast_sampling", use_fast_sampling)
|
|
312
|
+
|
|
313
|
+
num_m_blocks = (seqlen_q + tile_m - 1) // tile_m
|
|
314
|
+
num_n_blocks = (seqlen_k + tile_n - 1) // tile_n
|
|
315
|
+
|
|
316
|
+
mask_block_cnt = torch.zeros(
|
|
317
|
+
(batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32
|
|
318
|
+
)
|
|
319
|
+
mask_block_idx = torch.zeros(
|
|
320
|
+
(batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32
|
|
321
|
+
)
|
|
322
|
+
full_block_cnt = (
|
|
323
|
+
torch.zeros((batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32)
|
|
324
|
+
if compute_full_blocks
|
|
325
|
+
else None
|
|
326
|
+
)
|
|
327
|
+
full_block_idx = (
|
|
328
|
+
torch.zeros(
|
|
329
|
+
(batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32
|
|
330
|
+
)
|
|
331
|
+
if compute_full_blocks
|
|
332
|
+
else None
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
blocksparse_tensors_torch = BlockSparseTensorsTorch(
|
|
336
|
+
mask_block_cnt=mask_block_cnt,
|
|
337
|
+
mask_block_idx=mask_block_idx,
|
|
338
|
+
full_block_cnt=full_block_cnt,
|
|
339
|
+
full_block_idx=full_block_idx,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
mask_mod_hash = hash_callable(mask_mod)
|
|
343
|
+
blocksparse_tensors = to_cute_block_sparse_tensors(
|
|
344
|
+
blocksparse_tensors_torch, enable_tvm_ffi=True
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
compile_key = (
|
|
348
|
+
tile_m,
|
|
349
|
+
tile_n,
|
|
350
|
+
mask_mod_hash,
|
|
351
|
+
compute_full_blocks,
|
|
352
|
+
aux_tensors is not None,
|
|
353
|
+
use_fast_sampling,
|
|
354
|
+
)
|
|
355
|
+
if compile_key not in compute_block_sparsity.compile_cache:
|
|
356
|
+
kernel = BlockSparsityKernel(
|
|
357
|
+
mask_mod,
|
|
358
|
+
tile_mn=(tile_m, tile_n),
|
|
359
|
+
compute_full_blocks=compute_full_blocks,
|
|
360
|
+
use_aux_tensors=aux_tensors is not None,
|
|
361
|
+
use_fast_sampling=use_fast_sampling,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
compute_block_sparsity.compile_cache[compile_key] = cute.compile(
|
|
365
|
+
kernel, blocksparse_tensors, seqlen_q, seqlen_k, aux_tensors, options="--enable-tvm-ffi"
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
compute_block_sparsity.compile_cache[compile_key](
|
|
369
|
+
blocksparse_tensors_torch,
|
|
370
|
+
seqlen_q,
|
|
371
|
+
seqlen_k,
|
|
372
|
+
aux_tensors,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
return blocksparse_tensors, blocksparse_tensors_torch
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
compute_block_sparsity.compile_cache = {}
|