mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,754 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Tri Dao.
|
|
3
|
+
from typing import Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import cutlass
|
|
6
|
+
import cutlass.cute as cute
|
|
7
|
+
from cutlass import Int32, Boolean, const_expr
|
|
8
|
+
from cutlass.cute.nvgpu import tcgen05
|
|
9
|
+
from cutlass._mlir.dialects import llvm
|
|
10
|
+
|
|
11
|
+
import mslk.attention.flash_attn.mma_sm100_desc as sm100_desc
|
|
12
|
+
from mslk.attention.flash_attn.utils import parse_swizzle_from_pointer
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@cute.jit
|
|
16
|
+
def gemm_w_idx(
|
|
17
|
+
tiled_mma: cute.TiledMma,
|
|
18
|
+
acc: cute.Tensor,
|
|
19
|
+
tCrA: cute.Tensor,
|
|
20
|
+
tCrB: cute.Tensor,
|
|
21
|
+
A_idx: Optional[Int32] = None,
|
|
22
|
+
B_idx: Optional[Int32] = None,
|
|
23
|
+
zero_init: bool | Boolean = False,
|
|
24
|
+
swap_AB: bool = False,
|
|
25
|
+
) -> None:
|
|
26
|
+
if const_expr(swap_AB):
|
|
27
|
+
return gemm_w_idx(
|
|
28
|
+
tiled_mma, acc, tCrB, tCrA, B_idx, A_idx, zero_init=zero_init, swap_AB=False
|
|
29
|
+
)
|
|
30
|
+
else:
|
|
31
|
+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
|
32
|
+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
|
33
|
+
mma_atom = cute.make_mma_atom(tiled_mma.op)
|
|
34
|
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
|
|
35
|
+
mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0)
|
|
36
|
+
cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@cute.jit
|
|
40
|
+
def gemm_ptx_w_idx(
|
|
41
|
+
tiled_mma: cute.TiledMma,
|
|
42
|
+
acc: cute.Tensor,
|
|
43
|
+
tCrA: cute.Tensor,
|
|
44
|
+
tCrB: cute.Tensor,
|
|
45
|
+
sA: Optional[cute.Tensor],
|
|
46
|
+
sB: cute.Tensor,
|
|
47
|
+
A_idx: Optional[Int32] = None,
|
|
48
|
+
B_idx: Optional[Int32] = None,
|
|
49
|
+
zero_init: bool | Boolean = False,
|
|
50
|
+
**kwargs,
|
|
51
|
+
) -> None:
|
|
52
|
+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
|
53
|
+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
|
54
|
+
sA_cur = None
|
|
55
|
+
if const_expr(sA is not None):
|
|
56
|
+
sA_cur = sA if const_expr(A_idx is None) else sA[None, None, None, A_idx]
|
|
57
|
+
sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx]
|
|
58
|
+
mma_atom = cute.make_mma_atom(tiled_mma.op)
|
|
59
|
+
acc_tmem_addr = acc.iterator.toint()
|
|
60
|
+
gemm_ptx_partial(
|
|
61
|
+
mma_atom.op, acc_tmem_addr, rA, rB, sA_cur, sB_cur, zero_init=zero_init, **kwargs
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@cute.jit
|
|
66
|
+
def gemm(
|
|
67
|
+
tiled_mma: cute.TiledMma,
|
|
68
|
+
acc: cute.Tensor,
|
|
69
|
+
tCrA: cute.Tensor,
|
|
70
|
+
tCrB: cute.Tensor,
|
|
71
|
+
zero_init: bool | Boolean = False,
|
|
72
|
+
) -> cute.TiledMma:
|
|
73
|
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
|
|
74
|
+
tiled_mma.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0)
|
|
75
|
+
cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
|
|
76
|
+
return tiled_mma
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def i64_to_i32x2(i: int) -> Tuple[int, int]:
|
|
80
|
+
"""Convert a 64-bit integer to a tuple of two 32-bit integers."""
|
|
81
|
+
return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@cute.jit
|
|
85
|
+
def gemm_ptx(
|
|
86
|
+
op: cute.nvgpu.tcgen05.mma.MmaOp,
|
|
87
|
+
acc: cute.Tensor,
|
|
88
|
+
tCrA: cute.Tensor,
|
|
89
|
+
tCrB: cute.Tensor,
|
|
90
|
+
sA: Optional[cute.Tensor],
|
|
91
|
+
sB: cute.Tensor,
|
|
92
|
+
zero_init: bool | Boolean = False,
|
|
93
|
+
) -> None:
|
|
94
|
+
is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
|
|
95
|
+
if const_expr(not is_ts):
|
|
96
|
+
assert sA is not None, "sA must be provided when a_src is not TMEM"
|
|
97
|
+
sA_layout = sA.layout if sA is not None else None
|
|
98
|
+
sB_layout = sB.layout
|
|
99
|
+
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
|
|
100
|
+
if const_expr(not is_ts):
|
|
101
|
+
sA_swizzle = parse_swizzle_from_pointer(sA.iterator)
|
|
102
|
+
smem_desc_base_a: int = const_expr(
|
|
103
|
+
sm100_desc.make_smem_desc_base(
|
|
104
|
+
cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
|
|
105
|
+
sA_swizzle,
|
|
106
|
+
sm100_desc.Major.K
|
|
107
|
+
if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
|
108
|
+
else sm100_desc.Major.MN,
|
|
109
|
+
)
|
|
110
|
+
)
|
|
111
|
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
|
112
|
+
smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
|
|
113
|
+
smem_desc_a_hi = const_expr(smem_desc_a_hi)
|
|
114
|
+
else:
|
|
115
|
+
smem_desc_base_a = None
|
|
116
|
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
|
117
|
+
sB_swizzle = parse_swizzle_from_pointer(sB.iterator)
|
|
118
|
+
smem_desc_base_b: int = const_expr(
|
|
119
|
+
sm100_desc.make_smem_desc_base(
|
|
120
|
+
cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
|
|
121
|
+
sB_swizzle,
|
|
122
|
+
sm100_desc.Major.K
|
|
123
|
+
if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
|
124
|
+
else sm100_desc.Major.MN,
|
|
125
|
+
)
|
|
126
|
+
)
|
|
127
|
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
|
128
|
+
smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
|
|
129
|
+
smem_desc_b_hi = const_expr(smem_desc_b_hi)
|
|
130
|
+
|
|
131
|
+
if const_expr(not is_ts):
|
|
132
|
+
smem_desc_start_a_lo = Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr(
|
|
133
|
+
sA[None, None, 0].iterator
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
smem_desc_start_a_lo = None
|
|
137
|
+
smem_desc_start_b_lo = Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr(
|
|
138
|
+
sB[None, None, 0].iterator
|
|
139
|
+
)
|
|
140
|
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
|
|
141
|
+
if const_expr(not is_ts):
|
|
142
|
+
smem_desc_a_lo = smem_desc_start_a_lo + (
|
|
143
|
+
(cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4
|
|
144
|
+
)
|
|
145
|
+
smem_desc_b_lo = smem_desc_start_b_lo + (
|
|
146
|
+
(cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4
|
|
147
|
+
)
|
|
148
|
+
# with cute.arch.elect_one():
|
|
149
|
+
# cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo)
|
|
150
|
+
# cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct)
|
|
151
|
+
with cute.arch.elect_one():
|
|
152
|
+
if const_expr(not is_ts):
|
|
153
|
+
llvm.inline_asm(
|
|
154
|
+
None,
|
|
155
|
+
[
|
|
156
|
+
acc.iterator.toint().ir_value(),
|
|
157
|
+
smem_desc_a_lo.ir_value(),
|
|
158
|
+
smem_desc_b_lo.ir_value(),
|
|
159
|
+
Int32(not zero_init or k != 0).ir_value(),
|
|
160
|
+
],
|
|
161
|
+
"{\n\t"
|
|
162
|
+
".reg .pred p;\n\t"
|
|
163
|
+
".reg .b64 smem_desc_a, smem_desc_b;\n\t"
|
|
164
|
+
".reg .b32 idesc;\n\t"
|
|
165
|
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
|
166
|
+
f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t"
|
|
167
|
+
f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t"
|
|
168
|
+
"setp.ne.b32 p, $3, 0;\n\t"
|
|
169
|
+
f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t"
|
|
170
|
+
"}\n",
|
|
171
|
+
"r,r,r,r",
|
|
172
|
+
has_side_effects=True,
|
|
173
|
+
is_align_stack=False,
|
|
174
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
llvm.inline_asm(
|
|
178
|
+
None,
|
|
179
|
+
[
|
|
180
|
+
acc.iterator.toint().ir_value(),
|
|
181
|
+
tCrA[None, None, k].iterator.toint().ir_value(),
|
|
182
|
+
smem_desc_b_lo.ir_value(),
|
|
183
|
+
Int32(not zero_init or k != 0).ir_value(),
|
|
184
|
+
],
|
|
185
|
+
"{\n\t"
|
|
186
|
+
".reg .pred p;\n\t"
|
|
187
|
+
".reg .b64 smem_desc_b;\n\t"
|
|
188
|
+
f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t"
|
|
189
|
+
"setp.ne.b32 p, $3, 0;\n\t"
|
|
190
|
+
f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t"
|
|
191
|
+
"}\n",
|
|
192
|
+
"r,r,r,r",
|
|
193
|
+
has_side_effects=True,
|
|
194
|
+
is_align_stack=False,
|
|
195
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@cute.jit
|
|
200
|
+
def gemm_ptx_loop(
|
|
201
|
+
op: cute.nvgpu.tcgen05.mma.MmaOp,
|
|
202
|
+
acc: cute.Tensor,
|
|
203
|
+
tCrA: cute.Tensor,
|
|
204
|
+
tCrB: cute.Tensor,
|
|
205
|
+
sA: Optional[cute.Tensor],
|
|
206
|
+
sB: cute.Tensor,
|
|
207
|
+
zero_init: bool | Boolean = False,
|
|
208
|
+
) -> None:
|
|
209
|
+
is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
|
|
210
|
+
if const_expr(not is_ts):
|
|
211
|
+
assert sA is not None, "sA must be provided when a_src is not TMEM"
|
|
212
|
+
sA_layout = sA.layout if sA is not None else tCrA.layout
|
|
213
|
+
sB_layout = sB.layout
|
|
214
|
+
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
|
|
215
|
+
if const_expr(not is_ts):
|
|
216
|
+
sA_swizzle = parse_swizzle_from_pointer(sA.iterator)
|
|
217
|
+
smem_desc_base_a: int = const_expr(
|
|
218
|
+
sm100_desc.make_smem_desc_base(
|
|
219
|
+
cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
|
|
220
|
+
sA_swizzle,
|
|
221
|
+
sm100_desc.Major.K
|
|
222
|
+
if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
|
223
|
+
else sm100_desc.Major.MN,
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
|
227
|
+
smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
|
|
228
|
+
smem_desc_a_hi = const_expr(smem_desc_a_hi)
|
|
229
|
+
else:
|
|
230
|
+
smem_desc_base_a = None
|
|
231
|
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
|
232
|
+
sB_swizzle = parse_swizzle_from_pointer(sB.iterator)
|
|
233
|
+
smem_desc_base_b: int = const_expr(
|
|
234
|
+
sm100_desc.make_smem_desc_base(
|
|
235
|
+
cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
|
|
236
|
+
sB_swizzle,
|
|
237
|
+
sm100_desc.Major.K
|
|
238
|
+
if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
|
239
|
+
else sm100_desc.Major.MN,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
|
243
|
+
smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
|
|
244
|
+
smem_desc_b_hi = const_expr(smem_desc_b_hi)
|
|
245
|
+
|
|
246
|
+
if const_expr(not is_ts):
|
|
247
|
+
offset_a = [
|
|
248
|
+
(cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4
|
|
249
|
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))
|
|
250
|
+
]
|
|
251
|
+
else:
|
|
252
|
+
offset_a = [
|
|
253
|
+
cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32
|
|
254
|
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))
|
|
255
|
+
]
|
|
256
|
+
offset_a_diff = [
|
|
257
|
+
offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
|
|
258
|
+
]
|
|
259
|
+
offset_b = [
|
|
260
|
+
(cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4
|
|
261
|
+
for k in cutlass.range_constexpr(cute.size(tCrB.shape[2]))
|
|
262
|
+
]
|
|
263
|
+
offset_b_diff = [
|
|
264
|
+
offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2]))
|
|
265
|
+
]
|
|
266
|
+
|
|
267
|
+
if const_expr(not is_ts):
|
|
268
|
+
smem_desc_start_a_lo = Int32(
|
|
269
|
+
smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)
|
|
270
|
+
)
|
|
271
|
+
else:
|
|
272
|
+
smem_desc_start_a_lo = None
|
|
273
|
+
smem_desc_start_b_lo = Int32(
|
|
274
|
+
smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)
|
|
275
|
+
)
|
|
276
|
+
pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
|
|
277
|
+
if const_expr(not is_ts):
|
|
278
|
+
llvm.inline_asm(
|
|
279
|
+
None,
|
|
280
|
+
[
|
|
281
|
+
acc.iterator.toint().ir_value(),
|
|
282
|
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
|
|
283
|
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
|
284
|
+
Int32(not zero_init).ir_value(),
|
|
285
|
+
],
|
|
286
|
+
"{\n\t"
|
|
287
|
+
".reg .pred leader_thread;\n\t"
|
|
288
|
+
".reg .pred p;\n\t"
|
|
289
|
+
".reg .b32 idesc;\n\t"
|
|
290
|
+
".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
|
|
291
|
+
".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
|
|
292
|
+
".reg .b64 smem_desc_a, smem_desc_b;\n\t"
|
|
293
|
+
"elect.sync _|leader_thread, -1;\n\t"
|
|
294
|
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
|
295
|
+
"mov.b32 smem_desc_a_lo, $1;\n\t"
|
|
296
|
+
"mov.b32 smem_desc_b_lo, $2;\n\t"
|
|
297
|
+
f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
|
|
298
|
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
|
299
|
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
|
300
|
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
|
301
|
+
"setp.ne.b32 p, $3, 0;\n\t"
|
|
302
|
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
|
|
303
|
+
+ "".join(
|
|
304
|
+
(
|
|
305
|
+
f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
|
|
306
|
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
|
307
|
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
|
308
|
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
|
309
|
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
|
|
310
|
+
)
|
|
311
|
+
for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
|
|
312
|
+
)
|
|
313
|
+
+ "}\n",
|
|
314
|
+
"r,r,r,r",
|
|
315
|
+
has_side_effects=True,
|
|
316
|
+
is_align_stack=False,
|
|
317
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
318
|
+
)
|
|
319
|
+
else:
|
|
320
|
+
llvm.inline_asm(
|
|
321
|
+
None,
|
|
322
|
+
[
|
|
323
|
+
acc.iterator.toint().ir_value(),
|
|
324
|
+
Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
|
|
325
|
+
Int32(smem_desc_start_b_lo).ir_value(),
|
|
326
|
+
Int32(not zero_init).ir_value(),
|
|
327
|
+
],
|
|
328
|
+
"{\n\t"
|
|
329
|
+
".reg .pred leader_thread;\n\t"
|
|
330
|
+
".reg .pred p;\n\t"
|
|
331
|
+
".reg .b32 idesc;\n\t"
|
|
332
|
+
".reg .b32 tmem_a;\n\t"
|
|
333
|
+
".reg .b32 smem_desc_b_lo;\n\t"
|
|
334
|
+
".reg .b32 smem_desc_b_hi;\n\t"
|
|
335
|
+
".reg .b64 smem_desc_b;\n\t"
|
|
336
|
+
"elect.sync _|leader_thread, -1;\n\t"
|
|
337
|
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
|
338
|
+
"mov.b32 tmem_a, $1;\n\t"
|
|
339
|
+
"mov.b32 smem_desc_b_lo, $2;\n\t"
|
|
340
|
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
|
341
|
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
|
342
|
+
"setp.ne.b32 p, $3, 0;\n\t"
|
|
343
|
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
|
|
344
|
+
+ "".join(
|
|
345
|
+
(
|
|
346
|
+
# f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
|
|
347
|
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
|
348
|
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
|
349
|
+
# f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\n\t"
|
|
350
|
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
|
|
351
|
+
)
|
|
352
|
+
for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
|
|
353
|
+
)
|
|
354
|
+
+ "}\n",
|
|
355
|
+
"r,r,r,r",
|
|
356
|
+
has_side_effects=True,
|
|
357
|
+
is_align_stack=False,
|
|
358
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
@cute.jit
|
|
363
|
+
def gemm_ptx_partial(
|
|
364
|
+
op: cute.nvgpu.tcgen05.mma.MmaOp,
|
|
365
|
+
acc_tmem_addr: Int32,
|
|
366
|
+
tCrA: cute.Tensor,
|
|
367
|
+
tCrB: cute.Tensor,
|
|
368
|
+
sA: Optional[cute.Tensor],
|
|
369
|
+
sB: cute.Tensor,
|
|
370
|
+
mbar_ptr: Optional[cutlass.Pointer] = None,
|
|
371
|
+
mbar_phase: Optional[Int32] = None,
|
|
372
|
+
zero_init: bool | Boolean = False,
|
|
373
|
+
# sA_offset: Int32 = 0,
|
|
374
|
+
# acc_offset: Int32 = 0,
|
|
375
|
+
tA_addr: Optional[Int32] = None,
|
|
376
|
+
) -> None:
|
|
377
|
+
# acc_tmem_addr += acc_offset
|
|
378
|
+
is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
|
|
379
|
+
if const_expr(not is_ts):
|
|
380
|
+
assert sA is not None, "sA must be provided when a_src is not TMEM"
|
|
381
|
+
sA_layout = sA.layout if sA is not None else tCrA.layout
|
|
382
|
+
sB_layout = sB.layout
|
|
383
|
+
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
|
|
384
|
+
if const_expr(not is_ts):
|
|
385
|
+
sA_swizzle = parse_swizzle_from_pointer(sA.iterator)
|
|
386
|
+
smem_desc_base_a: int = const_expr(
|
|
387
|
+
sm100_desc.make_smem_desc_base(
|
|
388
|
+
cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
|
|
389
|
+
sA_swizzle,
|
|
390
|
+
sm100_desc.Major.K
|
|
391
|
+
if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
|
392
|
+
else sm100_desc.Major.MN,
|
|
393
|
+
)
|
|
394
|
+
)
|
|
395
|
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
|
396
|
+
smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
|
|
397
|
+
smem_desc_a_hi = const_expr(smem_desc_a_hi)
|
|
398
|
+
else:
|
|
399
|
+
smem_desc_base_a = None
|
|
400
|
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
|
401
|
+
sB_swizzle = parse_swizzle_from_pointer(sB.iterator)
|
|
402
|
+
smem_desc_base_b: int = const_expr(
|
|
403
|
+
sm100_desc.make_smem_desc_base(
|
|
404
|
+
cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
|
|
405
|
+
sB_swizzle,
|
|
406
|
+
sm100_desc.Major.K
|
|
407
|
+
if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
|
408
|
+
else sm100_desc.Major.MN,
|
|
409
|
+
)
|
|
410
|
+
)
|
|
411
|
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
|
412
|
+
smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
|
|
413
|
+
smem_desc_b_hi = const_expr(smem_desc_b_hi)
|
|
414
|
+
|
|
415
|
+
tCrA_layout = (
|
|
416
|
+
tCrA.layout
|
|
417
|
+
if const_expr(not is_ts)
|
|
418
|
+
else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout)
|
|
419
|
+
)
|
|
420
|
+
offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))]
|
|
421
|
+
offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))]
|
|
422
|
+
offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))]
|
|
423
|
+
offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))]
|
|
424
|
+
|
|
425
|
+
if const_expr(not is_ts):
|
|
426
|
+
smem_desc_start_a_lo = Int32(
|
|
427
|
+
smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)
|
|
428
|
+
)
|
|
429
|
+
# ) + sA_offset
|
|
430
|
+
else:
|
|
431
|
+
smem_desc_start_a_lo = None
|
|
432
|
+
smem_desc_start_b_lo = Int32(
|
|
433
|
+
smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)
|
|
434
|
+
)
|
|
435
|
+
pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
|
|
436
|
+
if const_expr(not is_ts):
|
|
437
|
+
assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM"
|
|
438
|
+
llvm.inline_asm(
|
|
439
|
+
None,
|
|
440
|
+
[
|
|
441
|
+
# acc.iterator.toint().ir_value(),
|
|
442
|
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
|
|
443
|
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
|
444
|
+
Int32(not zero_init).ir_value(),
|
|
445
|
+
Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
|
|
446
|
+
],
|
|
447
|
+
"{\n\t"
|
|
448
|
+
".reg .pred leader_thread;\n\t"
|
|
449
|
+
".reg .pred p;\n\t"
|
|
450
|
+
".reg .b32 idesc;\n\t"
|
|
451
|
+
".reg .b32 tmem_acc;\n\t"
|
|
452
|
+
".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t"
|
|
453
|
+
".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
|
|
454
|
+
".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
|
|
455
|
+
".reg .b64 smem_desc_a, smem_desc_b;\n\t"
|
|
456
|
+
"elect.sync _|leader_thread, -1;\n\t"
|
|
457
|
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
|
458
|
+
# f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
|
|
459
|
+
f"mov.b32 tmem_acc, $3;\n\t"
|
|
460
|
+
"mov.b32 smem_desc_a_lo_start, $0;\n\t"
|
|
461
|
+
"mov.b32 smem_desc_b_lo_start, $1;\n\t"
|
|
462
|
+
f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
|
|
463
|
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
|
464
|
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t"
|
|
465
|
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
|
|
466
|
+
"setp.ne.b32 p, $2, 0;\n\t"
|
|
467
|
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
|
|
468
|
+
+ "".join(
|
|
469
|
+
(
|
|
470
|
+
# f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
|
|
471
|
+
# f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
|
472
|
+
f"add.u32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t"
|
|
473
|
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
|
474
|
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
|
475
|
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
|
476
|
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
|
|
477
|
+
)
|
|
478
|
+
for k in range(1, cute.size(tCrA.shape[2]))
|
|
479
|
+
)
|
|
480
|
+
+ "}\n",
|
|
481
|
+
# "r,r,r",
|
|
482
|
+
"r,r,r,r",
|
|
483
|
+
has_side_effects=True,
|
|
484
|
+
is_align_stack=False,
|
|
485
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
486
|
+
)
|
|
487
|
+
else:
|
|
488
|
+
# For TS gemm, somehow tCrA.iterator.toint() returns 0 no matter what, so we need to
|
|
489
|
+
# explicitly pass in the tA_addr for correctness.
|
|
490
|
+
tA_addr = tCrA[None, None, 0].iterator.toint() if tA_addr is None else tA_addr
|
|
491
|
+
input_args = [
|
|
492
|
+
# Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(),
|
|
493
|
+
Int32(cute.arch.make_warp_uniform(tA_addr)).ir_value(),
|
|
494
|
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
|
495
|
+
Int32(not zero_init).ir_value(),
|
|
496
|
+
Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
|
|
497
|
+
]
|
|
498
|
+
if const_expr(mbar_ptr is not None):
|
|
499
|
+
assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None"
|
|
500
|
+
input_args.append(mbar_ptr.toint().ir_value())
|
|
501
|
+
input_args.append(Int32(mbar_phase).ir_value())
|
|
502
|
+
mbar_wait_str = (
|
|
503
|
+
".reg .pred P1; \n\t"
|
|
504
|
+
"LAB_WAIT: \n\t"
|
|
505
|
+
"mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t"
|
|
506
|
+
"@P1 bra DONE; \n\t"
|
|
507
|
+
"bra LAB_WAIT; \n\t"
|
|
508
|
+
"DONE: \n\t"
|
|
509
|
+
)
|
|
510
|
+
else:
|
|
511
|
+
mbar_wait_str = ""
|
|
512
|
+
llvm.inline_asm(
|
|
513
|
+
None,
|
|
514
|
+
# [
|
|
515
|
+
# # acc.iterator.toint().ir_value(),
|
|
516
|
+
# Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
|
|
517
|
+
# Int32(smem_desc_start_b_lo).ir_value(),
|
|
518
|
+
# Int32(not zero_init).ir_value(),
|
|
519
|
+
# ],
|
|
520
|
+
input_args,
|
|
521
|
+
"{\n\t"
|
|
522
|
+
".reg .pred leader_thread;\n\t"
|
|
523
|
+
".reg .pred p;\n\t"
|
|
524
|
+
".reg .b32 idesc;\n\t"
|
|
525
|
+
".reg .b32 tmem_acc;\n\t"
|
|
526
|
+
".reg .b32 tmem_a;\n\t"
|
|
527
|
+
".reg .b32 smem_desc_b_lo_start;\n\t"
|
|
528
|
+
".reg .b32 smem_desc_b_lo;\n\t"
|
|
529
|
+
".reg .b32 smem_desc_b_hi;\n\t"
|
|
530
|
+
".reg .b64 smem_desc_b;\n\t"
|
|
531
|
+
"elect.sync _|leader_thread, -1;\n\t"
|
|
532
|
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
|
533
|
+
# f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
|
|
534
|
+
f"mov.b32 tmem_acc, $3;\n\t"
|
|
535
|
+
f"mov.b32 tmem_a, $0;\n\t"
|
|
536
|
+
f"mov.b32 smem_desc_b_lo_start, $1;\n\t"
|
|
537
|
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
|
538
|
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
|
|
539
|
+
"setp.ne.b32 p, $2, 0;\n\t"
|
|
540
|
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
|
|
541
|
+
+ "".join(
|
|
542
|
+
(
|
|
543
|
+
# f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
|
|
544
|
+
# f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
|
545
|
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
|
546
|
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
|
547
|
+
# f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t"
|
|
548
|
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
|
|
549
|
+
)
|
|
550
|
+
for k in range(
|
|
551
|
+
1,
|
|
552
|
+
cute.size(tCrA.shape[2])
|
|
553
|
+
if const_expr(mbar_ptr is None)
|
|
554
|
+
else cute.size(tCrA.shape[2]) // 4 * 3,
|
|
555
|
+
)
|
|
556
|
+
)
|
|
557
|
+
+ mbar_wait_str
|
|
558
|
+
+ (
|
|
559
|
+
"".join(
|
|
560
|
+
(
|
|
561
|
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
|
562
|
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
|
563
|
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
|
|
564
|
+
)
|
|
565
|
+
for k in range(cute.size(tCrA.shape[2]) // 4 * 3, cute.size(tCrA.shape[2]))
|
|
566
|
+
)
|
|
567
|
+
if const_expr(mbar_ptr is not None)
|
|
568
|
+
else ""
|
|
569
|
+
)
|
|
570
|
+
+ "}\n",
|
|
571
|
+
"r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r",
|
|
572
|
+
has_side_effects=True,
|
|
573
|
+
is_align_stack=False,
|
|
574
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
@cute.jit
|
|
579
|
+
def gemm_ptx_partial1(
|
|
580
|
+
op: cute.nvgpu.tcgen05.mma.MmaOp,
|
|
581
|
+
acc_tmem_addr: cutlass.Constexpr[int],
|
|
582
|
+
tCrA: cute.Tensor,
|
|
583
|
+
tCrB: cute.Tensor,
|
|
584
|
+
sA_base_addr_for_desc: Int32,
|
|
585
|
+
sA_addr_offset_for_desc: cutlass.Constexpr[int],
|
|
586
|
+
sA_stage: Int32,
|
|
587
|
+
sB_base_addr_for_desc: Int32,
|
|
588
|
+
sB_addr_offset_for_desc: cutlass.Constexpr[int],
|
|
589
|
+
sB_stage: Int32,
|
|
590
|
+
sA_layout: Optional[cute.Layout],
|
|
591
|
+
sB_layout: Optional[cute.Layout],
|
|
592
|
+
sA_swizzle: Optional[cute.Swizzle],
|
|
593
|
+
sB_swizzle: cute.Swizzle,
|
|
594
|
+
zero_init: bool | Boolean = False,
|
|
595
|
+
) -> None:
|
|
596
|
+
is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
|
|
597
|
+
if const_expr(not is_ts):
|
|
598
|
+
assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM"
|
|
599
|
+
assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM"
|
|
600
|
+
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
|
|
601
|
+
if const_expr(not is_ts):
|
|
602
|
+
smem_desc_base_a: int = const_expr(
|
|
603
|
+
sm100_desc.make_smem_desc_base(
|
|
604
|
+
cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
|
|
605
|
+
sA_swizzle,
|
|
606
|
+
sm100_desc.Major.K
|
|
607
|
+
if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
|
608
|
+
else sm100_desc.Major.MN,
|
|
609
|
+
)
|
|
610
|
+
)
|
|
611
|
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
|
612
|
+
smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
|
|
613
|
+
smem_desc_a_hi = const_expr(smem_desc_a_hi)
|
|
614
|
+
else:
|
|
615
|
+
smem_desc_base_a = None
|
|
616
|
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
|
617
|
+
smem_desc_base_b: int = const_expr(
|
|
618
|
+
sm100_desc.make_smem_desc_base(
|
|
619
|
+
cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
|
|
620
|
+
sB_swizzle,
|
|
621
|
+
sm100_desc.Major.K
|
|
622
|
+
if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
|
623
|
+
else sm100_desc.Major.MN,
|
|
624
|
+
)
|
|
625
|
+
)
|
|
626
|
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
|
627
|
+
smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
|
|
628
|
+
smem_desc_b_hi = const_expr(smem_desc_b_hi)
|
|
629
|
+
mask = [Int32(0)] * 4
|
|
630
|
+
|
|
631
|
+
if const_expr(not is_ts):
|
|
632
|
+
offset_a = [
|
|
633
|
+
(cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4
|
|
634
|
+
for k in range(cute.size(tCrA.shape[2]))
|
|
635
|
+
]
|
|
636
|
+
else:
|
|
637
|
+
offset_a = [
|
|
638
|
+
cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32
|
|
639
|
+
for k in range(cute.size(tCrA.shape[2]))
|
|
640
|
+
]
|
|
641
|
+
offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))]
|
|
642
|
+
offset_b = [
|
|
643
|
+
(cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4
|
|
644
|
+
for k in range(cute.size(tCrB.shape[2]))
|
|
645
|
+
]
|
|
646
|
+
offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))]
|
|
647
|
+
|
|
648
|
+
if const_expr(not is_ts):
|
|
649
|
+
# smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator))
|
|
650
|
+
smem_desc_start_a_lo = const_expr(smem_desc_base_a_lo)
|
|
651
|
+
else:
|
|
652
|
+
smem_desc_start_a_lo = None
|
|
653
|
+
# smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator))
|
|
654
|
+
smem_desc_start_b_lo = const_expr(smem_desc_base_b_lo)
|
|
655
|
+
pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
|
|
656
|
+
if const_expr(not is_ts):
|
|
657
|
+
llvm.inline_asm(
|
|
658
|
+
None,
|
|
659
|
+
[
|
|
660
|
+
# acc.iterator.toint().ir_value(),
|
|
661
|
+
# Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
|
|
662
|
+
Int32(sA_base_addr_for_desc).ir_value(),
|
|
663
|
+
Int32(sA_stage).ir_value(),
|
|
664
|
+
# Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
|
665
|
+
Int32(sB_base_addr_for_desc).ir_value(),
|
|
666
|
+
Int32(sB_stage).ir_value(),
|
|
667
|
+
Int32(not zero_init).ir_value(),
|
|
668
|
+
mask[0].ir_value(),
|
|
669
|
+
mask[1].ir_value(),
|
|
670
|
+
mask[2].ir_value(),
|
|
671
|
+
mask[3].ir_value(),
|
|
672
|
+
],
|
|
673
|
+
"{\n\t"
|
|
674
|
+
".reg .pred leader_thread;\n\t"
|
|
675
|
+
".reg .pred p;\n\t"
|
|
676
|
+
".reg .b32 idesc;\n\t"
|
|
677
|
+
".reg .b32 tmem_acc;\n\t"
|
|
678
|
+
".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
|
|
679
|
+
".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
|
|
680
|
+
".reg .b64 smem_desc_a, smem_desc_b;\n\t"
|
|
681
|
+
"elect.sync _|leader_thread, -1;\n\t"
|
|
682
|
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
|
683
|
+
f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
|
|
684
|
+
# "mov.b32 smem_desc_a_lo, $0;\n\t"
|
|
685
|
+
# f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t"
|
|
686
|
+
f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t"
|
|
687
|
+
# "mov.b32 smem_desc_b_lo, $2;\n\t"
|
|
688
|
+
f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t"
|
|
689
|
+
f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
|
|
690
|
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
|
691
|
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
|
692
|
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
|
693
|
+
"setp.ne.b32 p, $4, 0;\n\t"
|
|
694
|
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t"
|
|
695
|
+
+ "".join(
|
|
696
|
+
(
|
|
697
|
+
f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
|
|
698
|
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
|
699
|
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
|
700
|
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
|
701
|
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t"
|
|
702
|
+
)
|
|
703
|
+
for k in range(1, cute.size(tCrA.shape[2]))
|
|
704
|
+
)
|
|
705
|
+
+ "}\n",
|
|
706
|
+
"r,r,r,r,r,r,r,r,r",
|
|
707
|
+
has_side_effects=True,
|
|
708
|
+
is_align_stack=False,
|
|
709
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
710
|
+
)
|
|
711
|
+
else:
|
|
712
|
+
llvm.inline_asm(
|
|
713
|
+
None,
|
|
714
|
+
[
|
|
715
|
+
# acc.iterator.toint().ir_value(),
|
|
716
|
+
Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
|
|
717
|
+
Int32(smem_desc_start_b_lo).ir_value(),
|
|
718
|
+
Int32(not zero_init).ir_value(),
|
|
719
|
+
mask[0].ir_value(),
|
|
720
|
+
mask[1].ir_value(),
|
|
721
|
+
mask[2].ir_value(),
|
|
722
|
+
mask[3].ir_value(),
|
|
723
|
+
],
|
|
724
|
+
"{\n\t"
|
|
725
|
+
".reg .pred leader_thread;\n\t"
|
|
726
|
+
".reg .pred p;\n\t"
|
|
727
|
+
".reg .b32 idesc;\n\t"
|
|
728
|
+
".reg .b32 tmem_a;\n\t"
|
|
729
|
+
".reg .b32 smem_desc_b_lo;\n\t"
|
|
730
|
+
".reg .b32 smem_desc_b_hi;\n\t"
|
|
731
|
+
".reg .b64 smem_desc_b;\n\t"
|
|
732
|
+
"elect.sync _|leader_thread, -1;\n\t"
|
|
733
|
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
|
734
|
+
f"mov.b32 tmem_a, $1;\n\t"
|
|
735
|
+
f"mov.b32 smem_desc_b_lo, $2;\n\t"
|
|
736
|
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
|
737
|
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
|
738
|
+
"setp.ne.b32 p, $3, 0;\n\t"
|
|
739
|
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t"
|
|
740
|
+
+ "".join(
|
|
741
|
+
(
|
|
742
|
+
f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
|
|
743
|
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
|
744
|
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
|
745
|
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t"
|
|
746
|
+
)
|
|
747
|
+
for k in range(1, cute.size(tCrA.shape[2]))
|
|
748
|
+
)
|
|
749
|
+
+ "}\n",
|
|
750
|
+
"r,r,r,r,r,r,r,r",
|
|
751
|
+
has_side_effects=True,
|
|
752
|
+
is_align_stack=False,
|
|
753
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
754
|
+
)
|