mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Tri Dao.
|
|
3
|
+
|
|
4
|
+
# import math
|
|
5
|
+
from typing import Optional
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
|
|
8
|
+
import cutlass
|
|
9
|
+
import cutlass.cute as cute
|
|
10
|
+
from cutlass import Boolean, Int32, const_expr
|
|
11
|
+
from cutlass.cutlass_dsl import if_generate
|
|
12
|
+
from cutlass.pipeline import PipelineAsync, PipelineState, Agent, CooperativeGroup
|
|
13
|
+
from cutlass.pipeline import PipelineUserType, PipelineOp
|
|
14
|
+
from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
|
|
15
|
+
from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# We deviate from cute-dsl implementation to use cute.arch.cluster_arrive_relaxed
|
|
19
|
+
def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None):
|
|
20
|
+
"""
|
|
21
|
+
Fences the mbarrier init and syncs the threadblock or cluster
|
|
22
|
+
"""
|
|
23
|
+
cute.arch.mbarrier_init_fence()
|
|
24
|
+
|
|
25
|
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
|
26
|
+
# If not using clusters, sync the threadblock
|
|
27
|
+
_sync(Agent.ThreadBlock)
|
|
28
|
+
else:
|
|
29
|
+
# If using clusters, sync the cluster
|
|
30
|
+
_sync(Agent.ThreadBlockCluster)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _sync(group: Agent):
|
|
34
|
+
"""
|
|
35
|
+
Syncs all threads within an agent.
|
|
36
|
+
"""
|
|
37
|
+
if group is Agent.Thread:
|
|
38
|
+
raise NotImplementedError("Error: Not supported.")
|
|
39
|
+
elif group is Agent.ThreadBlock:
|
|
40
|
+
cute.arch.sync_threads()
|
|
41
|
+
elif group is Agent.ThreadBlockCluster:
|
|
42
|
+
cute.arch.cluster_arrive_relaxed()
|
|
43
|
+
cute.arch.cluster_wait()
|
|
44
|
+
else:
|
|
45
|
+
assert False, (
|
|
46
|
+
"Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead."
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class PipelineStateSimple:
|
|
51
|
+
"""
|
|
52
|
+
Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
|
|
53
|
+
Use a single Int32 to store both the index and phase bit, then we use divmod to get the
|
|
54
|
+
index and phase. If stages is a power of 2, divmod turns into bit twiddling.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, stages: int, phase_index: Int32):
|
|
58
|
+
# assert stages < 2**16
|
|
59
|
+
# self._log_stages = int(math.log2(stages))
|
|
60
|
+
# assert 1 << self._log_stages == stages, "Number of stages must be a power of 2."
|
|
61
|
+
self._stages = stages
|
|
62
|
+
self._phase_index = phase_index
|
|
63
|
+
|
|
64
|
+
def clone(self) -> "PipelineStateSimple":
|
|
65
|
+
return PipelineStateSimple(self.stages, self._phase_index)
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def stages(self) -> int:
|
|
69
|
+
# return 1 << self._log_stages
|
|
70
|
+
return self._stages
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def index(self) -> Int32:
|
|
74
|
+
# return self._phase_index & 0xFFFF
|
|
75
|
+
# return self._phase_index & ((1 << self._log_stages) - 1)
|
|
76
|
+
if const_expr(self._stages == 1):
|
|
77
|
+
return Int32(0)
|
|
78
|
+
else:
|
|
79
|
+
return self._phase_index % self._stages
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def phase(self) -> Int32:
|
|
83
|
+
# return self._phase_index >> 16
|
|
84
|
+
# PTX docs say that the phase parity needs to be 0 or 1, so by right we need to
|
|
85
|
+
# take modulo 2. But in practice just passing the phase in without modulo works fine.
|
|
86
|
+
# return (self._phase_index >> self._log_stages) % 2
|
|
87
|
+
# return self._phase_index >> self._log_stages
|
|
88
|
+
if const_expr(self._stages == 1):
|
|
89
|
+
return self._phase_index
|
|
90
|
+
else:
|
|
91
|
+
return self._phase_index // self._stages
|
|
92
|
+
|
|
93
|
+
def advance(self):
|
|
94
|
+
if const_expr(self._stages == 1):
|
|
95
|
+
self._phase_index ^= 1
|
|
96
|
+
else:
|
|
97
|
+
self._phase_index += 1
|
|
98
|
+
|
|
99
|
+
# def then_body(phase_index):
|
|
100
|
+
# # XOR the phase bit and set the index to 0
|
|
101
|
+
# return (phase_index & 0xFFFF0000) ^ (1 << 16)
|
|
102
|
+
|
|
103
|
+
# def else_body(phase_index):
|
|
104
|
+
# return phase_index
|
|
105
|
+
|
|
106
|
+
# self._phase_index = if_generate(
|
|
107
|
+
# (self._phase_index & 0xFFFF) == self.stages,
|
|
108
|
+
# then_body,
|
|
109
|
+
# else_body,
|
|
110
|
+
# [self._phase_index],
|
|
111
|
+
# [Int32],
|
|
112
|
+
# )
|
|
113
|
+
|
|
114
|
+
def __extract_mlir_values__(self):
|
|
115
|
+
phase_index = self._phase_index
|
|
116
|
+
return [phase_index.ir_value()]
|
|
117
|
+
|
|
118
|
+
def __new_from_mlir_values__(self, values):
|
|
119
|
+
return PipelineStateSimple(self.stages, Int32(values[0]))
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def make_pipeline_state(type: PipelineUserType, stages: int):
|
|
123
|
+
"""
|
|
124
|
+
Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
|
|
125
|
+
"""
|
|
126
|
+
if type is PipelineUserType.Producer:
|
|
127
|
+
# return PipelineStateSimple(stages, Int32(1 << 16))
|
|
128
|
+
return PipelineStateSimple(stages, Int32(stages))
|
|
129
|
+
elif type is PipelineUserType.Consumer:
|
|
130
|
+
return PipelineStateSimple(stages, Int32(0))
|
|
131
|
+
else:
|
|
132
|
+
assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@dataclass(frozen=True)
|
|
136
|
+
class PipelineTmaAsync(PipelineTmaAsyncOg):
|
|
137
|
+
"""
|
|
138
|
+
Override producer_acquire to take in extra_tx_count parameter.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
def create(*args, **kwargs):
|
|
143
|
+
obj = PipelineTmaAsyncOg.create(*args, **kwargs)
|
|
144
|
+
# Can't assign to __class__ directly since the dataclass is frozen
|
|
145
|
+
# obj.__class__ = PipelineTmaAsync
|
|
146
|
+
object.__setattr__(obj, "__class__", PipelineTmaAsync)
|
|
147
|
+
return obj
|
|
148
|
+
|
|
149
|
+
def producer_acquire(
|
|
150
|
+
self,
|
|
151
|
+
state: PipelineState,
|
|
152
|
+
try_acquire_token: Optional[Boolean] = None,
|
|
153
|
+
extra_tx_count: int = 0,
|
|
154
|
+
):
|
|
155
|
+
"""
|
|
156
|
+
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
|
|
157
|
+
"""
|
|
158
|
+
if_generate(
|
|
159
|
+
try_acquire_token is None or try_acquire_token == 0,
|
|
160
|
+
lambda: self.sync_object_empty.wait(state.index, state.phase),
|
|
161
|
+
)
|
|
162
|
+
if const_expr(extra_tx_count == 0):
|
|
163
|
+
self.sync_object_full.arrive(state.index, self.producer_mask)
|
|
164
|
+
else:
|
|
165
|
+
tx_count = self.sync_object_full.tx_count + extra_tx_count
|
|
166
|
+
self.sync_object_full.arrive_and_expect_tx(state.index, tx_count)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@dataclass(frozen=True)
|
|
170
|
+
class PipelineTmaUmma(PipelineTmaUmmaOg):
|
|
171
|
+
@staticmethod
|
|
172
|
+
def create(
|
|
173
|
+
*,
|
|
174
|
+
num_stages: int,
|
|
175
|
+
producer_group: CooperativeGroup,
|
|
176
|
+
consumer_group: CooperativeGroup,
|
|
177
|
+
tx_count: int,
|
|
178
|
+
barrier_storage: cute.Pointer = None,
|
|
179
|
+
cta_layout_vmnk: Optional[cute.Layout] = None,
|
|
180
|
+
mcast_mode_mn: tuple[int, int] = (1, 1),
|
|
181
|
+
init_wait: cutlass.Constexpr[bool] = True,
|
|
182
|
+
):
|
|
183
|
+
"""
|
|
184
|
+
This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
|
|
185
|
+
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
|
186
|
+
:type barrier_storage: cute.Pointer
|
|
187
|
+
:param num_stages: Number of buffer stages for this pipeline
|
|
188
|
+
:type num_stages: Int32
|
|
189
|
+
:param producer_group: `CooperativeGroup` for the producer agent
|
|
190
|
+
:type producer_group: CooperativeGroup
|
|
191
|
+
:param consumer_group: `CooperativeGroup` for the consumer agent
|
|
192
|
+
:type consumer_group: CooperativeGroup
|
|
193
|
+
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
|
194
|
+
:type tx_count: int
|
|
195
|
+
:param cta_layout_vmnk: Layout of the cluster shape
|
|
196
|
+
:type cta_layout_vmnk: cute.Layout | None
|
|
197
|
+
:param mcast_mode_mn: Tuple of two integers, specifying whether mcast is enabled for the m and n modes. At least one of the two integers must be 1.
|
|
198
|
+
:type mcast_mode_mn: tuple[int, int]
|
|
199
|
+
"""
|
|
200
|
+
if not isinstance(barrier_storage, cute.Pointer):
|
|
201
|
+
raise ValueError(
|
|
202
|
+
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
producer_type = PipelineOp.TmaLoad
|
|
206
|
+
consumer_type = PipelineOp.TCGen05Mma
|
|
207
|
+
|
|
208
|
+
producer = (producer_type, producer_group)
|
|
209
|
+
consumer = (consumer_type, consumer_group)
|
|
210
|
+
|
|
211
|
+
sync_object_full = PipelineAsync._make_sync_object(
|
|
212
|
+
barrier_storage.align(min_align=8), num_stages, producer, tx_count
|
|
213
|
+
)
|
|
214
|
+
sync_object_empty = PipelineAsync._make_sync_object(
|
|
215
|
+
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
|
219
|
+
# No mcast mask if not using clusters
|
|
220
|
+
producer_mask = None
|
|
221
|
+
# All threadblocks are leaders if not using clusters
|
|
222
|
+
is_leader_cta = True
|
|
223
|
+
else:
|
|
224
|
+
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(
|
|
225
|
+
cta_layout_vmnk, mcast_mode_mn
|
|
226
|
+
)
|
|
227
|
+
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
|
|
228
|
+
|
|
229
|
+
cta_group = (
|
|
230
|
+
cute.nvgpu.tcgen05.CtaGroup.ONE
|
|
231
|
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
|
|
232
|
+
else cute.nvgpu.tcgen05.CtaGroup.TWO
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
consumer_mask = producer_mask
|
|
236
|
+
|
|
237
|
+
if const_expr(init_wait):
|
|
238
|
+
pipeline_init_wait(cta_layout_vmnk)
|
|
239
|
+
|
|
240
|
+
return PipelineTmaUmma(
|
|
241
|
+
sync_object_full,
|
|
242
|
+
sync_object_empty,
|
|
243
|
+
num_stages,
|
|
244
|
+
producer_mask,
|
|
245
|
+
consumer_mask,
|
|
246
|
+
is_leader_cta,
|
|
247
|
+
cta_group,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
def producer_acquire(
|
|
251
|
+
self,
|
|
252
|
+
state: PipelineState,
|
|
253
|
+
try_acquire_token: Optional[Boolean] = None,
|
|
254
|
+
extra_tx_count: int = 0,
|
|
255
|
+
):
|
|
256
|
+
"""
|
|
257
|
+
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
|
|
258
|
+
"""
|
|
259
|
+
if_generate(
|
|
260
|
+
try_acquire_token is None or try_acquire_token == 0,
|
|
261
|
+
lambda: self.sync_object_empty.wait(state.index, state.phase),
|
|
262
|
+
)
|
|
263
|
+
if const_expr(extra_tx_count == 0):
|
|
264
|
+
if_generate(
|
|
265
|
+
self.is_leader_cta,
|
|
266
|
+
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
|
|
267
|
+
)
|
|
268
|
+
else:
|
|
269
|
+
tx_count = self.sync_object_full.tx_count + extra_tx_count
|
|
270
|
+
if_generate(
|
|
271
|
+
self.is_leader_cta,
|
|
272
|
+
lambda: self.sync_object_full.arrive_and_expect_tx(state.index, tx_count),
|
|
273
|
+
)
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import cutlass
|
|
6
|
+
import cutlass.cute as cute
|
|
7
|
+
from cutlass import Int32, const_expr
|
|
8
|
+
|
|
9
|
+
"""
|
|
10
|
+
This consolidates all the info related to sequence length. This is so that we can do all
|
|
11
|
+
the gmem reads once at the beginning of each tile, rather than having to repeat these reads
|
|
12
|
+
to compute various things like n_block_min, n_block_max, etc.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class SeqlenInfo:
|
|
18
|
+
offset: cutlass.Int32
|
|
19
|
+
seqlen: cutlass.Int32
|
|
20
|
+
|
|
21
|
+
@staticmethod
|
|
22
|
+
def create(
|
|
23
|
+
batch_idx: cutlass.Int32,
|
|
24
|
+
seqlen_static: cutlass.Int32,
|
|
25
|
+
cu_seqlens: Optional[cute.Tensor] = None,
|
|
26
|
+
seqused: Optional[cute.Tensor] = None,
|
|
27
|
+
):
|
|
28
|
+
offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx]
|
|
29
|
+
if const_expr(seqused is not None):
|
|
30
|
+
seqlen = seqused[batch_idx]
|
|
31
|
+
elif const_expr(cu_seqlens is not None):
|
|
32
|
+
seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]
|
|
33
|
+
else:
|
|
34
|
+
seqlen = seqlen_static
|
|
35
|
+
return SeqlenInfo(offset, seqlen)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass(frozen=True)
|
|
39
|
+
class SeqlenInfoQK:
|
|
40
|
+
offset_q: cutlass.Int32
|
|
41
|
+
offset_k: cutlass.Int32
|
|
42
|
+
padded_offset_q: cutlass.Int32
|
|
43
|
+
padded_offset_k: cutlass.Int32
|
|
44
|
+
seqlen_q: cutlass.Int32
|
|
45
|
+
seqlen_k: cutlass.Int32
|
|
46
|
+
has_cu_seqlens_q: cutlass.Constexpr[bool]
|
|
47
|
+
has_cu_seqlens_k: cutlass.Constexpr[bool]
|
|
48
|
+
has_seqused_q: cutlass.Constexpr[bool]
|
|
49
|
+
has_seqused_k: cutlass.Constexpr[bool]
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def create(
|
|
53
|
+
batch_idx: cutlass.Int32,
|
|
54
|
+
seqlen_q_static: cutlass.Int32,
|
|
55
|
+
seqlen_k_static: cutlass.Int32,
|
|
56
|
+
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
|
57
|
+
mCuSeqlensK: Optional[cute.Tensor] = None,
|
|
58
|
+
mSeqUsedQ: Optional[cute.Tensor] = None,
|
|
59
|
+
mSeqUsedK: Optional[cute.Tensor] = None,
|
|
60
|
+
tile_m: cutlass.Constexpr[cutlass.Int32] = 128,
|
|
61
|
+
tile_n: cutlass.Constexpr[cutlass.Int32] = 128,
|
|
62
|
+
):
|
|
63
|
+
offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]
|
|
64
|
+
offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx]
|
|
65
|
+
padded_offset_q = (
|
|
66
|
+
0
|
|
67
|
+
if const_expr(mCuSeqlensQ is None)
|
|
68
|
+
else (offset_q + batch_idx * tile_m) // tile_m * tile_m
|
|
69
|
+
)
|
|
70
|
+
padded_offset_k = (
|
|
71
|
+
0
|
|
72
|
+
if const_expr(mCuSeqlensK is None)
|
|
73
|
+
else (offset_k + batch_idx * tile_n) // tile_n * tile_n
|
|
74
|
+
)
|
|
75
|
+
if const_expr(mSeqUsedQ is not None):
|
|
76
|
+
seqlen_q = mSeqUsedQ[batch_idx]
|
|
77
|
+
else:
|
|
78
|
+
seqlen_q = (
|
|
79
|
+
seqlen_q_static
|
|
80
|
+
if const_expr(mCuSeqlensQ is None)
|
|
81
|
+
else mCuSeqlensQ[batch_idx + 1] - offset_q
|
|
82
|
+
)
|
|
83
|
+
if const_expr(mSeqUsedK is not None):
|
|
84
|
+
seqlen_k = mSeqUsedK[batch_idx]
|
|
85
|
+
else:
|
|
86
|
+
seqlen_k = (
|
|
87
|
+
seqlen_k_static
|
|
88
|
+
if const_expr(mCuSeqlensK is None)
|
|
89
|
+
else mCuSeqlensK[batch_idx + 1] - offset_k
|
|
90
|
+
)
|
|
91
|
+
has_cu_seqlens_q: int = mCuSeqlensQ is not None
|
|
92
|
+
has_cu_seqlens_k: int = mCuSeqlensK is not None
|
|
93
|
+
has_seqused_q: int = mSeqUsedQ is not None
|
|
94
|
+
has_seqused_k: int = mSeqUsedK is not None
|
|
95
|
+
return SeqlenInfoQK(
|
|
96
|
+
offset_q,
|
|
97
|
+
offset_k,
|
|
98
|
+
padded_offset_q,
|
|
99
|
+
padded_offset_k,
|
|
100
|
+
seqlen_q,
|
|
101
|
+
seqlen_k,
|
|
102
|
+
has_cu_seqlens_q,
|
|
103
|
+
has_cu_seqlens_k,
|
|
104
|
+
has_seqused_q,
|
|
105
|
+
has_seqused_k,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def offset_batch_Q(
|
|
109
|
+
self,
|
|
110
|
+
mQ: cute.Tensor,
|
|
111
|
+
batch_idx: Int32,
|
|
112
|
+
dim: int,
|
|
113
|
+
padded: cutlass.Constexpr[bool] = False,
|
|
114
|
+
) -> cute.Tensor:
|
|
115
|
+
"""Seqlen must be the first dimension of mQ"""
|
|
116
|
+
if const_expr(not self.has_cu_seqlens_q):
|
|
117
|
+
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)
|
|
118
|
+
return mQ[idx]
|
|
119
|
+
else:
|
|
120
|
+
offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q
|
|
121
|
+
offset = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, offset_q)
|
|
122
|
+
idx = (offset,) + (0,) * (cute.rank(mQ) - 1)
|
|
123
|
+
return cute.domain_offset(idx, mQ)
|
|
124
|
+
|
|
125
|
+
def offset_batch_K(
|
|
126
|
+
self,
|
|
127
|
+
mK: cute.Tensor,
|
|
128
|
+
batch_idx: Int32,
|
|
129
|
+
dim: int,
|
|
130
|
+
padded: cutlass.Constexpr[bool] = False,
|
|
131
|
+
) -> cute.Tensor:
|
|
132
|
+
"""Seqlen must be the first dimension of mK"""
|
|
133
|
+
if const_expr(not self.has_cu_seqlens_k):
|
|
134
|
+
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)
|
|
135
|
+
return mK[idx]
|
|
136
|
+
else:
|
|
137
|
+
offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k
|
|
138
|
+
idx = (offset_k,) + (0,) * (cute.rank(mK) - 1)
|
|
139
|
+
return cute.domain_offset(idx, mK)
|