mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,219 @@
1
+ # @nolint # fbcode
2
+ """
3
+ Block-sparsity utilities for FlexAttention
4
+ """
5
+
6
+ from typing import Callable, NamedTuple, Tuple
7
+
8
+ import cutlass.cute as cute
9
+ import torch
10
+
11
+ from mslk.attention.flash_attn.cute_dsl_utils import to_cute_tensor
12
+
13
+
14
+ def ceildiv(a: int, b: int) -> int:
15
+ return (a + b - 1) // b
16
+
17
+
18
+ class BlockSparseTensors(NamedTuple):
19
+ mask_block_cnt: cute.Tensor
20
+ mask_block_idx: cute.Tensor
21
+ full_block_cnt: cute.Tensor | None
22
+ full_block_idx: cute.Tensor | None
23
+
24
+ def __new_from_mlir_values__(self, values):
25
+ if len(values) == 2:
26
+ values = (*values, None, None)
27
+ return BlockSparseTensors(*values)
28
+
29
+
30
+ class BlockSparseTensorsTorch(NamedTuple):
31
+ mask_block_cnt: torch.Tensor
32
+ mask_block_idx: torch.Tensor
33
+ full_block_cnt: torch.Tensor | None = None
34
+ full_block_idx: torch.Tensor | None = None
35
+
36
+
37
+ def _expand_sparsity_tensor(
38
+ tensor: torch.Tensor,
39
+ expected_shape: Tuple[int, ...],
40
+ tensor_name: str,
41
+ context: str | None,
42
+ hint: str | Callable[[], str] | None,
43
+ ) -> torch.Tensor:
44
+ """Check if we need to expand the tensor to expected shape, and do so if possible."""
45
+ needs_expand = tensor.shape != expected_shape
46
+ if not needs_expand:
47
+ return tensor
48
+ can_expand = all(map(lambda cur, tgt: cur == tgt or cur == 1, tensor.shape, expected_shape))
49
+ if not can_expand:
50
+ context_clause = f" ({context})" if context else ""
51
+ resolved_hint = hint() if callable(hint) else hint
52
+ hint_clause = f" Hint: {resolved_hint}" if resolved_hint else ""
53
+ raise ValueError(
54
+ f"{tensor_name}{context_clause} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}."
55
+ f"{hint_clause}"
56
+ )
57
+ return tensor.expand(*expected_shape)
58
+
59
+
60
+ def _check_and_expand_block(
61
+ name: str,
62
+ cnt: torch.Tensor | None,
63
+ idx: torch.Tensor | None,
64
+ expected_count_shape: Tuple[int, int, int],
65
+ expected_index_shape: Tuple[int, int, int, int],
66
+ context: str | None,
67
+ hint: str | Callable[[], str] | None,
68
+ ) -> Tuple[torch.Tensor | None, torch.Tensor | None]:
69
+ if (cnt is None) != (idx is None):
70
+ raise ValueError(
71
+ f"{name}_block_cnt and {name}_block_idx must both be provided or both be None"
72
+ )
73
+ if cnt is None or idx is None:
74
+ return None, None
75
+ if cnt.dtype != torch.int32 or idx.dtype != torch.int32:
76
+ raise ValueError(f"{name}_block tensors must have dtype torch.int32")
77
+ if cnt.device != idx.device:
78
+ raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device")
79
+ if not cnt.is_cuda or not idx.is_cuda:
80
+ raise ValueError(f"{name}_block tensors must live on CUDA")
81
+ expanded_cnt = _expand_sparsity_tensor(
82
+ cnt, expected_count_shape, f"{name}_block_cnt", context, hint
83
+ )
84
+ expanded_idx = _expand_sparsity_tensor(
85
+ idx, expected_index_shape, f"{name}_block_idx", context, hint
86
+ )
87
+ return expanded_cnt, expanded_idx
88
+
89
+
90
+ def get_block_sparse_expected_shapes(
91
+ batch_size: int,
92
+ num_head: int,
93
+ seqlen_q: int,
94
+ seqlen_k: int,
95
+ m_block_size: int,
96
+ n_block_size: int,
97
+ q_stage: int,
98
+ ) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]:
99
+ """Return (expected_count_shape, expected_index_shape) for block sparse normalization."""
100
+ m_block_size_effective = q_stage * m_block_size
101
+ expected_m_blocks = ceildiv(seqlen_q, m_block_size_effective)
102
+ expected_n_blocks = ceildiv(seqlen_k, n_block_size)
103
+ expected_count_shape = (batch_size, num_head, expected_m_blocks)
104
+ expected_index_shape = (batch_size, num_head, expected_m_blocks, expected_n_blocks)
105
+ return expected_count_shape, expected_index_shape
106
+
107
+
108
+ def get_block_sparse_expected_shapes_bwd(
109
+ batch_size: int,
110
+ num_head: int,
111
+ seqlen_q: int,
112
+ seqlen_k: int,
113
+ m_block_size: int,
114
+ n_block_size: int,
115
+ subtile_factor: int,
116
+ ) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]:
117
+ """Return (expected_count_shape, expected_index_shape) for backward block sparse normalization.
118
+
119
+ Backward uses Q-direction indexing (transposed from forward), where shapes are
120
+ indexed by N-blocks first, then M-blocks. The sparse_block_size_q is determined
121
+ by subtile_factor * m_block_size.
122
+ """
123
+ sparse_block_size_q = subtile_factor * m_block_size
124
+ expected_m_blocks = ceildiv(seqlen_q, sparse_block_size_q)
125
+ expected_n_blocks = ceildiv(seqlen_k, n_block_size)
126
+ expected_count_shape = (batch_size, num_head, expected_n_blocks)
127
+ expected_index_shape = (batch_size, num_head, expected_n_blocks, expected_m_blocks)
128
+ return expected_count_shape, expected_index_shape
129
+
130
+
131
+ def normalize_block_sparse_tensors(
132
+ tensors: BlockSparseTensorsTorch,
133
+ *,
134
+ expected_count_shape: Tuple[int, int, int],
135
+ expected_index_shape: Tuple[int, int, int, int],
136
+ context: str | None = None,
137
+ hint: str | Callable[[], str] | None = None,
138
+ ) -> BlockSparseTensorsTorch:
139
+ if tensors.mask_block_cnt is None or tensors.mask_block_idx is None:
140
+ raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
141
+
142
+ mask_cnt, mask_idx = _check_and_expand_block(
143
+ "mask",
144
+ tensors.mask_block_cnt,
145
+ tensors.mask_block_idx,
146
+ expected_count_shape,
147
+ expected_index_shape,
148
+ context,
149
+ hint,
150
+ )
151
+ if mask_cnt is None or mask_idx is None:
152
+ raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
153
+
154
+ full_cnt, full_idx = _check_and_expand_block(
155
+ "full",
156
+ tensors.full_block_cnt,
157
+ tensors.full_block_idx,
158
+ expected_count_shape,
159
+ expected_index_shape,
160
+ context,
161
+ hint,
162
+ )
163
+ if full_cnt is not None and mask_cnt.device != full_cnt.device:
164
+ raise ValueError("All block sparse tensors must be on the same device")
165
+
166
+ return BlockSparseTensorsTorch(
167
+ mask_block_cnt=mask_cnt,
168
+ mask_block_idx=mask_idx,
169
+ full_block_cnt=full_cnt,
170
+ full_block_idx=full_idx,
171
+ )
172
+
173
+
174
+ def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool:
175
+ return any(t is not None for t in (tensors.full_block_cnt, tensors.mask_block_cnt))
176
+
177
+
178
+ def to_cute_block_sparse_tensors(
179
+ tensors: BlockSparseTensorsTorch, enable_tvm_ffi: bool = True
180
+ ) -> BlockSparseTensors | None:
181
+ """Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi"""
182
+ if not is_block_sparsity_enabled(tensors):
183
+ return None
184
+ (
185
+ mask_block_cnt,
186
+ mask_block_idx,
187
+ full_block_cnt,
188
+ full_block_idx,
189
+ ) = tensors
190
+
191
+ (
192
+ mask_block_cnt_tensor,
193
+ mask_block_idx_tensor,
194
+ ) = [
195
+ to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi)
196
+ for t in (mask_block_cnt, mask_block_idx)
197
+ ]
198
+ (
199
+ full_block_cnt_tensor,
200
+ full_block_idx_tensor,
201
+ ) = [
202
+ to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi)
203
+ if t is not None
204
+ else None
205
+ for t in (full_block_cnt, full_block_idx)
206
+ ]
207
+
208
+ return BlockSparseTensors(
209
+ mask_block_cnt_tensor,
210
+ mask_block_idx_tensor,
211
+ full_block_cnt_tensor,
212
+ full_block_idx_tensor,
213
+ )
214
+
215
+
216
+ def fast_sampling(mask_mod):
217
+ """Convenience decorator to mark mask_mod as safe for 5-point fast sampling"""
218
+ mask_mod.use_fast_sampling = True
219
+ return mask_mod
@@ -0,0 +1,378 @@
1
+ # @nolint # fbcode
2
+ from functools import partial
3
+ from typing import Callable, Optional, Tuple
4
+
5
+ import cutlass
6
+ import cutlass.cute as cute
7
+ import torch
8
+ from cutlass import Boolean, Int8, Int32, const_expr
9
+
10
+ from mslk.attention.flash_attn.block_sparsity import (
11
+ BlockSparseTensors,
12
+ BlockSparseTensorsTorch,
13
+ to_cute_block_sparse_tensors,
14
+ )
15
+ from mslk.attention.flash_attn.utils import hash_callable, scalar_to_ssa, ssa_to_scalar
16
+ from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
17
+
18
+
19
+ class BlockSparsityKernel:
20
+ """Block sparsity kernel for FlexAttention.
21
+
22
+ This kernel computes `mask_mod` for every token of each block
23
+ to determine if an n block is full, masked, or neither.
24
+
25
+ Writes block counts and indices to a BlockSparseTensors object.
26
+
27
+ When use_fast_sampling=True, uses 5-point sampling (4 corners + center)
28
+ which is much faster but only suitable for masks where this is sufficient.
29
+
30
+ TODO:
31
+ - optimize mask_mod evaluation
32
+ - varlen support
33
+ - transposed tensors for bwd pass
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ mask_mod: Callable,
39
+ tile_mn: Tuple[int, int],
40
+ compute_full_blocks: bool = True,
41
+ use_aux_tensors: bool = False,
42
+ use_fast_sampling: bool = False,
43
+ ):
44
+ self.mask_mod = mask_mod
45
+ self.tile_mn = tile_mn
46
+ self.compute_full_blocks = compute_full_blocks
47
+ self.use_aux_tensors = use_aux_tensors
48
+ self.use_fast_sampling = use_fast_sampling
49
+
50
+ @cute.jit
51
+ def __call__(
52
+ self,
53
+ blocksparse_tensors: BlockSparseTensors,
54
+ seqlen_q: Int32,
55
+ seqlen_k: Int32,
56
+ aux_tensors: Optional[list] = None,
57
+ ):
58
+ self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx = blocksparse_tensors
59
+
60
+ if const_expr(self.compute_full_blocks):
61
+ assert self.full_cnt is not None and self.full_idx is not None, (
62
+ "full block tensors must be provided when computing full blocks"
63
+ )
64
+
65
+ batch_size, num_heads, num_m_blocks, num_n_blocks = self.mask_idx.shape
66
+ # launch 1 CTA per m block
67
+ grid = [num_m_blocks, num_heads, batch_size]
68
+
69
+ if const_expr(self.use_fast_sampling):
70
+ num_threads = 5
71
+ self.num_warps = 1
72
+ else:
73
+ num_threads = self.tile_mn[0]
74
+ self.num_warps = (num_threads + 32 - 1) // 32
75
+
76
+ self.kernel(
77
+ self.mask_cnt,
78
+ self.mask_idx,
79
+ self.full_cnt,
80
+ self.full_idx,
81
+ num_n_blocks,
82
+ seqlen_q,
83
+ seqlen_k,
84
+ aux_tensors,
85
+ ).launch(grid=grid, block=[num_threads, 1, 1])
86
+
87
+ @cute.kernel
88
+ def kernel(
89
+ self,
90
+ mask_cnt: cute.Tensor,
91
+ mask_idx: cute.Tensor,
92
+ full_cnt: cute.Tensor,
93
+ full_idx: cute.Tensor,
94
+ num_n_blocks: Int32,
95
+ seqlen_q: Int32,
96
+ seqlen_k: Int32,
97
+ aux_tensors: Optional[list] = None,
98
+ ):
99
+ tidx, _, _ = cute.arch.thread_idx()
100
+ warp_idx = cute.arch.warp_idx()
101
+ lane_id = cute.arch.lane_idx()
102
+ m_block, head_idx, batch_idx = cute.arch.block_idx()
103
+
104
+ ssa = partial(scalar_to_ssa, dtype=Int32)
105
+
106
+ seqlen = SeqlenInfoQK.create(
107
+ batch_idx,
108
+ seqlen_q,
109
+ seqlen_k,
110
+ mCuSeqlensQ=None,
111
+ mCuSeqlensK=None,
112
+ mSeqUsedQ=None,
113
+ mSeqUsedK=None,
114
+ )
115
+
116
+ @cute.struct
117
+ class SharedStorage:
118
+ reduction_buffer_smem: cute.struct.Align[
119
+ cute.struct.MemRange[cutlass.Int8, 2 * self.num_warps], 1024
120
+ ]
121
+
122
+ smem = cutlass.utils.SmemAllocator()
123
+ storage = smem.allocate(SharedStorage, 16)
124
+
125
+ reduction_buffer = storage.reduction_buffer_smem.get_tensor(
126
+ cute.make_layout((self.num_warps, 2))
127
+ )
128
+
129
+ num_mask_blocks = Int32(0)
130
+ num_full_blocks = Int32(0)
131
+
132
+ for n_block in cutlass.range(num_n_blocks, unroll_full=True):
133
+ m_base = m_block * self.tile_mn[0]
134
+ n_base = n_block * self.tile_mn[1]
135
+
136
+ if const_expr(self.use_fast_sampling):
137
+ # Fast path: 5-point sampling (4 corners + center)
138
+ # Clamps OOB indices to nearest in bounds.
139
+ thread_result = Boolean(False)
140
+ thread_is_valid = Boolean(False)
141
+ q_idx = Int32(0)
142
+ kv_idx = Int32(0)
143
+
144
+ if tidx == 0:
145
+ # Top-left corner (0, 0); always in bounds
146
+ q_idx = m_base
147
+ kv_idx = n_base
148
+ elif tidx == 1:
149
+ # Top-right corner
150
+ q_idx = m_base
151
+ kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1)
152
+ elif tidx == 2:
153
+ # Bottom-left corner
154
+ q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1)
155
+ kv_idx = n_base
156
+ elif tidx == 3:
157
+ # Bottom-right corner
158
+ q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1)
159
+ kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1)
160
+ elif tidx == 4:
161
+ # Center point
162
+ q_idx = m_base + (cutlass.min(seqlen_q - m_base, self.tile_mn[0])) // 2
163
+ kv_idx = n_base + (cutlass.min(seqlen_k - n_base, self.tile_mn[1])) // 2
164
+ else:
165
+ thread_is_valid = Boolean(False)
166
+
167
+ # Check bounds and determine if this thread has a valid index pair
168
+ if tidx < 5 and q_idx < seqlen_q and kv_idx < seqlen_k:
169
+ thread_is_valid = Boolean(True)
170
+ q_idx_ssa = ssa(q_idx)
171
+ kv_idx_ssa = ssa(kv_idx)
172
+ thread_result = ssa_to_scalar(
173
+ self.mask_mod(
174
+ ssa(batch_idx),
175
+ ssa(head_idx),
176
+ q_idx_ssa,
177
+ kv_idx_ssa,
178
+ seqlen,
179
+ aux_tensors,
180
+ )
181
+ )
182
+ else:
183
+ thread_is_valid = Boolean(False)
184
+
185
+ # Use vote_any_sync to see if any valid thread found unmasked or masked
186
+ # Only count results from threads that checked valid indices
187
+ has_unmasked = cute.arch.vote_any_sync(thread_result & thread_is_valid)
188
+ has_masked = cute.arch.vote_any_sync((Boolean(not thread_result)) & thread_is_valid)
189
+
190
+ else:
191
+ # Full path: check all elements in the block
192
+ # Track if this thread's row has any masked or unmasked elements
193
+ thread_has_unmasked = Boolean(False)
194
+ thread_has_masked = Boolean(False)
195
+ thread_is_valid = Boolean(False)
196
+
197
+ # Each thread handles 1 row
198
+ q_idx = m_base + tidx
199
+ kv_idx = Int32(0)
200
+ if tidx < self.tile_mn[0] and q_idx < seqlen_q:
201
+ thread_is_valid = Boolean(True)
202
+ q_idx_ssa = ssa(q_idx)
203
+
204
+ # Loop over all columns in this row
205
+ for c in cutlass.range(self.tile_mn[1], unroll_full=True):
206
+ kv_idx = n_base + c
207
+ kv_idx_ssa = ssa(kv_idx)
208
+
209
+ # Only check elements within valid sequence bounds
210
+ if kv_idx < seqlen_k:
211
+ # Direct scalar call
212
+ mask_val = ssa_to_scalar(
213
+ self.mask_mod(
214
+ ssa(batch_idx),
215
+ ssa(head_idx),
216
+ q_idx_ssa,
217
+ kv_idx_ssa,
218
+ seqlen,
219
+ aux_tensors,
220
+ )
221
+ )
222
+
223
+ # Update tracking flags
224
+ if mask_val:
225
+ thread_has_unmasked = Boolean(True)
226
+ else:
227
+ thread_has_masked = Boolean(True)
228
+
229
+ # Block-level reduction to combine results across all threads
230
+ # Only count votes from threads that checked valid indices
231
+ warp_has_unmasked_mask = cute.arch.vote_any_sync(
232
+ thread_has_unmasked & thread_is_valid
233
+ )
234
+ warp_has_masked_mask = cute.arch.vote_any_sync(thread_has_masked & thread_is_valid)
235
+
236
+ # lane 0 writes the ballot mask to shared memory
237
+ lane_id = tidx % 32
238
+ if lane_id == 0:
239
+ # Store as Int8
240
+ reduction_buffer[warp_idx, 0] = Int8(1) if warp_has_unmasked_mask else Int8(0)
241
+ reduction_buffer[warp_idx, 1] = Int8(1) if warp_has_masked_mask else Int8(0)
242
+
243
+ cute.arch.sync_threads()
244
+
245
+ # Thread 0 ORs all warp results together
246
+ has_unmasked = Boolean(False)
247
+ has_masked = Boolean(False)
248
+ if tidx == 0:
249
+ for w in cutlass.range(self.num_warps):
250
+ if reduction_buffer[w, 0]:
251
+ has_unmasked = Boolean(True)
252
+ if reduction_buffer[w, 1]:
253
+ has_masked = Boolean(True)
254
+
255
+ # Only thread 0 updates the output arrays (common to both paths)
256
+ if tidx == 0:
257
+ # Block classification based on what we found:
258
+ # - If has_masked and has_unmasked: partial block (needs masking)
259
+ # - If only has_unmasked: full block (no masking needed)
260
+ # - If only has_masked: skip this block entirely
261
+ is_partial = Boolean(has_masked and has_unmasked)
262
+ is_full = Boolean(has_unmasked and (not has_masked))
263
+
264
+ if is_partial:
265
+ mask_idx[batch_idx, head_idx, m_block, num_mask_blocks] = n_block
266
+ num_mask_blocks += 1
267
+ elif is_full and const_expr(self.compute_full_blocks):
268
+ full_idx[batch_idx, head_idx, m_block, num_full_blocks] = n_block
269
+ num_full_blocks += 1
270
+
271
+ # Only thread 0 writes back the counts
272
+ if tidx == 0:
273
+ mask_cnt[batch_idx, head_idx, m_block] = num_mask_blocks
274
+ if const_expr(self.compute_full_blocks):
275
+ full_cnt[batch_idx, head_idx, m_block] = num_full_blocks
276
+
277
+
278
+ def compute_block_sparsity(
279
+ tile_m,
280
+ tile_n,
281
+ batch_size,
282
+ num_heads,
283
+ seqlen_q,
284
+ seqlen_k,
285
+ mask_mod: Callable,
286
+ aux_tensors: Optional[list], # list[cute.Tensor]
287
+ device,
288
+ compute_full_blocks: bool = True,
289
+ use_fast_sampling: bool = False,
290
+ ) -> Tuple[BlockSparseTensors, BlockSparseTensorsTorch]:
291
+ """
292
+ Computes block sparsity for a given `mask_mod`.
293
+
294
+ Args:
295
+ tile_m: The tile size for the m dimension.
296
+ tile_n: The tile size for the n dimension.
297
+ batch_size: The batch size.
298
+ num_heads: The number of heads.
299
+ seqlen_q: The sequence length for the query.
300
+ seqlen_k: The sequence length for the key.
301
+ mask_mod: The `mask_mod` callable to use.
302
+ aux_tensors: A list of auxiliary tensors.
303
+ device: The device to use.
304
+ compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed.
305
+ use_fast_sampling: Whether to use 5-point sampling (4 corners + center). This is much faster, but only suitable for masks where this check is sufficient.
306
+
307
+ Returns:
308
+ A tuple of `BlockSparseTensors` and `BlockSparseTensorsTorch`.
309
+ """
310
+ # Check if mask_mod is marked as suitable for 5-point fast sampling
311
+ use_fast_sampling = getattr(mask_mod, "use_fast_sampling", use_fast_sampling)
312
+
313
+ num_m_blocks = (seqlen_q + tile_m - 1) // tile_m
314
+ num_n_blocks = (seqlen_k + tile_n - 1) // tile_n
315
+
316
+ mask_block_cnt = torch.zeros(
317
+ (batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32
318
+ )
319
+ mask_block_idx = torch.zeros(
320
+ (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32
321
+ )
322
+ full_block_cnt = (
323
+ torch.zeros((batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32)
324
+ if compute_full_blocks
325
+ else None
326
+ )
327
+ full_block_idx = (
328
+ torch.zeros(
329
+ (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32
330
+ )
331
+ if compute_full_blocks
332
+ else None
333
+ )
334
+
335
+ blocksparse_tensors_torch = BlockSparseTensorsTorch(
336
+ mask_block_cnt=mask_block_cnt,
337
+ mask_block_idx=mask_block_idx,
338
+ full_block_cnt=full_block_cnt,
339
+ full_block_idx=full_block_idx,
340
+ )
341
+
342
+ mask_mod_hash = hash_callable(mask_mod)
343
+ blocksparse_tensors = to_cute_block_sparse_tensors(
344
+ blocksparse_tensors_torch, enable_tvm_ffi=True
345
+ )
346
+
347
+ compile_key = (
348
+ tile_m,
349
+ tile_n,
350
+ mask_mod_hash,
351
+ compute_full_blocks,
352
+ aux_tensors is not None,
353
+ use_fast_sampling,
354
+ )
355
+ if compile_key not in compute_block_sparsity.compile_cache:
356
+ kernel = BlockSparsityKernel(
357
+ mask_mod,
358
+ tile_mn=(tile_m, tile_n),
359
+ compute_full_blocks=compute_full_blocks,
360
+ use_aux_tensors=aux_tensors is not None,
361
+ use_fast_sampling=use_fast_sampling,
362
+ )
363
+
364
+ compute_block_sparsity.compile_cache[compile_key] = cute.compile(
365
+ kernel, blocksparse_tensors, seqlen_q, seqlen_k, aux_tensors, options="--enable-tvm-ffi"
366
+ )
367
+
368
+ compute_block_sparsity.compile_cache[compile_key](
369
+ blocksparse_tensors_torch,
370
+ seqlen_q,
371
+ seqlen_k,
372
+ aux_tensors,
373
+ )
374
+
375
+ return blocksparse_tensors, blocksparse_tensors_torch
376
+
377
+
378
+ compute_block_sparsity.compile_cache = {}