mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
3
|
+
|
|
4
|
+
import math
|
|
5
|
+
from typing import Optional, Type, Callable
|
|
6
|
+
|
|
7
|
+
import cutlass
|
|
8
|
+
import cutlass.cute as cute
|
|
9
|
+
from cutlass import Float32, Int32, const_expr
|
|
10
|
+
from cutlass.cute.nvgpu import cpasync
|
|
11
|
+
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
12
|
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
13
|
+
from cutlass._mlir.dialects import llvm
|
|
14
|
+
import cutlass.pipeline
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dsl_user_op
|
|
18
|
+
def cvt_copy(
|
|
19
|
+
atom: cute.CopyAtom,
|
|
20
|
+
src: cute.Tensor,
|
|
21
|
+
dst: cute.Tensor,
|
|
22
|
+
*,
|
|
23
|
+
pred: Optional[cute.Tensor] = None,
|
|
24
|
+
loc=None,
|
|
25
|
+
ip=None,
|
|
26
|
+
**kwargs,
|
|
27
|
+
) -> None:
|
|
28
|
+
assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
|
|
29
|
+
if const_expr(src.element_type != dst.element_type):
|
|
30
|
+
src_cvt = cute.make_fragment_like(src, dst.element_type, loc=loc, ip=ip)
|
|
31
|
+
src_cvt.store(src.load().to(dst.element_type))
|
|
32
|
+
src = src_cvt
|
|
33
|
+
cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dsl_user_op
|
|
37
|
+
def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
38
|
+
dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip)
|
|
39
|
+
cute.autovec_copy(src, dst, loc=loc, ip=ip)
|
|
40
|
+
return dst
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dsl_user_op
|
|
44
|
+
def get_copy_atom(
|
|
45
|
+
dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
|
|
46
|
+
) -> cute.CopyAtom:
|
|
47
|
+
num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
|
|
48
|
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
49
|
+
return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dsl_user_op
|
|
53
|
+
def make_tmem_copy(
|
|
54
|
+
tmem_copy_atom: cute.CopyAtom, num_wg: int = 1, *, loc=None, ip=None
|
|
55
|
+
) -> cute.CopyAtom:
|
|
56
|
+
num_dp, num_bits, num_rep, _ = sm100_utils.get_tmem_copy_properties(tmem_copy_atom)
|
|
57
|
+
assert num_dp == 32
|
|
58
|
+
assert num_bits == 32
|
|
59
|
+
tiler_mn = (cute.make_layout((128 * num_rep * num_wg // 32, 32), stride=(32, 1)),)
|
|
60
|
+
layout_tv = cute.make_layout(
|
|
61
|
+
((32, 4, num_wg), (num_rep, 32)), stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg))
|
|
62
|
+
)
|
|
63
|
+
return cute.make_tiled_copy(tmem_copy_atom, layout_tv, tiler_mn)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dsl_user_op
|
|
67
|
+
def copy(
|
|
68
|
+
src: cute.Tensor,
|
|
69
|
+
dst: cute.Tensor,
|
|
70
|
+
*,
|
|
71
|
+
pred: Optional[cute.Tensor] = None,
|
|
72
|
+
num_copy_elems: int = 1,
|
|
73
|
+
is_async: bool = False,
|
|
74
|
+
loc=None,
|
|
75
|
+
ip=None,
|
|
76
|
+
**kwargs,
|
|
77
|
+
) -> None:
|
|
78
|
+
copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
|
|
79
|
+
cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def tiled_copy_1d(
|
|
83
|
+
dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False
|
|
84
|
+
) -> cute.TiledCopy:
|
|
85
|
+
num_copy_bits = num_copy_elems * dtype.width
|
|
86
|
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
87
|
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
88
|
+
thr_layout = cute.make_layout(num_threads)
|
|
89
|
+
val_layout = cute.make_layout(num_copy_elems)
|
|
90
|
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def tiled_copy_2d(
|
|
94
|
+
dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
|
|
95
|
+
) -> cute.TiledCopy:
|
|
96
|
+
num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
|
|
97
|
+
copy_elems = num_copy_bits // dtype.width
|
|
98
|
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
99
|
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
100
|
+
gmem_threads_per_row = major_mode_size // copy_elems
|
|
101
|
+
assert num_threads % gmem_threads_per_row == 0
|
|
102
|
+
thr_layout = cute.make_ordered_layout(
|
|
103
|
+
(num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
|
104
|
+
order=(1, 0),
|
|
105
|
+
)
|
|
106
|
+
val_layout = cute.make_layout((1, copy_elems))
|
|
107
|
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@dsl_user_op
|
|
111
|
+
def atomic_add_fp32x4(
|
|
112
|
+
a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None
|
|
113
|
+
) -> None:
|
|
114
|
+
gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
115
|
+
# cache_hint = cutlass.Int64(0x12F0000000000000)
|
|
116
|
+
llvm.inline_asm(
|
|
117
|
+
None,
|
|
118
|
+
[
|
|
119
|
+
gmem_ptr_i64,
|
|
120
|
+
Float32(a).ir_value(loc=loc, ip=ip),
|
|
121
|
+
Float32(b).ir_value(loc=loc, ip=ip),
|
|
122
|
+
Float32(c).ir_value(loc=loc, ip=ip),
|
|
123
|
+
Float32(d).ir_value(loc=loc, ip=ip),
|
|
124
|
+
],
|
|
125
|
+
# [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()],
|
|
126
|
+
"{\n\t"
|
|
127
|
+
# ".reg .b128 abcd;\n\t"
|
|
128
|
+
# "mov.b128 abcd, {$1, $2, $3, $4};\n\t"
|
|
129
|
+
".reg .v4 .f32 abcd;\n\t"
|
|
130
|
+
# "mov.b128 abcd, {$1, $2, $3, $4};\n\t"
|
|
131
|
+
"mov.f32 abcd.x, $1;\n\t"
|
|
132
|
+
"mov.f32 abcd.y, $2;\n\t"
|
|
133
|
+
"mov.f32 abcd.z, $3;\n\t"
|
|
134
|
+
"mov.f32 abcd.w, $4;\n\t"
|
|
135
|
+
"red.global.add.v4.f32 [$0], abcd;\n\t"
|
|
136
|
+
# "red.global.add.L2::cache_hint.v4.f32 [$0], abcd, 0x14F0000000000000;\n\t"
|
|
137
|
+
"}\n",
|
|
138
|
+
# "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;",
|
|
139
|
+
# "red.global.add.L2::cache_hint.f32 [$0], $1, $2;",
|
|
140
|
+
"l,f,f,f,f",
|
|
141
|
+
# "l,f,l",
|
|
142
|
+
has_side_effects=True,
|
|
143
|
+
is_align_stack=False,
|
|
144
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@dsl_user_op
|
|
149
|
+
def set_block_rank(
|
|
150
|
+
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
|
|
151
|
+
) -> Int32:
|
|
152
|
+
"""Map the given smem pointer to the address at another CTA rank in the cluster."""
|
|
153
|
+
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
154
|
+
return Int32(
|
|
155
|
+
llvm.inline_asm(
|
|
156
|
+
T.i32(),
|
|
157
|
+
[smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
|
|
158
|
+
"mapa.shared::cluster.u32 $0, $1, $2;",
|
|
159
|
+
"=r,r,r",
|
|
160
|
+
has_side_effects=False,
|
|
161
|
+
is_align_stack=False,
|
|
162
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@dsl_user_op
|
|
168
|
+
def store_shared_remote_fp32x4(
|
|
169
|
+
a: Float32,
|
|
170
|
+
b: Float32,
|
|
171
|
+
c: Float32,
|
|
172
|
+
d: Float32,
|
|
173
|
+
smem_ptr: cute.Pointer,
|
|
174
|
+
mbar_ptr: cute.Pointer,
|
|
175
|
+
peer_cta_rank_in_cluster: Int32,
|
|
176
|
+
*,
|
|
177
|
+
loc=None,
|
|
178
|
+
ip=None,
|
|
179
|
+
) -> None:
|
|
180
|
+
remote_smem_ptr_i32 = set_block_rank(
|
|
181
|
+
smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
|
182
|
+
).ir_value()
|
|
183
|
+
remote_mbar_ptr_i32 = set_block_rank(
|
|
184
|
+
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
|
185
|
+
).ir_value()
|
|
186
|
+
llvm.inline_asm(
|
|
187
|
+
None,
|
|
188
|
+
[
|
|
189
|
+
remote_smem_ptr_i32,
|
|
190
|
+
remote_mbar_ptr_i32,
|
|
191
|
+
Float32(a).ir_value(loc=loc, ip=ip),
|
|
192
|
+
Float32(b).ir_value(loc=loc, ip=ip),
|
|
193
|
+
Float32(c).ir_value(loc=loc, ip=ip),
|
|
194
|
+
Float32(d).ir_value(loc=loc, ip=ip),
|
|
195
|
+
],
|
|
196
|
+
"{\n\t"
|
|
197
|
+
".reg .v4 .f32 abcd;\n\t"
|
|
198
|
+
"mov.f32 abcd.x, $2;\n\t"
|
|
199
|
+
"mov.f32 abcd.y, $3;\n\t"
|
|
200
|
+
"mov.f32 abcd.z, $4;\n\t"
|
|
201
|
+
"mov.f32 abcd.w, $5;\n\t"
|
|
202
|
+
"st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.f32 [$0], abcd, [$1];\n\t"
|
|
203
|
+
"}\n",
|
|
204
|
+
"r,r,f,f,f,f",
|
|
205
|
+
has_side_effects=True,
|
|
206
|
+
is_align_stack=False,
|
|
207
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@dsl_user_op
|
|
212
|
+
def cpasync_bulk_g2s(
|
|
213
|
+
gmem_ptr: cute.Pointer,
|
|
214
|
+
smem_ptr: cute.Pointer,
|
|
215
|
+
tma_bar_ptr: cute.Pointer,
|
|
216
|
+
size: int | Int32,
|
|
217
|
+
*,
|
|
218
|
+
loc=None,
|
|
219
|
+
ip=None,
|
|
220
|
+
):
|
|
221
|
+
gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
222
|
+
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
223
|
+
mbar_ptr_i32 = tma_bar_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
224
|
+
llvm.inline_asm(
|
|
225
|
+
None,
|
|
226
|
+
[gmem_ptr_i64, smem_ptr_i32, mbar_ptr_i32, Int32(size).ir_value()],
|
|
227
|
+
"cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$0], $3, [$2];",
|
|
228
|
+
"l,r,r,r",
|
|
229
|
+
has_side_effects=True,
|
|
230
|
+
is_align_stack=False,
|
|
231
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
@dsl_user_op
|
|
236
|
+
def cpasync_reduce_bulk_add_f32(
|
|
237
|
+
smem_ptr: cute.Pointer,
|
|
238
|
+
gmem_ptr: cute.Pointer,
|
|
239
|
+
store_bytes: int | Int32,
|
|
240
|
+
*,
|
|
241
|
+
loc=None,
|
|
242
|
+
ip=None,
|
|
243
|
+
):
|
|
244
|
+
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
245
|
+
# cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST
|
|
246
|
+
llvm.inline_asm(
|
|
247
|
+
None,
|
|
248
|
+
[gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()],
|
|
249
|
+
"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;",
|
|
250
|
+
"l,r,r",
|
|
251
|
+
# [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()],
|
|
252
|
+
# "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;",
|
|
253
|
+
# "l,r,r,l",
|
|
254
|
+
has_side_effects=True,
|
|
255
|
+
is_align_stack=False,
|
|
256
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def cpasync_bulk_get_copy_fn(
|
|
261
|
+
src_tensor: cute.Tensor,
|
|
262
|
+
dst_tensor: cute.Tensor,
|
|
263
|
+
single_stage: bool = False,
|
|
264
|
+
**kwargs,
|
|
265
|
+
) -> Callable:
|
|
266
|
+
# src_is_smem = const_expr(
|
|
267
|
+
# isinstance(src_tensor.iterator, cute.Pointer)
|
|
268
|
+
# and src_tensor.memspace == cute.AddressSpace.smem
|
|
269
|
+
# )
|
|
270
|
+
group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0))
|
|
271
|
+
group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0))
|
|
272
|
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
|
273
|
+
src = cute.group_modes(src_tensor, 0, group_rank_src)
|
|
274
|
+
dst = cute.group_modes(dst_tensor, 0, group_rank_dst)
|
|
275
|
+
|
|
276
|
+
def copy_bulk(src_idx, dst_idx, **new_kwargs):
|
|
277
|
+
size = const_expr(cute.size(src.shape[:-1]) * src.element_type.width // 8)
|
|
278
|
+
cpasync_bulk_g2s(
|
|
279
|
+
src[None, src_idx].iterator,
|
|
280
|
+
dst[None, dst_idx].iterator,
|
|
281
|
+
size=size,
|
|
282
|
+
**new_kwargs,
|
|
283
|
+
**kwargs,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
def copy_bulk_single_stage(**new_kwargs):
|
|
287
|
+
size = const_expr(cute.size(src.shape) * src.element_type.width // 8)
|
|
288
|
+
cpasync_bulk_g2s(src.iterator, dst.iterator, size=size, **new_kwargs, **kwargs)
|
|
289
|
+
|
|
290
|
+
return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def tma_get_copy_fn(
|
|
294
|
+
atom: cute.CopyAtom,
|
|
295
|
+
cta_coord: cute.Coord,
|
|
296
|
+
cta_layout: cute.Layout,
|
|
297
|
+
src_tensor: cute.Tensor,
|
|
298
|
+
dst_tensor: cute.Tensor,
|
|
299
|
+
filter_zeros: bool = False,
|
|
300
|
+
single_stage: bool = False,
|
|
301
|
+
**kwargs,
|
|
302
|
+
) -> Callable:
|
|
303
|
+
src_is_smem = const_expr(
|
|
304
|
+
isinstance(src_tensor.iterator, cute.Pointer)
|
|
305
|
+
and src_tensor.memspace == cute.AddressSpace.smem
|
|
306
|
+
)
|
|
307
|
+
smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
|
|
308
|
+
group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))
|
|
309
|
+
group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))
|
|
310
|
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
|
311
|
+
s, g = cpasync.tma_partition(
|
|
312
|
+
atom,
|
|
313
|
+
cta_coord,
|
|
314
|
+
cta_layout,
|
|
315
|
+
cute.group_modes(smem_tensor, 0, group_rank_smem),
|
|
316
|
+
cute.group_modes(gmem_tensor, 0, group_rank_gmem),
|
|
317
|
+
)
|
|
318
|
+
if const_expr(filter_zeros):
|
|
319
|
+
s = cute.filter_zeros(s)
|
|
320
|
+
g = cute.filter_zeros(g)
|
|
321
|
+
src, dst = (s, g) if src_is_smem else (g, s)
|
|
322
|
+
|
|
323
|
+
def copy_tma(src_idx, dst_idx, **new_kwargs):
|
|
324
|
+
cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
|
|
325
|
+
|
|
326
|
+
def copy_tma_single_stage(**new_kwargs):
|
|
327
|
+
cute.copy(atom, src, dst, **new_kwargs, **kwargs)
|
|
328
|
+
|
|
329
|
+
return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
|
|
333
|
+
def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):
|
|
334
|
+
copy(
|
|
335
|
+
src_idx=src_idx,
|
|
336
|
+
dst_idx=producer_state.index,
|
|
337
|
+
tma_bar_ptr=pipeline.producer_get_barrier(producer_state),
|
|
338
|
+
**new_kwargs,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
return copy_fn
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Tri Dao.
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import pathlib
|
|
6
|
+
from typing import Tuple
|
|
7
|
+
from functools import partial, lru_cache
|
|
8
|
+
from dataclasses import dataclass, fields
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from triton.tools.disasm import extract
|
|
14
|
+
except ImportError:
|
|
15
|
+
extract = None
|
|
16
|
+
|
|
17
|
+
import cutlass
|
|
18
|
+
import cutlass.cute as cute
|
|
19
|
+
from cutlass.base_dsl.typing import JitArgument
|
|
20
|
+
from cutlass.cutlass_dsl import NumericMeta
|
|
21
|
+
from cutlass.cute.runtime import from_dlpack
|
|
22
|
+
|
|
23
|
+
StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
|
|
27
|
+
cute_compile_og = cute.compile
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
torch2cute_dtype_map = {
|
|
31
|
+
torch.float16: cutlass.Float16,
|
|
32
|
+
torch.bfloat16: cutlass.BFloat16,
|
|
33
|
+
torch.float32: cutlass.Float32,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@lru_cache
|
|
38
|
+
def get_max_active_clusters(cluster_size):
|
|
39
|
+
return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@lru_cache
|
|
43
|
+
def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
|
|
44
|
+
return torch.cuda.get_device_capability(device)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class ParamsBase:
|
|
49
|
+
def __extract_mlir_values__(self):
|
|
50
|
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
|
51
|
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
|
52
|
+
values, self._values_pos = [], []
|
|
53
|
+
for obj in non_constexpr_fields:
|
|
54
|
+
obj_values = cutlass.extract_mlir_values(obj)
|
|
55
|
+
values += obj_values
|
|
56
|
+
self._values_pos.append(len(obj_values))
|
|
57
|
+
return values
|
|
58
|
+
|
|
59
|
+
def __new_from_mlir_values__(self, values):
|
|
60
|
+
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
|
61
|
+
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
|
62
|
+
non_constexpr_fields = {
|
|
63
|
+
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
|
64
|
+
}
|
|
65
|
+
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
|
66
|
+
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
|
67
|
+
values = values[n_items:]
|
|
68
|
+
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class ArgumentsBase(JitArgument):
|
|
73
|
+
def __c_pointers__(self):
|
|
74
|
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
|
75
|
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
|
76
|
+
c_ptrs = []
|
|
77
|
+
for obj in non_constexpr_fields:
|
|
78
|
+
if hasattr(obj, "__c_pointers__"):
|
|
79
|
+
c_ptrs.extend(obj.__c_pointers__())
|
|
80
|
+
return c_ptrs
|
|
81
|
+
|
|
82
|
+
def __get_mlir_types__(self):
|
|
83
|
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
|
84
|
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
|
85
|
+
types, self._values_pos = [], []
|
|
86
|
+
for obj in non_constexpr_fields:
|
|
87
|
+
if hasattr(obj, "__get_mlir_types__"):
|
|
88
|
+
obj_types = obj.__get_mlir_types__()
|
|
89
|
+
types.extend(obj_types)
|
|
90
|
+
self._values_pos.append(len(obj_types))
|
|
91
|
+
else:
|
|
92
|
+
self._values_pos.append(0)
|
|
93
|
+
return types
|
|
94
|
+
|
|
95
|
+
def __new_from_mlir_values__(self, values):
|
|
96
|
+
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
|
97
|
+
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
|
98
|
+
non_constexpr_fields = {
|
|
99
|
+
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
|
100
|
+
}
|
|
101
|
+
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
|
102
|
+
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
|
103
|
+
values = values[n_items:]
|
|
104
|
+
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def load_cubin_module_data_patched(cubin_data, filepath):
|
|
108
|
+
pathlib.Path(filepath).write_bytes(cubin_data)
|
|
109
|
+
return load_cubin_module_data_og(cubin_data)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def cute_compile_patched(*args, **kwargs):
|
|
113
|
+
"""A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set."""
|
|
114
|
+
cubin_path = os.getenv("CUTE_CUBIN_PATH", None)
|
|
115
|
+
if cubin_path is not None:
|
|
116
|
+
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
|
|
117
|
+
load_cubin_module_data_patched, filepath=cubin_path
|
|
118
|
+
)
|
|
119
|
+
output = cute_compile_og(*args, **kwargs)
|
|
120
|
+
if cubin_path is not None:
|
|
121
|
+
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
|
|
122
|
+
if extract is not None:
|
|
123
|
+
sass = extract(cubin_path, None)
|
|
124
|
+
pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass)
|
|
125
|
+
return output
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True):
|
|
129
|
+
"""Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1."""
|
|
130
|
+
tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi)
|
|
131
|
+
if fully_dynamic:
|
|
132
|
+
return tensor.mark_layout_dynamic()
|
|
133
|
+
if leading_dim == -1:
|
|
134
|
+
leading_dim = t.ndim - 1
|
|
135
|
+
return tensor.mark_layout_dynamic(leading_dim=leading_dim)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Tri Dao.
|
|
3
|
+
|
|
4
|
+
import cutlass
|
|
5
|
+
import cutlass.cute as cute
|
|
6
|
+
from cutlass import Int32
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@cute.jit
|
|
10
|
+
def clz(x: Int32) -> Int32:
|
|
11
|
+
# for i in cutlass.range_constexpr(32):
|
|
12
|
+
# if (1 << (31 - i)) & x:
|
|
13
|
+
# return Int32(i)
|
|
14
|
+
# return Int32(32)
|
|
15
|
+
# Early exit is not supported yet
|
|
16
|
+
res = Int32(32)
|
|
17
|
+
done = False
|
|
18
|
+
for i in cutlass.range(32):
|
|
19
|
+
if ((1 << (31 - i)) & x) and not done:
|
|
20
|
+
res = Int32(i)
|
|
21
|
+
done = True
|
|
22
|
+
return res
|