fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,4422 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
import functools
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
from typing import Optional, Union
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import triton # @manual
|
|
15
|
+
|
|
16
|
+
import triton.language as tl # @manual
|
|
17
|
+
|
|
18
|
+
from fbgemm_gpu.experimental.gemm.triton_gemm.matmul_perf_model import (
|
|
19
|
+
early_config_prune,
|
|
20
|
+
estimate_matmul_time,
|
|
21
|
+
)
|
|
22
|
+
from fbgemm_gpu.experimental.gemm.triton_gemm.utils import (
|
|
23
|
+
map_dtype_to_triton,
|
|
24
|
+
TmaAutoTuneHelper,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
from packaging import version
|
|
28
|
+
from torch._tensor import Tensor
|
|
29
|
+
|
|
30
|
+
from triton import Config # @manual
|
|
31
|
+
from triton.runtime.jit import reinterpret as tl_reinterpret, TensorWrapper # @manual
|
|
32
|
+
|
|
33
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
running_on_github: bool = os.getenv("GITHUB_ENV") is not None
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
# pyre-ignore[21]
|
|
39
|
+
from triton.fb.compat import disable_bufferops # @manual
|
|
40
|
+
except ModuleNotFoundError:
|
|
41
|
+
# Ensure we can call disable_bufferops if compat is not included (e.g. opensource)
|
|
42
|
+
# TODO(njriasan): Remove when we integrate triton.fb.compat into every Triton
|
|
43
|
+
# version.
|
|
44
|
+
from contextlib import contextmanager
|
|
45
|
+
|
|
46
|
+
@contextmanager
|
|
47
|
+
def disable_bufferops(_unused: bool):
|
|
48
|
+
yield None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@functools.lru_cache
|
|
52
|
+
def supports_float8_fnuz(throw_on_hip_incompatibility: bool = True) -> bool:
|
|
53
|
+
if torch.version.hip:
|
|
54
|
+
device_capability = torch.cuda.get_device_capability()
|
|
55
|
+
|
|
56
|
+
if device_capability < (9, 4):
|
|
57
|
+
gpu_arch = torch.cuda.get_device_properties("cuda").gcnArchName
|
|
58
|
+
msg = f"Unsupported GPU arch: {gpu_arch} for FP8"
|
|
59
|
+
if throw_on_hip_incompatibility:
|
|
60
|
+
raise RuntimeError(msg)
|
|
61
|
+
else:
|
|
62
|
+
logging.error(msg)
|
|
63
|
+
return False
|
|
64
|
+
|
|
65
|
+
elif device_capability == (9, 4):
|
|
66
|
+
return True
|
|
67
|
+
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_fp8_constants() -> tuple[torch.dtype, tl.dtype, float, float]:
|
|
72
|
+
"""
|
|
73
|
+
Helper function to get constant values for the current platform.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
pt_dtype (torch.dtype): The correct torch fp8 datatype.
|
|
77
|
+
tl_dtype (tl.dtype): The correct triton fp8 datatype.
|
|
78
|
+
max_fp8 (float): The maximum reprsentable value for the fp8 datatype.
|
|
79
|
+
eps (float): Minimum clip value to prevent divide by zero.
|
|
80
|
+
"""
|
|
81
|
+
if supports_float8_fnuz(throw_on_hip_incompatibility=(not running_on_github)):
|
|
82
|
+
pt_fp8_dtype = torch.float8_e4m3fnuz
|
|
83
|
+
tl_fp8_dtype = tl.float8e4b8
|
|
84
|
+
else:
|
|
85
|
+
pt_fp8_dtype = torch.float8_e4m3fn
|
|
86
|
+
tl_fp8_dtype = tl.float8e4nv
|
|
87
|
+
|
|
88
|
+
return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def reinterpret_fp8_type(tensor: torch.Tensor, dtype: tl.dtype) -> TensorWrapper:
|
|
92
|
+
"""
|
|
93
|
+
Converts tensor to triton fp8 type.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
tensor (torch.Tensor): input tensor.
|
|
97
|
+
dtype (tl.dtype): target triton dtype.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
triton.TensorWrapper: fp8 tensor.
|
|
101
|
+
"""
|
|
102
|
+
return tl_reinterpret(tensor, dtype=dtype)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def init_to_zero(name):
|
|
106
|
+
return lambda nargs: nargs[name].zero_()
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def get_configs_io_bound() -> list[Config]:
|
|
110
|
+
"""
|
|
111
|
+
Returns a list of configs for matmul that are IO bound.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
List[Config]: list of configs.
|
|
115
|
+
"""
|
|
116
|
+
configs = []
|
|
117
|
+
for num_stages in [2, 3, 4, 5, 6]:
|
|
118
|
+
for block_m in [16, 32]:
|
|
119
|
+
for block_k in [32, 64]:
|
|
120
|
+
for block_n in [32, 64, 128, 256]:
|
|
121
|
+
num_warps = 2 if block_n <= 64 else 4
|
|
122
|
+
configs.append(
|
|
123
|
+
Config(
|
|
124
|
+
{
|
|
125
|
+
"BLOCK_M": block_m,
|
|
126
|
+
"BLOCK_N": block_n,
|
|
127
|
+
"BLOCK_K": block_k,
|
|
128
|
+
"SPLIT_K": 1,
|
|
129
|
+
},
|
|
130
|
+
num_stages=num_stages,
|
|
131
|
+
num_warps=num_warps,
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
# split_k
|
|
135
|
+
for split_k in []: # Disabled [2, 4, 8, 16]:
|
|
136
|
+
configs.append(
|
|
137
|
+
Config(
|
|
138
|
+
{
|
|
139
|
+
"BLOCK_M": block_m,
|
|
140
|
+
"BLOCK_N": block_n,
|
|
141
|
+
"BLOCK_K": block_k,
|
|
142
|
+
"SPLIT_K": split_k,
|
|
143
|
+
},
|
|
144
|
+
num_stages=num_stages,
|
|
145
|
+
num_warps=num_warps,
|
|
146
|
+
pre_hook=init_to_zero("C"),
|
|
147
|
+
)
|
|
148
|
+
)
|
|
149
|
+
return configs
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def dummy_prune_configs(configs, named_args, **kwargs):
|
|
153
|
+
|
|
154
|
+
M = named_args["M"]
|
|
155
|
+
N = named_args["N"]
|
|
156
|
+
K = named_args["K"]
|
|
157
|
+
|
|
158
|
+
logger.info(f"{len(configs)=} {len(configs)=} for {M=} {N=} {K=}")
|
|
159
|
+
return configs
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
MATMUL_CONFIGS: list[Config] = [
|
|
163
|
+
# basic configs for compute-bound matmuls
|
|
164
|
+
Config(
|
|
165
|
+
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
166
|
+
num_stages=3,
|
|
167
|
+
num_warps=8,
|
|
168
|
+
),
|
|
169
|
+
Config(
|
|
170
|
+
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
171
|
+
num_stages=3,
|
|
172
|
+
num_warps=8,
|
|
173
|
+
),
|
|
174
|
+
Config(
|
|
175
|
+
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
176
|
+
num_stages=4,
|
|
177
|
+
num_warps=4,
|
|
178
|
+
),
|
|
179
|
+
Config(
|
|
180
|
+
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
181
|
+
num_stages=4,
|
|
182
|
+
num_warps=4,
|
|
183
|
+
),
|
|
184
|
+
Config(
|
|
185
|
+
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
186
|
+
num_stages=4,
|
|
187
|
+
num_warps=4,
|
|
188
|
+
),
|
|
189
|
+
Config(
|
|
190
|
+
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 256, "SPLIT_K": 1},
|
|
191
|
+
num_stages=4,
|
|
192
|
+
num_warps=4,
|
|
193
|
+
),
|
|
194
|
+
Config(
|
|
195
|
+
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
196
|
+
num_stages=4,
|
|
197
|
+
num_warps=4,
|
|
198
|
+
),
|
|
199
|
+
Config(
|
|
200
|
+
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
201
|
+
num_stages=4,
|
|
202
|
+
num_warps=4,
|
|
203
|
+
),
|
|
204
|
+
Config(
|
|
205
|
+
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
206
|
+
num_stages=4,
|
|
207
|
+
num_warps=4,
|
|
208
|
+
),
|
|
209
|
+
Config(
|
|
210
|
+
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
211
|
+
num_stages=4,
|
|
212
|
+
num_warps=4,
|
|
213
|
+
),
|
|
214
|
+
Config(
|
|
215
|
+
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
|
|
216
|
+
num_stages=5,
|
|
217
|
+
num_warps=2,
|
|
218
|
+
),
|
|
219
|
+
# good for int8
|
|
220
|
+
Config(
|
|
221
|
+
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
222
|
+
num_stages=3,
|
|
223
|
+
num_warps=8,
|
|
224
|
+
),
|
|
225
|
+
Config(
|
|
226
|
+
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
227
|
+
num_stages=3,
|
|
228
|
+
num_warps=8,
|
|
229
|
+
),
|
|
230
|
+
Config(
|
|
231
|
+
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
232
|
+
num_stages=4,
|
|
233
|
+
num_warps=4,
|
|
234
|
+
),
|
|
235
|
+
Config(
|
|
236
|
+
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
237
|
+
num_stages=4,
|
|
238
|
+
num_warps=4,
|
|
239
|
+
),
|
|
240
|
+
Config(
|
|
241
|
+
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
242
|
+
num_stages=4,
|
|
243
|
+
num_warps=4,
|
|
244
|
+
),
|
|
245
|
+
Config(
|
|
246
|
+
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
|
|
247
|
+
num_stages=4,
|
|
248
|
+
num_warps=4,
|
|
249
|
+
),
|
|
250
|
+
Config(
|
|
251
|
+
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
|
|
252
|
+
num_stages=4,
|
|
253
|
+
num_warps=4,
|
|
254
|
+
),
|
|
255
|
+
Config(
|
|
256
|
+
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
|
|
257
|
+
num_stages=4,
|
|
258
|
+
num_warps=4,
|
|
259
|
+
),
|
|
260
|
+
Config(
|
|
261
|
+
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
|
|
262
|
+
num_stages=5,
|
|
263
|
+
num_warps=2,
|
|
264
|
+
),
|
|
265
|
+
] + get_configs_io_bound()
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@triton.autotune(
|
|
269
|
+
configs=MATMUL_CONFIGS,
|
|
270
|
+
prune_configs_by={
|
|
271
|
+
"early_config_prune": dummy_prune_configs,
|
|
272
|
+
},
|
|
273
|
+
key=[
|
|
274
|
+
"m_key",
|
|
275
|
+
"n_key",
|
|
276
|
+
"k_key",
|
|
277
|
+
],
|
|
278
|
+
)
|
|
279
|
+
@triton.jit
|
|
280
|
+
def _kernel_matmul_fp8_row(
|
|
281
|
+
A_ptr,
|
|
282
|
+
B_ptr,
|
|
283
|
+
C_ptr,
|
|
284
|
+
M,
|
|
285
|
+
N,
|
|
286
|
+
K,
|
|
287
|
+
m_key,
|
|
288
|
+
n_key,
|
|
289
|
+
k_key,
|
|
290
|
+
A_scale,
|
|
291
|
+
B_scale,
|
|
292
|
+
Bias,
|
|
293
|
+
stride_am,
|
|
294
|
+
stride_ak,
|
|
295
|
+
stride_bn,
|
|
296
|
+
stride_bk,
|
|
297
|
+
stride_cm,
|
|
298
|
+
stride_cn,
|
|
299
|
+
dot_out_dtype: tl.constexpr,
|
|
300
|
+
allow_tf32: tl.constexpr,
|
|
301
|
+
fp8_fast_accum: tl.constexpr,
|
|
302
|
+
skip_scaling_a: tl.constexpr,
|
|
303
|
+
BLOCK_M: tl.constexpr,
|
|
304
|
+
BLOCK_N: tl.constexpr,
|
|
305
|
+
BLOCK_K: tl.constexpr,
|
|
306
|
+
GROUP_M: tl.constexpr,
|
|
307
|
+
SPLIT_K: tl.constexpr,
|
|
308
|
+
USE_BIAS: tl.constexpr,
|
|
309
|
+
AB_DTYPE: tl.constexpr,
|
|
310
|
+
NUM_SMS: tl.constexpr,
|
|
311
|
+
) -> None:
|
|
312
|
+
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
|
|
313
|
+
|
|
314
|
+
performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
A (TensorWrapper): [M, K] input tensor.
|
|
318
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
319
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
320
|
+
M (int): M dimension of input tensor.
|
|
321
|
+
N (int): N dimension of input tensor.
|
|
322
|
+
K (int): K dimension of input tensor.
|
|
323
|
+
m_key (int): Autotuning key for M dimension of input tensor.
|
|
324
|
+
n_key (int): Autotuning key for N dimension of input tensor.
|
|
325
|
+
k_key (int): Autotuning key for K dimension of input tensor.
|
|
326
|
+
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A.
|
|
327
|
+
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B.
|
|
328
|
+
Bias (tensorWrapper): [N] Optional bias tensor.
|
|
329
|
+
stride_am (int): Stride of M dimension of A.
|
|
330
|
+
stride_ak (int): Stride of K dimension of A.
|
|
331
|
+
stride_bn (int): Stride of N dimension of B.
|
|
332
|
+
stride_bk (int): Stride of K dimension of B.
|
|
333
|
+
stride_cm (int): Stride of M dimension of C.
|
|
334
|
+
stride_cn (int): Stride of N dimension of C.
|
|
335
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
336
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
337
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
338
|
+
BLOCK_M (int): Block size for M dimension.
|
|
339
|
+
BLOCK_N (int): Block size for N dimension.
|
|
340
|
+
BLOCK_K (int): Block size for K dimension.
|
|
341
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
342
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
343
|
+
USE_BIAS (bool): Whether to use bias.
|
|
344
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
345
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
346
|
+
"""
|
|
347
|
+
# Matrix multiplication.
|
|
348
|
+
start_pid = tl.program_id(axis=0)
|
|
349
|
+
num_pid_m = tl.cdiv(M, BLOCK_M)
|
|
350
|
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
|
351
|
+
k_tiles = tl.cdiv(K, BLOCK_K)
|
|
352
|
+
num_tiles = num_pid_m * num_pid_n
|
|
353
|
+
|
|
354
|
+
tiles_per_SM = num_tiles // NUM_SMS
|
|
355
|
+
if start_pid < num_tiles % NUM_SMS:
|
|
356
|
+
tiles_per_SM += 1
|
|
357
|
+
|
|
358
|
+
tile_id = start_pid - NUM_SMS
|
|
359
|
+
ki = -1
|
|
360
|
+
|
|
361
|
+
offs_k_for_mask = tl.arange(0, BLOCK_K)
|
|
362
|
+
|
|
363
|
+
num_pid_in_group = GROUP_M * num_pid_n
|
|
364
|
+
|
|
365
|
+
pid_m = 0
|
|
366
|
+
pid_n = 0
|
|
367
|
+
offs_am = tl.arange(0, BLOCK_M)
|
|
368
|
+
offs_bn = tl.arange(0, BLOCK_N)
|
|
369
|
+
acc_dtype = tl.float32 if allow_tf32 else dot_out_dtype
|
|
370
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
|
|
371
|
+
|
|
372
|
+
for _ in range(0, k_tiles * tiles_per_SM):
|
|
373
|
+
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
|
|
374
|
+
if ki == 0:
|
|
375
|
+
tile_id += NUM_SMS
|
|
376
|
+
group_id = tile_id // num_pid_in_group
|
|
377
|
+
first_pid_m = group_id * GROUP_M
|
|
378
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
|
|
379
|
+
pid_m = first_pid_m + (tile_id % group_size_m)
|
|
380
|
+
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
|
381
|
+
|
|
382
|
+
start_m = pid_m * BLOCK_M
|
|
383
|
+
start_n = pid_n * BLOCK_N
|
|
384
|
+
offs_am = start_m + tl.arange(0, BLOCK_M)
|
|
385
|
+
offs_bn = start_n + tl.arange(0, BLOCK_N)
|
|
386
|
+
offs_am = tl.where(offs_am < M, offs_am, 0)
|
|
387
|
+
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
|
|
388
|
+
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_M), BLOCK_M)
|
|
389
|
+
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_N), BLOCK_N)
|
|
390
|
+
offs_k = ki * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
391
|
+
A = A_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
392
|
+
B = B_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
393
|
+
|
|
394
|
+
a = tl.load(A, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_K, other=0.0)
|
|
395
|
+
b = tl.load(B, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_K, other=0.0)
|
|
396
|
+
acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32)
|
|
397
|
+
|
|
398
|
+
if ki == k_tiles - 1:
|
|
399
|
+
# rematerialize rm and rn to save registers
|
|
400
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
401
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
402
|
+
|
|
403
|
+
# Invert scaling.
|
|
404
|
+
b_scale = tl.load(B_scale + rn, mask=rn < N)
|
|
405
|
+
if skip_scaling_a:
|
|
406
|
+
acc *= b_scale[None, :]
|
|
407
|
+
else:
|
|
408
|
+
a_scale = tl.load(A_scale + rm, mask=rm < M)
|
|
409
|
+
# pyre-ignore[16]: Undefined attribute [16]: `float`
|
|
410
|
+
# has no attribute `__getitem__`.
|
|
411
|
+
scale = a_scale[:, None] * b_scale[None, :]
|
|
412
|
+
acc *= scale
|
|
413
|
+
|
|
414
|
+
# Load and add bias if specified.
|
|
415
|
+
if USE_BIAS:
|
|
416
|
+
bias = tl.load(Bias + rn, mask=rn < N)
|
|
417
|
+
acc += bias[None, :]
|
|
418
|
+
|
|
419
|
+
acc = acc.to(C_ptr.dtype.element_ty)
|
|
420
|
+
C = C_ptr + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
|
421
|
+
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
|
422
|
+
# Handles write-back with reduction-splitting
|
|
423
|
+
tl.store(C, acc, mask=mask)
|
|
424
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
@triton.autotune(
|
|
428
|
+
configs=MATMUL_CONFIGS
|
|
429
|
+
+ [
|
|
430
|
+
Config(
|
|
431
|
+
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
432
|
+
num_stages=3,
|
|
433
|
+
num_warps=8,
|
|
434
|
+
),
|
|
435
|
+
],
|
|
436
|
+
key=[
|
|
437
|
+
"m_key",
|
|
438
|
+
"n_key",
|
|
439
|
+
"k_key",
|
|
440
|
+
],
|
|
441
|
+
)
|
|
442
|
+
@triton.heuristics(
|
|
443
|
+
{
|
|
444
|
+
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
445
|
+
}
|
|
446
|
+
)
|
|
447
|
+
@triton.jit
|
|
448
|
+
def _kernel_matmul_fp8_row_no_fast_acc(
|
|
449
|
+
A_ptr,
|
|
450
|
+
B_ptr,
|
|
451
|
+
C_ptr,
|
|
452
|
+
M,
|
|
453
|
+
N,
|
|
454
|
+
K,
|
|
455
|
+
m_key,
|
|
456
|
+
n_key,
|
|
457
|
+
k_key,
|
|
458
|
+
A_scale,
|
|
459
|
+
B_scale,
|
|
460
|
+
Bias,
|
|
461
|
+
stride_am,
|
|
462
|
+
stride_ak,
|
|
463
|
+
stride_bn,
|
|
464
|
+
stride_bk,
|
|
465
|
+
stride_cm,
|
|
466
|
+
stride_cn,
|
|
467
|
+
dot_out_dtype: tl.constexpr,
|
|
468
|
+
allow_tf32: tl.constexpr,
|
|
469
|
+
fp8_fast_accum: tl.constexpr,
|
|
470
|
+
BLOCK_M: tl.constexpr,
|
|
471
|
+
BLOCK_N: tl.constexpr,
|
|
472
|
+
BLOCK_K: tl.constexpr,
|
|
473
|
+
GROUP_M: tl.constexpr,
|
|
474
|
+
SPLIT_K: tl.constexpr,
|
|
475
|
+
EVEN_K: tl.constexpr,
|
|
476
|
+
USE_BIAS: tl.constexpr,
|
|
477
|
+
AB_DTYPE: tl.constexpr,
|
|
478
|
+
NUM_SMS: tl.constexpr,
|
|
479
|
+
) -> None:
|
|
480
|
+
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
|
|
481
|
+
|
|
482
|
+
performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
A (TensorWrapper): [M, K] input tensor.
|
|
486
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
487
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
488
|
+
M (int): M dimension of input tensor.
|
|
489
|
+
N (int): N dimension of input tensor.
|
|
490
|
+
K (int): K dimension of input tensor.
|
|
491
|
+
m_key (int): Autotuning key for M dimension of input tensor.
|
|
492
|
+
n_key (int): Autotuning key for N dimension of input tensor.
|
|
493
|
+
k_key (int): Autotuning key for K dimension of input tensor.
|
|
494
|
+
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
|
|
495
|
+
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
|
|
496
|
+
Bias (TensorWrapper): [N] Optional bias tensor.
|
|
497
|
+
stride_am (int): Stride of M dimension of A.
|
|
498
|
+
stride_ak (int): Stride of K dimension of A.
|
|
499
|
+
stride_bn (int): Stride of N dimension of B.
|
|
500
|
+
stride_bk (int): Stride of K dimension of B.
|
|
501
|
+
stride_cm (int): Stride of M dimension of C.
|
|
502
|
+
stride_cn (int): Stride of N dimension of C.
|
|
503
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
504
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
505
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
506
|
+
BLOCK_M (int): Block size for M dimension.
|
|
507
|
+
BLOCK_N (int): Block size for N dimension.
|
|
508
|
+
BLOCK_K (int): Block size for K dimension.
|
|
509
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
510
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
511
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
512
|
+
USE_BIAS(bool): Whether to use bias.
|
|
513
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
514
|
+
"""
|
|
515
|
+
# Matrix multiplication.
|
|
516
|
+
|
|
517
|
+
start_pid = tl.program_id(axis=0)
|
|
518
|
+
num_pid_m = tl.cdiv(M, BLOCK_M)
|
|
519
|
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
|
520
|
+
k_tiles = tl.cdiv(K, BLOCK_K)
|
|
521
|
+
num_tiles = num_pid_m * num_pid_n
|
|
522
|
+
|
|
523
|
+
tiles_per_SM = num_tiles // NUM_SMS
|
|
524
|
+
if start_pid < num_tiles % NUM_SMS:
|
|
525
|
+
tiles_per_SM += 1
|
|
526
|
+
|
|
527
|
+
tile_id = start_pid - NUM_SMS
|
|
528
|
+
ki = -1
|
|
529
|
+
|
|
530
|
+
offs_k_for_mask = tl.arange(0, BLOCK_K)
|
|
531
|
+
|
|
532
|
+
num_pid_in_group = GROUP_M * num_pid_n
|
|
533
|
+
|
|
534
|
+
pid_m = 0
|
|
535
|
+
pid_n = 0
|
|
536
|
+
offs_am = tl.arange(0, BLOCK_M)
|
|
537
|
+
offs_bn = tl.arange(0, BLOCK_N)
|
|
538
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
539
|
+
|
|
540
|
+
for _ in range(0, k_tiles * tiles_per_SM):
|
|
541
|
+
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
|
|
542
|
+
if ki == 0:
|
|
543
|
+
tile_id += NUM_SMS
|
|
544
|
+
group_id = tile_id // num_pid_in_group
|
|
545
|
+
first_pid_m = group_id * GROUP_M
|
|
546
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
|
|
547
|
+
pid_m = first_pid_m + (tile_id % group_size_m)
|
|
548
|
+
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
|
549
|
+
|
|
550
|
+
start_m = pid_m * BLOCK_M
|
|
551
|
+
start_n = pid_n * BLOCK_N
|
|
552
|
+
offs_am = start_m + tl.arange(0, BLOCK_M)
|
|
553
|
+
offs_bn = start_n + tl.arange(0, BLOCK_N)
|
|
554
|
+
offs_am = tl.where(offs_am < M, offs_am, 0)
|
|
555
|
+
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
|
|
556
|
+
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_M), BLOCK_M)
|
|
557
|
+
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_N), BLOCK_N)
|
|
558
|
+
offs_k = ki * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
559
|
+
A = A_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
560
|
+
B = B_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
561
|
+
|
|
562
|
+
a = tl.load(A, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_K, other=0.0)
|
|
563
|
+
b = tl.load(B, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_K, other=0.0)
|
|
564
|
+
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
|
565
|
+
|
|
566
|
+
if ki == k_tiles - 1:
|
|
567
|
+
# rematerialize rm and rn to save registers
|
|
568
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
569
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
570
|
+
|
|
571
|
+
# Invert scaling.
|
|
572
|
+
a_scale = tl.load(A_scale + rm, mask=rm < M)
|
|
573
|
+
b_scale = tl.load(B_scale + rn, mask=rn < N)
|
|
574
|
+
# pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
|
|
575
|
+
scale = a_scale[:, None] * b_scale[None, :]
|
|
576
|
+
acc *= scale
|
|
577
|
+
|
|
578
|
+
# Load and add bias if specified.
|
|
579
|
+
if USE_BIAS:
|
|
580
|
+
bias = tl.load(Bias + rn, mask=rn < N)
|
|
581
|
+
acc += bias[None, :]
|
|
582
|
+
|
|
583
|
+
acc = acc.to(C_ptr.dtype.element_ty)
|
|
584
|
+
C = C_ptr + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
|
585
|
+
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
|
586
|
+
# Handles write-back with reduction-splitting
|
|
587
|
+
tl.store(C, acc, mask=mask)
|
|
588
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
@triton.autotune(
|
|
592
|
+
configs=MATMUL_CONFIGS,
|
|
593
|
+
key=[
|
|
594
|
+
"m_key",
|
|
595
|
+
"n_key",
|
|
596
|
+
"k_key",
|
|
597
|
+
],
|
|
598
|
+
)
|
|
599
|
+
@triton.heuristics(
|
|
600
|
+
{
|
|
601
|
+
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
602
|
+
}
|
|
603
|
+
)
|
|
604
|
+
@triton.jit
|
|
605
|
+
def _kernel_matmul_fp8_row_imprecise_acc(
|
|
606
|
+
A,
|
|
607
|
+
B,
|
|
608
|
+
C,
|
|
609
|
+
M,
|
|
610
|
+
N,
|
|
611
|
+
K,
|
|
612
|
+
m_key,
|
|
613
|
+
n_key,
|
|
614
|
+
k_key,
|
|
615
|
+
A_scale,
|
|
616
|
+
B_scale,
|
|
617
|
+
Bias,
|
|
618
|
+
stride_am,
|
|
619
|
+
stride_ak,
|
|
620
|
+
stride_bn,
|
|
621
|
+
stride_bk,
|
|
622
|
+
stride_cm,
|
|
623
|
+
stride_cn,
|
|
624
|
+
dot_out_dtype: tl.constexpr,
|
|
625
|
+
allow_tf32: tl.constexpr,
|
|
626
|
+
fp8_fast_accum: tl.constexpr,
|
|
627
|
+
BLOCK_M: tl.constexpr,
|
|
628
|
+
BLOCK_N: tl.constexpr,
|
|
629
|
+
BLOCK_K: tl.constexpr,
|
|
630
|
+
GROUP_M: tl.constexpr,
|
|
631
|
+
SPLIT_K: tl.constexpr,
|
|
632
|
+
EVEN_K: tl.constexpr,
|
|
633
|
+
USE_BIAS: tl.constexpr,
|
|
634
|
+
AB_DTYPE: tl.constexpr,
|
|
635
|
+
) -> None:
|
|
636
|
+
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
|
|
637
|
+
|
|
638
|
+
performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
|
|
639
|
+
|
|
640
|
+
Args:
|
|
641
|
+
A (TensorWrapper): [M, K] input tensor.
|
|
642
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
643
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
644
|
+
M (int): M dimension of input tensor.
|
|
645
|
+
N (int): N dimension of input tensor.
|
|
646
|
+
K (int): K dimension of input tensor.
|
|
647
|
+
m_key (int): Autotuning key for M dimension of input tensor.
|
|
648
|
+
n_key (int): Autotuning key for N dimension of input tensor.
|
|
649
|
+
k_key (int): Autotuning key for K dimension of input tensor.
|
|
650
|
+
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
|
|
651
|
+
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
|
|
652
|
+
Bias (TensorWrapper): [N] Optional bias tensor.
|
|
653
|
+
stride_am (int): Stride of M dimension of A.
|
|
654
|
+
stride_ak (int): Stride of K dimension of A.
|
|
655
|
+
stride_bn (int): Stride of N dimension of B.
|
|
656
|
+
stride_bk (int): Stride of K dimension of B.
|
|
657
|
+
stride_cm (int): Stride of M dimension of C.
|
|
658
|
+
stride_cn (int): Stride of N dimension of C.
|
|
659
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
660
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
661
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
662
|
+
BLOCK_M (int): Block size for M dimension.
|
|
663
|
+
BLOCK_N (int): Block size for N dimension.
|
|
664
|
+
BLOCK_K (int): Block size for K dimension.
|
|
665
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
666
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
667
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
668
|
+
USE_BIAS (bool): Whether to use bias.
|
|
669
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
670
|
+
"""
|
|
671
|
+
# Matrix multiplication.
|
|
672
|
+
pid = tl.program_id(0)
|
|
673
|
+
pid_z = tl.program_id(1)
|
|
674
|
+
grid_m = tl.cdiv(M, BLOCK_M)
|
|
675
|
+
grid_n = tl.cdiv(N, BLOCK_N)
|
|
676
|
+
# Re-order program ID for better L2 performance (swizzle).
|
|
677
|
+
width = GROUP_M * grid_n
|
|
678
|
+
group_id = pid // width
|
|
679
|
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
|
680
|
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
|
681
|
+
pid_n = (pid % width) // (group_size)
|
|
682
|
+
# Do matrix multiplication.
|
|
683
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
684
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
685
|
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
|
686
|
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
|
687
|
+
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
688
|
+
# Pointers.
|
|
689
|
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
|
690
|
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
|
691
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
692
|
+
|
|
693
|
+
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
|
694
|
+
if EVEN_K:
|
|
695
|
+
a = tl.load(A)
|
|
696
|
+
b = tl.load(B)
|
|
697
|
+
else:
|
|
698
|
+
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
|
699
|
+
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
|
|
700
|
+
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
|
|
701
|
+
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
|
|
702
|
+
if AB_DTYPE:
|
|
703
|
+
a = a.to(C.dtype.element_ty)
|
|
704
|
+
b = b.to(C.dtype.element_ty)
|
|
705
|
+
if fp8_fast_accum:
|
|
706
|
+
acc = tl.dot(
|
|
707
|
+
a,
|
|
708
|
+
b,
|
|
709
|
+
acc,
|
|
710
|
+
max_num_imprecise_acc=32,
|
|
711
|
+
out_dtype=dot_out_dtype,
|
|
712
|
+
allow_tf32=allow_tf32,
|
|
713
|
+
)
|
|
714
|
+
else:
|
|
715
|
+
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
|
716
|
+
|
|
717
|
+
A += BLOCK_K * SPLIT_K * stride_ak
|
|
718
|
+
B += BLOCK_K * SPLIT_K * stride_bk
|
|
719
|
+
|
|
720
|
+
# rematerialize rm and rn to save registers
|
|
721
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
722
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
723
|
+
|
|
724
|
+
# Invert scaling.
|
|
725
|
+
a_scale = tl.load(A_scale + rm, mask=rm < M)
|
|
726
|
+
b_scale = tl.load(B_scale + rn, mask=rn < N)
|
|
727
|
+
# pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
|
|
728
|
+
scale = a_scale[:, None] * b_scale[None, :]
|
|
729
|
+
acc *= scale
|
|
730
|
+
|
|
731
|
+
# Apply bias.
|
|
732
|
+
if USE_BIAS:
|
|
733
|
+
bias = tl.load(Bias + rn, mask=rn < N)
|
|
734
|
+
acc += bias[None, :]
|
|
735
|
+
|
|
736
|
+
acc = acc.to(C.dtype.element_ty)
|
|
737
|
+
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
|
738
|
+
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
|
739
|
+
# Handles write-back with reduction-splitting
|
|
740
|
+
if SPLIT_K == 1:
|
|
741
|
+
tl.store(C, acc, mask=mask)
|
|
742
|
+
else:
|
|
743
|
+
tl.atomic_add(C, acc, mask=mask)
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
@triton.autotune(
|
|
747
|
+
configs=[
|
|
748
|
+
Config(
|
|
749
|
+
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
750
|
+
num_stages=3,
|
|
751
|
+
num_warps=8,
|
|
752
|
+
),
|
|
753
|
+
Config(
|
|
754
|
+
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
755
|
+
num_stages=3,
|
|
756
|
+
num_warps=8,
|
|
757
|
+
),
|
|
758
|
+
Config(
|
|
759
|
+
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
760
|
+
num_stages=4,
|
|
761
|
+
num_warps=4,
|
|
762
|
+
),
|
|
763
|
+
Config(
|
|
764
|
+
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
765
|
+
num_stages=4,
|
|
766
|
+
num_warps=4,
|
|
767
|
+
),
|
|
768
|
+
Config(
|
|
769
|
+
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
|
770
|
+
num_stages=4,
|
|
771
|
+
num_warps=4,
|
|
772
|
+
),
|
|
773
|
+
Config(
|
|
774
|
+
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
|
|
775
|
+
num_stages=4,
|
|
776
|
+
num_warps=4,
|
|
777
|
+
),
|
|
778
|
+
Config(
|
|
779
|
+
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
|
|
780
|
+
num_stages=4,
|
|
781
|
+
num_warps=4,
|
|
782
|
+
),
|
|
783
|
+
Config(
|
|
784
|
+
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 512, "SPLIT_K": 1},
|
|
785
|
+
num_stages=3,
|
|
786
|
+
num_warps=4,
|
|
787
|
+
),
|
|
788
|
+
],
|
|
789
|
+
key=[
|
|
790
|
+
"m_key",
|
|
791
|
+
"n_key",
|
|
792
|
+
"k_key",
|
|
793
|
+
],
|
|
794
|
+
use_cuda_graph=True,
|
|
795
|
+
)
|
|
796
|
+
@triton.heuristics(
|
|
797
|
+
{
|
|
798
|
+
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
799
|
+
}
|
|
800
|
+
)
|
|
801
|
+
@triton.jit
|
|
802
|
+
def _kernel_matmul_fp8_row_tma_persistent(
|
|
803
|
+
A_ptr,
|
|
804
|
+
B_ptr,
|
|
805
|
+
C_ptr,
|
|
806
|
+
M,
|
|
807
|
+
N,
|
|
808
|
+
K,
|
|
809
|
+
m_key,
|
|
810
|
+
n_key,
|
|
811
|
+
k_key,
|
|
812
|
+
A_scale,
|
|
813
|
+
B_scale,
|
|
814
|
+
Bias,
|
|
815
|
+
stride_am,
|
|
816
|
+
stride_ak,
|
|
817
|
+
stride_bn,
|
|
818
|
+
stride_bk,
|
|
819
|
+
stride_cm,
|
|
820
|
+
stride_cn,
|
|
821
|
+
dot_out_dtype: tl.constexpr,
|
|
822
|
+
c_dtype: tl.constexpr,
|
|
823
|
+
bias_dtype: tl.constexpr,
|
|
824
|
+
allow_tf32: tl.constexpr,
|
|
825
|
+
fp8_fast_accum: tl.constexpr,
|
|
826
|
+
BLOCK_M: tl.constexpr,
|
|
827
|
+
BLOCK_N: tl.constexpr,
|
|
828
|
+
BLOCK_K: tl.constexpr,
|
|
829
|
+
GROUP_M: tl.constexpr,
|
|
830
|
+
AB_DTYPE: tl.constexpr,
|
|
831
|
+
SPLIT_K: tl.constexpr,
|
|
832
|
+
EVEN_K: tl.constexpr,
|
|
833
|
+
NUM_SMS: tl.constexpr,
|
|
834
|
+
USE_BIAS: tl.constexpr,
|
|
835
|
+
) -> None:
|
|
836
|
+
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
|
|
837
|
+
|
|
838
|
+
performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
|
|
839
|
+
|
|
840
|
+
Args:
|
|
841
|
+
A (TensorWrapper): [M, K] input tensor.
|
|
842
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
843
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
844
|
+
M (int): M dimension of input tensor.
|
|
845
|
+
N (int): N dimension of input tensor.
|
|
846
|
+
K (int): K dimension of input tensor.
|
|
847
|
+
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
|
|
848
|
+
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
|
|
849
|
+
stride_am (int): Stride of M dimension of A.
|
|
850
|
+
stride_ak (int): Stride of K dimension of A.
|
|
851
|
+
stride_bn (int): Stride of N dimension of B.
|
|
852
|
+
stride_bk (int): Stride of K dimension of B.
|
|
853
|
+
stride_cm (int): Stride of M dimension of C.
|
|
854
|
+
stride_cn (int): Stride of N dimension of C.
|
|
855
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
856
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
857
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
858
|
+
BLOCK_M (int): Block size for M dimension.
|
|
859
|
+
BLOCK_N (int): Block size for N dimension.
|
|
860
|
+
BLOCK_K (int): Block size for K dimension.
|
|
861
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
862
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
863
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
864
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
865
|
+
"""
|
|
866
|
+
# Matrix multiplication.
|
|
867
|
+
start_pid = tl.program_id(axis=0)
|
|
868
|
+
num_pid_m = tl.cdiv(M, BLOCK_M)
|
|
869
|
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
|
870
|
+
k_tiles = tl.cdiv(K, BLOCK_K)
|
|
871
|
+
num_tiles = num_pid_m * num_pid_n
|
|
872
|
+
|
|
873
|
+
tiles_per_SM = num_tiles // NUM_SMS
|
|
874
|
+
if start_pid < num_tiles % NUM_SMS:
|
|
875
|
+
tiles_per_SM += 1
|
|
876
|
+
|
|
877
|
+
tile_id = start_pid - NUM_SMS
|
|
878
|
+
ki = -1
|
|
879
|
+
|
|
880
|
+
pid_m = 0
|
|
881
|
+
pid_n = 0
|
|
882
|
+
offs_am = 0
|
|
883
|
+
offs_bn = 0
|
|
884
|
+
|
|
885
|
+
num_pid_in_group = GROUP_M * num_pid_n
|
|
886
|
+
|
|
887
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
888
|
+
|
|
889
|
+
dtype_fp8 = tl.float8e4nv
|
|
890
|
+
scale_dtype = tl.float32
|
|
891
|
+
|
|
892
|
+
for _ in range(0, k_tiles * tiles_per_SM):
|
|
893
|
+
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
|
|
894
|
+
if ki == 0:
|
|
895
|
+
tile_id += NUM_SMS
|
|
896
|
+
group_id = tile_id // num_pid_in_group
|
|
897
|
+
first_pid_m = group_id * GROUP_M
|
|
898
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
|
|
899
|
+
pid_m = first_pid_m + (tile_id % group_size_m)
|
|
900
|
+
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
|
901
|
+
|
|
902
|
+
offs_am = pid_m * BLOCK_M
|
|
903
|
+
offs_bn = pid_n * BLOCK_N
|
|
904
|
+
offs_am = tl.multiple_of(offs_am, BLOCK_M)
|
|
905
|
+
offs_bn = tl.multiple_of(offs_bn, BLOCK_N)
|
|
906
|
+
|
|
907
|
+
offs_k = ki * BLOCK_K
|
|
908
|
+
|
|
909
|
+
a = tl._experimental_descriptor_load(
|
|
910
|
+
A_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], dtype_fp8
|
|
911
|
+
)
|
|
912
|
+
b = tl._experimental_descriptor_load(
|
|
913
|
+
B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
if fp8_fast_accum:
|
|
917
|
+
acc = tl.dot(a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
|
918
|
+
else:
|
|
919
|
+
acc += tl.dot(a, b.T, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
|
920
|
+
|
|
921
|
+
if ki == k_tiles - 1:
|
|
922
|
+
# rematerialize rm and rn to save registers
|
|
923
|
+
|
|
924
|
+
# # Invert scaling.
|
|
925
|
+
a_scale = tl._experimental_descriptor_load(
|
|
926
|
+
A_scale, [offs_am], [BLOCK_M], scale_dtype
|
|
927
|
+
)
|
|
928
|
+
b_scale = tl._experimental_descriptor_load(
|
|
929
|
+
B_scale, [offs_bn], [BLOCK_N], scale_dtype
|
|
930
|
+
)
|
|
931
|
+
# pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
|
|
932
|
+
scale = a_scale[:, None] * b_scale[None, :]
|
|
933
|
+
acc *= scale
|
|
934
|
+
|
|
935
|
+
# Load and add bias if specified.
|
|
936
|
+
if USE_BIAS:
|
|
937
|
+
bias = tl._experimental_descriptor_load(
|
|
938
|
+
Bias, [offs_bn], [BLOCK_N], bias_dtype
|
|
939
|
+
)
|
|
940
|
+
acc += bias[None, :]
|
|
941
|
+
|
|
942
|
+
acc = acc.to(c_dtype)
|
|
943
|
+
tl._experimental_descriptor_store(C_ptr, acc, [offs_am, offs_bn])
|
|
944
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
945
|
+
|
|
946
|
+
|
|
947
|
+
has_warp_specialization = hasattr(tl, "async_task")
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
def make_autotuner_config(dictargs, **kwargs):
|
|
951
|
+
# NOTE: Triton 3.4.x removed some keyword arguments from Config constructor;
|
|
952
|
+
# however, fbcode uses 3.3.1, and so this shim is provided to support both
|
|
953
|
+
# versions.
|
|
954
|
+
#
|
|
955
|
+
# https://github.com/triton-lang/triton/blob/v3.3.1/python/triton/runtime/autotuner.py#L275
|
|
956
|
+
# https://github.com/triton-lang/triton/blame/release/3.4.x/python/triton/runtime/autotuner.py#L319
|
|
957
|
+
if version.parse(triton.__version__) > version.parse("3.3.1"):
|
|
958
|
+
for key in ["num_buffers_warp_spec", "num_consumer_groups"]:
|
|
959
|
+
kwargs.pop(key, None)
|
|
960
|
+
return Config(dictargs, **kwargs)
|
|
961
|
+
|
|
962
|
+
|
|
963
|
+
def get_ws_configs() -> list[Config]:
|
|
964
|
+
if not has_warp_specialization:
|
|
965
|
+
return []
|
|
966
|
+
return [
|
|
967
|
+
make_autotuner_config(
|
|
968
|
+
{
|
|
969
|
+
"BLOCK_M": 128,
|
|
970
|
+
"BLOCK_N": 256,
|
|
971
|
+
"BLOCK_K": 128,
|
|
972
|
+
"SPLIT_K": 1,
|
|
973
|
+
"NUM_CONSUMER_GROUPS": 2,
|
|
974
|
+
},
|
|
975
|
+
num_stages=3,
|
|
976
|
+
num_warps=4,
|
|
977
|
+
num_consumer_groups=2,
|
|
978
|
+
num_buffers_warp_spec=3,
|
|
979
|
+
),
|
|
980
|
+
make_autotuner_config(
|
|
981
|
+
{
|
|
982
|
+
"BLOCK_M": 128,
|
|
983
|
+
"BLOCK_N": 128,
|
|
984
|
+
"BLOCK_K": 128,
|
|
985
|
+
"SPLIT_K": 1,
|
|
986
|
+
"NUM_CONSUMER_GROUPS": 2,
|
|
987
|
+
},
|
|
988
|
+
num_stages=4,
|
|
989
|
+
num_warps=4,
|
|
990
|
+
num_consumer_groups=2,
|
|
991
|
+
num_buffers_warp_spec=4,
|
|
992
|
+
),
|
|
993
|
+
make_autotuner_config(
|
|
994
|
+
{
|
|
995
|
+
"BLOCK_M": 128,
|
|
996
|
+
"BLOCK_N": 256,
|
|
997
|
+
"BLOCK_K": 128,
|
|
998
|
+
"SPLIT_K": 1,
|
|
999
|
+
"NUM_CONSUMER_GROUPS": 1,
|
|
1000
|
+
},
|
|
1001
|
+
num_stages=3,
|
|
1002
|
+
num_warps=8,
|
|
1003
|
+
num_consumer_groups=0,
|
|
1004
|
+
num_buffers_warp_spec=3,
|
|
1005
|
+
),
|
|
1006
|
+
make_autotuner_config(
|
|
1007
|
+
{
|
|
1008
|
+
"BLOCK_M": 64,
|
|
1009
|
+
"BLOCK_N": 64,
|
|
1010
|
+
"BLOCK_K": 512,
|
|
1011
|
+
"SPLIT_K": 1,
|
|
1012
|
+
"NUM_CONSUMER_GROUPS": 1,
|
|
1013
|
+
},
|
|
1014
|
+
num_stages=3,
|
|
1015
|
+
num_warps=4,
|
|
1016
|
+
num_consumer_groups=0,
|
|
1017
|
+
num_buffers_warp_spec=3,
|
|
1018
|
+
),
|
|
1019
|
+
]
|
|
1020
|
+
|
|
1021
|
+
|
|
1022
|
+
@triton.autotune(
|
|
1023
|
+
configs=[
|
|
1024
|
+
Config(
|
|
1025
|
+
{
|
|
1026
|
+
"BLOCK_M": 128,
|
|
1027
|
+
"BLOCK_N": 256,
|
|
1028
|
+
"BLOCK_K": 128,
|
|
1029
|
+
"SPLIT_K": 1,
|
|
1030
|
+
"NUM_CONSUMER_GROUPS": 1,
|
|
1031
|
+
},
|
|
1032
|
+
num_stages=3,
|
|
1033
|
+
num_warps=8,
|
|
1034
|
+
),
|
|
1035
|
+
]
|
|
1036
|
+
+ get_ws_configs(),
|
|
1037
|
+
key=[
|
|
1038
|
+
"m_key",
|
|
1039
|
+
"n_key",
|
|
1040
|
+
"k_key",
|
|
1041
|
+
],
|
|
1042
|
+
use_cuda_graph=True,
|
|
1043
|
+
)
|
|
1044
|
+
@triton.heuristics(
|
|
1045
|
+
{
|
|
1046
|
+
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
1047
|
+
}
|
|
1048
|
+
)
|
|
1049
|
+
@triton.jit
|
|
1050
|
+
def _kernel_matmul_fp8_row_tma_persistent_ws_cooperative(
|
|
1051
|
+
A_ptr,
|
|
1052
|
+
B_ptr,
|
|
1053
|
+
C_ptr,
|
|
1054
|
+
M,
|
|
1055
|
+
N,
|
|
1056
|
+
K,
|
|
1057
|
+
m_key,
|
|
1058
|
+
n_key,
|
|
1059
|
+
k_key,
|
|
1060
|
+
A_scale,
|
|
1061
|
+
B_scale,
|
|
1062
|
+
Bias,
|
|
1063
|
+
stride_am,
|
|
1064
|
+
stride_ak,
|
|
1065
|
+
stride_bn,
|
|
1066
|
+
stride_bk,
|
|
1067
|
+
stride_cm,
|
|
1068
|
+
stride_cn,
|
|
1069
|
+
dot_out_dtype: tl.constexpr,
|
|
1070
|
+
c_dtype: tl.constexpr,
|
|
1071
|
+
bias_dtype: tl.constexpr,
|
|
1072
|
+
allow_tf32: tl.constexpr,
|
|
1073
|
+
fp8_fast_accum: tl.constexpr,
|
|
1074
|
+
BLOCK_M: tl.constexpr,
|
|
1075
|
+
BLOCK_N: tl.constexpr,
|
|
1076
|
+
BLOCK_K: tl.constexpr,
|
|
1077
|
+
GROUP_M: tl.constexpr,
|
|
1078
|
+
AB_DTYPE: tl.constexpr,
|
|
1079
|
+
SPLIT_K: tl.constexpr,
|
|
1080
|
+
EVEN_K: tl.constexpr,
|
|
1081
|
+
NUM_SMS: tl.constexpr,
|
|
1082
|
+
USE_BIAS: tl.constexpr,
|
|
1083
|
+
NUM_CONSUMER_GROUPS: tl.constexpr,
|
|
1084
|
+
) -> None:
|
|
1085
|
+
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
|
|
1086
|
+
|
|
1087
|
+
performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
|
|
1088
|
+
|
|
1089
|
+
Args:
|
|
1090
|
+
A (TensorWrapper): [M , K] input tensor.
|
|
1091
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
1092
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
1093
|
+
M (int): M dimension of input tensor.
|
|
1094
|
+
N (int): N dimension of input tensor.
|
|
1095
|
+
K (int): K dimension of input tensor.
|
|
1096
|
+
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
|
|
1097
|
+
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
|
|
1098
|
+
stride_am (int): Stride of M dimension of A.
|
|
1099
|
+
stride_ak (int): Stride of K dimension of A.
|
|
1100
|
+
stride_bn (int): Stride of N dimension of B.
|
|
1101
|
+
stride_bk (int): Stride of K dimension of B.
|
|
1102
|
+
stride_cm (int): Stride of M dimension of C.
|
|
1103
|
+
stride_cn (int): Stride of N dimension of C.
|
|
1104
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
1105
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
1106
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
1107
|
+
BLOCK_M (int): Block size for M dimension.
|
|
1108
|
+
BLOCK_N (int): Block size for N dimension.
|
|
1109
|
+
BLOCK_K (int): Block size for K dimension.
|
|
1110
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
1111
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
1112
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
1113
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
1114
|
+
"""
|
|
1115
|
+
num_tiles = tl.cdiv(M, BLOCK_M) * tl.cdiv(N, BLOCK_N)
|
|
1116
|
+
num_pid_m = tl.cdiv(M, BLOCK_M)
|
|
1117
|
+
num_pid_n = tl.cdiv(N, BLOCK_N)
|
|
1118
|
+
dtype_fp8 = tl.float8e4nv
|
|
1119
|
+
for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)):
|
|
1120
|
+
num_pid_in_group = GROUP_M * num_pid_n
|
|
1121
|
+
group_id = pid // num_pid_in_group
|
|
1122
|
+
first_pid_m = group_id * GROUP_M
|
|
1123
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
|
|
1124
|
+
# pyre-ignore
|
|
1125
|
+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
|
1126
|
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
1127
|
+
|
|
1128
|
+
# ----------------------------------------------------------
|
|
1129
|
+
# Create pointers for the first blocks of A and B.
|
|
1130
|
+
# We will advance this pointer as we move in the K direction
|
|
1131
|
+
# and accumulate
|
|
1132
|
+
# `a_ptrs` is a block of [BLOCK_M, BLOCK_K] pointers
|
|
1133
|
+
# `b_ptrs` is a block of [BLOCK_K, BLOCK_N] pointers
|
|
1134
|
+
# See above `Pointer Arithmetic` section for details
|
|
1135
|
+
offs_am = pid_m * BLOCK_M
|
|
1136
|
+
offs_bn = pid_n * BLOCK_N
|
|
1137
|
+
offs_k = 0
|
|
1138
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
1139
|
+
# pyre-ignore
|
|
1140
|
+
tl.assume(tl.cdiv(K, BLOCK_K) > 0)
|
|
1141
|
+
for _ in range(0, tl.cdiv(K, BLOCK_K)):
|
|
1142
|
+
# pyre-ignore
|
|
1143
|
+
with tl.async_task([0]):
|
|
1144
|
+
a = tl._experimental_descriptor_load(
|
|
1145
|
+
A_ptr,
|
|
1146
|
+
[offs_am, offs_k],
|
|
1147
|
+
[BLOCK_M, BLOCK_K],
|
|
1148
|
+
dtype_fp8,
|
|
1149
|
+
)
|
|
1150
|
+
b = tl._experimental_descriptor_load(
|
|
1151
|
+
B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8
|
|
1152
|
+
)
|
|
1153
|
+
|
|
1154
|
+
if fp8_fast_accum:
|
|
1155
|
+
acc = tl.dot(
|
|
1156
|
+
a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32
|
|
1157
|
+
)
|
|
1158
|
+
else:
|
|
1159
|
+
acc += tl.dot(a, b.T, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
|
1160
|
+
|
|
1161
|
+
offs_k += BLOCK_K
|
|
1162
|
+
|
|
1163
|
+
# pyre-ignore
|
|
1164
|
+
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
|
|
1165
|
+
# Invert scaling.
|
|
1166
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
1167
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
1168
|
+
a_scale = tl.load(A_scale + rm, mask=rm < M)
|
|
1169
|
+
b_scale = tl.load(B_scale + rn, mask=rn < N)
|
|
1170
|
+
scale = a_scale[:, None] * b_scale[None, :]
|
|
1171
|
+
acc *= scale
|
|
1172
|
+
# Load and add bias if specified.
|
|
1173
|
+
if USE_BIAS:
|
|
1174
|
+
bias = tl._experimental_descriptor_load(
|
|
1175
|
+
Bias, [offs_bn], [BLOCK_N], bias_dtype
|
|
1176
|
+
)
|
|
1177
|
+
acc += bias[None, :]
|
|
1178
|
+
acc = acc.to(c_dtype)
|
|
1179
|
+
tl._experimental_descriptor_store(C_ptr, acc, [offs_am, offs_bn])
|
|
1180
|
+
|
|
1181
|
+
|
|
1182
|
+
def _is_eligible_for_skip_scaling(
|
|
1183
|
+
is_rowwise: bool,
|
|
1184
|
+
fp8_fast_accum: bool,
|
|
1185
|
+
imprecise_acc: bool,
|
|
1186
|
+
tma_persistent: bool,
|
|
1187
|
+
no_use_persistent: Optional[bool],
|
|
1188
|
+
use_warp_specialization: bool,
|
|
1189
|
+
) -> bool:
|
|
1190
|
+
if not is_rowwise:
|
|
1191
|
+
return False
|
|
1192
|
+
|
|
1193
|
+
return (
|
|
1194
|
+
fp8_fast_accum
|
|
1195
|
+
and not imprecise_acc
|
|
1196
|
+
and not tma_persistent
|
|
1197
|
+
and not no_use_persistent
|
|
1198
|
+
and not use_warp_specialization
|
|
1199
|
+
)
|
|
1200
|
+
|
|
1201
|
+
|
|
1202
|
+
@torch._library.triton_op("triton::matmul_fp8_row", mutates_args=())
|
|
1203
|
+
def matmul_fp8_row(
|
|
1204
|
+
a: torch.Tensor,
|
|
1205
|
+
b: torch.Tensor,
|
|
1206
|
+
a_scale: Optional[torch.Tensor],
|
|
1207
|
+
b_scale: torch.Tensor,
|
|
1208
|
+
bias: Optional[torch.Tensor] = None,
|
|
1209
|
+
dot_out_dtype: Optional[torch.dtype] = None,
|
|
1210
|
+
allow_tf32: bool = True,
|
|
1211
|
+
fp8_fast_accum: bool = True,
|
|
1212
|
+
imprecise_acc: bool = False,
|
|
1213
|
+
tma_persistent: bool = True,
|
|
1214
|
+
no_use_persistent: Optional[bool] = None,
|
|
1215
|
+
# add an option to explicitly require the use of persistent process
|
|
1216
|
+
use_persistent: Optional[bool] = None,
|
|
1217
|
+
use_warp_specialization: bool = False,
|
|
1218
|
+
) -> torch.Tensor:
|
|
1219
|
+
"""
|
|
1220
|
+
Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N].
|
|
1221
|
+
|
|
1222
|
+
Args:
|
|
1223
|
+
a (torch.Tensor): [M, K] input tensor.
|
|
1224
|
+
b (torch.Tensor): [N, K] input tensor.
|
|
1225
|
+
a_scale (Optiona;[torch.Tensor]): [M] reciprocal scale tensor per row.
|
|
1226
|
+
A * a_scale = original A. Scaling will be skiped if a_scale is None.
|
|
1227
|
+
b_scale (torch.Tensor): [N] reciprocal scale tensor per row. B * b_scale = original B
|
|
1228
|
+
bias (torch.Tensor): [N] optional bias tensor to add to output if provided.
|
|
1229
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
1230
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
1231
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
1232
|
+
tma_persistent (bool): Whether to use TMA persistent kernel impl.
|
|
1233
|
+
|
|
1234
|
+
Returns:
|
|
1235
|
+
torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :])
|
|
1236
|
+
"""
|
|
1237
|
+
if use_persistent:
|
|
1238
|
+
no_use_persistent = False
|
|
1239
|
+
elif no_use_persistent is None:
|
|
1240
|
+
# Default True for AMD and False for Nvidia.
|
|
1241
|
+
if torch.version.hip is not None:
|
|
1242
|
+
no_use_persistent = True
|
|
1243
|
+
else:
|
|
1244
|
+
no_use_persistent = False
|
|
1245
|
+
# if use_persistent is explicitly requested, set o_use_persistent to False
|
|
1246
|
+
|
|
1247
|
+
# Get datatypes and constants to use.
|
|
1248
|
+
pt_fp8_dtype, _, _, _ = get_fp8_constants()
|
|
1249
|
+
# Handle 3D+ a shape
|
|
1250
|
+
a_shape = a.shape
|
|
1251
|
+
a = a.view(-1, a.size(-1))
|
|
1252
|
+
# View inputs into proper torch fp8 dtype.
|
|
1253
|
+
if torch.version.cuda:
|
|
1254
|
+
assert a.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
|
|
1255
|
+
elif torch.version.hip:
|
|
1256
|
+
if torch.cuda.get_device_capability() < (9, 5):
|
|
1257
|
+
assert a.dtype in (
|
|
1258
|
+
torch.float8_e4m3fnuz,
|
|
1259
|
+
torch.float8_e5m2fnuz,
|
|
1260
|
+
)
|
|
1261
|
+
else:
|
|
1262
|
+
assert a.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
|
|
1263
|
+
else:
|
|
1264
|
+
assert a.dtype in (
|
|
1265
|
+
torch.float8_e4m3fnuz,
|
|
1266
|
+
torch.float8_e5m2fnuz,
|
|
1267
|
+
)
|
|
1268
|
+
assert b.dtype == pt_fp8_dtype
|
|
1269
|
+
M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device = (
|
|
1270
|
+
prep_matmul(a, b, dot_out_dtype)
|
|
1271
|
+
)
|
|
1272
|
+
|
|
1273
|
+
# Skip scaling (a_scale is None) can only be applied in certain cases.
|
|
1274
|
+
assert a_scale is not None or _is_eligible_for_skip_scaling(
|
|
1275
|
+
is_rowwise=True,
|
|
1276
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
1277
|
+
imprecise_acc=imprecise_acc,
|
|
1278
|
+
tma_persistent=tma_persistent,
|
|
1279
|
+
no_use_persistent=no_use_persistent,
|
|
1280
|
+
use_warp_specialization=use_warp_specialization,
|
|
1281
|
+
)
|
|
1282
|
+
|
|
1283
|
+
output_shape = a_shape[:-1] + (N,)
|
|
1284
|
+
# Handle tensor with empty inputs.
|
|
1285
|
+
if (M == 0) or (N == 0) or (K == 0):
|
|
1286
|
+
return torch.zeros(output_shape, device=device, dtype=torch.bfloat16)
|
|
1287
|
+
# launch kernel
|
|
1288
|
+
if a.device == torch.device("cpu"):
|
|
1289
|
+
logger.info(
|
|
1290
|
+
"FP8 Row-wise Triton kernel not supported on cpu, fallback to torch"
|
|
1291
|
+
)
|
|
1292
|
+
if a_scale is None:
|
|
1293
|
+
scale = b_scale[None, :]
|
|
1294
|
+
else:
|
|
1295
|
+
scale = a_scale[:, None] * b_scale[None, :]
|
|
1296
|
+
output = torch.matmul(a.to(torch.bfloat16), b.to(torch.bfloat16).T) * scale
|
|
1297
|
+
if bias is not None:
|
|
1298
|
+
output += bias[None, :]
|
|
1299
|
+
return output.to(c.dtype)
|
|
1300
|
+
|
|
1301
|
+
def grid(META: dict[str, int]) -> tuple[int, int]:
|
|
1302
|
+
return (
|
|
1303
|
+
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
|
1304
|
+
META["SPLIT_K"],
|
|
1305
|
+
)
|
|
1306
|
+
|
|
1307
|
+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
|
1308
|
+
|
|
1309
|
+
def persistent_grid(META: dict[str, int]) -> tuple[int]:
|
|
1310
|
+
return (
|
|
1311
|
+
min(
|
|
1312
|
+
NUM_SMS,
|
|
1313
|
+
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
|
1314
|
+
),
|
|
1315
|
+
)
|
|
1316
|
+
|
|
1317
|
+
if no_use_persistent:
|
|
1318
|
+
logger.debug("Using non-persistent kernel")
|
|
1319
|
+
with torch.cuda.device(a.device.index):
|
|
1320
|
+
torch._library.capture_triton(_kernel_matmul_fp8_row_non_persistent)[grid](
|
|
1321
|
+
a,
|
|
1322
|
+
b,
|
|
1323
|
+
c,
|
|
1324
|
+
M,
|
|
1325
|
+
N,
|
|
1326
|
+
K,
|
|
1327
|
+
m_key,
|
|
1328
|
+
n_key,
|
|
1329
|
+
k_key,
|
|
1330
|
+
a_scale,
|
|
1331
|
+
b_scale,
|
|
1332
|
+
bias,
|
|
1333
|
+
a.stride(0),
|
|
1334
|
+
a.stride(1),
|
|
1335
|
+
b.stride(0),
|
|
1336
|
+
b.stride(1),
|
|
1337
|
+
c.stride(0),
|
|
1338
|
+
c.stride(1),
|
|
1339
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
1340
|
+
allow_tf32=allow_tf32,
|
|
1341
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
1342
|
+
# GROUP_M=8,
|
|
1343
|
+
USE_BIAS=bias is not None,
|
|
1344
|
+
AB_DTYPE=False,
|
|
1345
|
+
)
|
|
1346
|
+
elif use_warp_specialization:
|
|
1347
|
+
assert has_warp_specialization
|
|
1348
|
+
# used by TMA warp specialization kernel
|
|
1349
|
+
desc_helper = TmaAutoTuneHelper()
|
|
1350
|
+
desc_helper.init_tma_descriptor("a")
|
|
1351
|
+
desc_helper.init_tma_descriptor("b")
|
|
1352
|
+
desc_helper.init_tma_descriptor("c")
|
|
1353
|
+
desc_helper.init_tma_descriptor("a_scale")
|
|
1354
|
+
desc_helper.init_tma_descriptor("b_scale")
|
|
1355
|
+
desc_helper.init_tma_descriptor("bias")
|
|
1356
|
+
|
|
1357
|
+
def persistent_grid_tma_ws(META: dict[str, int]) -> tuple[int]:
|
|
1358
|
+
nonlocal desc_helper # noqa: F824
|
|
1359
|
+
assert a_scale is not None # Type narrowing for Pyre
|
|
1360
|
+
desc_helper.fill_2d_tma_descriptor(
|
|
1361
|
+
"a",
|
|
1362
|
+
a.data_ptr(),
|
|
1363
|
+
M,
|
|
1364
|
+
K,
|
|
1365
|
+
META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"],
|
|
1366
|
+
META["BLOCK_K"],
|
|
1367
|
+
a.element_size(),
|
|
1368
|
+
)
|
|
1369
|
+
|
|
1370
|
+
desc_helper.fill_2d_tma_descriptor(
|
|
1371
|
+
"b",
|
|
1372
|
+
b.data_ptr(),
|
|
1373
|
+
N,
|
|
1374
|
+
K,
|
|
1375
|
+
META["BLOCK_N"],
|
|
1376
|
+
META["BLOCK_K"],
|
|
1377
|
+
b.element_size(),
|
|
1378
|
+
)
|
|
1379
|
+
desc_helper.fill_2d_tma_descriptor(
|
|
1380
|
+
"c",
|
|
1381
|
+
c.data_ptr(),
|
|
1382
|
+
M,
|
|
1383
|
+
N,
|
|
1384
|
+
META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"],
|
|
1385
|
+
META["BLOCK_N"],
|
|
1386
|
+
c.element_size(),
|
|
1387
|
+
)
|
|
1388
|
+
desc_helper.fill_1d_tma_descriptor(
|
|
1389
|
+
"a_scale",
|
|
1390
|
+
a_scale.data_ptr(),
|
|
1391
|
+
M,
|
|
1392
|
+
META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"],
|
|
1393
|
+
a_scale.element_size(),
|
|
1394
|
+
)
|
|
1395
|
+
desc_helper.fill_1d_tma_descriptor(
|
|
1396
|
+
"b_scale",
|
|
1397
|
+
b_scale.data_ptr(),
|
|
1398
|
+
N,
|
|
1399
|
+
META["BLOCK_N"],
|
|
1400
|
+
b_scale.element_size(),
|
|
1401
|
+
)
|
|
1402
|
+
if bias is not None:
|
|
1403
|
+
desc_helper.fill_1d_tma_descriptor(
|
|
1404
|
+
"bias",
|
|
1405
|
+
bias.data_ptr(),
|
|
1406
|
+
N,
|
|
1407
|
+
META["BLOCK_N"],
|
|
1408
|
+
bias.element_size(),
|
|
1409
|
+
)
|
|
1410
|
+
return (
|
|
1411
|
+
min(
|
|
1412
|
+
NUM_SMS,
|
|
1413
|
+
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
|
1414
|
+
),
|
|
1415
|
+
)
|
|
1416
|
+
|
|
1417
|
+
desc_a = desc_helper.get_tma_descriptor_kernel_param("a")
|
|
1418
|
+
desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
|
|
1419
|
+
desc_c = desc_helper.get_tma_descriptor_kernel_param("c")
|
|
1420
|
+
desc_a_scale = desc_helper.get_tma_descriptor_kernel_param("a_scale")
|
|
1421
|
+
desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale")
|
|
1422
|
+
desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias")
|
|
1423
|
+
|
|
1424
|
+
bias_dtype_triton = None
|
|
1425
|
+
if bias is not None:
|
|
1426
|
+
bias_dtype_triton = map_dtype_to_triton(bias.dtype)
|
|
1427
|
+
|
|
1428
|
+
# pyre-ignore
|
|
1429
|
+
torch._library.capture_triton(
|
|
1430
|
+
_kernel_matmul_fp8_row_tma_persistent_ws_cooperative
|
|
1431
|
+
)[persistent_grid_tma_ws](
|
|
1432
|
+
desc_a,
|
|
1433
|
+
desc_b,
|
|
1434
|
+
desc_c,
|
|
1435
|
+
M,
|
|
1436
|
+
N,
|
|
1437
|
+
K,
|
|
1438
|
+
m_key,
|
|
1439
|
+
n_key,
|
|
1440
|
+
k_key,
|
|
1441
|
+
a_scale,
|
|
1442
|
+
b_scale,
|
|
1443
|
+
desc_bias,
|
|
1444
|
+
a.stride(0),
|
|
1445
|
+
a.stride(1),
|
|
1446
|
+
b.stride(0),
|
|
1447
|
+
b.stride(1),
|
|
1448
|
+
c.stride(0),
|
|
1449
|
+
c.stride(1),
|
|
1450
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
1451
|
+
c_dtype=c_dtype_triton,
|
|
1452
|
+
bias_dtype=bias_dtype_triton,
|
|
1453
|
+
allow_tf32=allow_tf32,
|
|
1454
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
1455
|
+
GROUP_M=8,
|
|
1456
|
+
AB_DTYPE=False,
|
|
1457
|
+
NUM_SMS=NUM_SMS,
|
|
1458
|
+
USE_BIAS=bias is not None,
|
|
1459
|
+
)
|
|
1460
|
+
elif tma_persistent:
|
|
1461
|
+
# used by TMA persistent kernel
|
|
1462
|
+
desc_helper = TmaAutoTuneHelper()
|
|
1463
|
+
desc_helper.init_tma_descriptor("a")
|
|
1464
|
+
desc_helper.init_tma_descriptor("b")
|
|
1465
|
+
desc_helper.init_tma_descriptor("c")
|
|
1466
|
+
desc_helper.init_tma_descriptor("a_scale")
|
|
1467
|
+
desc_helper.init_tma_descriptor("b_scale")
|
|
1468
|
+
desc_helper.init_tma_descriptor("bias")
|
|
1469
|
+
|
|
1470
|
+
def persistent_grid_tma(META: dict[str, int]) -> tuple[int]:
|
|
1471
|
+
nonlocal desc_helper # noqa: F824
|
|
1472
|
+
assert a_scale is not None # Type narrowing for Pyre
|
|
1473
|
+
desc_helper.fill_2d_tma_descriptor(
|
|
1474
|
+
"a",
|
|
1475
|
+
a.data_ptr(),
|
|
1476
|
+
M,
|
|
1477
|
+
K,
|
|
1478
|
+
META["BLOCK_M"],
|
|
1479
|
+
META["BLOCK_K"],
|
|
1480
|
+
a.element_size(),
|
|
1481
|
+
)
|
|
1482
|
+
|
|
1483
|
+
desc_helper.fill_2d_tma_descriptor(
|
|
1484
|
+
"b",
|
|
1485
|
+
b.data_ptr(),
|
|
1486
|
+
N,
|
|
1487
|
+
K,
|
|
1488
|
+
META["BLOCK_N"],
|
|
1489
|
+
META["BLOCK_K"],
|
|
1490
|
+
b.element_size(),
|
|
1491
|
+
)
|
|
1492
|
+
desc_helper.fill_2d_tma_descriptor(
|
|
1493
|
+
"c",
|
|
1494
|
+
c.data_ptr(),
|
|
1495
|
+
M,
|
|
1496
|
+
N,
|
|
1497
|
+
META["BLOCK_M"],
|
|
1498
|
+
META["BLOCK_N"],
|
|
1499
|
+
c.element_size(),
|
|
1500
|
+
)
|
|
1501
|
+
desc_helper.fill_1d_tma_descriptor(
|
|
1502
|
+
"a_scale",
|
|
1503
|
+
a_scale.data_ptr(),
|
|
1504
|
+
M,
|
|
1505
|
+
META["BLOCK_M"],
|
|
1506
|
+
a_scale.element_size(),
|
|
1507
|
+
)
|
|
1508
|
+
desc_helper.fill_1d_tma_descriptor(
|
|
1509
|
+
"b_scale",
|
|
1510
|
+
b_scale.data_ptr(),
|
|
1511
|
+
N,
|
|
1512
|
+
META["BLOCK_N"],
|
|
1513
|
+
b_scale.element_size(),
|
|
1514
|
+
)
|
|
1515
|
+
if bias is not None:
|
|
1516
|
+
desc_helper.fill_1d_tma_descriptor(
|
|
1517
|
+
"bias",
|
|
1518
|
+
bias.data_ptr(),
|
|
1519
|
+
N,
|
|
1520
|
+
META["BLOCK_N"],
|
|
1521
|
+
bias.element_size(),
|
|
1522
|
+
)
|
|
1523
|
+
return (
|
|
1524
|
+
min(
|
|
1525
|
+
NUM_SMS,
|
|
1526
|
+
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
|
1527
|
+
),
|
|
1528
|
+
)
|
|
1529
|
+
|
|
1530
|
+
desc_a = desc_helper.get_tma_descriptor_kernel_param("a")
|
|
1531
|
+
desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
|
|
1532
|
+
desc_c = desc_helper.get_tma_descriptor_kernel_param("c")
|
|
1533
|
+
desc_a_scale = desc_helper.get_tma_descriptor_kernel_param("a_scale")
|
|
1534
|
+
desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale")
|
|
1535
|
+
desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias")
|
|
1536
|
+
|
|
1537
|
+
bias_dtype_triton = None
|
|
1538
|
+
if bias is not None:
|
|
1539
|
+
bias_dtype_triton = map_dtype_to_triton(bias.dtype)
|
|
1540
|
+
|
|
1541
|
+
# pyre-ignore
|
|
1542
|
+
torch._library.capture_triton(_kernel_matmul_fp8_row_tma_persistent)[
|
|
1543
|
+
persistent_grid_tma
|
|
1544
|
+
](
|
|
1545
|
+
desc_a,
|
|
1546
|
+
desc_b,
|
|
1547
|
+
desc_c,
|
|
1548
|
+
M,
|
|
1549
|
+
N,
|
|
1550
|
+
K,
|
|
1551
|
+
m_key,
|
|
1552
|
+
n_key,
|
|
1553
|
+
k_key,
|
|
1554
|
+
desc_a_scale,
|
|
1555
|
+
desc_b_scale,
|
|
1556
|
+
desc_bias,
|
|
1557
|
+
a.stride(0),
|
|
1558
|
+
a.stride(1),
|
|
1559
|
+
b.stride(0),
|
|
1560
|
+
b.stride(1),
|
|
1561
|
+
c.stride(0),
|
|
1562
|
+
c.stride(1),
|
|
1563
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
1564
|
+
c_dtype=c_dtype_triton,
|
|
1565
|
+
bias_dtype=bias_dtype_triton,
|
|
1566
|
+
allow_tf32=allow_tf32,
|
|
1567
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
1568
|
+
GROUP_M=8,
|
|
1569
|
+
AB_DTYPE=False,
|
|
1570
|
+
NUM_SMS=NUM_SMS,
|
|
1571
|
+
USE_BIAS=bias is not None,
|
|
1572
|
+
)
|
|
1573
|
+
elif imprecise_acc:
|
|
1574
|
+
with torch.cuda.device(a.device.index):
|
|
1575
|
+
torch._library.capture_triton(_kernel_matmul_fp8_row_imprecise_acc)[grid](
|
|
1576
|
+
a,
|
|
1577
|
+
b,
|
|
1578
|
+
c,
|
|
1579
|
+
M,
|
|
1580
|
+
N,
|
|
1581
|
+
K,
|
|
1582
|
+
m_key,
|
|
1583
|
+
n_key,
|
|
1584
|
+
k_key,
|
|
1585
|
+
a_scale,
|
|
1586
|
+
b_scale,
|
|
1587
|
+
bias,
|
|
1588
|
+
a.stride(0),
|
|
1589
|
+
a.stride(1),
|
|
1590
|
+
b.stride(0),
|
|
1591
|
+
b.stride(1),
|
|
1592
|
+
c.stride(0),
|
|
1593
|
+
c.stride(1),
|
|
1594
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
1595
|
+
allow_tf32=allow_tf32,
|
|
1596
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
1597
|
+
GROUP_M=8,
|
|
1598
|
+
USE_BIAS=bias is not None,
|
|
1599
|
+
AB_DTYPE=False,
|
|
1600
|
+
)
|
|
1601
|
+
elif fp8_fast_accum:
|
|
1602
|
+
skip_scaling_a = a_scale is None
|
|
1603
|
+
torch._library.capture_triton(_kernel_matmul_fp8_row)[persistent_grid](
|
|
1604
|
+
a,
|
|
1605
|
+
b,
|
|
1606
|
+
c,
|
|
1607
|
+
M,
|
|
1608
|
+
N,
|
|
1609
|
+
K,
|
|
1610
|
+
m_key,
|
|
1611
|
+
n_key,
|
|
1612
|
+
k_key,
|
|
1613
|
+
a_scale,
|
|
1614
|
+
b_scale,
|
|
1615
|
+
bias,
|
|
1616
|
+
a.stride(0),
|
|
1617
|
+
a.stride(1),
|
|
1618
|
+
b.stride(0),
|
|
1619
|
+
b.stride(1),
|
|
1620
|
+
c.stride(0),
|
|
1621
|
+
c.stride(1),
|
|
1622
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
1623
|
+
allow_tf32=allow_tf32,
|
|
1624
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
1625
|
+
skip_scaling_a=skip_scaling_a,
|
|
1626
|
+
GROUP_M=8,
|
|
1627
|
+
USE_BIAS=bias is not None,
|
|
1628
|
+
AB_DTYPE=False,
|
|
1629
|
+
NUM_SMS=NUM_SMS,
|
|
1630
|
+
)
|
|
1631
|
+
else:
|
|
1632
|
+
torch._library.capture_triton(_kernel_matmul_fp8_row_no_fast_acc)[
|
|
1633
|
+
persistent_grid
|
|
1634
|
+
](
|
|
1635
|
+
a,
|
|
1636
|
+
b,
|
|
1637
|
+
c,
|
|
1638
|
+
M,
|
|
1639
|
+
N,
|
|
1640
|
+
K,
|
|
1641
|
+
m_key,
|
|
1642
|
+
n_key,
|
|
1643
|
+
k_key,
|
|
1644
|
+
a_scale,
|
|
1645
|
+
b_scale,
|
|
1646
|
+
bias,
|
|
1647
|
+
a.stride(0),
|
|
1648
|
+
a.stride(1),
|
|
1649
|
+
b.stride(0),
|
|
1650
|
+
b.stride(1),
|
|
1651
|
+
c.stride(0),
|
|
1652
|
+
c.stride(1),
|
|
1653
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
1654
|
+
allow_tf32=allow_tf32,
|
|
1655
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
1656
|
+
GROUP_M=8,
|
|
1657
|
+
USE_BIAS=bias is not None,
|
|
1658
|
+
AB_DTYPE=False,
|
|
1659
|
+
NUM_SMS=NUM_SMS,
|
|
1660
|
+
)
|
|
1661
|
+
return c.view(output_shape)
|
|
1662
|
+
|
|
1663
|
+
|
|
1664
|
+
@matmul_fp8_row.register_fake
|
|
1665
|
+
def matmul_fp8_row_meta(
|
|
1666
|
+
a: torch.Tensor,
|
|
1667
|
+
b: torch.Tensor,
|
|
1668
|
+
a_scale: Optional[torch.Tensor],
|
|
1669
|
+
b_scale: torch.Tensor,
|
|
1670
|
+
bias: Optional[torch.Tensor] = None,
|
|
1671
|
+
dot_out_dtype: Optional[torch.dtype] = None,
|
|
1672
|
+
allow_tf32: bool = True,
|
|
1673
|
+
fp8_fast_accum: bool = True,
|
|
1674
|
+
imprecise_acc: bool = False,
|
|
1675
|
+
tma_persistent: bool = True,
|
|
1676
|
+
no_use_persistent: Optional[bool] = None,
|
|
1677
|
+
use_warp_specialization: bool = False,
|
|
1678
|
+
) -> torch.Tensor:
|
|
1679
|
+
"""Shape function for torch compile."""
|
|
1680
|
+
M, K = a.shape
|
|
1681
|
+
N, K = b.shape
|
|
1682
|
+
return torch.empty(
|
|
1683
|
+
(M, N),
|
|
1684
|
+
device=a.device,
|
|
1685
|
+
dtype=torch.bfloat16 if dot_out_dtype is None else dot_out_dtype,
|
|
1686
|
+
)
|
|
1687
|
+
|
|
1688
|
+
|
|
1689
|
+
# pruned some unreasonable config
|
|
1690
|
+
def prune_configs_block(configs, named_args, **kwargs):
|
|
1691
|
+
configs = early_config_prune(configs, named_args, **kwargs)
|
|
1692
|
+
scale_block_k = named_args["scale_block_k"]
|
|
1693
|
+
pruned_configs = []
|
|
1694
|
+
# Further rule out configs with scale_block_k is not a multiple of BLOCK_K
|
|
1695
|
+
for config in configs:
|
|
1696
|
+
kw = config.kwargs
|
|
1697
|
+
BLOCK_K = kw["BLOCK_K"]
|
|
1698
|
+
if scale_block_k % BLOCK_K != 0:
|
|
1699
|
+
continue
|
|
1700
|
+
pruned_configs.append(config)
|
|
1701
|
+
return pruned_configs
|
|
1702
|
+
|
|
1703
|
+
|
|
1704
|
+
@triton.autotune(
|
|
1705
|
+
configs=MATMUL_CONFIGS,
|
|
1706
|
+
key=[
|
|
1707
|
+
"m_key",
|
|
1708
|
+
"n_key",
|
|
1709
|
+
"k_key",
|
|
1710
|
+
], # TODO caller side bin keys so similar shapes can use same triton.autotune.
|
|
1711
|
+
prune_configs_by={
|
|
1712
|
+
"early_config_prune": prune_configs_block,
|
|
1713
|
+
"perf_model": estimate_matmul_time,
|
|
1714
|
+
"top_k": 10,
|
|
1715
|
+
},
|
|
1716
|
+
)
|
|
1717
|
+
@triton.heuristics(
|
|
1718
|
+
{
|
|
1719
|
+
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
1720
|
+
}
|
|
1721
|
+
)
|
|
1722
|
+
@triton.jit
|
|
1723
|
+
def _kernel_matmul_fp8_block_fastacc(
|
|
1724
|
+
A,
|
|
1725
|
+
B,
|
|
1726
|
+
C,
|
|
1727
|
+
M,
|
|
1728
|
+
N,
|
|
1729
|
+
K,
|
|
1730
|
+
m_key,
|
|
1731
|
+
n_key,
|
|
1732
|
+
k_key,
|
|
1733
|
+
A_scale,
|
|
1734
|
+
B_scale,
|
|
1735
|
+
scale_block_m: tl.constexpr,
|
|
1736
|
+
scale_block_n: tl.constexpr,
|
|
1737
|
+
scale_block_k: tl.constexpr,
|
|
1738
|
+
stride_am,
|
|
1739
|
+
stride_ak,
|
|
1740
|
+
stride_bn,
|
|
1741
|
+
stride_bk,
|
|
1742
|
+
stride_cm,
|
|
1743
|
+
stride_cn,
|
|
1744
|
+
stride_scale_am,
|
|
1745
|
+
stride_scale_ak,
|
|
1746
|
+
stride_scale_bn,
|
|
1747
|
+
stride_scale_bk,
|
|
1748
|
+
dot_out_dtype: tl.constexpr,
|
|
1749
|
+
allow_tf32: tl.constexpr,
|
|
1750
|
+
BLOCK_M: tl.constexpr,
|
|
1751
|
+
BLOCK_N: tl.constexpr,
|
|
1752
|
+
BLOCK_K: tl.constexpr,
|
|
1753
|
+
GROUP_M: tl.constexpr,
|
|
1754
|
+
SPLIT_K: tl.constexpr,
|
|
1755
|
+
EVEN_K: tl.constexpr,
|
|
1756
|
+
AB_DTYPE: tl.constexpr,
|
|
1757
|
+
) -> None:
|
|
1758
|
+
"""Matmul kernel of [M, K] @ [N, K] with block-wise scales
|
|
1759
|
+
|
|
1760
|
+
Performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles and
|
|
1761
|
+
A and B scaled by a scaling factor per [scale_block_m, scale_block_k] and
|
|
1762
|
+
[scale_block_n, scale_block_k] tiles
|
|
1763
|
+
respectively.
|
|
1764
|
+
|
|
1765
|
+
Todo:
|
|
1766
|
+
* Support scale_block_{mnk} < BLOCK{MNK} for each dim.
|
|
1767
|
+
Args:
|
|
1768
|
+
A (TensorWrapper): [M, K] input tensor.
|
|
1769
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
1770
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
1771
|
+
M (int): M dimension of input tensor.
|
|
1772
|
+
N (int): N dimension of input tensor.
|
|
1773
|
+
K (int): K dimension of input tensor.
|
|
1774
|
+
m_key (int): Autotuning key for M dimension of input tensor.
|
|
1775
|
+
n_key (int): Autotuning key for N dimension of input tensor.
|
|
1776
|
+
k_key (int): Autotuning key for K dimension of input tensor.
|
|
1777
|
+
A_scale (TensorWrapper): [cdiv(M, scale_block_m), cdiv(K, scale_block_k)] reciprocal scale tensor per block. A * A_scale = original A
|
|
1778
|
+
B_scale (TensorWrapper): [cdiv(N, scale_block_n), cdiv(K, scale_block_k)] reciprocal scale tensor per block. B * B_scale = original B
|
|
1779
|
+
scale_block_m (int): Block size for M dimension of A_scale.
|
|
1780
|
+
scale_block_n (int): Block size for N dimension of B_scale.
|
|
1781
|
+
scale_block_k (int): Block size for K dimension of A_scale and B_scale.
|
|
1782
|
+
stride_am (int): Stride of M dimension of A.
|
|
1783
|
+
stride_ak (int): Stride of K dimension of A.
|
|
1784
|
+
stride_bn (int): Stride of N dimension of B.
|
|
1785
|
+
stride_bk (int): Stride of K dimension of B.
|
|
1786
|
+
stride_cm (int): Stride of M dimension of C.
|
|
1787
|
+
stride_cn (int): Stride of N dimension of C.
|
|
1788
|
+
stride_scale_am (int): Stride of M dimension of A_scale.
|
|
1789
|
+
stride_scale_ak (int): Stride of K dimension of A_scale.
|
|
1790
|
+
stride_scale_bn (int): Stride of N dimension of B_scale.
|
|
1791
|
+
stride_scale_bk (int): Stride of K dimension of B_scale.
|
|
1792
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
1793
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
1794
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
1795
|
+
BLOCK_M (int): Block size for M dimension.
|
|
1796
|
+
BLOCK_N (int): Block size for N dimension.
|
|
1797
|
+
BLOCK_K (int): Block size for K dimension.
|
|
1798
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
1799
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
1800
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
1801
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
1802
|
+
"""
|
|
1803
|
+
assert BLOCK_M < scale_block_m
|
|
1804
|
+
assert BLOCK_N < scale_block_n
|
|
1805
|
+
assert BLOCK_K < scale_block_k
|
|
1806
|
+
# matrix multiplication
|
|
1807
|
+
pid = tl.program_id(0)
|
|
1808
|
+
pid_z = tl.program_id(1)
|
|
1809
|
+
|
|
1810
|
+
grid_m = tl.cdiv(M, BLOCK_M)
|
|
1811
|
+
grid_n = tl.cdiv(N, BLOCK_N)
|
|
1812
|
+
# re-order program ID for better L2 performance
|
|
1813
|
+
width = GROUP_M * grid_n
|
|
1814
|
+
group_id = pid // width
|
|
1815
|
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
|
1816
|
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
|
1817
|
+
pid_n = (pid % width) // (group_size)
|
|
1818
|
+
# do matrix multiplication
|
|
1819
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
1820
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
1821
|
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
|
1822
|
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
|
1823
|
+
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
1824
|
+
# pointers
|
|
1825
|
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
|
1826
|
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
|
1827
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
1828
|
+
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
|
|
1829
|
+
scale_m = pid_m * BLOCK_M // scale_block_m
|
|
1830
|
+
scale_n = pid_n * BLOCK_N // scale_block_n
|
|
1831
|
+
k_multiple = scale_block_k // BLOCK_K
|
|
1832
|
+
|
|
1833
|
+
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
|
1834
|
+
|
|
1835
|
+
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
|
1836
|
+
|
|
1837
|
+
if EVEN_K:
|
|
1838
|
+
a = tl.load(A)
|
|
1839
|
+
b = tl.load(B)
|
|
1840
|
+
else:
|
|
1841
|
+
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
|
|
1842
|
+
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
|
|
1843
|
+
if AB_DTYPE:
|
|
1844
|
+
a = a.to(C.dtype.element_ty)
|
|
1845
|
+
b = b.to(C.dtype.element_ty)
|
|
1846
|
+
|
|
1847
|
+
acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
|
1848
|
+
|
|
1849
|
+
A += BLOCK_K * SPLIT_K * stride_ak
|
|
1850
|
+
B += BLOCK_K * SPLIT_K * stride_bk
|
|
1851
|
+
|
|
1852
|
+
# Some math to precompute on scalars, and apply once on matrix.
|
|
1853
|
+
# a + c/s = (as + c) / s
|
|
1854
|
+
# (((a_i-1 * s_i-1 + c_i-1) / s_i-1) * s_i + c_i) / s_i ... ) * s_k + c_k) * 1.0 / s_k
|
|
1855
|
+
# Simplifies to (a_i-1 + c) * (s_i+1/s_i)
|
|
1856
|
+
# And have s_k+1 be 1.
|
|
1857
|
+
# Scale_i = pid_i * BLOCK_I / scale_block_i
|
|
1858
|
+
pid_k = k * SPLIT_K + pid_z
|
|
1859
|
+
if ((pid_k + 1) % k_multiple == 0) or (k_remaining < BLOCK_K * SPLIT_K):
|
|
1860
|
+
# Note: Due to split_k access "pid_k" = k * SPLIT_K + pid_z
|
|
1861
|
+
# Access a_scale[pid_m, k * SPLIT_K + pid_z]
|
|
1862
|
+
# and b_scale[k * SPLIT_K + pid_z, pid_n]
|
|
1863
|
+
|
|
1864
|
+
scale_k = pid_k // k_multiple
|
|
1865
|
+
scale_k_next = scale_k + 1
|
|
1866
|
+
a_scale = tl.load(
|
|
1867
|
+
A_scale + scale_m * stride_scale_am + scale_k * stride_scale_ak
|
|
1868
|
+
)
|
|
1869
|
+
b_scale = tl.load(
|
|
1870
|
+
B_scale + scale_n * stride_scale_bn + scale_k * stride_scale_bk
|
|
1871
|
+
)
|
|
1872
|
+
scale = a_scale * b_scale
|
|
1873
|
+
if k + 1 == tl.cdiv(K, BLOCK_K * SPLIT_K):
|
|
1874
|
+
scale_next_inv_scale = scale
|
|
1875
|
+
else:
|
|
1876
|
+
a_scale_next = tl.load(
|
|
1877
|
+
A_scale + scale_m * stride_scale_am + scale_k_next * stride_scale_ak
|
|
1878
|
+
)
|
|
1879
|
+
b_scale_next = tl.load(
|
|
1880
|
+
B_scale + scale_n * stride_scale_bn + scale_k_next * stride_scale_bk
|
|
1881
|
+
)
|
|
1882
|
+
scale_next = a_scale_next * b_scale_next
|
|
1883
|
+
scale_next_inv_scale = scale / scale_next
|
|
1884
|
+
acc *= scale_next_inv_scale
|
|
1885
|
+
|
|
1886
|
+
# rematerialize rm and rn to save registers
|
|
1887
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
1888
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
1889
|
+
|
|
1890
|
+
acc = acc.to(C.dtype.element_ty)
|
|
1891
|
+
c = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
|
1892
|
+
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
|
1893
|
+
# handles write-back with reduction-splitting
|
|
1894
|
+
if SPLIT_K == 1:
|
|
1895
|
+
tl.store(c, acc, mask=mask)
|
|
1896
|
+
else:
|
|
1897
|
+
tl.atomic_add(c, acc, mask=mask)
|
|
1898
|
+
|
|
1899
|
+
|
|
1900
|
+
@triton.autotune(
|
|
1901
|
+
configs=MATMUL_CONFIGS,
|
|
1902
|
+
key=[
|
|
1903
|
+
"m_key",
|
|
1904
|
+
"n_key",
|
|
1905
|
+
"k_key",
|
|
1906
|
+
], # TODO caller side bin keys so similar shapes can use same triton.autotune.
|
|
1907
|
+
prune_configs_by={
|
|
1908
|
+
"early_config_prune": early_config_prune,
|
|
1909
|
+
"perf_model": estimate_matmul_time,
|
|
1910
|
+
"top_k": 10,
|
|
1911
|
+
},
|
|
1912
|
+
)
|
|
1913
|
+
@triton.heuristics(
|
|
1914
|
+
{
|
|
1915
|
+
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
1916
|
+
}
|
|
1917
|
+
)
|
|
1918
|
+
@triton.jit
|
|
1919
|
+
def _kernel_matmul_fp8_block_slowacc(
|
|
1920
|
+
A,
|
|
1921
|
+
B,
|
|
1922
|
+
C,
|
|
1923
|
+
M,
|
|
1924
|
+
N,
|
|
1925
|
+
K,
|
|
1926
|
+
m_key,
|
|
1927
|
+
n_key,
|
|
1928
|
+
k_key,
|
|
1929
|
+
A_scale,
|
|
1930
|
+
B_scale,
|
|
1931
|
+
scale_block_m: tl.constexpr,
|
|
1932
|
+
scale_block_n: tl.constexpr,
|
|
1933
|
+
scale_block_k: tl.constexpr,
|
|
1934
|
+
stride_am,
|
|
1935
|
+
stride_ak,
|
|
1936
|
+
stride_bn,
|
|
1937
|
+
stride_bk,
|
|
1938
|
+
stride_cm,
|
|
1939
|
+
stride_cn,
|
|
1940
|
+
stride_scale_am,
|
|
1941
|
+
stride_scale_ak,
|
|
1942
|
+
stride_scale_bn,
|
|
1943
|
+
stride_scale_bk,
|
|
1944
|
+
dot_out_dtype: tl.constexpr,
|
|
1945
|
+
allow_tf32: tl.constexpr,
|
|
1946
|
+
BLOCK_M: tl.constexpr,
|
|
1947
|
+
BLOCK_N: tl.constexpr,
|
|
1948
|
+
BLOCK_K: tl.constexpr,
|
|
1949
|
+
GROUP_M: tl.constexpr,
|
|
1950
|
+
SPLIT_K: tl.constexpr,
|
|
1951
|
+
EVEN_K: tl.constexpr,
|
|
1952
|
+
AB_DTYPE: tl.constexpr,
|
|
1953
|
+
) -> None:
|
|
1954
|
+
"""Matmul kernel of [M, K] @ [N, K] with block-wise scales
|
|
1955
|
+
|
|
1956
|
+
Performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles and
|
|
1957
|
+
A and B scaled by a scaling factor per [scale_block_m, scale_block_k] and
|
|
1958
|
+
[scale_block_n, scale_block_k] tiles
|
|
1959
|
+
respectively.
|
|
1960
|
+
|
|
1961
|
+
Todo:
|
|
1962
|
+
* Support scale_block_{mnk} < BLOCK{MNK} for each dim.
|
|
1963
|
+
Args:
|
|
1964
|
+
A (TensorWrapper): [M, K] input tensor.
|
|
1965
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
1966
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
1967
|
+
M (int): M dimension of input tensor.
|
|
1968
|
+
N (int): N dimension of input tensor.
|
|
1969
|
+
K (int): K dimension of input tensor.
|
|
1970
|
+
m_key (int): Autotuning key for M dimension of input tensor.
|
|
1971
|
+
n_key (int): Autotuning key for N dimension of input tensor.
|
|
1972
|
+
k_key (int): Autotuning key for K dimension of input tensor.
|
|
1973
|
+
A_scale (TensorWrapper): [cdiv(M, scale_block_m), cdiv(K, scale_block_k)] reciprocal scale tensor per block. A * A_scale = original A
|
|
1974
|
+
B_scale (TensorWrapper): [cdiv(N, scale_block_n), cdiv(K, scale_block_k)] reciprocal scale tensor per block. B * B_scale = original B
|
|
1975
|
+
scale_block_m (int): Block size for M dimension of A_scale.
|
|
1976
|
+
scale_block_n (int): Block size for N dimension of B_scale.
|
|
1977
|
+
scale_block_k (int): Block size for K dimension of A_scale and B_scale.
|
|
1978
|
+
stride_am (int): Stride of M dimension of A.
|
|
1979
|
+
stride_ak (int): Stride of K dimension of A.
|
|
1980
|
+
stride_bn (int): Stride of N dimension of B.
|
|
1981
|
+
stride_bk (int): Stride of K dimension of B.
|
|
1982
|
+
stride_cm (int): Stride of M dimension of C.
|
|
1983
|
+
stride_cn (int): Stride of N dimension of C.
|
|
1984
|
+
stride_scale_am (int): Stride of M dimension of A_scale.
|
|
1985
|
+
stride_scale_ak (int): Stride of K dimension of A_scale.
|
|
1986
|
+
stride_scale_bn (int): Stride of N dimension of B_scale.
|
|
1987
|
+
stride_scale_bk (int): Stride of K dimension of B_scale.
|
|
1988
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
1989
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
1990
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
1991
|
+
BLOCK_M (int): Block size for M dimension.
|
|
1992
|
+
BLOCK_N (int): Block size for N dimension.
|
|
1993
|
+
BLOCK_K (int): Block size for K dimension.
|
|
1994
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
1995
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
1996
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
1997
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
1998
|
+
"""
|
|
1999
|
+
assert BLOCK_M < scale_block_m
|
|
2000
|
+
assert BLOCK_N < scale_block_n
|
|
2001
|
+
assert BLOCK_K < scale_block_k
|
|
2002
|
+
# matrix multiplication
|
|
2003
|
+
pid = tl.program_id(0)
|
|
2004
|
+
pid_z = tl.program_id(1)
|
|
2005
|
+
|
|
2006
|
+
grid_m = tl.cdiv(M, BLOCK_M)
|
|
2007
|
+
grid_n = tl.cdiv(N, BLOCK_N)
|
|
2008
|
+
# re-order program ID for better L2 performance
|
|
2009
|
+
width = GROUP_M * grid_n
|
|
2010
|
+
group_id = pid // width
|
|
2011
|
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
|
2012
|
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
|
2013
|
+
pid_n = (pid % width) // (group_size)
|
|
2014
|
+
# do matrix multiplication
|
|
2015
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
2016
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
2017
|
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
|
2018
|
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
|
2019
|
+
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
2020
|
+
# pointers
|
|
2021
|
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
|
2022
|
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
|
2023
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
|
2024
|
+
scale_m = pid_m * BLOCK_M // scale_block_m
|
|
2025
|
+
scale_n = pid_n * BLOCK_N // scale_block_n
|
|
2026
|
+
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
|
|
2027
|
+
|
|
2028
|
+
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
|
2029
|
+
# Note: Due to split_k access "pid_k" = k * SPLIT_K + pid_z
|
|
2030
|
+
# Access a_scale[pid_m, k * SPLIT_K + pid_z]
|
|
2031
|
+
# and b_scale[k * SPLIT_K + pid_z, pid_n]
|
|
2032
|
+
pid_k = k * SPLIT_K + pid_z
|
|
2033
|
+
scale_k = pid_k * BLOCK_K // scale_block_k
|
|
2034
|
+
a_scale = tl.load(
|
|
2035
|
+
A_scale + scale_m * stride_scale_am + scale_k * stride_scale_ak
|
|
2036
|
+
)
|
|
2037
|
+
b_scale = tl.load(
|
|
2038
|
+
B_scale + scale_n * stride_scale_bn + scale_k * stride_scale_bk
|
|
2039
|
+
)
|
|
2040
|
+
scale = a_scale * b_scale
|
|
2041
|
+
|
|
2042
|
+
if EVEN_K:
|
|
2043
|
+
a = tl.load(A)
|
|
2044
|
+
b = tl.load(B)
|
|
2045
|
+
else:
|
|
2046
|
+
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
|
2047
|
+
|
|
2048
|
+
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
|
|
2049
|
+
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
|
|
2050
|
+
if AB_DTYPE:
|
|
2051
|
+
a = a.to(C.dtype.element_ty)
|
|
2052
|
+
b = b.to(C.dtype.element_ty)
|
|
2053
|
+
|
|
2054
|
+
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) * scale
|
|
2055
|
+
A += BLOCK_K * SPLIT_K * stride_ak
|
|
2056
|
+
B += BLOCK_K * SPLIT_K * stride_bk
|
|
2057
|
+
|
|
2058
|
+
# rematerialize rm and rn to save registers
|
|
2059
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
2060
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
2061
|
+
|
|
2062
|
+
acc = acc.to(C.dtype.element_ty)
|
|
2063
|
+
c = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
|
2064
|
+
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
|
2065
|
+
# handles write-back with reduction-splitting
|
|
2066
|
+
if SPLIT_K == 1:
|
|
2067
|
+
tl.store(c, acc, mask=mask)
|
|
2068
|
+
else:
|
|
2069
|
+
tl.atomic_add(c, acc, mask=mask)
|
|
2070
|
+
|
|
2071
|
+
|
|
2072
|
+
@torch.library.custom_op("triton::matmul_fp8_block", mutates_args=())
|
|
2073
|
+
def matmul_fp8_block(
|
|
2074
|
+
a: torch.Tensor,
|
|
2075
|
+
b: torch.Tensor,
|
|
2076
|
+
a_scale: torch.Tensor,
|
|
2077
|
+
b_scale: torch.Tensor,
|
|
2078
|
+
scale_block_m: int = 256,
|
|
2079
|
+
scale_block_n: int = 256,
|
|
2080
|
+
scale_block_k: int = 256,
|
|
2081
|
+
dot_out_dtype: Optional[torch.dtype] = None,
|
|
2082
|
+
allow_tf32: bool = True,
|
|
2083
|
+
fp8_fast_accum: bool = True,
|
|
2084
|
+
) -> Tensor:
|
|
2085
|
+
"""Performs matmul on [M, K] and [N, K] fp8 matrices with block-wise scalings.
|
|
2086
|
+
|
|
2087
|
+
Args:
|
|
2088
|
+
a (torch.Tensor): [M, K] input tensor.
|
|
2089
|
+
b (torch.Tensor): [N, K] input tensor.
|
|
2090
|
+
a_scale (torch.Tensor): [cdiv(M, scale_block_m), cdiv(K, scale_block_k)] reciprocal scale tensor per scale block. A * A_scale = original A
|
|
2091
|
+
b_scale (torch.Tensor): [cdiv(N, scale_block_n), cdiv(K, scale_block_k)] reciprocal scale tensor per scale block. B * B_scale = original B
|
|
2092
|
+
scale_block_m (int): Block size for M dimension of A_scale.
|
|
2093
|
+
scale_block_n (int): Block size for N dimension of B_scale.
|
|
2094
|
+
scale_block_k (int): Block size for K dimension of A_scale and B_scale.
|
|
2095
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
2096
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
2097
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
2098
|
+
|
|
2099
|
+
Returns:
|
|
2100
|
+
Tensor: [M, N] output tensor, (a / a_scale) @ (b / b_scale)
|
|
2101
|
+
"""
|
|
2102
|
+
# Get datatypes and constants to use.
|
|
2103
|
+
_, tl_fp8_dtype, _, _ = get_fp8_constants()
|
|
2104
|
+
# Handle 3D+ a shape
|
|
2105
|
+
a_shape = a.shape
|
|
2106
|
+
a = a.view(-1, a.size(-1))
|
|
2107
|
+
# View inputs into proper triton fp8 dtype.
|
|
2108
|
+
a_tl = reinterpret_fp8_type(a, tl_fp8_dtype)
|
|
2109
|
+
b_tl = reinterpret_fp8_type(b, tl_fp8_dtype)
|
|
2110
|
+
|
|
2111
|
+
M, N, K, m_key, n_key, k_key, c, _, dot_out_dtype_triton, device = prep_matmul(
|
|
2112
|
+
a_tl, b_tl, dot_out_dtype
|
|
2113
|
+
)
|
|
2114
|
+
|
|
2115
|
+
output_shape = a_shape[:-1] + (N,)
|
|
2116
|
+
# Handle case where inputs are empty.
|
|
2117
|
+
if (M == 0) or (N == 0) or (K == 0):
|
|
2118
|
+
return torch.zeros(output_shape, device=device, dtype=torch.bfloat16)
|
|
2119
|
+
|
|
2120
|
+
# launch kernel
|
|
2121
|
+
assert device != torch.device(
|
|
2122
|
+
"cpu"
|
|
2123
|
+
), "Blockwise matmul not supported on cpu, please use row-wise instead."
|
|
2124
|
+
|
|
2125
|
+
if b.device != a.device:
|
|
2126
|
+
raise Exception("'b' must be on the same device as 'a'")
|
|
2127
|
+
if a_scale.device != a.device:
|
|
2128
|
+
raise Exception("'a_scale' must be on the same device as 'a'")
|
|
2129
|
+
if b_scale.device != a.device:
|
|
2130
|
+
raise Exception("'b_scale' must be on the same device as 'a'")
|
|
2131
|
+
|
|
2132
|
+
# noqa: E731:
|
|
2133
|
+
def grid(META: dict[str, int]) -> tuple[int, int]:
|
|
2134
|
+
return (
|
|
2135
|
+
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
|
2136
|
+
META["SPLIT_K"],
|
|
2137
|
+
)
|
|
2138
|
+
|
|
2139
|
+
if fp8_fast_accum:
|
|
2140
|
+
with torch.cuda.device(a_tl.device.index):
|
|
2141
|
+
_kernel_matmul_fp8_block_fastacc[grid](
|
|
2142
|
+
a_tl,
|
|
2143
|
+
b_tl,
|
|
2144
|
+
c,
|
|
2145
|
+
M,
|
|
2146
|
+
N,
|
|
2147
|
+
K,
|
|
2148
|
+
m_key,
|
|
2149
|
+
n_key,
|
|
2150
|
+
k_key,
|
|
2151
|
+
a_scale,
|
|
2152
|
+
b_scale,
|
|
2153
|
+
scale_block_m,
|
|
2154
|
+
scale_block_n,
|
|
2155
|
+
scale_block_k,
|
|
2156
|
+
a.stride(0),
|
|
2157
|
+
a.stride(1),
|
|
2158
|
+
b.stride(0),
|
|
2159
|
+
b.stride(1),
|
|
2160
|
+
c.stride(0),
|
|
2161
|
+
c.stride(1),
|
|
2162
|
+
a_scale.stride(0),
|
|
2163
|
+
a_scale.stride(1),
|
|
2164
|
+
b_scale.stride(0),
|
|
2165
|
+
b_scale.stride(1),
|
|
2166
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
2167
|
+
allow_tf32=allow_tf32,
|
|
2168
|
+
GROUP_M=8,
|
|
2169
|
+
AB_DTYPE=False,
|
|
2170
|
+
)
|
|
2171
|
+
else:
|
|
2172
|
+
with torch.cuda.device(a_tl.device.index):
|
|
2173
|
+
_kernel_matmul_fp8_block_slowacc[grid](
|
|
2174
|
+
a_tl,
|
|
2175
|
+
b_tl,
|
|
2176
|
+
c,
|
|
2177
|
+
M,
|
|
2178
|
+
N,
|
|
2179
|
+
K,
|
|
2180
|
+
m_key,
|
|
2181
|
+
n_key,
|
|
2182
|
+
k_key,
|
|
2183
|
+
a_scale,
|
|
2184
|
+
b_scale,
|
|
2185
|
+
scale_block_m,
|
|
2186
|
+
scale_block_n,
|
|
2187
|
+
scale_block_k,
|
|
2188
|
+
a.stride(0),
|
|
2189
|
+
a.stride(1),
|
|
2190
|
+
b.stride(0),
|
|
2191
|
+
b.stride(1),
|
|
2192
|
+
c.stride(0),
|
|
2193
|
+
c.stride(1),
|
|
2194
|
+
a_scale.stride(0),
|
|
2195
|
+
a_scale.stride(1),
|
|
2196
|
+
b_scale.stride(0),
|
|
2197
|
+
b_scale.stride(1),
|
|
2198
|
+
dot_out_dtype=dot_out_dtype_triton,
|
|
2199
|
+
allow_tf32=allow_tf32,
|
|
2200
|
+
GROUP_M=8,
|
|
2201
|
+
AB_DTYPE=False,
|
|
2202
|
+
)
|
|
2203
|
+
return c.view(output_shape)
|
|
2204
|
+
|
|
2205
|
+
|
|
2206
|
+
@matmul_fp8_block.register_fake
|
|
2207
|
+
def matmul_fp8_block_meta(
|
|
2208
|
+
a: torch.Tensor,
|
|
2209
|
+
b: torch.Tensor,
|
|
2210
|
+
a_scale: torch.Tensor,
|
|
2211
|
+
b_scale: torch.Tensor,
|
|
2212
|
+
scale_block_m: int = 256,
|
|
2213
|
+
scale_block_n: int = 256,
|
|
2214
|
+
scale_block_k: int = 256,
|
|
2215
|
+
dot_out_dtype: Optional[torch.dtype] = None,
|
|
2216
|
+
allow_tf32: bool = True,
|
|
2217
|
+
fp8_fast_accum: bool = True,
|
|
2218
|
+
) -> torch.Tensor:
|
|
2219
|
+
"""Shape function for torch compile."""
|
|
2220
|
+
M, K = a.shape
|
|
2221
|
+
N, K = b.shape
|
|
2222
|
+
return torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
|
|
2223
|
+
|
|
2224
|
+
|
|
2225
|
+
def get_matmul_tune(M: int, N: int, K: int) -> tuple[int, int, int]:
|
|
2226
|
+
"""
|
|
2227
|
+
Generate a simplified matmul tune key for A @ B.T
|
|
2228
|
+
with [M, K] A and [N, K] B to reduce excessive autotuning.
|
|
2229
|
+
|
|
2230
|
+
Args:
|
|
2231
|
+
M (int): Number of rows in A.
|
|
2232
|
+
N (int): Number of rows in B.
|
|
2233
|
+
K (int): Number of cols in A and cols in B.
|
|
2234
|
+
|
|
2235
|
+
Returns:
|
|
2236
|
+
m_key (int): Autotuning key for M dim.
|
|
2237
|
+
n_key (int): Autotuning key for N dim.
|
|
2238
|
+
k_key (int): Autotuning key for K dim.
|
|
2239
|
+
|
|
2240
|
+
TODO: Refine this. For now it's useful for LLM inference where N, K dims are fixed
|
|
2241
|
+
and M dim varies due to seq_len.
|
|
2242
|
+
"""
|
|
2243
|
+
if M < 256:
|
|
2244
|
+
m_key = M
|
|
2245
|
+
else:
|
|
2246
|
+
m_key = 256 + M // 1024
|
|
2247
|
+
return m_key, N, K
|
|
2248
|
+
|
|
2249
|
+
|
|
2250
|
+
def prep_matmul(
|
|
2251
|
+
a: Union[TensorWrapper, torch.Tensor],
|
|
2252
|
+
b: Union[TensorWrapper, torch.Tensor],
|
|
2253
|
+
dot_out_dtype: Optional[torch.dtype],
|
|
2254
|
+
) -> tuple[
|
|
2255
|
+
int, int, int, int, int, int, torch.Tensor, tl.dtype, tl.dtype, torch.device
|
|
2256
|
+
]:
|
|
2257
|
+
"""
|
|
2258
|
+
Shared bookkeeping for a @ b.T matmul.
|
|
2259
|
+
|
|
2260
|
+
Args:
|
|
2261
|
+
a (torch.Tensor): [M, K] input tensor.
|
|
2262
|
+
b (torch.Tensor): [N, K] input tensor.
|
|
2263
|
+
dot_out_dtype (tl.dtype): Output type of tensor core.
|
|
2264
|
+
|
|
2265
|
+
Returns:
|
|
2266
|
+
M (int): Number of rows in A.
|
|
2267
|
+
N (int): Number of rows in B.
|
|
2268
|
+
K (int): Number of cols in A and cols in B.
|
|
2269
|
+
m_key (int): Autotuning key for M dim.
|
|
2270
|
+
n_key (int): Autotuning key for N dim.
|
|
2271
|
+
k_key (int): Autotuning key for K dim.
|
|
2272
|
+
c (Tensor): [M, N] output tensor.
|
|
2273
|
+
c_dtype_triton (tl.dtype): Type of output tensor.
|
|
2274
|
+
dot_out_dtype (tl.dtype): Output type of tensor core.
|
|
2275
|
+
device (torch.device): Device of output tensor.
|
|
2276
|
+
"""
|
|
2277
|
+
device = a.device
|
|
2278
|
+
|
|
2279
|
+
# checks constraints
|
|
2280
|
+
assert (
|
|
2281
|
+
a.shape[1] == b.shape[1]
|
|
2282
|
+
), f"incompatible dimensions, a: {a.shape}, b: {b.shape}"
|
|
2283
|
+
M, K = a.shape
|
|
2284
|
+
N, _ = b.shape
|
|
2285
|
+
m_key, n_key, k_key = get_matmul_tune(M, N, K)
|
|
2286
|
+
|
|
2287
|
+
# allocates output
|
|
2288
|
+
assert a.dtype in [
|
|
2289
|
+
torch.float8_e4m3fn,
|
|
2290
|
+
torch.float8_e5m2,
|
|
2291
|
+
torch.float8_e4m3fnuz,
|
|
2292
|
+
torch.float8_e5m2fnuz,
|
|
2293
|
+
tl.float8e4nv,
|
|
2294
|
+
tl.float8e4b15,
|
|
2295
|
+
tl.float8e5,
|
|
2296
|
+
tl.float8e4b8,
|
|
2297
|
+
]
|
|
2298
|
+
assert b.dtype in [
|
|
2299
|
+
torch.float8_e4m3fn,
|
|
2300
|
+
torch.float8_e5m2,
|
|
2301
|
+
torch.float8_e4m3fnuz,
|
|
2302
|
+
torch.float8_e5m2fnuz,
|
|
2303
|
+
tl.float8e4nv,
|
|
2304
|
+
tl.float8e4b15,
|
|
2305
|
+
tl.float8e5,
|
|
2306
|
+
tl.float8e4b8,
|
|
2307
|
+
]
|
|
2308
|
+
|
|
2309
|
+
c_dtype, c_dtype_triton = (
|
|
2310
|
+
(torch.bfloat16, tl.bfloat16)
|
|
2311
|
+
if dot_out_dtype is None
|
|
2312
|
+
else (dot_out_dtype, map_dtype_to_triton(dot_out_dtype))
|
|
2313
|
+
)
|
|
2314
|
+
|
|
2315
|
+
c = torch.empty((M, N), device=device, dtype=c_dtype)
|
|
2316
|
+
if dot_out_dtype is None:
|
|
2317
|
+
dot_out_dtype_triton = tl.float32
|
|
2318
|
+
else:
|
|
2319
|
+
assert isinstance(
|
|
2320
|
+
dot_out_dtype, torch.dtype
|
|
2321
|
+
), f"dot_out_dtype type {type(dot_out_dtype)} must be a torch.dtype"
|
|
2322
|
+
dot_out_dtype_triton = map_dtype_to_triton(dot_out_dtype)
|
|
2323
|
+
|
|
2324
|
+
return M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device
|
|
2325
|
+
|
|
2326
|
+
|
|
2327
|
+
@triton.autotune(
|
|
2328
|
+
configs=[
|
|
2329
|
+
Config({"BLOCK_SIZE": 512}),
|
|
2330
|
+
Config({"BLOCK_SIZE": 1024}),
|
|
2331
|
+
Config({"BLOCK_SIZE": 2048}),
|
|
2332
|
+
Config({"BLOCK_SIZE": 4096}),
|
|
2333
|
+
Config({"BLOCK_SIZE": 8192}),
|
|
2334
|
+
],
|
|
2335
|
+
key=["K"],
|
|
2336
|
+
)
|
|
2337
|
+
@triton.jit
|
|
2338
|
+
def _kernel_quantize_fp8_row(
|
|
2339
|
+
A,
|
|
2340
|
+
A_scale,
|
|
2341
|
+
A_fp8,
|
|
2342
|
+
scale_ub,
|
|
2343
|
+
zero_start_index_M,
|
|
2344
|
+
B,
|
|
2345
|
+
M,
|
|
2346
|
+
N,
|
|
2347
|
+
K,
|
|
2348
|
+
K_fp8, # used when padding
|
|
2349
|
+
stride_ab,
|
|
2350
|
+
stride_am,
|
|
2351
|
+
stride_an,
|
|
2352
|
+
stride_ak,
|
|
2353
|
+
stride_ob,
|
|
2354
|
+
stride_om,
|
|
2355
|
+
stride_on,
|
|
2356
|
+
stride_ok,
|
|
2357
|
+
stride_zb,
|
|
2358
|
+
stride_zm,
|
|
2359
|
+
TL_FP8_DTYPE: tl.constexpr,
|
|
2360
|
+
MAX_FP8: tl.constexpr,
|
|
2361
|
+
EPS: tl.constexpr,
|
|
2362
|
+
CLAMP_MAX: tl.constexpr,
|
|
2363
|
+
JAGGED: tl.constexpr,
|
|
2364
|
+
BLOCK_SIZE: tl.constexpr,
|
|
2365
|
+
USE_INT64: tl.constexpr,
|
|
2366
|
+
) -> None:
|
|
2367
|
+
"""Quantize and scale each row.
|
|
2368
|
+
|
|
2369
|
+
Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
|
|
2370
|
+
|
|
2371
|
+
Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
|
|
2372
|
+
in a max pass then scale/quantize pass.
|
|
2373
|
+
|
|
2374
|
+
Todo:
|
|
2375
|
+
* Better tiling schemes.
|
|
2376
|
+
|
|
2377
|
+
Args:
|
|
2378
|
+
A (Tensor): higher precision input tensor of 4 dimension.
|
|
2379
|
+
A_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
|
|
2380
|
+
A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
|
|
2381
|
+
scale_ub (Tensor): [1] Maximum value allowed for scale.
|
|
2382
|
+
B (int): Size of dimenion 0
|
|
2383
|
+
M (int): Size of dimenion 1
|
|
2384
|
+
N (int): Size of dimenion 2
|
|
2385
|
+
K (int): Size of dimenion 3 (input row size)
|
|
2386
|
+
K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
|
|
2387
|
+
stride_ab (int): Stride of b dimension of A.
|
|
2388
|
+
stride_am (int): Stride of m dimension of A.
|
|
2389
|
+
stride_an (int): Stride of n dimension of A.
|
|
2390
|
+
stride_ak (int): Stride of k dimension of A.
|
|
2391
|
+
stride_ob (int): Stride of b dimension of output.
|
|
2392
|
+
stride_om (int): Stride of m dimension of output.
|
|
2393
|
+
stride_on (int): Stride of n dimension of output.
|
|
2394
|
+
stride_ok (int): Stride of k dimension of output.
|
|
2395
|
+
stride_zb (int): Stride of b dimension of jagged index.
|
|
2396
|
+
stride_zm (int): Stride of m dimension of jagged index.
|
|
2397
|
+
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
|
|
2398
|
+
MAX_FP8 (float): Maxmimum expressible value for FP8.
|
|
2399
|
+
EPS (float): Epsilon value for numerical stability.
|
|
2400
|
+
CLAMP_MAX (bool): Whethar to apply scale_ub.
|
|
2401
|
+
JAGGED (bool): Whether to use jagged indexing.
|
|
2402
|
+
BLOCK_SIZE (int): Block size for reduction.
|
|
2403
|
+
USE_INT64 (bool): Whether to use int64 indexing for large inputs.
|
|
2404
|
+
"""
|
|
2405
|
+
pid = tl.program_id(0)
|
|
2406
|
+
# Use int64 indexing for large inputs. This is slower, but
|
|
2407
|
+
# needed to avoid index overflows.
|
|
2408
|
+
if USE_INT64:
|
|
2409
|
+
pid = pid.to(tl.int64)
|
|
2410
|
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
|
2411
|
+
a_offset_base = (
|
|
2412
|
+
pid // (M * N) * stride_ab
|
|
2413
|
+
+ (pid % (M * N)) // N * stride_am
|
|
2414
|
+
+ (pid % (M * N)) % N * stride_an
|
|
2415
|
+
)
|
|
2416
|
+
a_fp8_offset_base = (
|
|
2417
|
+
pid // (M * N) * stride_ob
|
|
2418
|
+
+ (pid % (M * N)) // N * stride_om
|
|
2419
|
+
+ (pid % (M * N)) % N * stride_on
|
|
2420
|
+
)
|
|
2421
|
+
|
|
2422
|
+
K_in = K
|
|
2423
|
+
|
|
2424
|
+
if JAGGED:
|
|
2425
|
+
z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
|
|
2426
|
+
group_rows = tl.load(zero_start_index_M + z_offset_base)
|
|
2427
|
+
current_row = pid % N
|
|
2428
|
+
# If this row is empty, dont process any of it.
|
|
2429
|
+
if current_row >= group_rows:
|
|
2430
|
+
K_in = 0
|
|
2431
|
+
|
|
2432
|
+
# Calculate max.
|
|
2433
|
+
cur_max = 0.0
|
|
2434
|
+
for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
|
|
2435
|
+
a = tl.load(
|
|
2436
|
+
A + a_offset_base + n_offset * stride_ak,
|
|
2437
|
+
mask=n_offset < K_in,
|
|
2438
|
+
other=0.0,
|
|
2439
|
+
)
|
|
2440
|
+
tile_max = tl.max(tl.abs(a))
|
|
2441
|
+
cur_max = tl.maximum(tile_max, cur_max)
|
|
2442
|
+
n_offset += BLOCK_SIZE
|
|
2443
|
+
|
|
2444
|
+
# Clamp max value appropriately.
|
|
2445
|
+
if CLAMP_MAX:
|
|
2446
|
+
ub = tl.load(scale_ub)
|
|
2447
|
+
cur_max = tl.clamp(cur_max, EPS, ub)
|
|
2448
|
+
else:
|
|
2449
|
+
cur_max = tl.maximum(cur_max, EPS)
|
|
2450
|
+
# Scale and quantize.
|
|
2451
|
+
a_scale = MAX_FP8 / cur_max
|
|
2452
|
+
tl.store(A_scale + pid, 1.0 / a_scale)
|
|
2453
|
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
|
2454
|
+
|
|
2455
|
+
# Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
|
|
2456
|
+
for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)):
|
|
2457
|
+
# Load from A if in range, else 0 (we're going all the way to K_fp8)
|
|
2458
|
+
a = tl.load(
|
|
2459
|
+
A + a_offset_base + n_offset * stride_ak,
|
|
2460
|
+
mask=n_offset < K_in,
|
|
2461
|
+
other=0.0,
|
|
2462
|
+
)
|
|
2463
|
+
# For elements >= K, a will be 0
|
|
2464
|
+
a_fp8 = a * a_scale
|
|
2465
|
+
# Clamp A to fp8 range to make sure there's no overflow.
|
|
2466
|
+
# This is required for AMD. Nvidia's default saturation
|
|
2467
|
+
# handles it, but it's nice to have anyway.
|
|
2468
|
+
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
|
|
2469
|
+
|
|
2470
|
+
# Store the full new row in its place (for elements >= K, a_fp8 is already 0)
|
|
2471
|
+
tl.store(
|
|
2472
|
+
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
|
|
2473
|
+
a_fp8,
|
|
2474
|
+
mask=n_offset < K_fp8,
|
|
2475
|
+
)
|
|
2476
|
+
n_offset += BLOCK_SIZE
|
|
2477
|
+
|
|
2478
|
+
|
|
2479
|
+
def triton_quantize_fp8_row(
|
|
2480
|
+
a: Tensor,
|
|
2481
|
+
scale_ub: Optional[Tensor] = None,
|
|
2482
|
+
zero_start_index_M: Optional[Tensor] = None,
|
|
2483
|
+
align_rows_to: Optional[int] = None,
|
|
2484
|
+
) -> tuple[Tensor, Tensor]:
|
|
2485
|
+
"""
|
|
2486
|
+
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
|
|
2487
|
+
|
|
2488
|
+
Args:
|
|
2489
|
+
a (Tensor): higher precision input tensor of 4 dimension.
|
|
2490
|
+
scale_ub (Tensor): Maximum allowed value for scale.
|
|
2491
|
+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
|
|
2492
|
+
align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
|
|
2493
|
+
|
|
2494
|
+
Returns:
|
|
2495
|
+
torch.Tensor: fp8 scaled tensor.
|
|
2496
|
+
torch.Tensor: reciprocal scale tensor per row.
|
|
2497
|
+
"""
|
|
2498
|
+
if scale_ub is not None and scale_ub.device != a.device:
|
|
2499
|
+
raise Exception("'scale_ub' must be on the same device as 'a'")
|
|
2500
|
+
if zero_start_index_M is not None and zero_start_index_M.device != a.device:
|
|
2501
|
+
raise Exception("'zero_start_index_M' must be on the same device as 'a'")
|
|
2502
|
+
|
|
2503
|
+
assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
|
|
2504
|
+
a_shape = a.shape
|
|
2505
|
+
while a.dim() < 4:
|
|
2506
|
+
a = a.unsqueeze(0)
|
|
2507
|
+
if zero_start_index_M is not None:
|
|
2508
|
+
# There should be one value of zero_start_index_M per NxK matrix.
|
|
2509
|
+
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
|
|
2510
|
+
# Get constant values.
|
|
2511
|
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
|
2512
|
+
num_rows = a.numel() // a.shape[-1]
|
|
2513
|
+
a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
|
|
2514
|
+
# If align_rows_to is provided, pad the last dimension to be a multiple of it
|
|
2515
|
+
if align_rows_to is not None:
|
|
2516
|
+
last_dim = a.shape[-1]
|
|
2517
|
+
padded_last_dim = (
|
|
2518
|
+
(last_dim + align_rows_to - 1) // align_rows_to
|
|
2519
|
+
) * align_rows_to
|
|
2520
|
+
a_fp8 = torch.empty(
|
|
2521
|
+
(*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype
|
|
2522
|
+
)
|
|
2523
|
+
a_shape = torch.Size((*a_shape[:-1], padded_last_dim))
|
|
2524
|
+
else:
|
|
2525
|
+
a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
|
|
2526
|
+
|
|
2527
|
+
# If input tensor is sufficiently large, we need to use int64 indexing.
|
|
2528
|
+
use_int64 = a.numel() > (2**31 - 1)
|
|
2529
|
+
grid = (num_rows,)
|
|
2530
|
+
# Pick a conservative value for inference shapes for disabling BufferOps.
|
|
2531
|
+
should_disable_bufferops = torch.version.hip is not None and a_shape[0] < 32
|
|
2532
|
+
with disable_bufferops(should_disable_bufferops):
|
|
2533
|
+
with torch.cuda.device(a.device.index):
|
|
2534
|
+
_kernel_quantize_fp8_row[grid](
|
|
2535
|
+
a,
|
|
2536
|
+
a_scale,
|
|
2537
|
+
a_fp8,
|
|
2538
|
+
scale_ub,
|
|
2539
|
+
zero_start_index_M,
|
|
2540
|
+
a.shape[0],
|
|
2541
|
+
a.shape[1],
|
|
2542
|
+
a.shape[2],
|
|
2543
|
+
a.shape[3],
|
|
2544
|
+
a_fp8.shape[3],
|
|
2545
|
+
a.stride(0),
|
|
2546
|
+
a.stride(1),
|
|
2547
|
+
a.stride(2),
|
|
2548
|
+
a.stride(3),
|
|
2549
|
+
a_fp8.stride(0),
|
|
2550
|
+
a_fp8.stride(1),
|
|
2551
|
+
a_fp8.stride(2),
|
|
2552
|
+
a_fp8.stride(3),
|
|
2553
|
+
(
|
|
2554
|
+
zero_start_index_M.stride(0)
|
|
2555
|
+
if zero_start_index_M is not None
|
|
2556
|
+
else None
|
|
2557
|
+
),
|
|
2558
|
+
(
|
|
2559
|
+
zero_start_index_M.stride(1)
|
|
2560
|
+
if zero_start_index_M is not None
|
|
2561
|
+
else None
|
|
2562
|
+
),
|
|
2563
|
+
TL_FP8_DTYPE=tl_dtype,
|
|
2564
|
+
MAX_FP8=max_fp8,
|
|
2565
|
+
EPS=eps,
|
|
2566
|
+
CLAMP_MAX=scale_ub is not None,
|
|
2567
|
+
JAGGED=zero_start_index_M is not None,
|
|
2568
|
+
USE_INT64=use_int64,
|
|
2569
|
+
)
|
|
2570
|
+
|
|
2571
|
+
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
|
|
2572
|
+
|
|
2573
|
+
|
|
2574
|
+
@triton.autotune(
|
|
2575
|
+
configs=[
|
|
2576
|
+
Config({"BLOCK_SIZE": 512}),
|
|
2577
|
+
Config({"BLOCK_SIZE": 1024}),
|
|
2578
|
+
Config({"BLOCK_SIZE": 2048}),
|
|
2579
|
+
Config({"BLOCK_SIZE": 4096}),
|
|
2580
|
+
Config({"BLOCK_SIZE": 8192}),
|
|
2581
|
+
],
|
|
2582
|
+
key=["K"],
|
|
2583
|
+
)
|
|
2584
|
+
@triton.jit
|
|
2585
|
+
def _kernel_quantize_fp8_packed_row(
|
|
2586
|
+
A,
|
|
2587
|
+
A_fp8,
|
|
2588
|
+
packed_scale,
|
|
2589
|
+
scale_ub,
|
|
2590
|
+
zero_start_index_M,
|
|
2591
|
+
B,
|
|
2592
|
+
M,
|
|
2593
|
+
N,
|
|
2594
|
+
K,
|
|
2595
|
+
stride_ab,
|
|
2596
|
+
stride_am,
|
|
2597
|
+
stride_an,
|
|
2598
|
+
stride_ak,
|
|
2599
|
+
stride_ob,
|
|
2600
|
+
stride_om,
|
|
2601
|
+
stride_on,
|
|
2602
|
+
stride_ok,
|
|
2603
|
+
packed_scale_stride,
|
|
2604
|
+
stride_zb,
|
|
2605
|
+
stride_zm,
|
|
2606
|
+
TL_FP8_DTYPE: tl.constexpr,
|
|
2607
|
+
MAX_FP8: tl.constexpr,
|
|
2608
|
+
EPS: tl.constexpr,
|
|
2609
|
+
CLAMP_MAX: tl.constexpr,
|
|
2610
|
+
JAGGED: tl.constexpr,
|
|
2611
|
+
BLOCK_SIZE: tl.constexpr,
|
|
2612
|
+
USE_INT64: tl.constexpr,
|
|
2613
|
+
) -> None:
|
|
2614
|
+
"""Quantize and scale each row.
|
|
2615
|
+
|
|
2616
|
+
Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
|
|
2617
|
+
|
|
2618
|
+
Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
|
|
2619
|
+
in a max pass then scale/quantize pass.
|
|
2620
|
+
|
|
2621
|
+
Todo:
|
|
2622
|
+
* Better tiling schemes.
|
|
2623
|
+
|
|
2624
|
+
Args:
|
|
2625
|
+
A (Tensor): higher precision input tensor of 4 dimension.
|
|
2626
|
+
packed_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
|
|
2627
|
+
A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
|
|
2628
|
+
scale_ub (Tensor): [1] Maximum value allowed for scale.
|
|
2629
|
+
B (int): Size of dimenion 0
|
|
2630
|
+
M (int): Size of dimenion 1
|
|
2631
|
+
N (int): Size of dimenion 2
|
|
2632
|
+
K (int): Size of dimenion 3
|
|
2633
|
+
stride_ab (int): Stride of b dimension of A.
|
|
2634
|
+
stride_am (int): Stride of m dimension of A.
|
|
2635
|
+
stride_an (int): Stride of n dimension of A.
|
|
2636
|
+
stride_ak (int): Stride of k dimension of A.
|
|
2637
|
+
stride_ob (int): Stride of b dimension of output.
|
|
2638
|
+
stride_om (int): Stride of m dimension of output.
|
|
2639
|
+
stride_on (int): Stride of n dimension of output.
|
|
2640
|
+
stride_ok (int): Stride of k dimension of output.
|
|
2641
|
+
packed_scale_stride (int): Stride of the packed scale, indexing into a_fp8.
|
|
2642
|
+
stride_zb (int): Stride of b dimension of jagged index.
|
|
2643
|
+
stride_zm (int): Stride of m dimension of jagged index.
|
|
2644
|
+
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
|
|
2645
|
+
MAX_FP8 (float): Maxmimum expressible value for FP8.
|
|
2646
|
+
EPS (float): Epsilon value for numerical stability.
|
|
2647
|
+
CLAMP_MAX (bool): Whethar to apply scale_ub.
|
|
2648
|
+
JAGGED (bool): Whether to use jagged indexing.
|
|
2649
|
+
BLOCK_SIZE (int): Block size for reduction.
|
|
2650
|
+
USE_INT64 (bool): Whether to use int64 indexing for large inputs.
|
|
2651
|
+
"""
|
|
2652
|
+
pid = tl.program_id(0)
|
|
2653
|
+
# Use int64 indexing for large inputs. This is slower, but
|
|
2654
|
+
# needed to avoid index overflows.
|
|
2655
|
+
if USE_INT64:
|
|
2656
|
+
pid = pid.to(tl.int64)
|
|
2657
|
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
|
2658
|
+
a_offset_base = (
|
|
2659
|
+
pid // (M * N) * stride_ab
|
|
2660
|
+
+ (pid % (M * N)) // N * stride_am
|
|
2661
|
+
+ (pid % (M * N)) % N * stride_an
|
|
2662
|
+
)
|
|
2663
|
+
a_fp8_offset_base = (
|
|
2664
|
+
pid // (M * N) * stride_ob
|
|
2665
|
+
+ (pid % (M * N)) // N * stride_om
|
|
2666
|
+
+ (pid % (M * N)) % N * stride_on
|
|
2667
|
+
)
|
|
2668
|
+
|
|
2669
|
+
K_in = K
|
|
2670
|
+
|
|
2671
|
+
if JAGGED:
|
|
2672
|
+
z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
|
|
2673
|
+
group_rows = tl.load(zero_start_index_M + z_offset_base)
|
|
2674
|
+
current_row = pid % N
|
|
2675
|
+
# If this row is empty, dont process any of it.
|
|
2676
|
+
if current_row >= group_rows:
|
|
2677
|
+
K_in = 0
|
|
2678
|
+
|
|
2679
|
+
# Calculate max.
|
|
2680
|
+
cur_max = 0.0
|
|
2681
|
+
for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
|
|
2682
|
+
a = tl.load(
|
|
2683
|
+
A + a_offset_base + n_offset * stride_ak,
|
|
2684
|
+
mask=n_offset < K_in,
|
|
2685
|
+
other=0.0,
|
|
2686
|
+
)
|
|
2687
|
+
tile_max = tl.max(tl.abs(a))
|
|
2688
|
+
cur_max = tl.maximum(tile_max, cur_max)
|
|
2689
|
+
n_offset += BLOCK_SIZE
|
|
2690
|
+
|
|
2691
|
+
# Clamp max value appropriately.
|
|
2692
|
+
if CLAMP_MAX:
|
|
2693
|
+
ub = tl.load(scale_ub)
|
|
2694
|
+
cur_max = tl.clamp(cur_max, EPS, ub)
|
|
2695
|
+
else:
|
|
2696
|
+
cur_max = tl.maximum(cur_max, EPS)
|
|
2697
|
+
# Scale and quantize.
|
|
2698
|
+
a_scale = MAX_FP8 / cur_max
|
|
2699
|
+
|
|
2700
|
+
(fp8_0, fp8_1, fp8_2, fp8_3) = tl.inline_asm_elementwise(
|
|
2701
|
+
asm="""
|
|
2702
|
+
{
|
|
2703
|
+
// $4 is the input register
|
|
2704
|
+
.reg .b32 input;
|
|
2705
|
+
mov.b32 input, $4;
|
|
2706
|
+
mov.b32 $0, $4;
|
|
2707
|
+
shr.b32 $1, $4, 8;
|
|
2708
|
+
shr.b32 $2, $4, 16;
|
|
2709
|
+
shr.b32 $3, $4, 24;
|
|
2710
|
+
}
|
|
2711
|
+
""",
|
|
2712
|
+
constraints=("=r,=r,=r,=r," "r"),
|
|
2713
|
+
# Let's pass in 1 uint32 value per iteration, containing 8 packed int4 values
|
|
2714
|
+
args=[1.0 / a_scale],
|
|
2715
|
+
dtype=(
|
|
2716
|
+
tl.uint8,
|
|
2717
|
+
tl.uint8,
|
|
2718
|
+
tl.uint8,
|
|
2719
|
+
tl.uint8,
|
|
2720
|
+
),
|
|
2721
|
+
is_pure=True,
|
|
2722
|
+
pack=1,
|
|
2723
|
+
)
|
|
2724
|
+
|
|
2725
|
+
# There are some compiler issues with FP8 pointers
|
|
2726
|
+
packed_scale_ptr = packed_scale.to(tl.pointer_type(tl.uint8))
|
|
2727
|
+
tl.store(packed_scale_ptr + pid * packed_scale_stride, fp8_0)
|
|
2728
|
+
tl.store(packed_scale_ptr + pid * packed_scale_stride + 1, fp8_1)
|
|
2729
|
+
tl.store(packed_scale_ptr + pid * packed_scale_stride + 2, fp8_2)
|
|
2730
|
+
tl.store(packed_scale_ptr + pid * packed_scale_stride + 3, fp8_3)
|
|
2731
|
+
|
|
2732
|
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
|
2733
|
+
|
|
2734
|
+
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
|
|
2735
|
+
a = tl.load(
|
|
2736
|
+
A + a_offset_base + n_offset * stride_ak,
|
|
2737
|
+
mask=n_offset < K_in,
|
|
2738
|
+
other=0.0,
|
|
2739
|
+
)
|
|
2740
|
+
a_fp8 = a * a_scale
|
|
2741
|
+
# Clamp A to fp8 range to make sure there's no overflow.
|
|
2742
|
+
# This is required for AMD. Nvidia's default saturation
|
|
2743
|
+
# handles it, but it's nice to have anyway.
|
|
2744
|
+
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
|
|
2745
|
+
tl.store(
|
|
2746
|
+
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
|
|
2747
|
+
a_fp8,
|
|
2748
|
+
mask=n_offset < K,
|
|
2749
|
+
)
|
|
2750
|
+
|
|
2751
|
+
n_offset += BLOCK_SIZE
|
|
2752
|
+
|
|
2753
|
+
|
|
2754
|
+
def triton_quantize_fp8_packed_row(
|
|
2755
|
+
a: Tensor,
|
|
2756
|
+
scale_ub: Optional[Tensor] = None,
|
|
2757
|
+
zero_start_index_M: Optional[Tensor] = None,
|
|
2758
|
+
return_only_packed: Optional[bool] = False,
|
|
2759
|
+
) -> tuple[Optional[Tensor], Optional[Tensor], Tensor]:
|
|
2760
|
+
"""
|
|
2761
|
+
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
|
|
2762
|
+
|
|
2763
|
+
This packs the FP32 scale at the end of each row, so the fp8 scaled tensor and the reciprocal scale tensor per row are contiguous in memory.
|
|
2764
|
+
|
|
2765
|
+
Args:
|
|
2766
|
+
a (Tensor): higher precision input tensor of 4 dimension.
|
|
2767
|
+
scale_ub (Tensor): Maximum allowed value for scale.
|
|
2768
|
+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
|
|
2769
|
+
return_only_packed (bool): Only return the packed tensor, do not unpack results if True
|
|
2770
|
+
Returns:
|
|
2771
|
+
torch.Tensor: fp8 scaled tensor.
|
|
2772
|
+
torch.Tensor: reciprocal scale tensor per row.
|
|
2773
|
+
torch.Tensor: The packed FP8 scaled tensor, with the scale at the end of each row.
|
|
2774
|
+
"""
|
|
2775
|
+
if scale_ub is not None and scale_ub.device != a.device:
|
|
2776
|
+
raise Exception("'scale_ub' must be on the same device as 'a'")
|
|
2777
|
+
if zero_start_index_M is not None and zero_start_index_M.device != a.device:
|
|
2778
|
+
raise Exception("'zero_start_index_M' must be on the same device as 'a'")
|
|
2779
|
+
|
|
2780
|
+
assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
|
|
2781
|
+
a_shape = a.shape
|
|
2782
|
+
while a.dim() < 4:
|
|
2783
|
+
a = a.unsqueeze(0)
|
|
2784
|
+
if zero_start_index_M is not None:
|
|
2785
|
+
# There should be one value of zero_start_index_M per NxK matrix.
|
|
2786
|
+
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
|
|
2787
|
+
# Get constant values.
|
|
2788
|
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
|
2789
|
+
num_rows = a.numel() // a.shape[-1]
|
|
2790
|
+
|
|
2791
|
+
# Allocate an extra 4-bytes at the end of each row for the scale.
|
|
2792
|
+
a_fp8 = torch.empty(
|
|
2793
|
+
(*a.shape[:-1], a.shape[-1] + 4), device=a.device, dtype=pt_dtype
|
|
2794
|
+
)
|
|
2795
|
+
|
|
2796
|
+
# create a view of the packed scale
|
|
2797
|
+
packed_scale = a_fp8[..., -4:]
|
|
2798
|
+
|
|
2799
|
+
# If input tensor is sufficiently large, we need to use int64 indexing.
|
|
2800
|
+
use_int64 = a.numel() > (2**31 - 1)
|
|
2801
|
+
grid = (num_rows,)
|
|
2802
|
+
|
|
2803
|
+
with torch.cuda.device(a.device.index):
|
|
2804
|
+
_kernel_quantize_fp8_packed_row[grid](
|
|
2805
|
+
a,
|
|
2806
|
+
a_fp8,
|
|
2807
|
+
packed_scale,
|
|
2808
|
+
scale_ub,
|
|
2809
|
+
zero_start_index_M,
|
|
2810
|
+
a.shape[0],
|
|
2811
|
+
a.shape[1],
|
|
2812
|
+
a.shape[2],
|
|
2813
|
+
a.shape[3],
|
|
2814
|
+
a.stride(0),
|
|
2815
|
+
a.stride(1),
|
|
2816
|
+
a.stride(2),
|
|
2817
|
+
a.stride(3),
|
|
2818
|
+
a_fp8.stride(0),
|
|
2819
|
+
a_fp8.stride(1),
|
|
2820
|
+
a_fp8.stride(2),
|
|
2821
|
+
a_fp8.stride(3),
|
|
2822
|
+
packed_scale.stride(2), # this is the stride that matters
|
|
2823
|
+
zero_start_index_M.stride(0) if zero_start_index_M is not None else None,
|
|
2824
|
+
zero_start_index_M.stride(1) if zero_start_index_M is not None else None,
|
|
2825
|
+
TL_FP8_DTYPE=tl_dtype,
|
|
2826
|
+
MAX_FP8=max_fp8,
|
|
2827
|
+
EPS=eps,
|
|
2828
|
+
CLAMP_MAX=scale_ub is not None,
|
|
2829
|
+
JAGGED=zero_start_index_M is not None,
|
|
2830
|
+
USE_INT64=use_int64,
|
|
2831
|
+
)
|
|
2832
|
+
if return_only_packed:
|
|
2833
|
+
return None, None, a_fp8.view((*a_shape[:-1], a_shape[-1] + 4))
|
|
2834
|
+
|
|
2835
|
+
# Extract the original shape data without the extra 4 bytes per row
|
|
2836
|
+
# The data is still contiguous in memory, so we have to unpack it.
|
|
2837
|
+
final_fp8_view = a_fp8[..., :-4].view(a_shape)
|
|
2838
|
+
scale_view = a_fp8[..., -4:].reshape((num_rows * 4)).view(torch.float32)
|
|
2839
|
+
|
|
2840
|
+
# the difference with the packed API is that it also
|
|
2841
|
+
# returns the full packed tensor as a third return value
|
|
2842
|
+
return final_fp8_view, scale_view.view(a_shape[:-1]), a_fp8
|
|
2843
|
+
|
|
2844
|
+
|
|
2845
|
+
@torch.library.custom_op("triton::quantize_fp8_packed_row", mutates_args=())
|
|
2846
|
+
def quantize_fp8_packed_row(
|
|
2847
|
+
a: Tensor,
|
|
2848
|
+
scale_ub: Optional[Tensor] = None,
|
|
2849
|
+
zero_start_index_M: Optional[Tensor] = None,
|
|
2850
|
+
use_triton: bool = True,
|
|
2851
|
+
output_device: Optional[torch.device] = None,
|
|
2852
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
2853
|
+
"""
|
|
2854
|
+
Quantize a to fp8 with row-wise scalings and optionally move to output device.
|
|
2855
|
+
|
|
2856
|
+
Args:
|
|
2857
|
+
a (Tensor): Input high precision tensor. Required to have no more than 4 dimension
|
|
2858
|
+
scale_ub (Tensor): Maximum allowed value for scale.
|
|
2859
|
+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
|
|
2860
|
+
use_triton (bool): Whether to use triton kernel or pytorch.
|
|
2861
|
+
output_device (torch.device): Device to optionally move the scaled tensors to.
|
|
2862
|
+
Returns:
|
|
2863
|
+
torch.Tensor: fp8 scaled tensor.
|
|
2864
|
+
torch.Tensor: The reciprocal scale tensor per row.
|
|
2865
|
+
"""
|
|
2866
|
+
|
|
2867
|
+
if a.device == torch.device("cpu"):
|
|
2868
|
+
logger.info("Triton does not support cpu, falling back to torch ops.")
|
|
2869
|
+
use_triton = False
|
|
2870
|
+
if use_triton:
|
|
2871
|
+
# ignore the packed tensor here, we aren't testing it
|
|
2872
|
+
a_fp8, scale, _ = triton_quantize_fp8_packed_row(
|
|
2873
|
+
a, scale_ub, zero_start_index_M, return_only_packed=False
|
|
2874
|
+
)
|
|
2875
|
+
assert a_fp8 is not None
|
|
2876
|
+
assert scale is not None
|
|
2877
|
+
return a_fp8, scale
|
|
2878
|
+
# else use pytorch implementation.
|
|
2879
|
+
if not output_device:
|
|
2880
|
+
output_device = a.device
|
|
2881
|
+
|
|
2882
|
+
a_shape = a.shape
|
|
2883
|
+
# Get constants.
|
|
2884
|
+
pt_dtype, _, max_fp8, eps = get_fp8_constants()
|
|
2885
|
+
row_max: torch.Tensor = torch.max(torch.abs(a), dim=-1)[0]
|
|
2886
|
+
# Apply clamping.
|
|
2887
|
+
if scale_ub is not None:
|
|
2888
|
+
row_max = torch.clamp(row_max, min=eps, max=scale_ub.item())
|
|
2889
|
+
else:
|
|
2890
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
2891
|
+
row_max = torch.clamp(row_max, min=eps)
|
|
2892
|
+
a_scale = torch.empty((a.shape[:-1]), dtype=torch.float32, device=output_device)
|
|
2893
|
+
a_scale = max_fp8 / row_max.to(torch.float32) # pyre-ignore
|
|
2894
|
+
a_scale[a_scale == float("inf")] = 1.0 # pyre-ignore
|
|
2895
|
+
a_fp8 = a * a_scale[..., None] # pyre-ignore
|
|
2896
|
+
# Cast and move data to output device (for cpu weight loading).
|
|
2897
|
+
a_fp8 = a_fp8.to(device=output_device, dtype=pt_dtype)
|
|
2898
|
+
a_scale = a_scale.to(output_device) # pyre-ignore
|
|
2899
|
+
del a
|
|
2900
|
+
return a_fp8, (1 / a_scale).view(a_shape[:-1]) # pyre-ignore
|
|
2901
|
+
|
|
2902
|
+
|
|
2903
|
+
@torch.library.custom_op("triton::quantize_fp8_packed_row_raw", mutates_args=())
|
|
2904
|
+
def quantize_fp8_packed_row_raw(
|
|
2905
|
+
a: Tensor,
|
|
2906
|
+
scale_ub: Optional[Tensor] = None,
|
|
2907
|
+
zero_start_index_M: Optional[Tensor] = None,
|
|
2908
|
+
use_triton: bool = True,
|
|
2909
|
+
output_device: Optional[torch.device] = None,
|
|
2910
|
+
) -> torch.Tensor:
|
|
2911
|
+
"""
|
|
2912
|
+
Quantize a to fp8 with row-wise scalings and optionally move to output device.
|
|
2913
|
+
|
|
2914
|
+
Identical to quantize_fp8_packed_row, except it only returns the raw packed tensor.
|
|
2915
|
+
|
|
2916
|
+
Args:
|
|
2917
|
+
a (Tensor): Input high precision tensor. Required to have no more than 4 dimension
|
|
2918
|
+
scale_ub (Tensor): Maximum allowed value for scale.
|
|
2919
|
+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
|
|
2920
|
+
use_triton (bool): Whether to use triton kernel or pytorch.
|
|
2921
|
+
output_device (torch.device): Device to optionally move the scaled tensors to.
|
|
2922
|
+
Returns:
|
|
2923
|
+
torch.Tensor: fp8 scaled tensor.
|
|
2924
|
+
torch.Tensor: The reciprocal scale tensor per row.
|
|
2925
|
+
"""
|
|
2926
|
+
|
|
2927
|
+
if a.device == torch.device("cpu"):
|
|
2928
|
+
logger.info("Triton does not support cpu, falling back to torch ops.")
|
|
2929
|
+
use_triton = False
|
|
2930
|
+
if use_triton:
|
|
2931
|
+
# ignore the packed tensor here, we aren't testing it
|
|
2932
|
+
_, _, packed_tensor = triton_quantize_fp8_packed_row(
|
|
2933
|
+
a, scale_ub, zero_start_index_M, return_only_packed=True
|
|
2934
|
+
)
|
|
2935
|
+
return packed_tensor
|
|
2936
|
+
else:
|
|
2937
|
+
raise Exception(
|
|
2938
|
+
"No PyTorch implementation provided for triton::quantize_fp8_packed_row_raw"
|
|
2939
|
+
)
|
|
2940
|
+
|
|
2941
|
+
|
|
2942
|
+
@torch.library.custom_op("triton::quantize_fp8_row", mutates_args=())
|
|
2943
|
+
def quantize_fp8_row(
|
|
2944
|
+
a: Tensor,
|
|
2945
|
+
scale_ub: Optional[Tensor] = None,
|
|
2946
|
+
zero_start_index_M: Optional[Tensor] = None,
|
|
2947
|
+
use_triton: bool = True,
|
|
2948
|
+
output_device: Optional[torch.device] = None,
|
|
2949
|
+
align_rows_to: Optional[int] = None,
|
|
2950
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
2951
|
+
"""
|
|
2952
|
+
Quantize a to fp8 with row-wise scalings and optionally move to output device.
|
|
2953
|
+
|
|
2954
|
+
Args:
|
|
2955
|
+
a (Tensor): Input high precision tensor. Required to have no more than 4 dimension
|
|
2956
|
+
scale_ub (Tensor): Maximum allowed value for scale.
|
|
2957
|
+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
|
|
2958
|
+
use_triton (bool): Whether to use triton kernel or pytorch.
|
|
2959
|
+
output_device (torch.device): Device to optionally move the scaled tensors to.
|
|
2960
|
+
align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
|
|
2961
|
+
|
|
2962
|
+
Returns:
|
|
2963
|
+
torch.Tensor: fp8 scaled tensor.
|
|
2964
|
+
torch.Tensor: The reciprocal scale tensor per row.
|
|
2965
|
+
"""
|
|
2966
|
+
|
|
2967
|
+
if a.device == torch.device("cpu"):
|
|
2968
|
+
logger.info("Triton does not support cpu, falling back to torch ops.")
|
|
2969
|
+
use_triton = False
|
|
2970
|
+
if use_triton:
|
|
2971
|
+
return triton_quantize_fp8_row(
|
|
2972
|
+
a,
|
|
2973
|
+
scale_ub,
|
|
2974
|
+
zero_start_index_M,
|
|
2975
|
+
align_rows_to=align_rows_to,
|
|
2976
|
+
)
|
|
2977
|
+
# else use pytorch implementation.
|
|
2978
|
+
if not output_device:
|
|
2979
|
+
output_device = a.device
|
|
2980
|
+
|
|
2981
|
+
a_shape = a.shape
|
|
2982
|
+
# Get constants.
|
|
2983
|
+
pt_dtype, _, max_fp8, eps = get_fp8_constants()
|
|
2984
|
+
row_max: torch.Tensor = torch.max(torch.abs(a), dim=-1)[0]
|
|
2985
|
+
# Apply clamping.
|
|
2986
|
+
if scale_ub is not None:
|
|
2987
|
+
row_max = torch.clamp(row_max, min=eps, max=scale_ub.item())
|
|
2988
|
+
else:
|
|
2989
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
2990
|
+
row_max = torch.clamp(row_max, min=eps)
|
|
2991
|
+
a_scale = torch.empty((a.shape[:-1]), dtype=torch.float32, device=output_device)
|
|
2992
|
+
a_scale = max_fp8 / row_max.to(torch.float32) # pyre-ignore
|
|
2993
|
+
a_scale[a_scale == float("inf")] = 1.0 # pyre-ignore
|
|
2994
|
+
a_fp8 = a * a_scale[..., None] # pyre-ignore
|
|
2995
|
+
# Cast and move data to output device (for cpu weight loading).
|
|
2996
|
+
a_fp8 = a_fp8.to(device=output_device, dtype=pt_dtype)
|
|
2997
|
+
a_scale = a_scale.to(output_device) # pyre-ignore
|
|
2998
|
+
del a
|
|
2999
|
+
return a_fp8, (1 / a_scale).view(a_shape[:-1]) # pyre-ignore
|
|
3000
|
+
|
|
3001
|
+
|
|
3002
|
+
@quantize_fp8_row.register_fake
|
|
3003
|
+
def quantize_fp8_row_meta(
|
|
3004
|
+
a: Tensor,
|
|
3005
|
+
scale_ub: Optional[Tensor] = None,
|
|
3006
|
+
zero_start_index_M: Optional[Tensor] = None,
|
|
3007
|
+
use_triton: bool = True,
|
|
3008
|
+
output_device: Optional[torch.device] = None,
|
|
3009
|
+
align_rows_to: Optional[int] = None,
|
|
3010
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
3011
|
+
"""Shape function for torch compile."""
|
|
3012
|
+
if output_device is None:
|
|
3013
|
+
output_device = a.device
|
|
3014
|
+
a_shape = a.shape
|
|
3015
|
+
dtype = get_fp8_constants()[0]
|
|
3016
|
+
fake_scale = torch.empty(a_shape[:-1], device=output_device, dtype=torch.float32)
|
|
3017
|
+
if align_rows_to is not None:
|
|
3018
|
+
last_dim = a.shape[-1]
|
|
3019
|
+
padded_last_dim = (
|
|
3020
|
+
(last_dim + align_rows_to - 1) // align_rows_to
|
|
3021
|
+
) * align_rows_to
|
|
3022
|
+
fake_out = torch.empty(
|
|
3023
|
+
(*a.shape[:-1], padded_last_dim), device=output_device, dtype=dtype
|
|
3024
|
+
)
|
|
3025
|
+
return fake_out, fake_scale
|
|
3026
|
+
else:
|
|
3027
|
+
fake_out = torch.empty(a.shape, device=output_device, dtype=dtype)
|
|
3028
|
+
return fake_out, fake_scale
|
|
3029
|
+
|
|
3030
|
+
|
|
3031
|
+
@triton.autotune(
|
|
3032
|
+
configs=[
|
|
3033
|
+
Config({"BLOCK_SIZE": 512}),
|
|
3034
|
+
Config({"BLOCK_SIZE": 1024}),
|
|
3035
|
+
Config({"BLOCK_SIZE": 2048}),
|
|
3036
|
+
Config({"BLOCK_SIZE": 4096}),
|
|
3037
|
+
Config({"BLOCK_SIZE": 8192}),
|
|
3038
|
+
],
|
|
3039
|
+
key=["N"],
|
|
3040
|
+
)
|
|
3041
|
+
@triton.jit
|
|
3042
|
+
def _kernel_scale_fp8_row(
|
|
3043
|
+
A,
|
|
3044
|
+
x_scale,
|
|
3045
|
+
w_scale,
|
|
3046
|
+
scaled_out,
|
|
3047
|
+
M,
|
|
3048
|
+
N,
|
|
3049
|
+
stride_am,
|
|
3050
|
+
stride_an,
|
|
3051
|
+
stride_om,
|
|
3052
|
+
stride_on,
|
|
3053
|
+
BLOCK_SIZE: tl.constexpr,
|
|
3054
|
+
) -> None:
|
|
3055
|
+
"""
|
|
3056
|
+
Scale each row of A by x_scale and each column of A by w_scale.
|
|
3057
|
+
|
|
3058
|
+
Args:
|
|
3059
|
+
A (Tensor): [m, n] Input tensor to scale.
|
|
3060
|
+
x_scale (Tensor): [m] Row-wise scale tensor.
|
|
3061
|
+
w_scale (Tensor): [n] Col-wise scale tensor.
|
|
3062
|
+
scaled_out (Tensor): [m, n] Output tensor.
|
|
3063
|
+
M (int): Number of rows.
|
|
3064
|
+
N (int): Number of columns.
|
|
3065
|
+
stride_am (int): Stride of m dimension of A.
|
|
3066
|
+
stride_an (int): Stride of n dimension of A.
|
|
3067
|
+
stride_om (int): Stride of m dimension of output.
|
|
3068
|
+
stride_on (int): Stride of n dimension of output.
|
|
3069
|
+
BLOCK_SIZE (int): Block size for data loads.
|
|
3070
|
+
"""
|
|
3071
|
+
pid = tl.program_id(0)
|
|
3072
|
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
|
3073
|
+
# Load activation scale for this row.
|
|
3074
|
+
row_scale = tl.load(x_scale + pid)
|
|
3075
|
+
|
|
3076
|
+
# Iterate over chunks of the row and apply scales.
|
|
3077
|
+
for _k in range(0, tl.cdiv(N, BLOCK_SIZE)):
|
|
3078
|
+
a = tl.load(
|
|
3079
|
+
A + pid * stride_am + n_offset * stride_an, mask=n_offset < N, other=0.0
|
|
3080
|
+
)
|
|
3081
|
+
col_scale = tl.load(w_scale + n_offset)
|
|
3082
|
+
scaled_a = a * row_scale * col_scale
|
|
3083
|
+
tl.store(
|
|
3084
|
+
scaled_out + pid * stride_om + n_offset * stride_on,
|
|
3085
|
+
scaled_a,
|
|
3086
|
+
mask=n_offset < N,
|
|
3087
|
+
)
|
|
3088
|
+
n_offset += BLOCK_SIZE
|
|
3089
|
+
|
|
3090
|
+
|
|
3091
|
+
def scale_fp8_row(
|
|
3092
|
+
a: Tensor,
|
|
3093
|
+
x_scale: Tensor,
|
|
3094
|
+
w_scale: Tensor,
|
|
3095
|
+
) -> torch.Tensor:
|
|
3096
|
+
"""
|
|
3097
|
+
Apply only rowwise scaling to a tensor. Useful when combining with kernels
|
|
3098
|
+
that do not support fused rowwise scaling.
|
|
3099
|
+
|
|
3100
|
+
Args:
|
|
3101
|
+
a (Tensor): Input floating point tensor to be scaled.
|
|
3102
|
+
x_scale (Tensor): Row-wise activation scale tensor.
|
|
3103
|
+
w_scale (Tensor): Col-wise weight scale tensor.
|
|
3104
|
+
"""
|
|
3105
|
+
if a.device == torch.device("cpu"):
|
|
3106
|
+
# On CPU we'll just use native pytorch to scale.
|
|
3107
|
+
return a * x_scale[:, None] * w_scale[None, :]
|
|
3108
|
+
|
|
3109
|
+
if x_scale.device != a.device:
|
|
3110
|
+
raise Exception("'x_scale' must be on the same device as 'a'")
|
|
3111
|
+
if w_scale.device != a.device:
|
|
3112
|
+
raise Exception("'w_scale' must be on the same device as 'a'")
|
|
3113
|
+
|
|
3114
|
+
# Otherwise, use a fast triton kernel to implement.
|
|
3115
|
+
# We'll parallelize over rows.
|
|
3116
|
+
num_rows = a.shape[0]
|
|
3117
|
+
scaled_out = torch.empty(a.shape, device=a.device, dtype=a.dtype)
|
|
3118
|
+
grid = (num_rows,)
|
|
3119
|
+
with torch.cuda.device(a.device.index):
|
|
3120
|
+
_kernel_scale_fp8_row[grid](
|
|
3121
|
+
a,
|
|
3122
|
+
x_scale,
|
|
3123
|
+
w_scale,
|
|
3124
|
+
scaled_out,
|
|
3125
|
+
a.shape[0],
|
|
3126
|
+
a.shape[1],
|
|
3127
|
+
a.stride(0),
|
|
3128
|
+
a.stride(1),
|
|
3129
|
+
scaled_out.stride(0),
|
|
3130
|
+
scaled_out.stride(1),
|
|
3131
|
+
)
|
|
3132
|
+
|
|
3133
|
+
return scaled_out
|
|
3134
|
+
|
|
3135
|
+
|
|
3136
|
+
@triton.jit
|
|
3137
|
+
def _kernel_quantize_fp8_block(
|
|
3138
|
+
A,
|
|
3139
|
+
A_scale,
|
|
3140
|
+
A_fp8,
|
|
3141
|
+
scale_ub,
|
|
3142
|
+
M,
|
|
3143
|
+
K,
|
|
3144
|
+
stride_am,
|
|
3145
|
+
stride_ak,
|
|
3146
|
+
stride_om,
|
|
3147
|
+
stride_ok,
|
|
3148
|
+
stride_a_scale_m,
|
|
3149
|
+
stride_a_scale_k,
|
|
3150
|
+
TL_FP8_DTYPE: tl.constexpr,
|
|
3151
|
+
MAX_FP8: tl.constexpr,
|
|
3152
|
+
EPS: tl.constexpr,
|
|
3153
|
+
CLAMP_MAX: tl.constexpr,
|
|
3154
|
+
BLOCK_M: tl.constexpr,
|
|
3155
|
+
BLOCK_K: tl.constexpr,
|
|
3156
|
+
K_MAJOR: tl.constexpr,
|
|
3157
|
+
) -> None:
|
|
3158
|
+
"""Quantize and scale each [BLOCK_M, BLOCK_K] block.
|
|
3159
|
+
|
|
3160
|
+
Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(A[i:i+BLOCK_M, j:j+BLOCK_K])))
|
|
3161
|
+
|
|
3162
|
+
Kernel naively iterates through matrix with [BLOCK_M, BLOCK_K] tiles.
|
|
3163
|
+
|
|
3164
|
+
Todo:
|
|
3165
|
+
* Better tiling and ordering schemes.
|
|
3166
|
+
|
|
3167
|
+
Args:
|
|
3168
|
+
A (Tensor): [M, K] higher precision input tensor.
|
|
3169
|
+
A_scale (Tensor): [cdiv(M, BLOCK_M), cdiv(K, BLOCK_K)] reciprocal scale tensor per block.
|
|
3170
|
+
A_fp8 (Tensor): [M, K] fp8 scaled tensor. A_fp8 = A * a_scale
|
|
3171
|
+
scale_ub (Tensor): [1] Maximum allowed value for scale.
|
|
3172
|
+
M (int): Number of rows.
|
|
3173
|
+
K (int): Number of columns.
|
|
3174
|
+
stride_am (int): Stride of m dimension of A.
|
|
3175
|
+
stride_ak (int): Stride of k dimension of A.
|
|
3176
|
+
stride_om (int): Stride of m dimension of output.
|
|
3177
|
+
stride_ok (int): Stride of k dimension of output.
|
|
3178
|
+
stride_a_scale_m (int): Stride of m dimension of A_scale.
|
|
3179
|
+
stride_a_scale_k (int): Stride of k dimension of A_scale.
|
|
3180
|
+
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
|
|
3181
|
+
MAX_FP8 (float): Maxmimum expressible value for FP8.
|
|
3182
|
+
EPS (float): Epsilon value for numerical stability.
|
|
3183
|
+
CLAMP_MAX (bool): Whether to apply scale_ub.
|
|
3184
|
+
BLOCK_M (int): Block size for M dimension of A_scale and kernel.
|
|
3185
|
+
BLOCK_K (int): Block size for K dimension of A_scale and kernel.
|
|
3186
|
+
K_MAJOR (bool): Whether output scales should be K major (True) or MN major (False).
|
|
3187
|
+
"""
|
|
3188
|
+
pid = tl.program_id(0)
|
|
3189
|
+
grid_k = tl.cdiv(K, BLOCK_K)
|
|
3190
|
+
block_m = pid // grid_k
|
|
3191
|
+
block_k = pid % grid_k
|
|
3192
|
+
rm = block_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
3193
|
+
rk = block_k * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
3194
|
+
a_offset = rm[:, None] * stride_am + rk[None, :] * stride_ak
|
|
3195
|
+
out_offset = rm[:, None] * stride_om + rk[None, :] * stride_ok
|
|
3196
|
+
a_mask = (rm < M)[:, None] & (rk < K)[None, :]
|
|
3197
|
+
a_block = tl.load(A + a_offset, mask=a_mask, other=0.0)
|
|
3198
|
+
|
|
3199
|
+
block_max = tl.max(tl.abs(a_block))
|
|
3200
|
+
# Apply appropriate clamping.
|
|
3201
|
+
if CLAMP_MAX:
|
|
3202
|
+
ub = tl.load(scale_ub)
|
|
3203
|
+
block_max = tl.clamp(block_max, EPS, ub)
|
|
3204
|
+
else:
|
|
3205
|
+
block_max = tl.maximum(block_max, EPS)
|
|
3206
|
+
scale = MAX_FP8 / block_max
|
|
3207
|
+
|
|
3208
|
+
# Write in transposed order if specified.
|
|
3209
|
+
if K_MAJOR:
|
|
3210
|
+
scale_offset = block_m * stride_a_scale_m + block_k * stride_a_scale_k
|
|
3211
|
+
else:
|
|
3212
|
+
scale_offset = block_k * stride_a_scale_m + block_m * stride_a_scale_k
|
|
3213
|
+
tl.store(A_scale + scale_offset, 1.0 / scale)
|
|
3214
|
+
a_fp8 = a_block * scale
|
|
3215
|
+
# Clamp A to fp8 range to make sure there's no overflow.
|
|
3216
|
+
# This is required for AMD. Nvidia's default saturation
|
|
3217
|
+
# handles it, but it's nice to have anyway.
|
|
3218
|
+
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8)
|
|
3219
|
+
a_fp8.to(TL_FP8_DTYPE)
|
|
3220
|
+
tl.store(A_fp8 + out_offset, a_fp8, mask=a_mask)
|
|
3221
|
+
|
|
3222
|
+
|
|
3223
|
+
def triton_quantize_fp8_block(
|
|
3224
|
+
x: torch.Tensor,
|
|
3225
|
+
block_m: int = 256,
|
|
3226
|
+
block_k: int = 256,
|
|
3227
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
3228
|
+
k_major: bool = True,
|
|
3229
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
3230
|
+
"""
|
|
3231
|
+
Quantize a tensor to fp8 with block-wise scalings.
|
|
3232
|
+
|
|
3233
|
+
Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(x[i:i+block_m, j:j+block_k])))
|
|
3234
|
+
|
|
3235
|
+
Args:
|
|
3236
|
+
x (torch.Tensor): [M, K] higher precision input tensor.
|
|
3237
|
+
block_m (int): Block size for M dimension of scale.
|
|
3238
|
+
block_k (int): Block size for K dimension of scale.
|
|
3239
|
+
scale_ub: Maximum allowed value for scale.
|
|
3240
|
+
k_major (bool): Whether output scales should be K major (True) or MN major (False).
|
|
3241
|
+
|
|
3242
|
+
Returns:
|
|
3243
|
+
torch.Tensor : [M, K] fp8 scaled tensor.
|
|
3244
|
+
torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block
|
|
3245
|
+
if k_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
|
|
3246
|
+
"""
|
|
3247
|
+
assert x.device != torch.device(
|
|
3248
|
+
"cpu"
|
|
3249
|
+
), "Blockwise quantization not support on cpu, please use row-wise quantization instead."
|
|
3250
|
+
|
|
3251
|
+
if scale_ub is not None and scale_ub.device != x.device:
|
|
3252
|
+
raise Exception("'scale_ub' must be on the same device as 'a'")
|
|
3253
|
+
|
|
3254
|
+
x_shape = x.shape
|
|
3255
|
+
x = x.view(-1, x.size(-1))
|
|
3256
|
+
# Get constant values.
|
|
3257
|
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
|
3258
|
+
M, K = x.shape
|
|
3259
|
+
grid_m = triton.cdiv(M, block_m)
|
|
3260
|
+
grid_k = triton.cdiv(K, block_k)
|
|
3261
|
+
if k_major:
|
|
3262
|
+
x_scale = torch.empty((grid_m, grid_k), device=x.device, dtype=torch.float32)
|
|
3263
|
+
else:
|
|
3264
|
+
x_scale = torch.empty((grid_k, grid_m), device=x.device, dtype=torch.float32)
|
|
3265
|
+
x_fp8 = torch.empty((M, K), device=x.device, dtype=pt_dtype)
|
|
3266
|
+
|
|
3267
|
+
_kernel_quantize_fp8_block[(grid_m * grid_k,)](
|
|
3268
|
+
x,
|
|
3269
|
+
x_scale,
|
|
3270
|
+
x_fp8,
|
|
3271
|
+
scale_ub,
|
|
3272
|
+
M,
|
|
3273
|
+
K,
|
|
3274
|
+
x.stride(0),
|
|
3275
|
+
x.stride(1),
|
|
3276
|
+
x_fp8.stride(0),
|
|
3277
|
+
x_fp8.stride(1),
|
|
3278
|
+
x_scale.stride(0),
|
|
3279
|
+
x_scale.stride(1),
|
|
3280
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
3281
|
+
TL_FP8_DTYPE=tl_dtype,
|
|
3282
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
3283
|
+
MAX_FP8=max_fp8,
|
|
3284
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
3285
|
+
EPS=eps,
|
|
3286
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
3287
|
+
CLAMP_MAX=scale_ub is not None,
|
|
3288
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
3289
|
+
BLOCK_M=block_m,
|
|
3290
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
3291
|
+
BLOCK_K=block_k,
|
|
3292
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
3293
|
+
K_MAJOR=k_major,
|
|
3294
|
+
)
|
|
3295
|
+
|
|
3296
|
+
return x_fp8.view(x_shape), x_scale
|
|
3297
|
+
|
|
3298
|
+
|
|
3299
|
+
def quantize_fp8_block(
|
|
3300
|
+
x: torch.Tensor,
|
|
3301
|
+
block_m: int = 256,
|
|
3302
|
+
block_k: int = 256,
|
|
3303
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
3304
|
+
use_triton: bool = True,
|
|
3305
|
+
output_device: Optional[torch.device] = None,
|
|
3306
|
+
k_major: bool = True,
|
|
3307
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
3308
|
+
"""
|
|
3309
|
+
Quantize a tensor to fp8 with block-wise scalings and optionally move to output device.
|
|
3310
|
+
|
|
3311
|
+
Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(x[i:i+block_m, j:j+block_k])))
|
|
3312
|
+
|
|
3313
|
+
Args:
|
|
3314
|
+
x (Tensor): [M, K] higher precision input tensor.
|
|
3315
|
+
block_m (int): Block size for M dimension of scale.
|
|
3316
|
+
block_k (int): Block size for K dimension of scale.
|
|
3317
|
+
scale_ub: Maximum allowed value for scale.
|
|
3318
|
+
use_triton (bool): Whether to use triton kernel or pytorch.
|
|
3319
|
+
output_device (torch.device): Device to optionally move the scaled tensors to.
|
|
3320
|
+
k_major (bool): Whether output scales should be K major (True) or MN major (False).
|
|
3321
|
+
|
|
3322
|
+
Returns:
|
|
3323
|
+
torch.Tensor: [M, K] fp8 scaled tensor.
|
|
3324
|
+
torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block
|
|
3325
|
+
if k_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
|
|
3326
|
+
"""
|
|
3327
|
+
x_shape = x.shape
|
|
3328
|
+
x = x.view(-1, x.size(-1))
|
|
3329
|
+
if x.device == torch.device("cpu"):
|
|
3330
|
+
logger.info("Triton does not support cpu, falling back to torch ops.")
|
|
3331
|
+
use_triton = False
|
|
3332
|
+
if use_triton:
|
|
3333
|
+
xq, x_scale = triton_quantize_fp8_block(x, block_m, block_k, scale_ub, k_major)
|
|
3334
|
+
return xq.view(x_shape), x_scale
|
|
3335
|
+
# else use pytorch implementation.
|
|
3336
|
+
if not output_device:
|
|
3337
|
+
output_device = x.device
|
|
3338
|
+
|
|
3339
|
+
# Get constants.
|
|
3340
|
+
pt_dtype, _, max_fp8, eps = get_fp8_constants()
|
|
3341
|
+
|
|
3342
|
+
M, K = x.shape
|
|
3343
|
+
grid_m = triton.cdiv(M, block_m)
|
|
3344
|
+
grid_k = triton.cdiv(K, block_k)
|
|
3345
|
+
|
|
3346
|
+
# Pad x to multiple of block size.
|
|
3347
|
+
padded_m = grid_m * block_m
|
|
3348
|
+
padded_k = grid_k * block_k
|
|
3349
|
+
x_padded = torch.zeros(padded_m, padded_k, dtype=x.dtype, device=x.device)
|
|
3350
|
+
x_padded[:M, :K] = x
|
|
3351
|
+
|
|
3352
|
+
# Blockwise max.
|
|
3353
|
+
block_max = (
|
|
3354
|
+
x_padded.abs().reshape(grid_m, block_m, grid_k, block_k).amax(dim=(1, 3))
|
|
3355
|
+
)
|
|
3356
|
+
|
|
3357
|
+
# Apply clamping.
|
|
3358
|
+
if scale_ub is not None:
|
|
3359
|
+
block_max = torch.clamp(block_max, min=eps, max=scale_ub.item())
|
|
3360
|
+
else:
|
|
3361
|
+
block_max = torch.clamp(block_max, min=eps)
|
|
3362
|
+
x_scale = torch.empty((grid_m, grid_k), dtype=torch.float32, device=output_device)
|
|
3363
|
+
x_scale = max_fp8 / block_max.to(torch.float32) # pyre-ignore
|
|
3364
|
+
# pyre-ignore[16]: Undefined attribute [16]
|
|
3365
|
+
x_scale[x_scale == float("inf")] = 1.0
|
|
3366
|
+
x_fp8 = (
|
|
3367
|
+
x_padded
|
|
3368
|
+
# pyre-ignore[16]: Undefined attribute [16]
|
|
3369
|
+
* x_scale.repeat_interleave(block_m, dim=0).repeat_interleave(block_k, dim=1)
|
|
3370
|
+
)[:M, :K]
|
|
3371
|
+
|
|
3372
|
+
# Cast and move data to output device (for cpu weight loading).
|
|
3373
|
+
x_fp8 = x_fp8.to(device=output_device, dtype=pt_dtype)
|
|
3374
|
+
x_scale = x_scale.to(output_device) # pyre-ignore
|
|
3375
|
+
del x, x_padded
|
|
3376
|
+
if not k_major:
|
|
3377
|
+
x_scale = x_scale.t().contiguous()
|
|
3378
|
+
return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore
|
|
3379
|
+
|
|
3380
|
+
|
|
3381
|
+
@triton.autotune(
|
|
3382
|
+
configs=[
|
|
3383
|
+
Config({"GROUP_LOAD": 2}),
|
|
3384
|
+
Config({"GROUP_LOAD": 4}),
|
|
3385
|
+
Config({"GROUP_LOAD": 8}),
|
|
3386
|
+
Config({"GROUP_LOAD": 16}),
|
|
3387
|
+
Config({"GROUP_LOAD": 32}),
|
|
3388
|
+
],
|
|
3389
|
+
key=["K"],
|
|
3390
|
+
)
|
|
3391
|
+
@triton.jit
|
|
3392
|
+
def _kernel_quantize_fp8_group(
|
|
3393
|
+
A,
|
|
3394
|
+
A_scale,
|
|
3395
|
+
A_fp8,
|
|
3396
|
+
scale_ub,
|
|
3397
|
+
m_sizes,
|
|
3398
|
+
M,
|
|
3399
|
+
K,
|
|
3400
|
+
stride_am,
|
|
3401
|
+
stride_ak,
|
|
3402
|
+
stride_om,
|
|
3403
|
+
stride_ok,
|
|
3404
|
+
stride_a_scale_m,
|
|
3405
|
+
stride_a_scale_k,
|
|
3406
|
+
TL_FP8_DTYPE: tl.constexpr,
|
|
3407
|
+
MAX_FP8: tl.constexpr,
|
|
3408
|
+
EPS: tl.constexpr,
|
|
3409
|
+
CLAMP_MAX: tl.constexpr,
|
|
3410
|
+
USE_INT64: tl.constexpr,
|
|
3411
|
+
GROUP_SIZE: tl.constexpr,
|
|
3412
|
+
USE_M_MAJOR: tl.constexpr,
|
|
3413
|
+
G: tl.constexpr,
|
|
3414
|
+
GROUP_LOAD: tl.constexpr,
|
|
3415
|
+
):
|
|
3416
|
+
"""Quantize and scale each GROUP_SIZE chunk of each row.
|
|
3417
|
+
|
|
3418
|
+
Scale per group i is computed as 1 / (MAX_FP8 / max(abs(A[i:i+GROUP_SIZE])))
|
|
3419
|
+
|
|
3420
|
+
Each kernel thread is responsible for one row and loads and processes a tunable
|
|
3421
|
+
number of groups at once.
|
|
3422
|
+
|
|
3423
|
+
Args:
|
|
3424
|
+
A (Tensor): [M, K] higher precision input tensor.
|
|
3425
|
+
A_scale (Tensor): [M, cdiv(K, GROUP_SIZE)] reciprocal scale tensor per group.
|
|
3426
|
+
A_fp8 (Tensor): [M, K] fp8 scaled tensor. A_fp8 = A * a
|
|
3427
|
+
scale_ub (Tensor): [1] Maximum allowed value for scale.
|
|
3428
|
+
m_sizes (Optional[Tensor]): [G] Number of rows in each group.
|
|
3429
|
+
M (int): Number of rows.
|
|
3430
|
+
K (int): Number of columns.
|
|
3431
|
+
stride_am (int): Stride of m dimension of A.
|
|
3432
|
+
stride_ak (int): Stride of k dimension of A.
|
|
3433
|
+
stride_om (int): Stride of m dimension of output.
|
|
3434
|
+
stride_ok (int): Stride of k dimension of output.
|
|
3435
|
+
stride_a_scale_m (int): Stride of m dimension of A_scale.
|
|
3436
|
+
stride_a_scale_k (int): Stride of k dimension of A_scale.
|
|
3437
|
+
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
|
|
3438
|
+
MAX_FP8 (float): Maxmimum expressible value for FP8.
|
|
3439
|
+
EPS (float): Epsilon value for numerical stability.
|
|
3440
|
+
CLAMP_MAX (bool): Whether to apply scale_ub.
|
|
3441
|
+
USE_INT64 (bool): Whether to index using int64, which may be needed for large tensors.
|
|
3442
|
+
GROUP_SIZE (int): Group size for K dimension of A_scale and kernel.
|
|
3443
|
+
USE_M_MAJOR (bool): Whether to use grouped M-major layout for A_scale.
|
|
3444
|
+
G (int): Number of groups in A_scale, only relevant when m_sizes is provided.
|
|
3445
|
+
GROUP_LOAD (int): Number of groups to load and process simultaneously.
|
|
3446
|
+
"""
|
|
3447
|
+
pid = tl.program_id(0)
|
|
3448
|
+
if USE_INT64:
|
|
3449
|
+
pid = pid.to(tl.int64)
|
|
3450
|
+
# We load group_size * group_load chunks at a time.
|
|
3451
|
+
row_offset = pid * stride_am
|
|
3452
|
+
out_offset = pid * stride_om
|
|
3453
|
+
scale_row_offset = pid * stride_a_scale_m
|
|
3454
|
+
k_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE)
|
|
3455
|
+
scale_k_offset = tl.arange(0, GROUP_LOAD)
|
|
3456
|
+
NUM_GROUPS: tl.constexpr = K // GROUP_SIZE
|
|
3457
|
+
|
|
3458
|
+
# When dealing with an M-major grouped gemm, we need to figure out
|
|
3459
|
+
# which group this thread corresponds to and figure out the corresponding
|
|
3460
|
+
# scale offset.
|
|
3461
|
+
group_offset = 0
|
|
3462
|
+
group_cumsum = 0
|
|
3463
|
+
group_M = 0
|
|
3464
|
+
stop = False
|
|
3465
|
+
if USE_M_MAJOR and G > 0:
|
|
3466
|
+
# Iterate over groups to both compute the cumulative sum and find which group we are in.
|
|
3467
|
+
for i in range(G):
|
|
3468
|
+
if not stop:
|
|
3469
|
+
group_M = tl.cast(tl.load(m_sizes + i), pid.dtype)
|
|
3470
|
+
if (group_cumsum + group_M) <= pid:
|
|
3471
|
+
group_cumsum += group_M
|
|
3472
|
+
else:
|
|
3473
|
+
# Indicate we are finished computing cumsum.
|
|
3474
|
+
stop = True
|
|
3475
|
+
|
|
3476
|
+
group_offset = group_cumsum * NUM_GROUPS
|
|
3477
|
+
|
|
3478
|
+
for k in range(0, tl.cdiv(K, (GROUP_LOAD * GROUP_SIZE))):
|
|
3479
|
+
# Load groups of the input.
|
|
3480
|
+
chunk_offset = k_offset + k * GROUP_LOAD * GROUP_SIZE
|
|
3481
|
+
a = tl.load(
|
|
3482
|
+
A + row_offset + chunk_offset * stride_ak, mask=chunk_offset < K, other=0.0
|
|
3483
|
+
)
|
|
3484
|
+
# View loaded chunk as a set of groups.
|
|
3485
|
+
a_grouped = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE])
|
|
3486
|
+
# Reduce over groups.
|
|
3487
|
+
group_max = tl.max(tl.abs(a_grouped), axis=1)
|
|
3488
|
+
# Apply clamping if specified.
|
|
3489
|
+
if CLAMP_MAX:
|
|
3490
|
+
ub = tl.load(scale_ub)
|
|
3491
|
+
group_max = tl.clamp(group_max, EPS, ub)
|
|
3492
|
+
else:
|
|
3493
|
+
group_max = tl.maximum(group_max, EPS)
|
|
3494
|
+
# Scale and quantize.
|
|
3495
|
+
a_scale = MAX_FP8 / group_max
|
|
3496
|
+
scale_chunk_offset = scale_k_offset + k * GROUP_LOAD
|
|
3497
|
+
|
|
3498
|
+
if USE_M_MAJOR and G > 0:
|
|
3499
|
+
tl.store(
|
|
3500
|
+
A_scale
|
|
3501
|
+
+ group_offset
|
|
3502
|
+
+ (pid - group_cumsum) * stride_a_scale_k
|
|
3503
|
+
+ (scale_chunk_offset * group_M),
|
|
3504
|
+
1.0 / a_scale,
|
|
3505
|
+
mask=scale_chunk_offset < NUM_GROUPS,
|
|
3506
|
+
)
|
|
3507
|
+
else:
|
|
3508
|
+
if USE_M_MAJOR:
|
|
3509
|
+
tl.store(
|
|
3510
|
+
A_scale
|
|
3511
|
+
+ pid * stride_a_scale_k
|
|
3512
|
+
+ scale_chunk_offset * stride_a_scale_m,
|
|
3513
|
+
1.0 / a_scale,
|
|
3514
|
+
mask=scale_chunk_offset < NUM_GROUPS,
|
|
3515
|
+
)
|
|
3516
|
+
else:
|
|
3517
|
+
tl.store(
|
|
3518
|
+
A_scale + scale_row_offset + scale_chunk_offset * stride_a_scale_k,
|
|
3519
|
+
1.0 / a_scale,
|
|
3520
|
+
mask=scale_chunk_offset < NUM_GROUPS,
|
|
3521
|
+
)
|
|
3522
|
+
# Apply scale to input.
|
|
3523
|
+
a_fp8 = a_grouped * a_scale[:, None]
|
|
3524
|
+
# Clamp to FP8 range to avoid overflow
|
|
3525
|
+
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
|
|
3526
|
+
# Write to output.
|
|
3527
|
+
tl.store(
|
|
3528
|
+
A_fp8 + out_offset + chunk_offset * stride_ok,
|
|
3529
|
+
tl.ravel(a_fp8),
|
|
3530
|
+
mask=chunk_offset < K,
|
|
3531
|
+
)
|
|
3532
|
+
|
|
3533
|
+
|
|
3534
|
+
def triton_quantize_fp8_group(
|
|
3535
|
+
x: torch.Tensor,
|
|
3536
|
+
group_size: int = 128,
|
|
3537
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
3538
|
+
m_sizes: Optional[torch.Tensor] = None,
|
|
3539
|
+
k_major: bool = True,
|
|
3540
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
3541
|
+
"""
|
|
3542
|
+
Quantize a tensor to fp8 with group-wise scalings.
|
|
3543
|
+
|
|
3544
|
+
Scale per group i is computed as 1 / (MAX_FP8 / max(abs(x[i:i+group_size])))
|
|
3545
|
+
|
|
3546
|
+
Args:
|
|
3547
|
+
x (torch.Tensor): [M, K] higher precision input tensor.
|
|
3548
|
+
group_size (int): Group size for M dimension of scale.
|
|
3549
|
+
scale_ub: Maximum allowed value for scale.
|
|
3550
|
+
m_sizes: Optional input for grouped gemm to specify the number of rows in each group.
|
|
3551
|
+
k_major (bool): Whether output scales should be K major (True) or MN major (False).
|
|
3552
|
+
|
|
3553
|
+
Returns:
|
|
3554
|
+
torch.Tensor: [M, K] fp8 scaled tensor.
|
|
3555
|
+
torch.Tensor: [M, cdiv(K, group_size)] reciprocal scale tensor per group.
|
|
3556
|
+
"""
|
|
3557
|
+
assert x.device != torch.device(
|
|
3558
|
+
"cpu"
|
|
3559
|
+
), "Triton groupwise quantization not supported on cpu."
|
|
3560
|
+
|
|
3561
|
+
if scale_ub is not None and scale_ub.device != x.device:
|
|
3562
|
+
raise Exception("'scale_ub' must be on the same device as 'a'")
|
|
3563
|
+
if m_sizes is not None and m_sizes.device != x.device:
|
|
3564
|
+
raise Exception("'m_sizes' must be on the same device as 'a'")
|
|
3565
|
+
|
|
3566
|
+
x_shape = x.shape
|
|
3567
|
+
x = x.view(-1, x.size(-1))
|
|
3568
|
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
|
3569
|
+
M, K = x.shape
|
|
3570
|
+
k_groups = triton.cdiv(K, group_size)
|
|
3571
|
+
if k_major:
|
|
3572
|
+
x_scale = torch.empty((M, k_groups), device=x.device, dtype=torch.float32)
|
|
3573
|
+
else:
|
|
3574
|
+
x_scale = torch.empty((k_groups, M), device=x.device, dtype=torch.float32)
|
|
3575
|
+
x_fp8 = torch.empty((M, K), device=x.device, dtype=pt_dtype)
|
|
3576
|
+
_kernel_quantize_fp8_group[(M,)](
|
|
3577
|
+
x,
|
|
3578
|
+
x_scale,
|
|
3579
|
+
x_fp8,
|
|
3580
|
+
scale_ub,
|
|
3581
|
+
m_sizes,
|
|
3582
|
+
M,
|
|
3583
|
+
K,
|
|
3584
|
+
x.stride(0),
|
|
3585
|
+
x.stride(1),
|
|
3586
|
+
x_fp8.stride(0),
|
|
3587
|
+
x_fp8.stride(1),
|
|
3588
|
+
x_scale.stride(0),
|
|
3589
|
+
x_scale.stride(1),
|
|
3590
|
+
TL_FP8_DTYPE=tl_dtype,
|
|
3591
|
+
MAX_FP8=max_fp8,
|
|
3592
|
+
EPS=eps,
|
|
3593
|
+
CLAMP_MAX=scale_ub is not None,
|
|
3594
|
+
USE_INT64=x.numel() > (2**32 - 1),
|
|
3595
|
+
GROUP_SIZE=group_size,
|
|
3596
|
+
USE_M_MAJOR=m_sizes is not None or k_major is False,
|
|
3597
|
+
G=m_sizes.numel() if m_sizes is not None else 0,
|
|
3598
|
+
)
|
|
3599
|
+
return x_fp8.view(x_shape), x_scale
|
|
3600
|
+
|
|
3601
|
+
|
|
3602
|
+
def quantize_fp8_group(
|
|
3603
|
+
x: torch.Tensor,
|
|
3604
|
+
group_size: int = 128,
|
|
3605
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
3606
|
+
m_sizes: Optional[torch.Tensor] = None,
|
|
3607
|
+
k_major: bool = True,
|
|
3608
|
+
use_triton: bool = True,
|
|
3609
|
+
output_device: Optional[torch.device] = None,
|
|
3610
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
3611
|
+
"""
|
|
3612
|
+
Quantize a tensor to fp8 with group-wise scalings and optionally move to output device.
|
|
3613
|
+
|
|
3614
|
+
Scale per group i is computed as 1 / (MAX_FP8 / max(abs(x[i:i+group_size])))
|
|
3615
|
+
|
|
3616
|
+
Args:
|
|
3617
|
+
x (Tensor): [M, K] higher precision input tensor.
|
|
3618
|
+
group_size (int): Group size for M dimension of scale.
|
|
3619
|
+
scale_ub: Maximum allowed value for scale.
|
|
3620
|
+
m_sizes: Optional input for grouped gemm to specify the number of rows in each group.
|
|
3621
|
+
k_major (bool): Whether output scales should be K major (True) or MN major (False).
|
|
3622
|
+
This is needed because some kernels like cutlass require a special layout for scales.
|
|
3623
|
+
use_triton (bool): Whether to use triton kernel or pytorch.
|
|
3624
|
+
output_device (torch.device): Device to optionally move the scaled tensors to.
|
|
3625
|
+
|
|
3626
|
+
Returns:
|
|
3627
|
+
torch.Tensor: [M, K] fp8 scaled tensor.
|
|
3628
|
+
torch.Tensor: [M, cdiv(K, group_size)] reciprocal scale tensor per group.
|
|
3629
|
+
"""
|
|
3630
|
+
x_shape = x.shape
|
|
3631
|
+
x = x.view(-1, x.size(-1))
|
|
3632
|
+
if x.device == torch.device("cpu"):
|
|
3633
|
+
logger.info("Triton does not support cpu, falling back to torch ops.")
|
|
3634
|
+
use_triton = False
|
|
3635
|
+
if use_triton:
|
|
3636
|
+
xq, x_scale = triton_quantize_fp8_group(
|
|
3637
|
+
x, group_size, scale_ub, m_sizes, k_major
|
|
3638
|
+
)
|
|
3639
|
+
return xq.view(x_shape), x_scale
|
|
3640
|
+
# else use pytorch implementation.
|
|
3641
|
+
if not output_device:
|
|
3642
|
+
output_device = x.device
|
|
3643
|
+
|
|
3644
|
+
# Get constants.
|
|
3645
|
+
pt_dtype, _, max_fp8, eps = get_fp8_constants()
|
|
3646
|
+
|
|
3647
|
+
M, K = x.shape
|
|
3648
|
+
assert (
|
|
3649
|
+
K % group_size == 0
|
|
3650
|
+
), "K must be divisible by group_size for cpu implementation."
|
|
3651
|
+
assert m_sizes is None, "m_sizes is not supported for cpu implementation."
|
|
3652
|
+
k_groups = triton.cdiv(K, group_size)
|
|
3653
|
+
# View input as colleciton of groups for reduction.
|
|
3654
|
+
x_grouped = x.view(M, k_groups, group_size).to(torch.float32)
|
|
3655
|
+
# Reduce over groups.
|
|
3656
|
+
group_max = x_grouped.abs().amax(dim=2)
|
|
3657
|
+
# Apply clamping.
|
|
3658
|
+
group_max = (
|
|
3659
|
+
torch.clamp(group_max, min=eps, max=scale_ub.item())
|
|
3660
|
+
if scale_ub
|
|
3661
|
+
else torch.clamp(group_max, min=eps)
|
|
3662
|
+
)
|
|
3663
|
+
x_scale = torch.empty((M, k_groups), dtype=torch.float32, device=output_device)
|
|
3664
|
+
x_scale = max_fp8 / group_max # pyre-ignore
|
|
3665
|
+
# pyre-ignore[16]: Undefined attribute [16]
|
|
3666
|
+
x_scale[x_scale == float("inf")] = 1.0
|
|
3667
|
+
# pyre-ignore[16]: Undefined attribute [16]
|
|
3668
|
+
x_fp8 = x.view(-1, k_groups, group_size) * x_scale.unsqueeze(2)
|
|
3669
|
+
# Cast and move data to output device (for cpu weight loading).
|
|
3670
|
+
x_fp8 = x_fp8.to(device=output_device, dtype=pt_dtype)
|
|
3671
|
+
x_scale = x_scale.to(output_device) # pyre-ignore
|
|
3672
|
+
if not k_major:
|
|
3673
|
+
x_scale = x_scale.t().contiguous()
|
|
3674
|
+
return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore
|
|
3675
|
+
|
|
3676
|
+
|
|
3677
|
+
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
|
|
3678
|
+
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
|
|
3679
|
+
|
|
3680
|
+
|
|
3681
|
+
# Force a failure instead of a warning when all configs are pruned.
|
|
3682
|
+
# TODO: Determine a better approach for model level testing. We need
|
|
3683
|
+
# to standardize our approach around prune_configs in general.
|
|
3684
|
+
FORCE_FAILURE_ON_EMPTY_CONFIGS = False
|
|
3685
|
+
|
|
3686
|
+
|
|
3687
|
+
def is_invalid_config(config, N, M, K, mfma, use_bias):
|
|
3688
|
+
"""
|
|
3689
|
+
Contains all of the configuration checks for prune_configs
|
|
3690
|
+
that will result in an invalid result if select as the config.
|
|
3691
|
+
|
|
3692
|
+
This is done to ensure that if no config is "optimal" for a given
|
|
3693
|
+
shape we don't accidentally select
|
|
3694
|
+
"""
|
|
3695
|
+
BLOCK_SIZE_M = config.kwargs.get("BLOCK_M")
|
|
3696
|
+
BLOCK_SIZE_N = config.kwargs.get("BLOCK_N")
|
|
3697
|
+
BLOCK_SIZE_K = config.kwargs.get("BLOCK_K")
|
|
3698
|
+
SPLIT_K = config.kwargs.get("SPLIT_K")
|
|
3699
|
+
matrix_instr_nonkdim = config.kwargs.get("matrix_instr_nonkdim")
|
|
3700
|
+
if matrix_instr_nonkdim > mfma:
|
|
3701
|
+
return True
|
|
3702
|
+
if mfma == 4 and BLOCK_SIZE_K < 64:
|
|
3703
|
+
return True
|
|
3704
|
+
# some layouts could not work properly in case
|
|
3705
|
+
# number elements per thread is less 1
|
|
3706
|
+
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
|
3707
|
+
return True
|
|
3708
|
+
if BLOCK_SIZE_M < matrix_instr_nonkdim or BLOCK_SIZE_N < matrix_instr_nonkdim:
|
|
3709
|
+
return True
|
|
3710
|
+
if M <= matrix_instr_nonkdim and BLOCK_SIZE_M != matrix_instr_nonkdim:
|
|
3711
|
+
return True
|
|
3712
|
+
if N <= matrix_instr_nonkdim and BLOCK_SIZE_N != matrix_instr_nonkdim:
|
|
3713
|
+
return True
|
|
3714
|
+
# split_k cannot be used if there is a bias
|
|
3715
|
+
if use_bias and SPLIT_K != 1:
|
|
3716
|
+
return True
|
|
3717
|
+
return False
|
|
3718
|
+
|
|
3719
|
+
|
|
3720
|
+
# Configs adapted from https://github.com/ROCm/triton/blob/main_perf/python/perf-kernels/tools/tune_gemm/tune_gemm.py
|
|
3721
|
+
def prune_configs(configs, named_args, **kwargs):
|
|
3722
|
+
|
|
3723
|
+
pruned_configs = []
|
|
3724
|
+
M = named_args["M"]
|
|
3725
|
+
N = named_args["N"]
|
|
3726
|
+
K = named_args["K"]
|
|
3727
|
+
elemBytes_a = named_args["A"].element_size()
|
|
3728
|
+
elemBytes_b = named_args["B"].element_size()
|
|
3729
|
+
use_bias = kwargs["USE_BIAS"]
|
|
3730
|
+
|
|
3731
|
+
if M < 32 or N < 32:
|
|
3732
|
+
mfma = 16
|
|
3733
|
+
else:
|
|
3734
|
+
mfma = 32
|
|
3735
|
+
|
|
3736
|
+
for config in configs:
|
|
3737
|
+
BLOCK_SIZE_M = config.kwargs.get("BLOCK_M")
|
|
3738
|
+
BLOCK_SIZE_N = config.kwargs.get("BLOCK_N")
|
|
3739
|
+
BLOCK_SIZE_K = config.kwargs.get("BLOCK_K")
|
|
3740
|
+
SPLIT_K = config.kwargs.get("SPLIT_K")
|
|
3741
|
+
GROUP_M = config.kwargs.get("GROUP_M")
|
|
3742
|
+
if is_invalid_config(config, N, M, K, mfma, use_bias):
|
|
3743
|
+
continue
|
|
3744
|
+
# Skip BLOCK_SIZE that is too large compare to M/N
|
|
3745
|
+
# unless BLOCK_SIZE is already small enough
|
|
3746
|
+
if BLOCK_SIZE_M > M * 2 and BLOCK_SIZE_M != 16:
|
|
3747
|
+
continue
|
|
3748
|
+
if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16:
|
|
3749
|
+
continue
|
|
3750
|
+
# skip large split_k when not necessary
|
|
3751
|
+
if SPLIT_K != 1 and not need_split_k(M, N, K):
|
|
3752
|
+
continue
|
|
3753
|
+
# skip large GROUP_M
|
|
3754
|
+
if GROUP_M * BLOCK_SIZE_M >= M and GROUP_M != 1:
|
|
3755
|
+
continue
|
|
3756
|
+
# out of shared memory resource
|
|
3757
|
+
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
|
3758
|
+
LDS = (
|
|
3759
|
+
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
|
|
3760
|
+
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
|
|
3761
|
+
)
|
|
3762
|
+
if LDS > 65536:
|
|
3763
|
+
continue
|
|
3764
|
+
pruned_configs.append(config)
|
|
3765
|
+
|
|
3766
|
+
print(f"{len(configs)=} {len(pruned_configs)=} for {M=} {N=} {K=}")
|
|
3767
|
+
if len(pruned_configs) == 0:
|
|
3768
|
+
if not FORCE_FAILURE_ON_EMPTY_CONFIGS:
|
|
3769
|
+
# Prune configs that can lead to incorrect results even if all configs are sub-optimal.
|
|
3770
|
+
candidate_configs = [
|
|
3771
|
+
c for c in configs if not is_invalid_config(c, N, M, K, mfma, use_bias)
|
|
3772
|
+
]
|
|
3773
|
+
print(f"No configs left after pruning! {M=} {N=} {K=}")
|
|
3774
|
+
pruned_configs = candidate_configs[:10]
|
|
3775
|
+
if len(pruned_configs) == 0:
|
|
3776
|
+
raise RuntimeError(
|
|
3777
|
+
"No valid configs left after pruning! Consider autotuning further with TritonBench"
|
|
3778
|
+
)
|
|
3779
|
+
return pruned_configs
|
|
3780
|
+
|
|
3781
|
+
|
|
3782
|
+
def get_full_non_persistent_tuning_space():
|
|
3783
|
+
configs = []
|
|
3784
|
+
|
|
3785
|
+
block_mn_range = [16, 32, 64, 128, 256]
|
|
3786
|
+
block_k_range = [16, 32, 64, 128, 256]
|
|
3787
|
+
split_k_range = [1]
|
|
3788
|
+
num_warps_range = [1, 2, 4, 8]
|
|
3789
|
+
group_m_range = [1, 2, 4, 8, 16, 32]
|
|
3790
|
+
num_stage_range = [2]
|
|
3791
|
+
waves_per_eu_range = [0]
|
|
3792
|
+
matrix_instr_nonkdim_range = [16, 32]
|
|
3793
|
+
kpack_range = [1, 2]
|
|
3794
|
+
|
|
3795
|
+
for block_m in block_mn_range:
|
|
3796
|
+
for block_n in block_mn_range:
|
|
3797
|
+
for block_k in block_k_range:
|
|
3798
|
+
for num_warps in num_warps_range:
|
|
3799
|
+
for group_m in group_m_range:
|
|
3800
|
+
for split_k in split_k_range:
|
|
3801
|
+
for num_stages in num_stage_range:
|
|
3802
|
+
for waves_per_eu in waves_per_eu_range:
|
|
3803
|
+
for (
|
|
3804
|
+
matrix_instr_nonkdim
|
|
3805
|
+
) in matrix_instr_nonkdim_range:
|
|
3806
|
+
for kpack in kpack_range:
|
|
3807
|
+
configs.append(
|
|
3808
|
+
triton.Config(
|
|
3809
|
+
{
|
|
3810
|
+
"BLOCK_M": block_m,
|
|
3811
|
+
"BLOCK_N": block_n,
|
|
3812
|
+
"BLOCK_K": block_k,
|
|
3813
|
+
"GROUP_M": group_m,
|
|
3814
|
+
"SPLIT_K": split_k,
|
|
3815
|
+
"waves_per_eu": waves_per_eu,
|
|
3816
|
+
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
|
3817
|
+
"kpack": kpack,
|
|
3818
|
+
},
|
|
3819
|
+
num_warps=num_warps,
|
|
3820
|
+
num_stages=num_stages,
|
|
3821
|
+
)
|
|
3822
|
+
)
|
|
3823
|
+
return configs
|
|
3824
|
+
|
|
3825
|
+
|
|
3826
|
+
MATMUL_CONFIGS_NON_PERSISTENT: list[Config] = get_full_non_persistent_tuning_space()
|
|
3827
|
+
# (BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, SPLIT_K, waves_per_eu, matrix_instr_nonkdim, kpack, num_warps, num_stages)
|
|
3828
|
+
_MATMUL_CONFIG_TUPLES_PINGPONG_4K_8K_16K = [
|
|
3829
|
+
(16, 16, 256, 1, 1, 8, 16, 2, 2, 2),
|
|
3830
|
+
(16, 16, 256, 1, 1, 0, 16, 2, 2, 2),
|
|
3831
|
+
(32, 64, 512, 1, 1, 2, 16, 2, 8, 2),
|
|
3832
|
+
(64, 64, 256, 1, 1, 2, 16, 2, 4, 2),
|
|
3833
|
+
(256, 256, 128, 32, 1, 2, 16, 1, 8, 2),
|
|
3834
|
+
(256, 256, 128, 2, 1, 0, 32, 2, 8, 2),
|
|
3835
|
+
(256, 256, 128, 1, 1, 0, 32, 2, 8, 2),
|
|
3836
|
+
(256, 256, 128, 2, 1, 0, 16, 1, 8, 2),
|
|
3837
|
+
(256, 256, 64, 2, 1, 2, 16, 1, 8, 2),
|
|
3838
|
+
(128, 256, 64, 2, 1, 2, 16, 1, 4, 2),
|
|
3839
|
+
(256, 128, 128, 4, 1, 0, 16, 1, 8, 2),
|
|
3840
|
+
(128, 128, 128, 1, 1, 2, 16, 2, 4, 2),
|
|
3841
|
+
(128, 128, 256, 1, 1, 2, 16, 2, 8, 2),
|
|
3842
|
+
(128, 128, 64, 4, 1, 2, 16, 2, 4, 2),
|
|
3843
|
+
(128, 128, 64, 1, 1, 2, 16, 2, 4, 2),
|
|
3844
|
+
(128, 64, 64, 4, 1, 0, 16, 2, 4, 2),
|
|
3845
|
+
(128, 64, 64, 1, 1, 0, 16, 2, 4, 2),
|
|
3846
|
+
(256, 128, 128, 1, 1, 2, 16, 1, 8, 2),
|
|
3847
|
+
(128, 256, 128, 2, 1, 2, 16, 2, 4, 1),
|
|
3848
|
+
(256, 128, 64, 2, 1, 2, 16, 1, 4, 2),
|
|
3849
|
+
(128, 128, 256, 2, 1, 0, 16, 2, 8, 2),
|
|
3850
|
+
(128, 64, 128, 2, 1, 2, 16, 2, 4, 2),
|
|
3851
|
+
(128, 128, 64, 2, 1, 0, 16, 1, 4, 2),
|
|
3852
|
+
(128, 128, 128, 1, 1, 2, 16, 1, 4, 2),
|
|
3853
|
+
]
|
|
3854
|
+
|
|
3855
|
+
|
|
3856
|
+
def _should_skip_config(block_k, matrix_instr_nonkdim):
|
|
3857
|
+
"""Skip config if BLOCK_K=64 and matrix_instr_nonkdim=16 on GFX95+"""
|
|
3858
|
+
try:
|
|
3859
|
+
return (
|
|
3860
|
+
block_k == 64
|
|
3861
|
+
and matrix_instr_nonkdim == 16
|
|
3862
|
+
and torch.version.hip is not None
|
|
3863
|
+
and torch.cuda.get_device_capability() >= (9, 5)
|
|
3864
|
+
)
|
|
3865
|
+
except RuntimeError:
|
|
3866
|
+
# If no HIP GPUs are available, we can't check device capability
|
|
3867
|
+
# so we don't skip any configs
|
|
3868
|
+
return False
|
|
3869
|
+
|
|
3870
|
+
|
|
3871
|
+
MATMUL_CONFIGS_NON_PERSISTENT_PINGPONG_4K_8K_16K = [
|
|
3872
|
+
triton.Config(
|
|
3873
|
+
{
|
|
3874
|
+
"BLOCK_M": block_m,
|
|
3875
|
+
"BLOCK_N": block_n,
|
|
3876
|
+
"BLOCK_K": block_k,
|
|
3877
|
+
"GROUP_M": group_m,
|
|
3878
|
+
"SPLIT_K": split_k,
|
|
3879
|
+
"waves_per_eu": waves_per_eu,
|
|
3880
|
+
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
|
3881
|
+
"kpack": kpack,
|
|
3882
|
+
},
|
|
3883
|
+
num_warps=num_warps,
|
|
3884
|
+
num_stages=num_stages,
|
|
3885
|
+
)
|
|
3886
|
+
for block_m, block_n, block_k, group_m, split_k, waves_per_eu, matrix_instr_nonkdim, kpack, num_warps, num_stages in _MATMUL_CONFIG_TUPLES_PINGPONG_4K_8K_16K
|
|
3887
|
+
if not _should_skip_config(block_k, matrix_instr_nonkdim)
|
|
3888
|
+
]
|
|
3889
|
+
|
|
3890
|
+
# Set this to enable full autotuning for proper benchmarking.
|
|
3891
|
+
# This should only be used when invoking the kernel through
|
|
3892
|
+
# Triton directly (e.g. TritonBench)
|
|
3893
|
+
#
|
|
3894
|
+
# NOTE: This will SIGNIFICANTLY increase autotuning time, often
|
|
3895
|
+
# taking hours. You should combine this with TRITON_PRINT_AUTOTUNING=1
|
|
3896
|
+
# to extract and add the optimal autotuning configs to
|
|
3897
|
+
# MATMUL_CONFIGS_NON_PERSISTENT_PINGPONG_4K_8K_16K.
|
|
3898
|
+
|
|
3899
|
+
FULL_NON_PERSISTENT_AUTOTUNING = False
|
|
3900
|
+
USED_MATMUL_NON_PERSISTENT_CONFIGS = (
|
|
3901
|
+
MATMUL_CONFIGS_NON_PERSISTENT
|
|
3902
|
+
if FULL_NON_PERSISTENT_AUTOTUNING
|
|
3903
|
+
else MATMUL_CONFIGS_NON_PERSISTENT_PINGPONG_4K_8K_16K
|
|
3904
|
+
)
|
|
3905
|
+
|
|
3906
|
+
|
|
3907
|
+
@triton.autotune(
|
|
3908
|
+
configs=USED_MATMUL_NON_PERSISTENT_CONFIGS,
|
|
3909
|
+
key=["M", "N", "K"],
|
|
3910
|
+
prune_configs_by={
|
|
3911
|
+
"early_config_prune": prune_configs,
|
|
3912
|
+
"perf_model": None,
|
|
3913
|
+
"top_k": None,
|
|
3914
|
+
},
|
|
3915
|
+
use_cuda_graph=FULL_NON_PERSISTENT_AUTOTUNING,
|
|
3916
|
+
)
|
|
3917
|
+
@triton.heuristics(
|
|
3918
|
+
{
|
|
3919
|
+
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
3920
|
+
}
|
|
3921
|
+
)
|
|
3922
|
+
@triton.jit
|
|
3923
|
+
def _kernel_matmul_fp8_row_non_persistent(
|
|
3924
|
+
A,
|
|
3925
|
+
B,
|
|
3926
|
+
C,
|
|
3927
|
+
M,
|
|
3928
|
+
N,
|
|
3929
|
+
K,
|
|
3930
|
+
m_key,
|
|
3931
|
+
n_key,
|
|
3932
|
+
k_key,
|
|
3933
|
+
A_scale,
|
|
3934
|
+
B_scale,
|
|
3935
|
+
Bias,
|
|
3936
|
+
stride_am,
|
|
3937
|
+
stride_ak,
|
|
3938
|
+
stride_bn,
|
|
3939
|
+
stride_bk,
|
|
3940
|
+
stride_cm,
|
|
3941
|
+
stride_cn,
|
|
3942
|
+
dot_out_dtype: tl.constexpr,
|
|
3943
|
+
allow_tf32: tl.constexpr,
|
|
3944
|
+
fp8_fast_accum: tl.constexpr,
|
|
3945
|
+
BLOCK_M: tl.constexpr,
|
|
3946
|
+
BLOCK_N: tl.constexpr,
|
|
3947
|
+
BLOCK_K: tl.constexpr,
|
|
3948
|
+
GROUP_M: tl.constexpr,
|
|
3949
|
+
SPLIT_K: tl.constexpr,
|
|
3950
|
+
EVEN_K: tl.constexpr,
|
|
3951
|
+
USE_BIAS: tl.constexpr,
|
|
3952
|
+
AB_DTYPE: tl.constexpr,
|
|
3953
|
+
) -> None:
|
|
3954
|
+
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
|
|
3955
|
+
|
|
3956
|
+
performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
|
|
3957
|
+
|
|
3958
|
+
Args:
|
|
3959
|
+
A (TensorWrapper): [M, K] input tensor.
|
|
3960
|
+
B (TensorWrapper): [N, K] input tensor.
|
|
3961
|
+
C (TensorWrapper): [M, N] output tensor.
|
|
3962
|
+
M (int): M dimension of input tensor.
|
|
3963
|
+
N (int): N dimension of input tensor.
|
|
3964
|
+
K (int): K dimension of input tensor.
|
|
3965
|
+
m_key (int): Autotuning key for M dimension of input tensor.
|
|
3966
|
+
n_key (int): Autotuning key for N dimension of input tensor.
|
|
3967
|
+
k_key (int): Autotuning key for K dimension of input tensor.
|
|
3968
|
+
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
|
|
3969
|
+
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
|
|
3970
|
+
Bias (tensorWrapper): [N] Optional bias tensor.
|
|
3971
|
+
stride_am (int): Stride of M dimension of A.
|
|
3972
|
+
stride_ak (int): Stride of K dimension of A.
|
|
3973
|
+
stride_bn (int): Stride of N dimension of B.
|
|
3974
|
+
stride_bk (int): Stride of K dimension of B.
|
|
3975
|
+
stride_cm (int): Stride of M dimension of C.
|
|
3976
|
+
stride_cn (int): Stride of N dimension of C.
|
|
3977
|
+
dot_out_dtype (torch.dtype): Output type of tensor core.
|
|
3978
|
+
allow_tf32 (bool): Whether to use TF32 for tensor core.
|
|
3979
|
+
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
|
|
3980
|
+
BLOCK_M (int): Block size for M dimension.
|
|
3981
|
+
BLOCK_N (int): Block size for N dimension.
|
|
3982
|
+
BLOCK_K (int): Block size for K dimension.
|
|
3983
|
+
GROUP_M (int): Number of groups for M dimension swizzle.
|
|
3984
|
+
SPLIT_K (int): Number of SM's to launch per row.
|
|
3985
|
+
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
|
|
3986
|
+
USE_BIAS (bool): Whether to use bias.
|
|
3987
|
+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
|
|
3988
|
+
"""
|
|
3989
|
+
tl.assume(M >= 0)
|
|
3990
|
+
tl.assume(N >= 0)
|
|
3991
|
+
tl.assume(K >= 0)
|
|
3992
|
+
tl.assume(stride_am >= 0)
|
|
3993
|
+
tl.assume(stride_ak >= 0)
|
|
3994
|
+
tl.assume(stride_bn >= 0)
|
|
3995
|
+
tl.assume(stride_bk >= 0)
|
|
3996
|
+
tl.assume(stride_cm >= 0)
|
|
3997
|
+
tl.assume(stride_cn >= 0)
|
|
3998
|
+
# Matrix multiplication.
|
|
3999
|
+
pid = tl.program_id(0)
|
|
4000
|
+
pid_z = tl.program_id(1)
|
|
4001
|
+
grid_m = tl.cdiv(M, BLOCK_M)
|
|
4002
|
+
grid_n = tl.cdiv(N, BLOCK_N)
|
|
4003
|
+
# Re-order program ID for better L2 performance (swizzle).
|
|
4004
|
+
width = GROUP_M * grid_n
|
|
4005
|
+
group_id = pid // width
|
|
4006
|
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
|
4007
|
+
pid_m = group_id * GROUP_M + ((pid % width) % group_size)
|
|
4008
|
+
pid_n = (pid % width) // (group_size)
|
|
4009
|
+
tl.assume(pid_m >= 0)
|
|
4010
|
+
tl.assume(pid_n >= 0)
|
|
4011
|
+
# Do matrix multiplication.
|
|
4012
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
4013
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
4014
|
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
|
4015
|
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
|
4016
|
+
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
4017
|
+
# Pointers.
|
|
4018
|
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
|
4019
|
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
|
4020
|
+
acc_dtype = tl.float32 if allow_tf32 else dot_out_dtype
|
|
4021
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
|
|
4022
|
+
|
|
4023
|
+
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
|
4024
|
+
if EVEN_K:
|
|
4025
|
+
a = tl.load(A)
|
|
4026
|
+
b = tl.load(B)
|
|
4027
|
+
else:
|
|
4028
|
+
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
|
4029
|
+
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
|
|
4030
|
+
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
|
|
4031
|
+
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
|
|
4032
|
+
if AB_DTYPE:
|
|
4033
|
+
a = a.to(C.dtype.element_ty)
|
|
4034
|
+
b = b.to(C.dtype.element_ty)
|
|
4035
|
+
if fp8_fast_accum:
|
|
4036
|
+
acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32)
|
|
4037
|
+
else:
|
|
4038
|
+
acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32)
|
|
4039
|
+
|
|
4040
|
+
A += BLOCK_K * SPLIT_K * stride_ak
|
|
4041
|
+
B += BLOCK_K * SPLIT_K * stride_bk
|
|
4042
|
+
|
|
4043
|
+
# rematerialize rm and rn to save registers
|
|
4044
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
4045
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
4046
|
+
|
|
4047
|
+
# Invert scaling.
|
|
4048
|
+
a_scale = tl.load(A_scale + rm, mask=rm < M)
|
|
4049
|
+
b_scale = tl.load(B_scale + rn, mask=rn < N)
|
|
4050
|
+
# Invert vector, then multiply on matrix for speed.
|
|
4051
|
+
# pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
|
|
4052
|
+
scale = a_scale[:, None] * b_scale[None, :]
|
|
4053
|
+
acc *= scale
|
|
4054
|
+
|
|
4055
|
+
# Load and add bias if specified.
|
|
4056
|
+
if USE_BIAS:
|
|
4057
|
+
bias = tl.load(Bias + rn, mask=rn < N)
|
|
4058
|
+
acc += bias[None, :]
|
|
4059
|
+
|
|
4060
|
+
acc = acc.to(C.dtype.element_ty)
|
|
4061
|
+
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
|
4062
|
+
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
|
4063
|
+
# Handles write-back with reduction-splitting
|
|
4064
|
+
if SPLIT_K == 1:
|
|
4065
|
+
tl.store(C, acc, mask=mask)
|
|
4066
|
+
else:
|
|
4067
|
+
tl.atomic_add(C, acc, mask=mask)
|
|
4068
|
+
|
|
4069
|
+
|
|
4070
|
+
@triton.autotune(
|
|
4071
|
+
configs=[Config({"BLOCK_M": 16, "BLOCK_K": 512, "NUM_STAGES": 2})],
|
|
4072
|
+
key=["M", "K"],
|
|
4073
|
+
)
|
|
4074
|
+
@triton.jit
|
|
4075
|
+
def _kernel_dequantize_fp8_row(
|
|
4076
|
+
xq_ptr,
|
|
4077
|
+
x_scale_ptr,
|
|
4078
|
+
x_dequant_ptr,
|
|
4079
|
+
M,
|
|
4080
|
+
K,
|
|
4081
|
+
stride_xm,
|
|
4082
|
+
stride_xk,
|
|
4083
|
+
stride_xdqm,
|
|
4084
|
+
stride_xdqk,
|
|
4085
|
+
BLOCK_M: tl.constexpr,
|
|
4086
|
+
BLOCK_K: tl.constexpr,
|
|
4087
|
+
NUM_STAGES: tl.constexpr,
|
|
4088
|
+
USE_INT64: tl.constexpr,
|
|
4089
|
+
):
|
|
4090
|
+
"""
|
|
4091
|
+
Kernel to dequantize FP8 tensor to BF16 tensor.
|
|
4092
|
+
Args:
|
|
4093
|
+
xq_ptr (tl.constexpr): Pointer to FP8 tensor.
|
|
4094
|
+
x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
|
|
4095
|
+
x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
|
|
4096
|
+
M (tl.constexpr): M dimension of input tensor.
|
|
4097
|
+
K (tl.constexpr): K dimension of input tensor (along which scales are applied)
|
|
4098
|
+
BLOCK_SIZE (tl.constexpr): Block size for the K dimension.
|
|
4099
|
+
"""
|
|
4100
|
+
pid = tl.program_id(axis=0)
|
|
4101
|
+
if USE_INT64:
|
|
4102
|
+
pid = pid.to(tl.int64)
|
|
4103
|
+
offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
4104
|
+
offs_k = tl.arange(0, BLOCK_K)
|
|
4105
|
+
scales = tl.load(x_scale_ptr + offs_m)
|
|
4106
|
+
|
|
4107
|
+
for _k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
|
|
4108
|
+
mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
|
|
4109
|
+
xq = tl.load(
|
|
4110
|
+
xq_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk,
|
|
4111
|
+
mask=mask,
|
|
4112
|
+
)
|
|
4113
|
+
x_dq = xq * scales[:, None]
|
|
4114
|
+
tl.store(
|
|
4115
|
+
x_dequant_ptr
|
|
4116
|
+
+ offs_m[:, None] * stride_xdqm
|
|
4117
|
+
+ offs_k[None, :] * stride_xdqk,
|
|
4118
|
+
x_dq,
|
|
4119
|
+
mask=mask,
|
|
4120
|
+
)
|
|
4121
|
+
offs_k += BLOCK_K
|
|
4122
|
+
|
|
4123
|
+
|
|
4124
|
+
def dequantize_fp8_row(
|
|
4125
|
+
xq: torch.Tensor,
|
|
4126
|
+
x_scale: torch.Tensor,
|
|
4127
|
+
) -> torch.Tensor:
|
|
4128
|
+
"""
|
|
4129
|
+
Rowwise Dequantize FP8 tensor to BF16 tensor along last axis.
|
|
4130
|
+
|
|
4131
|
+
Args:
|
|
4132
|
+
xq (torch.Tensor): FP8 tensor to be dequantized.
|
|
4133
|
+
x_scale (torch.Tensor): FP8 scale tensor.
|
|
4134
|
+
|
|
4135
|
+
Returns:
|
|
4136
|
+
torch.Tensor: Dequantized BF16 tensor.
|
|
4137
|
+
"""
|
|
4138
|
+
|
|
4139
|
+
assert (
|
|
4140
|
+
xq.is_contiguous() and x_scale.is_contiguous()
|
|
4141
|
+
), "Input tensors must be contiguous"
|
|
4142
|
+
x_dequant = torch.empty_like(xq, dtype=torch.bfloat16)
|
|
4143
|
+
|
|
4144
|
+
# Reshape to 2-d array keeping last dim only.
|
|
4145
|
+
K = xq.shape[-1]
|
|
4146
|
+
xq = xq.reshape(-1, K)
|
|
4147
|
+
M = xq.shape[0]
|
|
4148
|
+
use_int64 = xq.numel() > 2**31
|
|
4149
|
+
|
|
4150
|
+
def grid(meta: dict[str, int]) -> tuple[int]:
|
|
4151
|
+
return (triton.cdiv(M, meta["BLOCK_M"]),)
|
|
4152
|
+
|
|
4153
|
+
with torch.cuda.device(xq.device.index):
|
|
4154
|
+
_kernel_dequantize_fp8_row[grid](
|
|
4155
|
+
xq,
|
|
4156
|
+
x_scale,
|
|
4157
|
+
x_dequant,
|
|
4158
|
+
M,
|
|
4159
|
+
K,
|
|
4160
|
+
xq.stride(0),
|
|
4161
|
+
xq.stride(1),
|
|
4162
|
+
xq.stride(0), # Use squashed stride.
|
|
4163
|
+
xq.stride(1),
|
|
4164
|
+
USE_INT64=use_int64,
|
|
4165
|
+
)
|
|
4166
|
+
return x_dequant
|
|
4167
|
+
|
|
4168
|
+
|
|
4169
|
+
@triton.autotune(
|
|
4170
|
+
configs=[Config({"BLOCK_M": 16, "BLOCK_K": 512, "NUM_STAGES": 2})],
|
|
4171
|
+
key=["M", "K"],
|
|
4172
|
+
)
|
|
4173
|
+
@triton.jit
|
|
4174
|
+
def _kernel_dequantize_fp8_packed_row(
|
|
4175
|
+
xq_ptr,
|
|
4176
|
+
x_scale_ptr,
|
|
4177
|
+
x_dequant_ptr,
|
|
4178
|
+
M,
|
|
4179
|
+
K,
|
|
4180
|
+
stride_xm,
|
|
4181
|
+
stride_xk,
|
|
4182
|
+
stride_xdqm,
|
|
4183
|
+
stride_xdqk,
|
|
4184
|
+
BLOCK_M: tl.constexpr,
|
|
4185
|
+
BLOCK_K: tl.constexpr,
|
|
4186
|
+
NUM_STAGES: tl.constexpr,
|
|
4187
|
+
USE_INT64: tl.constexpr,
|
|
4188
|
+
):
|
|
4189
|
+
"""
|
|
4190
|
+
Kernel to dequantize FP8 tensor to BF16 tensor.
|
|
4191
|
+
Args:
|
|
4192
|
+
xq_ptr (tl.constexpr): Pointer to FP8 tensor.
|
|
4193
|
+
x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
|
|
4194
|
+
x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
|
|
4195
|
+
M (tl.constexpr): M dimension of input tensor.
|
|
4196
|
+
K (tl.constexpr): K dimension of input tensor (along which scales are applied)
|
|
4197
|
+
BLOCK_SIZE (tl.constexpr): Block size for the K dimension.
|
|
4198
|
+
"""
|
|
4199
|
+
pid = tl.program_id(axis=0)
|
|
4200
|
+
if USE_INT64:
|
|
4201
|
+
pid = pid.to(tl.int64)
|
|
4202
|
+
offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
4203
|
+
offs_k = tl.arange(0, BLOCK_K)
|
|
4204
|
+
scales = tl.load(x_scale_ptr + offs_m)
|
|
4205
|
+
|
|
4206
|
+
for _k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
|
|
4207
|
+
mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
|
|
4208
|
+
|
|
4209
|
+
xq = tl.load(
|
|
4210
|
+
xq_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk,
|
|
4211
|
+
mask=mask,
|
|
4212
|
+
other=0.0,
|
|
4213
|
+
)
|
|
4214
|
+
x_dq = xq * scales[:, None]
|
|
4215
|
+
|
|
4216
|
+
tl.store(
|
|
4217
|
+
x_dequant_ptr
|
|
4218
|
+
+ offs_m[:, None] * stride_xdqm
|
|
4219
|
+
+ offs_k[None, :] * stride_xdqk,
|
|
4220
|
+
x_dq,
|
|
4221
|
+
mask=mask,
|
|
4222
|
+
)
|
|
4223
|
+
offs_k += BLOCK_K
|
|
4224
|
+
|
|
4225
|
+
|
|
4226
|
+
def dequantize_fp8_packed_row(
|
|
4227
|
+
xq: torch.Tensor,
|
|
4228
|
+
) -> torch.Tensor:
|
|
4229
|
+
"""
|
|
4230
|
+
Rowwise Dequantize FP8 tensor to BF16 tensor along last axis.
|
|
4231
|
+
|
|
4232
|
+
Args:
|
|
4233
|
+
xq (torch.Tensor): Packed FP8 tensor to be dequantized. The last 4 bytes of each row is the FP32 scale for that row.
|
|
4234
|
+
|
|
4235
|
+
Returns:
|
|
4236
|
+
torch.Tensor: Dequantized BF16 tensor.
|
|
4237
|
+
"""
|
|
4238
|
+
|
|
4239
|
+
# Create a view of the packed tensors, get the scale and actual xq tensor
|
|
4240
|
+
# This makes it much easier to write the kernel
|
|
4241
|
+
orig_shape = (*xq.shape[:-1], xq.shape[-1] - 4)
|
|
4242
|
+
actual_xq = xq[..., :-4].view(orig_shape)
|
|
4243
|
+
|
|
4244
|
+
assert xq.is_contiguous(), "Input tensors must be contiguous"
|
|
4245
|
+
x_dequant = torch.empty(orig_shape, dtype=torch.bfloat16, device=xq.device)
|
|
4246
|
+
|
|
4247
|
+
# Calculate number of rows when flattened
|
|
4248
|
+
num_rows = actual_xq.numel() // actual_xq.shape[-1]
|
|
4249
|
+
|
|
4250
|
+
# TODO: we take a perf hit from these reshapes, can we do better?
|
|
4251
|
+
# It's hard to skip this reshape, we can't create a int32/float32 view because of alignment issues
|
|
4252
|
+
scale_view = xq[..., -4:].reshape((num_rows * 4)).view(torch.float32)
|
|
4253
|
+
scale_view = scale_view.view(orig_shape[:-1])
|
|
4254
|
+
|
|
4255
|
+
# Reshape to 2-d array keeping last dim only.
|
|
4256
|
+
K = actual_xq.shape[-1]
|
|
4257
|
+
actual_xq = actual_xq.reshape(-1, K)
|
|
4258
|
+
M = actual_xq.shape[0]
|
|
4259
|
+
use_int64 = actual_xq.numel() > 2**31
|
|
4260
|
+
|
|
4261
|
+
def grid(meta: dict[str, int]) -> tuple[int]:
|
|
4262
|
+
return (triton.cdiv(M, meta["BLOCK_M"]),)
|
|
4263
|
+
|
|
4264
|
+
with torch.cuda.device(actual_xq.device.index):
|
|
4265
|
+
_kernel_dequantize_fp8_packed_row[grid](
|
|
4266
|
+
actual_xq,
|
|
4267
|
+
scale_view,
|
|
4268
|
+
x_dequant,
|
|
4269
|
+
M,
|
|
4270
|
+
K,
|
|
4271
|
+
actual_xq.stride(0),
|
|
4272
|
+
actual_xq.stride(1),
|
|
4273
|
+
x_dequant.stride(-2), # Use squashed stride.
|
|
4274
|
+
x_dequant.stride(-1),
|
|
4275
|
+
USE_INT64=use_int64,
|
|
4276
|
+
)
|
|
4277
|
+
|
|
4278
|
+
return x_dequant
|
|
4279
|
+
|
|
4280
|
+
|
|
4281
|
+
@triton.jit
|
|
4282
|
+
def _kernel_dequantize_fp8_block(
|
|
4283
|
+
xq_ptr,
|
|
4284
|
+
x_scale_ptr,
|
|
4285
|
+
x_dequant_ptr,
|
|
4286
|
+
M,
|
|
4287
|
+
K,
|
|
4288
|
+
BLOCK_M: tl.constexpr,
|
|
4289
|
+
BLOCK_K: tl.constexpr,
|
|
4290
|
+
):
|
|
4291
|
+
"""
|
|
4292
|
+
Kernel to dequantize FP8 tensor to BF16 tensor.
|
|
4293
|
+
Args:
|
|
4294
|
+
xq_ptr (tl.constexpr): Pointer to FP8 tensor.
|
|
4295
|
+
x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
|
|
4296
|
+
x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
|
|
4297
|
+
M (tl.constexpr): M dimension of input tensor.
|
|
4298
|
+
K (tl.constexpr): K dimension of input tensor.
|
|
4299
|
+
BLOCK_M (tl.constexpr): Block size for the M dimension.
|
|
4300
|
+
BLOCK_K (tl.constexpr): Block size for the K dimension.
|
|
4301
|
+
"""
|
|
4302
|
+
pid_m = tl.program_id(axis=0)
|
|
4303
|
+
pid_k = tl.program_id(axis=1)
|
|
4304
|
+
k = tl.cdiv(K, BLOCK_K)
|
|
4305
|
+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
4306
|
+
offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
4307
|
+
offs = offs_m[:, None] * K + offs_k[None, :]
|
|
4308
|
+
mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
|
|
4309
|
+
xq = tl.load(xq_ptr + offs, mask=mask).to(tl.bfloat16)
|
|
4310
|
+
x_scale = tl.load(x_scale_ptr + pid_m * k + pid_k)
|
|
4311
|
+
x_dequant = xq * x_scale
|
|
4312
|
+
tl.store(x_dequant_ptr + offs, x_dequant, mask=mask)
|
|
4313
|
+
|
|
4314
|
+
|
|
4315
|
+
def dequantize_fp8_block(
|
|
4316
|
+
xq: torch.Tensor,
|
|
4317
|
+
x_scale: torch.Tensor,
|
|
4318
|
+
block_m: int = 256,
|
|
4319
|
+
block_k: int = 256,
|
|
4320
|
+
) -> torch.Tensor:
|
|
4321
|
+
"""
|
|
4322
|
+
Dequantize FP8 tensor to BF16 tensor.
|
|
4323
|
+
|
|
4324
|
+
Args:
|
|
4325
|
+
xq (torch.Tensor): FP8 tensor to be dequantized.
|
|
4326
|
+
x_scale (torch.Tensor): FP8 scale tensor.
|
|
4327
|
+
block_m (int): Block size for the M dimension.
|
|
4328
|
+
block_k (int): Block size for the K dimension.
|
|
4329
|
+
|
|
4330
|
+
Returns:
|
|
4331
|
+
torch.Tensor: Dequantized BF16 tensor.
|
|
4332
|
+
"""
|
|
4333
|
+
|
|
4334
|
+
assert (
|
|
4335
|
+
xq.is_contiguous() and x_scale.is_contiguous()
|
|
4336
|
+
), "Input tensors must be contiguous"
|
|
4337
|
+
assert xq.dim() == 2 and x_scale.dim() == 2, "Input tensors must have 2 dimensions"
|
|
4338
|
+
M, K = xq.size()
|
|
4339
|
+
x_dequant = torch.empty_like(xq, dtype=torch.bfloat16)
|
|
4340
|
+
|
|
4341
|
+
def grid(meta: dict[str, int]) -> tuple[int, int]:
|
|
4342
|
+
return (
|
|
4343
|
+
triton.cdiv(M, meta["BLOCK_M"]),
|
|
4344
|
+
triton.cdiv(K, meta["BLOCK_K"]),
|
|
4345
|
+
)
|
|
4346
|
+
|
|
4347
|
+
with torch.cuda.device(xq.device.index):
|
|
4348
|
+
_kernel_dequantize_fp8_block[grid](
|
|
4349
|
+
xq, x_scale, x_dequant, M, K, BLOCK_M=block_m, BLOCK_K=block_k # pyre-ignore[6]
|
|
4350
|
+
)
|
|
4351
|
+
return x_dequant
|
|
4352
|
+
|
|
4353
|
+
|
|
4354
|
+
# This function is extracted from https://github.com/pytorch/ao/blob/v0.12.0/torchao/prototype/mx_formats/mx_tensor.py#L142
|
|
4355
|
+
def to_mxfp8(
|
|
4356
|
+
data_hp: torch.Tensor,
|
|
4357
|
+
block_size: int = 32,
|
|
4358
|
+
):
|
|
4359
|
+
assert data_hp.dtype in (
|
|
4360
|
+
torch.bfloat16,
|
|
4361
|
+
torch.float,
|
|
4362
|
+
), f"{data_hp.dtype} is not supported yet"
|
|
4363
|
+
assert (
|
|
4364
|
+
data_hp.shape[-1] % block_size == 0
|
|
4365
|
+
), f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}"
|
|
4366
|
+
assert data_hp.is_contiguous(), "unsupported"
|
|
4367
|
+
|
|
4368
|
+
orig_shape = data_hp.shape
|
|
4369
|
+
data_hp = data_hp.reshape(
|
|
4370
|
+
*orig_shape[:-1], orig_shape[-1] // block_size, block_size
|
|
4371
|
+
)
|
|
4372
|
+
|
|
4373
|
+
max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)
|
|
4374
|
+
|
|
4375
|
+
data_hp = data_hp.to(torch.float32)
|
|
4376
|
+
max_abs = max_abs.to(torch.float32)
|
|
4377
|
+
|
|
4378
|
+
F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
|
|
4379
|
+
max_pos = F8E4M3_MAX
|
|
4380
|
+
|
|
4381
|
+
# RCEIL
|
|
4382
|
+
def _to_mx_rceil(
|
|
4383
|
+
data_hp: torch.Tensor,
|
|
4384
|
+
max_abs: torch.Tensor,
|
|
4385
|
+
max_pos: float,
|
|
4386
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
4387
|
+
E8M0_EXPONENT_BIAS = 127
|
|
4388
|
+
descale = max_abs / max_pos
|
|
4389
|
+
exponent = torch.where(
|
|
4390
|
+
torch.isnan(descale),
|
|
4391
|
+
0xFF, # Handle biased exponent for nan
|
|
4392
|
+
# NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping
|
|
4393
|
+
(
|
|
4394
|
+
torch.clamp(
|
|
4395
|
+
torch.ceil(torch.log2(descale)),
|
|
4396
|
+
min=-E8M0_EXPONENT_BIAS,
|
|
4397
|
+
max=E8M0_EXPONENT_BIAS,
|
|
4398
|
+
)
|
|
4399
|
+
+ E8M0_EXPONENT_BIAS
|
|
4400
|
+
).to(torch.uint8),
|
|
4401
|
+
)
|
|
4402
|
+
|
|
4403
|
+
descale_fp = torch.where(
|
|
4404
|
+
exponent == 0,
|
|
4405
|
+
1.0,
|
|
4406
|
+
torch.exp2(E8M0_EXPONENT_BIAS - exponent.to(torch.float32)),
|
|
4407
|
+
)
|
|
4408
|
+
|
|
4409
|
+
# scale and saturated cast the data elements to max of target dtype
|
|
4410
|
+
data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos)
|
|
4411
|
+
return exponent, data_lp
|
|
4412
|
+
|
|
4413
|
+
scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos)
|
|
4414
|
+
|
|
4415
|
+
# cast to target dtype
|
|
4416
|
+
data_lp = data_lp.to(torch.float8_e4m3fn)
|
|
4417
|
+
# need to reshape at the end to help inductor fuse things
|
|
4418
|
+
data_lp = data_lp.reshape(orig_shape)
|
|
4419
|
+
|
|
4420
|
+
scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
|
|
4421
|
+
scale_e8m0_biased = scale_e8m0_biased.squeeze(-1)
|
|
4422
|
+
return scale_e8m0_biased, data_lp
|