mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1902 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
from typing import Dict, Optional, Tuple
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import triton # @manual
|
|
14
|
+
import triton.language as tl # @manual
|
|
15
|
+
from mslk.utils.triton.fp8_utils import get_fp8_constants
|
|
16
|
+
from triton import Config # @manual
|
|
17
|
+
|
|
18
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
running_on_github: bool = os.getenv("GITHUB_ENV") is not None
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
# pyre-ignore[21]
|
|
24
|
+
from triton.fb.compat import disable_bufferops # @manual
|
|
25
|
+
except ModuleNotFoundError:
|
|
26
|
+
# Ensure we can call disable_bufferops if compat is not included (e.g. opensource)
|
|
27
|
+
# TODO(njriasan): Remove when we integrate triton.fb.compat into every Triton
|
|
28
|
+
# version.
|
|
29
|
+
from contextlib import contextmanager
|
|
30
|
+
|
|
31
|
+
@contextmanager
|
|
32
|
+
def disable_bufferops(_unused: bool):
|
|
33
|
+
yield None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@triton.autotune(
|
|
37
|
+
configs=[
|
|
38
|
+
Config({"BLOCK_SIZE": 512}),
|
|
39
|
+
Config({"BLOCK_SIZE": 1024}),
|
|
40
|
+
Config({"BLOCK_SIZE": 2048}),
|
|
41
|
+
Config({"BLOCK_SIZE": 4096}),
|
|
42
|
+
Config({"BLOCK_SIZE": 8192}),
|
|
43
|
+
],
|
|
44
|
+
key=["K"],
|
|
45
|
+
)
|
|
46
|
+
@triton.jit
|
|
47
|
+
def _kernel_quantize_fp8_row(
|
|
48
|
+
A,
|
|
49
|
+
A_scale,
|
|
50
|
+
A_fp8,
|
|
51
|
+
scale_ub,
|
|
52
|
+
zero_start_index_M,
|
|
53
|
+
B,
|
|
54
|
+
M,
|
|
55
|
+
N,
|
|
56
|
+
K,
|
|
57
|
+
K_fp8, # used when padding
|
|
58
|
+
stride_ab,
|
|
59
|
+
stride_am,
|
|
60
|
+
stride_an,
|
|
61
|
+
stride_ak,
|
|
62
|
+
stride_ob,
|
|
63
|
+
stride_om,
|
|
64
|
+
stride_on,
|
|
65
|
+
stride_ok,
|
|
66
|
+
stride_zb,
|
|
67
|
+
stride_zm,
|
|
68
|
+
TL_FP8_DTYPE: tl.constexpr,
|
|
69
|
+
MAX_FP8: tl.constexpr,
|
|
70
|
+
EPS: tl.constexpr,
|
|
71
|
+
CLAMP_MAX: tl.constexpr,
|
|
72
|
+
JAGGED: tl.constexpr,
|
|
73
|
+
BLOCK_SIZE: tl.constexpr,
|
|
74
|
+
USE_INT64: tl.constexpr,
|
|
75
|
+
) -> None:
|
|
76
|
+
"""Quantize and scale each row.
|
|
77
|
+
|
|
78
|
+
Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
|
|
79
|
+
|
|
80
|
+
Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
|
|
81
|
+
in a max pass then scale/quantize pass.
|
|
82
|
+
|
|
83
|
+
Todo:
|
|
84
|
+
* Better tiling schemes.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
A (Tensor): higher precision input tensor of 4 dimension.
|
|
88
|
+
A_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
|
|
89
|
+
A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
|
|
90
|
+
scale_ub (Tensor): [1] Maximum value allowed for scale.
|
|
91
|
+
B (int): Size of dimenion 0
|
|
92
|
+
M (int): Size of dimenion 1
|
|
93
|
+
N (int): Size of dimenion 2
|
|
94
|
+
K (int): Size of dimenion 3 (input row size)
|
|
95
|
+
K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
|
|
96
|
+
stride_ab (int): Stride of b dimension of A.
|
|
97
|
+
stride_am (int): Stride of m dimension of A.
|
|
98
|
+
stride_an (int): Stride of n dimension of A.
|
|
99
|
+
stride_ak (int): Stride of k dimension of A.
|
|
100
|
+
stride_ob (int): Stride of b dimension of output.
|
|
101
|
+
stride_om (int): Stride of m dimension of output.
|
|
102
|
+
stride_on (int): Stride of n dimension of output.
|
|
103
|
+
stride_ok (int): Stride of k dimension of output.
|
|
104
|
+
stride_zb (int): Stride of b dimension of jagged index.
|
|
105
|
+
stride_zm (int): Stride of m dimension of jagged index.
|
|
106
|
+
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
|
|
107
|
+
MAX_FP8 (float): Maxmimum expressible value for FP8.
|
|
108
|
+
EPS (float): Epsilon value for numerical stability.
|
|
109
|
+
CLAMP_MAX (bool): Whethar to apply scale_ub.
|
|
110
|
+
JAGGED (bool): Whether to use jagged indexing.
|
|
111
|
+
BLOCK_SIZE (int): Block size for reduction.
|
|
112
|
+
USE_INT64 (bool): Whether to use int64 indexing for large inputs.
|
|
113
|
+
"""
|
|
114
|
+
pid = tl.program_id(0)
|
|
115
|
+
# Use int64 indexing for large inputs. This is slower, but
|
|
116
|
+
# needed to avoid index overflows.
|
|
117
|
+
if USE_INT64:
|
|
118
|
+
pid = pid.to(tl.int64)
|
|
119
|
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
|
120
|
+
a_offset_base = (
|
|
121
|
+
pid // (M * N) * stride_ab
|
|
122
|
+
+ (pid % (M * N)) // N * stride_am
|
|
123
|
+
+ (pid % (M * N)) % N * stride_an
|
|
124
|
+
)
|
|
125
|
+
a_fp8_offset_base = (
|
|
126
|
+
pid // (M * N) * stride_ob
|
|
127
|
+
+ (pid % (M * N)) // N * stride_om
|
|
128
|
+
+ (pid % (M * N)) % N * stride_on
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
K_in = K
|
|
132
|
+
|
|
133
|
+
if JAGGED:
|
|
134
|
+
z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
|
|
135
|
+
group_rows = tl.load(zero_start_index_M + z_offset_base)
|
|
136
|
+
current_row = pid % N
|
|
137
|
+
# If this row is empty, dont process any of it.
|
|
138
|
+
if current_row >= group_rows:
|
|
139
|
+
K_in = 0
|
|
140
|
+
|
|
141
|
+
# Calculate max.
|
|
142
|
+
cur_max = 0.0
|
|
143
|
+
for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
|
|
144
|
+
a = tl.load(
|
|
145
|
+
A + a_offset_base + n_offset * stride_ak,
|
|
146
|
+
mask=n_offset < K_in,
|
|
147
|
+
other=0.0,
|
|
148
|
+
)
|
|
149
|
+
tile_max = tl.max(tl.abs(a))
|
|
150
|
+
cur_max = tl.maximum(tile_max, cur_max)
|
|
151
|
+
n_offset += BLOCK_SIZE
|
|
152
|
+
|
|
153
|
+
# Clamp max value appropriately.
|
|
154
|
+
if CLAMP_MAX:
|
|
155
|
+
ub = tl.load(scale_ub)
|
|
156
|
+
cur_max = tl.clamp(cur_max, EPS, ub)
|
|
157
|
+
else:
|
|
158
|
+
cur_max = tl.maximum(cur_max, EPS)
|
|
159
|
+
# Scale and quantize.
|
|
160
|
+
a_scale = MAX_FP8 / cur_max
|
|
161
|
+
tl.store(A_scale + pid, 1.0 / a_scale)
|
|
162
|
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
|
163
|
+
|
|
164
|
+
# Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
|
|
165
|
+
for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)):
|
|
166
|
+
# Load from A if in range, else 0 (we're going all the way to K_fp8)
|
|
167
|
+
a = tl.load(
|
|
168
|
+
A + a_offset_base + n_offset * stride_ak,
|
|
169
|
+
mask=n_offset < K_in,
|
|
170
|
+
other=0.0,
|
|
171
|
+
)
|
|
172
|
+
# For elements >= K, a will be 0
|
|
173
|
+
a_fp8 = a * a_scale
|
|
174
|
+
# Clamp A to fp8 range to make sure there's no overflow.
|
|
175
|
+
# This is required for AMD. Nvidia's default saturation
|
|
176
|
+
# handles it, but it's nice to have anyway.
|
|
177
|
+
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
|
|
178
|
+
|
|
179
|
+
# Store the full new row in its place (for elements >= K, a_fp8 is already 0)
|
|
180
|
+
tl.store(
|
|
181
|
+
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
|
|
182
|
+
a_fp8,
|
|
183
|
+
mask=n_offset < K_fp8,
|
|
184
|
+
)
|
|
185
|
+
n_offset += BLOCK_SIZE
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def triton_quantize_fp8_row(
|
|
189
|
+
a: torch.Tensor,
|
|
190
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
191
|
+
zero_start_index_M: Optional[torch.Tensor] = None,
|
|
192
|
+
align_rows_to: Optional[int] = None,
|
|
193
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
194
|
+
"""
|
|
195
|
+
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
a (Tensor): higher precision input tensor of 4 dimension.
|
|
199
|
+
scale_ub (Tensor): Maximum allowed value for scale.
|
|
200
|
+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
|
|
201
|
+
align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
torch.Tensor: fp8 scaled tensor.
|
|
205
|
+
torch.Tensor: reciprocal scale tensor per row.
|
|
206
|
+
"""
|
|
207
|
+
if scale_ub is not None and scale_ub.device != a.device:
|
|
208
|
+
raise Exception("'scale_ub' must be on the same device as 'a'")
|
|
209
|
+
if zero_start_index_M is not None and zero_start_index_M.device != a.device:
|
|
210
|
+
raise Exception("'zero_start_index_M' must be on the same device as 'a'")
|
|
211
|
+
|
|
212
|
+
assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
|
|
213
|
+
a_shape = a.shape
|
|
214
|
+
while a.dim() < 4:
|
|
215
|
+
a = a.unsqueeze(0)
|
|
216
|
+
if zero_start_index_M is not None:
|
|
217
|
+
# There should be one value of zero_start_index_M per NxK matrix.
|
|
218
|
+
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
|
|
219
|
+
# Get constant values.
|
|
220
|
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
|
221
|
+
num_rows = a.numel() // a.shape[-1]
|
|
222
|
+
a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
|
|
223
|
+
# If align_rows_to is provided, pad the last dimension to be a multiple of it
|
|
224
|
+
if align_rows_to is not None:
|
|
225
|
+
last_dim = a.shape[-1]
|
|
226
|
+
padded_last_dim = (
|
|
227
|
+
(last_dim + align_rows_to - 1) // align_rows_to
|
|
228
|
+
) * align_rows_to
|
|
229
|
+
a_fp8 = torch.empty(
|
|
230
|
+
(*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype
|
|
231
|
+
)
|
|
232
|
+
a_shape = torch.Size((*a_shape[:-1], padded_last_dim))
|
|
233
|
+
else:
|
|
234
|
+
a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
|
|
235
|
+
|
|
236
|
+
# If input tensor is sufficiently large, we need to use int64 indexing.
|
|
237
|
+
use_int64 = a.numel() > (2**31 - 1)
|
|
238
|
+
grid = (num_rows,)
|
|
239
|
+
# Pick a conservative value for inference shapes for disabling BufferOps.
|
|
240
|
+
should_disable_bufferops = torch.version.hip is not None and a_shape[0] < 32
|
|
241
|
+
with disable_bufferops(should_disable_bufferops):
|
|
242
|
+
with torch.cuda.device(a.device.index):
|
|
243
|
+
_kernel_quantize_fp8_row[grid](
|
|
244
|
+
a,
|
|
245
|
+
a_scale,
|
|
246
|
+
a_fp8,
|
|
247
|
+
scale_ub,
|
|
248
|
+
zero_start_index_M,
|
|
249
|
+
a.shape[0],
|
|
250
|
+
a.shape[1],
|
|
251
|
+
a.shape[2],
|
|
252
|
+
a.shape[3],
|
|
253
|
+
a_fp8.shape[3],
|
|
254
|
+
a.stride(0),
|
|
255
|
+
a.stride(1),
|
|
256
|
+
a.stride(2),
|
|
257
|
+
a.stride(3),
|
|
258
|
+
a_fp8.stride(0),
|
|
259
|
+
a_fp8.stride(1),
|
|
260
|
+
a_fp8.stride(2),
|
|
261
|
+
a_fp8.stride(3),
|
|
262
|
+
(
|
|
263
|
+
zero_start_index_M.stride(0)
|
|
264
|
+
if zero_start_index_M is not None
|
|
265
|
+
else None
|
|
266
|
+
),
|
|
267
|
+
(
|
|
268
|
+
zero_start_index_M.stride(1)
|
|
269
|
+
if zero_start_index_M is not None
|
|
270
|
+
else None
|
|
271
|
+
),
|
|
272
|
+
TL_FP8_DTYPE=tl_dtype,
|
|
273
|
+
MAX_FP8=max_fp8,
|
|
274
|
+
EPS=eps,
|
|
275
|
+
CLAMP_MAX=scale_ub is not None,
|
|
276
|
+
JAGGED=zero_start_index_M is not None,
|
|
277
|
+
USE_INT64=use_int64,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
@triton.autotune(
|
|
284
|
+
configs=[
|
|
285
|
+
Config({"BLOCK_SIZE": 512}),
|
|
286
|
+
Config({"BLOCK_SIZE": 1024}),
|
|
287
|
+
Config({"BLOCK_SIZE": 2048}),
|
|
288
|
+
Config({"BLOCK_SIZE": 4096}),
|
|
289
|
+
Config({"BLOCK_SIZE": 8192}),
|
|
290
|
+
],
|
|
291
|
+
key=["K"],
|
|
292
|
+
)
|
|
293
|
+
@triton.jit
|
|
294
|
+
def _kernel_quantize_fp8_packed_row(
|
|
295
|
+
A,
|
|
296
|
+
A_fp8,
|
|
297
|
+
packed_scale,
|
|
298
|
+
scale_ub,
|
|
299
|
+
zero_start_index_M,
|
|
300
|
+
B,
|
|
301
|
+
M,
|
|
302
|
+
N,
|
|
303
|
+
K,
|
|
304
|
+
stride_ab,
|
|
305
|
+
stride_am,
|
|
306
|
+
stride_an,
|
|
307
|
+
stride_ak,
|
|
308
|
+
stride_ob,
|
|
309
|
+
stride_om,
|
|
310
|
+
stride_on,
|
|
311
|
+
stride_ok,
|
|
312
|
+
packed_scale_stride,
|
|
313
|
+
stride_zb,
|
|
314
|
+
stride_zm,
|
|
315
|
+
TL_FP8_DTYPE: tl.constexpr,
|
|
316
|
+
MAX_FP8: tl.constexpr,
|
|
317
|
+
EPS: tl.constexpr,
|
|
318
|
+
CLAMP_MAX: tl.constexpr,
|
|
319
|
+
JAGGED: tl.constexpr,
|
|
320
|
+
BLOCK_SIZE: tl.constexpr,
|
|
321
|
+
USE_INT64: tl.constexpr,
|
|
322
|
+
) -> None:
|
|
323
|
+
"""Quantize and scale each row.
|
|
324
|
+
|
|
325
|
+
Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
|
|
326
|
+
|
|
327
|
+
Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
|
|
328
|
+
in a max pass then scale/quantize pass.
|
|
329
|
+
|
|
330
|
+
Todo:
|
|
331
|
+
* Better tiling schemes.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
A (Tensor): higher precision input tensor of 4 dimension.
|
|
335
|
+
packed_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
|
|
336
|
+
A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
|
|
337
|
+
scale_ub (Tensor): [1] Maximum value allowed for scale.
|
|
338
|
+
B (int): Size of dimenion 0
|
|
339
|
+
M (int): Size of dimenion 1
|
|
340
|
+
N (int): Size of dimenion 2
|
|
341
|
+
K (int): Size of dimenion 3
|
|
342
|
+
stride_ab (int): Stride of b dimension of A.
|
|
343
|
+
stride_am (int): Stride of m dimension of A.
|
|
344
|
+
stride_an (int): Stride of n dimension of A.
|
|
345
|
+
stride_ak (int): Stride of k dimension of A.
|
|
346
|
+
stride_ob (int): Stride of b dimension of output.
|
|
347
|
+
stride_om (int): Stride of m dimension of output.
|
|
348
|
+
stride_on (int): Stride of n dimension of output.
|
|
349
|
+
stride_ok (int): Stride of k dimension of output.
|
|
350
|
+
packed_scale_stride (int): Stride of the packed scale, indexing into a_fp8.
|
|
351
|
+
stride_zb (int): Stride of b dimension of jagged index.
|
|
352
|
+
stride_zm (int): Stride of m dimension of jagged index.
|
|
353
|
+
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
|
|
354
|
+
MAX_FP8 (float): Maxmimum expressible value for FP8.
|
|
355
|
+
EPS (float): Epsilon value for numerical stability.
|
|
356
|
+
CLAMP_MAX (bool): Whethar to apply scale_ub.
|
|
357
|
+
JAGGED (bool): Whether to use jagged indexing.
|
|
358
|
+
BLOCK_SIZE (int): Block size for reduction.
|
|
359
|
+
USE_INT64 (bool): Whether to use int64 indexing for large inputs.
|
|
360
|
+
"""
|
|
361
|
+
pid = tl.program_id(0)
|
|
362
|
+
# Use int64 indexing for large inputs. This is slower, but
|
|
363
|
+
# needed to avoid index overflows.
|
|
364
|
+
if USE_INT64:
|
|
365
|
+
pid = pid.to(tl.int64)
|
|
366
|
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
|
367
|
+
a_offset_base = (
|
|
368
|
+
pid // (M * N) * stride_ab
|
|
369
|
+
+ (pid % (M * N)) // N * stride_am
|
|
370
|
+
+ (pid % (M * N)) % N * stride_an
|
|
371
|
+
)
|
|
372
|
+
a_fp8_offset_base = (
|
|
373
|
+
pid // (M * N) * stride_ob
|
|
374
|
+
+ (pid % (M * N)) // N * stride_om
|
|
375
|
+
+ (pid % (M * N)) % N * stride_on
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
K_in = K
|
|
379
|
+
|
|
380
|
+
if JAGGED:
|
|
381
|
+
z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
|
|
382
|
+
group_rows = tl.load(zero_start_index_M + z_offset_base)
|
|
383
|
+
current_row = pid % N
|
|
384
|
+
# If this row is empty, dont process any of it.
|
|
385
|
+
if current_row >= group_rows:
|
|
386
|
+
K_in = 0
|
|
387
|
+
|
|
388
|
+
# Calculate max.
|
|
389
|
+
cur_max = 0.0
|
|
390
|
+
for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
|
|
391
|
+
a = tl.load(
|
|
392
|
+
A + a_offset_base + n_offset * stride_ak,
|
|
393
|
+
mask=n_offset < K_in,
|
|
394
|
+
other=0.0,
|
|
395
|
+
)
|
|
396
|
+
tile_max = tl.max(tl.abs(a))
|
|
397
|
+
cur_max = tl.maximum(tile_max, cur_max)
|
|
398
|
+
n_offset += BLOCK_SIZE
|
|
399
|
+
|
|
400
|
+
# Clamp max value appropriately.
|
|
401
|
+
if CLAMP_MAX:
|
|
402
|
+
ub = tl.load(scale_ub)
|
|
403
|
+
cur_max = tl.clamp(cur_max, EPS, ub)
|
|
404
|
+
else:
|
|
405
|
+
cur_max = tl.maximum(cur_max, EPS)
|
|
406
|
+
# Scale and quantize.
|
|
407
|
+
a_scale = MAX_FP8 / cur_max
|
|
408
|
+
|
|
409
|
+
(fp8_0, fp8_1, fp8_2, fp8_3) = tl.inline_asm_elementwise(
|
|
410
|
+
asm="""
|
|
411
|
+
{
|
|
412
|
+
// $4 is the input register
|
|
413
|
+
.reg .b32 input;
|
|
414
|
+
mov.b32 input, $4;
|
|
415
|
+
mov.b32 $0, $4;
|
|
416
|
+
shr.b32 $1, $4, 8;
|
|
417
|
+
shr.b32 $2, $4, 16;
|
|
418
|
+
shr.b32 $3, $4, 24;
|
|
419
|
+
}
|
|
420
|
+
""",
|
|
421
|
+
constraints=("=r,=r,=r,=r,r"),
|
|
422
|
+
# Let's pass in 1 uint32 value per iteration, containing 8 packed int4 values
|
|
423
|
+
args=[1.0 / a_scale],
|
|
424
|
+
dtype=(
|
|
425
|
+
tl.uint8,
|
|
426
|
+
tl.uint8,
|
|
427
|
+
tl.uint8,
|
|
428
|
+
tl.uint8,
|
|
429
|
+
),
|
|
430
|
+
is_pure=True,
|
|
431
|
+
pack=1,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# There are some compiler issues with FP8 pointers
|
|
435
|
+
packed_scale_ptr = packed_scale.to(tl.pointer_type(tl.uint8))
|
|
436
|
+
tl.store(packed_scale_ptr + pid * packed_scale_stride, fp8_0)
|
|
437
|
+
tl.store(packed_scale_ptr + pid * packed_scale_stride + 1, fp8_1)
|
|
438
|
+
tl.store(packed_scale_ptr + pid * packed_scale_stride + 2, fp8_2)
|
|
439
|
+
tl.store(packed_scale_ptr + pid * packed_scale_stride + 3, fp8_3)
|
|
440
|
+
|
|
441
|
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
|
442
|
+
|
|
443
|
+
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
|
|
444
|
+
a = tl.load(
|
|
445
|
+
A + a_offset_base + n_offset * stride_ak,
|
|
446
|
+
mask=n_offset < K_in,
|
|
447
|
+
other=0.0,
|
|
448
|
+
)
|
|
449
|
+
a_fp8 = a * a_scale
|
|
450
|
+
# Clamp A to fp8 range to make sure there's no overflow.
|
|
451
|
+
# This is required for AMD. Nvidia's default saturation
|
|
452
|
+
# handles it, but it's nice to have anyway.
|
|
453
|
+
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
|
|
454
|
+
tl.store(
|
|
455
|
+
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
|
|
456
|
+
a_fp8,
|
|
457
|
+
mask=n_offset < K,
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
n_offset += BLOCK_SIZE
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def triton_quantize_fp8_packed_row(
|
|
464
|
+
a: torch.Tensor,
|
|
465
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
466
|
+
zero_start_index_M: Optional[torch.Tensor] = None,
|
|
467
|
+
return_only_packed: Optional[bool] = False,
|
|
468
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor]:
|
|
469
|
+
"""
|
|
470
|
+
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
|
|
471
|
+
|
|
472
|
+
This packs the FP32 scale at the end of each row, so the fp8 scaled tensor and the reciprocal scale tensor per row are contiguous in memory.
|
|
473
|
+
|
|
474
|
+
Args:
|
|
475
|
+
a (Tensor): higher precision input tensor of 4 dimension.
|
|
476
|
+
scale_ub (Tensor): Maximum allowed value for scale.
|
|
477
|
+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
|
|
478
|
+
return_only_packed (bool): Only return the packed tensor, do not unpack results if True
|
|
479
|
+
Returns:
|
|
480
|
+
torch.Tensor: fp8 scaled tensor.
|
|
481
|
+
torch.Tensor: reciprocal scale tensor per row.
|
|
482
|
+
torch.Tensor: The packed FP8 scaled tensor, with the scale at the end of each row.
|
|
483
|
+
"""
|
|
484
|
+
if scale_ub is not None and scale_ub.device != a.device:
|
|
485
|
+
raise Exception("'scale_ub' must be on the same device as 'a'")
|
|
486
|
+
if zero_start_index_M is not None and zero_start_index_M.device != a.device:
|
|
487
|
+
raise Exception("'zero_start_index_M' must be on the same device as 'a'")
|
|
488
|
+
|
|
489
|
+
assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
|
|
490
|
+
a_shape = a.shape
|
|
491
|
+
while a.dim() < 4:
|
|
492
|
+
a = a.unsqueeze(0)
|
|
493
|
+
if zero_start_index_M is not None:
|
|
494
|
+
# There should be one value of zero_start_index_M per NxK matrix.
|
|
495
|
+
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
|
|
496
|
+
# Get constant values.
|
|
497
|
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
|
498
|
+
num_rows = a.numel() // a.shape[-1]
|
|
499
|
+
|
|
500
|
+
# Allocate an extra 4-bytes at the end of each row for the scale.
|
|
501
|
+
a_fp8 = torch.empty(
|
|
502
|
+
(*a.shape[:-1], a.shape[-1] + 4), device=a.device, dtype=pt_dtype
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
# create a view of the packed scale
|
|
506
|
+
packed_scale = a_fp8[..., -4:]
|
|
507
|
+
|
|
508
|
+
# If input tensor is sufficiently large, we need to use int64 indexing.
|
|
509
|
+
use_int64 = a.numel() > (2**31 - 1)
|
|
510
|
+
grid = (num_rows,)
|
|
511
|
+
|
|
512
|
+
with torch.cuda.device(a.device.index):
|
|
513
|
+
_kernel_quantize_fp8_packed_row[grid](
|
|
514
|
+
a,
|
|
515
|
+
a_fp8,
|
|
516
|
+
packed_scale,
|
|
517
|
+
scale_ub,
|
|
518
|
+
zero_start_index_M,
|
|
519
|
+
a.shape[0],
|
|
520
|
+
a.shape[1],
|
|
521
|
+
a.shape[2],
|
|
522
|
+
a.shape[3],
|
|
523
|
+
a.stride(0),
|
|
524
|
+
a.stride(1),
|
|
525
|
+
a.stride(2),
|
|
526
|
+
a.stride(3),
|
|
527
|
+
a_fp8.stride(0),
|
|
528
|
+
a_fp8.stride(1),
|
|
529
|
+
a_fp8.stride(2),
|
|
530
|
+
a_fp8.stride(3),
|
|
531
|
+
packed_scale.stride(2), # this is the stride that matters
|
|
532
|
+
zero_start_index_M.stride(0) if zero_start_index_M is not None else None,
|
|
533
|
+
zero_start_index_M.stride(1) if zero_start_index_M is not None else None,
|
|
534
|
+
TL_FP8_DTYPE=tl_dtype,
|
|
535
|
+
MAX_FP8=max_fp8,
|
|
536
|
+
EPS=eps,
|
|
537
|
+
CLAMP_MAX=scale_ub is not None,
|
|
538
|
+
JAGGED=zero_start_index_M is not None,
|
|
539
|
+
USE_INT64=use_int64,
|
|
540
|
+
)
|
|
541
|
+
if return_only_packed:
|
|
542
|
+
return None, None, a_fp8.view((*a_shape[:-1], a_shape[-1] + 4))
|
|
543
|
+
|
|
544
|
+
# Extract the original shape data without the extra 4 bytes per row
|
|
545
|
+
# The data is still contiguous in memory, so we have to unpack it.
|
|
546
|
+
final_fp8_view = a_fp8[..., :-4].view(a_shape)
|
|
547
|
+
scale_view = a_fp8[..., -4:].reshape((num_rows * 4)).view(torch.float32)
|
|
548
|
+
|
|
549
|
+
# the difference with the packed API is that it also
|
|
550
|
+
# returns the full packed tensor as a third return value
|
|
551
|
+
return final_fp8_view, scale_view.view(a_shape[:-1]), a_fp8
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
@torch.library.custom_op("triton::quantize_fp8_packed_row", mutates_args=())
|
|
555
|
+
def quantize_fp8_packed_row(
|
|
556
|
+
a: torch.Tensor,
|
|
557
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
558
|
+
zero_start_index_M: Optional[torch.Tensor] = None,
|
|
559
|
+
use_triton: bool = True,
|
|
560
|
+
output_device: Optional[torch.device] = None,
|
|
561
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
562
|
+
"""
|
|
563
|
+
Quantize a to fp8 with row-wise scalings and optionally move to output device.
|
|
564
|
+
|
|
565
|
+
Args:
|
|
566
|
+
a (Tensor): Input high precision tensor. Required to have no more than 4 dimension
|
|
567
|
+
scale_ub (Tensor): Maximum allowed value for scale.
|
|
568
|
+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
|
|
569
|
+
use_triton (bool): Whether to use triton kernel or pytorch.
|
|
570
|
+
output_device (torch.device): Device to optionally move the scaled tensors to.
|
|
571
|
+
Returns:
|
|
572
|
+
torch.Tensor: fp8 scaled tensor.
|
|
573
|
+
torch.Tensor: The reciprocal scale tensor per row.
|
|
574
|
+
"""
|
|
575
|
+
|
|
576
|
+
if a.device == torch.device("cpu"):
|
|
577
|
+
logger.info("Triton does not support cpu, falling back to torch ops.")
|
|
578
|
+
use_triton = False
|
|
579
|
+
if use_triton:
|
|
580
|
+
# ignore the packed tensor here, we aren't testing it
|
|
581
|
+
a_fp8, scale, _ = triton_quantize_fp8_packed_row(
|
|
582
|
+
a, scale_ub, zero_start_index_M, return_only_packed=False
|
|
583
|
+
)
|
|
584
|
+
assert a_fp8 is not None
|
|
585
|
+
assert scale is not None
|
|
586
|
+
return a_fp8, scale
|
|
587
|
+
# else use pytorch implementation.
|
|
588
|
+
if not output_device:
|
|
589
|
+
output_device = a.device
|
|
590
|
+
|
|
591
|
+
a_shape = a.shape
|
|
592
|
+
# Get constants.
|
|
593
|
+
pt_dtype, _, max_fp8, eps = get_fp8_constants()
|
|
594
|
+
row_max: torch.Tensor = torch.max(torch.abs(a), dim=-1)[0]
|
|
595
|
+
# Apply clamping.
|
|
596
|
+
if scale_ub is not None:
|
|
597
|
+
row_max = torch.clamp(row_max, min=eps, max=scale_ub.item())
|
|
598
|
+
else:
|
|
599
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
600
|
+
row_max = torch.clamp(row_max, min=eps)
|
|
601
|
+
a_scale = torch.empty((a.shape[:-1]), dtype=torch.float32, device=output_device)
|
|
602
|
+
a_scale = max_fp8 / row_max.to(torch.float32) # pyre-ignore
|
|
603
|
+
a_scale[a_scale == float("inf")] = 1.0 # pyre-ignore
|
|
604
|
+
a_fp8 = a * a_scale[..., None] # pyre-ignore
|
|
605
|
+
# Cast and move data to output device (for cpu weight loading).
|
|
606
|
+
a_fp8 = a_fp8.to(device=output_device, dtype=pt_dtype)
|
|
607
|
+
a_scale = a_scale.to(output_device) # pyre-ignore
|
|
608
|
+
del a
|
|
609
|
+
return a_fp8, (1 / a_scale).view(a_shape[:-1]) # pyre-ignore
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
@torch.library.custom_op("triton::quantize_fp8_packed_row_raw", mutates_args=())
|
|
613
|
+
def quantize_fp8_packed_row_raw(
|
|
614
|
+
a: torch.Tensor,
|
|
615
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
616
|
+
zero_start_index_M: Optional[torch.Tensor] = None,
|
|
617
|
+
use_triton: bool = True,
|
|
618
|
+
output_device: Optional[torch.device] = None,
|
|
619
|
+
) -> torch.Tensor:
|
|
620
|
+
"""
|
|
621
|
+
Quantize a to fp8 with row-wise scalings and optionally move to output device.
|
|
622
|
+
|
|
623
|
+
Identical to quantize_fp8_packed_row, except it only returns the raw packed tensor.
|
|
624
|
+
|
|
625
|
+
Args:
|
|
626
|
+
a (Tensor): Input high precision tensor. Required to have no more than 4 dimension
|
|
627
|
+
scale_ub (Tensor): Maximum allowed value for scale.
|
|
628
|
+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
|
|
629
|
+
use_triton (bool): Whether to use triton kernel or pytorch.
|
|
630
|
+
output_device (torch.device): Device to optionally move the scaled tensors to.
|
|
631
|
+
Returns:
|
|
632
|
+
torch.Tensor: fp8 scaled tensor.
|
|
633
|
+
torch.Tensor: The reciprocal scale tensor per row.
|
|
634
|
+
"""
|
|
635
|
+
|
|
636
|
+
if a.device == torch.device("cpu"):
|
|
637
|
+
logger.info("Triton does not support cpu, falling back to torch ops.")
|
|
638
|
+
use_triton = False
|
|
639
|
+
if use_triton:
|
|
640
|
+
# ignore the packed tensor here, we aren't testing it
|
|
641
|
+
_, _, packed_tensor = triton_quantize_fp8_packed_row(
|
|
642
|
+
a, scale_ub, zero_start_index_M, return_only_packed=True
|
|
643
|
+
)
|
|
644
|
+
return packed_tensor
|
|
645
|
+
else:
|
|
646
|
+
raise Exception(
|
|
647
|
+
"No PyTorch implementation provided for triton::quantize_fp8_packed_row_raw"
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
@torch.library.custom_op("triton::quantize_fp8_row", mutates_args=())
|
|
652
|
+
def quantize_fp8_row(
|
|
653
|
+
a: torch.Tensor,
|
|
654
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
655
|
+
zero_start_index_M: Optional[torch.Tensor] = None,
|
|
656
|
+
use_triton: bool = True,
|
|
657
|
+
output_device: Optional[torch.device] = None,
|
|
658
|
+
align_rows_to: Optional[int] = None,
|
|
659
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
660
|
+
"""
|
|
661
|
+
Quantize a to fp8 with row-wise scalings and optionally move to output device.
|
|
662
|
+
|
|
663
|
+
Args:
|
|
664
|
+
a (Tensor): Input high precision tensor. Required to have no more than 4 dimension
|
|
665
|
+
scale_ub (Tensor): Maximum allowed value for scale.
|
|
666
|
+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
|
|
667
|
+
use_triton (bool): Whether to use triton kernel or pytorch.
|
|
668
|
+
output_device (torch.device): Device to optionally move the scaled tensors to.
|
|
669
|
+
align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
|
|
670
|
+
|
|
671
|
+
Returns:
|
|
672
|
+
torch.Tensor: fp8 scaled tensor.
|
|
673
|
+
torch.Tensor: The reciprocal scale tensor per row.
|
|
674
|
+
"""
|
|
675
|
+
|
|
676
|
+
if a.device == torch.device("cpu"):
|
|
677
|
+
logger.info("Triton does not support cpu, falling back to torch ops.")
|
|
678
|
+
use_triton = False
|
|
679
|
+
if use_triton:
|
|
680
|
+
return triton_quantize_fp8_row(
|
|
681
|
+
a,
|
|
682
|
+
scale_ub,
|
|
683
|
+
zero_start_index_M,
|
|
684
|
+
align_rows_to=align_rows_to,
|
|
685
|
+
)
|
|
686
|
+
# else use pytorch implementation.
|
|
687
|
+
if not output_device:
|
|
688
|
+
output_device = a.device
|
|
689
|
+
|
|
690
|
+
a_shape = a.shape
|
|
691
|
+
# Get constants.
|
|
692
|
+
pt_dtype, _, max_fp8, eps = get_fp8_constants()
|
|
693
|
+
row_max: torch.Tensor = torch.max(torch.abs(a), dim=-1)[0]
|
|
694
|
+
# Apply clamping.
|
|
695
|
+
if scale_ub is not None:
|
|
696
|
+
row_max = torch.clamp(row_max, min=eps, max=scale_ub.item())
|
|
697
|
+
else:
|
|
698
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
699
|
+
row_max = torch.clamp(row_max, min=eps)
|
|
700
|
+
a_scale = torch.empty((a.shape[:-1]), dtype=torch.float32, device=output_device)
|
|
701
|
+
a_scale = max_fp8 / row_max.to(torch.float32) # pyre-ignore
|
|
702
|
+
a_scale[a_scale == float("inf")] = 1.0 # pyre-ignore
|
|
703
|
+
a_fp8 = a * a_scale[..., None] # pyre-ignore
|
|
704
|
+
# Cast and move data to output device (for cpu weight loading).
|
|
705
|
+
a_fp8 = a_fp8.to(device=output_device, dtype=pt_dtype)
|
|
706
|
+
a_scale = a_scale.to(output_device) # pyre-ignore
|
|
707
|
+
del a
|
|
708
|
+
return a_fp8, (1 / a_scale).view(a_shape[:-1]) # pyre-ignore
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
@quantize_fp8_row.register_fake
|
|
712
|
+
def quantize_fp8_row_meta(
|
|
713
|
+
a: torch.Tensor,
|
|
714
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
715
|
+
zero_start_index_M: Optional[torch.Tensor] = None,
|
|
716
|
+
use_triton: bool = True,
|
|
717
|
+
output_device: Optional[torch.device] = None,
|
|
718
|
+
align_rows_to: Optional[int] = None,
|
|
719
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
720
|
+
"""Shape function for torch compile."""
|
|
721
|
+
if output_device is None:
|
|
722
|
+
output_device = a.device
|
|
723
|
+
a_shape = a.shape
|
|
724
|
+
dtype = get_fp8_constants()[0]
|
|
725
|
+
fake_scale = torch.empty(a_shape[:-1], device=output_device, dtype=torch.float32)
|
|
726
|
+
if align_rows_to is not None:
|
|
727
|
+
last_dim = a.shape[-1]
|
|
728
|
+
padded_last_dim = (
|
|
729
|
+
(last_dim + align_rows_to - 1) // align_rows_to
|
|
730
|
+
) * align_rows_to
|
|
731
|
+
fake_out = torch.empty(
|
|
732
|
+
(*a.shape[:-1], padded_last_dim), device=output_device, dtype=dtype
|
|
733
|
+
)
|
|
734
|
+
return fake_out, fake_scale
|
|
735
|
+
else:
|
|
736
|
+
fake_out = torch.empty(a.shape, device=output_device, dtype=dtype)
|
|
737
|
+
return fake_out, fake_scale
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
@triton.autotune(
|
|
741
|
+
configs=[
|
|
742
|
+
Config({"BLOCK_SIZE": 512}),
|
|
743
|
+
Config({"BLOCK_SIZE": 1024}),
|
|
744
|
+
Config({"BLOCK_SIZE": 2048}),
|
|
745
|
+
Config({"BLOCK_SIZE": 4096}),
|
|
746
|
+
Config({"BLOCK_SIZE": 8192}),
|
|
747
|
+
],
|
|
748
|
+
key=["N"],
|
|
749
|
+
)
|
|
750
|
+
@triton.jit
|
|
751
|
+
def _kernel_scale_fp8_row(
|
|
752
|
+
A,
|
|
753
|
+
x_scale,
|
|
754
|
+
w_scale,
|
|
755
|
+
scaled_out,
|
|
756
|
+
M,
|
|
757
|
+
N,
|
|
758
|
+
stride_am,
|
|
759
|
+
stride_an,
|
|
760
|
+
stride_om,
|
|
761
|
+
stride_on,
|
|
762
|
+
BLOCK_SIZE: tl.constexpr,
|
|
763
|
+
) -> None:
|
|
764
|
+
"""
|
|
765
|
+
Scale each row of A by x_scale and each column of A by w_scale.
|
|
766
|
+
|
|
767
|
+
Args:
|
|
768
|
+
A (Tensor): [m, n] Input tensor to scale.
|
|
769
|
+
x_scale (Tensor): [m] Row-wise scale tensor.
|
|
770
|
+
w_scale (Tensor): [n] Col-wise scale tensor.
|
|
771
|
+
scaled_out (Tensor): [m, n] Output tensor.
|
|
772
|
+
M (int): Number of rows.
|
|
773
|
+
N (int): Number of columns.
|
|
774
|
+
stride_am (int): Stride of m dimension of A.
|
|
775
|
+
stride_an (int): Stride of n dimension of A.
|
|
776
|
+
stride_om (int): Stride of m dimension of output.
|
|
777
|
+
stride_on (int): Stride of n dimension of output.
|
|
778
|
+
BLOCK_SIZE (int): Block size for data loads.
|
|
779
|
+
"""
|
|
780
|
+
pid = tl.program_id(0)
|
|
781
|
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
|
782
|
+
# Load activation scale for this row.
|
|
783
|
+
row_scale = tl.load(x_scale + pid)
|
|
784
|
+
|
|
785
|
+
# Iterate over chunks of the row and apply scales.
|
|
786
|
+
for _k in range(0, tl.cdiv(N, BLOCK_SIZE)):
|
|
787
|
+
a = tl.load(
|
|
788
|
+
A + pid * stride_am + n_offset * stride_an, mask=n_offset < N, other=0.0
|
|
789
|
+
)
|
|
790
|
+
col_scale = tl.load(w_scale + n_offset)
|
|
791
|
+
scaled_a = a * row_scale * col_scale
|
|
792
|
+
tl.store(
|
|
793
|
+
scaled_out + pid * stride_om + n_offset * stride_on,
|
|
794
|
+
scaled_a,
|
|
795
|
+
mask=n_offset < N,
|
|
796
|
+
)
|
|
797
|
+
n_offset += BLOCK_SIZE
|
|
798
|
+
|
|
799
|
+
|
|
800
|
+
def scale_fp8_row(
|
|
801
|
+
a: torch.Tensor,
|
|
802
|
+
x_scale: torch.Tensor,
|
|
803
|
+
w_scale: torch.Tensor,
|
|
804
|
+
) -> torch.Tensor:
|
|
805
|
+
"""
|
|
806
|
+
Apply only rowwise scaling to a tensor. Useful when combining with kernels
|
|
807
|
+
that do not support fused rowwise scaling.
|
|
808
|
+
|
|
809
|
+
Args:
|
|
810
|
+
a (Tensor): Input floating point tensor to be scaled.
|
|
811
|
+
x_scale (Tensor): Row-wise activation scale tensor.
|
|
812
|
+
w_scale (Tensor): Col-wise weight scale tensor.
|
|
813
|
+
"""
|
|
814
|
+
if a.device == torch.device("cpu"):
|
|
815
|
+
# On CPU we'll just use native pytorch to scale.
|
|
816
|
+
return a * x_scale[:, None] * w_scale[None, :]
|
|
817
|
+
|
|
818
|
+
if x_scale.device != a.device:
|
|
819
|
+
raise Exception("'x_scale' must be on the same device as 'a'")
|
|
820
|
+
if w_scale.device != a.device:
|
|
821
|
+
raise Exception("'w_scale' must be on the same device as 'a'")
|
|
822
|
+
|
|
823
|
+
# Otherwise, use a fast triton kernel to implement.
|
|
824
|
+
# We'll parallelize over rows.
|
|
825
|
+
num_rows = a.shape[0]
|
|
826
|
+
scaled_out = torch.empty(a.shape, device=a.device, dtype=a.dtype)
|
|
827
|
+
grid = (num_rows,)
|
|
828
|
+
with torch.cuda.device(a.device.index):
|
|
829
|
+
_kernel_scale_fp8_row[grid](
|
|
830
|
+
a,
|
|
831
|
+
x_scale,
|
|
832
|
+
w_scale,
|
|
833
|
+
scaled_out,
|
|
834
|
+
a.shape[0],
|
|
835
|
+
a.shape[1],
|
|
836
|
+
a.stride(0),
|
|
837
|
+
a.stride(1),
|
|
838
|
+
scaled_out.stride(0),
|
|
839
|
+
scaled_out.stride(1),
|
|
840
|
+
)
|
|
841
|
+
|
|
842
|
+
return scaled_out
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
@triton.jit
|
|
846
|
+
def _kernel_quantize_fp8_block(
|
|
847
|
+
A,
|
|
848
|
+
A_scale,
|
|
849
|
+
A_fp8,
|
|
850
|
+
scale_ub,
|
|
851
|
+
M,
|
|
852
|
+
K,
|
|
853
|
+
stride_am,
|
|
854
|
+
stride_ak,
|
|
855
|
+
stride_om,
|
|
856
|
+
stride_ok,
|
|
857
|
+
stride_a_scale_m,
|
|
858
|
+
stride_a_scale_k,
|
|
859
|
+
TL_FP8_DTYPE: tl.constexpr,
|
|
860
|
+
MAX_FP8: tl.constexpr,
|
|
861
|
+
EPS: tl.constexpr,
|
|
862
|
+
CLAMP_MAX: tl.constexpr,
|
|
863
|
+
BLOCK_M: tl.constexpr,
|
|
864
|
+
BLOCK_K: tl.constexpr,
|
|
865
|
+
K_MAJOR: tl.constexpr,
|
|
866
|
+
) -> None:
|
|
867
|
+
"""Quantize and scale each [BLOCK_M, BLOCK_K] block.
|
|
868
|
+
|
|
869
|
+
Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(A[i:i+BLOCK_M, j:j+BLOCK_K])))
|
|
870
|
+
|
|
871
|
+
Kernel naively iterates through matrix with [BLOCK_M, BLOCK_K] tiles.
|
|
872
|
+
|
|
873
|
+
Todo:
|
|
874
|
+
* Better tiling and ordering schemes.
|
|
875
|
+
|
|
876
|
+
Args:
|
|
877
|
+
A (Tensor): [M, K] higher precision input tensor.
|
|
878
|
+
A_scale (Tensor): [cdiv(M, BLOCK_M), cdiv(K, BLOCK_K)] reciprocal scale tensor per block.
|
|
879
|
+
A_fp8 (Tensor): [M, K] fp8 scaled tensor. A_fp8 = A * a_scale
|
|
880
|
+
scale_ub (Tensor): [1] Maximum allowed value for scale.
|
|
881
|
+
M (int): Number of rows.
|
|
882
|
+
K (int): Number of columns.
|
|
883
|
+
stride_am (int): Stride of m dimension of A.
|
|
884
|
+
stride_ak (int): Stride of k dimension of A.
|
|
885
|
+
stride_om (int): Stride of m dimension of output.
|
|
886
|
+
stride_ok (int): Stride of k dimension of output.
|
|
887
|
+
stride_a_scale_m (int): Stride of m dimension of A_scale.
|
|
888
|
+
stride_a_scale_k (int): Stride of k dimension of A_scale.
|
|
889
|
+
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
|
|
890
|
+
MAX_FP8 (float): Maxmimum expressible value for FP8.
|
|
891
|
+
EPS (float): Epsilon value for numerical stability.
|
|
892
|
+
CLAMP_MAX (bool): Whether to apply scale_ub.
|
|
893
|
+
BLOCK_M (int): Block size for M dimension of A_scale and kernel.
|
|
894
|
+
BLOCK_K (int): Block size for K dimension of A_scale and kernel.
|
|
895
|
+
K_MAJOR (bool): Whether output scales should be K major (True) or MN major (False).
|
|
896
|
+
"""
|
|
897
|
+
pid = tl.program_id(0)
|
|
898
|
+
grid_k = tl.cdiv(K, BLOCK_K)
|
|
899
|
+
block_m = pid // grid_k
|
|
900
|
+
block_k = pid % grid_k
|
|
901
|
+
rm = block_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
902
|
+
rk = block_k * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
903
|
+
a_offset = rm[:, None] * stride_am + rk[None, :] * stride_ak
|
|
904
|
+
out_offset = rm[:, None] * stride_om + rk[None, :] * stride_ok
|
|
905
|
+
a_mask = (rm < M)[:, None] & (rk < K)[None, :]
|
|
906
|
+
a_block = tl.load(A + a_offset, mask=a_mask, other=0.0)
|
|
907
|
+
|
|
908
|
+
block_max = tl.max(tl.abs(a_block))
|
|
909
|
+
# Apply appropriate clamping.
|
|
910
|
+
if CLAMP_MAX:
|
|
911
|
+
ub = tl.load(scale_ub)
|
|
912
|
+
block_max = tl.clamp(block_max, EPS, ub)
|
|
913
|
+
else:
|
|
914
|
+
block_max = tl.maximum(block_max, EPS)
|
|
915
|
+
scale = MAX_FP8 / block_max
|
|
916
|
+
|
|
917
|
+
# Write in transposed order if specified.
|
|
918
|
+
if K_MAJOR:
|
|
919
|
+
scale_offset = block_m * stride_a_scale_m + block_k * stride_a_scale_k
|
|
920
|
+
else:
|
|
921
|
+
scale_offset = block_k * stride_a_scale_m + block_m * stride_a_scale_k
|
|
922
|
+
tl.store(A_scale + scale_offset, 1.0 / scale)
|
|
923
|
+
a_fp8 = a_block * scale
|
|
924
|
+
# Clamp A to fp8 range to make sure there's no overflow.
|
|
925
|
+
# This is required for AMD. Nvidia's default saturation
|
|
926
|
+
# handles it, but it's nice to have anyway.
|
|
927
|
+
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8)
|
|
928
|
+
a_fp8.to(TL_FP8_DTYPE)
|
|
929
|
+
tl.store(A_fp8 + out_offset, a_fp8, mask=a_mask)
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
def triton_quantize_fp8_block(
|
|
933
|
+
x: torch.Tensor,
|
|
934
|
+
block_m: int = 256,
|
|
935
|
+
block_k: int = 256,
|
|
936
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
937
|
+
k_major: bool = True,
|
|
938
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
939
|
+
"""
|
|
940
|
+
Quantize a tensor to fp8 with block-wise scalings.
|
|
941
|
+
|
|
942
|
+
Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(x[i:i+block_m, j:j+block_k])))
|
|
943
|
+
|
|
944
|
+
Args:
|
|
945
|
+
x (torch.Tensor): [M, K] higher precision input tensor.
|
|
946
|
+
block_m (int): Block size for M dimension of scale.
|
|
947
|
+
block_k (int): Block size for K dimension of scale.
|
|
948
|
+
scale_ub: Maximum allowed value for scale.
|
|
949
|
+
k_major (bool): Whether output scales should be K major (True) or MN major (False).
|
|
950
|
+
|
|
951
|
+
Returns:
|
|
952
|
+
torch.Tensor : [M, K] fp8 scaled tensor.
|
|
953
|
+
torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block
|
|
954
|
+
if k_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
|
|
955
|
+
"""
|
|
956
|
+
assert x.device != torch.device("cpu"), (
|
|
957
|
+
"Blockwise quantization not support on cpu, please use row-wise quantization instead."
|
|
958
|
+
)
|
|
959
|
+
|
|
960
|
+
if scale_ub is not None and scale_ub.device != x.device:
|
|
961
|
+
raise Exception("'scale_ub' must be on the same device as 'a'")
|
|
962
|
+
|
|
963
|
+
x_shape = x.shape
|
|
964
|
+
x = x.view(-1, x.size(-1))
|
|
965
|
+
# Get constant values.
|
|
966
|
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
|
967
|
+
M, K = x.shape
|
|
968
|
+
grid_m = triton.cdiv(M, block_m)
|
|
969
|
+
grid_k = triton.cdiv(K, block_k)
|
|
970
|
+
if k_major:
|
|
971
|
+
x_scale = torch.empty((grid_m, grid_k), device=x.device, dtype=torch.float32)
|
|
972
|
+
else:
|
|
973
|
+
x_scale = torch.empty((grid_k, grid_m), device=x.device, dtype=torch.float32)
|
|
974
|
+
x_fp8 = torch.empty((M, K), device=x.device, dtype=pt_dtype)
|
|
975
|
+
|
|
976
|
+
_kernel_quantize_fp8_block[(grid_m * grid_k,)](
|
|
977
|
+
x,
|
|
978
|
+
x_scale,
|
|
979
|
+
x_fp8,
|
|
980
|
+
scale_ub,
|
|
981
|
+
M,
|
|
982
|
+
K,
|
|
983
|
+
x.stride(0),
|
|
984
|
+
x.stride(1),
|
|
985
|
+
x_fp8.stride(0),
|
|
986
|
+
x_fp8.stride(1),
|
|
987
|
+
x_scale.stride(0),
|
|
988
|
+
x_scale.stride(1),
|
|
989
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
990
|
+
TL_FP8_DTYPE=tl_dtype,
|
|
991
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
992
|
+
MAX_FP8=max_fp8,
|
|
993
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
994
|
+
EPS=eps,
|
|
995
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
996
|
+
CLAMP_MAX=scale_ub is not None,
|
|
997
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
998
|
+
BLOCK_M=block_m,
|
|
999
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
1000
|
+
BLOCK_K=block_k,
|
|
1001
|
+
# pyre-ignore[6]: Incompatible parameter type [6]
|
|
1002
|
+
K_MAJOR=k_major,
|
|
1003
|
+
)
|
|
1004
|
+
|
|
1005
|
+
return x_fp8.view(x_shape), x_scale
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
@torch.library.custom_op("triton::quantize_fp8_block", mutates_args=())
|
|
1009
|
+
def quantize_fp8_block(
|
|
1010
|
+
x: torch.Tensor,
|
|
1011
|
+
block_m: int = 256,
|
|
1012
|
+
block_k: int = 256,
|
|
1013
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
1014
|
+
use_triton: bool = True,
|
|
1015
|
+
output_device: Optional[torch.device] = None,
|
|
1016
|
+
k_major: bool = True,
|
|
1017
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1018
|
+
"""
|
|
1019
|
+
Quantize a tensor to fp8 with block-wise scalings and optionally move to output device.
|
|
1020
|
+
|
|
1021
|
+
Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(x[i:i+block_m, j:j+block_k])))
|
|
1022
|
+
|
|
1023
|
+
Args:
|
|
1024
|
+
x (Tensor): [M, K] higher precision input tensor.
|
|
1025
|
+
block_m (int): Block size for M dimension of scale.
|
|
1026
|
+
block_k (int): Block size for K dimension of scale.
|
|
1027
|
+
scale_ub: Maximum allowed value for scale.
|
|
1028
|
+
use_triton (bool): Whether to use triton kernel or pytorch.
|
|
1029
|
+
output_device (torch.device): Device to optionally move the scaled tensors to.
|
|
1030
|
+
k_major (bool): Whether output scales should be K major (True) or MN major (False).
|
|
1031
|
+
|
|
1032
|
+
Returns:
|
|
1033
|
+
torch.Tensor: [M, K] fp8 scaled tensor.
|
|
1034
|
+
torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block
|
|
1035
|
+
if k_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
|
|
1036
|
+
"""
|
|
1037
|
+
x_shape = x.shape
|
|
1038
|
+
x = x.view(-1, x.size(-1))
|
|
1039
|
+
if x.device == torch.device("cpu"):
|
|
1040
|
+
logger.info("Triton does not support cpu, falling back to torch ops.")
|
|
1041
|
+
use_triton = False
|
|
1042
|
+
if use_triton:
|
|
1043
|
+
xq, x_scale = triton_quantize_fp8_block(x, block_m, block_k, scale_ub, k_major)
|
|
1044
|
+
return xq.view(x_shape), x_scale
|
|
1045
|
+
# else use pytorch implementation.
|
|
1046
|
+
if not output_device:
|
|
1047
|
+
output_device = x.device
|
|
1048
|
+
|
|
1049
|
+
# Get constants.
|
|
1050
|
+
pt_dtype, _, max_fp8, eps = get_fp8_constants()
|
|
1051
|
+
|
|
1052
|
+
M, K = x.shape
|
|
1053
|
+
grid_m = triton.cdiv(M, block_m)
|
|
1054
|
+
grid_k = triton.cdiv(K, block_k)
|
|
1055
|
+
|
|
1056
|
+
# Pad x to multiple of block size.
|
|
1057
|
+
padded_m = grid_m * block_m
|
|
1058
|
+
padded_k = grid_k * block_k
|
|
1059
|
+
x_padded = torch.zeros(padded_m, padded_k, dtype=x.dtype, device=x.device)
|
|
1060
|
+
x_padded[:M, :K] = x
|
|
1061
|
+
|
|
1062
|
+
# Blockwise max.
|
|
1063
|
+
block_max = (
|
|
1064
|
+
x_padded.abs().reshape(grid_m, block_m, grid_k, block_k).amax(dim=(1, 3))
|
|
1065
|
+
)
|
|
1066
|
+
|
|
1067
|
+
# Apply clamping.
|
|
1068
|
+
if scale_ub is not None:
|
|
1069
|
+
block_max = torch.clamp(block_max, min=eps, max=scale_ub.item())
|
|
1070
|
+
else:
|
|
1071
|
+
block_max = torch.clamp(block_max, min=eps)
|
|
1072
|
+
x_scale = torch.empty((grid_m, grid_k), dtype=torch.float32, device=output_device)
|
|
1073
|
+
x_scale = max_fp8 / block_max.to(torch.float32) # pyre-ignore
|
|
1074
|
+
# pyre-ignore[16]: Undefined attribute [16]
|
|
1075
|
+
x_scale[x_scale == float("inf")] = 1.0
|
|
1076
|
+
x_fp8 = (
|
|
1077
|
+
x_padded
|
|
1078
|
+
# pyre-ignore[16]: Undefined attribute [16]
|
|
1079
|
+
* x_scale.repeat_interleave(block_m, dim=0).repeat_interleave(block_k, dim=1)
|
|
1080
|
+
)[:M, :K]
|
|
1081
|
+
|
|
1082
|
+
# Cast and move data to output device (for cpu weight loading).
|
|
1083
|
+
x_fp8 = x_fp8.to(device=output_device, dtype=pt_dtype)
|
|
1084
|
+
x_scale = x_scale.to(output_device) # pyre-ignore
|
|
1085
|
+
del x, x_padded
|
|
1086
|
+
if not k_major:
|
|
1087
|
+
x_scale = x_scale.t().contiguous()
|
|
1088
|
+
return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore
|
|
1089
|
+
|
|
1090
|
+
|
|
1091
|
+
@quantize_fp8_block.register_fake
|
|
1092
|
+
def quantize_fp8_block_meta(
|
|
1093
|
+
a: torch.Tensor,
|
|
1094
|
+
block_m: int = 256,
|
|
1095
|
+
block_k: int = 256,
|
|
1096
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
1097
|
+
use_triton: bool = True,
|
|
1098
|
+
output_device: Optional[torch.device] = None,
|
|
1099
|
+
k_major: bool = True,
|
|
1100
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1101
|
+
"""Shape function for torch compile."""
|
|
1102
|
+
if output_device is None:
|
|
1103
|
+
output_device = a.device
|
|
1104
|
+
a_shape = a.shape
|
|
1105
|
+
dtype = get_fp8_constants()[0]
|
|
1106
|
+
fake_out = torch.empty(a.shape, device=output_device, dtype=dtype)
|
|
1107
|
+
scale_m = triton.cdiv(a_shape[0], block_m)
|
|
1108
|
+
scale_k = triton.cdiv(a_shape[1], block_k)
|
|
1109
|
+
scale_out_shape = (
|
|
1110
|
+
a_shape[:-2] + (scale_m, scale_k) if k_major else (scale_k, scale_m)
|
|
1111
|
+
)
|
|
1112
|
+
fake_scale = torch.empty(
|
|
1113
|
+
scale_out_shape,
|
|
1114
|
+
device=output_device,
|
|
1115
|
+
dtype=torch.float32,
|
|
1116
|
+
)
|
|
1117
|
+
return fake_out, fake_scale
|
|
1118
|
+
|
|
1119
|
+
|
|
1120
|
+
@triton.autotune(
|
|
1121
|
+
configs=[
|
|
1122
|
+
Config({"GROUP_LOAD": 2}),
|
|
1123
|
+
Config({"GROUP_LOAD": 4}),
|
|
1124
|
+
Config({"GROUP_LOAD": 8}),
|
|
1125
|
+
Config({"GROUP_LOAD": 16}),
|
|
1126
|
+
Config({"GROUP_LOAD": 32}),
|
|
1127
|
+
],
|
|
1128
|
+
key=["K"],
|
|
1129
|
+
)
|
|
1130
|
+
@triton.jit
|
|
1131
|
+
def _kernel_quantize_fp8_group(
|
|
1132
|
+
A,
|
|
1133
|
+
A_scale,
|
|
1134
|
+
A_fp8,
|
|
1135
|
+
scale_ub,
|
|
1136
|
+
m_sizes,
|
|
1137
|
+
M,
|
|
1138
|
+
K,
|
|
1139
|
+
stride_am,
|
|
1140
|
+
stride_ak,
|
|
1141
|
+
stride_om,
|
|
1142
|
+
stride_ok,
|
|
1143
|
+
stride_a_scale_m,
|
|
1144
|
+
stride_a_scale_k,
|
|
1145
|
+
TL_FP8_DTYPE: tl.constexpr,
|
|
1146
|
+
MAX_FP8: tl.constexpr,
|
|
1147
|
+
EPS: tl.constexpr,
|
|
1148
|
+
CLAMP_MAX: tl.constexpr,
|
|
1149
|
+
USE_INT64: tl.constexpr,
|
|
1150
|
+
GROUP_SIZE: tl.constexpr,
|
|
1151
|
+
USE_M_MAJOR: tl.constexpr,
|
|
1152
|
+
G: tl.constexpr,
|
|
1153
|
+
GROUP_LOAD: tl.constexpr,
|
|
1154
|
+
):
|
|
1155
|
+
"""Quantize and scale each GROUP_SIZE chunk of each row.
|
|
1156
|
+
|
|
1157
|
+
Scale per group i is computed as 1 / (MAX_FP8 / max(abs(A[i:i+GROUP_SIZE])))
|
|
1158
|
+
|
|
1159
|
+
Each kernel thread is responsible for one row and loads and processes a tunable
|
|
1160
|
+
number of groups at once.
|
|
1161
|
+
|
|
1162
|
+
Args:
|
|
1163
|
+
A (Tensor): [M, K] higher precision input tensor.
|
|
1164
|
+
A_scale (Tensor): [M, cdiv(K, GROUP_SIZE)] reciprocal scale tensor per group.
|
|
1165
|
+
A_fp8 (Tensor): [M, K] fp8 scaled tensor. A_fp8 = A * a
|
|
1166
|
+
scale_ub (Tensor): [1] Maximum allowed value for scale.
|
|
1167
|
+
m_sizes (Optional[Tensor]): [G] Number of rows in each group.
|
|
1168
|
+
M (int): Number of rows.
|
|
1169
|
+
K (int): Number of columns.
|
|
1170
|
+
stride_am (int): Stride of m dimension of A.
|
|
1171
|
+
stride_ak (int): Stride of k dimension of A.
|
|
1172
|
+
stride_om (int): Stride of m dimension of output.
|
|
1173
|
+
stride_ok (int): Stride of k dimension of output.
|
|
1174
|
+
stride_a_scale_m (int): Stride of m dimension of A_scale.
|
|
1175
|
+
stride_a_scale_k (int): Stride of k dimension of A_scale.
|
|
1176
|
+
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
|
|
1177
|
+
MAX_FP8 (float): Maxmimum expressible value for FP8.
|
|
1178
|
+
EPS (float): Epsilon value for numerical stability.
|
|
1179
|
+
CLAMP_MAX (bool): Whether to apply scale_ub.
|
|
1180
|
+
USE_INT64 (bool): Whether to index using int64, which may be needed for large tensors.
|
|
1181
|
+
GROUP_SIZE (int): Group size for K dimension of A_scale and kernel.
|
|
1182
|
+
USE_M_MAJOR (bool): Whether to use grouped M-major layout for A_scale.
|
|
1183
|
+
G (int): Number of groups in A_scale, only relevant when m_sizes is provided.
|
|
1184
|
+
GROUP_LOAD (int): Number of groups to load and process simultaneously.
|
|
1185
|
+
"""
|
|
1186
|
+
pid = tl.program_id(0)
|
|
1187
|
+
if USE_INT64:
|
|
1188
|
+
pid = pid.to(tl.int64)
|
|
1189
|
+
# We load group_size * group_load chunks at a time.
|
|
1190
|
+
row_offset = pid * stride_am
|
|
1191
|
+
out_offset = pid * stride_om
|
|
1192
|
+
scale_row_offset = pid * stride_a_scale_m
|
|
1193
|
+
k_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE)
|
|
1194
|
+
scale_k_offset = tl.arange(0, GROUP_LOAD)
|
|
1195
|
+
NUM_GROUPS: tl.constexpr = K // GROUP_SIZE
|
|
1196
|
+
|
|
1197
|
+
# When dealing with an M-major grouped gemm, we need to figure out
|
|
1198
|
+
# which group this thread corresponds to and figure out the corresponding
|
|
1199
|
+
# scale offset.
|
|
1200
|
+
group_offset = 0
|
|
1201
|
+
group_cumsum = 0
|
|
1202
|
+
group_M = 0
|
|
1203
|
+
stop = False
|
|
1204
|
+
if USE_M_MAJOR and G > 0:
|
|
1205
|
+
# Iterate over groups to both compute the cumulative sum and find which group we are in.
|
|
1206
|
+
for i in range(G):
|
|
1207
|
+
if not stop:
|
|
1208
|
+
group_M = tl.cast(tl.load(m_sizes + i), pid.dtype)
|
|
1209
|
+
if (group_cumsum + group_M) <= pid:
|
|
1210
|
+
group_cumsum += group_M
|
|
1211
|
+
else:
|
|
1212
|
+
# Indicate we are finished computing cumsum.
|
|
1213
|
+
stop = True
|
|
1214
|
+
|
|
1215
|
+
group_offset = group_cumsum * NUM_GROUPS
|
|
1216
|
+
|
|
1217
|
+
for k in range(0, tl.cdiv(K, (GROUP_LOAD * GROUP_SIZE))):
|
|
1218
|
+
# Load groups of the input.
|
|
1219
|
+
chunk_offset = k_offset + k * GROUP_LOAD * GROUP_SIZE
|
|
1220
|
+
a = tl.load(
|
|
1221
|
+
A + row_offset + chunk_offset * stride_ak, mask=chunk_offset < K, other=0.0
|
|
1222
|
+
)
|
|
1223
|
+
# View loaded chunk as a set of groups.
|
|
1224
|
+
a_grouped = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE])
|
|
1225
|
+
# Reduce over groups.
|
|
1226
|
+
group_max = tl.max(tl.abs(a_grouped), axis=1)
|
|
1227
|
+
# Apply clamping if specified.
|
|
1228
|
+
if CLAMP_MAX:
|
|
1229
|
+
ub = tl.load(scale_ub)
|
|
1230
|
+
group_max = tl.clamp(group_max, EPS, ub)
|
|
1231
|
+
else:
|
|
1232
|
+
group_max = tl.maximum(group_max, EPS)
|
|
1233
|
+
# Scale and quantize.
|
|
1234
|
+
a_scale = MAX_FP8 / group_max
|
|
1235
|
+
scale_chunk_offset = scale_k_offset + k * GROUP_LOAD
|
|
1236
|
+
|
|
1237
|
+
if USE_M_MAJOR and G > 0:
|
|
1238
|
+
tl.store(
|
|
1239
|
+
A_scale
|
|
1240
|
+
+ group_offset
|
|
1241
|
+
+ (pid - group_cumsum) * stride_a_scale_k
|
|
1242
|
+
+ (scale_chunk_offset * group_M),
|
|
1243
|
+
1.0 / a_scale,
|
|
1244
|
+
mask=scale_chunk_offset < NUM_GROUPS,
|
|
1245
|
+
)
|
|
1246
|
+
else:
|
|
1247
|
+
if USE_M_MAJOR:
|
|
1248
|
+
tl.store(
|
|
1249
|
+
A_scale
|
|
1250
|
+
+ pid * stride_a_scale_k
|
|
1251
|
+
+ scale_chunk_offset * stride_a_scale_m,
|
|
1252
|
+
1.0 / a_scale,
|
|
1253
|
+
mask=scale_chunk_offset < NUM_GROUPS,
|
|
1254
|
+
)
|
|
1255
|
+
else:
|
|
1256
|
+
tl.store(
|
|
1257
|
+
A_scale + scale_row_offset + scale_chunk_offset * stride_a_scale_k,
|
|
1258
|
+
1.0 / a_scale,
|
|
1259
|
+
mask=scale_chunk_offset < NUM_GROUPS,
|
|
1260
|
+
)
|
|
1261
|
+
# Apply scale to input.
|
|
1262
|
+
a_fp8 = a_grouped * a_scale[:, None]
|
|
1263
|
+
# Clamp to FP8 range to avoid overflow
|
|
1264
|
+
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
|
|
1265
|
+
# Write to output.
|
|
1266
|
+
tl.store(
|
|
1267
|
+
A_fp8 + out_offset + chunk_offset * stride_ok,
|
|
1268
|
+
tl.ravel(a_fp8),
|
|
1269
|
+
mask=chunk_offset < K,
|
|
1270
|
+
)
|
|
1271
|
+
|
|
1272
|
+
|
|
1273
|
+
def triton_quantize_fp8_group(
|
|
1274
|
+
x: torch.Tensor,
|
|
1275
|
+
group_size: int = 128,
|
|
1276
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
1277
|
+
m_sizes: Optional[torch.Tensor] = None,
|
|
1278
|
+
k_major: bool = True,
|
|
1279
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1280
|
+
"""
|
|
1281
|
+
Quantize a tensor to fp8 with group-wise scalings.
|
|
1282
|
+
|
|
1283
|
+
Scale per group i is computed as 1 / (MAX_FP8 / max(abs(x[i:i+group_size])))
|
|
1284
|
+
|
|
1285
|
+
Args:
|
|
1286
|
+
x (torch.Tensor): [M, K] higher precision input tensor.
|
|
1287
|
+
group_size (int): Group size for M dimension of scale.
|
|
1288
|
+
scale_ub: Maximum allowed value for scale.
|
|
1289
|
+
m_sizes: Optional input for grouped gemm to specify the number of rows in each group.
|
|
1290
|
+
k_major (bool): Whether output scales should be K major (True) or MN major (False).
|
|
1291
|
+
|
|
1292
|
+
Returns:
|
|
1293
|
+
torch.Tensor: [M, K] fp8 scaled tensor.
|
|
1294
|
+
torch.Tensor: [M, cdiv(K, group_size)] reciprocal scale tensor per group.
|
|
1295
|
+
"""
|
|
1296
|
+
assert x.device != torch.device("cpu"), (
|
|
1297
|
+
"Triton groupwise quantization not supported on cpu."
|
|
1298
|
+
)
|
|
1299
|
+
|
|
1300
|
+
if scale_ub is not None and scale_ub.device != x.device:
|
|
1301
|
+
raise Exception("'scale_ub' must be on the same device as 'a'")
|
|
1302
|
+
if m_sizes is not None and m_sizes.device != x.device:
|
|
1303
|
+
raise Exception("'m_sizes' must be on the same device as 'a'")
|
|
1304
|
+
|
|
1305
|
+
x_shape = x.shape
|
|
1306
|
+
x = x.view(-1, x.size(-1))
|
|
1307
|
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
|
1308
|
+
M, K = x.shape
|
|
1309
|
+
k_groups = triton.cdiv(K, group_size)
|
|
1310
|
+
if k_major:
|
|
1311
|
+
x_scale = torch.empty((M, k_groups), device=x.device, dtype=torch.float32)
|
|
1312
|
+
else:
|
|
1313
|
+
x_scale = torch.empty((k_groups, M), device=x.device, dtype=torch.float32)
|
|
1314
|
+
x_fp8 = torch.empty((M, K), device=x.device, dtype=pt_dtype)
|
|
1315
|
+
_kernel_quantize_fp8_group[(M,)](
|
|
1316
|
+
x,
|
|
1317
|
+
x_scale,
|
|
1318
|
+
x_fp8,
|
|
1319
|
+
scale_ub,
|
|
1320
|
+
m_sizes,
|
|
1321
|
+
M,
|
|
1322
|
+
K,
|
|
1323
|
+
x.stride(0),
|
|
1324
|
+
x.stride(1),
|
|
1325
|
+
x_fp8.stride(0),
|
|
1326
|
+
x_fp8.stride(1),
|
|
1327
|
+
x_scale.stride(0),
|
|
1328
|
+
x_scale.stride(1),
|
|
1329
|
+
TL_FP8_DTYPE=tl_dtype,
|
|
1330
|
+
MAX_FP8=max_fp8,
|
|
1331
|
+
EPS=eps,
|
|
1332
|
+
CLAMP_MAX=scale_ub is not None,
|
|
1333
|
+
USE_INT64=x.numel() > (2**32 - 1),
|
|
1334
|
+
GROUP_SIZE=group_size,
|
|
1335
|
+
USE_M_MAJOR=m_sizes is not None or k_major is False,
|
|
1336
|
+
G=m_sizes.numel() if m_sizes is not None else 0,
|
|
1337
|
+
)
|
|
1338
|
+
return x_fp8.view(x_shape), x_scale
|
|
1339
|
+
|
|
1340
|
+
|
|
1341
|
+
def quantize_fp8_group(
|
|
1342
|
+
x: torch.Tensor,
|
|
1343
|
+
group_size: int = 128,
|
|
1344
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
1345
|
+
m_sizes: Optional[torch.Tensor] = None,
|
|
1346
|
+
k_major: bool = True,
|
|
1347
|
+
use_triton: bool = True,
|
|
1348
|
+
output_device: Optional[torch.device] = None,
|
|
1349
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1350
|
+
"""
|
|
1351
|
+
Quantize a tensor to fp8 with group-wise scalings and optionally move to output device.
|
|
1352
|
+
|
|
1353
|
+
Scale per group i is computed as 1 / (MAX_FP8 / max(abs(x[i:i+group_size])))
|
|
1354
|
+
|
|
1355
|
+
Args:
|
|
1356
|
+
x (Tensor): [M, K] higher precision input tensor.
|
|
1357
|
+
group_size (int): Group size for M dimension of scale.
|
|
1358
|
+
scale_ub: Maximum allowed value for scale.
|
|
1359
|
+
m_sizes: Optional input for grouped gemm to specify the number of rows in each group.
|
|
1360
|
+
k_major (bool): Whether output scales should be K major (True) or MN major (False).
|
|
1361
|
+
This is needed because some kernels like cutlass require a special layout for scales.
|
|
1362
|
+
use_triton (bool): Whether to use triton kernel or pytorch.
|
|
1363
|
+
output_device (torch.device): Device to optionally move the scaled tensors to.
|
|
1364
|
+
|
|
1365
|
+
Returns:
|
|
1366
|
+
torch.Tensor: [M, K] fp8 scaled tensor.
|
|
1367
|
+
torch.Tensor: [M, cdiv(K, group_size)] reciprocal scale tensor per group.
|
|
1368
|
+
"""
|
|
1369
|
+
x_shape = x.shape
|
|
1370
|
+
x = x.view(-1, x.size(-1))
|
|
1371
|
+
if x.device == torch.device("cpu"):
|
|
1372
|
+
logger.info("Triton does not support cpu, falling back to torch ops.")
|
|
1373
|
+
use_triton = False
|
|
1374
|
+
if use_triton:
|
|
1375
|
+
xq, x_scale = triton_quantize_fp8_group(
|
|
1376
|
+
x, group_size, scale_ub, m_sizes, k_major
|
|
1377
|
+
)
|
|
1378
|
+
return xq.view(x_shape), x_scale
|
|
1379
|
+
# else use pytorch implementation.
|
|
1380
|
+
if not output_device:
|
|
1381
|
+
output_device = x.device
|
|
1382
|
+
|
|
1383
|
+
# Get constants.
|
|
1384
|
+
pt_dtype, _, max_fp8, eps = get_fp8_constants()
|
|
1385
|
+
|
|
1386
|
+
M, K = x.shape
|
|
1387
|
+
assert K % group_size == 0, (
|
|
1388
|
+
"K must be divisible by group_size for cpu implementation."
|
|
1389
|
+
)
|
|
1390
|
+
assert m_sizes is None, "m_sizes is not supported for cpu implementation."
|
|
1391
|
+
k_groups = triton.cdiv(K, group_size)
|
|
1392
|
+
# View input as colleciton of groups for reduction.
|
|
1393
|
+
x_grouped = x.view(M, k_groups, group_size).to(torch.float32)
|
|
1394
|
+
# Reduce over groups.
|
|
1395
|
+
group_max = x_grouped.abs().amax(dim=2)
|
|
1396
|
+
# Apply clamping.
|
|
1397
|
+
group_max = (
|
|
1398
|
+
torch.clamp(group_max, min=eps, max=scale_ub.item())
|
|
1399
|
+
if scale_ub
|
|
1400
|
+
else torch.clamp(group_max, min=eps)
|
|
1401
|
+
)
|
|
1402
|
+
x_scale = torch.empty((M, k_groups), dtype=torch.float32, device=output_device)
|
|
1403
|
+
x_scale = max_fp8 / group_max # pyre-ignore
|
|
1404
|
+
# pyre-ignore[16]: Undefined attribute [16]
|
|
1405
|
+
x_scale[x_scale == float("inf")] = 1.0
|
|
1406
|
+
# pyre-ignore[16]: Undefined attribute [16]
|
|
1407
|
+
x_fp8 = x.view(-1, k_groups, group_size) * x_scale.unsqueeze(2)
|
|
1408
|
+
# Cast and move data to output device (for cpu weight loading).
|
|
1409
|
+
x_fp8 = x_fp8.to(device=output_device, dtype=pt_dtype)
|
|
1410
|
+
x_scale = x_scale.to(output_device) # pyre-ignore
|
|
1411
|
+
if not k_major:
|
|
1412
|
+
x_scale = x_scale.t().contiguous()
|
|
1413
|
+
return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore
|
|
1414
|
+
|
|
1415
|
+
|
|
1416
|
+
@triton.autotune(
|
|
1417
|
+
configs=[Config({"BLOCK_M": 16, "BLOCK_K": 512, "NUM_STAGES": 2})],
|
|
1418
|
+
key=["M", "K"],
|
|
1419
|
+
)
|
|
1420
|
+
@triton.jit
|
|
1421
|
+
def _kernel_dequantize_fp8_row(
|
|
1422
|
+
xq_ptr,
|
|
1423
|
+
x_scale_ptr,
|
|
1424
|
+
x_dequant_ptr,
|
|
1425
|
+
M,
|
|
1426
|
+
K,
|
|
1427
|
+
stride_xm,
|
|
1428
|
+
stride_xk,
|
|
1429
|
+
stride_xdqm,
|
|
1430
|
+
stride_xdqk,
|
|
1431
|
+
BLOCK_M: tl.constexpr,
|
|
1432
|
+
BLOCK_K: tl.constexpr,
|
|
1433
|
+
NUM_STAGES: tl.constexpr,
|
|
1434
|
+
USE_INT64: tl.constexpr,
|
|
1435
|
+
):
|
|
1436
|
+
"""
|
|
1437
|
+
Kernel to dequantize FP8 tensor to BF16 tensor.
|
|
1438
|
+
Args:
|
|
1439
|
+
xq_ptr (tl.constexpr): Pointer to FP8 tensor.
|
|
1440
|
+
x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
|
|
1441
|
+
x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
|
|
1442
|
+
M (tl.constexpr): M dimension of input tensor.
|
|
1443
|
+
K (tl.constexpr): K dimension of input tensor (along which scales are applied)
|
|
1444
|
+
BLOCK_SIZE (tl.constexpr): Block size for the K dimension.
|
|
1445
|
+
"""
|
|
1446
|
+
pid = tl.program_id(axis=0)
|
|
1447
|
+
if USE_INT64:
|
|
1448
|
+
pid = pid.to(tl.int64)
|
|
1449
|
+
offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
1450
|
+
offs_k = tl.arange(0, BLOCK_K)
|
|
1451
|
+
scales = tl.load(x_scale_ptr + offs_m)
|
|
1452
|
+
|
|
1453
|
+
for _k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
|
|
1454
|
+
mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
|
|
1455
|
+
xq = tl.load(
|
|
1456
|
+
xq_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk,
|
|
1457
|
+
mask=mask,
|
|
1458
|
+
)
|
|
1459
|
+
x_dq = xq * scales[:, None]
|
|
1460
|
+
tl.store(
|
|
1461
|
+
x_dequant_ptr
|
|
1462
|
+
+ offs_m[:, None] * stride_xdqm
|
|
1463
|
+
+ offs_k[None, :] * stride_xdqk,
|
|
1464
|
+
x_dq,
|
|
1465
|
+
mask=mask,
|
|
1466
|
+
)
|
|
1467
|
+
offs_k += BLOCK_K
|
|
1468
|
+
|
|
1469
|
+
|
|
1470
|
+
def dequantize_fp8_row(
|
|
1471
|
+
xq: torch.Tensor,
|
|
1472
|
+
x_scale: torch.Tensor,
|
|
1473
|
+
) -> torch.Tensor:
|
|
1474
|
+
"""
|
|
1475
|
+
Rowwise Dequantize FP8 tensor to BF16 tensor along last axis.
|
|
1476
|
+
|
|
1477
|
+
Args:
|
|
1478
|
+
xq (torch.Tensor): FP8 tensor to be dequantized.
|
|
1479
|
+
x_scale (torch.Tensor): FP8 scale tensor.
|
|
1480
|
+
|
|
1481
|
+
Returns:
|
|
1482
|
+
torch.Tensor: Dequantized BF16 tensor.
|
|
1483
|
+
"""
|
|
1484
|
+
|
|
1485
|
+
assert xq.is_contiguous() and x_scale.is_contiguous(), (
|
|
1486
|
+
"Input tensors must be contiguous"
|
|
1487
|
+
)
|
|
1488
|
+
x_dequant = torch.empty_like(xq, dtype=torch.bfloat16)
|
|
1489
|
+
|
|
1490
|
+
# Reshape to 2-d array keeping last dim only.
|
|
1491
|
+
K = xq.shape[-1]
|
|
1492
|
+
xq = xq.reshape(-1, K)
|
|
1493
|
+
M = xq.shape[0]
|
|
1494
|
+
use_int64 = xq.numel() > 2**31
|
|
1495
|
+
|
|
1496
|
+
def grid(meta: Dict[str, int]) -> Tuple[int]:
|
|
1497
|
+
return (triton.cdiv(M, meta["BLOCK_M"]),)
|
|
1498
|
+
|
|
1499
|
+
with torch.cuda.device(xq.device.index):
|
|
1500
|
+
_kernel_dequantize_fp8_row[grid](
|
|
1501
|
+
xq,
|
|
1502
|
+
x_scale,
|
|
1503
|
+
x_dequant,
|
|
1504
|
+
M,
|
|
1505
|
+
K,
|
|
1506
|
+
xq.stride(0),
|
|
1507
|
+
xq.stride(1),
|
|
1508
|
+
xq.stride(0), # Use squashed stride.
|
|
1509
|
+
xq.stride(1),
|
|
1510
|
+
USE_INT64=use_int64,
|
|
1511
|
+
)
|
|
1512
|
+
return x_dequant
|
|
1513
|
+
|
|
1514
|
+
|
|
1515
|
+
@triton.autotune(
|
|
1516
|
+
configs=[Config({"BLOCK_M": 16, "BLOCK_K": 512, "NUM_STAGES": 2})],
|
|
1517
|
+
key=["M", "K"],
|
|
1518
|
+
)
|
|
1519
|
+
@triton.jit
|
|
1520
|
+
def _kernel_dequantize_fp8_packed_row(
|
|
1521
|
+
xq_ptr,
|
|
1522
|
+
x_scale_ptr,
|
|
1523
|
+
x_dequant_ptr,
|
|
1524
|
+
M,
|
|
1525
|
+
K,
|
|
1526
|
+
stride_xm,
|
|
1527
|
+
stride_xk,
|
|
1528
|
+
stride_xdqm,
|
|
1529
|
+
stride_xdqk,
|
|
1530
|
+
BLOCK_M: tl.constexpr,
|
|
1531
|
+
BLOCK_K: tl.constexpr,
|
|
1532
|
+
NUM_STAGES: tl.constexpr,
|
|
1533
|
+
USE_INT64: tl.constexpr,
|
|
1534
|
+
):
|
|
1535
|
+
"""
|
|
1536
|
+
Kernel to dequantize FP8 tensor to BF16 tensor.
|
|
1537
|
+
Args:
|
|
1538
|
+
xq_ptr (tl.constexpr): Pointer to FP8 tensor.
|
|
1539
|
+
x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
|
|
1540
|
+
x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
|
|
1541
|
+
M (tl.constexpr): M dimension of input tensor.
|
|
1542
|
+
K (tl.constexpr): K dimension of input tensor (along which scales are applied)
|
|
1543
|
+
BLOCK_SIZE (tl.constexpr): Block size for the K dimension.
|
|
1544
|
+
"""
|
|
1545
|
+
pid = tl.program_id(axis=0)
|
|
1546
|
+
if USE_INT64:
|
|
1547
|
+
pid = pid.to(tl.int64)
|
|
1548
|
+
offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
1549
|
+
offs_k = tl.arange(0, BLOCK_K)
|
|
1550
|
+
scales = tl.load(x_scale_ptr + offs_m)
|
|
1551
|
+
|
|
1552
|
+
for _k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
|
|
1553
|
+
mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
|
|
1554
|
+
|
|
1555
|
+
xq = tl.load(
|
|
1556
|
+
xq_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk,
|
|
1557
|
+
mask=mask,
|
|
1558
|
+
other=0.0,
|
|
1559
|
+
)
|
|
1560
|
+
x_dq = xq * scales[:, None]
|
|
1561
|
+
|
|
1562
|
+
tl.store(
|
|
1563
|
+
x_dequant_ptr
|
|
1564
|
+
+ offs_m[:, None] * stride_xdqm
|
|
1565
|
+
+ offs_k[None, :] * stride_xdqk,
|
|
1566
|
+
x_dq,
|
|
1567
|
+
mask=mask,
|
|
1568
|
+
)
|
|
1569
|
+
offs_k += BLOCK_K
|
|
1570
|
+
|
|
1571
|
+
|
|
1572
|
+
def dequantize_fp8_packed_row(
|
|
1573
|
+
xq: torch.Tensor,
|
|
1574
|
+
) -> torch.Tensor:
|
|
1575
|
+
"""
|
|
1576
|
+
Rowwise Dequantize FP8 tensor to BF16 tensor along last axis.
|
|
1577
|
+
|
|
1578
|
+
Args:
|
|
1579
|
+
xq (torch.Tensor): Packed FP8 tensor to be dequantized. The last 4 bytes of each row is the FP32 scale for that row.
|
|
1580
|
+
|
|
1581
|
+
Returns:
|
|
1582
|
+
torch.Tensor: Dequantized BF16 tensor.
|
|
1583
|
+
"""
|
|
1584
|
+
|
|
1585
|
+
# Create a view of the packed tensors, get the scale and actual xq tensor
|
|
1586
|
+
# This makes it much easier to write the kernel
|
|
1587
|
+
orig_shape = (*xq.shape[:-1], xq.shape[-1] - 4)
|
|
1588
|
+
actual_xq = xq[..., :-4].view(orig_shape)
|
|
1589
|
+
|
|
1590
|
+
assert xq.is_contiguous(), "Input tensors must be contiguous"
|
|
1591
|
+
x_dequant = torch.empty(orig_shape, dtype=torch.bfloat16, device=xq.device)
|
|
1592
|
+
|
|
1593
|
+
# Calculate number of rows when flattened
|
|
1594
|
+
num_rows = actual_xq.numel() // actual_xq.shape[-1]
|
|
1595
|
+
|
|
1596
|
+
# TODO: we take a perf hit from these reshapes, can we do better?
|
|
1597
|
+
# It's hard to skip this reshape, we can't create a int32/float32 view because of alignment issues
|
|
1598
|
+
scale_view = xq[..., -4:].reshape((num_rows * 4)).view(torch.float32)
|
|
1599
|
+
scale_view = scale_view.view(orig_shape[:-1])
|
|
1600
|
+
|
|
1601
|
+
# Reshape to 2-d array keeping last dim only.
|
|
1602
|
+
K = actual_xq.shape[-1]
|
|
1603
|
+
actual_xq = actual_xq.reshape(-1, K)
|
|
1604
|
+
M = actual_xq.shape[0]
|
|
1605
|
+
use_int64 = actual_xq.numel() > 2**31
|
|
1606
|
+
|
|
1607
|
+
def grid(meta: Dict[str, int]) -> Tuple[int]:
|
|
1608
|
+
return (triton.cdiv(M, meta["BLOCK_M"]),)
|
|
1609
|
+
|
|
1610
|
+
with torch.cuda.device(actual_xq.device.index):
|
|
1611
|
+
_kernel_dequantize_fp8_packed_row[grid](
|
|
1612
|
+
actual_xq,
|
|
1613
|
+
scale_view,
|
|
1614
|
+
x_dequant,
|
|
1615
|
+
M,
|
|
1616
|
+
K,
|
|
1617
|
+
actual_xq.stride(0),
|
|
1618
|
+
actual_xq.stride(1),
|
|
1619
|
+
x_dequant.stride(-2), # Use squashed stride.
|
|
1620
|
+
x_dequant.stride(-1),
|
|
1621
|
+
USE_INT64=use_int64,
|
|
1622
|
+
)
|
|
1623
|
+
|
|
1624
|
+
return x_dequant
|
|
1625
|
+
|
|
1626
|
+
|
|
1627
|
+
@triton.jit
|
|
1628
|
+
def _kernel_quantize_fp8_tensor(
|
|
1629
|
+
A,
|
|
1630
|
+
A_fp8,
|
|
1631
|
+
global_max_ptr,
|
|
1632
|
+
blocks_done_ptr,
|
|
1633
|
+
scale_ready_ptr,
|
|
1634
|
+
scale_out_ptr,
|
|
1635
|
+
N,
|
|
1636
|
+
num_sms,
|
|
1637
|
+
TL_FP8_DTYPE: tl.constexpr,
|
|
1638
|
+
MAX_FP8: tl.constexpr,
|
|
1639
|
+
EPS: tl.constexpr,
|
|
1640
|
+
BLOCK_SIZE: tl.constexpr,
|
|
1641
|
+
) -> None:
|
|
1642
|
+
"""Fused persistent kernel that finds global max and quantizes.
|
|
1643
|
+
|
|
1644
|
+
Uses a persistent kernel approach where we launch exactly num_sms blocks,
|
|
1645
|
+
guaranteeing all blocks run concurrently and avoiding deadlocks.
|
|
1646
|
+
Each block processes multiple chunks of the input in a loop.
|
|
1647
|
+
|
|
1648
|
+
Args:
|
|
1649
|
+
A (Tensor): Flattened input tensor.
|
|
1650
|
+
A_fp8 (Tensor): Output fp8 tensor.
|
|
1651
|
+
global_max_ptr (Tensor): Pointer to global max value (initialized to 0).
|
|
1652
|
+
blocks_done_ptr (Tensor): Pointer to atomic counter (initialized to 0).
|
|
1653
|
+
scale_ready_ptr (Tensor): Pointer to ready flag (initialized to 0).
|
|
1654
|
+
scale_out_ptr (Tensor): Pointer to output scale value.
|
|
1655
|
+
N (int): Total number of elements.
|
|
1656
|
+
num_sms (int): Number of SMs (equals number of blocks launched).
|
|
1657
|
+
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
|
|
1658
|
+
MAX_FP8 (float): Maximum expressible value for FP8.
|
|
1659
|
+
EPS (float): Epsilon for numerical stability.
|
|
1660
|
+
BLOCK_SIZE (int): Block size for processing.
|
|
1661
|
+
"""
|
|
1662
|
+
pid = tl.program_id(0)
|
|
1663
|
+
|
|
1664
|
+
# Phase 1: Each block finds max across all its assigned chunks
|
|
1665
|
+
local_max = 0.0
|
|
1666
|
+
chunk_id = pid
|
|
1667
|
+
num_chunks = tl.cdiv(N, BLOCK_SIZE)
|
|
1668
|
+
|
|
1669
|
+
while chunk_id < num_chunks:
|
|
1670
|
+
offset = chunk_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
1671
|
+
a = tl.load(A + offset, mask=offset < N, other=0.0)
|
|
1672
|
+
chunk_max = tl.max(tl.abs(a))
|
|
1673
|
+
local_max = tl.maximum(local_max, chunk_max)
|
|
1674
|
+
chunk_id += num_sms
|
|
1675
|
+
|
|
1676
|
+
# Atomically update global max using integer atomics on float bits
|
|
1677
|
+
local_max_int = local_max.to(tl.float32, bitcast=False).to(tl.int32, bitcast=True)
|
|
1678
|
+
tl.atomic_max(global_max_ptr, local_max_int)
|
|
1679
|
+
|
|
1680
|
+
# Increment completed block counter
|
|
1681
|
+
old_count = tl.atomic_add(blocks_done_ptr, 1)
|
|
1682
|
+
|
|
1683
|
+
# Last block to finish computes the scale
|
|
1684
|
+
if old_count == num_sms - 1:
|
|
1685
|
+
global_max_int = tl.load(global_max_ptr)
|
|
1686
|
+
global_max_float = global_max_int.to(tl.float32, bitcast=True)
|
|
1687
|
+
global_max_float = tl.maximum(global_max_float, EPS)
|
|
1688
|
+
scale = tl.div_rn(global_max_float, MAX_FP8)
|
|
1689
|
+
tl.store(scale_out_ptr, scale)
|
|
1690
|
+
tl.atomic_xchg(scale_ready_ptr, 1)
|
|
1691
|
+
|
|
1692
|
+
# Phase 2: Spin-wait for scale to be ready
|
|
1693
|
+
# Safe because all num_sms blocks are guaranteed to be running
|
|
1694
|
+
while tl.atomic_add(scale_ready_ptr, 0) == 0:
|
|
1695
|
+
pass
|
|
1696
|
+
|
|
1697
|
+
# Load scale and quantize all assigned chunks
|
|
1698
|
+
scale = tl.load(scale_out_ptr)
|
|
1699
|
+
chunk_id = pid
|
|
1700
|
+
|
|
1701
|
+
while chunk_id < num_chunks:
|
|
1702
|
+
offset = chunk_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
1703
|
+
a = tl.load(A + offset, mask=offset < N, other=0.0)
|
|
1704
|
+
a_fp8 = a * tl.div_rn(1.0, scale)
|
|
1705
|
+
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
|
|
1706
|
+
tl.store(A_fp8 + offset, a_fp8, mask=offset < N)
|
|
1707
|
+
chunk_id += num_sms
|
|
1708
|
+
|
|
1709
|
+
|
|
1710
|
+
def _get_num_sms(device: torch.device) -> int:
|
|
1711
|
+
"""Get the number of SMs on the current GPU device."""
|
|
1712
|
+
return torch.cuda.get_device_properties(device).multi_processor_count
|
|
1713
|
+
|
|
1714
|
+
|
|
1715
|
+
def triton_quantize_fp8_tensor(
|
|
1716
|
+
a: torch.Tensor,
|
|
1717
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1718
|
+
"""
|
|
1719
|
+
Triton implementation to quantize a tensor to fp8 with a single scale.
|
|
1720
|
+
|
|
1721
|
+
Uses a fused persistent kernel with atomic operations for inter-block
|
|
1722
|
+
coordination. By launching exactly num_sms blocks, we guarantee all
|
|
1723
|
+
blocks run concurrently, avoiding deadlocks from spin-waiting.
|
|
1724
|
+
|
|
1725
|
+
Args:
|
|
1726
|
+
a (Tensor): Input tensor to be quantized.
|
|
1727
|
+
|
|
1728
|
+
Returns:
|
|
1729
|
+
torch.Tensor: fp8 quantized tensor.
|
|
1730
|
+
torch.Tensor: scalar reciprocal scale tensor (fp32).
|
|
1731
|
+
"""
|
|
1732
|
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
|
1733
|
+
N = a.numel()
|
|
1734
|
+
|
|
1735
|
+
BLOCK_SIZE = 4096
|
|
1736
|
+
# Launch exactly num_sms blocks to guarantee concurrent execution
|
|
1737
|
+
num_sms = _get_num_sms(a.device)
|
|
1738
|
+
|
|
1739
|
+
# Allocate synchronization buffers (initialized to 0)
|
|
1740
|
+
global_max = torch.zeros(1, device=a.device, dtype=torch.int32)
|
|
1741
|
+
blocks_done = torch.zeros(1, device=a.device, dtype=torch.int32)
|
|
1742
|
+
scale_ready = torch.zeros(1, device=a.device, dtype=torch.int32)
|
|
1743
|
+
scale_out = torch.empty((), device=a.device, dtype=torch.float32)
|
|
1744
|
+
|
|
1745
|
+
# Output tensor matches shape of a but is contiguous.
|
|
1746
|
+
a_fp8 = torch.empty_like(a, dtype=pt_dtype)
|
|
1747
|
+
|
|
1748
|
+
with torch.cuda.device(a.device.index):
|
|
1749
|
+
_kernel_quantize_fp8_tensor[(num_sms,)](
|
|
1750
|
+
a,
|
|
1751
|
+
a_fp8,
|
|
1752
|
+
global_max,
|
|
1753
|
+
blocks_done,
|
|
1754
|
+
scale_ready,
|
|
1755
|
+
scale_out,
|
|
1756
|
+
N,
|
|
1757
|
+
num_sms,
|
|
1758
|
+
# pyre-ignore[6]: Incompatible parameter type
|
|
1759
|
+
TL_FP8_DTYPE=tl_dtype,
|
|
1760
|
+
# pyre-ignore[6]: Incompatible parameter type
|
|
1761
|
+
MAX_FP8=max_fp8,
|
|
1762
|
+
# pyre-ignore[6]: Incompatible parameter type
|
|
1763
|
+
EPS=eps,
|
|
1764
|
+
# pyre-ignore[6]: Incompatible parameter type
|
|
1765
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
1766
|
+
)
|
|
1767
|
+
|
|
1768
|
+
return a_fp8, scale_out
|
|
1769
|
+
|
|
1770
|
+
|
|
1771
|
+
@torch.library.custom_op("triton::quantize_fp8_tensor", mutates_args=())
|
|
1772
|
+
def quantize_fp8_tensor(
|
|
1773
|
+
a: torch.Tensor,
|
|
1774
|
+
use_triton: bool = True,
|
|
1775
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1776
|
+
"""
|
|
1777
|
+
Quantize a tensor to fp8 with a single scale factor across the entire tensor.
|
|
1778
|
+
|
|
1779
|
+
The scale is computed as MAX_FP8 / max(abs(a)) and applied uniformly.
|
|
1780
|
+
Handles non-contiguous input tensors and returns a contiguous output.
|
|
1781
|
+
|
|
1782
|
+
Args:
|
|
1783
|
+
a (Tensor): Input tensor of any shape. May be non-contiguous.
|
|
1784
|
+
use_triton (bool): Whether to use optimized triton kernel.
|
|
1785
|
+
|
|
1786
|
+
Returns:
|
|
1787
|
+
torch.Tensor: fp8 quantized tensor (contiguous, same shape as input).
|
|
1788
|
+
torch.Tensor: scalar reciprocal scale tensor (fp32).
|
|
1789
|
+
"""
|
|
1790
|
+
if a.device == torch.device("cpu"):
|
|
1791
|
+
use_triton = False
|
|
1792
|
+
|
|
1793
|
+
if use_triton:
|
|
1794
|
+
a_fp8, reciprocal_scale = triton_quantize_fp8_tensor(a)
|
|
1795
|
+
return a_fp8, reciprocal_scale
|
|
1796
|
+
|
|
1797
|
+
# Fallback to PyTorch implementation
|
|
1798
|
+
pt_dtype, _, max_fp8, eps = get_fp8_constants()
|
|
1799
|
+
|
|
1800
|
+
tensor_max = torch.max(torch.abs(a)).to(torch.float32)
|
|
1801
|
+
tensor_max = torch.clamp(tensor_max, min=eps)
|
|
1802
|
+
|
|
1803
|
+
scale = max_fp8 / tensor_max # pyre-ignore[58]
|
|
1804
|
+
a_scaled = a.to(torch.float32) * scale
|
|
1805
|
+
a_scaled = torch.clamp(a_scaled, -max_fp8, max_fp8)
|
|
1806
|
+
a_fp8 = a_scaled.to(pt_dtype)
|
|
1807
|
+
|
|
1808
|
+
reciprocal_scale = (1.0 / scale).to(torch.float32) # pyre-ignore[16]
|
|
1809
|
+
|
|
1810
|
+
return a_fp8, reciprocal_scale
|
|
1811
|
+
|
|
1812
|
+
|
|
1813
|
+
@quantize_fp8_tensor.register_fake
|
|
1814
|
+
def quantize_fp8_tensor_meta(
|
|
1815
|
+
a: torch.Tensor,
|
|
1816
|
+
use_triton: bool = True,
|
|
1817
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1818
|
+
"""Shape function for torch compile."""
|
|
1819
|
+
dtype = get_fp8_constants()[0]
|
|
1820
|
+
# Preserve memory format (e.g., channels_last_3d) from input tensor
|
|
1821
|
+
fake_out = torch.empty_like(a, dtype=dtype)
|
|
1822
|
+
fake_scale = torch.empty((), device=a.device, dtype=torch.float32)
|
|
1823
|
+
return fake_out, fake_scale
|
|
1824
|
+
|
|
1825
|
+
|
|
1826
|
+
@triton.jit
|
|
1827
|
+
def _kernel_dequantize_fp8_block(
|
|
1828
|
+
xq_ptr,
|
|
1829
|
+
x_scale_ptr,
|
|
1830
|
+
x_dequant_ptr,
|
|
1831
|
+
M,
|
|
1832
|
+
K,
|
|
1833
|
+
BLOCK_M: tl.constexpr,
|
|
1834
|
+
BLOCK_K: tl.constexpr,
|
|
1835
|
+
):
|
|
1836
|
+
"""
|
|
1837
|
+
Kernel to dequantize FP8 tensor to BF16 tensor.
|
|
1838
|
+
Args:
|
|
1839
|
+
xq_ptr (tl.constexpr): Pointer to FP8 tensor.
|
|
1840
|
+
x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
|
|
1841
|
+
x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
|
|
1842
|
+
M (tl.constexpr): M dimension of input tensor.
|
|
1843
|
+
K (tl.constexpr): K dimension of input tensor.
|
|
1844
|
+
BLOCK_M (tl.constexpr): Block size for the M dimension.
|
|
1845
|
+
BLOCK_K (tl.constexpr): Block size for the K dimension.
|
|
1846
|
+
"""
|
|
1847
|
+
pid_m = tl.program_id(axis=0)
|
|
1848
|
+
pid_k = tl.program_id(axis=1)
|
|
1849
|
+
k = tl.cdiv(K, BLOCK_K)
|
|
1850
|
+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
1851
|
+
offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
1852
|
+
offs = offs_m[:, None] * K + offs_k[None, :]
|
|
1853
|
+
mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
|
|
1854
|
+
xq = tl.load(xq_ptr + offs, mask=mask).to(tl.bfloat16)
|
|
1855
|
+
x_scale = tl.load(x_scale_ptr + pid_m * k + pid_k)
|
|
1856
|
+
x_dequant = xq * x_scale
|
|
1857
|
+
tl.store(x_dequant_ptr + offs, x_dequant, mask=mask)
|
|
1858
|
+
|
|
1859
|
+
|
|
1860
|
+
def dequantize_fp8_block(
|
|
1861
|
+
xq: torch.Tensor,
|
|
1862
|
+
x_scale: torch.Tensor,
|
|
1863
|
+
block_m: int = 256,
|
|
1864
|
+
block_k: int = 256,
|
|
1865
|
+
) -> torch.Tensor:
|
|
1866
|
+
"""
|
|
1867
|
+
Dequantize FP8 tensor to BF16 tensor.
|
|
1868
|
+
|
|
1869
|
+
Args:
|
|
1870
|
+
xq (torch.Tensor): FP8 tensor to be dequantized.
|
|
1871
|
+
x_scale (torch.Tensor): FP8 scale tensor.
|
|
1872
|
+
block_m (int): Block size for the M dimension.
|
|
1873
|
+
block_k (int): Block size for the K dimension.
|
|
1874
|
+
|
|
1875
|
+
Returns:
|
|
1876
|
+
torch.Tensor: Dequantized BF16 tensor.
|
|
1877
|
+
"""
|
|
1878
|
+
|
|
1879
|
+
assert xq.is_contiguous() and x_scale.is_contiguous(), (
|
|
1880
|
+
"Input tensors must be contiguous"
|
|
1881
|
+
)
|
|
1882
|
+
assert xq.dim() == 2 and x_scale.dim() == 2, "Input tensors must have 2 dimensions"
|
|
1883
|
+
M, K = xq.size()
|
|
1884
|
+
x_dequant = torch.empty_like(xq, dtype=torch.bfloat16)
|
|
1885
|
+
|
|
1886
|
+
def grid(meta: Dict[str, int]) -> Tuple[int, int]:
|
|
1887
|
+
return (
|
|
1888
|
+
triton.cdiv(M, meta["BLOCK_M"]),
|
|
1889
|
+
triton.cdiv(K, meta["BLOCK_K"]),
|
|
1890
|
+
)
|
|
1891
|
+
|
|
1892
|
+
with torch.cuda.device(xq.device.index):
|
|
1893
|
+
_kernel_dequantize_fp8_block[grid](
|
|
1894
|
+
xq,
|
|
1895
|
+
x_scale,
|
|
1896
|
+
x_dequant,
|
|
1897
|
+
M,
|
|
1898
|
+
K,
|
|
1899
|
+
BLOCK_M=block_m, # pyre-ignore[6]
|
|
1900
|
+
BLOCK_K=block_k, # pyre-ignore[6]
|
|
1901
|
+
)
|
|
1902
|
+
return x_dequant
|