mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Tri Dao.
|
|
3
|
+
# Ported Cutlass code from C++ to Python:
|
|
4
|
+
# https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/mma_sm100_desc.hpp
|
|
5
|
+
# https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/mma_traits_sm100.hpp
|
|
6
|
+
|
|
7
|
+
from enum import IntEnum
|
|
8
|
+
|
|
9
|
+
import cutlass
|
|
10
|
+
import cutlass.cute as cute
|
|
11
|
+
|
|
12
|
+
# ---------------------------------------------------------------------------
|
|
13
|
+
# Enumerations that match the HW encodings (values MUST stay identical)
|
|
14
|
+
# ---------------------------------------------------------------------------
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Major(IntEnum): # matrix “layout” in the ISA docs
|
|
18
|
+
K = 0
|
|
19
|
+
MN = 1
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ScaleIn(IntEnum): # negate flags
|
|
23
|
+
One = 0
|
|
24
|
+
Neg = 1
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Saturate(IntEnum):
|
|
28
|
+
False_ = 0
|
|
29
|
+
True_ = 1
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CFormat(IntEnum): # 2-bit field (bits 4-5)
|
|
33
|
+
F16 = 0
|
|
34
|
+
F32 = 1
|
|
35
|
+
S32 = 2
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class F16F32Format(IntEnum): # 3-bit field (A/B element type)
|
|
39
|
+
F16 = 0
|
|
40
|
+
BF16 = 1
|
|
41
|
+
TF32 = 2
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class S8Format(IntEnum):
|
|
45
|
+
UINT8 = 0
|
|
46
|
+
INT8 = 1
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class MXF8F6F4Format(IntEnum):
|
|
50
|
+
E4M3 = 0
|
|
51
|
+
E5M2 = 1
|
|
52
|
+
E2M3 = 3
|
|
53
|
+
E3M2 = 4
|
|
54
|
+
E2M1 = 5
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class MaxShift(IntEnum):
|
|
58
|
+
NoShift = 0
|
|
59
|
+
MaxShift8 = 1
|
|
60
|
+
MaxShift16 = 2
|
|
61
|
+
MaxShift32 = 3
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# ---------------------------------------------------------------------------
|
|
65
|
+
# CUTLASS-type → encoding helpers
|
|
66
|
+
# ---------------------------------------------------------------------------
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def to_UMMA_format(cutlass_type) -> int:
|
|
70
|
+
"""
|
|
71
|
+
Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B.
|
|
72
|
+
"""
|
|
73
|
+
if cutlass_type is cutlass.Int8:
|
|
74
|
+
return S8Format.INT8
|
|
75
|
+
# Unsigned 8-bit (if available in your CUTLASS build)
|
|
76
|
+
if cutlass_type is cutlass.Uint8:
|
|
77
|
+
return S8Format.UINT8
|
|
78
|
+
# FP-16 / BF-16
|
|
79
|
+
if cutlass_type is cutlass.Float16:
|
|
80
|
+
return F16F32Format.F16
|
|
81
|
+
if cutlass_type is cutlass.BFloat16:
|
|
82
|
+
return F16F32Format.BF16
|
|
83
|
+
# TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits)
|
|
84
|
+
if cutlass_type is cutlass.TFloat32:
|
|
85
|
+
return F16F32Format.TF32
|
|
86
|
+
# Float-8 / Float-6 / Float-4 – add whenever CUTLASS exposes them
|
|
87
|
+
if cutlass_type is cutlass.FloatE4M3FN:
|
|
88
|
+
return MXF8F6F4Format.E4M3
|
|
89
|
+
if cutlass_type is cutlass.FloatE5M2:
|
|
90
|
+
return MXF8F6F4Format.E5M2
|
|
91
|
+
raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}")
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def to_C_format(cutlass_type) -> int:
|
|
95
|
+
"""
|
|
96
|
+
Map a CUTLASS scalar class to the 2-bit accumulator encoding.
|
|
97
|
+
"""
|
|
98
|
+
if cutlass_type is cutlass.Float16:
|
|
99
|
+
return CFormat.F16
|
|
100
|
+
if cutlass_type is cutlass.Float32:
|
|
101
|
+
return CFormat.F32
|
|
102
|
+
if cutlass_type is cutlass.Int32:
|
|
103
|
+
return CFormat.S32
|
|
104
|
+
raise TypeError(f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# ---------------------------------------------------------------------------
|
|
108
|
+
# The constructor – accepts only CUTLASS scalar classes
|
|
109
|
+
# ---------------------------------------------------------------------------
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def make_instr_desc(
|
|
113
|
+
a_type, # CUTLASS scalar class, e.g. cutlass.Int8
|
|
114
|
+
b_type,
|
|
115
|
+
c_type,
|
|
116
|
+
M: int, # 64, 128 or 256
|
|
117
|
+
N: int, # 8 … 256 (multiple of 8)
|
|
118
|
+
a_major: Major,
|
|
119
|
+
b_major: Major,
|
|
120
|
+
a_neg: ScaleIn = ScaleIn.One,
|
|
121
|
+
b_neg: ScaleIn = ScaleIn.One,
|
|
122
|
+
c_sat: Saturate = Saturate.False_,
|
|
123
|
+
is_sparse: bool = False,
|
|
124
|
+
max_shift: MaxShift = MaxShift.NoShift,
|
|
125
|
+
) -> int:
|
|
126
|
+
"""
|
|
127
|
+
Build the 32-bit instruction descriptor for Blackwell MMA.
|
|
128
|
+
All matrix/accumulator **types must be CUTLASS scalar classes** –
|
|
129
|
+
passing integers is forbidden.
|
|
130
|
+
"""
|
|
131
|
+
# --- encode element formats -------------------------------------------------
|
|
132
|
+
a_fmt = int(to_UMMA_format(a_type))
|
|
133
|
+
b_fmt = int(to_UMMA_format(b_type))
|
|
134
|
+
c_fmt = int(to_C_format(c_type))
|
|
135
|
+
|
|
136
|
+
# --- range checks on M/N -----------------------------------------------------
|
|
137
|
+
if M not in (64, 128, 256):
|
|
138
|
+
raise ValueError("M must be 64, 128 or 256")
|
|
139
|
+
if N < 8 or N > 256 or (N & 7):
|
|
140
|
+
raise ValueError("N must be a multiple of 8 in the range 8…256")
|
|
141
|
+
|
|
142
|
+
m_dim = M >> 4 # 5-bit field
|
|
143
|
+
n_dim = N >> 3 # 6-bit field
|
|
144
|
+
|
|
145
|
+
# fmt: off
|
|
146
|
+
# --- pack the bit-fields -----------------------------------------------------
|
|
147
|
+
desc = 0
|
|
148
|
+
desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here)
|
|
149
|
+
desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag
|
|
150
|
+
desc |= (int(c_sat) & 0x1) << 3 # saturate
|
|
151
|
+
desc |= (c_fmt & 0x3) << 4 # c_format
|
|
152
|
+
desc |= (a_fmt & 0x7) << 7 # a_format
|
|
153
|
+
desc |= (b_fmt & 0x7) << 10 # b_format
|
|
154
|
+
desc |= (int(a_neg) & 0x1) << 13 # a_negate
|
|
155
|
+
desc |= (int(b_neg) & 0x1) << 14 # b_negate
|
|
156
|
+
desc |= (int(a_major) & 0x1) << 15 # a_major
|
|
157
|
+
desc |= (int(b_major) & 0x1) << 16 # b_major
|
|
158
|
+
desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits)
|
|
159
|
+
desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits)
|
|
160
|
+
desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits)
|
|
161
|
+
# fmt: on
|
|
162
|
+
|
|
163
|
+
return desc & 0xFFFF_FFFF # ensure 32-bit result
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp):
|
|
167
|
+
return make_instr_desc(
|
|
168
|
+
op.a_dtype,
|
|
169
|
+
op.b_dtype,
|
|
170
|
+
op.acc_dtype,
|
|
171
|
+
op.shape_mnk[0],
|
|
172
|
+
op.shape_mnk[1],
|
|
173
|
+
Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN,
|
|
174
|
+
Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class LayoutType(IntEnum): # occupies the top-3 bits [61:64)
|
|
179
|
+
SWIZZLE_NONE = 0 # (a.k.a. “INTERLEAVE” in older docs)
|
|
180
|
+
SWIZZLE_128B_BASE32B = 1
|
|
181
|
+
SWIZZLE_128B = 2
|
|
182
|
+
SWIZZLE_64B = 4
|
|
183
|
+
SWIZZLE_32B = 6
|
|
184
|
+
# values 3,5,7 are reserved / illegal for UMMA
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# ---------------------------------------------------------------------------
|
|
188
|
+
# Helpers – figure out the SWIZZLE_* family from the tensor layout
|
|
189
|
+
# ---------------------------------------------------------------------------
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _layout_type(swizzle: cute.Swizzle) -> LayoutType:
|
|
193
|
+
# No idea what the right way to get B, M, S is – so we're just parsing it from the __str__
|
|
194
|
+
# Swizzle string has the form "S<B,M,S>"
|
|
195
|
+
swz_str = str(swizzle)
|
|
196
|
+
inside = swz_str[swz_str.index("<") + 1 : swz_str.index(">")] # '3,4,3'
|
|
197
|
+
B, M, S = [int(x) for x in inside.split(",")] # [3, 4, 3]
|
|
198
|
+
|
|
199
|
+
if M == 4: # Swizzle<*,4,3>
|
|
200
|
+
if S != 3:
|
|
201
|
+
raise ValueError("Unexpected swizzle shift – want S==3 for M==4")
|
|
202
|
+
return {
|
|
203
|
+
0: LayoutType.SWIZZLE_NONE,
|
|
204
|
+
1: LayoutType.SWIZZLE_32B,
|
|
205
|
+
2: LayoutType.SWIZZLE_64B,
|
|
206
|
+
3: LayoutType.SWIZZLE_128B,
|
|
207
|
+
}[B] # KeyError ⇒ invalid B→ raise
|
|
208
|
+
if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5)
|
|
209
|
+
if (B, S) != (2, 2):
|
|
210
|
+
raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B")
|
|
211
|
+
return LayoutType.SWIZZLE_128B_BASE32B
|
|
212
|
+
|
|
213
|
+
# Any other (M,B,S) triple is not a UMMA-legal shared-memory layout
|
|
214
|
+
raise ValueError("Unsupported swizzle triple for UMMA smem descriptor")
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int:
|
|
218
|
+
"""
|
|
219
|
+
Convert a 2-D *shared-memory* Cute layout into the Blackwell 64-bit
|
|
220
|
+
smem-descriptor, without the smem start address.
|
|
221
|
+
layout must correspond to layout of an uint128 tensor.
|
|
222
|
+
"""
|
|
223
|
+
# ------------------------------------------------------------------ meta
|
|
224
|
+
layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family
|
|
225
|
+
|
|
226
|
+
VERSION = 1 # bits 46–47
|
|
227
|
+
LBO_MODE = 0 # bit 52
|
|
228
|
+
BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0)
|
|
229
|
+
|
|
230
|
+
# ---------------------------------------------------------- strides (units: uint128_t = 16 B)
|
|
231
|
+
swizzle_atom_mn_size = {
|
|
232
|
+
LayoutType.SWIZZLE_NONE: 1,
|
|
233
|
+
LayoutType.SWIZZLE_32B: 2,
|
|
234
|
+
LayoutType.SWIZZLE_64B: 4,
|
|
235
|
+
LayoutType.SWIZZLE_128B: 8,
|
|
236
|
+
LayoutType.SWIZZLE_128B_BASE32B: 8,
|
|
237
|
+
}[layout_type]
|
|
238
|
+
|
|
239
|
+
if major is Major.MN:
|
|
240
|
+
swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8
|
|
241
|
+
canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size))
|
|
242
|
+
if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))):
|
|
243
|
+
raise ValueError("Not a canonical UMMA_MN Layout: Expected profile failure.")
|
|
244
|
+
stride_00 = canonical_layout.stride[0][0]
|
|
245
|
+
if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1:
|
|
246
|
+
raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.")
|
|
247
|
+
stride_10 = canonical_layout.stride[1][0]
|
|
248
|
+
if stride_10 != swizzle_atom_mn_size:
|
|
249
|
+
raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.")
|
|
250
|
+
stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1]
|
|
251
|
+
if layout_type is LayoutType.SWIZZLE_NONE:
|
|
252
|
+
stride_byte_offset, leading_byte_offset = stride_01, stride_11
|
|
253
|
+
else:
|
|
254
|
+
stride_byte_offset, leading_byte_offset = stride_11, stride_01
|
|
255
|
+
else:
|
|
256
|
+
if layout_type == LayoutType.SWIZZLE_128B_BASE32B:
|
|
257
|
+
raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K")
|
|
258
|
+
if not cute.size(layout.shape[0]) % 8 == 0:
|
|
259
|
+
raise ValueError("Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.")
|
|
260
|
+
canonical_layout = cute.logical_divide(layout, (8, 2))
|
|
261
|
+
if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))):
|
|
262
|
+
raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.")
|
|
263
|
+
stride_00 = canonical_layout.stride[0][0]
|
|
264
|
+
if stride_00 != swizzle_atom_mn_size:
|
|
265
|
+
raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.")
|
|
266
|
+
stride_10 = canonical_layout.stride[1][0]
|
|
267
|
+
if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1:
|
|
268
|
+
raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.")
|
|
269
|
+
stride_01 = canonical_layout.stride[0][1]
|
|
270
|
+
stride_byte_offset, leading_byte_offset = stride_01, stride_10
|
|
271
|
+
|
|
272
|
+
# ------------------------------------------------------------------ pack
|
|
273
|
+
desc = 0
|
|
274
|
+
# leading_byte_offset_ [16:30)
|
|
275
|
+
desc |= (leading_byte_offset & 0x3FFF) << 16
|
|
276
|
+
# stride_byte_offset_ [32:46)
|
|
277
|
+
desc |= (stride_byte_offset & 0x3FFF) << 32
|
|
278
|
+
# version_ [46:48)
|
|
279
|
+
desc |= (VERSION & 0x3) << 46
|
|
280
|
+
# base_offset_ [49:52)
|
|
281
|
+
desc |= (BASE_OFFSET & 0x7) << 49
|
|
282
|
+
# lbo_mode_ [52:53)
|
|
283
|
+
desc |= (LBO_MODE & 0x1) << 52
|
|
284
|
+
# layout_type_ [61:64)
|
|
285
|
+
desc |= (int(layout_type) & 0x7) << 61
|
|
286
|
+
|
|
287
|
+
return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32:
|
|
291
|
+
# 14 bits, remove 4 LSB (bits 0-13 in desc)
|
|
292
|
+
return (start_addr.toint() & 0x3FFFF) >> 4
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
|
3
|
+
|
|
4
|
+
import enum
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class NamedBarrierFwd(enum.IntEnum):
|
|
8
|
+
Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
|
|
9
|
+
WarpSchedulerWG1 = enum.auto()
|
|
10
|
+
WarpSchedulerWG2 = enum.auto()
|
|
11
|
+
WarpSchedulerWG3 = enum.auto()
|
|
12
|
+
PFull = enum.auto()
|
|
13
|
+
PEmpty = enum.auto()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class NamedBarrierBwd(enum.IntEnum):
|
|
17
|
+
Epilogue = enum.auto()
|
|
18
|
+
WarpSchedulerWG1 = enum.auto()
|
|
19
|
+
WarpSchedulerWG2 = enum.auto()
|
|
20
|
+
WarpSchedulerWG3 = enum.auto()
|
|
21
|
+
PdS = enum.auto()
|
|
22
|
+
dQFullWG0 = enum.auto()
|
|
23
|
+
dQFullWG1 = enum.auto()
|
|
24
|
+
dQEmptyWG0 = enum.auto()
|
|
25
|
+
dQEmptyWG1 = enum.auto()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class NamedBarrierBwdSm100(enum.IntEnum):
|
|
29
|
+
EpilogueWG1 = enum.auto()
|
|
30
|
+
EpilogueWG2 = enum.auto()
|
|
31
|
+
Compute = enum.auto()
|
|
32
|
+
dQaccReduce = enum.auto()
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Tri Dao.
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
import cutlass
|
|
6
|
+
import cutlass.cute as cute
|
|
7
|
+
|
|
8
|
+
import mslk.attention.flash_attn.utils as utils
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PackGQA:
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
m_block_size: cutlass.Constexpr[int],
|
|
15
|
+
head_dim_padded: cutlass.Constexpr[int],
|
|
16
|
+
check_hdim_oob: cutlass.Constexpr[bool],
|
|
17
|
+
qhead_per_kvhead: cutlass.Constexpr[bool],
|
|
18
|
+
):
|
|
19
|
+
self.m_block_size = m_block_size
|
|
20
|
+
self.head_dim_padded = head_dim_padded
|
|
21
|
+
self.check_hdim_oob = check_hdim_oob
|
|
22
|
+
self.qhead_per_kvhead = qhead_per_kvhead
|
|
23
|
+
|
|
24
|
+
@cute.jit
|
|
25
|
+
def compute_ptr(
|
|
26
|
+
self,
|
|
27
|
+
tensor: cute.Tensor,
|
|
28
|
+
cRows: cute.Tensor,
|
|
29
|
+
tidx: cutlass.Int32,
|
|
30
|
+
block: cutlass.Int32,
|
|
31
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
32
|
+
num_threads: cutlass.Constexpr[int],
|
|
33
|
+
):
|
|
34
|
+
num_ptr_per_thread = cute.ceil_div(cute.size(cRows), threads_per_row)
|
|
35
|
+
tPrPtr = cute.make_fragment(num_ptr_per_thread, cutlass.Int64)
|
|
36
|
+
for i in cutlass.range_constexpr(num_ptr_per_thread):
|
|
37
|
+
row = i * num_threads + cRows[tidx % threads_per_row][0]
|
|
38
|
+
idx = block * self.m_block_size + row
|
|
39
|
+
m_idx = idx // self.qhead_per_kvhead
|
|
40
|
+
h_idx = idx - m_idx * self.qhead_per_kvhead
|
|
41
|
+
tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint()
|
|
42
|
+
return tPrPtr
|
|
43
|
+
|
|
44
|
+
@cute.jit
|
|
45
|
+
def load_Q(
|
|
46
|
+
self,
|
|
47
|
+
mQ: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim)
|
|
48
|
+
sQ: cute.Tensor, # (m_block_size, head_dim_padded)
|
|
49
|
+
gmem_tiled_copy: cute.TiledCopy,
|
|
50
|
+
tidx: cutlass.Int32,
|
|
51
|
+
block: cutlass.Int32,
|
|
52
|
+
seqlen: cutlass.Int32,
|
|
53
|
+
):
|
|
54
|
+
gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)
|
|
55
|
+
cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
|
|
56
|
+
tQsQ = gmem_thr_copy.partition_D(sQ)
|
|
57
|
+
tQcQ = gmem_thr_copy.partition_S(cQ)
|
|
58
|
+
t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ)
|
|
59
|
+
tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[1])
|
|
60
|
+
tQcQ_row = tQcQ[0, None, 0]
|
|
61
|
+
threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0]
|
|
62
|
+
assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
|
|
63
|
+
num_threads = gmem_tiled_copy.size
|
|
64
|
+
tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads)
|
|
65
|
+
for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
|
|
66
|
+
q_ptr_i64 = utils.shuffle_sync(
|
|
67
|
+
tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row
|
|
68
|
+
)
|
|
69
|
+
q_gmem_ptr = cute.make_ptr(
|
|
70
|
+
mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
|
|
71
|
+
)
|
|
72
|
+
if (
|
|
73
|
+
t0QcQ[0, m, 0][0]
|
|
74
|
+
< seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0]
|
|
75
|
+
):
|
|
76
|
+
mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,))
|
|
77
|
+
elems_per_load = cute.size(tQsQ.shape[0][0])
|
|
78
|
+
mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,))
|
|
79
|
+
for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])):
|
|
80
|
+
ki = tQcQ[0, 0, k][1] // elems_per_load
|
|
81
|
+
cute.copy(
|
|
82
|
+
gmem_thr_copy,
|
|
83
|
+
mQ_cur_copy[None, ki],
|
|
84
|
+
tQsQ[None, m, k],
|
|
85
|
+
pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None,
|
|
86
|
+
)
|
|
87
|
+
# We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
|
88
|
+
|
|
89
|
+
@cute.jit
|
|
90
|
+
def store_LSE(
|
|
91
|
+
self,
|
|
92
|
+
mLSE: cute.Tensor, # (qhead_per_kvhead, seqlen_q)
|
|
93
|
+
tLSErLSE: cute.Tensor, # (m_block_size, head_dim_padded)
|
|
94
|
+
tiled_mma: cute.TiledMma,
|
|
95
|
+
tidx: cutlass.Int32,
|
|
96
|
+
block: cutlass.Int32,
|
|
97
|
+
seqlen: cutlass.Int32,
|
|
98
|
+
):
|
|
99
|
+
thr_mma = tiled_mma.get_slice(tidx)
|
|
100
|
+
caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
|
|
101
|
+
taccOcO = thr_mma.partition_C(caccO)
|
|
102
|
+
taccOcO_row = utils.make_acc_tensor_mn_view(taccOcO)[None, 0]
|
|
103
|
+
assert cute.size(tLSErLSE) == cute.size(taccOcO_row)
|
|
104
|
+
threads_per_row = tiled_mma.tv_layout_C.shape[0][0]
|
|
105
|
+
assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
|
|
106
|
+
assert cute.size(tLSErLSE) <= threads_per_row
|
|
107
|
+
num_threads = tiled_mma.size
|
|
108
|
+
tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads)
|
|
109
|
+
for m in cutlass.range_constexpr(cute.size(tLSErLSE)):
|
|
110
|
+
lse_ptr_i64 = utils.shuffle_sync(
|
|
111
|
+
tPrLSEPtr[m // threads_per_row],
|
|
112
|
+
m % threads_per_row,
|
|
113
|
+
width=threads_per_row,
|
|
114
|
+
)
|
|
115
|
+
lse_gmem_ptr = cute.make_ptr(
|
|
116
|
+
mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4
|
|
117
|
+
)
|
|
118
|
+
row = block * self.m_block_size + taccOcO_row[m][0]
|
|
119
|
+
# Only the thread corresponding to column 0 writes out the lse to gmem
|
|
120
|
+
if taccOcO[0][1] == 0 and row < seqlen * self.qhead_per_kvhead:
|
|
121
|
+
mLSE_copy = cute.make_tensor(lse_gmem_ptr, (1,))
|
|
122
|
+
mLSE_copy[0] = tLSErLSE[m]
|
|
123
|
+
|
|
124
|
+
@cute.jit
|
|
125
|
+
def store_O(
|
|
126
|
+
self,
|
|
127
|
+
mO: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim)
|
|
128
|
+
tOrO: cute.Tensor, # (m_block_size, head_dim_padded) split across threads according to gmem_tiled_copy
|
|
129
|
+
gmem_tiled_copy: cute.TiledCopy,
|
|
130
|
+
tidx: cutlass.Int32,
|
|
131
|
+
block: cutlass.Int32,
|
|
132
|
+
seqlen: cutlass.Int32,
|
|
133
|
+
):
|
|
134
|
+
gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)
|
|
135
|
+
cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
|
|
136
|
+
tOcO = gmem_thr_copy.partition_S(cO)
|
|
137
|
+
t0OcO = gmem_thr_copy.get_slice(0).partition_S(cO)
|
|
138
|
+
tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])
|
|
139
|
+
tOcO_row = tOcO[0, None, 0]
|
|
140
|
+
threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0]
|
|
141
|
+
assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
|
|
142
|
+
num_threads = gmem_tiled_copy.size
|
|
143
|
+
tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads)
|
|
144
|
+
for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
|
|
145
|
+
o_ptr_i64 = utils.shuffle_sync(
|
|
146
|
+
tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row
|
|
147
|
+
)
|
|
148
|
+
o_gmem_ptr = cute.make_ptr(
|
|
149
|
+
mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
|
|
150
|
+
)
|
|
151
|
+
if (
|
|
152
|
+
t0OcO[0, m, 0][0]
|
|
153
|
+
< seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0]
|
|
154
|
+
):
|
|
155
|
+
mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,))
|
|
156
|
+
elems_per_load = cute.size(tOrO.shape[0][0])
|
|
157
|
+
mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,))
|
|
158
|
+
for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])):
|
|
159
|
+
ki = tOcO[0, 0, k][1] // elems_per_load
|
|
160
|
+
cute.copy(
|
|
161
|
+
gmem_thr_copy,
|
|
162
|
+
tOrO[None, m, k],
|
|
163
|
+
mO_cur_copy[None, ki],
|
|
164
|
+
pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None,
|
|
165
|
+
)
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
from typing import Type
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import cutlass
|
|
6
|
+
import cutlass.cute as cute
|
|
7
|
+
from cutlass.cute.nvgpu import cpasync
|
|
8
|
+
from cutlass import Int32, const_expr
|
|
9
|
+
|
|
10
|
+
from mslk.attention.flash_attn import utils
|
|
11
|
+
from mslk.attention.flash_attn.cute_dsl_utils import ParamsBase
|
|
12
|
+
from cutlass.cute import FastDivmodDivisor
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class PagedKVManager(ParamsBase):
|
|
17
|
+
mPageTable: cute.Tensor
|
|
18
|
+
mK_paged: cute.Tensor
|
|
19
|
+
mV_paged: cute.Tensor
|
|
20
|
+
thread_idx: Int32
|
|
21
|
+
|
|
22
|
+
page_size_divmod: FastDivmodDivisor
|
|
23
|
+
seqlen_k: Int32
|
|
24
|
+
leftpad_k: Int32
|
|
25
|
+
n_block_size: Int32
|
|
26
|
+
num_threads: cutlass.Constexpr[Int32]
|
|
27
|
+
head_dim_padded: cutlass.Constexpr[Int32]
|
|
28
|
+
head_dim_v_padded: cutlass.Constexpr[Int32]
|
|
29
|
+
|
|
30
|
+
gmem_threads_per_row: cutlass.Constexpr[Int32]
|
|
31
|
+
page_entry_per_thread: Int32
|
|
32
|
+
async_copy_elems: Int32
|
|
33
|
+
|
|
34
|
+
gmem_tiled_copy_KV: cute.TiledCopy
|
|
35
|
+
gmem_thr_copy_KV: cute.TiledCopy
|
|
36
|
+
tPrPage: cute.Tensor
|
|
37
|
+
tPrPageOffset: cute.Tensor
|
|
38
|
+
tKpK: cute.Tensor
|
|
39
|
+
tVpV: cute.Tensor
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def create(
|
|
43
|
+
mPageTable: cute.Tensor,
|
|
44
|
+
mK_paged: cute.Tensor,
|
|
45
|
+
mV_paged: cute.Tensor,
|
|
46
|
+
page_size_divmod: FastDivmodDivisor,
|
|
47
|
+
bidb: Int32,
|
|
48
|
+
bidh: Int32,
|
|
49
|
+
thread_idx: Int32,
|
|
50
|
+
seqlen_k: Int32,
|
|
51
|
+
leftpad_k: Int32,
|
|
52
|
+
n_block_size: cutlass.Constexpr[Int32],
|
|
53
|
+
head_dim_padded: cutlass.Constexpr[Int32],
|
|
54
|
+
head_dim_v_padded: cutlass.Constexpr[Int32],
|
|
55
|
+
num_threads: cutlass.Constexpr[Int32],
|
|
56
|
+
dtype: Type[cutlass.Numeric],
|
|
57
|
+
):
|
|
58
|
+
universal_copy_bits = 128
|
|
59
|
+
gmem_threads_per_row = 8 # 8 threads loading 128 bits = 128 bytes = 1 cache line
|
|
60
|
+
async_copy_elems = universal_copy_bits // dtype.width
|
|
61
|
+
atom_async_copy = cute.make_copy_atom(
|
|
62
|
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
|
63
|
+
dtype,
|
|
64
|
+
num_bits_per_copy=universal_copy_bits,
|
|
65
|
+
)
|
|
66
|
+
thr_layout = cute.make_ordered_layout(
|
|
67
|
+
(num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
|
68
|
+
order=(1, 0),
|
|
69
|
+
)
|
|
70
|
+
val_layout = cute.make_layout((1, async_copy_elems))
|
|
71
|
+
gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout)
|
|
72
|
+
gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx)
|
|
73
|
+
page_entry_per_thread = n_block_size * gmem_threads_per_row // num_threads
|
|
74
|
+
|
|
75
|
+
tPrPage = cute.make_rmem_tensor((page_entry_per_thread,), Int32)
|
|
76
|
+
tPrPageOffset = cute.make_rmem_tensor((page_entry_per_thread,), Int32)
|
|
77
|
+
|
|
78
|
+
mPageTable = mPageTable[bidb, None]
|
|
79
|
+
mK_paged = mK_paged[None, None, bidh, None]
|
|
80
|
+
mV_paged = mV_paged[None, None, bidh, None]
|
|
81
|
+
|
|
82
|
+
cK = cute.make_identity_tensor((n_block_size, head_dim_padded))
|
|
83
|
+
tKcK = gmem_thr_copy_KV.partition_S(cK)
|
|
84
|
+
tKpK = utils.predicate_k(tKcK, limit=mK_paged.shape[1])
|
|
85
|
+
|
|
86
|
+
if const_expr(head_dim_padded == head_dim_v_padded):
|
|
87
|
+
tVpV = tKpK
|
|
88
|
+
else:
|
|
89
|
+
cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded))
|
|
90
|
+
tVcV = gmem_thr_copy_KV.partition_S(cV)
|
|
91
|
+
tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0])
|
|
92
|
+
|
|
93
|
+
return PagedKVManager(
|
|
94
|
+
mPageTable,
|
|
95
|
+
mK_paged,
|
|
96
|
+
mV_paged,
|
|
97
|
+
thread_idx,
|
|
98
|
+
page_size_divmod,
|
|
99
|
+
seqlen_k,
|
|
100
|
+
leftpad_k,
|
|
101
|
+
n_block_size,
|
|
102
|
+
num_threads,
|
|
103
|
+
head_dim_padded,
|
|
104
|
+
head_dim_v_padded,
|
|
105
|
+
gmem_threads_per_row,
|
|
106
|
+
page_entry_per_thread,
|
|
107
|
+
async_copy_elems,
|
|
108
|
+
gmem_tiled_copy_KV,
|
|
109
|
+
gmem_thr_copy_KV,
|
|
110
|
+
tPrPage,
|
|
111
|
+
tPrPageOffset,
|
|
112
|
+
tKpK,
|
|
113
|
+
tVpV,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
@cute.jit
|
|
117
|
+
def load_page_table(self, n_block: Int32):
|
|
118
|
+
for i in cutlass.range(self.page_entry_per_thread, unroll=1):
|
|
119
|
+
row = (i * self.num_threads + self.thread_idx) // self.gmem_threads_per_row
|
|
120
|
+
row_idx = n_block * self.n_block_size + row
|
|
121
|
+
|
|
122
|
+
page_idx, page_offset = divmod(row_idx + self.leftpad_k, self.page_size_divmod)
|
|
123
|
+
|
|
124
|
+
is_valid = (
|
|
125
|
+
(i + 1) * self.num_threads <= self.n_block_size or row < self.n_block_size
|
|
126
|
+
) and row_idx < self.seqlen_k
|
|
127
|
+
page = self.mPageTable[page_idx] if is_valid else 0
|
|
128
|
+
|
|
129
|
+
self.tPrPage[i] = page
|
|
130
|
+
self.tPrPageOffset[i] = page_offset
|
|
131
|
+
|
|
132
|
+
@cute.jit
|
|
133
|
+
def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str):
|
|
134
|
+
assert K_or_V in ("K", "V")
|
|
135
|
+
|
|
136
|
+
# Finesse sX layout to be (M, N).
|
|
137
|
+
sX_pi = cute.make_tensor(
|
|
138
|
+
sX.iterator,
|
|
139
|
+
cute.make_layout(
|
|
140
|
+
(sX.shape[0][0], (sX.shape[0][1], sX.shape[2])),
|
|
141
|
+
stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])),
|
|
142
|
+
),
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
if const_expr(K_or_V == "V"):
|
|
146
|
+
# Need to transpose V
|
|
147
|
+
sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0]))
|
|
148
|
+
|
|
149
|
+
head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded
|
|
150
|
+
cX = cute.make_identity_tensor((self.n_block_size, head_dim))
|
|
151
|
+
tXsX = self.gmem_thr_copy_KV.partition_D(sX_pi)
|
|
152
|
+
tXcX = self.gmem_thr_copy_KV.partition_S(cX)
|
|
153
|
+
|
|
154
|
+
seqlenk_row_limit = self.seqlen_k - n_block * self.n_block_size if n_block >= 0 else 0
|
|
155
|
+
for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])):
|
|
156
|
+
row_valid = tXcX[0, m, 0][0] < seqlenk_row_limit
|
|
157
|
+
should_load = cute.make_fragment_like(tXsX[None, m, 0], cute.Boolean)
|
|
158
|
+
should_load.fill(row_valid)
|
|
159
|
+
|
|
160
|
+
page = self.tPrPage[m]
|
|
161
|
+
page_offset = self.tPrPageOffset[m]
|
|
162
|
+
mX_paged_cur = (
|
|
163
|
+
self.mK_paged[page_offset, None, page]
|
|
164
|
+
if const_expr(K_or_V == "K")
|
|
165
|
+
else self.mV_paged[None, page_offset, page]
|
|
166
|
+
)
|
|
167
|
+
mX_paged_cur_copy = cute.tiled_divide(mX_paged_cur, (self.async_copy_elems,))
|
|
168
|
+
|
|
169
|
+
for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])):
|
|
170
|
+
ki = tXcX[0, 0, k][1] // self.async_copy_elems
|
|
171
|
+
cute.copy(
|
|
172
|
+
self.gmem_tiled_copy_KV,
|
|
173
|
+
mX_paged_cur_copy[None, ki],
|
|
174
|
+
tXsX[None, m, k],
|
|
175
|
+
pred=should_load,
|
|
176
|
+
)
|