mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2702 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import triton # @manual
|
|
14
|
+
import triton.language as tl # @manual
|
|
15
|
+
from mslk.gemm.triton.matmul_perf_model import early_config_prune, estimate_matmul_time
|
|
16
|
+
from mslk.gemm.triton.utils import map_dtype_to_triton, TmaAutoTuneHelper
|
|
17
|
+
from mslk.utils.triton.fp8_utils import get_fp8_constants, reinterpret_fp8_type
|
|
18
|
+
from packaging import version
|
|
19
|
+
from torch._tensor import Tensor
|
|
20
|
+
from triton import Config # @manual
|
|
21
|
+
from triton.runtime.jit import TensorWrapper # @manual
|
|
22
|
+
|
|
23
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
running_on_github: bool = os.getenv("GITHUB_ENV") is not None
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
# pyre-ignore[21]
|
|
29
|
+
from triton.fb.compat import disable_bufferops # @manual
|
|
30
|
+
except ModuleNotFoundError:
|
|
31
|
+
# Ensure we can call disable_bufferops if compat is not included (e.g. opensource)
|
|
32
|
+
# TODO(njriasan): Remove when we integrate triton.fb.compat into every Triton
|
|
33
|
+
# version.
|
|
34
|
+
from contextlib import contextmanager
|
|
35
|
+
|
|
36
|
+
@contextmanager
|
|
37
|
+
def disable_bufferops(_unused: bool):
|
|
38
|
+
yield None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def init_to_zero(name):
|
|
42
|
+
return lambda nargs: nargs[name].zero_()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_configs_io_bound() -> List[Config]:
|
|
46
|
+
"""
|
|
47
|
+
Returns a list of configs for matmul that are IO bound.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
List[Config]: list of configs.
|
|
51
|
+
"""
|
|
52
|
+
configs = []
|
|
53
|
+
for num_stages in [2, 3, 4, 5, 6]:
|
|
54
|
+
for block_m in [16, 32]:
|
|
55
|
+
for block_k in [32, 64]:
|
|
56
|
+
for block_n in [32, 64, 128, 256]:
|
|
57
|
+
num_warps = 2 if block_n <= 64 else 4
|
|
58
|
+
configs.append(
|
|
59
|
+
Config(
|
|
60
|
+
{
|
|
61
|
+
"BLOCK_M": block_m,
|
|
62
|
+
"BLOCK_N": block_n,
|
|
63
|
+
"BLOCK_K": block_k,
|
|
64
|
+
"SPLIT_K": 1,
|
|
65
|
+
},
|
|
66
|
+
num_stages=num_stages,
|
|
67
|
+
num_warps=num_warps,
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
# split_k
|
|
71
|
+
for split_k in []: # Disabled [2, 4, 8, 16]:
|
|
72
|
+
configs.append(
|
|
73
|
+
Config(
|
|
74
|
+
{
|
|
75
|
+
"BLOCK_M": block_m,
|
|
76
|
+
"BLOCK_N": block_n,
|
|
77
|
+
"BLOCK_K": block_k,
|
|
78
|
+
"SPLIT_K": split_k,
|
|
79
|
+
},
|
|
80
|
+
num_stages=num_stages,
|
|
81
|
+
num_warps=num_warps,
|
|
82
|
+
pre_hook=init_to_zero("C"),
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
return configs
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def dummy_prune_configs(configs, named_args, **kwargs):
|
|
89
|
+
M = named_args["M"]
|
|
90
|
+
N = named_args["N"]
|
|
91
|
+
K = named_args["K"]
|
|
92
|
+
|
|
93
|
+
logger.info(f"{len(configs)=} {len(configs)=} for {M=} {N=} {K=}")
|
|
94
|
+
return configs
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
MATMUL_CONFIGS: List[Config] = [
|
|
98
|
+
# basic configs for compute-bound matmuls
|
|
99
|
+
Config(
|
|
100
|
+
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
101
|
+
num_stages=3,
|
|
102
|
+
num_warps=8,
|
|
103
|
+
),
|
|
104
|
+
Config(
|
|
105
|
+
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
106
|
+
num_stages=3,
|
|
107
|
+
num_warps=8,
|
|
108
|
+
),
|
|
109
|
+
Config(
|
|
110
|
+
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
111
|
+
num_stages=4,
|
|
112
|
+
num_warps=4,
|
|
113
|
+
),
|
|
114
|
+
Config(
|
|
115
|
+
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
116
|
+
num_stages=4,
|
|
117
|
+
num_warps=4,
|
|
118
|
+
),
|
|
119
|
+
Config(
|
|
120
|
+
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
121
|
+
num_stages=4,
|
|
122
|
+
num_warps=4,
|
|
123
|
+
),
|
|
124
|
+
Config(
|
|
125
|
+
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 256, "SPLIT_K": 1},
|
|
126
|
+
num_stages=4,
|
|
127
|
+
num_warps=4,
|
|
128
|
+
),
|
|
129
|
+
Config(
|
|
130
|
+
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
131
|
+
num_stages=4,
|
|
132
|
+
num_warps=4,
|
|
133
|
+
),
|
|
134
|
+
Config(
|
|
135
|
+
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
136
|
+
num_stages=4,
|
|
137
|
+
num_warps=4,
|
|
138
|
+
),
|
|
139
|
+
Config(
|
|
140
|
+
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
141
|
+
num_stages=4,
|
|
142
|
+
num_warps=4,
|
|
143
|
+
),
|
|
144
|
+
Config(
|
|
145
|
+
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
146
|
+
num_stages=4,
|
|
147
|
+
num_warps=4,
|
|
148
|
+
),
|
|
149
|
+
Config(
|
|
150
|
+
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
151
|
+
num_stages=5,
|
|
152
|
+
num_warps=2,
|
|
153
|
+
),
|
|
154
|
+
# good for int8
|
|
155
|
+
Config(
|
|
156
|
+
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
157
|
+
num_stages=3,
|
|
158
|
+
num_warps=8,
|
|
159
|
+
),
|
|
160
|
+
Config(
|
|
161
|
+
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
162
|
+
num_stages=3,
|
|
163
|
+
num_warps=8,
|
|
164
|
+
),
|
|
165
|
+
Config(
|
|
166
|
+
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
167
|
+
num_stages=4,
|
|
168
|
+
num_warps=4,
|
|
169
|
+
),
|
|
170
|
+
Config(
|
|
171
|
+
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
172
|
+
num_stages=4,
|
|
173
|
+
num_warps=4,
|
|
174
|
+
),
|
|
175
|
+
Config(
|
|
176
|
+
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
177
|
+
num_stages=4,
|
|
178
|
+
num_warps=4,
|
|
179
|
+
),
|
|
180
|
+
Config(
|
|
181
|
+
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
|
|
182
|
+
num_stages=4,
|
|
183
|
+
num_warps=4,
|
|
184
|
+
),
|
|
185
|
+
Config(
|
|
186
|
+
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
|
|
187
|
+
num_stages=4,
|
|
188
|
+
num_warps=4,
|
|
189
|
+
),
|
|
190
|
+
Config(
|
|
191
|
+
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
|
|
192
|
+
num_stages=4,
|
|
193
|
+
num_warps=4,
|
|
194
|
+
),
|
|
195
|
+
Config(
|
|
196
|
+
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
|
|
197
|
+
num_stages=5,
|
|
198
|
+
num_warps=2,
|
|
199
|
+
),
|
|
200
|
+
] + get_configs_io_bound()
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@triton.autotune(
|
|
204
|
+
configs=MATMUL_CONFIGS,
|
|
205
|
+
prune_configs_by={
|
|
206
|
+
"early_config_prune": dummy_prune_configs,
|
|
207
|
+
},
|
|
208
|
+
key=[
|
|
209
|
+
"m_key",
|
|
210
|
+
"n_key",
|
|
211
|
+
"k_key",
|
|
212
|
+
],
|
|
213
|
+
)
|
|
214
|
+
@triton.jit
|
|
215
|
+
def _kernel_matmul_fp8_row(
|
|
216
|
+
A_ptr,
|
|
217
|
+
B_ptr,
|
|
218
|
+
C_ptr,
|
|
219
|
+
M,
|
|
220
|
+
N,
|
|
221
|
+
K,
|
|
222
|
+
m_key,
|
|
223
|
+
n_key,
|
|
224
|
+
k_key,
|
|
225
|
+
A_scale,
|
|
226
|
+
B_scale,
|
|
227
|
+
Bias,
|
|
228
|
+
stride_am,
|
|
229
|
+
stride_ak,
|
|
230
|
+
stride_bn,
|
|
231
|
+
stride_bk,
|
|
232
|
+
stride_cm,
|
|
233
|
+
stride_cn,
|
|
234
|
+
dot_out_dtype: tl.constexpr,
|
|
235
|
+
allow_tf32: tl.constexpr,
|
|
236
|
+
fp8_fast_accum: tl.constexpr,
|
|
237
|
+
skip_scaling_a: tl.constexpr,
|
|
238
|
+
BLOCK_M: tl.constexpr,
|
|
239
|
+
BLOCK_N: tl.constexpr,
|
|
240
|
+
BLOCK_K: tl.constexpr,
|
|
241
|
+
GROUP_M: tl.constexpr,
|
|
242
|
+
SPLIT_K: tl.constexpr,
|
|
243
|
+
USE_BIAS: tl.constexpr,
|
|
244
|
+
AB_DTYPE: tl.constexpr,
|
|
245
|
+
NUM_SMS: tl.constexpr,
|
|
246
|
+
) -> None:
|
|
247
|
+
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
|
|
248
|
+
|
|
249
|
+
performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
A (TensorWrapper): [M, K] input tensor.
|
|
253
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
254
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
255
|
+
M (int): M dimension of input tensor.
|
|
256
|
+
N (int): N dimension of input tensor.
|
|
257
|
+
K (int): K dimension of input tensor.
|
|
258
|
+
m_key (int): Autotuning key for M dimension of input tensor.
|
|
259
|
+
n_key (int): Autotuning key for N dimension of input tensor.
|
|
260
|
+
k_key (int): Autotuning key for K dimension of input tensor.
|
|
261
|
+
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A.
|
|
262
|
+
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B.
|
|
263
|
+
Bias (tensorWrapper): [N] Optional bias tensor.
|
|
264
|
+
stride_am (int): Stride of M dimension of A.
|
|
265
|
+
stride_ak (int): Stride of K dimension of A.
|
|
266
|
+
stride_bn (int): Stride of N dimension of B.
|
|
267
|
+
stride_bk (int): Stride of K dimension of B.
|
|
268
|
+
stride_cm (int): Stride of M dimension of C.
|
|
269
|
+
stride_cn (int): Stride of N dimension of C.
|
|
270
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
271
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
272
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
273
|
+
BLOCK_M (int): Block size for M dimension.
|
|
274
|
+
BLOCK_N (int): Block size for N dimension.
|
|
275
|
+
BLOCK_K (int): Block size for K dimension.
|
|
276
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
277
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
278
|
+
USE_BIAS (bool): Whether to use bias.
|
|
279
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
280
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
281
|
+
"""
|
|
282
|
+
# Matrix multiplication.
|
|
283
|
+
start_pid = tl.program_id(axis=0)
|
|
284
|
+
num_pid_m = tl.cdiv(M, BLOCK_M)
|
|
285
|
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
|
286
|
+
k_tiles = tl.cdiv(K, BLOCK_K)
|
|
287
|
+
num_tiles = num_pid_m * num_pid_n
|
|
288
|
+
|
|
289
|
+
offs_k_for_mask = tl.arange(0, BLOCK_K)
|
|
290
|
+
|
|
291
|
+
num_pid_in_group = GROUP_M * num_pid_n
|
|
292
|
+
|
|
293
|
+
acc_dtype = tl.float32 if allow_tf32 else dot_out_dtype
|
|
294
|
+
|
|
295
|
+
# Outer loop over tiles assigned to this SM
|
|
296
|
+
for tile_id in range(start_pid, num_tiles, NUM_SMS):
|
|
297
|
+
group_id = tile_id // num_pid_in_group
|
|
298
|
+
first_pid_m = group_id * GROUP_M
|
|
299
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
|
|
300
|
+
# pyre-ignore[58]: `%` is not supported for operand types `int` and `tl.core.constexpr`.
|
|
301
|
+
pid_m = first_pid_m + (tile_id % group_size_m)
|
|
302
|
+
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
|
303
|
+
|
|
304
|
+
start_m = pid_m * BLOCK_M
|
|
305
|
+
start_n = pid_n * BLOCK_N
|
|
306
|
+
offs_am = start_m + tl.arange(0, BLOCK_M)
|
|
307
|
+
offs_bn = start_n + tl.arange(0, BLOCK_N)
|
|
308
|
+
offs_am = tl.where(offs_am < M, offs_am, 0)
|
|
309
|
+
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
|
|
310
|
+
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_M), BLOCK_M)
|
|
311
|
+
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_N), BLOCK_N)
|
|
312
|
+
|
|
313
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
|
|
314
|
+
|
|
315
|
+
# Inner loop over K dimension
|
|
316
|
+
for ki in range(0, k_tiles):
|
|
317
|
+
offs_k = ki * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
318
|
+
A = A_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
319
|
+
B = B_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
320
|
+
|
|
321
|
+
a = tl.load(A, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_K, other=0.0)
|
|
322
|
+
b = tl.load(B, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_K, other=0.0)
|
|
323
|
+
acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32)
|
|
324
|
+
|
|
325
|
+
# rematerialize rm and rn to save registers
|
|
326
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
327
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
328
|
+
|
|
329
|
+
# Invert scaling.
|
|
330
|
+
b_scale = tl.load(B_scale + rn, mask=rn < N)
|
|
331
|
+
if skip_scaling_a:
|
|
332
|
+
acc *= b_scale[None, :]
|
|
333
|
+
else:
|
|
334
|
+
a_scale = tl.load(A_scale + rm, mask=rm < M)
|
|
335
|
+
# pyre-ignore[16]: Undefined attribute [16]: `float`
|
|
336
|
+
# has no attribute `__getitem__`.
|
|
337
|
+
scale = a_scale[:, None] * b_scale[None, :]
|
|
338
|
+
acc *= scale
|
|
339
|
+
|
|
340
|
+
# Load and add bias if specified.
|
|
341
|
+
if USE_BIAS:
|
|
342
|
+
bias = tl.load(Bias + rn, mask=rn < N)
|
|
343
|
+
acc += bias[None, :]
|
|
344
|
+
|
|
345
|
+
acc = acc.to(C_ptr.dtype.element_ty)
|
|
346
|
+
C = C_ptr + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
|
347
|
+
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
|
348
|
+
# Handles write-back with reduction-splitting
|
|
349
|
+
tl.store(C, acc, mask=mask)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
@triton.autotune(
|
|
353
|
+
configs=MATMUL_CONFIGS
|
|
354
|
+
+ [
|
|
355
|
+
Config(
|
|
356
|
+
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
357
|
+
num_stages=3,
|
|
358
|
+
num_warps=8,
|
|
359
|
+
),
|
|
360
|
+
],
|
|
361
|
+
key=[
|
|
362
|
+
"m_key",
|
|
363
|
+
"n_key",
|
|
364
|
+
"k_key",
|
|
365
|
+
],
|
|
366
|
+
)
|
|
367
|
+
@triton.heuristics(
|
|
368
|
+
{
|
|
369
|
+
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
370
|
+
}
|
|
371
|
+
)
|
|
372
|
+
@triton.jit
|
|
373
|
+
def _kernel_matmul_fp8_row_no_fast_acc(
|
|
374
|
+
A_ptr,
|
|
375
|
+
B_ptr,
|
|
376
|
+
C_ptr,
|
|
377
|
+
M,
|
|
378
|
+
N,
|
|
379
|
+
K,
|
|
380
|
+
m_key,
|
|
381
|
+
n_key,
|
|
382
|
+
k_key,
|
|
383
|
+
A_scale,
|
|
384
|
+
B_scale,
|
|
385
|
+
Bias,
|
|
386
|
+
stride_am,
|
|
387
|
+
stride_ak,
|
|
388
|
+
stride_bn,
|
|
389
|
+
stride_bk,
|
|
390
|
+
stride_cm,
|
|
391
|
+
stride_cn,
|
|
392
|
+
dot_out_dtype: tl.constexpr,
|
|
393
|
+
allow_tf32: tl.constexpr,
|
|
394
|
+
fp8_fast_accum: tl.constexpr,
|
|
395
|
+
BLOCK_M: tl.constexpr,
|
|
396
|
+
BLOCK_N: tl.constexpr,
|
|
397
|
+
BLOCK_K: tl.constexpr,
|
|
398
|
+
GROUP_M: tl.constexpr,
|
|
399
|
+
SPLIT_K: tl.constexpr,
|
|
400
|
+
EVEN_K: tl.constexpr,
|
|
401
|
+
USE_BIAS: tl.constexpr,
|
|
402
|
+
AB_DTYPE: tl.constexpr,
|
|
403
|
+
NUM_SMS: tl.constexpr,
|
|
404
|
+
) -> None:
|
|
405
|
+
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
|
|
406
|
+
|
|
407
|
+
performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
A (TensorWrapper): [M, K] input tensor.
|
|
411
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
412
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
413
|
+
M (int): M dimension of input tensor.
|
|
414
|
+
N (int): N dimension of input tensor.
|
|
415
|
+
K (int): K dimension of input tensor.
|
|
416
|
+
m_key (int): Autotuning key for M dimension of input tensor.
|
|
417
|
+
n_key (int): Autotuning key for N dimension of input tensor.
|
|
418
|
+
k_key (int): Autotuning key for K dimension of input tensor.
|
|
419
|
+
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
|
|
420
|
+
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
|
|
421
|
+
Bias (TensorWrapper): [N] Optional bias tensor.
|
|
422
|
+
stride_am (int): Stride of M dimension of A.
|
|
423
|
+
stride_ak (int): Stride of K dimension of A.
|
|
424
|
+
stride_bn (int): Stride of N dimension of B.
|
|
425
|
+
stride_bk (int): Stride of K dimension of B.
|
|
426
|
+
stride_cm (int): Stride of M dimension of C.
|
|
427
|
+
stride_cn (int): Stride of N dimension of C.
|
|
428
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
429
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
430
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
431
|
+
BLOCK_M (int): Block size for M dimension.
|
|
432
|
+
BLOCK_N (int): Block size for N dimension.
|
|
433
|
+
BLOCK_K (int): Block size for K dimension.
|
|
434
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
435
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
436
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
437
|
+
USE_BIAS(bool): Whether to use bias.
|
|
438
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
439
|
+
"""
|
|
440
|
+
# Matrix multiplication.
|
|
441
|
+
|
|
442
|
+
start_pid = tl.program_id(axis=0)
|
|
443
|
+
num_pid_m = tl.cdiv(M, BLOCK_M)
|
|
444
|
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
|
445
|
+
k_tiles = tl.cdiv(K, BLOCK_K)
|
|
446
|
+
num_tiles = num_pid_m * num_pid_n
|
|
447
|
+
|
|
448
|
+
tiles_per_SM = num_tiles // NUM_SMS
|
|
449
|
+
if start_pid < num_tiles % NUM_SMS:
|
|
450
|
+
tiles_per_SM += 1
|
|
451
|
+
|
|
452
|
+
tile_id = start_pid - NUM_SMS
|
|
453
|
+
ki = -1
|
|
454
|
+
|
|
455
|
+
offs_k_for_mask = tl.arange(0, BLOCK_K)
|
|
456
|
+
|
|
457
|
+
num_pid_in_group = GROUP_M * num_pid_n
|
|
458
|
+
|
|
459
|
+
pid_m = 0
|
|
460
|
+
pid_n = 0
|
|
461
|
+
offs_am = tl.arange(0, BLOCK_M)
|
|
462
|
+
offs_bn = tl.arange(0, BLOCK_N)
|
|
463
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
464
|
+
|
|
465
|
+
for _ in range(0, k_tiles * tiles_per_SM):
|
|
466
|
+
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
|
|
467
|
+
if ki == 0:
|
|
468
|
+
tile_id += NUM_SMS
|
|
469
|
+
group_id = tile_id // num_pid_in_group
|
|
470
|
+
first_pid_m = group_id * GROUP_M
|
|
471
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
|
|
472
|
+
pid_m = first_pid_m + (tile_id % group_size_m)
|
|
473
|
+
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
|
474
|
+
|
|
475
|
+
start_m = pid_m * BLOCK_M
|
|
476
|
+
start_n = pid_n * BLOCK_N
|
|
477
|
+
offs_am = start_m + tl.arange(0, BLOCK_M)
|
|
478
|
+
offs_bn = start_n + tl.arange(0, BLOCK_N)
|
|
479
|
+
offs_am = tl.where(offs_am < M, offs_am, 0)
|
|
480
|
+
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
|
|
481
|
+
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_M), BLOCK_M)
|
|
482
|
+
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_N), BLOCK_N)
|
|
483
|
+
offs_k = ki * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
484
|
+
A = A_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
485
|
+
B = B_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
486
|
+
|
|
487
|
+
a = tl.load(A, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_K, other=0.0)
|
|
488
|
+
b = tl.load(B, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_K, other=0.0)
|
|
489
|
+
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
|
490
|
+
|
|
491
|
+
if ki == k_tiles - 1:
|
|
492
|
+
# rematerialize rm and rn to save registers
|
|
493
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
494
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
495
|
+
|
|
496
|
+
# Invert scaling.
|
|
497
|
+
a_scale = tl.load(A_scale + rm, mask=rm < M)
|
|
498
|
+
b_scale = tl.load(B_scale + rn, mask=rn < N)
|
|
499
|
+
# pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
|
|
500
|
+
scale = a_scale[:, None] * b_scale[None, :]
|
|
501
|
+
acc *= scale
|
|
502
|
+
|
|
503
|
+
# Load and add bias if specified.
|
|
504
|
+
if USE_BIAS:
|
|
505
|
+
bias = tl.load(Bias + rn, mask=rn < N)
|
|
506
|
+
acc += bias[None, :]
|
|
507
|
+
|
|
508
|
+
acc = acc.to(C_ptr.dtype.element_ty)
|
|
509
|
+
C = C_ptr + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
|
510
|
+
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
|
511
|
+
# Handles write-back with reduction-splitting
|
|
512
|
+
tl.store(C, acc, mask=mask)
|
|
513
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
@triton.autotune(
|
|
517
|
+
configs=MATMUL_CONFIGS,
|
|
518
|
+
key=[
|
|
519
|
+
"m_key",
|
|
520
|
+
"n_key",
|
|
521
|
+
"k_key",
|
|
522
|
+
],
|
|
523
|
+
)
|
|
524
|
+
@triton.heuristics(
|
|
525
|
+
{
|
|
526
|
+
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
527
|
+
}
|
|
528
|
+
)
|
|
529
|
+
@triton.jit
|
|
530
|
+
def _kernel_matmul_fp8_row_imprecise_acc(
|
|
531
|
+
A,
|
|
532
|
+
B,
|
|
533
|
+
C,
|
|
534
|
+
M,
|
|
535
|
+
N,
|
|
536
|
+
K,
|
|
537
|
+
m_key,
|
|
538
|
+
n_key,
|
|
539
|
+
k_key,
|
|
540
|
+
A_scale,
|
|
541
|
+
B_scale,
|
|
542
|
+
Bias,
|
|
543
|
+
stride_am,
|
|
544
|
+
stride_ak,
|
|
545
|
+
stride_bn,
|
|
546
|
+
stride_bk,
|
|
547
|
+
stride_cm,
|
|
548
|
+
stride_cn,
|
|
549
|
+
dot_out_dtype: tl.constexpr,
|
|
550
|
+
allow_tf32: tl.constexpr,
|
|
551
|
+
fp8_fast_accum: tl.constexpr,
|
|
552
|
+
BLOCK_M: tl.constexpr,
|
|
553
|
+
BLOCK_N: tl.constexpr,
|
|
554
|
+
BLOCK_K: tl.constexpr,
|
|
555
|
+
GROUP_M: tl.constexpr,
|
|
556
|
+
SPLIT_K: tl.constexpr,
|
|
557
|
+
EVEN_K: tl.constexpr,
|
|
558
|
+
USE_BIAS: tl.constexpr,
|
|
559
|
+
AB_DTYPE: tl.constexpr,
|
|
560
|
+
) -> None:
|
|
561
|
+
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
|
|
562
|
+
|
|
563
|
+
performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
|
|
564
|
+
|
|
565
|
+
Args:
|
|
566
|
+
A (TensorWrapper): [M, K] input tensor.
|
|
567
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
568
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
569
|
+
M (int): M dimension of input tensor.
|
|
570
|
+
N (int): N dimension of input tensor.
|
|
571
|
+
K (int): K dimension of input tensor.
|
|
572
|
+
m_key (int): Autotuning key for M dimension of input tensor.
|
|
573
|
+
n_key (int): Autotuning key for N dimension of input tensor.
|
|
574
|
+
k_key (int): Autotuning key for K dimension of input tensor.
|
|
575
|
+
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
|
|
576
|
+
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
|
|
577
|
+
Bias (TensorWrapper): [N] Optional bias tensor.
|
|
578
|
+
stride_am (int): Stride of M dimension of A.
|
|
579
|
+
stride_ak (int): Stride of K dimension of A.
|
|
580
|
+
stride_bn (int): Stride of N dimension of B.
|
|
581
|
+
stride_bk (int): Stride of K dimension of B.
|
|
582
|
+
stride_cm (int): Stride of M dimension of C.
|
|
583
|
+
stride_cn (int): Stride of N dimension of C.
|
|
584
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
585
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
586
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
587
|
+
BLOCK_M (int): Block size for M dimension.
|
|
588
|
+
BLOCK_N (int): Block size for N dimension.
|
|
589
|
+
BLOCK_K (int): Block size for K dimension.
|
|
590
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
591
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
592
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
593
|
+
USE_BIAS (bool): Whether to use bias.
|
|
594
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
595
|
+
"""
|
|
596
|
+
# Matrix multiplication.
|
|
597
|
+
pid = tl.program_id(0)
|
|
598
|
+
pid_z = tl.program_id(1)
|
|
599
|
+
grid_m = tl.cdiv(M, BLOCK_M)
|
|
600
|
+
grid_n = tl.cdiv(N, BLOCK_N)
|
|
601
|
+
# Re-order program ID for better L2 performance (swizzle).
|
|
602
|
+
width = GROUP_M * grid_n
|
|
603
|
+
group_id = pid // width
|
|
604
|
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
|
605
|
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
|
606
|
+
pid_n = (pid % width) // (group_size)
|
|
607
|
+
# Do matrix multiplication.
|
|
608
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
609
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
610
|
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
|
611
|
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
|
612
|
+
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
613
|
+
# Pointers.
|
|
614
|
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
|
615
|
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
|
616
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
617
|
+
|
|
618
|
+
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
|
619
|
+
if EVEN_K:
|
|
620
|
+
a = tl.load(A)
|
|
621
|
+
b = tl.load(B)
|
|
622
|
+
else:
|
|
623
|
+
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
|
624
|
+
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
|
|
625
|
+
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
|
|
626
|
+
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
|
|
627
|
+
if AB_DTYPE:
|
|
628
|
+
a = a.to(C.dtype.element_ty)
|
|
629
|
+
b = b.to(C.dtype.element_ty)
|
|
630
|
+
if fp8_fast_accum:
|
|
631
|
+
acc = tl.dot(
|
|
632
|
+
a,
|
|
633
|
+
b,
|
|
634
|
+
acc,
|
|
635
|
+
max_num_imprecise_acc=32,
|
|
636
|
+
out_dtype=dot_out_dtype,
|
|
637
|
+
allow_tf32=allow_tf32,
|
|
638
|
+
)
|
|
639
|
+
else:
|
|
640
|
+
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
|
641
|
+
|
|
642
|
+
A += BLOCK_K * SPLIT_K * stride_ak
|
|
643
|
+
B += BLOCK_K * SPLIT_K * stride_bk
|
|
644
|
+
|
|
645
|
+
# rematerialize rm and rn to save registers
|
|
646
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
647
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
648
|
+
|
|
649
|
+
# Invert scaling.
|
|
650
|
+
a_scale = tl.load(A_scale + rm, mask=rm < M)
|
|
651
|
+
b_scale = tl.load(B_scale + rn, mask=rn < N)
|
|
652
|
+
# pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
|
|
653
|
+
scale = a_scale[:, None] * b_scale[None, :]
|
|
654
|
+
acc *= scale
|
|
655
|
+
|
|
656
|
+
# Apply bias.
|
|
657
|
+
if USE_BIAS:
|
|
658
|
+
bias = tl.load(Bias + rn, mask=rn < N)
|
|
659
|
+
acc += bias[None, :]
|
|
660
|
+
|
|
661
|
+
acc = acc.to(C.dtype.element_ty)
|
|
662
|
+
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
|
663
|
+
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
|
664
|
+
# Handles write-back with reduction-splitting
|
|
665
|
+
if SPLIT_K == 1:
|
|
666
|
+
tl.store(C, acc, mask=mask)
|
|
667
|
+
else:
|
|
668
|
+
tl.atomic_add(C, acc, mask=mask)
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
@triton.autotune(
|
|
672
|
+
configs=[
|
|
673
|
+
Config(
|
|
674
|
+
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
675
|
+
num_stages=3,
|
|
676
|
+
num_warps=8,
|
|
677
|
+
),
|
|
678
|
+
Config(
|
|
679
|
+
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
680
|
+
num_stages=3,
|
|
681
|
+
num_warps=8,
|
|
682
|
+
),
|
|
683
|
+
Config(
|
|
684
|
+
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
685
|
+
num_stages=4,
|
|
686
|
+
num_warps=4,
|
|
687
|
+
),
|
|
688
|
+
Config(
|
|
689
|
+
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
690
|
+
num_stages=4,
|
|
691
|
+
num_warps=4,
|
|
692
|
+
),
|
|
693
|
+
Config(
|
|
694
|
+
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
695
|
+
num_stages=4,
|
|
696
|
+
num_warps=4,
|
|
697
|
+
),
|
|
698
|
+
Config(
|
|
699
|
+
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
|
|
700
|
+
num_stages=4,
|
|
701
|
+
num_warps=4,
|
|
702
|
+
),
|
|
703
|
+
Config(
|
|
704
|
+
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
|
|
705
|
+
num_stages=4,
|
|
706
|
+
num_warps=4,
|
|
707
|
+
),
|
|
708
|
+
Config(
|
|
709
|
+
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 512, "SPLIT_K": 1},
|
|
710
|
+
num_stages=3,
|
|
711
|
+
num_warps=4,
|
|
712
|
+
),
|
|
713
|
+
],
|
|
714
|
+
key=[
|
|
715
|
+
"m_key",
|
|
716
|
+
"n_key",
|
|
717
|
+
"k_key",
|
|
718
|
+
],
|
|
719
|
+
use_cuda_graph=True,
|
|
720
|
+
)
|
|
721
|
+
@triton.heuristics(
|
|
722
|
+
{
|
|
723
|
+
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
724
|
+
}
|
|
725
|
+
)
|
|
726
|
+
@triton.jit
|
|
727
|
+
def _kernel_matmul_fp8_row_tma_persistent(
|
|
728
|
+
A_ptr,
|
|
729
|
+
B_ptr,
|
|
730
|
+
C_ptr,
|
|
731
|
+
M,
|
|
732
|
+
N,
|
|
733
|
+
K,
|
|
734
|
+
m_key,
|
|
735
|
+
n_key,
|
|
736
|
+
k_key,
|
|
737
|
+
A_scale,
|
|
738
|
+
B_scale,
|
|
739
|
+
Bias,
|
|
740
|
+
stride_am,
|
|
741
|
+
stride_ak,
|
|
742
|
+
stride_bn,
|
|
743
|
+
stride_bk,
|
|
744
|
+
stride_cm,
|
|
745
|
+
stride_cn,
|
|
746
|
+
dot_out_dtype: tl.constexpr,
|
|
747
|
+
c_dtype: tl.constexpr,
|
|
748
|
+
bias_dtype: tl.constexpr,
|
|
749
|
+
allow_tf32: tl.constexpr,
|
|
750
|
+
fp8_fast_accum: tl.constexpr,
|
|
751
|
+
BLOCK_M: tl.constexpr,
|
|
752
|
+
BLOCK_N: tl.constexpr,
|
|
753
|
+
BLOCK_K: tl.constexpr,
|
|
754
|
+
GROUP_M: tl.constexpr,
|
|
755
|
+
AB_DTYPE: tl.constexpr,
|
|
756
|
+
SPLIT_K: tl.constexpr,
|
|
757
|
+
EVEN_K: tl.constexpr,
|
|
758
|
+
NUM_SMS: tl.constexpr,
|
|
759
|
+
USE_BIAS: tl.constexpr,
|
|
760
|
+
) -> None:
|
|
761
|
+
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
|
|
762
|
+
|
|
763
|
+
performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
|
|
764
|
+
|
|
765
|
+
Args:
|
|
766
|
+
A (TensorWrapper): [M, K] input tensor.
|
|
767
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
768
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
769
|
+
M (int): M dimension of input tensor.
|
|
770
|
+
N (int): N dimension of input tensor.
|
|
771
|
+
K (int): K dimension of input tensor.
|
|
772
|
+
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
|
|
773
|
+
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
|
|
774
|
+
stride_am (int): Stride of M dimension of A.
|
|
775
|
+
stride_ak (int): Stride of K dimension of A.
|
|
776
|
+
stride_bn (int): Stride of N dimension of B.
|
|
777
|
+
stride_bk (int): Stride of K dimension of B.
|
|
778
|
+
stride_cm (int): Stride of M dimension of C.
|
|
779
|
+
stride_cn (int): Stride of N dimension of C.
|
|
780
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
781
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
782
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
783
|
+
BLOCK_M (int): Block size for M dimension.
|
|
784
|
+
BLOCK_N (int): Block size for N dimension.
|
|
785
|
+
BLOCK_K (int): Block size for K dimension.
|
|
786
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
787
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
788
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
789
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
790
|
+
"""
|
|
791
|
+
# Matrix multiplication.
|
|
792
|
+
start_pid = tl.program_id(axis=0)
|
|
793
|
+
num_pid_m = tl.cdiv(M, BLOCK_M)
|
|
794
|
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
|
795
|
+
k_tiles = tl.cdiv(K, BLOCK_K)
|
|
796
|
+
num_tiles = num_pid_m * num_pid_n
|
|
797
|
+
|
|
798
|
+
tiles_per_SM = num_tiles // NUM_SMS
|
|
799
|
+
if start_pid < num_tiles % NUM_SMS:
|
|
800
|
+
tiles_per_SM += 1
|
|
801
|
+
|
|
802
|
+
tile_id = start_pid - NUM_SMS
|
|
803
|
+
ki = -1
|
|
804
|
+
|
|
805
|
+
pid_m = 0
|
|
806
|
+
pid_n = 0
|
|
807
|
+
offs_am = 0
|
|
808
|
+
offs_bn = 0
|
|
809
|
+
|
|
810
|
+
num_pid_in_group = GROUP_M * num_pid_n
|
|
811
|
+
|
|
812
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
813
|
+
|
|
814
|
+
dtype_fp8 = tl.float8e4nv
|
|
815
|
+
scale_dtype = tl.float32
|
|
816
|
+
|
|
817
|
+
for _ in range(0, k_tiles * tiles_per_SM):
|
|
818
|
+
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
|
|
819
|
+
if ki == 0:
|
|
820
|
+
tile_id += NUM_SMS
|
|
821
|
+
group_id = tile_id // num_pid_in_group
|
|
822
|
+
first_pid_m = group_id * GROUP_M
|
|
823
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
|
|
824
|
+
pid_m = first_pid_m + (tile_id % group_size_m)
|
|
825
|
+
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
|
826
|
+
|
|
827
|
+
offs_am = pid_m * BLOCK_M
|
|
828
|
+
offs_bn = pid_n * BLOCK_N
|
|
829
|
+
offs_am = tl.multiple_of(offs_am, BLOCK_M)
|
|
830
|
+
offs_bn = tl.multiple_of(offs_bn, BLOCK_N)
|
|
831
|
+
|
|
832
|
+
offs_k = ki * BLOCK_K
|
|
833
|
+
|
|
834
|
+
a = tl._experimental_descriptor_load(
|
|
835
|
+
A_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], dtype_fp8
|
|
836
|
+
)
|
|
837
|
+
b = tl._experimental_descriptor_load(
|
|
838
|
+
B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
if fp8_fast_accum:
|
|
842
|
+
acc = tl.dot(a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
|
843
|
+
else:
|
|
844
|
+
acc += tl.dot(a, b.T, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
|
845
|
+
|
|
846
|
+
if ki == k_tiles - 1:
|
|
847
|
+
# rematerialize rm and rn to save registers
|
|
848
|
+
|
|
849
|
+
# # Invert scaling.
|
|
850
|
+
a_scale = tl._experimental_descriptor_load(
|
|
851
|
+
A_scale, [offs_am], [BLOCK_M], scale_dtype
|
|
852
|
+
)
|
|
853
|
+
b_scale = tl._experimental_descriptor_load(
|
|
854
|
+
B_scale, [offs_bn], [BLOCK_N], scale_dtype
|
|
855
|
+
)
|
|
856
|
+
# pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
|
|
857
|
+
scale = a_scale[:, None] * b_scale[None, :]
|
|
858
|
+
acc *= scale
|
|
859
|
+
|
|
860
|
+
# Load and add bias if specified.
|
|
861
|
+
if USE_BIAS:
|
|
862
|
+
bias = tl._experimental_descriptor_load(
|
|
863
|
+
Bias, [offs_bn], [BLOCK_N], bias_dtype
|
|
864
|
+
)
|
|
865
|
+
acc += bias[None, :]
|
|
866
|
+
|
|
867
|
+
acc = acc.to(c_dtype)
|
|
868
|
+
tl._experimental_descriptor_store(C_ptr, acc, [offs_am, offs_bn])
|
|
869
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
870
|
+
|
|
871
|
+
|
|
872
|
+
has_warp_specialization = hasattr(tl, "async_task")
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
def make_autotuner_config(dictargs, **kwargs):
|
|
876
|
+
# NOTE: Triton 3.4.x removed some keyword arguments from Config constructor;
|
|
877
|
+
# however, fbcode uses 3.3.1, and so this shim is provided to support both
|
|
878
|
+
# versions.
|
|
879
|
+
#
|
|
880
|
+
# https://github.com/triton-lang/triton/blob/v3.3.1/python/triton/runtime/autotuner.py#L275
|
|
881
|
+
# https://github.com/triton-lang/triton/blame/release/3.4.x/python/triton/runtime/autotuner.py#L319
|
|
882
|
+
if version.parse(triton.__version__) > version.parse("3.3.1"):
|
|
883
|
+
for key in ["num_buffers_warp_spec", "num_consumer_groups"]:
|
|
884
|
+
kwargs.pop(key, None)
|
|
885
|
+
return Config(dictargs, **kwargs)
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
def get_ws_configs() -> List[Config]:
|
|
889
|
+
if not has_warp_specialization:
|
|
890
|
+
return []
|
|
891
|
+
return [
|
|
892
|
+
make_autotuner_config(
|
|
893
|
+
{
|
|
894
|
+
"BLOCK_M": 128,
|
|
895
|
+
"BLOCK_N": 256,
|
|
896
|
+
"BLOCK_K": 128,
|
|
897
|
+
"SPLIT_K": 1,
|
|
898
|
+
"NUM_CONSUMER_GROUPS": 2,
|
|
899
|
+
},
|
|
900
|
+
num_stages=3,
|
|
901
|
+
num_warps=4,
|
|
902
|
+
num_consumer_groups=2,
|
|
903
|
+
num_buffers_warp_spec=3,
|
|
904
|
+
),
|
|
905
|
+
make_autotuner_config(
|
|
906
|
+
{
|
|
907
|
+
"BLOCK_M": 128,
|
|
908
|
+
"BLOCK_N": 128,
|
|
909
|
+
"BLOCK_K": 128,
|
|
910
|
+
"SPLIT_K": 1,
|
|
911
|
+
"NUM_CONSUMER_GROUPS": 2,
|
|
912
|
+
},
|
|
913
|
+
num_stages=4,
|
|
914
|
+
num_warps=4,
|
|
915
|
+
num_consumer_groups=2,
|
|
916
|
+
num_buffers_warp_spec=4,
|
|
917
|
+
),
|
|
918
|
+
make_autotuner_config(
|
|
919
|
+
{
|
|
920
|
+
"BLOCK_M": 128,
|
|
921
|
+
"BLOCK_N": 256,
|
|
922
|
+
"BLOCK_K": 128,
|
|
923
|
+
"SPLIT_K": 1,
|
|
924
|
+
"NUM_CONSUMER_GROUPS": 1,
|
|
925
|
+
},
|
|
926
|
+
num_stages=3,
|
|
927
|
+
num_warps=8,
|
|
928
|
+
num_consumer_groups=0,
|
|
929
|
+
num_buffers_warp_spec=3,
|
|
930
|
+
),
|
|
931
|
+
make_autotuner_config(
|
|
932
|
+
{
|
|
933
|
+
"BLOCK_M": 64,
|
|
934
|
+
"BLOCK_N": 64,
|
|
935
|
+
"BLOCK_K": 512,
|
|
936
|
+
"SPLIT_K": 1,
|
|
937
|
+
"NUM_CONSUMER_GROUPS": 1,
|
|
938
|
+
},
|
|
939
|
+
num_stages=3,
|
|
940
|
+
num_warps=4,
|
|
941
|
+
num_consumer_groups=0,
|
|
942
|
+
num_buffers_warp_spec=3,
|
|
943
|
+
),
|
|
944
|
+
]
|
|
945
|
+
|
|
946
|
+
|
|
947
|
+
@triton.autotune(
|
|
948
|
+
configs=[
|
|
949
|
+
Config(
|
|
950
|
+
{
|
|
951
|
+
"BLOCK_M": 128,
|
|
952
|
+
"BLOCK_N": 256,
|
|
953
|
+
"BLOCK_K": 128,
|
|
954
|
+
"SPLIT_K": 1,
|
|
955
|
+
"NUM_CONSUMER_GROUPS": 1,
|
|
956
|
+
},
|
|
957
|
+
num_stages=3,
|
|
958
|
+
num_warps=8,
|
|
959
|
+
),
|
|
960
|
+
]
|
|
961
|
+
+ get_ws_configs(),
|
|
962
|
+
key=[
|
|
963
|
+
"m_key",
|
|
964
|
+
"n_key",
|
|
965
|
+
"k_key",
|
|
966
|
+
],
|
|
967
|
+
use_cuda_graph=True,
|
|
968
|
+
)
|
|
969
|
+
@triton.heuristics(
|
|
970
|
+
{
|
|
971
|
+
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
972
|
+
}
|
|
973
|
+
)
|
|
974
|
+
@triton.jit
|
|
975
|
+
def _kernel_matmul_fp8_row_tma_persistent_ws_cooperative(
|
|
976
|
+
A_ptr,
|
|
977
|
+
B_ptr,
|
|
978
|
+
C_ptr,
|
|
979
|
+
M,
|
|
980
|
+
N,
|
|
981
|
+
K,
|
|
982
|
+
m_key,
|
|
983
|
+
n_key,
|
|
984
|
+
k_key,
|
|
985
|
+
A_scale,
|
|
986
|
+
B_scale,
|
|
987
|
+
Bias,
|
|
988
|
+
stride_am,
|
|
989
|
+
stride_ak,
|
|
990
|
+
stride_bn,
|
|
991
|
+
stride_bk,
|
|
992
|
+
stride_cm,
|
|
993
|
+
stride_cn,
|
|
994
|
+
dot_out_dtype: tl.constexpr,
|
|
995
|
+
c_dtype: tl.constexpr,
|
|
996
|
+
bias_dtype: tl.constexpr,
|
|
997
|
+
allow_tf32: tl.constexpr,
|
|
998
|
+
fp8_fast_accum: tl.constexpr,
|
|
999
|
+
BLOCK_M: tl.constexpr,
|
|
1000
|
+
BLOCK_N: tl.constexpr,
|
|
1001
|
+
BLOCK_K: tl.constexpr,
|
|
1002
|
+
GROUP_M: tl.constexpr,
|
|
1003
|
+
AB_DTYPE: tl.constexpr,
|
|
1004
|
+
SPLIT_K: tl.constexpr,
|
|
1005
|
+
EVEN_K: tl.constexpr,
|
|
1006
|
+
NUM_SMS: tl.constexpr,
|
|
1007
|
+
USE_BIAS: tl.constexpr,
|
|
1008
|
+
NUM_CONSUMER_GROUPS: tl.constexpr,
|
|
1009
|
+
) -> None:
|
|
1010
|
+
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
|
|
1011
|
+
|
|
1012
|
+
performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
|
|
1013
|
+
|
|
1014
|
+
Args:
|
|
1015
|
+
A (TensorWrapper): [M , K] input tensor.
|
|
1016
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
1017
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
1018
|
+
M (int): M dimension of input tensor.
|
|
1019
|
+
N (int): N dimension of input tensor.
|
|
1020
|
+
K (int): K dimension of input tensor.
|
|
1021
|
+
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
|
|
1022
|
+
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
|
|
1023
|
+
stride_am (int): Stride of M dimension of A.
|
|
1024
|
+
stride_ak (int): Stride of K dimension of A.
|
|
1025
|
+
stride_bn (int): Stride of N dimension of B.
|
|
1026
|
+
stride_bk (int): Stride of K dimension of B.
|
|
1027
|
+
stride_cm (int): Stride of M dimension of C.
|
|
1028
|
+
stride_cn (int): Stride of N dimension of C.
|
|
1029
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
1030
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
1031
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
1032
|
+
BLOCK_M (int): Block size for M dimension.
|
|
1033
|
+
BLOCK_N (int): Block size for N dimension.
|
|
1034
|
+
BLOCK_K (int): Block size for K dimension.
|
|
1035
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
1036
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
1037
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
1038
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
1039
|
+
"""
|
|
1040
|
+
num_tiles = tl.cdiv(M, BLOCK_M) * tl.cdiv(N, BLOCK_N)
|
|
1041
|
+
num_pid_m = tl.cdiv(M, BLOCK_M)
|
|
1042
|
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
|
1043
|
+
dtype_fp8 = tl.float8e4nv
|
|
1044
|
+
for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)):
|
|
1045
|
+
num_pid_in_group = GROUP_M * num_pid_n
|
|
1046
|
+
group_id = pid // num_pid_in_group
|
|
1047
|
+
first_pid_m = group_id * GROUP_M
|
|
1048
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
|
|
1049
|
+
# pyre-ignore
|
|
1050
|
+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
|
1051
|
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
1052
|
+
|
|
1053
|
+
# ----------------------------------------------------------
|
|
1054
|
+
# Create pointers for the first blocks of A and B.
|
|
1055
|
+
# We will advance this pointer as we move in the K direction
|
|
1056
|
+
# and accumulate
|
|
1057
|
+
# `a_ptrs` is a block of [BLOCK_M, BLOCK_K] pointers
|
|
1058
|
+
# `b_ptrs` is a block of [BLOCK_K, BLOCK_N] pointers
|
|
1059
|
+
# See above `Pointer Arithmetic` section for details
|
|
1060
|
+
offs_am = pid_m * BLOCK_M
|
|
1061
|
+
offs_bn = pid_n * BLOCK_N
|
|
1062
|
+
offs_k = 0
|
|
1063
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
1064
|
+
# pyre-ignore
|
|
1065
|
+
tl.assume(tl.cdiv(K, BLOCK_K) > 0)
|
|
1066
|
+
for _ in range(0, tl.cdiv(K, BLOCK_K)):
|
|
1067
|
+
# pyre-ignore
|
|
1068
|
+
with tl.async_task([0]):
|
|
1069
|
+
a = tl._experimental_descriptor_load(
|
|
1070
|
+
A_ptr,
|
|
1071
|
+
[offs_am, offs_k],
|
|
1072
|
+
[BLOCK_M, BLOCK_K],
|
|
1073
|
+
dtype_fp8,
|
|
1074
|
+
)
|
|
1075
|
+
b = tl._experimental_descriptor_load(
|
|
1076
|
+
B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8
|
|
1077
|
+
)
|
|
1078
|
+
|
|
1079
|
+
if fp8_fast_accum:
|
|
1080
|
+
acc = tl.dot(
|
|
1081
|
+
a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32
|
|
1082
|
+
)
|
|
1083
|
+
else:
|
|
1084
|
+
acc += tl.dot(a, b.T, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
|
1085
|
+
|
|
1086
|
+
offs_k += BLOCK_K
|
|
1087
|
+
|
|
1088
|
+
# pyre-ignore
|
|
1089
|
+
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
|
|
1090
|
+
# Invert scaling.
|
|
1091
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
1092
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
1093
|
+
a_scale = tl.load(A_scale + rm, mask=rm < M)
|
|
1094
|
+
b_scale = tl.load(B_scale + rn, mask=rn < N)
|
|
1095
|
+
scale = a_scale[:, None] * b_scale[None, :]
|
|
1096
|
+
acc *= scale
|
|
1097
|
+
# Load and add bias if specified.
|
|
1098
|
+
if USE_BIAS:
|
|
1099
|
+
bias = tl._experimental_descriptor_load(
|
|
1100
|
+
Bias, [offs_bn], [BLOCK_N], bias_dtype
|
|
1101
|
+
)
|
|
1102
|
+
acc += bias[None, :]
|
|
1103
|
+
acc = acc.to(c_dtype)
|
|
1104
|
+
tl._experimental_descriptor_store(C_ptr, acc, [offs_am, offs_bn])
|
|
1105
|
+
|
|
1106
|
+
|
|
1107
|
+
def _is_eligible_for_skip_scaling(
|
|
1108
|
+
is_rowwise: bool,
|
|
1109
|
+
fp8_fast_accum: bool,
|
|
1110
|
+
imprecise_acc: bool,
|
|
1111
|
+
tma_persistent: bool,
|
|
1112
|
+
no_use_persistent: Optional[bool],
|
|
1113
|
+
use_warp_specialization: bool,
|
|
1114
|
+
) -> bool:
|
|
1115
|
+
if not is_rowwise:
|
|
1116
|
+
return False
|
|
1117
|
+
|
|
1118
|
+
return (
|
|
1119
|
+
fp8_fast_accum
|
|
1120
|
+
and not imprecise_acc
|
|
1121
|
+
and not tma_persistent
|
|
1122
|
+
and not no_use_persistent
|
|
1123
|
+
and not use_warp_specialization
|
|
1124
|
+
)
|
|
1125
|
+
|
|
1126
|
+
|
|
1127
|
+
@torch._library.triton_op("triton::matmul_fp8_row", mutates_args=())
|
|
1128
|
+
def matmul_fp8_row(
|
|
1129
|
+
a: torch.Tensor,
|
|
1130
|
+
b: torch.Tensor,
|
|
1131
|
+
a_scale: Optional[torch.Tensor],
|
|
1132
|
+
b_scale: torch.Tensor,
|
|
1133
|
+
bias: Optional[torch.Tensor] = None,
|
|
1134
|
+
dot_out_dtype: Optional[torch.dtype] = None,
|
|
1135
|
+
allow_tf32: bool = True,
|
|
1136
|
+
fp8_fast_accum: bool = True,
|
|
1137
|
+
imprecise_acc: bool = False,
|
|
1138
|
+
tma_persistent: bool = True,
|
|
1139
|
+
no_use_persistent: Optional[bool] = None,
|
|
1140
|
+
use_warp_specialization: bool = False,
|
|
1141
|
+
) -> torch.Tensor:
|
|
1142
|
+
"""
|
|
1143
|
+
Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N].
|
|
1144
|
+
|
|
1145
|
+
Args:
|
|
1146
|
+
a (torch.Tensor): [M, K] input tensor.
|
|
1147
|
+
b (torch.Tensor): [N, K] input tensor.
|
|
1148
|
+
a_scale (Optiona;[torch.Tensor]): [M] reciprocal scale tensor per row.
|
|
1149
|
+
A * a_scale = original A. Scaling will be skiped if a_scale is None.
|
|
1150
|
+
b_scale (torch.Tensor): [N] reciprocal scale tensor per row. B * b_scale = original B
|
|
1151
|
+
bias (torch.Tensor): [N] optional bias tensor to add to output if provided.
|
|
1152
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
1153
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
1154
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
1155
|
+
tma_persistent (bool): Whether to use TMA persistent kernel impl.
|
|
1156
|
+
|
|
1157
|
+
Returns:
|
|
1158
|
+
torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :])
|
|
1159
|
+
"""
|
|
1160
|
+
if no_use_persistent is None:
|
|
1161
|
+
# Default True for AMD and False for Nvidia.
|
|
1162
|
+
if torch.version.hip is not None:
|
|
1163
|
+
no_use_persistent = True
|
|
1164
|
+
else:
|
|
1165
|
+
no_use_persistent = False
|
|
1166
|
+
# Get datatypes and constants to use.
|
|
1167
|
+
pt_fp8_dtype, _, _, _ = get_fp8_constants()
|
|
1168
|
+
# Handle 3D+ a shape
|
|
1169
|
+
a_shape = a.shape
|
|
1170
|
+
a = a.view(-1, a.size(-1))
|
|
1171
|
+
# View inputs into proper torch fp8 dtype.
|
|
1172
|
+
if torch.version.cuda:
|
|
1173
|
+
assert a.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
|
|
1174
|
+
elif torch.version.hip:
|
|
1175
|
+
if torch.cuda.get_device_capability() < (9, 5):
|
|
1176
|
+
assert a.dtype in (
|
|
1177
|
+
torch.float8_e4m3fnuz,
|
|
1178
|
+
torch.float8_e5m2fnuz,
|
|
1179
|
+
)
|
|
1180
|
+
else:
|
|
1181
|
+
assert a.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
|
|
1182
|
+
else:
|
|
1183
|
+
assert a.dtype in (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)
|
|
1184
|
+
assert b.dtype == pt_fp8_dtype
|
|
1185
|
+
M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device = (
|
|
1186
|
+
prep_matmul(a, b, dot_out_dtype)
|
|
1187
|
+
)
|
|
1188
|
+
|
|
1189
|
+
# Skip scaling (a_scale is None) can only be applied in certain cases.
|
|
1190
|
+
assert a_scale is not None or _is_eligible_for_skip_scaling(
|
|
1191
|
+
is_rowwise=True,
|
|
1192
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
1193
|
+
imprecise_acc=imprecise_acc,
|
|
1194
|
+
tma_persistent=tma_persistent,
|
|
1195
|
+
no_use_persistent=no_use_persistent,
|
|
1196
|
+
use_warp_specialization=use_warp_specialization,
|
|
1197
|
+
)
|
|
1198
|
+
|
|
1199
|
+
output_shape = a_shape[:-1] + (N,)
|
|
1200
|
+
# Handle tensor with empty inputs.
|
|
1201
|
+
if (M == 0) or (N == 0) or (K == 0):
|
|
1202
|
+
return torch.zeros(output_shape, device=device, dtype=torch.bfloat16)
|
|
1203
|
+
# launch kernel
|
|
1204
|
+
if a.device == torch.device("cpu"):
|
|
1205
|
+
logger.info(
|
|
1206
|
+
"FP8 Row-wise Triton kernel not supported on cpu, fallback to torch"
|
|
1207
|
+
)
|
|
1208
|
+
if a_scale is None:
|
|
1209
|
+
scale = b_scale[None, :]
|
|
1210
|
+
else:
|
|
1211
|
+
scale = a_scale[:, None] * b_scale[None, :]
|
|
1212
|
+
output = torch.matmul(a.to(torch.bfloat16), b.to(torch.bfloat16).T) * scale
|
|
1213
|
+
if bias is not None:
|
|
1214
|
+
output += bias[None, :]
|
|
1215
|
+
return output.to(c.dtype)
|
|
1216
|
+
|
|
1217
|
+
def grid(META: Dict[str, int]) -> Tuple[int, int]:
|
|
1218
|
+
return (
|
|
1219
|
+
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
|
1220
|
+
META["SPLIT_K"],
|
|
1221
|
+
)
|
|
1222
|
+
|
|
1223
|
+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
|
1224
|
+
|
|
1225
|
+
def persistent_grid(META: Dict[str, int]) -> Tuple[int]:
|
|
1226
|
+
return (
|
|
1227
|
+
min(
|
|
1228
|
+
NUM_SMS,
|
|
1229
|
+
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
|
1230
|
+
),
|
|
1231
|
+
)
|
|
1232
|
+
|
|
1233
|
+
if no_use_persistent:
|
|
1234
|
+
logger.debug("Using non-persistent kernel")
|
|
1235
|
+
with torch.cuda.device(a.device.index):
|
|
1236
|
+
torch._library.capture_triton(_kernel_matmul_fp8_row_non_persistent)[grid](
|
|
1237
|
+
a,
|
|
1238
|
+
b,
|
|
1239
|
+
c,
|
|
1240
|
+
M,
|
|
1241
|
+
N,
|
|
1242
|
+
K,
|
|
1243
|
+
m_key,
|
|
1244
|
+
n_key,
|
|
1245
|
+
k_key,
|
|
1246
|
+
a_scale,
|
|
1247
|
+
b_scale,
|
|
1248
|
+
bias,
|
|
1249
|
+
a.stride(0),
|
|
1250
|
+
a.stride(1),
|
|
1251
|
+
b.stride(0),
|
|
1252
|
+
b.stride(1),
|
|
1253
|
+
c.stride(0),
|
|
1254
|
+
c.stride(1),
|
|
1255
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
1256
|
+
allow_tf32=allow_tf32,
|
|
1257
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
1258
|
+
# GROUP_M=8,
|
|
1259
|
+
USE_BIAS=bias is not None,
|
|
1260
|
+
AB_DTYPE=False,
|
|
1261
|
+
)
|
|
1262
|
+
elif use_warp_specialization:
|
|
1263
|
+
assert has_warp_specialization
|
|
1264
|
+
# used by TMA warp specialization kernel
|
|
1265
|
+
desc_helper = TmaAutoTuneHelper()
|
|
1266
|
+
desc_helper.init_tma_descriptor("a")
|
|
1267
|
+
desc_helper.init_tma_descriptor("b")
|
|
1268
|
+
desc_helper.init_tma_descriptor("c")
|
|
1269
|
+
desc_helper.init_tma_descriptor("a_scale")
|
|
1270
|
+
desc_helper.init_tma_descriptor("b_scale")
|
|
1271
|
+
desc_helper.init_tma_descriptor("bias")
|
|
1272
|
+
|
|
1273
|
+
def persistent_grid_tma_ws(META: Dict[str, int]) -> Tuple[int]:
|
|
1274
|
+
nonlocal desc_helper # noqa: F824
|
|
1275
|
+
assert a_scale is not None # Type narrowing for Pyre
|
|
1276
|
+
desc_helper.fill_2d_tma_descriptor(
|
|
1277
|
+
"a",
|
|
1278
|
+
a.data_ptr(),
|
|
1279
|
+
M,
|
|
1280
|
+
K,
|
|
1281
|
+
META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"],
|
|
1282
|
+
META["BLOCK_K"],
|
|
1283
|
+
a.element_size(),
|
|
1284
|
+
)
|
|
1285
|
+
|
|
1286
|
+
desc_helper.fill_2d_tma_descriptor(
|
|
1287
|
+
"b",
|
|
1288
|
+
b.data_ptr(),
|
|
1289
|
+
N,
|
|
1290
|
+
K,
|
|
1291
|
+
META["BLOCK_N"],
|
|
1292
|
+
META["BLOCK_K"],
|
|
1293
|
+
b.element_size(),
|
|
1294
|
+
)
|
|
1295
|
+
desc_helper.fill_2d_tma_descriptor(
|
|
1296
|
+
"c",
|
|
1297
|
+
c.data_ptr(),
|
|
1298
|
+
M,
|
|
1299
|
+
N,
|
|
1300
|
+
META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"],
|
|
1301
|
+
META["BLOCK_N"],
|
|
1302
|
+
c.element_size(),
|
|
1303
|
+
)
|
|
1304
|
+
desc_helper.fill_1d_tma_descriptor(
|
|
1305
|
+
"a_scale",
|
|
1306
|
+
a_scale.data_ptr(),
|
|
1307
|
+
M,
|
|
1308
|
+
META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"],
|
|
1309
|
+
a_scale.element_size(),
|
|
1310
|
+
)
|
|
1311
|
+
desc_helper.fill_1d_tma_descriptor(
|
|
1312
|
+
"b_scale",
|
|
1313
|
+
b_scale.data_ptr(),
|
|
1314
|
+
N,
|
|
1315
|
+
META["BLOCK_N"],
|
|
1316
|
+
b_scale.element_size(),
|
|
1317
|
+
)
|
|
1318
|
+
if bias is not None:
|
|
1319
|
+
desc_helper.fill_1d_tma_descriptor(
|
|
1320
|
+
"bias",
|
|
1321
|
+
bias.data_ptr(),
|
|
1322
|
+
N,
|
|
1323
|
+
META["BLOCK_N"],
|
|
1324
|
+
bias.element_size(),
|
|
1325
|
+
)
|
|
1326
|
+
return (
|
|
1327
|
+
min(
|
|
1328
|
+
NUM_SMS,
|
|
1329
|
+
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
|
1330
|
+
),
|
|
1331
|
+
)
|
|
1332
|
+
|
|
1333
|
+
desc_a = desc_helper.get_tma_descriptor_kernel_param("a")
|
|
1334
|
+
desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
|
|
1335
|
+
desc_c = desc_helper.get_tma_descriptor_kernel_param("c")
|
|
1336
|
+
desc_a_scale = desc_helper.get_tma_descriptor_kernel_param("a_scale")
|
|
1337
|
+
desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale")
|
|
1338
|
+
desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias")
|
|
1339
|
+
|
|
1340
|
+
bias_dtype_triton = None
|
|
1341
|
+
if bias is not None:
|
|
1342
|
+
bias_dtype_triton = map_dtype_to_triton(bias.dtype)
|
|
1343
|
+
|
|
1344
|
+
# pyre-ignore
|
|
1345
|
+
torch._library.capture_triton(
|
|
1346
|
+
_kernel_matmul_fp8_row_tma_persistent_ws_cooperative
|
|
1347
|
+
)[persistent_grid_tma_ws](
|
|
1348
|
+
desc_a,
|
|
1349
|
+
desc_b,
|
|
1350
|
+
desc_c,
|
|
1351
|
+
M,
|
|
1352
|
+
N,
|
|
1353
|
+
K,
|
|
1354
|
+
m_key,
|
|
1355
|
+
n_key,
|
|
1356
|
+
k_key,
|
|
1357
|
+
a_scale,
|
|
1358
|
+
b_scale,
|
|
1359
|
+
desc_bias,
|
|
1360
|
+
a.stride(0),
|
|
1361
|
+
a.stride(1),
|
|
1362
|
+
b.stride(0),
|
|
1363
|
+
b.stride(1),
|
|
1364
|
+
c.stride(0),
|
|
1365
|
+
c.stride(1),
|
|
1366
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
1367
|
+
c_dtype=c_dtype_triton,
|
|
1368
|
+
bias_dtype=bias_dtype_triton,
|
|
1369
|
+
allow_tf32=allow_tf32,
|
|
1370
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
1371
|
+
GROUP_M=8,
|
|
1372
|
+
AB_DTYPE=False,
|
|
1373
|
+
NUM_SMS=NUM_SMS,
|
|
1374
|
+
USE_BIAS=bias is not None,
|
|
1375
|
+
)
|
|
1376
|
+
elif tma_persistent:
|
|
1377
|
+
# used by TMA persistent kernel
|
|
1378
|
+
desc_helper = TmaAutoTuneHelper()
|
|
1379
|
+
desc_helper.init_tma_descriptor("a")
|
|
1380
|
+
desc_helper.init_tma_descriptor("b")
|
|
1381
|
+
desc_helper.init_tma_descriptor("c")
|
|
1382
|
+
desc_helper.init_tma_descriptor("a_scale")
|
|
1383
|
+
desc_helper.init_tma_descriptor("b_scale")
|
|
1384
|
+
desc_helper.init_tma_descriptor("bias")
|
|
1385
|
+
|
|
1386
|
+
def persistent_grid_tma(META: Dict[str, int]) -> Tuple[int]:
|
|
1387
|
+
nonlocal desc_helper # noqa: F824
|
|
1388
|
+
assert a_scale is not None # Type narrowing for Pyre
|
|
1389
|
+
desc_helper.fill_2d_tma_descriptor(
|
|
1390
|
+
"a",
|
|
1391
|
+
a.data_ptr(),
|
|
1392
|
+
M,
|
|
1393
|
+
K,
|
|
1394
|
+
META["BLOCK_M"],
|
|
1395
|
+
META["BLOCK_K"],
|
|
1396
|
+
a.element_size(),
|
|
1397
|
+
)
|
|
1398
|
+
|
|
1399
|
+
desc_helper.fill_2d_tma_descriptor(
|
|
1400
|
+
"b",
|
|
1401
|
+
b.data_ptr(),
|
|
1402
|
+
N,
|
|
1403
|
+
K,
|
|
1404
|
+
META["BLOCK_N"],
|
|
1405
|
+
META["BLOCK_K"],
|
|
1406
|
+
b.element_size(),
|
|
1407
|
+
)
|
|
1408
|
+
desc_helper.fill_2d_tma_descriptor(
|
|
1409
|
+
"c",
|
|
1410
|
+
c.data_ptr(),
|
|
1411
|
+
M,
|
|
1412
|
+
N,
|
|
1413
|
+
META["BLOCK_M"],
|
|
1414
|
+
META["BLOCK_N"],
|
|
1415
|
+
c.element_size(),
|
|
1416
|
+
)
|
|
1417
|
+
desc_helper.fill_1d_tma_descriptor(
|
|
1418
|
+
"a_scale",
|
|
1419
|
+
a_scale.data_ptr(),
|
|
1420
|
+
M,
|
|
1421
|
+
META["BLOCK_M"],
|
|
1422
|
+
a_scale.element_size(),
|
|
1423
|
+
)
|
|
1424
|
+
desc_helper.fill_1d_tma_descriptor(
|
|
1425
|
+
"b_scale",
|
|
1426
|
+
b_scale.data_ptr(),
|
|
1427
|
+
N,
|
|
1428
|
+
META["BLOCK_N"],
|
|
1429
|
+
b_scale.element_size(),
|
|
1430
|
+
)
|
|
1431
|
+
if bias is not None:
|
|
1432
|
+
desc_helper.fill_1d_tma_descriptor(
|
|
1433
|
+
"bias",
|
|
1434
|
+
bias.data_ptr(),
|
|
1435
|
+
N,
|
|
1436
|
+
META["BLOCK_N"],
|
|
1437
|
+
bias.element_size(),
|
|
1438
|
+
)
|
|
1439
|
+
return (
|
|
1440
|
+
min(
|
|
1441
|
+
NUM_SMS,
|
|
1442
|
+
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
|
1443
|
+
),
|
|
1444
|
+
)
|
|
1445
|
+
|
|
1446
|
+
desc_a = desc_helper.get_tma_descriptor_kernel_param("a")
|
|
1447
|
+
desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
|
|
1448
|
+
desc_c = desc_helper.get_tma_descriptor_kernel_param("c")
|
|
1449
|
+
desc_a_scale = desc_helper.get_tma_descriptor_kernel_param("a_scale")
|
|
1450
|
+
desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale")
|
|
1451
|
+
desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias")
|
|
1452
|
+
|
|
1453
|
+
bias_dtype_triton = None
|
|
1454
|
+
if bias is not None:
|
|
1455
|
+
bias_dtype_triton = map_dtype_to_triton(bias.dtype)
|
|
1456
|
+
|
|
1457
|
+
# pyre-ignore
|
|
1458
|
+
torch._library.capture_triton(_kernel_matmul_fp8_row_tma_persistent)[
|
|
1459
|
+
persistent_grid_tma
|
|
1460
|
+
](
|
|
1461
|
+
desc_a,
|
|
1462
|
+
desc_b,
|
|
1463
|
+
desc_c,
|
|
1464
|
+
M,
|
|
1465
|
+
N,
|
|
1466
|
+
K,
|
|
1467
|
+
m_key,
|
|
1468
|
+
n_key,
|
|
1469
|
+
k_key,
|
|
1470
|
+
desc_a_scale,
|
|
1471
|
+
desc_b_scale,
|
|
1472
|
+
desc_bias,
|
|
1473
|
+
a.stride(0),
|
|
1474
|
+
a.stride(1),
|
|
1475
|
+
b.stride(0),
|
|
1476
|
+
b.stride(1),
|
|
1477
|
+
c.stride(0),
|
|
1478
|
+
c.stride(1),
|
|
1479
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
1480
|
+
c_dtype=c_dtype_triton,
|
|
1481
|
+
bias_dtype=bias_dtype_triton,
|
|
1482
|
+
allow_tf32=allow_tf32,
|
|
1483
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
1484
|
+
GROUP_M=8,
|
|
1485
|
+
AB_DTYPE=False,
|
|
1486
|
+
NUM_SMS=NUM_SMS,
|
|
1487
|
+
USE_BIAS=bias is not None,
|
|
1488
|
+
)
|
|
1489
|
+
elif imprecise_acc:
|
|
1490
|
+
with torch.cuda.device(a.device.index):
|
|
1491
|
+
torch._library.capture_triton(_kernel_matmul_fp8_row_imprecise_acc)[grid](
|
|
1492
|
+
a,
|
|
1493
|
+
b,
|
|
1494
|
+
c,
|
|
1495
|
+
M,
|
|
1496
|
+
N,
|
|
1497
|
+
K,
|
|
1498
|
+
m_key,
|
|
1499
|
+
n_key,
|
|
1500
|
+
k_key,
|
|
1501
|
+
a_scale,
|
|
1502
|
+
b_scale,
|
|
1503
|
+
bias,
|
|
1504
|
+
a.stride(0),
|
|
1505
|
+
a.stride(1),
|
|
1506
|
+
b.stride(0),
|
|
1507
|
+
b.stride(1),
|
|
1508
|
+
c.stride(0),
|
|
1509
|
+
c.stride(1),
|
|
1510
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
1511
|
+
allow_tf32=allow_tf32,
|
|
1512
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
1513
|
+
GROUP_M=8,
|
|
1514
|
+
USE_BIAS=bias is not None,
|
|
1515
|
+
AB_DTYPE=False,
|
|
1516
|
+
)
|
|
1517
|
+
elif fp8_fast_accum:
|
|
1518
|
+
skip_scaling_a = a_scale is None
|
|
1519
|
+
torch._library.capture_triton(_kernel_matmul_fp8_row)[persistent_grid](
|
|
1520
|
+
a,
|
|
1521
|
+
b,
|
|
1522
|
+
c,
|
|
1523
|
+
M,
|
|
1524
|
+
N,
|
|
1525
|
+
K,
|
|
1526
|
+
m_key,
|
|
1527
|
+
n_key,
|
|
1528
|
+
k_key,
|
|
1529
|
+
a_scale,
|
|
1530
|
+
b_scale,
|
|
1531
|
+
bias,
|
|
1532
|
+
a.stride(0),
|
|
1533
|
+
a.stride(1),
|
|
1534
|
+
b.stride(0),
|
|
1535
|
+
b.stride(1),
|
|
1536
|
+
c.stride(0),
|
|
1537
|
+
c.stride(1),
|
|
1538
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
1539
|
+
allow_tf32=allow_tf32,
|
|
1540
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
1541
|
+
skip_scaling_a=skip_scaling_a,
|
|
1542
|
+
GROUP_M=8,
|
|
1543
|
+
USE_BIAS=bias is not None,
|
|
1544
|
+
AB_DTYPE=False,
|
|
1545
|
+
NUM_SMS=NUM_SMS,
|
|
1546
|
+
)
|
|
1547
|
+
else:
|
|
1548
|
+
torch._library.capture_triton(_kernel_matmul_fp8_row_no_fast_acc)[
|
|
1549
|
+
persistent_grid
|
|
1550
|
+
](
|
|
1551
|
+
a,
|
|
1552
|
+
b,
|
|
1553
|
+
c,
|
|
1554
|
+
M,
|
|
1555
|
+
N,
|
|
1556
|
+
K,
|
|
1557
|
+
m_key,
|
|
1558
|
+
n_key,
|
|
1559
|
+
k_key,
|
|
1560
|
+
a_scale,
|
|
1561
|
+
b_scale,
|
|
1562
|
+
bias,
|
|
1563
|
+
a.stride(0),
|
|
1564
|
+
a.stride(1),
|
|
1565
|
+
b.stride(0),
|
|
1566
|
+
b.stride(1),
|
|
1567
|
+
c.stride(0),
|
|
1568
|
+
c.stride(1),
|
|
1569
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
1570
|
+
allow_tf32=allow_tf32,
|
|
1571
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
1572
|
+
GROUP_M=8,
|
|
1573
|
+
USE_BIAS=bias is not None,
|
|
1574
|
+
AB_DTYPE=False,
|
|
1575
|
+
NUM_SMS=NUM_SMS,
|
|
1576
|
+
)
|
|
1577
|
+
return c.view(output_shape)
|
|
1578
|
+
|
|
1579
|
+
|
|
1580
|
+
@matmul_fp8_row.register_fake
|
|
1581
|
+
def matmul_fp8_row_meta(
|
|
1582
|
+
a: torch.Tensor,
|
|
1583
|
+
b: torch.Tensor,
|
|
1584
|
+
a_scale: Optional[torch.Tensor],
|
|
1585
|
+
b_scale: torch.Tensor,
|
|
1586
|
+
bias: Optional[torch.Tensor] = None,
|
|
1587
|
+
dot_out_dtype: Optional[torch.dtype] = None,
|
|
1588
|
+
allow_tf32: bool = True,
|
|
1589
|
+
fp8_fast_accum: bool = True,
|
|
1590
|
+
imprecise_acc: bool = False,
|
|
1591
|
+
tma_persistent: bool = True,
|
|
1592
|
+
no_use_persistent: Optional[bool] = None,
|
|
1593
|
+
use_warp_specialization: bool = False,
|
|
1594
|
+
) -> torch.Tensor:
|
|
1595
|
+
"""Shape function for torch compile."""
|
|
1596
|
+
M, K = a.shape
|
|
1597
|
+
N, K = b.shape
|
|
1598
|
+
return torch.empty(
|
|
1599
|
+
(M, N),
|
|
1600
|
+
device=a.device,
|
|
1601
|
+
dtype=torch.bfloat16 if dot_out_dtype is None else dot_out_dtype,
|
|
1602
|
+
)
|
|
1603
|
+
|
|
1604
|
+
|
|
1605
|
+
# pruned some unreasonable config
|
|
1606
|
+
def prune_configs_block(configs, named_args, **kwargs):
|
|
1607
|
+
configs = early_config_prune(configs, named_args, **kwargs)
|
|
1608
|
+
scale_block_k = named_args["scale_block_k"]
|
|
1609
|
+
pruned_configs = []
|
|
1610
|
+
# Further rule out configs with scale_block_k is not a multiple of BLOCK_K
|
|
1611
|
+
for config in configs:
|
|
1612
|
+
kw = config.kwargs
|
|
1613
|
+
BLOCK_K = kw["BLOCK_K"]
|
|
1614
|
+
if scale_block_k % BLOCK_K != 0:
|
|
1615
|
+
continue
|
|
1616
|
+
pruned_configs.append(config)
|
|
1617
|
+
return pruned_configs
|
|
1618
|
+
|
|
1619
|
+
|
|
1620
|
+
@triton.autotune(
|
|
1621
|
+
configs=MATMUL_CONFIGS,
|
|
1622
|
+
key=[
|
|
1623
|
+
"m_key",
|
|
1624
|
+
"n_key",
|
|
1625
|
+
"k_key",
|
|
1626
|
+
], # TODO caller side bin keys so similar shapes can use same triton.autotune.
|
|
1627
|
+
prune_configs_by={
|
|
1628
|
+
"early_config_prune": prune_configs_block,
|
|
1629
|
+
"perf_model": estimate_matmul_time,
|
|
1630
|
+
"top_k": 10,
|
|
1631
|
+
},
|
|
1632
|
+
)
|
|
1633
|
+
@triton.heuristics(
|
|
1634
|
+
{
|
|
1635
|
+
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
1636
|
+
}
|
|
1637
|
+
)
|
|
1638
|
+
@triton.jit
|
|
1639
|
+
def _kernel_matmul_fp8_block_fastacc(
|
|
1640
|
+
A,
|
|
1641
|
+
B,
|
|
1642
|
+
C,
|
|
1643
|
+
M,
|
|
1644
|
+
N,
|
|
1645
|
+
K,
|
|
1646
|
+
m_key,
|
|
1647
|
+
n_key,
|
|
1648
|
+
k_key,
|
|
1649
|
+
A_scale,
|
|
1650
|
+
B_scale,
|
|
1651
|
+
scale_block_m: tl.constexpr,
|
|
1652
|
+
scale_block_n: tl.constexpr,
|
|
1653
|
+
scale_block_k: tl.constexpr,
|
|
1654
|
+
stride_am,
|
|
1655
|
+
stride_ak,
|
|
1656
|
+
stride_bn,
|
|
1657
|
+
stride_bk,
|
|
1658
|
+
stride_cm,
|
|
1659
|
+
stride_cn,
|
|
1660
|
+
stride_scale_am,
|
|
1661
|
+
stride_scale_ak,
|
|
1662
|
+
stride_scale_bn,
|
|
1663
|
+
stride_scale_bk,
|
|
1664
|
+
dot_out_dtype: tl.constexpr,
|
|
1665
|
+
allow_tf32: tl.constexpr,
|
|
1666
|
+
BLOCK_M: tl.constexpr,
|
|
1667
|
+
BLOCK_N: tl.constexpr,
|
|
1668
|
+
BLOCK_K: tl.constexpr,
|
|
1669
|
+
GROUP_M: tl.constexpr,
|
|
1670
|
+
SPLIT_K: tl.constexpr,
|
|
1671
|
+
EVEN_K: tl.constexpr,
|
|
1672
|
+
AB_DTYPE: tl.constexpr,
|
|
1673
|
+
) -> None:
|
|
1674
|
+
"""Matmul kernel of [M, K] @ [N, K] with block-wise scales
|
|
1675
|
+
|
|
1676
|
+
Performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles and
|
|
1677
|
+
A and B scaled by a scaling factor per [scale_block_m, scale_block_k] and
|
|
1678
|
+
[scale_block_n, scale_block_k] tiles
|
|
1679
|
+
respectively.
|
|
1680
|
+
|
|
1681
|
+
Todo:
|
|
1682
|
+
* Support scale_block_{mnk} < BLOCK{MNK} for each dim.
|
|
1683
|
+
Args:
|
|
1684
|
+
A (TensorWrapper): [M, K] input tensor.
|
|
1685
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
1686
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
1687
|
+
M (int): M dimension of input tensor.
|
|
1688
|
+
N (int): N dimension of input tensor.
|
|
1689
|
+
K (int): K dimension of input tensor.
|
|
1690
|
+
m_key (int): Autotuning key for M dimension of input tensor.
|
|
1691
|
+
n_key (int): Autotuning key for N dimension of input tensor.
|
|
1692
|
+
k_key (int): Autotuning key for K dimension of input tensor.
|
|
1693
|
+
A_scale (TensorWrapper): [cdiv(M, scale_block_m), cdiv(K, scale_block_k)] reciprocal scale tensor per block. A * A_scale = original A
|
|
1694
|
+
B_scale (TensorWrapper): [cdiv(N, scale_block_n), cdiv(K, scale_block_k)] reciprocal scale tensor per block. B * B_scale = original B
|
|
1695
|
+
scale_block_m (int): Block size for M dimension of A_scale.
|
|
1696
|
+
scale_block_n (int): Block size for N dimension of B_scale.
|
|
1697
|
+
scale_block_k (int): Block size for K dimension of A_scale and B_scale.
|
|
1698
|
+
stride_am (int): Stride of M dimension of A.
|
|
1699
|
+
stride_ak (int): Stride of K dimension of A.
|
|
1700
|
+
stride_bn (int): Stride of N dimension of B.
|
|
1701
|
+
stride_bk (int): Stride of K dimension of B.
|
|
1702
|
+
stride_cm (int): Stride of M dimension of C.
|
|
1703
|
+
stride_cn (int): Stride of N dimension of C.
|
|
1704
|
+
stride_scale_am (int): Stride of M dimension of A_scale.
|
|
1705
|
+
stride_scale_ak (int): Stride of K dimension of A_scale.
|
|
1706
|
+
stride_scale_bn (int): Stride of N dimension of B_scale.
|
|
1707
|
+
stride_scale_bk (int): Stride of K dimension of B_scale.
|
|
1708
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
1709
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
1710
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
1711
|
+
BLOCK_M (int): Block size for M dimension.
|
|
1712
|
+
BLOCK_N (int): Block size for N dimension.
|
|
1713
|
+
BLOCK_K (int): Block size for K dimension.
|
|
1714
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
1715
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
1716
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
1717
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
1718
|
+
"""
|
|
1719
|
+
assert BLOCK_M < scale_block_m
|
|
1720
|
+
assert BLOCK_N < scale_block_n
|
|
1721
|
+
assert BLOCK_K < scale_block_k
|
|
1722
|
+
# matrix multiplication
|
|
1723
|
+
pid = tl.program_id(0)
|
|
1724
|
+
pid_z = tl.program_id(1)
|
|
1725
|
+
|
|
1726
|
+
grid_m = tl.cdiv(M, BLOCK_M)
|
|
1727
|
+
grid_n = tl.cdiv(N, BLOCK_N)
|
|
1728
|
+
# re-order program ID for better L2 performance
|
|
1729
|
+
width = GROUP_M * grid_n
|
|
1730
|
+
group_id = pid // width
|
|
1731
|
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
|
1732
|
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
|
1733
|
+
pid_n = (pid % width) // (group_size)
|
|
1734
|
+
# do matrix multiplication
|
|
1735
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
1736
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
1737
|
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
|
1738
|
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
|
1739
|
+
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
1740
|
+
# pointers
|
|
1741
|
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
|
1742
|
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
|
1743
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
1744
|
+
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
|
|
1745
|
+
scale_m = pid_m * BLOCK_M // scale_block_m
|
|
1746
|
+
scale_n = pid_n * BLOCK_N // scale_block_n
|
|
1747
|
+
k_multiple = scale_block_k // BLOCK_K
|
|
1748
|
+
|
|
1749
|
+
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
|
1750
|
+
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
|
1751
|
+
|
|
1752
|
+
if EVEN_K:
|
|
1753
|
+
a = tl.load(A)
|
|
1754
|
+
b = tl.load(B)
|
|
1755
|
+
else:
|
|
1756
|
+
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
|
|
1757
|
+
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
|
|
1758
|
+
if AB_DTYPE:
|
|
1759
|
+
a = a.to(C.dtype.element_ty)
|
|
1760
|
+
b = b.to(C.dtype.element_ty)
|
|
1761
|
+
|
|
1762
|
+
acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
|
1763
|
+
|
|
1764
|
+
A += BLOCK_K * SPLIT_K * stride_ak
|
|
1765
|
+
B += BLOCK_K * SPLIT_K * stride_bk
|
|
1766
|
+
|
|
1767
|
+
# Some math to precompute on scalars, and apply once on matrix.
|
|
1768
|
+
# a + c/s = (as + c) / s
|
|
1769
|
+
# (((a_i-1 * s_i-1 + c_i-1) / s_i-1) * s_i + c_i) / s_i ... ) * s_k + c_k) * 1.0 / s_k
|
|
1770
|
+
# Simplifies to (a_i-1 + c) * (s_i+1/s_i)
|
|
1771
|
+
# And have s_k+1 be 1.
|
|
1772
|
+
# Scale_i = pid_i * BLOCK_I / scale_block_i
|
|
1773
|
+
pid_k = k * SPLIT_K + pid_z
|
|
1774
|
+
if ((pid_k + 1) % k_multiple == 0) or (k_remaining < BLOCK_K * SPLIT_K):
|
|
1775
|
+
# Note: Due to split_k access "pid_k" = k * SPLIT_K + pid_z
|
|
1776
|
+
# Access a_scale[pid_m, k * SPLIT_K + pid_z]
|
|
1777
|
+
# and b_scale[k * SPLIT_K + pid_z, pid_n]
|
|
1778
|
+
|
|
1779
|
+
scale_k = pid_k // k_multiple
|
|
1780
|
+
scale_k_next = scale_k + 1
|
|
1781
|
+
a_scale = tl.load(
|
|
1782
|
+
A_scale + scale_m * stride_scale_am + scale_k * stride_scale_ak
|
|
1783
|
+
)
|
|
1784
|
+
b_scale = tl.load(
|
|
1785
|
+
B_scale + scale_n * stride_scale_bn + scale_k * stride_scale_bk
|
|
1786
|
+
)
|
|
1787
|
+
scale = a_scale * b_scale
|
|
1788
|
+
if k + 1 == tl.cdiv(K, BLOCK_K * SPLIT_K):
|
|
1789
|
+
scale_next_inv_scale = scale
|
|
1790
|
+
else:
|
|
1791
|
+
a_scale_next = tl.load(
|
|
1792
|
+
A_scale + scale_m * stride_scale_am + scale_k_next * stride_scale_ak
|
|
1793
|
+
)
|
|
1794
|
+
b_scale_next = tl.load(
|
|
1795
|
+
B_scale + scale_n * stride_scale_bn + scale_k_next * stride_scale_bk
|
|
1796
|
+
)
|
|
1797
|
+
scale_next = a_scale_next * b_scale_next
|
|
1798
|
+
scale_next_inv_scale = scale / scale_next
|
|
1799
|
+
acc *= scale_next_inv_scale
|
|
1800
|
+
|
|
1801
|
+
# rematerialize rm and rn to save registers
|
|
1802
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
1803
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
1804
|
+
|
|
1805
|
+
acc = acc.to(C.dtype.element_ty)
|
|
1806
|
+
c = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
|
1807
|
+
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
|
1808
|
+
# handles write-back with reduction-splitting
|
|
1809
|
+
if SPLIT_K == 1:
|
|
1810
|
+
tl.store(c, acc, mask=mask)
|
|
1811
|
+
else:
|
|
1812
|
+
tl.atomic_add(c, acc, mask=mask)
|
|
1813
|
+
|
|
1814
|
+
|
|
1815
|
+
@triton.autotune(
|
|
1816
|
+
configs=MATMUL_CONFIGS,
|
|
1817
|
+
key=[
|
|
1818
|
+
"m_key",
|
|
1819
|
+
"n_key",
|
|
1820
|
+
"k_key",
|
|
1821
|
+
], # TODO caller side bin keys so similar shapes can use same triton.autotune.
|
|
1822
|
+
prune_configs_by={
|
|
1823
|
+
"early_config_prune": early_config_prune,
|
|
1824
|
+
"perf_model": estimate_matmul_time,
|
|
1825
|
+
"top_k": 10,
|
|
1826
|
+
},
|
|
1827
|
+
)
|
|
1828
|
+
@triton.heuristics(
|
|
1829
|
+
{
|
|
1830
|
+
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
1831
|
+
}
|
|
1832
|
+
)
|
|
1833
|
+
@triton.jit
|
|
1834
|
+
def _kernel_matmul_fp8_block_slowacc(
|
|
1835
|
+
A,
|
|
1836
|
+
B,
|
|
1837
|
+
C,
|
|
1838
|
+
M,
|
|
1839
|
+
N,
|
|
1840
|
+
K,
|
|
1841
|
+
m_key,
|
|
1842
|
+
n_key,
|
|
1843
|
+
k_key,
|
|
1844
|
+
A_scale,
|
|
1845
|
+
B_scale,
|
|
1846
|
+
scale_block_m: tl.constexpr,
|
|
1847
|
+
scale_block_n: tl.constexpr,
|
|
1848
|
+
scale_block_k: tl.constexpr,
|
|
1849
|
+
stride_am,
|
|
1850
|
+
stride_ak,
|
|
1851
|
+
stride_bn,
|
|
1852
|
+
stride_bk,
|
|
1853
|
+
stride_cm,
|
|
1854
|
+
stride_cn,
|
|
1855
|
+
stride_scale_am,
|
|
1856
|
+
stride_scale_ak,
|
|
1857
|
+
stride_scale_bn,
|
|
1858
|
+
stride_scale_bk,
|
|
1859
|
+
dot_out_dtype: tl.constexpr,
|
|
1860
|
+
allow_tf32: tl.constexpr,
|
|
1861
|
+
BLOCK_M: tl.constexpr,
|
|
1862
|
+
BLOCK_N: tl.constexpr,
|
|
1863
|
+
BLOCK_K: tl.constexpr,
|
|
1864
|
+
GROUP_M: tl.constexpr,
|
|
1865
|
+
SPLIT_K: tl.constexpr,
|
|
1866
|
+
EVEN_K: tl.constexpr,
|
|
1867
|
+
AB_DTYPE: tl.constexpr,
|
|
1868
|
+
) -> None:
|
|
1869
|
+
"""Matmul kernel of [M, K] @ [N, K] with block-wise scales
|
|
1870
|
+
|
|
1871
|
+
Performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles and
|
|
1872
|
+
A and B scaled by a scaling factor per [scale_block_m, scale_block_k] and
|
|
1873
|
+
[scale_block_n, scale_block_k] tiles
|
|
1874
|
+
respectively.
|
|
1875
|
+
|
|
1876
|
+
Todo:
|
|
1877
|
+
* Support scale_block_{mnk} < BLOCK{MNK} for each dim.
|
|
1878
|
+
Args:
|
|
1879
|
+
A (TensorWrapper): [M, K] input tensor.
|
|
1880
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
1881
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
1882
|
+
M (int): M dimension of input tensor.
|
|
1883
|
+
N (int): N dimension of input tensor.
|
|
1884
|
+
K (int): K dimension of input tensor.
|
|
1885
|
+
m_key (int): Autotuning key for M dimension of input tensor.
|
|
1886
|
+
n_key (int): Autotuning key for N dimension of input tensor.
|
|
1887
|
+
k_key (int): Autotuning key for K dimension of input tensor.
|
|
1888
|
+
A_scale (TensorWrapper): [cdiv(M, scale_block_m), cdiv(K, scale_block_k)] reciprocal scale tensor per block. A * A_scale = original A
|
|
1889
|
+
B_scale (TensorWrapper): [cdiv(N, scale_block_n), cdiv(K, scale_block_k)] reciprocal scale tensor per block. B * B_scale = original B
|
|
1890
|
+
scale_block_m (int): Block size for M dimension of A_scale.
|
|
1891
|
+
scale_block_n (int): Block size for N dimension of B_scale.
|
|
1892
|
+
scale_block_k (int): Block size for K dimension of A_scale and B_scale.
|
|
1893
|
+
stride_am (int): Stride of M dimension of A.
|
|
1894
|
+
stride_ak (int): Stride of K dimension of A.
|
|
1895
|
+
stride_bn (int): Stride of N dimension of B.
|
|
1896
|
+
stride_bk (int): Stride of K dimension of B.
|
|
1897
|
+
stride_cm (int): Stride of M dimension of C.
|
|
1898
|
+
stride_cn (int): Stride of N dimension of C.
|
|
1899
|
+
stride_scale_am (int): Stride of M dimension of A_scale.
|
|
1900
|
+
stride_scale_ak (int): Stride of K dimension of A_scale.
|
|
1901
|
+
stride_scale_bn (int): Stride of N dimension of B_scale.
|
|
1902
|
+
stride_scale_bk (int): Stride of K dimension of B_scale.
|
|
1903
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
1904
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
1905
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
1906
|
+
BLOCK_M (int): Block size for M dimension.
|
|
1907
|
+
BLOCK_N (int): Block size for N dimension.
|
|
1908
|
+
BLOCK_K (int): Block size for K dimension.
|
|
1909
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
1910
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
1911
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
1912
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
1913
|
+
"""
|
|
1914
|
+
assert BLOCK_M < scale_block_m
|
|
1915
|
+
assert BLOCK_N < scale_block_n
|
|
1916
|
+
assert BLOCK_K < scale_block_k
|
|
1917
|
+
# matrix multiplication
|
|
1918
|
+
pid = tl.program_id(0)
|
|
1919
|
+
pid_z = tl.program_id(1)
|
|
1920
|
+
|
|
1921
|
+
grid_m = tl.cdiv(M, BLOCK_M)
|
|
1922
|
+
grid_n = tl.cdiv(N, BLOCK_N)
|
|
1923
|
+
# re-order program ID for better L2 performance
|
|
1924
|
+
width = GROUP_M * grid_n
|
|
1925
|
+
group_id = pid // width
|
|
1926
|
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
|
1927
|
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
|
1928
|
+
pid_n = (pid % width) // (group_size)
|
|
1929
|
+
# do matrix multiplication
|
|
1930
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
1931
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
1932
|
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
|
1933
|
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
|
1934
|
+
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
1935
|
+
# pointers
|
|
1936
|
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
|
1937
|
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
|
1938
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
1939
|
+
scale_m = pid_m * BLOCK_M // scale_block_m
|
|
1940
|
+
scale_n = pid_n * BLOCK_N // scale_block_n
|
|
1941
|
+
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
|
|
1942
|
+
|
|
1943
|
+
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
|
1944
|
+
# Note: Due to split_k access "pid_k" = k * SPLIT_K + pid_z
|
|
1945
|
+
# Access a_scale[pid_m, k * SPLIT_K + pid_z]
|
|
1946
|
+
# and b_scale[k * SPLIT_K + pid_z, pid_n]
|
|
1947
|
+
pid_k = k * SPLIT_K + pid_z
|
|
1948
|
+
scale_k = pid_k * BLOCK_K // scale_block_k
|
|
1949
|
+
a_scale = tl.load(
|
|
1950
|
+
A_scale + scale_m * stride_scale_am + scale_k * stride_scale_ak
|
|
1951
|
+
)
|
|
1952
|
+
b_scale = tl.load(
|
|
1953
|
+
B_scale + scale_n * stride_scale_bn + scale_k * stride_scale_bk
|
|
1954
|
+
)
|
|
1955
|
+
scale = a_scale * b_scale
|
|
1956
|
+
|
|
1957
|
+
if EVEN_K:
|
|
1958
|
+
a = tl.load(A)
|
|
1959
|
+
b = tl.load(B)
|
|
1960
|
+
else:
|
|
1961
|
+
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
|
1962
|
+
|
|
1963
|
+
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
|
|
1964
|
+
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
|
|
1965
|
+
if AB_DTYPE:
|
|
1966
|
+
a = a.to(C.dtype.element_ty)
|
|
1967
|
+
b = b.to(C.dtype.element_ty)
|
|
1968
|
+
|
|
1969
|
+
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) * scale
|
|
1970
|
+
A += BLOCK_K * SPLIT_K * stride_ak
|
|
1971
|
+
B += BLOCK_K * SPLIT_K * stride_bk
|
|
1972
|
+
|
|
1973
|
+
# rematerialize rm and rn to save registers
|
|
1974
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
1975
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
1976
|
+
|
|
1977
|
+
acc = acc.to(C.dtype.element_ty)
|
|
1978
|
+
c = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
|
1979
|
+
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
|
1980
|
+
# handles write-back with reduction-splitting
|
|
1981
|
+
if SPLIT_K == 1:
|
|
1982
|
+
tl.store(c, acc, mask=mask)
|
|
1983
|
+
else:
|
|
1984
|
+
tl.atomic_add(c, acc, mask=mask)
|
|
1985
|
+
|
|
1986
|
+
|
|
1987
|
+
@torch.library.custom_op("triton::matmul_fp8_block", mutates_args=())
|
|
1988
|
+
def matmul_fp8_block(
|
|
1989
|
+
a: torch.Tensor,
|
|
1990
|
+
b: torch.Tensor,
|
|
1991
|
+
a_scale: torch.Tensor,
|
|
1992
|
+
b_scale: torch.Tensor,
|
|
1993
|
+
scale_block_m: int = 256,
|
|
1994
|
+
scale_block_n: int = 256,
|
|
1995
|
+
scale_block_k: int = 256,
|
|
1996
|
+
dot_out_dtype: Optional[torch.dtype] = None,
|
|
1997
|
+
allow_tf32: bool = True,
|
|
1998
|
+
fp8_fast_accum: bool = True,
|
|
1999
|
+
) -> Tensor:
|
|
2000
|
+
"""Performs matmul on [M, K] and [N, K] fp8 matrices with block-wise scalings.
|
|
2001
|
+
|
|
2002
|
+
Args:
|
|
2003
|
+
a (torch.Tensor): [M, K] input tensor.
|
|
2004
|
+
b (torch.Tensor): [N, K] input tensor.
|
|
2005
|
+
a_scale (torch.Tensor): [cdiv(M, scale_block_m), cdiv(K, scale_block_k)] reciprocal scale tensor per scale block. A * A_scale = original A
|
|
2006
|
+
b_scale (torch.Tensor): [cdiv(N, scale_block_n), cdiv(K, scale_block_k)] reciprocal scale tensor per scale block. B * B_scale = original B
|
|
2007
|
+
scale_block_m (int): Block size for M dimension of A_scale.
|
|
2008
|
+
scale_block_n (int): Block size for N dimension of B_scale.
|
|
2009
|
+
scale_block_k (int): Block size for K dimension of A_scale and B_scale.
|
|
2010
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
2011
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
2012
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
2013
|
+
|
|
2014
|
+
Returns:
|
|
2015
|
+
Tensor: [M, N] output tensor, (a / a_scale) @ (b / b_scale)
|
|
2016
|
+
"""
|
|
2017
|
+
# Get datatypes and constants to use.
|
|
2018
|
+
_, tl_fp8_dtype, _, _ = get_fp8_constants()
|
|
2019
|
+
# Handle 3D+ a shape
|
|
2020
|
+
a_shape = a.shape
|
|
2021
|
+
a = a.view(-1, a.size(-1))
|
|
2022
|
+
# View inputs into proper triton fp8 dtype.
|
|
2023
|
+
a_tl = reinterpret_fp8_type(a, tl_fp8_dtype)
|
|
2024
|
+
b_tl = reinterpret_fp8_type(b, tl_fp8_dtype)
|
|
2025
|
+
|
|
2026
|
+
M, N, K, m_key, n_key, k_key, c, _, dot_out_dtype_triton, device = prep_matmul(
|
|
2027
|
+
a_tl, b_tl, dot_out_dtype
|
|
2028
|
+
)
|
|
2029
|
+
|
|
2030
|
+
output_shape = a_shape[:-1] + (N,)
|
|
2031
|
+
# Handle case where inputs are empty.
|
|
2032
|
+
if (M == 0) or (N == 0) or (K == 0):
|
|
2033
|
+
return torch.zeros(output_shape, device=device, dtype=torch.bfloat16)
|
|
2034
|
+
|
|
2035
|
+
# launch kernel
|
|
2036
|
+
assert device != torch.device("cpu"), (
|
|
2037
|
+
"Blockwise matmul not supported on cpu, please use row-wise instead."
|
|
2038
|
+
)
|
|
2039
|
+
|
|
2040
|
+
if b.device != a.device:
|
|
2041
|
+
raise Exception("'b' must be on the same device as 'a'")
|
|
2042
|
+
if a_scale.device != a.device:
|
|
2043
|
+
raise Exception("'a_scale' must be on the same device as 'a'")
|
|
2044
|
+
if b_scale.device != a.device:
|
|
2045
|
+
raise Exception("'b_scale' must be on the same device as 'a'")
|
|
2046
|
+
|
|
2047
|
+
# noqa: E731:
|
|
2048
|
+
def grid(META: Dict[str, int]) -> Tuple[int, int]:
|
|
2049
|
+
return (
|
|
2050
|
+
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
|
2051
|
+
META["SPLIT_K"],
|
|
2052
|
+
)
|
|
2053
|
+
|
|
2054
|
+
if fp8_fast_accum:
|
|
2055
|
+
with torch.cuda.device(a_tl.device.index):
|
|
2056
|
+
_kernel_matmul_fp8_block_fastacc[grid](
|
|
2057
|
+
a_tl,
|
|
2058
|
+
b_tl,
|
|
2059
|
+
c,
|
|
2060
|
+
M,
|
|
2061
|
+
N,
|
|
2062
|
+
K,
|
|
2063
|
+
m_key,
|
|
2064
|
+
n_key,
|
|
2065
|
+
k_key,
|
|
2066
|
+
a_scale,
|
|
2067
|
+
b_scale,
|
|
2068
|
+
scale_block_m,
|
|
2069
|
+
scale_block_n,
|
|
2070
|
+
scale_block_k,
|
|
2071
|
+
a.stride(0),
|
|
2072
|
+
a.stride(1),
|
|
2073
|
+
b.stride(0),
|
|
2074
|
+
b.stride(1),
|
|
2075
|
+
c.stride(0),
|
|
2076
|
+
c.stride(1),
|
|
2077
|
+
a_scale.stride(0),
|
|
2078
|
+
a_scale.stride(1),
|
|
2079
|
+
b_scale.stride(0),
|
|
2080
|
+
b_scale.stride(1),
|
|
2081
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
2082
|
+
allow_tf32=allow_tf32,
|
|
2083
|
+
GROUP_M=8,
|
|
2084
|
+
AB_DTYPE=False,
|
|
2085
|
+
)
|
|
2086
|
+
else:
|
|
2087
|
+
with torch.cuda.device(a_tl.device.index):
|
|
2088
|
+
_kernel_matmul_fp8_block_slowacc[grid](
|
|
2089
|
+
a_tl,
|
|
2090
|
+
b_tl,
|
|
2091
|
+
c,
|
|
2092
|
+
M,
|
|
2093
|
+
N,
|
|
2094
|
+
K,
|
|
2095
|
+
m_key,
|
|
2096
|
+
n_key,
|
|
2097
|
+
k_key,
|
|
2098
|
+
a_scale,
|
|
2099
|
+
b_scale,
|
|
2100
|
+
scale_block_m,
|
|
2101
|
+
scale_block_n,
|
|
2102
|
+
scale_block_k,
|
|
2103
|
+
a.stride(0),
|
|
2104
|
+
a.stride(1),
|
|
2105
|
+
b.stride(0),
|
|
2106
|
+
b.stride(1),
|
|
2107
|
+
c.stride(0),
|
|
2108
|
+
c.stride(1),
|
|
2109
|
+
a_scale.stride(0),
|
|
2110
|
+
a_scale.stride(1),
|
|
2111
|
+
b_scale.stride(0),
|
|
2112
|
+
b_scale.stride(1),
|
|
2113
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
2114
|
+
allow_tf32=allow_tf32,
|
|
2115
|
+
GROUP_M=8,
|
|
2116
|
+
AB_DTYPE=False,
|
|
2117
|
+
)
|
|
2118
|
+
return c.view(output_shape)
|
|
2119
|
+
|
|
2120
|
+
|
|
2121
|
+
@matmul_fp8_block.register_fake
|
|
2122
|
+
def matmul_fp8_block_meta(
|
|
2123
|
+
a: torch.Tensor,
|
|
2124
|
+
b: torch.Tensor,
|
|
2125
|
+
a_scale: torch.Tensor,
|
|
2126
|
+
b_scale: torch.Tensor,
|
|
2127
|
+
scale_block_m: int = 256,
|
|
2128
|
+
scale_block_n: int = 256,
|
|
2129
|
+
scale_block_k: int = 256,
|
|
2130
|
+
dot_out_dtype: Optional[torch.dtype] = None,
|
|
2131
|
+
allow_tf32: bool = True,
|
|
2132
|
+
fp8_fast_accum: bool = True,
|
|
2133
|
+
) -> torch.Tensor:
|
|
2134
|
+
"""Shape function for torch compile."""
|
|
2135
|
+
M, K = a.shape
|
|
2136
|
+
N, K = b.shape
|
|
2137
|
+
return torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
|
|
2138
|
+
|
|
2139
|
+
|
|
2140
|
+
def get_matmul_tune(M: int, N: int, K: int) -> Tuple[int, int, int]:
|
|
2141
|
+
"""
|
|
2142
|
+
Generate a simplified matmul tune key for A @ B.T
|
|
2143
|
+
with [M, K] A and [N, K] B to reduce excessive autotuning.
|
|
2144
|
+
|
|
2145
|
+
Args:
|
|
2146
|
+
M (int): Number of rows in A.
|
|
2147
|
+
N (int): Number of rows in B.
|
|
2148
|
+
K (int): Number of cols in A and cols in B.
|
|
2149
|
+
|
|
2150
|
+
Returns:
|
|
2151
|
+
m_key (int): Autotuning key for M dim.
|
|
2152
|
+
n_key (int): Autotuning key for N dim.
|
|
2153
|
+
k_key (int): Autotuning key for K dim.
|
|
2154
|
+
|
|
2155
|
+
TODO: Refine this. For now it's useful for LLM inference where N, K dims are fixed
|
|
2156
|
+
and M dim varies due to seq_len.
|
|
2157
|
+
"""
|
|
2158
|
+
if M < 256:
|
|
2159
|
+
m_key = M
|
|
2160
|
+
else:
|
|
2161
|
+
m_key = 256 + M // 1024
|
|
2162
|
+
return m_key, N, K
|
|
2163
|
+
|
|
2164
|
+
|
|
2165
|
+
def prep_matmul(
|
|
2166
|
+
a: Union[TensorWrapper, torch.Tensor],
|
|
2167
|
+
b: Union[TensorWrapper, torch.Tensor],
|
|
2168
|
+
dot_out_dtype: Optional[torch.dtype],
|
|
2169
|
+
) -> Tuple[
|
|
2170
|
+
int, int, int, int, int, int, torch.Tensor, tl.dtype, tl.dtype, torch.device
|
|
2171
|
+
]:
|
|
2172
|
+
"""
|
|
2173
|
+
Shared bookkeeping for a @ b.T matmul.
|
|
2174
|
+
|
|
2175
|
+
Args:
|
|
2176
|
+
a (torch.Tensor): [M, K] input tensor.
|
|
2177
|
+
b (torch.Tensor): [N, K] input tensor.
|
|
2178
|
+
dot_out_dtype (tl.dtype): Output type of tensor core.
|
|
2179
|
+
|
|
2180
|
+
Returns:
|
|
2181
|
+
M (int): Number of rows in A.
|
|
2182
|
+
N (int): Number of rows in B.
|
|
2183
|
+
K (int): Number of cols in A and cols in B.
|
|
2184
|
+
m_key (int): Autotuning key for M dim.
|
|
2185
|
+
n_key (int): Autotuning key for N dim.
|
|
2186
|
+
k_key (int): Autotuning key for K dim.
|
|
2187
|
+
c (Tensor): [M, N] output tensor.
|
|
2188
|
+
c_dtype_triton (tl.dtype): Type of output tensor.
|
|
2189
|
+
dot_out_dtype (tl.dtype): Output type of tensor core.
|
|
2190
|
+
device (torch.device): Device of output tensor.
|
|
2191
|
+
"""
|
|
2192
|
+
device = a.device
|
|
2193
|
+
|
|
2194
|
+
# checks constraints
|
|
2195
|
+
assert a.shape[1] == b.shape[1], (
|
|
2196
|
+
f"incompatible dimensions, a: {a.shape}, b: {b.shape}"
|
|
2197
|
+
)
|
|
2198
|
+
M, K = a.shape
|
|
2199
|
+
N, _ = b.shape
|
|
2200
|
+
m_key, n_key, k_key = get_matmul_tune(M, N, K)
|
|
2201
|
+
|
|
2202
|
+
# allocates output
|
|
2203
|
+
assert a.dtype in [
|
|
2204
|
+
torch.float8_e4m3fn,
|
|
2205
|
+
torch.float8_e5m2,
|
|
2206
|
+
torch.float8_e4m3fnuz,
|
|
2207
|
+
torch.float8_e5m2fnuz,
|
|
2208
|
+
tl.float8e4nv,
|
|
2209
|
+
tl.float8e4b15,
|
|
2210
|
+
tl.float8e5,
|
|
2211
|
+
tl.float8e4b8,
|
|
2212
|
+
]
|
|
2213
|
+
assert b.dtype in [
|
|
2214
|
+
torch.float8_e4m3fn,
|
|
2215
|
+
torch.float8_e5m2,
|
|
2216
|
+
torch.float8_e4m3fnuz,
|
|
2217
|
+
torch.float8_e5m2fnuz,
|
|
2218
|
+
tl.float8e4nv,
|
|
2219
|
+
tl.float8e4b15,
|
|
2220
|
+
tl.float8e5,
|
|
2221
|
+
tl.float8e4b8,
|
|
2222
|
+
]
|
|
2223
|
+
|
|
2224
|
+
c_dtype, c_dtype_triton = (
|
|
2225
|
+
(torch.bfloat16, tl.bfloat16)
|
|
2226
|
+
if dot_out_dtype is None
|
|
2227
|
+
else (dot_out_dtype, map_dtype_to_triton(dot_out_dtype))
|
|
2228
|
+
)
|
|
2229
|
+
|
|
2230
|
+
c = torch.empty((M, N), device=device, dtype=c_dtype)
|
|
2231
|
+
if dot_out_dtype is None:
|
|
2232
|
+
dot_out_dtype_triton = tl.float32
|
|
2233
|
+
else:
|
|
2234
|
+
assert isinstance(dot_out_dtype, torch.dtype), (
|
|
2235
|
+
f"dot_out_dtype type {type(dot_out_dtype)} must be a torch.dtype"
|
|
2236
|
+
)
|
|
2237
|
+
dot_out_dtype_triton = map_dtype_to_triton(dot_out_dtype)
|
|
2238
|
+
|
|
2239
|
+
return M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device
|
|
2240
|
+
|
|
2241
|
+
|
|
2242
|
+
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
|
|
2243
|
+
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
|
|
2244
|
+
|
|
2245
|
+
|
|
2246
|
+
# Force a failure instead of a warning when all configs are pruned.
|
|
2247
|
+
# TODO: Determine a better approach for model level testing. We need
|
|
2248
|
+
# to standardize our approach around prune_configs in general.
|
|
2249
|
+
FORCE_FAILURE_ON_EMPTY_CONFIGS = False
|
|
2250
|
+
|
|
2251
|
+
|
|
2252
|
+
def is_invalid_config(config, N, M, K, mfma, use_bias):
|
|
2253
|
+
"""
|
|
2254
|
+
Contains all of the configuration checks for prune_configs
|
|
2255
|
+
that will result in an invalid result if select as the config.
|
|
2256
|
+
|
|
2257
|
+
This is done to ensure that if no config is "optimal" for a given
|
|
2258
|
+
shape we don't accidentally select
|
|
2259
|
+
"""
|
|
2260
|
+
BLOCK_SIZE_M = config.kwargs.get("BLOCK_M")
|
|
2261
|
+
BLOCK_SIZE_N = config.kwargs.get("BLOCK_N")
|
|
2262
|
+
BLOCK_SIZE_K = config.kwargs.get("BLOCK_K")
|
|
2263
|
+
SPLIT_K = config.kwargs.get("SPLIT_K")
|
|
2264
|
+
matrix_instr_nonkdim = config.kwargs.get("matrix_instr_nonkdim")
|
|
2265
|
+
if matrix_instr_nonkdim > mfma:
|
|
2266
|
+
return True
|
|
2267
|
+
if mfma == 4 and BLOCK_SIZE_K < 64:
|
|
2268
|
+
return True
|
|
2269
|
+
# some layouts could not work properly in case
|
|
2270
|
+
# number elements per thread is less 1
|
|
2271
|
+
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
|
2272
|
+
return True
|
|
2273
|
+
if BLOCK_SIZE_M < matrix_instr_nonkdim or BLOCK_SIZE_N < matrix_instr_nonkdim:
|
|
2274
|
+
return True
|
|
2275
|
+
if M <= matrix_instr_nonkdim and BLOCK_SIZE_M != matrix_instr_nonkdim:
|
|
2276
|
+
return True
|
|
2277
|
+
if N <= matrix_instr_nonkdim and BLOCK_SIZE_N != matrix_instr_nonkdim:
|
|
2278
|
+
return True
|
|
2279
|
+
# split_k cannot be used if there is a bias
|
|
2280
|
+
if use_bias and SPLIT_K != 1:
|
|
2281
|
+
return True
|
|
2282
|
+
return False
|
|
2283
|
+
|
|
2284
|
+
|
|
2285
|
+
# Configs adapted from https://github.com/ROCm/triton/blob/main_perf/python/perf-kernels/tools/tune_gemm/tune_gemm.py
|
|
2286
|
+
def prune_configs(configs, named_args, **kwargs):
|
|
2287
|
+
pruned_configs = []
|
|
2288
|
+
M = named_args["M"]
|
|
2289
|
+
N = named_args["N"]
|
|
2290
|
+
K = named_args["K"]
|
|
2291
|
+
elemBytes_a = named_args["A"].element_size()
|
|
2292
|
+
elemBytes_b = named_args["B"].element_size()
|
|
2293
|
+
use_bias = kwargs["USE_BIAS"]
|
|
2294
|
+
|
|
2295
|
+
if M < 32 or N < 32:
|
|
2296
|
+
mfma = 16
|
|
2297
|
+
else:
|
|
2298
|
+
mfma = 32
|
|
2299
|
+
|
|
2300
|
+
for config in configs:
|
|
2301
|
+
BLOCK_SIZE_M = config.kwargs.get("BLOCK_M")
|
|
2302
|
+
BLOCK_SIZE_N = config.kwargs.get("BLOCK_N")
|
|
2303
|
+
BLOCK_SIZE_K = config.kwargs.get("BLOCK_K")
|
|
2304
|
+
SPLIT_K = config.kwargs.get("SPLIT_K")
|
|
2305
|
+
GROUP_M = config.kwargs.get("GROUP_M")
|
|
2306
|
+
if is_invalid_config(config, N, M, K, mfma, use_bias):
|
|
2307
|
+
continue
|
|
2308
|
+
# Skip BLOCK_SIZE that is too large compare to M/N
|
|
2309
|
+
# unless BLOCK_SIZE is already small enough
|
|
2310
|
+
if BLOCK_SIZE_M > M * 2 and BLOCK_SIZE_M != 16:
|
|
2311
|
+
continue
|
|
2312
|
+
if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16:
|
|
2313
|
+
continue
|
|
2314
|
+
# skip large split_k when not necessary
|
|
2315
|
+
if SPLIT_K != 1 and not need_split_k(M, N, K):
|
|
2316
|
+
continue
|
|
2317
|
+
# skip large GROUP_M
|
|
2318
|
+
if GROUP_M * BLOCK_SIZE_M >= M and GROUP_M != 1:
|
|
2319
|
+
continue
|
|
2320
|
+
# out of shared memory resource
|
|
2321
|
+
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
|
2322
|
+
LDS = (
|
|
2323
|
+
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
|
|
2324
|
+
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
|
|
2325
|
+
)
|
|
2326
|
+
if LDS > 65536:
|
|
2327
|
+
continue
|
|
2328
|
+
pruned_configs.append(config)
|
|
2329
|
+
|
|
2330
|
+
print(f"{len(configs)=} {len(pruned_configs)=} for {M=} {N=} {K=}")
|
|
2331
|
+
if len(pruned_configs) == 0:
|
|
2332
|
+
if not FORCE_FAILURE_ON_EMPTY_CONFIGS:
|
|
2333
|
+
# Prune configs that can lead to incorrect results even if all configs are sub-optimal.
|
|
2334
|
+
candidate_configs = [
|
|
2335
|
+
c for c in configs if not is_invalid_config(c, N, M, K, mfma, use_bias)
|
|
2336
|
+
]
|
|
2337
|
+
print(f"No configs left after pruning! {M=} {N=} {K=}")
|
|
2338
|
+
pruned_configs = candidate_configs[:10]
|
|
2339
|
+
if len(pruned_configs) == 0:
|
|
2340
|
+
raise RuntimeError(
|
|
2341
|
+
"No valid configs left after pruning! Consider autotuning further with TritonBench"
|
|
2342
|
+
)
|
|
2343
|
+
return pruned_configs
|
|
2344
|
+
|
|
2345
|
+
|
|
2346
|
+
def get_full_non_persistent_tuning_space():
|
|
2347
|
+
configs = []
|
|
2348
|
+
|
|
2349
|
+
block_mn_range = [16, 32, 64, 128, 256]
|
|
2350
|
+
block_k_range = [16, 32, 64, 128, 256]
|
|
2351
|
+
split_k_range = [1]
|
|
2352
|
+
num_warps_range = [1, 2, 4, 8]
|
|
2353
|
+
group_m_range = [1, 2, 4, 8, 16, 32]
|
|
2354
|
+
num_stage_range = [2]
|
|
2355
|
+
waves_per_eu_range = [0]
|
|
2356
|
+
matrix_instr_nonkdim_range = [16, 32]
|
|
2357
|
+
kpack_range = [1, 2]
|
|
2358
|
+
|
|
2359
|
+
for block_m in block_mn_range:
|
|
2360
|
+
for block_n in block_mn_range:
|
|
2361
|
+
for block_k in block_k_range:
|
|
2362
|
+
for num_warps in num_warps_range:
|
|
2363
|
+
for group_m in group_m_range:
|
|
2364
|
+
for split_k in split_k_range:
|
|
2365
|
+
for num_stages in num_stage_range:
|
|
2366
|
+
for waves_per_eu in waves_per_eu_range:
|
|
2367
|
+
for (
|
|
2368
|
+
matrix_instr_nonkdim
|
|
2369
|
+
) in matrix_instr_nonkdim_range:
|
|
2370
|
+
for kpack in kpack_range:
|
|
2371
|
+
configs.append(
|
|
2372
|
+
triton.Config(
|
|
2373
|
+
{
|
|
2374
|
+
"BLOCK_M": block_m,
|
|
2375
|
+
"BLOCK_N": block_n,
|
|
2376
|
+
"BLOCK_K": block_k,
|
|
2377
|
+
"GROUP_M": group_m,
|
|
2378
|
+
"SPLIT_K": split_k,
|
|
2379
|
+
"waves_per_eu": waves_per_eu,
|
|
2380
|
+
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
|
2381
|
+
"kpack": kpack,
|
|
2382
|
+
},
|
|
2383
|
+
num_warps=num_warps,
|
|
2384
|
+
num_stages=num_stages,
|
|
2385
|
+
)
|
|
2386
|
+
)
|
|
2387
|
+
return configs
|
|
2388
|
+
|
|
2389
|
+
|
|
2390
|
+
MATMUL_CONFIGS_NON_PERSISTENT: List[Config] = get_full_non_persistent_tuning_space()
|
|
2391
|
+
# (BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, SPLIT_K, waves_per_eu, matrix_instr_nonkdim, kpack, num_warps, num_stages)
|
|
2392
|
+
_MATMUL_CONFIG_TUPLES_PINGPONG_4K_8K_16K = [
|
|
2393
|
+
(16, 16, 256, 1, 1, 8, 16, 2, 2, 2),
|
|
2394
|
+
(16, 16, 256, 1, 1, 0, 16, 2, 2, 2),
|
|
2395
|
+
(32, 64, 512, 1, 1, 2, 16, 2, 8, 2),
|
|
2396
|
+
(64, 64, 256, 1, 1, 2, 16, 2, 4, 2),
|
|
2397
|
+
(256, 256, 128, 32, 1, 2, 16, 1, 8, 2),
|
|
2398
|
+
(256, 256, 128, 2, 1, 0, 32, 2, 8, 2),
|
|
2399
|
+
(256, 256, 128, 1, 1, 0, 32, 2, 8, 2),
|
|
2400
|
+
(256, 256, 128, 2, 1, 0, 16, 1, 8, 2),
|
|
2401
|
+
(256, 256, 64, 2, 1, 2, 16, 1, 8, 2),
|
|
2402
|
+
(128, 256, 64, 2, 1, 2, 16, 1, 4, 2),
|
|
2403
|
+
(256, 128, 128, 4, 1, 0, 16, 1, 8, 2),
|
|
2404
|
+
(128, 128, 128, 1, 1, 2, 16, 2, 4, 2),
|
|
2405
|
+
(128, 128, 256, 1, 1, 2, 16, 2, 8, 2),
|
|
2406
|
+
(128, 128, 64, 4, 1, 2, 16, 2, 4, 2),
|
|
2407
|
+
(128, 128, 64, 1, 1, 2, 16, 2, 4, 2),
|
|
2408
|
+
(128, 64, 64, 4, 1, 0, 16, 2, 4, 2),
|
|
2409
|
+
(128, 64, 64, 1, 1, 0, 16, 2, 4, 2),
|
|
2410
|
+
(256, 128, 128, 1, 1, 2, 16, 1, 8, 2),
|
|
2411
|
+
(128, 256, 128, 2, 1, 2, 16, 2, 4, 1),
|
|
2412
|
+
(256, 128, 64, 2, 1, 2, 16, 1, 4, 2),
|
|
2413
|
+
(128, 128, 256, 2, 1, 0, 16, 2, 8, 2),
|
|
2414
|
+
(128, 64, 128, 2, 1, 2, 16, 2, 4, 2),
|
|
2415
|
+
(128, 128, 64, 2, 1, 0, 16, 1, 4, 2),
|
|
2416
|
+
(128, 128, 128, 1, 1, 2, 16, 1, 4, 2),
|
|
2417
|
+
]
|
|
2418
|
+
|
|
2419
|
+
|
|
2420
|
+
def _should_skip_config(block_k, matrix_instr_nonkdim):
|
|
2421
|
+
"""Skip config if BLOCK_K=64 and matrix_instr_nonkdim=16 on GFX95+"""
|
|
2422
|
+
try:
|
|
2423
|
+
return (
|
|
2424
|
+
block_k == 64
|
|
2425
|
+
and matrix_instr_nonkdim == 16
|
|
2426
|
+
and torch.version.hip is not None
|
|
2427
|
+
and torch.cuda.get_device_capability() >= (9, 5)
|
|
2428
|
+
)
|
|
2429
|
+
except RuntimeError:
|
|
2430
|
+
# If no HIP GPUs are available, we can't check device capability
|
|
2431
|
+
# so we don't skip any configs
|
|
2432
|
+
return False
|
|
2433
|
+
|
|
2434
|
+
|
|
2435
|
+
MATMUL_CONFIGS_NON_PERSISTENT_PINGPONG_4K_8K_16K = [
|
|
2436
|
+
triton.Config(
|
|
2437
|
+
{
|
|
2438
|
+
"BLOCK_M": block_m,
|
|
2439
|
+
"BLOCK_N": block_n,
|
|
2440
|
+
"BLOCK_K": block_k,
|
|
2441
|
+
"GROUP_M": group_m,
|
|
2442
|
+
"SPLIT_K": split_k,
|
|
2443
|
+
"waves_per_eu": waves_per_eu,
|
|
2444
|
+
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
|
2445
|
+
"kpack": kpack,
|
|
2446
|
+
},
|
|
2447
|
+
num_warps=num_warps,
|
|
2448
|
+
num_stages=num_stages,
|
|
2449
|
+
)
|
|
2450
|
+
for block_m, block_n, block_k, group_m, split_k, waves_per_eu, matrix_instr_nonkdim, kpack, num_warps, num_stages in _MATMUL_CONFIG_TUPLES_PINGPONG_4K_8K_16K
|
|
2451
|
+
if not _should_skip_config(block_k, matrix_instr_nonkdim)
|
|
2452
|
+
]
|
|
2453
|
+
|
|
2454
|
+
# Set this to enable full autotuning for proper benchmarking.
|
|
2455
|
+
# This should only be used when invoking the kernel through
|
|
2456
|
+
# Triton directly (e.g. TritonBench)
|
|
2457
|
+
#
|
|
2458
|
+
# NOTE: This will SIGNIFICANTLY increase autotuning time, often
|
|
2459
|
+
# taking hours. You should combine this with TRITON_PRINT_AUTOTUNING=1
|
|
2460
|
+
# to extract and add the optimal autotuning configs to
|
|
2461
|
+
# MATMUL_CONFIGS_NON_PERSISTENT_PINGPONG_4K_8K_16K.
|
|
2462
|
+
|
|
2463
|
+
FULL_NON_PERSISTENT_AUTOTUNING = False
|
|
2464
|
+
USED_MATMUL_NON_PERSISTENT_CONFIGS = (
|
|
2465
|
+
MATMUL_CONFIGS_NON_PERSISTENT
|
|
2466
|
+
if FULL_NON_PERSISTENT_AUTOTUNING
|
|
2467
|
+
else MATMUL_CONFIGS_NON_PERSISTENT_PINGPONG_4K_8K_16K
|
|
2468
|
+
)
|
|
2469
|
+
|
|
2470
|
+
|
|
2471
|
+
@triton.autotune(
|
|
2472
|
+
configs=USED_MATMUL_NON_PERSISTENT_CONFIGS,
|
|
2473
|
+
key=["M", "N", "K"],
|
|
2474
|
+
prune_configs_by={
|
|
2475
|
+
"early_config_prune": prune_configs,
|
|
2476
|
+
"perf_model": None,
|
|
2477
|
+
"top_k": None,
|
|
2478
|
+
},
|
|
2479
|
+
use_cuda_graph=FULL_NON_PERSISTENT_AUTOTUNING,
|
|
2480
|
+
)
|
|
2481
|
+
@triton.heuristics(
|
|
2482
|
+
{
|
|
2483
|
+
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
2484
|
+
}
|
|
2485
|
+
)
|
|
2486
|
+
@triton.jit
|
|
2487
|
+
def _kernel_matmul_fp8_row_non_persistent(
|
|
2488
|
+
A,
|
|
2489
|
+
B,
|
|
2490
|
+
C,
|
|
2491
|
+
M,
|
|
2492
|
+
N,
|
|
2493
|
+
K,
|
|
2494
|
+
m_key,
|
|
2495
|
+
n_key,
|
|
2496
|
+
k_key,
|
|
2497
|
+
A_scale,
|
|
2498
|
+
B_scale,
|
|
2499
|
+
Bias,
|
|
2500
|
+
stride_am,
|
|
2501
|
+
stride_ak,
|
|
2502
|
+
stride_bn,
|
|
2503
|
+
stride_bk,
|
|
2504
|
+
stride_cm,
|
|
2505
|
+
stride_cn,
|
|
2506
|
+
dot_out_dtype: tl.constexpr,
|
|
2507
|
+
allow_tf32: tl.constexpr,
|
|
2508
|
+
fp8_fast_accum: tl.constexpr,
|
|
2509
|
+
BLOCK_M: tl.constexpr,
|
|
2510
|
+
BLOCK_N: tl.constexpr,
|
|
2511
|
+
BLOCK_K: tl.constexpr,
|
|
2512
|
+
GROUP_M: tl.constexpr,
|
|
2513
|
+
SPLIT_K: tl.constexpr,
|
|
2514
|
+
EVEN_K: tl.constexpr,
|
|
2515
|
+
USE_BIAS: tl.constexpr,
|
|
2516
|
+
AB_DTYPE: tl.constexpr,
|
|
2517
|
+
) -> None:
|
|
2518
|
+
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
|
|
2519
|
+
|
|
2520
|
+
performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
|
|
2521
|
+
|
|
2522
|
+
Args:
|
|
2523
|
+
A (TensorWrapper): [M, K] input tensor.
|
|
2524
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
2525
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
2526
|
+
M (int): M dimension of input tensor.
|
|
2527
|
+
N (int): N dimension of input tensor.
|
|
2528
|
+
K (int): K dimension of input tensor.
|
|
2529
|
+
m_key (int): Autotuning key for M dimension of input tensor.
|
|
2530
|
+
n_key (int): Autotuning key for N dimension of input tensor.
|
|
2531
|
+
k_key (int): Autotuning key for K dimension of input tensor.
|
|
2532
|
+
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
|
|
2533
|
+
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
|
|
2534
|
+
Bias (tensorWrapper): [N] Optional bias tensor.
|
|
2535
|
+
stride_am (int): Stride of M dimension of A.
|
|
2536
|
+
stride_ak (int): Stride of K dimension of A.
|
|
2537
|
+
stride_bn (int): Stride of N dimension of B.
|
|
2538
|
+
stride_bk (int): Stride of K dimension of B.
|
|
2539
|
+
stride_cm (int): Stride of M dimension of C.
|
|
2540
|
+
stride_cn (int): Stride of N dimension of C.
|
|
2541
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
2542
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
2543
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
2544
|
+
BLOCK_M (int): Block size for M dimension.
|
|
2545
|
+
BLOCK_N (int): Block size for N dimension.
|
|
2546
|
+
BLOCK_K (int): Block size for K dimension.
|
|
2547
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
2548
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
2549
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
2550
|
+
USE_BIAS (bool): Whether to use bias.
|
|
2551
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
2552
|
+
"""
|
|
2553
|
+
tl.assume(M >= 0)
|
|
2554
|
+
tl.assume(N >= 0)
|
|
2555
|
+
tl.assume(K >= 0)
|
|
2556
|
+
tl.assume(stride_am >= 0)
|
|
2557
|
+
tl.assume(stride_ak >= 0)
|
|
2558
|
+
tl.assume(stride_bn >= 0)
|
|
2559
|
+
tl.assume(stride_bk >= 0)
|
|
2560
|
+
tl.assume(stride_cm >= 0)
|
|
2561
|
+
tl.assume(stride_cn >= 0)
|
|
2562
|
+
# Matrix multiplication.
|
|
2563
|
+
pid = tl.program_id(0)
|
|
2564
|
+
pid_z = tl.program_id(1)
|
|
2565
|
+
grid_m = tl.cdiv(M, BLOCK_M)
|
|
2566
|
+
grid_n = tl.cdiv(N, BLOCK_N)
|
|
2567
|
+
# Re-order program ID for better L2 performance (swizzle).
|
|
2568
|
+
width = GROUP_M * grid_n
|
|
2569
|
+
group_id = pid // width
|
|
2570
|
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
|
2571
|
+
pid_m = group_id * GROUP_M + ((pid % width) % group_size)
|
|
2572
|
+
pid_n = (pid % width) // (group_size)
|
|
2573
|
+
tl.assume(pid_m >= 0)
|
|
2574
|
+
tl.assume(pid_n >= 0)
|
|
2575
|
+
# Do matrix multiplication.
|
|
2576
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
2577
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
2578
|
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
|
2579
|
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
|
2580
|
+
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
2581
|
+
# Pointers.
|
|
2582
|
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
|
2583
|
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
|
2584
|
+
acc_dtype = tl.float32 if allow_tf32 else dot_out_dtype
|
|
2585
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
|
|
2586
|
+
|
|
2587
|
+
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
|
2588
|
+
if EVEN_K:
|
|
2589
|
+
a = tl.load(A)
|
|
2590
|
+
b = tl.load(B)
|
|
2591
|
+
else:
|
|
2592
|
+
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
|
2593
|
+
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
|
|
2594
|
+
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
|
|
2595
|
+
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
|
|
2596
|
+
if AB_DTYPE:
|
|
2597
|
+
a = a.to(C.dtype.element_ty)
|
|
2598
|
+
b = b.to(C.dtype.element_ty)
|
|
2599
|
+
if fp8_fast_accum:
|
|
2600
|
+
acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32)
|
|
2601
|
+
else:
|
|
2602
|
+
acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32)
|
|
2603
|
+
|
|
2604
|
+
A += BLOCK_K * SPLIT_K * stride_ak
|
|
2605
|
+
B += BLOCK_K * SPLIT_K * stride_bk
|
|
2606
|
+
|
|
2607
|
+
# rematerialize rm and rn to save registers
|
|
2608
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
2609
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
2610
|
+
|
|
2611
|
+
# Invert scaling.
|
|
2612
|
+
a_scale = tl.load(A_scale + rm, mask=rm < M)
|
|
2613
|
+
b_scale = tl.load(B_scale + rn, mask=rn < N)
|
|
2614
|
+
# Invert vector, then multiply on matrix for speed.
|
|
2615
|
+
# pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
|
|
2616
|
+
scale = a_scale[:, None] * b_scale[None, :]
|
|
2617
|
+
acc *= scale
|
|
2618
|
+
|
|
2619
|
+
# Load and add bias if specified.
|
|
2620
|
+
if USE_BIAS:
|
|
2621
|
+
bias = tl.load(Bias + rn, mask=rn < N)
|
|
2622
|
+
acc += bias[None, :]
|
|
2623
|
+
|
|
2624
|
+
acc = acc.to(C.dtype.element_ty)
|
|
2625
|
+
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
|
2626
|
+
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
|
2627
|
+
# Handles write-back with reduction-splitting
|
|
2628
|
+
if SPLIT_K == 1:
|
|
2629
|
+
tl.store(C, acc, mask=mask)
|
|
2630
|
+
else:
|
|
2631
|
+
tl.atomic_add(C, acc, mask=mask)
|
|
2632
|
+
|
|
2633
|
+
|
|
2634
|
+
# This function is extracted from https://github.com/pytorch/ao/blob/v0.12.0/torchao/prototype/mx_formats/mx_tensor.py#L142
|
|
2635
|
+
def to_mxfp8(
|
|
2636
|
+
data_hp: torch.Tensor,
|
|
2637
|
+
block_size: int = 32,
|
|
2638
|
+
):
|
|
2639
|
+
assert data_hp.dtype in (
|
|
2640
|
+
torch.bfloat16,
|
|
2641
|
+
torch.float,
|
|
2642
|
+
), f"{data_hp.dtype} is not supported yet"
|
|
2643
|
+
assert data_hp.shape[-1] % block_size == 0, (
|
|
2644
|
+
f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}"
|
|
2645
|
+
)
|
|
2646
|
+
assert data_hp.is_contiguous(), "unsupported"
|
|
2647
|
+
|
|
2648
|
+
orig_shape = data_hp.shape
|
|
2649
|
+
data_hp = data_hp.reshape(
|
|
2650
|
+
*orig_shape[:-1], orig_shape[-1] // block_size, block_size
|
|
2651
|
+
)
|
|
2652
|
+
|
|
2653
|
+
max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)
|
|
2654
|
+
|
|
2655
|
+
data_hp = data_hp.to(torch.float32)
|
|
2656
|
+
max_abs = max_abs.to(torch.float32)
|
|
2657
|
+
|
|
2658
|
+
F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
|
|
2659
|
+
max_pos = F8E4M3_MAX
|
|
2660
|
+
|
|
2661
|
+
# RCEIL
|
|
2662
|
+
def _to_mx_rceil(
|
|
2663
|
+
data_hp: torch.Tensor,
|
|
2664
|
+
max_abs: torch.Tensor,
|
|
2665
|
+
max_pos: float,
|
|
2666
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
2667
|
+
E8M0_EXPONENT_BIAS = 127
|
|
2668
|
+
descale = max_abs / max_pos
|
|
2669
|
+
exponent = torch.where(
|
|
2670
|
+
torch.isnan(descale),
|
|
2671
|
+
0xFF, # Handle biased exponent for nan
|
|
2672
|
+
# NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping
|
|
2673
|
+
(
|
|
2674
|
+
torch.clamp(
|
|
2675
|
+
torch.ceil(torch.log2(descale)),
|
|
2676
|
+
min=-E8M0_EXPONENT_BIAS,
|
|
2677
|
+
max=E8M0_EXPONENT_BIAS,
|
|
2678
|
+
)
|
|
2679
|
+
+ E8M0_EXPONENT_BIAS
|
|
2680
|
+
).to(torch.uint8),
|
|
2681
|
+
)
|
|
2682
|
+
|
|
2683
|
+
descale_fp = torch.where(
|
|
2684
|
+
exponent == 0,
|
|
2685
|
+
1.0,
|
|
2686
|
+
torch.exp2(E8M0_EXPONENT_BIAS - exponent.to(torch.float32)),
|
|
2687
|
+
)
|
|
2688
|
+
|
|
2689
|
+
# scale and saturated cast the data elements to max of target dtype
|
|
2690
|
+
data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos)
|
|
2691
|
+
return exponent, data_lp
|
|
2692
|
+
|
|
2693
|
+
scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos)
|
|
2694
|
+
|
|
2695
|
+
# cast to target dtype
|
|
2696
|
+
data_lp = data_lp.to(torch.float8_e4m3fn)
|
|
2697
|
+
# need to reshape at the end to help inductor fuse things
|
|
2698
|
+
data_lp = data_lp.reshape(orig_shape)
|
|
2699
|
+
|
|
2700
|
+
scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
|
|
2701
|
+
scale_e8m0_biased = scale_e8m0_biased.squeeze(-1)
|
|
2702
|
+
return scale_e8m0_biased, data_lp
|