mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1132 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
import functools
|
|
10
|
+
import warnings
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import triton
|
|
15
|
+
import triton.language as tl
|
|
16
|
+
from triton.runtime import driver # @manual
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
# @manual=//triton:triton
|
|
20
|
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
|
21
|
+
|
|
22
|
+
TMA_AVAILABLE = True
|
|
23
|
+
except ImportError:
|
|
24
|
+
TMA_AVAILABLE = False
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _grouped_gemm_set_block_size_hook(nargs):
|
|
29
|
+
BLOCK_M = nargs["BLOCK_SIZE_M"]
|
|
30
|
+
BLOCK_N = nargs["BLOCK_SIZE_N"]
|
|
31
|
+
BLOCK_K = nargs["BLOCK_SIZE_K"]
|
|
32
|
+
if nargs["USE_TMA_LOAD"]:
|
|
33
|
+
nargs["a_desc_ptr"].block_shape = [BLOCK_M, BLOCK_K]
|
|
34
|
+
nargs["b_desc_ptr"].block_shape = [BLOCK_N, BLOCK_K]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
_NV_CONFIGS = [
|
|
38
|
+
triton.Config(
|
|
39
|
+
{
|
|
40
|
+
"BLOCK_SIZE_M": block_size_m,
|
|
41
|
+
"BLOCK_SIZE_N": block_size_n,
|
|
42
|
+
"BLOCK_SIZE_K": block_size_k,
|
|
43
|
+
"NUM_CONSUMER_GROUPS": 1,
|
|
44
|
+
},
|
|
45
|
+
num_stages=num_stages,
|
|
46
|
+
num_warps=num_warps,
|
|
47
|
+
num_ctas=num_ctas,
|
|
48
|
+
pre_hook=_grouped_gemm_set_block_size_hook,
|
|
49
|
+
)
|
|
50
|
+
for block_size_m in [64, 128]
|
|
51
|
+
for block_size_n in [64, 128, 256]
|
|
52
|
+
for block_size_k in [64, 128, 256]
|
|
53
|
+
for num_stages in [3, 4]
|
|
54
|
+
for num_warps in [4, 8]
|
|
55
|
+
for num_ctas in [1]
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
if TMA_AVAILABLE:
|
|
59
|
+
_NV_WS_CONFIGS = [
|
|
60
|
+
triton.Config(
|
|
61
|
+
{
|
|
62
|
+
"BLOCK_SIZE_M": block_size_m,
|
|
63
|
+
"BLOCK_SIZE_N": block_size_n,
|
|
64
|
+
"BLOCK_SIZE_K": block_size_k,
|
|
65
|
+
"NUM_CONSUMER_GROUPS": 1,
|
|
66
|
+
"USE_TMA_STORE": use_tma_store,
|
|
67
|
+
},
|
|
68
|
+
num_stages=num_stages,
|
|
69
|
+
num_warps=num_warps,
|
|
70
|
+
num_ctas=num_ctas,
|
|
71
|
+
pre_hook=_grouped_gemm_set_block_size_hook,
|
|
72
|
+
)
|
|
73
|
+
for block_size_m in [64, 128, 256]
|
|
74
|
+
for block_size_n in [64, 128, 256]
|
|
75
|
+
for block_size_k in [64, 128, 256]
|
|
76
|
+
for num_stages in [2, 3, 4]
|
|
77
|
+
for num_warps in [4, 8, 16]
|
|
78
|
+
for num_ctas in [1]
|
|
79
|
+
for use_tma_store in [False]
|
|
80
|
+
]
|
|
81
|
+
else:
|
|
82
|
+
_NV_WS_CONFIGS = _NV_CONFIGS
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
_AMD_CONFIGS = [
|
|
86
|
+
triton.Config(
|
|
87
|
+
{
|
|
88
|
+
"BLOCK_SIZE_M": block_size_m,
|
|
89
|
+
"BLOCK_SIZE_N": block_size_n,
|
|
90
|
+
"BLOCK_SIZE_K": block_size_k,
|
|
91
|
+
"waves_per_eu": waves_per_cu,
|
|
92
|
+
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
|
93
|
+
"NUM_CONSUMER_GROUPS": 1,
|
|
94
|
+
},
|
|
95
|
+
num_stages=num_stages,
|
|
96
|
+
num_warps=num_warps,
|
|
97
|
+
)
|
|
98
|
+
for block_size_m in [32, 64, 128]
|
|
99
|
+
for block_size_n in [32, 64, 128, 256]
|
|
100
|
+
for block_size_k in [128, 256]
|
|
101
|
+
for num_stages in [1, 2]
|
|
102
|
+
for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)]
|
|
103
|
+
for matrix_instr_nonkdim in [16]
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
|
|
108
|
+
device = torch.cuda.current_device()
|
|
109
|
+
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
|
110
|
+
if dtsize is None:
|
|
111
|
+
dtsize = named_args["c_ptr"].element_size()
|
|
112
|
+
if dtype is None:
|
|
113
|
+
dtype = named_args["c_ptr"].dtype
|
|
114
|
+
|
|
115
|
+
pruned_configs = []
|
|
116
|
+
for config in configs:
|
|
117
|
+
kw = config.kwargs
|
|
118
|
+
(
|
|
119
|
+
BLOCK_M,
|
|
120
|
+
BLOCK_N,
|
|
121
|
+
BLOCK_K,
|
|
122
|
+
num_stages,
|
|
123
|
+
use_tma_load_on_scales,
|
|
124
|
+
) = (
|
|
125
|
+
kw["BLOCK_SIZE_M"],
|
|
126
|
+
kw["BLOCK_SIZE_N"],
|
|
127
|
+
kw["BLOCK_SIZE_K"],
|
|
128
|
+
config.num_stages,
|
|
129
|
+
kw.get("USE_TMA_LOAD_ON_SCALES", False),
|
|
130
|
+
)
|
|
131
|
+
G, M, N = (
|
|
132
|
+
named_args["G"],
|
|
133
|
+
named_args["M_BUCKET"],
|
|
134
|
+
named_args["N"],
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# 1. make sure we have enough smem
|
|
138
|
+
max_shared_memory = driver.active.utils.get_device_properties(device)[
|
|
139
|
+
"max_shared_mem"
|
|
140
|
+
]
|
|
141
|
+
if torch.version.hip:
|
|
142
|
+
required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
|
|
143
|
+
else:
|
|
144
|
+
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
|
145
|
+
if required_shared_memory > max_shared_memory:
|
|
146
|
+
continue
|
|
147
|
+
|
|
148
|
+
M_PER_GROUP = M // G
|
|
149
|
+
MIN_M_TILES = 32 if torch.version.hip else 64
|
|
150
|
+
# 2. make sure we don't load M tiles that are too big
|
|
151
|
+
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
|
|
152
|
+
continue
|
|
153
|
+
# 3. make sure we don't load N tiles that are too small
|
|
154
|
+
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
|
155
|
+
continue
|
|
156
|
+
|
|
157
|
+
num_sm = driver.active.utils.get_device_properties(device)[
|
|
158
|
+
"multiprocessor_count"
|
|
159
|
+
]
|
|
160
|
+
N_TILES = (N + BLOCK_N - 1) // BLOCK_N
|
|
161
|
+
MIN_N_TILES = 32 if torch.version.hip else 64
|
|
162
|
+
# 4. make sure we don't load N tiles that are too big
|
|
163
|
+
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
|
|
164
|
+
continue
|
|
165
|
+
# 5. make sure we don't load N tiles that are too small
|
|
166
|
+
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
|
167
|
+
continue
|
|
168
|
+
if dtsize >= 2:
|
|
169
|
+
if use_tma_load_on_scales:
|
|
170
|
+
continue
|
|
171
|
+
pruned_configs.append(config)
|
|
172
|
+
|
|
173
|
+
return pruned_configs
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def early_config_prune_ws(configs, named_args, dtsize=None, dtype=None, **kwargs):
|
|
177
|
+
device = torch.cuda.current_device()
|
|
178
|
+
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
|
179
|
+
if dtsize is None:
|
|
180
|
+
dtsize = named_args["c_ptr"].element_size()
|
|
181
|
+
if dtype is None:
|
|
182
|
+
dtype = named_args["c_ptr"].dtype
|
|
183
|
+
|
|
184
|
+
pruned_configs = []
|
|
185
|
+
for config in configs:
|
|
186
|
+
kw = config.kwargs
|
|
187
|
+
(
|
|
188
|
+
BLOCK_M,
|
|
189
|
+
BLOCK_N,
|
|
190
|
+
BLOCK_K,
|
|
191
|
+
num_stages,
|
|
192
|
+
use_tma_load_on_scales,
|
|
193
|
+
) = (
|
|
194
|
+
kw["BLOCK_SIZE_M"],
|
|
195
|
+
kw["BLOCK_SIZE_N"],
|
|
196
|
+
kw["BLOCK_SIZE_K"],
|
|
197
|
+
config.num_stages,
|
|
198
|
+
kw.get("USE_TMA_LOAD_ON_SCALES", False),
|
|
199
|
+
)
|
|
200
|
+
G, M, N = (
|
|
201
|
+
named_args["G"],
|
|
202
|
+
named_args["M_BUCKET"],
|
|
203
|
+
named_args["N"],
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# 1. make sure we have enough smem
|
|
207
|
+
max_shared_memory = driver.active.utils.get_device_properties(device)[
|
|
208
|
+
"max_shared_mem"
|
|
209
|
+
]
|
|
210
|
+
if torch.version.hip:
|
|
211
|
+
required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
|
|
212
|
+
else:
|
|
213
|
+
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
|
214
|
+
if required_shared_memory > max_shared_memory:
|
|
215
|
+
continue
|
|
216
|
+
|
|
217
|
+
M_PER_GROUP = M // G
|
|
218
|
+
MIN_M_TILES = 32 if torch.version.hip else 64
|
|
219
|
+
# 2. make sure we don't load M tiles that are too big
|
|
220
|
+
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
|
|
221
|
+
continue
|
|
222
|
+
# 3. make sure we don't load N tiles that are too small
|
|
223
|
+
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
|
224
|
+
continue
|
|
225
|
+
|
|
226
|
+
num_sm = driver.active.utils.get_device_properties(device)[
|
|
227
|
+
"multiprocessor_count"
|
|
228
|
+
]
|
|
229
|
+
N_TILES = (N + BLOCK_N - 1) // BLOCK_N
|
|
230
|
+
MIN_N_TILES = 32 if torch.version.hip else 64
|
|
231
|
+
# 4. make sure we don't load N tiles that are too big
|
|
232
|
+
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
|
|
233
|
+
continue
|
|
234
|
+
# 5. make sure we don't load N tiles that are too small
|
|
235
|
+
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
if dtsize >= 2:
|
|
239
|
+
if use_tma_load_on_scales:
|
|
240
|
+
continue
|
|
241
|
+
pruned_configs.append(config)
|
|
242
|
+
|
|
243
|
+
return pruned_configs
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
@triton.autotune(
|
|
247
|
+
configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
|
|
248
|
+
key=["G", "M_BUCKET", "N", "K"],
|
|
249
|
+
prune_configs_by={"early_config_prune": early_config_prune},
|
|
250
|
+
restore_value=["c_ptr"], # restore for scatter_add fusion
|
|
251
|
+
)
|
|
252
|
+
@triton.jit
|
|
253
|
+
def _mslk_grouped_gemm(
|
|
254
|
+
a_desc_ptr,
|
|
255
|
+
b_desc_ptr,
|
|
256
|
+
c_ptr,
|
|
257
|
+
scatter_add_indices,
|
|
258
|
+
m_sizes,
|
|
259
|
+
bias_ptr,
|
|
260
|
+
token_weights_ptr,
|
|
261
|
+
# problem sizes
|
|
262
|
+
G: tl.constexpr,
|
|
263
|
+
M_BUCKET,
|
|
264
|
+
N: tl.constexpr,
|
|
265
|
+
K: tl.constexpr,
|
|
266
|
+
NUM_SMS: tl.constexpr,
|
|
267
|
+
FUSE_SCATTER_ADD: tl.constexpr,
|
|
268
|
+
USE_TMA_LOAD: tl.constexpr,
|
|
269
|
+
USE_TMA_STORE: tl.constexpr,
|
|
270
|
+
USE_FAST_ACCUM: tl.constexpr,
|
|
271
|
+
HAS_BIAS: tl.constexpr,
|
|
272
|
+
HAS_TOKEN_WEIGHTS: tl.constexpr,
|
|
273
|
+
# tile sizes
|
|
274
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
275
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
276
|
+
BLOCK_SIZE_K: tl.constexpr,
|
|
277
|
+
NUM_CONSUMER_GROUPS: tl.constexpr,
|
|
278
|
+
) -> None:
|
|
279
|
+
tl.static_assert(
|
|
280
|
+
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
|
|
281
|
+
"Cannot fuse scatter add with TMA store!",
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
tidx = tl.program_id(0)
|
|
285
|
+
|
|
286
|
+
M_end_offset = 0
|
|
287
|
+
M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
|
|
288
|
+
iterated_tiles = 0
|
|
289
|
+
for g in tl.range(G):
|
|
290
|
+
# Move across groups
|
|
291
|
+
m_size = tl.load(m_sizes + g)
|
|
292
|
+
|
|
293
|
+
if m_size > 0:
|
|
294
|
+
M_start_offset = M_end_offset
|
|
295
|
+
M_end_offset = M_start_offset + m_size
|
|
296
|
+
N_start_offset = g.to(tl.int64) * N
|
|
297
|
+
n_size = N
|
|
298
|
+
|
|
299
|
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
|
300
|
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
|
301
|
+
num_tiles = num_m_tiles * num_n_tiles
|
|
302
|
+
|
|
303
|
+
if USE_TMA_STORE:
|
|
304
|
+
c_desc_ptr = tl.make_tensor_descriptor(
|
|
305
|
+
c_ptr + M_start_offset * N,
|
|
306
|
+
shape=[m_size, n_size],
|
|
307
|
+
# pyre-ignore
|
|
308
|
+
strides=[n_size, 1],
|
|
309
|
+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# Move across tiles
|
|
313
|
+
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
|
|
314
|
+
gidx = tidx - iterated_tiles
|
|
315
|
+
# Split M first and N second.
|
|
316
|
+
tile_m_idx = gidx % num_m_tiles
|
|
317
|
+
tile_n_idx = gidx // num_m_tiles
|
|
318
|
+
|
|
319
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
320
|
+
|
|
321
|
+
if USE_TMA_LOAD:
|
|
322
|
+
tl.static_assert(K % BLOCK_SIZE_K == 0)
|
|
323
|
+
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
324
|
+
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
325
|
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
326
|
+
a = a_desc_ptr.load([m_offset, k_offset])
|
|
327
|
+
b = b_desc_ptr.load([n_offset, k_offset])
|
|
328
|
+
if USE_FAST_ACCUM:
|
|
329
|
+
accumulator = tl.dot(a, b.T, accumulator)
|
|
330
|
+
else:
|
|
331
|
+
accumulator += tl.dot(a, b.T)
|
|
332
|
+
else:
|
|
333
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
334
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
335
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
336
|
+
a_ptrs = (
|
|
337
|
+
a_desc_ptr
|
|
338
|
+
+ (M_start_offset + offs_am[:, None]) * K
|
|
339
|
+
+ offs_k[None, :]
|
|
340
|
+
)
|
|
341
|
+
b_ptrs = (
|
|
342
|
+
b_desc_ptr
|
|
343
|
+
+ (N_start_offset + offs_bn[:, None]) * K
|
|
344
|
+
+ offs_k[None, :]
|
|
345
|
+
)
|
|
346
|
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
347
|
+
updated_k_offset = k_offset + offs_k
|
|
348
|
+
updated_k_offset_mask = updated_k_offset[None, :] < K # type: ignore[16]
|
|
349
|
+
a = tl.load(
|
|
350
|
+
a_ptrs,
|
|
351
|
+
mask=((offs_am[:, None] < m_size) & updated_k_offset_mask),
|
|
352
|
+
other=0.0,
|
|
353
|
+
)
|
|
354
|
+
b = tl.load(
|
|
355
|
+
b_ptrs,
|
|
356
|
+
mask=((offs_bn[:, None] < n_size) & updated_k_offset_mask),
|
|
357
|
+
other=0.0,
|
|
358
|
+
)
|
|
359
|
+
accumulator += tl.dot(a, b.T)
|
|
360
|
+
a_ptrs += BLOCK_SIZE_K
|
|
361
|
+
b_ptrs += BLOCK_SIZE_K
|
|
362
|
+
|
|
363
|
+
if HAS_BIAS:
|
|
364
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
365
|
+
bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bn
|
|
366
|
+
bias = tl.load(bias_ptrs, mask=(offs_bn < n_size), other=0.0).to(
|
|
367
|
+
accumulator.dtype
|
|
368
|
+
)
|
|
369
|
+
accumulator = accumulator + bias[None, :]
|
|
370
|
+
|
|
371
|
+
if HAS_TOKEN_WEIGHTS:
|
|
372
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
373
|
+
tw_ptrs = token_weights_ptr + M_start_offset + offs_am
|
|
374
|
+
tw = tl.load(tw_ptrs, mask=(offs_am < m_size), other=1.0).to(
|
|
375
|
+
accumulator.dtype
|
|
376
|
+
)
|
|
377
|
+
accumulator = accumulator * tw[:, None]
|
|
378
|
+
|
|
379
|
+
if USE_TMA_STORE:
|
|
380
|
+
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
381
|
+
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
382
|
+
# pyre-ignore
|
|
383
|
+
c_desc_ptr.store(
|
|
384
|
+
[m_offset, n_offset], accumulator.to(c_ptr.dtype.element_ty)
|
|
385
|
+
)
|
|
386
|
+
elif FUSE_SCATTER_ADD:
|
|
387
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
388
|
+
mask = offs_am < m_size
|
|
389
|
+
m_offsets = tl.load(
|
|
390
|
+
scatter_add_indices + M_start_offset + offs_am,
|
|
391
|
+
mask=mask,
|
|
392
|
+
cache_modifier=".ca",
|
|
393
|
+
)
|
|
394
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
395
|
+
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
396
|
+
tl.atomic_add(
|
|
397
|
+
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
398
|
+
c,
|
|
399
|
+
mask=mask[:, None] and offs_bn[None, :] < n_size,
|
|
400
|
+
sem="relaxed",
|
|
401
|
+
)
|
|
402
|
+
else:
|
|
403
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
404
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
405
|
+
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
406
|
+
tl.store(
|
|
407
|
+
c_ptr
|
|
408
|
+
+ (M_start_offset + offs_am[:, None]) * N
|
|
409
|
+
+ offs_bn[None, :],
|
|
410
|
+
c,
|
|
411
|
+
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
|
|
412
|
+
)
|
|
413
|
+
tidx += NUM_SMS
|
|
414
|
+
|
|
415
|
+
iterated_tiles += num_tiles
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
# TODO(shikaili): Too much code duplication. Need to refactor.
|
|
419
|
+
@triton.autotune(
|
|
420
|
+
configs=_NV_WS_CONFIGS,
|
|
421
|
+
key=["G", "M_BUCKET", "N", "K"],
|
|
422
|
+
prune_configs_by={"early_config_prune": early_config_prune_ws},
|
|
423
|
+
restore_value=["c_ptr"], # restore for scatter_add fusion
|
|
424
|
+
)
|
|
425
|
+
@triton.jit
|
|
426
|
+
def _mslk_grouped_gemm_ws(
|
|
427
|
+
a_desc_ptr,
|
|
428
|
+
b_desc_ptr,
|
|
429
|
+
c_ptr,
|
|
430
|
+
scatter_add_indices,
|
|
431
|
+
m_sizes,
|
|
432
|
+
bias_ptr,
|
|
433
|
+
token_weights_ptr,
|
|
434
|
+
# problem sizes
|
|
435
|
+
G: tl.constexpr,
|
|
436
|
+
M_BUCKET: tl.constexpr,
|
|
437
|
+
N: tl.constexpr,
|
|
438
|
+
K: tl.constexpr,
|
|
439
|
+
NUM_SMS: tl.constexpr,
|
|
440
|
+
FUSE_SCATTER_ADD: tl.constexpr,
|
|
441
|
+
USE_TMA_LOAD: tl.constexpr,
|
|
442
|
+
USE_FAST_ACCUM: tl.constexpr,
|
|
443
|
+
HAS_BIAS: tl.constexpr,
|
|
444
|
+
HAS_TOKEN_WEIGHTS: tl.constexpr,
|
|
445
|
+
# tile sizes
|
|
446
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
447
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
448
|
+
BLOCK_SIZE_K: tl.constexpr,
|
|
449
|
+
NUM_CONSUMER_GROUPS: tl.constexpr,
|
|
450
|
+
USE_TMA_STORE: tl.constexpr,
|
|
451
|
+
) -> None:
|
|
452
|
+
tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!")
|
|
453
|
+
tl.static_assert(
|
|
454
|
+
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
|
|
455
|
+
"Cannot fuse scatter add with TMA store!",
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
tidx = tl.program_id(0)
|
|
459
|
+
|
|
460
|
+
M_end_offset = 0
|
|
461
|
+
M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
|
|
462
|
+
iterated_tiles = 0
|
|
463
|
+
for g in tl.range(G):
|
|
464
|
+
# Move across groups
|
|
465
|
+
m_size = tl.load(m_sizes + g, cache_modifier=".ca")
|
|
466
|
+
|
|
467
|
+
if m_size > 0:
|
|
468
|
+
M_start_offset = M_end_offset
|
|
469
|
+
M_end_offset = M_start_offset + m_size
|
|
470
|
+
N_start_offset = g.to(tl.int64) * N
|
|
471
|
+
|
|
472
|
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
|
473
|
+
tl.static_assert(N % BLOCK_SIZE_N == 0, f"{N=} {BLOCK_SIZE_N=}")
|
|
474
|
+
NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N
|
|
475
|
+
num_tiles = num_m_tiles * NUM_N_TILES
|
|
476
|
+
|
|
477
|
+
if USE_TMA_STORE:
|
|
478
|
+
c_desc_ptr = tl.make_tensor_descriptor(
|
|
479
|
+
c_ptr + M_start_offset * N,
|
|
480
|
+
shape=[m_size, N],
|
|
481
|
+
# pyre-ignore
|
|
482
|
+
strides=[N, 1],
|
|
483
|
+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
# Move across tiles
|
|
487
|
+
next_iterated_tiles = iterated_tiles + num_tiles
|
|
488
|
+
if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles):
|
|
489
|
+
for i in range(tidx, next_iterated_tiles, NUM_SMS):
|
|
490
|
+
gidx = i - iterated_tiles
|
|
491
|
+
# Split M first and N second.
|
|
492
|
+
tile_m_idx = gidx % num_m_tiles
|
|
493
|
+
tile_n_idx = gidx // num_m_tiles
|
|
494
|
+
|
|
495
|
+
accumulator = tl.zeros(
|
|
496
|
+
(BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
|
|
497
|
+
)
|
|
498
|
+
tl.static_assert(K % BLOCK_SIZE_K == 0)
|
|
499
|
+
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
500
|
+
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
501
|
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
502
|
+
a = a_desc_ptr.load([m_offset, k_offset])
|
|
503
|
+
b = b_desc_ptr.load([n_offset, k_offset])
|
|
504
|
+
if USE_FAST_ACCUM:
|
|
505
|
+
accumulator = tl.dot(a, b.T, accumulator)
|
|
506
|
+
else:
|
|
507
|
+
accumulator += tl.dot(a, b.T)
|
|
508
|
+
|
|
509
|
+
if HAS_BIAS:
|
|
510
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
511
|
+
bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bn
|
|
512
|
+
bias = tl.load(bias_ptrs).to(accumulator.dtype)
|
|
513
|
+
accumulator = accumulator + bias[None, :]
|
|
514
|
+
|
|
515
|
+
if HAS_TOKEN_WEIGHTS:
|
|
516
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
517
|
+
tw_ptrs = token_weights_ptr + M_start_offset + offs_am
|
|
518
|
+
tw = tl.load(tw_ptrs, mask=(offs_am < m_size), other=1.0).to(
|
|
519
|
+
accumulator.dtype
|
|
520
|
+
)
|
|
521
|
+
accumulator = accumulator * tw[:, None]
|
|
522
|
+
|
|
523
|
+
if USE_TMA_STORE:
|
|
524
|
+
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
525
|
+
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
526
|
+
# pyre-ignore
|
|
527
|
+
c_desc_ptr.store(
|
|
528
|
+
[m_offset, n_offset],
|
|
529
|
+
accumulator.to(c_ptr.dtype.element_ty),
|
|
530
|
+
)
|
|
531
|
+
elif FUSE_SCATTER_ADD:
|
|
532
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
533
|
+
mask = offs_am < m_size
|
|
534
|
+
m_offsets = tl.load(
|
|
535
|
+
scatter_add_indices + M_start_offset + offs_am,
|
|
536
|
+
mask=mask,
|
|
537
|
+
cache_modifier=".ca",
|
|
538
|
+
)
|
|
539
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
540
|
+
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
541
|
+
tl.atomic_add(
|
|
542
|
+
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
543
|
+
c,
|
|
544
|
+
mask=mask[:, None],
|
|
545
|
+
sem="relaxed",
|
|
546
|
+
)
|
|
547
|
+
else:
|
|
548
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
549
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
550
|
+
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
551
|
+
tl.store(
|
|
552
|
+
c_ptr
|
|
553
|
+
+ (M_start_offset + offs_am[:, None]) * N
|
|
554
|
+
+ offs_bn[None, :],
|
|
555
|
+
c,
|
|
556
|
+
mask=offs_am[:, None] < m_size,
|
|
557
|
+
cache_modifier=".cs",
|
|
558
|
+
)
|
|
559
|
+
tidx += NUM_SMS
|
|
560
|
+
|
|
561
|
+
iterated_tiles += num_tiles
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
TT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
# TODO(shikaili): clean up redundant 'b_scale_desc_ptr' argument.
|
|
568
|
+
@triton.autotune(
|
|
569
|
+
configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
|
|
570
|
+
key=["G", "M_BUCKET", "N", "K"],
|
|
571
|
+
prune_configs_by={
|
|
572
|
+
"early_config_prune": functools.partial(
|
|
573
|
+
early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1
|
|
574
|
+
)
|
|
575
|
+
},
|
|
576
|
+
restore_value=["c_ptr"], # restore for scatter_add fusion
|
|
577
|
+
)
|
|
578
|
+
@triton.jit
|
|
579
|
+
def _mslk_grouped_gemm_fp8_rowwise(
|
|
580
|
+
a_desc_ptr,
|
|
581
|
+
a_scale_ptr,
|
|
582
|
+
b_desc_ptr,
|
|
583
|
+
b_scale_ptr,
|
|
584
|
+
b_scale_desc_ptr,
|
|
585
|
+
c_ptr,
|
|
586
|
+
scatter_add_indices,
|
|
587
|
+
m_sizes,
|
|
588
|
+
# problem sizes
|
|
589
|
+
G: tl.constexpr,
|
|
590
|
+
M_BUCKET,
|
|
591
|
+
N: tl.constexpr,
|
|
592
|
+
K: tl.constexpr,
|
|
593
|
+
NUM_SMS: tl.constexpr,
|
|
594
|
+
FUSE_SCATTER_ADD: tl.constexpr,
|
|
595
|
+
USE_TMA_LOAD: tl.constexpr,
|
|
596
|
+
USE_TMA_STORE: tl.constexpr,
|
|
597
|
+
USE_FAST_ACCUM: tl.constexpr,
|
|
598
|
+
# tile sizes
|
|
599
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
600
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
601
|
+
BLOCK_SIZE_K: tl.constexpr,
|
|
602
|
+
NUM_CONSUMER_GROUPS: tl.constexpr,
|
|
603
|
+
) -> None:
|
|
604
|
+
tl.static_assert(
|
|
605
|
+
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
|
|
606
|
+
"Cannot fuse scatter add with TMA store!",
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
tidx = tl.program_id(0)
|
|
610
|
+
|
|
611
|
+
M_end_offset = 0
|
|
612
|
+
M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
|
|
613
|
+
iterated_tiles = 0
|
|
614
|
+
for g in tl.range(G):
|
|
615
|
+
# Move across groups
|
|
616
|
+
m_size = tl.load(m_sizes + g)
|
|
617
|
+
|
|
618
|
+
if m_size > 0:
|
|
619
|
+
M_start_offset = M_end_offset
|
|
620
|
+
M_end_offset = M_start_offset + m_size
|
|
621
|
+
N_start_offset = g.to(tl.int64) * N
|
|
622
|
+
n_size = N
|
|
623
|
+
|
|
624
|
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
|
625
|
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
|
626
|
+
num_tiles = num_m_tiles * num_n_tiles
|
|
627
|
+
|
|
628
|
+
if USE_TMA_STORE:
|
|
629
|
+
c_desc_ptr = tl.make_tensor_descriptor(
|
|
630
|
+
c_ptr + M_start_offset * N,
|
|
631
|
+
shape=[m_size, n_size],
|
|
632
|
+
# pyre-ignore
|
|
633
|
+
strides=[n_size, 1],
|
|
634
|
+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
# Move across tiles
|
|
638
|
+
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
|
|
639
|
+
gidx = tidx - iterated_tiles
|
|
640
|
+
# Split M first and N second.
|
|
641
|
+
tile_m_idx = gidx % num_m_tiles
|
|
642
|
+
tile_n_idx = gidx // num_m_tiles
|
|
643
|
+
|
|
644
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
645
|
+
tl.static_assert(K % BLOCK_SIZE_K == 0)
|
|
646
|
+
if USE_TMA_LOAD:
|
|
647
|
+
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
648
|
+
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
649
|
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
650
|
+
a = a_desc_ptr.load([m_offset, k_offset])
|
|
651
|
+
b = b_desc_ptr.load([n_offset, k_offset])
|
|
652
|
+
if USE_FAST_ACCUM:
|
|
653
|
+
accumulator = tl.dot(a, b.T, accumulator)
|
|
654
|
+
else:
|
|
655
|
+
accumulator += tl.dot(a, b.T)
|
|
656
|
+
else:
|
|
657
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
658
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
659
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
660
|
+
a_ptrs = (
|
|
661
|
+
a_desc_ptr
|
|
662
|
+
+ (M_start_offset + offs_am[:, None]) * K
|
|
663
|
+
+ offs_k[None, :]
|
|
664
|
+
)
|
|
665
|
+
b_ptrs = (
|
|
666
|
+
b_desc_ptr
|
|
667
|
+
+ (N_start_offset + offs_bn[:, None]) * K
|
|
668
|
+
+ offs_k[None, :]
|
|
669
|
+
)
|
|
670
|
+
for _ in range(0, K, BLOCK_SIZE_K):
|
|
671
|
+
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
|
|
672
|
+
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
|
|
673
|
+
accumulator += tl.dot(a, b.T)
|
|
674
|
+
a_ptrs += BLOCK_SIZE_K
|
|
675
|
+
b_ptrs += BLOCK_SIZE_K
|
|
676
|
+
|
|
677
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
678
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
679
|
+
a_scale = tl.load(
|
|
680
|
+
a_scale_ptr + M_start_offset + offs_am[:, None],
|
|
681
|
+
mask=offs_am[:, None] < m_size,
|
|
682
|
+
)
|
|
683
|
+
b_scale = tl.load(
|
|
684
|
+
b_scale_ptr + N_start_offset + offs_bn[None, :],
|
|
685
|
+
mask=offs_bn[None, :] < n_size,
|
|
686
|
+
)
|
|
687
|
+
c = accumulator.to(tl.float32) * a_scale * b_scale
|
|
688
|
+
|
|
689
|
+
if USE_TMA_STORE:
|
|
690
|
+
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
691
|
+
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
692
|
+
# pyre-ignore
|
|
693
|
+
c_desc_ptr.store([m_offset, n_offset], c.to(c_ptr.dtype.element_ty))
|
|
694
|
+
elif FUSE_SCATTER_ADD:
|
|
695
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
696
|
+
mask = offs_am < m_size
|
|
697
|
+
m_offsets = tl.load(
|
|
698
|
+
scatter_add_indices + M_start_offset + offs_am,
|
|
699
|
+
mask=mask,
|
|
700
|
+
cache_modifier=".ca",
|
|
701
|
+
)
|
|
702
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
703
|
+
tl.atomic_add(
|
|
704
|
+
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
705
|
+
c.to(c_ptr.dtype.element_ty),
|
|
706
|
+
mask=mask[:, None] and offs_bn[None, :] < n_size,
|
|
707
|
+
sem="relaxed",
|
|
708
|
+
)
|
|
709
|
+
else:
|
|
710
|
+
tl.store(
|
|
711
|
+
c_ptr
|
|
712
|
+
+ (M_start_offset + offs_am[:, None]) * N
|
|
713
|
+
+ offs_bn[None, :],
|
|
714
|
+
c,
|
|
715
|
+
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
|
|
716
|
+
)
|
|
717
|
+
tidx += NUM_SMS
|
|
718
|
+
|
|
719
|
+
iterated_tiles += num_tiles
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
# TODO(shikaili): Too much code duplication. Need to refactor.
|
|
723
|
+
@triton.autotune(
|
|
724
|
+
configs=_NV_WS_CONFIGS,
|
|
725
|
+
key=["G", "M_BUCKET", "N", "K"],
|
|
726
|
+
prune_configs_by={
|
|
727
|
+
"early_config_prune": functools.partial(
|
|
728
|
+
early_config_prune_ws, dtype=TT_FP8_DTYPE, dtsize=1
|
|
729
|
+
)
|
|
730
|
+
},
|
|
731
|
+
restore_value=["c_ptr"], # restore for scatter_add fusion
|
|
732
|
+
)
|
|
733
|
+
@triton.jit
|
|
734
|
+
def _mslk_grouped_gemm_fp8_rowwise_ws(
|
|
735
|
+
a_desc_ptr,
|
|
736
|
+
a_scale_ptr,
|
|
737
|
+
b_desc_ptr,
|
|
738
|
+
b_scale_ptr,
|
|
739
|
+
c_ptr,
|
|
740
|
+
scatter_add_indices,
|
|
741
|
+
m_sizes,
|
|
742
|
+
# problem sizes
|
|
743
|
+
G: tl.constexpr,
|
|
744
|
+
M_BUCKET: tl.constexpr,
|
|
745
|
+
N: tl.constexpr,
|
|
746
|
+
K: tl.constexpr,
|
|
747
|
+
NUM_SMS: tl.constexpr,
|
|
748
|
+
FUSE_SCATTER_ADD: tl.constexpr,
|
|
749
|
+
USE_TMA_LOAD: tl.constexpr,
|
|
750
|
+
USE_FAST_ACCUM: tl.constexpr,
|
|
751
|
+
# tile sizes
|
|
752
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
753
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
754
|
+
BLOCK_SIZE_K: tl.constexpr,
|
|
755
|
+
NUM_CONSUMER_GROUPS: tl.constexpr,
|
|
756
|
+
USE_TMA_STORE: tl.constexpr,
|
|
757
|
+
) -> None:
|
|
758
|
+
tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!")
|
|
759
|
+
tl.static_assert(
|
|
760
|
+
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
|
|
761
|
+
"Cannot fuse scatter add with TMA store!",
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
tidx = tl.program_id(0)
|
|
765
|
+
|
|
766
|
+
M_end_offset = 0
|
|
767
|
+
M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
|
|
768
|
+
iterated_tiles = 0
|
|
769
|
+
for g in tl.range(G):
|
|
770
|
+
# Move across groups
|
|
771
|
+
m_size = tl.load(m_sizes + g, cache_modifier=".ca")
|
|
772
|
+
|
|
773
|
+
if m_size > 0:
|
|
774
|
+
M_start_offset = M_end_offset
|
|
775
|
+
M_end_offset = M_start_offset + m_size
|
|
776
|
+
N_start_offset = g.to(tl.int64) * N
|
|
777
|
+
|
|
778
|
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
|
779
|
+
tl.static_assert(N % BLOCK_SIZE_N == 0)
|
|
780
|
+
NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N
|
|
781
|
+
num_tiles = num_m_tiles * NUM_N_TILES
|
|
782
|
+
|
|
783
|
+
if USE_TMA_STORE:
|
|
784
|
+
c_desc_ptr = tl.make_tensor_descriptor(
|
|
785
|
+
c_ptr + M_start_offset * N,
|
|
786
|
+
shape=[m_size, N],
|
|
787
|
+
# pyre-ignore
|
|
788
|
+
strides=[N, 1],
|
|
789
|
+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
# Move across tiles
|
|
793
|
+
next_iterated_tiles = iterated_tiles + num_tiles
|
|
794
|
+
if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles):
|
|
795
|
+
for i in range(tidx, next_iterated_tiles, NUM_SMS):
|
|
796
|
+
gidx = i - iterated_tiles
|
|
797
|
+
# Split M first and N second.
|
|
798
|
+
tile_m_idx = gidx % num_m_tiles
|
|
799
|
+
tile_n_idx = gidx // num_m_tiles
|
|
800
|
+
|
|
801
|
+
accumulator = tl.zeros(
|
|
802
|
+
(BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
|
|
803
|
+
)
|
|
804
|
+
tl.static_assert(K % BLOCK_SIZE_K == 0)
|
|
805
|
+
|
|
806
|
+
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
807
|
+
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
808
|
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
809
|
+
a = a_desc_ptr.load([m_offset, k_offset])
|
|
810
|
+
b = b_desc_ptr.load([n_offset, k_offset])
|
|
811
|
+
if USE_FAST_ACCUM:
|
|
812
|
+
accumulator = tl.dot(a, b.T, accumulator)
|
|
813
|
+
else:
|
|
814
|
+
accumulator += tl.dot(a, b.T)
|
|
815
|
+
|
|
816
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
817
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
818
|
+
a_scale = tl.load(
|
|
819
|
+
a_scale_ptr + M_start_offset + offs_am[:, None],
|
|
820
|
+
mask=offs_am[:, None] < m_size,
|
|
821
|
+
cache_modifier=".ca",
|
|
822
|
+
)
|
|
823
|
+
b_scale = tl.load(
|
|
824
|
+
b_scale_ptr + N_start_offset + offs_bn[None, :],
|
|
825
|
+
cache_modifier=".ca",
|
|
826
|
+
)
|
|
827
|
+
c = accumulator.to(tl.float32) * a_scale * b_scale
|
|
828
|
+
|
|
829
|
+
if USE_TMA_STORE:
|
|
830
|
+
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
831
|
+
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
832
|
+
# pyre-ignore
|
|
833
|
+
c_desc_ptr.store(
|
|
834
|
+
[m_offset, n_offset], c.to(c_ptr.dtype.element_ty)
|
|
835
|
+
)
|
|
836
|
+
elif FUSE_SCATTER_ADD:
|
|
837
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
838
|
+
mask = offs_am < m_size
|
|
839
|
+
m_offsets = tl.load(
|
|
840
|
+
scatter_add_indices + M_start_offset + offs_am,
|
|
841
|
+
mask=mask,
|
|
842
|
+
cache_modifier=".ca",
|
|
843
|
+
)
|
|
844
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
845
|
+
tl.atomic_add(
|
|
846
|
+
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
847
|
+
c,
|
|
848
|
+
mask=mask[:, None],
|
|
849
|
+
sem="relaxed",
|
|
850
|
+
)
|
|
851
|
+
else:
|
|
852
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
853
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
854
|
+
tl.store(
|
|
855
|
+
c_ptr
|
|
856
|
+
+ (M_start_offset + offs_am[:, None]) * N
|
|
857
|
+
+ offs_bn[None, :],
|
|
858
|
+
c,
|
|
859
|
+
mask=offs_am[:, None] < m_size,
|
|
860
|
+
cache_modifier=".cs",
|
|
861
|
+
)
|
|
862
|
+
tidx += NUM_SMS
|
|
863
|
+
|
|
864
|
+
iterated_tiles += num_tiles
|
|
865
|
+
|
|
866
|
+
|
|
867
|
+
warnings.simplefilter("once")
|
|
868
|
+
|
|
869
|
+
|
|
870
|
+
def _grouped_gemm(
|
|
871
|
+
*,
|
|
872
|
+
x: torch.Tensor,
|
|
873
|
+
w: torch.Tensor,
|
|
874
|
+
m_sizes: torch.Tensor,
|
|
875
|
+
x_scale: Optional[torch.Tensor],
|
|
876
|
+
w_scale: Optional[torch.Tensor],
|
|
877
|
+
bias: Optional[torch.Tensor],
|
|
878
|
+
token_weights: Optional[torch.Tensor],
|
|
879
|
+
use_fast_accum: bool,
|
|
880
|
+
use_warp_specialization: bool,
|
|
881
|
+
output_tensor: Optional[torch.Tensor],
|
|
882
|
+
scatter_add_indices: Optional[torch.Tensor],
|
|
883
|
+
) -> torch.Tensor:
|
|
884
|
+
USE_TMA_LOAD = not torch.version.hip and TMA_AVAILABLE
|
|
885
|
+
USE_TMA_STORE = False
|
|
886
|
+
|
|
887
|
+
# TODO(shikaili): Check the readniess of WS on ROCm side in Meta's Triton.
|
|
888
|
+
if use_warp_specialization and torch.version.hip:
|
|
889
|
+
warnings.warn(
|
|
890
|
+
"Warp specialization is disabled as it is not supported on ROCm.",
|
|
891
|
+
stacklevel=2,
|
|
892
|
+
)
|
|
893
|
+
use_warp_specialization = False
|
|
894
|
+
|
|
895
|
+
if use_warp_specialization:
|
|
896
|
+
assert TMA_AVAILABLE, "TMA is not available"
|
|
897
|
+
USE_TMA_STORE = True # Tuning decision
|
|
898
|
+
|
|
899
|
+
G = m_sizes.shape[0]
|
|
900
|
+
|
|
901
|
+
assert x.is_contiguous()
|
|
902
|
+
assert w.is_contiguous()
|
|
903
|
+
assert m_sizes.is_contiguous()
|
|
904
|
+
|
|
905
|
+
M, K = x.shape
|
|
906
|
+
N = w.shape[0] // G
|
|
907
|
+
assert K == w.shape[1]
|
|
908
|
+
|
|
909
|
+
if K % 8 != 0 or N % 8 != 0:
|
|
910
|
+
use_warp_specialization = False
|
|
911
|
+
USE_TMA_LOAD = False
|
|
912
|
+
USE_TMA_STORE = False
|
|
913
|
+
warnings.warn(
|
|
914
|
+
f"TMA load and warp specialization are disabled since K or N is not a multiple of 8: {K=}, {N=}.",
|
|
915
|
+
stacklevel=2,
|
|
916
|
+
)
|
|
917
|
+
assert x_scale is None, (
|
|
918
|
+
f"Quantisation is not supported yet when K or N is not a multiple of 8: {K=}, {N=}."
|
|
919
|
+
)
|
|
920
|
+
|
|
921
|
+
assert output_tensor is None, (
|
|
922
|
+
f"Fused scatter add has large rounding error when K or N is not a multiple of 8: {K=}, {N=}."
|
|
923
|
+
)
|
|
924
|
+
|
|
925
|
+
HAS_BIAS = bias is not None
|
|
926
|
+
if HAS_BIAS:
|
|
927
|
+
assert bias is not None # for type checker
|
|
928
|
+
assert bias.is_contiguous(), "Bias must be contiguous"
|
|
929
|
+
assert len(bias.shape) == 2, f"Bias must be 2D, got shape {bias.shape}"
|
|
930
|
+
assert bias.shape[0] == G, f"Bias dim 0 must match G={G}, got {bias.shape[0]}"
|
|
931
|
+
assert bias.shape[1] == N, f"Bias dim 1 must match N={N}, got {bias.shape[1]}"
|
|
932
|
+
|
|
933
|
+
HAS_TOKEN_WEIGHTS = token_weights is not None
|
|
934
|
+
if HAS_TOKEN_WEIGHTS:
|
|
935
|
+
assert token_weights is not None # for type checker
|
|
936
|
+
assert token_weights.is_contiguous(), "token_weights must be contiguous"
|
|
937
|
+
assert len(token_weights.shape) == 1, (
|
|
938
|
+
f"token_weights must be 1D, got shape {token_weights.shape}"
|
|
939
|
+
)
|
|
940
|
+
assert token_weights.shape[0] == M, (
|
|
941
|
+
f"token_weights dim 0 must match M={M}, got {token_weights.shape[0]}"
|
|
942
|
+
)
|
|
943
|
+
|
|
944
|
+
if output_tensor is None:
|
|
945
|
+
FUSE_SCATTER_ADD = False
|
|
946
|
+
assert scatter_add_indices is None
|
|
947
|
+
y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
|
|
948
|
+
else:
|
|
949
|
+
FUSE_SCATTER_ADD = True
|
|
950
|
+
assert scatter_add_indices is not None
|
|
951
|
+
assert scatter_add_indices.is_contiguous()
|
|
952
|
+
assert scatter_add_indices.shape == (M,)
|
|
953
|
+
y = output_tensor
|
|
954
|
+
if M == 0 or N == 0:
|
|
955
|
+
return y
|
|
956
|
+
|
|
957
|
+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
|
958
|
+
|
|
959
|
+
# A dummy block value that will be overwritten in the pre_hook when we have the real block size
|
|
960
|
+
dummy_block = [1, 1]
|
|
961
|
+
|
|
962
|
+
if USE_TMA_LOAD:
|
|
963
|
+
# pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional
|
|
964
|
+
# argument, expected `List[int]` but got `Size`
|
|
965
|
+
desc_x = TensorDescriptor(x, x.shape, x.stride(), dummy_block)
|
|
966
|
+
# pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional
|
|
967
|
+
# argument, expected `List[int]` but got `Size`
|
|
968
|
+
desc_w = TensorDescriptor(w, w.shape, w.stride(), dummy_block)
|
|
969
|
+
else:
|
|
970
|
+
desc_x = x
|
|
971
|
+
desc_w = w
|
|
972
|
+
|
|
973
|
+
if USE_TMA_STORE:
|
|
974
|
+
|
|
975
|
+
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
|
|
976
|
+
return torch.empty(size, device="cuda", dtype=torch.int8)
|
|
977
|
+
|
|
978
|
+
triton.set_allocator(alloc_fn)
|
|
979
|
+
|
|
980
|
+
def grid(META):
|
|
981
|
+
return (NUM_SMS,)
|
|
982
|
+
|
|
983
|
+
M_BUCKET_CAP = 16384
|
|
984
|
+
M_BUCKET = min(triton.next_power_of_2(M), M_BUCKET_CAP)
|
|
985
|
+
if x_scale is not None and w_scale is not None:
|
|
986
|
+
assert x_scale.is_contiguous()
|
|
987
|
+
assert w_scale.is_contiguous()
|
|
988
|
+
fn = (
|
|
989
|
+
_mslk_grouped_gemm_fp8_rowwise_ws
|
|
990
|
+
if use_warp_specialization
|
|
991
|
+
else _mslk_grouped_gemm_fp8_rowwise
|
|
992
|
+
)
|
|
993
|
+
if use_warp_specialization:
|
|
994
|
+
args = (
|
|
995
|
+
desc_x,
|
|
996
|
+
x_scale,
|
|
997
|
+
desc_w,
|
|
998
|
+
w_scale,
|
|
999
|
+
y,
|
|
1000
|
+
scatter_add_indices,
|
|
1001
|
+
m_sizes,
|
|
1002
|
+
G,
|
|
1003
|
+
M_BUCKET,
|
|
1004
|
+
N,
|
|
1005
|
+
K,
|
|
1006
|
+
NUM_SMS,
|
|
1007
|
+
FUSE_SCATTER_ADD,
|
|
1008
|
+
USE_TMA_LOAD,
|
|
1009
|
+
use_fast_accum,
|
|
1010
|
+
)
|
|
1011
|
+
else:
|
|
1012
|
+
args = (
|
|
1013
|
+
desc_x,
|
|
1014
|
+
x_scale,
|
|
1015
|
+
desc_w,
|
|
1016
|
+
w_scale,
|
|
1017
|
+
w_scale, # b_scale_desc_ptr (unused, just passed for API compatibility)
|
|
1018
|
+
y,
|
|
1019
|
+
scatter_add_indices,
|
|
1020
|
+
m_sizes,
|
|
1021
|
+
G,
|
|
1022
|
+
M_BUCKET,
|
|
1023
|
+
N,
|
|
1024
|
+
K,
|
|
1025
|
+
NUM_SMS,
|
|
1026
|
+
FUSE_SCATTER_ADD,
|
|
1027
|
+
USE_TMA_LOAD,
|
|
1028
|
+
USE_TMA_STORE,
|
|
1029
|
+
use_fast_accum,
|
|
1030
|
+
)
|
|
1031
|
+
fn[grid](*args)
|
|
1032
|
+
else:
|
|
1033
|
+
assert x_scale is None
|
|
1034
|
+
assert w_scale is None
|
|
1035
|
+
fn = _mslk_grouped_gemm_ws if use_warp_specialization else _mslk_grouped_gemm
|
|
1036
|
+
args = (
|
|
1037
|
+
desc_x,
|
|
1038
|
+
desc_w,
|
|
1039
|
+
y,
|
|
1040
|
+
scatter_add_indices,
|
|
1041
|
+
m_sizes,
|
|
1042
|
+
bias if HAS_BIAS else None,
|
|
1043
|
+
token_weights if HAS_TOKEN_WEIGHTS else None,
|
|
1044
|
+
G,
|
|
1045
|
+
M_BUCKET,
|
|
1046
|
+
N,
|
|
1047
|
+
K,
|
|
1048
|
+
NUM_SMS,
|
|
1049
|
+
FUSE_SCATTER_ADD,
|
|
1050
|
+
USE_TMA_LOAD,
|
|
1051
|
+
)
|
|
1052
|
+
if use_warp_specialization:
|
|
1053
|
+
args += (use_fast_accum, HAS_BIAS, HAS_TOKEN_WEIGHTS)
|
|
1054
|
+
else:
|
|
1055
|
+
args += (USE_TMA_STORE, use_fast_accum, HAS_BIAS, HAS_TOKEN_WEIGHTS)
|
|
1056
|
+
fn[grid](*args)
|
|
1057
|
+
|
|
1058
|
+
return y
|
|
1059
|
+
|
|
1060
|
+
|
|
1061
|
+
def grouped_gemm(
|
|
1062
|
+
x: torch.Tensor,
|
|
1063
|
+
w: torch.Tensor,
|
|
1064
|
+
m_sizes: torch.Tensor,
|
|
1065
|
+
bias: Optional[torch.Tensor] = None,
|
|
1066
|
+
token_weights: Optional[torch.Tensor] = None,
|
|
1067
|
+
use_fast_accum: bool = True,
|
|
1068
|
+
*,
|
|
1069
|
+
_use_warp_specialization: bool = True,
|
|
1070
|
+
_output_tensor: Optional[torch.Tensor] = None,
|
|
1071
|
+
_scatter_add_indices: Optional[torch.Tensor] = None,
|
|
1072
|
+
) -> torch.Tensor:
|
|
1073
|
+
"""
|
|
1074
|
+
Grouped GEMM with optional bias addition and per-token weight scaling.
|
|
1075
|
+
|
|
1076
|
+
Performs: output = (x @ w.T + bias) * token_weights
|
|
1077
|
+
where operations are grouped by experts.
|
|
1078
|
+
|
|
1079
|
+
Args:
|
|
1080
|
+
x: Input tensor [M, K] where M is total tokens across all experts
|
|
1081
|
+
w: Weight tensor [G * N, K] where G is number of experts
|
|
1082
|
+
m_sizes: Tensor [G] indicating number of tokens per expert
|
|
1083
|
+
bias: Optional bias tensor [G, N], one bias vector per expert
|
|
1084
|
+
token_weights: Optional per-token scaling weights [M] (e.g., router weights)
|
|
1085
|
+
use_fast_accum: Enable fast accumulation for better performance
|
|
1086
|
+
_use_warp_specialization: Flag for warp specialization
|
|
1087
|
+
_output_tensor: Optional pre-allocated output tensor for scatter-add
|
|
1088
|
+
_scatter_add_indices: Optional indices for scatter-add operation
|
|
1089
|
+
|
|
1090
|
+
Returns:
|
|
1091
|
+
Output tensor [M, N]
|
|
1092
|
+
"""
|
|
1093
|
+
return _grouped_gemm(
|
|
1094
|
+
x=x,
|
|
1095
|
+
w=w,
|
|
1096
|
+
m_sizes=m_sizes,
|
|
1097
|
+
x_scale=None,
|
|
1098
|
+
w_scale=None,
|
|
1099
|
+
bias=bias,
|
|
1100
|
+
token_weights=token_weights,
|
|
1101
|
+
use_fast_accum=use_fast_accum,
|
|
1102
|
+
use_warp_specialization=_use_warp_specialization,
|
|
1103
|
+
output_tensor=_output_tensor,
|
|
1104
|
+
scatter_add_indices=_scatter_add_indices,
|
|
1105
|
+
)
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
def grouped_gemm_fp8_rowwise(
|
|
1109
|
+
x: torch.Tensor,
|
|
1110
|
+
w: torch.Tensor,
|
|
1111
|
+
m_sizes: torch.Tensor,
|
|
1112
|
+
x_scale: torch.Tensor,
|
|
1113
|
+
w_scale: torch.Tensor,
|
|
1114
|
+
use_fast_accum: bool = True,
|
|
1115
|
+
*,
|
|
1116
|
+
_use_warp_specialization: bool = True,
|
|
1117
|
+
_output_tensor: Optional[torch.Tensor] = None,
|
|
1118
|
+
_scatter_add_indices: Optional[torch.Tensor] = None,
|
|
1119
|
+
) -> torch.Tensor:
|
|
1120
|
+
return _grouped_gemm(
|
|
1121
|
+
x=x,
|
|
1122
|
+
w=w,
|
|
1123
|
+
m_sizes=m_sizes,
|
|
1124
|
+
x_scale=x_scale,
|
|
1125
|
+
w_scale=w_scale,
|
|
1126
|
+
bias=None,
|
|
1127
|
+
token_weights=None,
|
|
1128
|
+
use_fast_accum=use_fast_accum,
|
|
1129
|
+
use_warp_specialization=_use_warp_specialization,
|
|
1130
|
+
output_tensor=_output_tensor,
|
|
1131
|
+
scatter_add_indices=_scatter_add_indices,
|
|
1132
|
+
)
|