mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,273 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Tri Dao.
3
+
4
+ # import math
5
+ from typing import Optional
6
+ from dataclasses import dataclass
7
+
8
+ import cutlass
9
+ import cutlass.cute as cute
10
+ from cutlass import Boolean, Int32, const_expr
11
+ from cutlass.cutlass_dsl import if_generate
12
+ from cutlass.pipeline import PipelineAsync, PipelineState, Agent, CooperativeGroup
13
+ from cutlass.pipeline import PipelineUserType, PipelineOp
14
+ from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
15
+ from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
16
+
17
+
18
+ # We deviate from cute-dsl implementation to use cute.arch.cluster_arrive_relaxed
19
+ def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None):
20
+ """
21
+ Fences the mbarrier init and syncs the threadblock or cluster
22
+ """
23
+ cute.arch.mbarrier_init_fence()
24
+
25
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
26
+ # If not using clusters, sync the threadblock
27
+ _sync(Agent.ThreadBlock)
28
+ else:
29
+ # If using clusters, sync the cluster
30
+ _sync(Agent.ThreadBlockCluster)
31
+
32
+
33
+ def _sync(group: Agent):
34
+ """
35
+ Syncs all threads within an agent.
36
+ """
37
+ if group is Agent.Thread:
38
+ raise NotImplementedError("Error: Not supported.")
39
+ elif group is Agent.ThreadBlock:
40
+ cute.arch.sync_threads()
41
+ elif group is Agent.ThreadBlockCluster:
42
+ cute.arch.cluster_arrive_relaxed()
43
+ cute.arch.cluster_wait()
44
+ else:
45
+ assert False, (
46
+ "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead."
47
+ )
48
+
49
+
50
+ class PipelineStateSimple:
51
+ """
52
+ Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
53
+ Use a single Int32 to store both the index and phase bit, then we use divmod to get the
54
+ index and phase. If stages is a power of 2, divmod turns into bit twiddling.
55
+ """
56
+
57
+ def __init__(self, stages: int, phase_index: Int32):
58
+ # assert stages < 2**16
59
+ # self._log_stages = int(math.log2(stages))
60
+ # assert 1 << self._log_stages == stages, "Number of stages must be a power of 2."
61
+ self._stages = stages
62
+ self._phase_index = phase_index
63
+
64
+ def clone(self) -> "PipelineStateSimple":
65
+ return PipelineStateSimple(self.stages, self._phase_index)
66
+
67
+ @property
68
+ def stages(self) -> int:
69
+ # return 1 << self._log_stages
70
+ return self._stages
71
+
72
+ @property
73
+ def index(self) -> Int32:
74
+ # return self._phase_index & 0xFFFF
75
+ # return self._phase_index & ((1 << self._log_stages) - 1)
76
+ if const_expr(self._stages == 1):
77
+ return Int32(0)
78
+ else:
79
+ return self._phase_index % self._stages
80
+
81
+ @property
82
+ def phase(self) -> Int32:
83
+ # return self._phase_index >> 16
84
+ # PTX docs say that the phase parity needs to be 0 or 1, so by right we need to
85
+ # take modulo 2. But in practice just passing the phase in without modulo works fine.
86
+ # return (self._phase_index >> self._log_stages) % 2
87
+ # return self._phase_index >> self._log_stages
88
+ if const_expr(self._stages == 1):
89
+ return self._phase_index
90
+ else:
91
+ return self._phase_index // self._stages
92
+
93
+ def advance(self):
94
+ if const_expr(self._stages == 1):
95
+ self._phase_index ^= 1
96
+ else:
97
+ self._phase_index += 1
98
+
99
+ # def then_body(phase_index):
100
+ # # XOR the phase bit and set the index to 0
101
+ # return (phase_index & 0xFFFF0000) ^ (1 << 16)
102
+
103
+ # def else_body(phase_index):
104
+ # return phase_index
105
+
106
+ # self._phase_index = if_generate(
107
+ # (self._phase_index & 0xFFFF) == self.stages,
108
+ # then_body,
109
+ # else_body,
110
+ # [self._phase_index],
111
+ # [Int32],
112
+ # )
113
+
114
+ def __extract_mlir_values__(self):
115
+ phase_index = self._phase_index
116
+ return [phase_index.ir_value()]
117
+
118
+ def __new_from_mlir_values__(self, values):
119
+ return PipelineStateSimple(self.stages, Int32(values[0]))
120
+
121
+
122
+ def make_pipeline_state(type: PipelineUserType, stages: int):
123
+ """
124
+ Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
125
+ """
126
+ if type is PipelineUserType.Producer:
127
+ # return PipelineStateSimple(stages, Int32(1 << 16))
128
+ return PipelineStateSimple(stages, Int32(stages))
129
+ elif type is PipelineUserType.Consumer:
130
+ return PipelineStateSimple(stages, Int32(0))
131
+ else:
132
+ assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
133
+
134
+
135
+ @dataclass(frozen=True)
136
+ class PipelineTmaAsync(PipelineTmaAsyncOg):
137
+ """
138
+ Override producer_acquire to take in extra_tx_count parameter.
139
+ """
140
+
141
+ @staticmethod
142
+ def create(*args, **kwargs):
143
+ obj = PipelineTmaAsyncOg.create(*args, **kwargs)
144
+ # Can't assign to __class__ directly since the dataclass is frozen
145
+ # obj.__class__ = PipelineTmaAsync
146
+ object.__setattr__(obj, "__class__", PipelineTmaAsync)
147
+ return obj
148
+
149
+ def producer_acquire(
150
+ self,
151
+ state: PipelineState,
152
+ try_acquire_token: Optional[Boolean] = None,
153
+ extra_tx_count: int = 0,
154
+ ):
155
+ """
156
+ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
157
+ """
158
+ if_generate(
159
+ try_acquire_token is None or try_acquire_token == 0,
160
+ lambda: self.sync_object_empty.wait(state.index, state.phase),
161
+ )
162
+ if const_expr(extra_tx_count == 0):
163
+ self.sync_object_full.arrive(state.index, self.producer_mask)
164
+ else:
165
+ tx_count = self.sync_object_full.tx_count + extra_tx_count
166
+ self.sync_object_full.arrive_and_expect_tx(state.index, tx_count)
167
+
168
+
169
+ @dataclass(frozen=True)
170
+ class PipelineTmaUmma(PipelineTmaUmmaOg):
171
+ @staticmethod
172
+ def create(
173
+ *,
174
+ num_stages: int,
175
+ producer_group: CooperativeGroup,
176
+ consumer_group: CooperativeGroup,
177
+ tx_count: int,
178
+ barrier_storage: cute.Pointer = None,
179
+ cta_layout_vmnk: Optional[cute.Layout] = None,
180
+ mcast_mode_mn: tuple[int, int] = (1, 1),
181
+ init_wait: cutlass.Constexpr[bool] = True,
182
+ ):
183
+ """
184
+ This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
185
+ :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
186
+ :type barrier_storage: cute.Pointer
187
+ :param num_stages: Number of buffer stages for this pipeline
188
+ :type num_stages: Int32
189
+ :param producer_group: `CooperativeGroup` for the producer agent
190
+ :type producer_group: CooperativeGroup
191
+ :param consumer_group: `CooperativeGroup` for the consumer agent
192
+ :type consumer_group: CooperativeGroup
193
+ :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
194
+ :type tx_count: int
195
+ :param cta_layout_vmnk: Layout of the cluster shape
196
+ :type cta_layout_vmnk: cute.Layout | None
197
+ :param mcast_mode_mn: Tuple of two integers, specifying whether mcast is enabled for the m and n modes. At least one of the two integers must be 1.
198
+ :type mcast_mode_mn: tuple[int, int]
199
+ """
200
+ if not isinstance(barrier_storage, cute.Pointer):
201
+ raise ValueError(
202
+ f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
203
+ )
204
+
205
+ producer_type = PipelineOp.TmaLoad
206
+ consumer_type = PipelineOp.TCGen05Mma
207
+
208
+ producer = (producer_type, producer_group)
209
+ consumer = (consumer_type, consumer_group)
210
+
211
+ sync_object_full = PipelineAsync._make_sync_object(
212
+ barrier_storage.align(min_align=8), num_stages, producer, tx_count
213
+ )
214
+ sync_object_empty = PipelineAsync._make_sync_object(
215
+ barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
216
+ )
217
+
218
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
219
+ # No mcast mask if not using clusters
220
+ producer_mask = None
221
+ # All threadblocks are leaders if not using clusters
222
+ is_leader_cta = True
223
+ else:
224
+ producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(
225
+ cta_layout_vmnk, mcast_mode_mn
226
+ )
227
+ is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
228
+
229
+ cta_group = (
230
+ cute.nvgpu.tcgen05.CtaGroup.ONE
231
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
232
+ else cute.nvgpu.tcgen05.CtaGroup.TWO
233
+ )
234
+
235
+ consumer_mask = producer_mask
236
+
237
+ if const_expr(init_wait):
238
+ pipeline_init_wait(cta_layout_vmnk)
239
+
240
+ return PipelineTmaUmma(
241
+ sync_object_full,
242
+ sync_object_empty,
243
+ num_stages,
244
+ producer_mask,
245
+ consumer_mask,
246
+ is_leader_cta,
247
+ cta_group,
248
+ )
249
+
250
+ def producer_acquire(
251
+ self,
252
+ state: PipelineState,
253
+ try_acquire_token: Optional[Boolean] = None,
254
+ extra_tx_count: int = 0,
255
+ ):
256
+ """
257
+ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
258
+ """
259
+ if_generate(
260
+ try_acquire_token is None or try_acquire_token == 0,
261
+ lambda: self.sync_object_empty.wait(state.index, state.phase),
262
+ )
263
+ if const_expr(extra_tx_count == 0):
264
+ if_generate(
265
+ self.is_leader_cta,
266
+ lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
267
+ )
268
+ else:
269
+ tx_count = self.sync_object_full.tx_count + extra_tx_count
270
+ if_generate(
271
+ self.is_leader_cta,
272
+ lambda: self.sync_object_full.arrive_and_expect_tx(state.index, tx_count),
273
+ )
@@ -0,0 +1,139 @@
1
+ # @nolint # fbcode
2
+ from typing import Optional
3
+ from dataclasses import dataclass
4
+
5
+ import cutlass
6
+ import cutlass.cute as cute
7
+ from cutlass import Int32, const_expr
8
+
9
+ """
10
+ This consolidates all the info related to sequence length. This is so that we can do all
11
+ the gmem reads once at the beginning of each tile, rather than having to repeat these reads
12
+ to compute various things like n_block_min, n_block_max, etc.
13
+ """
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class SeqlenInfo:
18
+ offset: cutlass.Int32
19
+ seqlen: cutlass.Int32
20
+
21
+ @staticmethod
22
+ def create(
23
+ batch_idx: cutlass.Int32,
24
+ seqlen_static: cutlass.Int32,
25
+ cu_seqlens: Optional[cute.Tensor] = None,
26
+ seqused: Optional[cute.Tensor] = None,
27
+ ):
28
+ offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx]
29
+ if const_expr(seqused is not None):
30
+ seqlen = seqused[batch_idx]
31
+ elif const_expr(cu_seqlens is not None):
32
+ seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]
33
+ else:
34
+ seqlen = seqlen_static
35
+ return SeqlenInfo(offset, seqlen)
36
+
37
+
38
+ @dataclass(frozen=True)
39
+ class SeqlenInfoQK:
40
+ offset_q: cutlass.Int32
41
+ offset_k: cutlass.Int32
42
+ padded_offset_q: cutlass.Int32
43
+ padded_offset_k: cutlass.Int32
44
+ seqlen_q: cutlass.Int32
45
+ seqlen_k: cutlass.Int32
46
+ has_cu_seqlens_q: cutlass.Constexpr[bool]
47
+ has_cu_seqlens_k: cutlass.Constexpr[bool]
48
+ has_seqused_q: cutlass.Constexpr[bool]
49
+ has_seqused_k: cutlass.Constexpr[bool]
50
+
51
+ @staticmethod
52
+ def create(
53
+ batch_idx: cutlass.Int32,
54
+ seqlen_q_static: cutlass.Int32,
55
+ seqlen_k_static: cutlass.Int32,
56
+ mCuSeqlensQ: Optional[cute.Tensor] = None,
57
+ mCuSeqlensK: Optional[cute.Tensor] = None,
58
+ mSeqUsedQ: Optional[cute.Tensor] = None,
59
+ mSeqUsedK: Optional[cute.Tensor] = None,
60
+ tile_m: cutlass.Constexpr[cutlass.Int32] = 128,
61
+ tile_n: cutlass.Constexpr[cutlass.Int32] = 128,
62
+ ):
63
+ offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]
64
+ offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx]
65
+ padded_offset_q = (
66
+ 0
67
+ if const_expr(mCuSeqlensQ is None)
68
+ else (offset_q + batch_idx * tile_m) // tile_m * tile_m
69
+ )
70
+ padded_offset_k = (
71
+ 0
72
+ if const_expr(mCuSeqlensK is None)
73
+ else (offset_k + batch_idx * tile_n) // tile_n * tile_n
74
+ )
75
+ if const_expr(mSeqUsedQ is not None):
76
+ seqlen_q = mSeqUsedQ[batch_idx]
77
+ else:
78
+ seqlen_q = (
79
+ seqlen_q_static
80
+ if const_expr(mCuSeqlensQ is None)
81
+ else mCuSeqlensQ[batch_idx + 1] - offset_q
82
+ )
83
+ if const_expr(mSeqUsedK is not None):
84
+ seqlen_k = mSeqUsedK[batch_idx]
85
+ else:
86
+ seqlen_k = (
87
+ seqlen_k_static
88
+ if const_expr(mCuSeqlensK is None)
89
+ else mCuSeqlensK[batch_idx + 1] - offset_k
90
+ )
91
+ has_cu_seqlens_q: int = mCuSeqlensQ is not None
92
+ has_cu_seqlens_k: int = mCuSeqlensK is not None
93
+ has_seqused_q: int = mSeqUsedQ is not None
94
+ has_seqused_k: int = mSeqUsedK is not None
95
+ return SeqlenInfoQK(
96
+ offset_q,
97
+ offset_k,
98
+ padded_offset_q,
99
+ padded_offset_k,
100
+ seqlen_q,
101
+ seqlen_k,
102
+ has_cu_seqlens_q,
103
+ has_cu_seqlens_k,
104
+ has_seqused_q,
105
+ has_seqused_k,
106
+ )
107
+
108
+ def offset_batch_Q(
109
+ self,
110
+ mQ: cute.Tensor,
111
+ batch_idx: Int32,
112
+ dim: int,
113
+ padded: cutlass.Constexpr[bool] = False,
114
+ ) -> cute.Tensor:
115
+ """Seqlen must be the first dimension of mQ"""
116
+ if const_expr(not self.has_cu_seqlens_q):
117
+ idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)
118
+ return mQ[idx]
119
+ else:
120
+ offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q
121
+ offset = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, offset_q)
122
+ idx = (offset,) + (0,) * (cute.rank(mQ) - 1)
123
+ return cute.domain_offset(idx, mQ)
124
+
125
+ def offset_batch_K(
126
+ self,
127
+ mK: cute.Tensor,
128
+ batch_idx: Int32,
129
+ dim: int,
130
+ padded: cutlass.Constexpr[bool] = False,
131
+ ) -> cute.Tensor:
132
+ """Seqlen must be the first dimension of mK"""
133
+ if const_expr(not self.has_cu_seqlens_k):
134
+ idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)
135
+ return mK[idx]
136
+ else:
137
+ offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k
138
+ idx = (offset_k,) + (0,) * (cute.rank(mK) - 1)
139
+ return cute.domain_offset(idx, mK)