mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,292 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Tri Dao.
3
+ # Ported Cutlass code from C++ to Python:
4
+ # https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/mma_sm100_desc.hpp
5
+ # https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/mma_traits_sm100.hpp
6
+
7
+ from enum import IntEnum
8
+
9
+ import cutlass
10
+ import cutlass.cute as cute
11
+
12
+ # ---------------------------------------------------------------------------
13
+ # Enumerations that match the HW encodings (values MUST stay identical)
14
+ # ---------------------------------------------------------------------------
15
+
16
+
17
+ class Major(IntEnum): # matrix “layout” in the ISA docs
18
+ K = 0
19
+ MN = 1
20
+
21
+
22
+ class ScaleIn(IntEnum): # negate flags
23
+ One = 0
24
+ Neg = 1
25
+
26
+
27
+ class Saturate(IntEnum):
28
+ False_ = 0
29
+ True_ = 1
30
+
31
+
32
+ class CFormat(IntEnum): # 2-bit field (bits 4-5)
33
+ F16 = 0
34
+ F32 = 1
35
+ S32 = 2
36
+
37
+
38
+ class F16F32Format(IntEnum): # 3-bit field (A/B element type)
39
+ F16 = 0
40
+ BF16 = 1
41
+ TF32 = 2
42
+
43
+
44
+ class S8Format(IntEnum):
45
+ UINT8 = 0
46
+ INT8 = 1
47
+
48
+
49
+ class MXF8F6F4Format(IntEnum):
50
+ E4M3 = 0
51
+ E5M2 = 1
52
+ E2M3 = 3
53
+ E3M2 = 4
54
+ E2M1 = 5
55
+
56
+
57
+ class MaxShift(IntEnum):
58
+ NoShift = 0
59
+ MaxShift8 = 1
60
+ MaxShift16 = 2
61
+ MaxShift32 = 3
62
+
63
+
64
+ # ---------------------------------------------------------------------------
65
+ # CUTLASS-type → encoding helpers
66
+ # ---------------------------------------------------------------------------
67
+
68
+
69
+ def to_UMMA_format(cutlass_type) -> int:
70
+ """
71
+ Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B.
72
+ """
73
+ if cutlass_type is cutlass.Int8:
74
+ return S8Format.INT8
75
+ # Unsigned 8-bit (if available in your CUTLASS build)
76
+ if cutlass_type is cutlass.Uint8:
77
+ return S8Format.UINT8
78
+ # FP-16 / BF-16
79
+ if cutlass_type is cutlass.Float16:
80
+ return F16F32Format.F16
81
+ if cutlass_type is cutlass.BFloat16:
82
+ return F16F32Format.BF16
83
+ # TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits)
84
+ if cutlass_type is cutlass.TFloat32:
85
+ return F16F32Format.TF32
86
+ # Float-8 / Float-6 / Float-4 – add whenever CUTLASS exposes them
87
+ if cutlass_type is cutlass.FloatE4M3FN:
88
+ return MXF8F6F4Format.E4M3
89
+ if cutlass_type is cutlass.FloatE5M2:
90
+ return MXF8F6F4Format.E5M2
91
+ raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}")
92
+
93
+
94
+ def to_C_format(cutlass_type) -> int:
95
+ """
96
+ Map a CUTLASS scalar class to the 2-bit accumulator encoding.
97
+ """
98
+ if cutlass_type is cutlass.Float16:
99
+ return CFormat.F16
100
+ if cutlass_type is cutlass.Float32:
101
+ return CFormat.F32
102
+ if cutlass_type is cutlass.Int32:
103
+ return CFormat.S32
104
+ raise TypeError(f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}")
105
+
106
+
107
+ # ---------------------------------------------------------------------------
108
+ # The constructor – accepts only CUTLASS scalar classes
109
+ # ---------------------------------------------------------------------------
110
+
111
+
112
+ def make_instr_desc(
113
+ a_type, # CUTLASS scalar class, e.g. cutlass.Int8
114
+ b_type,
115
+ c_type,
116
+ M: int, # 64, 128 or 256
117
+ N: int, # 8 … 256 (multiple of 8)
118
+ a_major: Major,
119
+ b_major: Major,
120
+ a_neg: ScaleIn = ScaleIn.One,
121
+ b_neg: ScaleIn = ScaleIn.One,
122
+ c_sat: Saturate = Saturate.False_,
123
+ is_sparse: bool = False,
124
+ max_shift: MaxShift = MaxShift.NoShift,
125
+ ) -> int:
126
+ """
127
+ Build the 32-bit instruction descriptor for Blackwell MMA.
128
+ All matrix/accumulator **types must be CUTLASS scalar classes** –
129
+ passing integers is forbidden.
130
+ """
131
+ # --- encode element formats -------------------------------------------------
132
+ a_fmt = int(to_UMMA_format(a_type))
133
+ b_fmt = int(to_UMMA_format(b_type))
134
+ c_fmt = int(to_C_format(c_type))
135
+
136
+ # --- range checks on M/N -----------------------------------------------------
137
+ if M not in (64, 128, 256):
138
+ raise ValueError("M must be 64, 128 or 256")
139
+ if N < 8 or N > 256 or (N & 7):
140
+ raise ValueError("N must be a multiple of 8 in the range 8…256")
141
+
142
+ m_dim = M >> 4 # 5-bit field
143
+ n_dim = N >> 3 # 6-bit field
144
+
145
+ # fmt: off
146
+ # --- pack the bit-fields -----------------------------------------------------
147
+ desc = 0
148
+ desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here)
149
+ desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag
150
+ desc |= (int(c_sat) & 0x1) << 3 # saturate
151
+ desc |= (c_fmt & 0x3) << 4 # c_format
152
+ desc |= (a_fmt & 0x7) << 7 # a_format
153
+ desc |= (b_fmt & 0x7) << 10 # b_format
154
+ desc |= (int(a_neg) & 0x1) << 13 # a_negate
155
+ desc |= (int(b_neg) & 0x1) << 14 # b_negate
156
+ desc |= (int(a_major) & 0x1) << 15 # a_major
157
+ desc |= (int(b_major) & 0x1) << 16 # b_major
158
+ desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits)
159
+ desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits)
160
+ desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits)
161
+ # fmt: on
162
+
163
+ return desc & 0xFFFF_FFFF # ensure 32-bit result
164
+
165
+
166
+ def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp):
167
+ return make_instr_desc(
168
+ op.a_dtype,
169
+ op.b_dtype,
170
+ op.acc_dtype,
171
+ op.shape_mnk[0],
172
+ op.shape_mnk[1],
173
+ Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN,
174
+ Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN,
175
+ )
176
+
177
+
178
+ class LayoutType(IntEnum): # occupies the top-3 bits [61:64)
179
+ SWIZZLE_NONE = 0 # (a.k.a. “INTERLEAVE” in older docs)
180
+ SWIZZLE_128B_BASE32B = 1
181
+ SWIZZLE_128B = 2
182
+ SWIZZLE_64B = 4
183
+ SWIZZLE_32B = 6
184
+ # values 3,5,7 are reserved / illegal for UMMA
185
+
186
+
187
+ # ---------------------------------------------------------------------------
188
+ # Helpers – figure out the SWIZZLE_* family from the tensor layout
189
+ # ---------------------------------------------------------------------------
190
+
191
+
192
+ def _layout_type(swizzle: cute.Swizzle) -> LayoutType:
193
+ # No idea what the right way to get B, M, S is – so we're just parsing it from the __str__
194
+ # Swizzle string has the form "S<B,M,S>"
195
+ swz_str = str(swizzle)
196
+ inside = swz_str[swz_str.index("<") + 1 : swz_str.index(">")] # '3,4,3'
197
+ B, M, S = [int(x) for x in inside.split(",")] # [3, 4, 3]
198
+
199
+ if M == 4: # Swizzle<*,4,3>
200
+ if S != 3:
201
+ raise ValueError("Unexpected swizzle shift – want S==3 for M==4")
202
+ return {
203
+ 0: LayoutType.SWIZZLE_NONE,
204
+ 1: LayoutType.SWIZZLE_32B,
205
+ 2: LayoutType.SWIZZLE_64B,
206
+ 3: LayoutType.SWIZZLE_128B,
207
+ }[B] # KeyError ⇒ invalid B→ raise
208
+ if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5)
209
+ if (B, S) != (2, 2):
210
+ raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B")
211
+ return LayoutType.SWIZZLE_128B_BASE32B
212
+
213
+ # Any other (M,B,S) triple is not a UMMA-legal shared-memory layout
214
+ raise ValueError("Unsupported swizzle triple for UMMA smem descriptor")
215
+
216
+
217
+ def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int:
218
+ """
219
+ Convert a 2-D *shared-memory* Cute layout into the Blackwell 64-bit
220
+ smem-descriptor, without the smem start address.
221
+ layout must correspond to layout of an uint128 tensor.
222
+ """
223
+ # ------------------------------------------------------------------ meta
224
+ layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family
225
+
226
+ VERSION = 1 # bits 46–47
227
+ LBO_MODE = 0 # bit 52
228
+ BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0)
229
+
230
+ # ---------------------------------------------------------- strides (units: uint128_t = 16 B)
231
+ swizzle_atom_mn_size = {
232
+ LayoutType.SWIZZLE_NONE: 1,
233
+ LayoutType.SWIZZLE_32B: 2,
234
+ LayoutType.SWIZZLE_64B: 4,
235
+ LayoutType.SWIZZLE_128B: 8,
236
+ LayoutType.SWIZZLE_128B_BASE32B: 8,
237
+ }[layout_type]
238
+
239
+ if major is Major.MN:
240
+ swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8
241
+ canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size))
242
+ if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))):
243
+ raise ValueError("Not a canonical UMMA_MN Layout: Expected profile failure.")
244
+ stride_00 = canonical_layout.stride[0][0]
245
+ if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1:
246
+ raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.")
247
+ stride_10 = canonical_layout.stride[1][0]
248
+ if stride_10 != swizzle_atom_mn_size:
249
+ raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.")
250
+ stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1]
251
+ if layout_type is LayoutType.SWIZZLE_NONE:
252
+ stride_byte_offset, leading_byte_offset = stride_01, stride_11
253
+ else:
254
+ stride_byte_offset, leading_byte_offset = stride_11, stride_01
255
+ else:
256
+ if layout_type == LayoutType.SWIZZLE_128B_BASE32B:
257
+ raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K")
258
+ if not cute.size(layout.shape[0]) % 8 == 0:
259
+ raise ValueError("Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.")
260
+ canonical_layout = cute.logical_divide(layout, (8, 2))
261
+ if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))):
262
+ raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.")
263
+ stride_00 = canonical_layout.stride[0][0]
264
+ if stride_00 != swizzle_atom_mn_size:
265
+ raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.")
266
+ stride_10 = canonical_layout.stride[1][0]
267
+ if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1:
268
+ raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.")
269
+ stride_01 = canonical_layout.stride[0][1]
270
+ stride_byte_offset, leading_byte_offset = stride_01, stride_10
271
+
272
+ # ------------------------------------------------------------------ pack
273
+ desc = 0
274
+ # leading_byte_offset_ [16:30)
275
+ desc |= (leading_byte_offset & 0x3FFF) << 16
276
+ # stride_byte_offset_ [32:46)
277
+ desc |= (stride_byte_offset & 0x3FFF) << 32
278
+ # version_ [46:48)
279
+ desc |= (VERSION & 0x3) << 46
280
+ # base_offset_ [49:52)
281
+ desc |= (BASE_OFFSET & 0x7) << 49
282
+ # lbo_mode_ [52:53)
283
+ desc |= (LBO_MODE & 0x1) << 52
284
+ # layout_type_ [61:64)
285
+ desc |= (int(layout_type) & 0x7) << 61
286
+
287
+ return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width
288
+
289
+
290
+ def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32:
291
+ # 14 bits, remove 4 LSB (bits 0-13 in desc)
292
+ return (start_addr.toint() & 0x3FFFF) >> 4
@@ -0,0 +1,32 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+
4
+ import enum
5
+
6
+
7
+ class NamedBarrierFwd(enum.IntEnum):
8
+ Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
9
+ WarpSchedulerWG1 = enum.auto()
10
+ WarpSchedulerWG2 = enum.auto()
11
+ WarpSchedulerWG3 = enum.auto()
12
+ PFull = enum.auto()
13
+ PEmpty = enum.auto()
14
+
15
+
16
+ class NamedBarrierBwd(enum.IntEnum):
17
+ Epilogue = enum.auto()
18
+ WarpSchedulerWG1 = enum.auto()
19
+ WarpSchedulerWG2 = enum.auto()
20
+ WarpSchedulerWG3 = enum.auto()
21
+ PdS = enum.auto()
22
+ dQFullWG0 = enum.auto()
23
+ dQFullWG1 = enum.auto()
24
+ dQEmptyWG0 = enum.auto()
25
+ dQEmptyWG1 = enum.auto()
26
+
27
+
28
+ class NamedBarrierBwdSm100(enum.IntEnum):
29
+ EpilogueWG1 = enum.auto()
30
+ EpilogueWG2 = enum.auto()
31
+ Compute = enum.auto()
32
+ dQaccReduce = enum.auto()
@@ -0,0 +1,165 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Tri Dao.
3
+
4
+
5
+ import cutlass
6
+ import cutlass.cute as cute
7
+
8
+ import mslk.attention.flash_attn.utils as utils
9
+
10
+
11
+ class PackGQA:
12
+ def __init__(
13
+ self,
14
+ m_block_size: cutlass.Constexpr[int],
15
+ head_dim_padded: cutlass.Constexpr[int],
16
+ check_hdim_oob: cutlass.Constexpr[bool],
17
+ qhead_per_kvhead: cutlass.Constexpr[bool],
18
+ ):
19
+ self.m_block_size = m_block_size
20
+ self.head_dim_padded = head_dim_padded
21
+ self.check_hdim_oob = check_hdim_oob
22
+ self.qhead_per_kvhead = qhead_per_kvhead
23
+
24
+ @cute.jit
25
+ def compute_ptr(
26
+ self,
27
+ tensor: cute.Tensor,
28
+ cRows: cute.Tensor,
29
+ tidx: cutlass.Int32,
30
+ block: cutlass.Int32,
31
+ threads_per_row: cutlass.Constexpr[int],
32
+ num_threads: cutlass.Constexpr[int],
33
+ ):
34
+ num_ptr_per_thread = cute.ceil_div(cute.size(cRows), threads_per_row)
35
+ tPrPtr = cute.make_fragment(num_ptr_per_thread, cutlass.Int64)
36
+ for i in cutlass.range_constexpr(num_ptr_per_thread):
37
+ row = i * num_threads + cRows[tidx % threads_per_row][0]
38
+ idx = block * self.m_block_size + row
39
+ m_idx = idx // self.qhead_per_kvhead
40
+ h_idx = idx - m_idx * self.qhead_per_kvhead
41
+ tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint()
42
+ return tPrPtr
43
+
44
+ @cute.jit
45
+ def load_Q(
46
+ self,
47
+ mQ: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim)
48
+ sQ: cute.Tensor, # (m_block_size, head_dim_padded)
49
+ gmem_tiled_copy: cute.TiledCopy,
50
+ tidx: cutlass.Int32,
51
+ block: cutlass.Int32,
52
+ seqlen: cutlass.Int32,
53
+ ):
54
+ gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)
55
+ cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
56
+ tQsQ = gmem_thr_copy.partition_D(sQ)
57
+ tQcQ = gmem_thr_copy.partition_S(cQ)
58
+ t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ)
59
+ tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[1])
60
+ tQcQ_row = tQcQ[0, None, 0]
61
+ threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0]
62
+ assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
63
+ num_threads = gmem_tiled_copy.size
64
+ tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads)
65
+ for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
66
+ q_ptr_i64 = utils.shuffle_sync(
67
+ tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row
68
+ )
69
+ q_gmem_ptr = cute.make_ptr(
70
+ mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
71
+ )
72
+ if (
73
+ t0QcQ[0, m, 0][0]
74
+ < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0]
75
+ ):
76
+ mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,))
77
+ elems_per_load = cute.size(tQsQ.shape[0][0])
78
+ mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,))
79
+ for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])):
80
+ ki = tQcQ[0, 0, k][1] // elems_per_load
81
+ cute.copy(
82
+ gmem_thr_copy,
83
+ mQ_cur_copy[None, ki],
84
+ tQsQ[None, m, k],
85
+ pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None,
86
+ )
87
+ # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
88
+
89
+ @cute.jit
90
+ def store_LSE(
91
+ self,
92
+ mLSE: cute.Tensor, # (qhead_per_kvhead, seqlen_q)
93
+ tLSErLSE: cute.Tensor, # (m_block_size, head_dim_padded)
94
+ tiled_mma: cute.TiledMma,
95
+ tidx: cutlass.Int32,
96
+ block: cutlass.Int32,
97
+ seqlen: cutlass.Int32,
98
+ ):
99
+ thr_mma = tiled_mma.get_slice(tidx)
100
+ caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
101
+ taccOcO = thr_mma.partition_C(caccO)
102
+ taccOcO_row = utils.make_acc_tensor_mn_view(taccOcO)[None, 0]
103
+ assert cute.size(tLSErLSE) == cute.size(taccOcO_row)
104
+ threads_per_row = tiled_mma.tv_layout_C.shape[0][0]
105
+ assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
106
+ assert cute.size(tLSErLSE) <= threads_per_row
107
+ num_threads = tiled_mma.size
108
+ tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads)
109
+ for m in cutlass.range_constexpr(cute.size(tLSErLSE)):
110
+ lse_ptr_i64 = utils.shuffle_sync(
111
+ tPrLSEPtr[m // threads_per_row],
112
+ m % threads_per_row,
113
+ width=threads_per_row,
114
+ )
115
+ lse_gmem_ptr = cute.make_ptr(
116
+ mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4
117
+ )
118
+ row = block * self.m_block_size + taccOcO_row[m][0]
119
+ # Only the thread corresponding to column 0 writes out the lse to gmem
120
+ if taccOcO[0][1] == 0 and row < seqlen * self.qhead_per_kvhead:
121
+ mLSE_copy = cute.make_tensor(lse_gmem_ptr, (1,))
122
+ mLSE_copy[0] = tLSErLSE[m]
123
+
124
+ @cute.jit
125
+ def store_O(
126
+ self,
127
+ mO: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim)
128
+ tOrO: cute.Tensor, # (m_block_size, head_dim_padded) split across threads according to gmem_tiled_copy
129
+ gmem_tiled_copy: cute.TiledCopy,
130
+ tidx: cutlass.Int32,
131
+ block: cutlass.Int32,
132
+ seqlen: cutlass.Int32,
133
+ ):
134
+ gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)
135
+ cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
136
+ tOcO = gmem_thr_copy.partition_S(cO)
137
+ t0OcO = gmem_thr_copy.get_slice(0).partition_S(cO)
138
+ tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])
139
+ tOcO_row = tOcO[0, None, 0]
140
+ threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0]
141
+ assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
142
+ num_threads = gmem_tiled_copy.size
143
+ tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads)
144
+ for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
145
+ o_ptr_i64 = utils.shuffle_sync(
146
+ tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row
147
+ )
148
+ o_gmem_ptr = cute.make_ptr(
149
+ mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
150
+ )
151
+ if (
152
+ t0OcO[0, m, 0][0]
153
+ < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0]
154
+ ):
155
+ mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,))
156
+ elems_per_load = cute.size(tOrO.shape[0][0])
157
+ mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,))
158
+ for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])):
159
+ ki = tOcO[0, 0, k][1] // elems_per_load
160
+ cute.copy(
161
+ gmem_thr_copy,
162
+ tOrO[None, m, k],
163
+ mO_cur_copy[None, ki],
164
+ pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None,
165
+ )
@@ -0,0 +1,176 @@
1
+ # @nolint # fbcode
2
+ from typing import Type
3
+ from dataclasses import dataclass
4
+
5
+ import cutlass
6
+ import cutlass.cute as cute
7
+ from cutlass.cute.nvgpu import cpasync
8
+ from cutlass import Int32, const_expr
9
+
10
+ from mslk.attention.flash_attn import utils
11
+ from mslk.attention.flash_attn.cute_dsl_utils import ParamsBase
12
+ from cutlass.cute import FastDivmodDivisor
13
+
14
+
15
+ @dataclass
16
+ class PagedKVManager(ParamsBase):
17
+ mPageTable: cute.Tensor
18
+ mK_paged: cute.Tensor
19
+ mV_paged: cute.Tensor
20
+ thread_idx: Int32
21
+
22
+ page_size_divmod: FastDivmodDivisor
23
+ seqlen_k: Int32
24
+ leftpad_k: Int32
25
+ n_block_size: Int32
26
+ num_threads: cutlass.Constexpr[Int32]
27
+ head_dim_padded: cutlass.Constexpr[Int32]
28
+ head_dim_v_padded: cutlass.Constexpr[Int32]
29
+
30
+ gmem_threads_per_row: cutlass.Constexpr[Int32]
31
+ page_entry_per_thread: Int32
32
+ async_copy_elems: Int32
33
+
34
+ gmem_tiled_copy_KV: cute.TiledCopy
35
+ gmem_thr_copy_KV: cute.TiledCopy
36
+ tPrPage: cute.Tensor
37
+ tPrPageOffset: cute.Tensor
38
+ tKpK: cute.Tensor
39
+ tVpV: cute.Tensor
40
+
41
+ @staticmethod
42
+ def create(
43
+ mPageTable: cute.Tensor,
44
+ mK_paged: cute.Tensor,
45
+ mV_paged: cute.Tensor,
46
+ page_size_divmod: FastDivmodDivisor,
47
+ bidb: Int32,
48
+ bidh: Int32,
49
+ thread_idx: Int32,
50
+ seqlen_k: Int32,
51
+ leftpad_k: Int32,
52
+ n_block_size: cutlass.Constexpr[Int32],
53
+ head_dim_padded: cutlass.Constexpr[Int32],
54
+ head_dim_v_padded: cutlass.Constexpr[Int32],
55
+ num_threads: cutlass.Constexpr[Int32],
56
+ dtype: Type[cutlass.Numeric],
57
+ ):
58
+ universal_copy_bits = 128
59
+ gmem_threads_per_row = 8 # 8 threads loading 128 bits = 128 bytes = 1 cache line
60
+ async_copy_elems = universal_copy_bits // dtype.width
61
+ atom_async_copy = cute.make_copy_atom(
62
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
63
+ dtype,
64
+ num_bits_per_copy=universal_copy_bits,
65
+ )
66
+ thr_layout = cute.make_ordered_layout(
67
+ (num_threads // gmem_threads_per_row, gmem_threads_per_row),
68
+ order=(1, 0),
69
+ )
70
+ val_layout = cute.make_layout((1, async_copy_elems))
71
+ gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout)
72
+ gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx)
73
+ page_entry_per_thread = n_block_size * gmem_threads_per_row // num_threads
74
+
75
+ tPrPage = cute.make_rmem_tensor((page_entry_per_thread,), Int32)
76
+ tPrPageOffset = cute.make_rmem_tensor((page_entry_per_thread,), Int32)
77
+
78
+ mPageTable = mPageTable[bidb, None]
79
+ mK_paged = mK_paged[None, None, bidh, None]
80
+ mV_paged = mV_paged[None, None, bidh, None]
81
+
82
+ cK = cute.make_identity_tensor((n_block_size, head_dim_padded))
83
+ tKcK = gmem_thr_copy_KV.partition_S(cK)
84
+ tKpK = utils.predicate_k(tKcK, limit=mK_paged.shape[1])
85
+
86
+ if const_expr(head_dim_padded == head_dim_v_padded):
87
+ tVpV = tKpK
88
+ else:
89
+ cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded))
90
+ tVcV = gmem_thr_copy_KV.partition_S(cV)
91
+ tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0])
92
+
93
+ return PagedKVManager(
94
+ mPageTable,
95
+ mK_paged,
96
+ mV_paged,
97
+ thread_idx,
98
+ page_size_divmod,
99
+ seqlen_k,
100
+ leftpad_k,
101
+ n_block_size,
102
+ num_threads,
103
+ head_dim_padded,
104
+ head_dim_v_padded,
105
+ gmem_threads_per_row,
106
+ page_entry_per_thread,
107
+ async_copy_elems,
108
+ gmem_tiled_copy_KV,
109
+ gmem_thr_copy_KV,
110
+ tPrPage,
111
+ tPrPageOffset,
112
+ tKpK,
113
+ tVpV,
114
+ )
115
+
116
+ @cute.jit
117
+ def load_page_table(self, n_block: Int32):
118
+ for i in cutlass.range(self.page_entry_per_thread, unroll=1):
119
+ row = (i * self.num_threads + self.thread_idx) // self.gmem_threads_per_row
120
+ row_idx = n_block * self.n_block_size + row
121
+
122
+ page_idx, page_offset = divmod(row_idx + self.leftpad_k, self.page_size_divmod)
123
+
124
+ is_valid = (
125
+ (i + 1) * self.num_threads <= self.n_block_size or row < self.n_block_size
126
+ ) and row_idx < self.seqlen_k
127
+ page = self.mPageTable[page_idx] if is_valid else 0
128
+
129
+ self.tPrPage[i] = page
130
+ self.tPrPageOffset[i] = page_offset
131
+
132
+ @cute.jit
133
+ def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str):
134
+ assert K_or_V in ("K", "V")
135
+
136
+ # Finesse sX layout to be (M, N).
137
+ sX_pi = cute.make_tensor(
138
+ sX.iterator,
139
+ cute.make_layout(
140
+ (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])),
141
+ stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])),
142
+ ),
143
+ )
144
+
145
+ if const_expr(K_or_V == "V"):
146
+ # Need to transpose V
147
+ sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0]))
148
+
149
+ head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded
150
+ cX = cute.make_identity_tensor((self.n_block_size, head_dim))
151
+ tXsX = self.gmem_thr_copy_KV.partition_D(sX_pi)
152
+ tXcX = self.gmem_thr_copy_KV.partition_S(cX)
153
+
154
+ seqlenk_row_limit = self.seqlen_k - n_block * self.n_block_size if n_block >= 0 else 0
155
+ for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])):
156
+ row_valid = tXcX[0, m, 0][0] < seqlenk_row_limit
157
+ should_load = cute.make_fragment_like(tXsX[None, m, 0], cute.Boolean)
158
+ should_load.fill(row_valid)
159
+
160
+ page = self.tPrPage[m]
161
+ page_offset = self.tPrPageOffset[m]
162
+ mX_paged_cur = (
163
+ self.mK_paged[page_offset, None, page]
164
+ if const_expr(K_or_V == "K")
165
+ else self.mV_paged[None, page_offset, page]
166
+ )
167
+ mX_paged_cur_copy = cute.tiled_divide(mX_paged_cur, (self.async_copy_elems,))
168
+
169
+ for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])):
170
+ ki = tXcX[0, 0, k][1] // self.async_copy_elems
171
+ cute.copy(
172
+ self.gmem_tiled_copy_KV,
173
+ mX_paged_cur_copy[None, ki],
174
+ tXsX[None, m, k],
175
+ pred=should_load,
176
+ )