fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
from fbgemm_gpu.quantize.quantize_ops import dequantize_mx, quantize_mx # noqa F401
|
|
10
|
+
from fbgemm_gpu.utils import TorchLibraryFragment
|
|
11
|
+
|
|
12
|
+
lib = TorchLibraryFragment("fbgemm")
|
|
13
|
+
|
|
14
|
+
lib.define(
|
|
15
|
+
"""quantize_mx(
|
|
16
|
+
Tensor input,
|
|
17
|
+
int scale_bits,
|
|
18
|
+
int elem_ebits,
|
|
19
|
+
int elem_mbits,
|
|
20
|
+
float elem_max_norm,
|
|
21
|
+
int mx_group_size,
|
|
22
|
+
int? rounding_mode = None
|
|
23
|
+
) -> Tensor
|
|
24
|
+
"""
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
lib.define(
|
|
28
|
+
"""dequantize_mx(
|
|
29
|
+
Tensor input,
|
|
30
|
+
int mx_group_size
|
|
31
|
+
) -> Tensor
|
|
32
|
+
"""
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
lib.register(
|
|
36
|
+
"quantize_mx",
|
|
37
|
+
{"CUDA": quantize_mx, "CPU": quantize_mx},
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
lib.register(
|
|
41
|
+
"dequantize_mx",
|
|
42
|
+
{"CUDA": dequantize_mx, "CPU": dequantize_mx},
|
|
43
|
+
)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
from typing import Union
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32, RoundingMode
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def quantize_mx(
|
|
16
|
+
input: torch.Tensor,
|
|
17
|
+
scale_bits: int = 8,
|
|
18
|
+
elem_ebits: int = 2,
|
|
19
|
+
elem_mbits: int = 3,
|
|
20
|
+
elem_max_norm: float = 6.0,
|
|
21
|
+
mx_group_size: int = 32,
|
|
22
|
+
rounding_mode: Union[RoundingMode, int] = RoundingMode.even,
|
|
23
|
+
) -> torch.Tensor:
|
|
24
|
+
"""
|
|
25
|
+
Registered quantize_mx ops for E2E comm.
|
|
26
|
+
(registration is done in __init__.py)
|
|
27
|
+
We use Triton implementation for quantization
|
|
28
|
+
Args:
|
|
29
|
+
input: FP32 tensor of size total_elems to be quantized
|
|
30
|
+
scale_bits: num bits of the shared exponent (i.e., 8 for MX4 e2m1)
|
|
31
|
+
elem_ebits: num bits of the exponent (i.e., 2 for MX4 e2m1)
|
|
32
|
+
elem_mbits: num bits of the mantissa incl. sign and implicit bits (
|
|
33
|
+
i.e., 3 for MX4 e2m1)
|
|
34
|
+
elem_max_norm: max value of the float (i.e., 6.0 for MX4 e2m1)
|
|
35
|
+
mx_group_size: num elements that share the max shared_exponent
|
|
36
|
+
rounding_mode: Which type of rounding to use when calculating shared exponent.
|
|
37
|
+
|
|
38
|
+
Return:
|
|
39
|
+
output: MX4 tensor packed into int8 values with size
|
|
40
|
+
(total_elems / 2 + total_elems / groupsize)
|
|
41
|
+
the shared exponent of each group is stored at the last byte
|
|
42
|
+
of output of each group
|
|
43
|
+
"""
|
|
44
|
+
return fp32_to_mx4(
|
|
45
|
+
input, mx_group_size, rounding_mode=rounding_mode, use_triton=True
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def dequantize_mx(
|
|
50
|
+
input: torch.Tensor,
|
|
51
|
+
mx_group_size: int = 32,
|
|
52
|
+
) -> torch.Tensor:
|
|
53
|
+
"""
|
|
54
|
+
Registered dequantize_mx ops for E2E comm
|
|
55
|
+
(registration is done in __init__.py to prevent multiple loading)
|
|
56
|
+
We use triton implementation for quantization
|
|
57
|
+
Args:
|
|
58
|
+
input: FP8 tensor (MX4 packed in FP8)
|
|
59
|
+
mx_group_size: number of elements that shares the same max shared_exponent
|
|
60
|
+
|
|
61
|
+
Return:
|
|
62
|
+
output: FP32 tensor with total elements (total_elems)
|
|
63
|
+
"""
|
|
64
|
+
return mx4_to_fp32(input, mx_group_size, use_triton=True)
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
# The code in this file is refactored from https://fburl.com/code/p2gy2gxb
|
|
11
|
+
# based on "Amy Yang et al., Training Deep Learning Recommendation Model with
|
|
12
|
+
# Quantized Collective Communications", DLP-KDD 2020.
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from typing import Optional, TypeVar
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from fbgemm_gpu.quantize_utils import (
|
|
21
|
+
bf16_to_fp32,
|
|
22
|
+
fp16_to_fp32,
|
|
23
|
+
fp32_to_bf16_with_clamp,
|
|
24
|
+
fp32_to_fp16_with_clamp,
|
|
25
|
+
fp32_to_hfp8_with_clamp,
|
|
26
|
+
fp32_to_mx4,
|
|
27
|
+
hfp8_to_fp32,
|
|
28
|
+
mx4_to_fp32,
|
|
29
|
+
RoundingMode,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
from fbgemm_gpu.split_embedding_configs import SparseType
|
|
33
|
+
|
|
34
|
+
from torch.autograd.profiler import record_function # usort:skip
|
|
35
|
+
from dataclasses import dataclass
|
|
36
|
+
|
|
37
|
+
import fbgemm_gpu.quantize.quantize_ops # noqa F401
|
|
38
|
+
|
|
39
|
+
logger: logging.Logger = logging.getLogger()
|
|
40
|
+
|
|
41
|
+
# FP8 configurations
|
|
42
|
+
ebits, mbits, bias = 4, 3, 15
|
|
43
|
+
max_pos: float = (2 ** ((1 << ebits) - 2 - bias)) * (2 - 2 ** (-mbits))
|
|
44
|
+
|
|
45
|
+
# INT8 configurations
|
|
46
|
+
ROW_DIM_DEFAULT = 32
|
|
47
|
+
|
|
48
|
+
# MX4 configurations
|
|
49
|
+
MX_GROUP_SIZE_DEFAULT = 32
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def none_throws(
|
|
53
|
+
# pyre-fixme[31]: Expression `typing.Optional[typing.TypeVar("_T")]` is not a
|
|
54
|
+
# valid type.
|
|
55
|
+
optional: Optional[TypeVar("_T")],
|
|
56
|
+
message: str = "Unexpected `None`",
|
|
57
|
+
# pyre-fixme[31]: Expression `typing.TypeVar("_T")` is not a valid type.
|
|
58
|
+
) -> TypeVar("_T"):
|
|
59
|
+
if optional is None:
|
|
60
|
+
raise AssertionError(message)
|
|
61
|
+
return optional
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class QuantizationContext:
|
|
66
|
+
row_dim: int = ROW_DIM_DEFAULT
|
|
67
|
+
row_dim_quant: int = -1
|
|
68
|
+
mx_group_size: int = MX_GROUP_SIZE_DEFAULT
|
|
69
|
+
rounding_mode: Optional[RoundingMode] = RoundingMode.even
|
|
70
|
+
padded_dim_sum_per_rank: Optional[list[int]] = None
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _quantize_tensor(
|
|
74
|
+
input_tensor: torch.Tensor,
|
|
75
|
+
comm_precision: SparseType,
|
|
76
|
+
ctx: Optional[QuantizationContext] = None,
|
|
77
|
+
is_fwd: bool = True,
|
|
78
|
+
) -> torch.Tensor:
|
|
79
|
+
if comm_precision == SparseType.FP32:
|
|
80
|
+
return input_tensor
|
|
81
|
+
elif comm_precision == SparseType.FP16:
|
|
82
|
+
return fp32_to_fp16_with_clamp(input_tensor)
|
|
83
|
+
elif comm_precision == SparseType.BF16:
|
|
84
|
+
return fp32_to_bf16_with_clamp(input_tensor)
|
|
85
|
+
elif comm_precision == SparseType.FP8:
|
|
86
|
+
# return fp32_to_hfp8_with_clamp(input_tensor, ebits, mbits, bias)
|
|
87
|
+
if ctx is not None and ctx.row_dim > 0:
|
|
88
|
+
ctx = none_throws(ctx)
|
|
89
|
+
row_dim = ctx.row_dim
|
|
90
|
+
input_2d = input_tensor.view((-1, row_dim)) if row_dim > 0 else input_tensor
|
|
91
|
+
input_2d_quant = torch.ops.fbgemm.FloatToFP8RowwiseQuantized(
|
|
92
|
+
input_2d, is_fwd
|
|
93
|
+
)
|
|
94
|
+
row_dim_quant = input_2d_quant.shape[1]
|
|
95
|
+
input_quant_all2all = None
|
|
96
|
+
input_quant_all2all = input_2d_quant.view((-1))
|
|
97
|
+
ctx.row_dim_quant = row_dim_quant
|
|
98
|
+
return input_quant_all2all
|
|
99
|
+
else:
|
|
100
|
+
return fp32_to_hfp8_with_clamp(input_tensor, ebits, mbits, bias)
|
|
101
|
+
elif comm_precision == SparseType.INT8:
|
|
102
|
+
ctx = none_throws(ctx)
|
|
103
|
+
row_dim = ctx.row_dim
|
|
104
|
+
input_2d = input_tensor.view((-1, row_dim)) if row_dim > 0 else input_tensor
|
|
105
|
+
input_2d_quant = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(input_2d)
|
|
106
|
+
row_dim_quant = input_2d_quant.shape[1]
|
|
107
|
+
input_quant_all2all = None
|
|
108
|
+
input_quant_all2all = input_2d_quant.view((-1))
|
|
109
|
+
ctx.row_dim_quant = row_dim_quant
|
|
110
|
+
return input_quant_all2all
|
|
111
|
+
elif comm_precision == SparseType.MX4:
|
|
112
|
+
mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT
|
|
113
|
+
rounding_mode = ctx.rounding_mode if ctx is not None else RoundingMode.even
|
|
114
|
+
return fp32_to_mx4(
|
|
115
|
+
input_tensor, mx_group_size, rounding_mode=rounding_mode
|
|
116
|
+
).view(-1)
|
|
117
|
+
else:
|
|
118
|
+
raise ValueError(f"comm_precision={comm_precision} is not supported")
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _dequantize_tensor(
|
|
122
|
+
quantized_tensor: torch.Tensor,
|
|
123
|
+
comm_precision: SparseType,
|
|
124
|
+
ctx: Optional[QuantizationContext] = None,
|
|
125
|
+
is_fwd: bool = True,
|
|
126
|
+
fp8_output_dtype: Optional[SparseType] = None,
|
|
127
|
+
) -> torch.Tensor:
|
|
128
|
+
if comm_precision == SparseType.FP32:
|
|
129
|
+
assert quantized_tensor.dtype == torch.float
|
|
130
|
+
return quantized_tensor
|
|
131
|
+
elif comm_precision == SparseType.FP16:
|
|
132
|
+
assert quantized_tensor.dtype == torch.half
|
|
133
|
+
return fp16_to_fp32(quantized_tensor)
|
|
134
|
+
elif comm_precision == SparseType.BF16:
|
|
135
|
+
assert quantized_tensor.dtype == torch.bfloat16
|
|
136
|
+
return bf16_to_fp32(quantized_tensor)
|
|
137
|
+
elif comm_precision == SparseType.FP8:
|
|
138
|
+
if ctx is not None and ctx.row_dim > 0:
|
|
139
|
+
row_dim_quant = ctx.row_dim_quant
|
|
140
|
+
quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant))
|
|
141
|
+
# use provided fp8_output_dtype or default to FP32 (0)
|
|
142
|
+
output_dtype_int = (
|
|
143
|
+
fp8_output_dtype.as_int() if fp8_output_dtype is not None else 0
|
|
144
|
+
)
|
|
145
|
+
dequant_tensor = torch.ops.fbgemm.FP8RowwiseQuantizedToFloat(
|
|
146
|
+
quantized_tensor_2d,
|
|
147
|
+
is_fwd,
|
|
148
|
+
output_dtype_int,
|
|
149
|
+
)
|
|
150
|
+
return dequant_tensor.view(-1)
|
|
151
|
+
else:
|
|
152
|
+
assert quantized_tensor.dtype == torch.uint8
|
|
153
|
+
return hfp8_to_fp32(quantized_tensor, ebits, bias)
|
|
154
|
+
elif comm_precision == SparseType.INT8:
|
|
155
|
+
ctx = none_throws(ctx)
|
|
156
|
+
row_dim_quant = ctx.row_dim_quant
|
|
157
|
+
quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant))
|
|
158
|
+
dequant_tensor = torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat(
|
|
159
|
+
quantized_tensor_2d
|
|
160
|
+
)
|
|
161
|
+
return dequant_tensor.view(-1)
|
|
162
|
+
elif comm_precision == SparseType.MX4:
|
|
163
|
+
mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT
|
|
164
|
+
return mx4_to_fp32(quantized_tensor, mx_group_size)
|
|
165
|
+
else:
|
|
166
|
+
raise ValueError(f"comm_precision={comm_precision} is not supported")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class QuantizedCommCodec:
|
|
170
|
+
# Concrete implementation of QuantizedCommCodec provided by FBGEMM functions.
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
comm_precision: SparseType,
|
|
174
|
+
loss_scale: Optional[float] = None,
|
|
175
|
+
row_dim: Optional[int] = None,
|
|
176
|
+
is_fwd: bool = True,
|
|
177
|
+
rounding_mode: Optional[RoundingMode] = None,
|
|
178
|
+
fp8_output_dtype: Optional[SparseType] = None,
|
|
179
|
+
) -> None:
|
|
180
|
+
if loss_scale is not None:
|
|
181
|
+
if comm_precision not in [SparseType.FP16, SparseType.BF16]:
|
|
182
|
+
logger.warning(
|
|
183
|
+
f"Setting loss scale for comm_precision={comm_precision} is not supported. Overriding to None"
|
|
184
|
+
)
|
|
185
|
+
loss_scale = None
|
|
186
|
+
|
|
187
|
+
logger.info(
|
|
188
|
+
f"Creating QuantizedCommsCodec comm_precision:{comm_precision}, loss_scale:{loss_scale}"
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
self._comm_precision = comm_precision
|
|
192
|
+
self._loss_scale = loss_scale
|
|
193
|
+
self._is_fwd = is_fwd
|
|
194
|
+
self._row_dim: int = -1 if row_dim is None else row_dim
|
|
195
|
+
self._rounding_mode: Optional[RoundingMode] = rounding_mode
|
|
196
|
+
self._fp8_output_dtype: Optional[SparseType] = fp8_output_dtype
|
|
197
|
+
if self._comm_precision == SparseType.MX4:
|
|
198
|
+
self._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim
|
|
199
|
+
self._rounding_mode = (
|
|
200
|
+
RoundingMode.even if rounding_mode is None else rounding_mode
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def encode(
|
|
204
|
+
self, input_tensor: torch.Tensor, ctx: Optional[QuantizationContext] = None
|
|
205
|
+
) -> torch.Tensor:
|
|
206
|
+
if self._loss_scale is not None:
|
|
207
|
+
input_tensor = self._loss_scale * input_tensor
|
|
208
|
+
with record_function(
|
|
209
|
+
f"## encoder {self._comm_precision} {self._loss_scale} ##"
|
|
210
|
+
):
|
|
211
|
+
output = _quantize_tensor(
|
|
212
|
+
input_tensor,
|
|
213
|
+
self._comm_precision,
|
|
214
|
+
ctx,
|
|
215
|
+
self._is_fwd,
|
|
216
|
+
)
|
|
217
|
+
return output
|
|
218
|
+
|
|
219
|
+
def decode(
|
|
220
|
+
self, input_tensor: torch.Tensor, ctx: Optional[QuantizationContext] = None
|
|
221
|
+
) -> torch.Tensor:
|
|
222
|
+
if self._loss_scale is not None:
|
|
223
|
+
input_tensor = input_tensor / self._loss_scale
|
|
224
|
+
with record_function(
|
|
225
|
+
f"## decoder {self._comm_precision} {self._loss_scale} ##"
|
|
226
|
+
):
|
|
227
|
+
dequantized_tensor = _dequantize_tensor(
|
|
228
|
+
input_tensor,
|
|
229
|
+
self._comm_precision,
|
|
230
|
+
ctx,
|
|
231
|
+
self._is_fwd,
|
|
232
|
+
fp8_output_dtype=self._fp8_output_dtype,
|
|
233
|
+
)
|
|
234
|
+
return dequantized_tensor
|
|
235
|
+
|
|
236
|
+
def calc_quantized_size(
|
|
237
|
+
self, input_len: int, ctx: Optional[QuantizationContext] = None
|
|
238
|
+
) -> int:
|
|
239
|
+
# Use the same logic in _float_to_fused8bitrowwise_gpu_t()
|
|
240
|
+
if self._comm_precision == SparseType.INT8 or (
|
|
241
|
+
self._comm_precision == SparseType.FP8 and self._row_dim > 0
|
|
242
|
+
):
|
|
243
|
+
ctx = none_throws(ctx)
|
|
244
|
+
torch._check(
|
|
245
|
+
input_len % ctx.row_dim == 0,
|
|
246
|
+
lambda: f"input_len {input_len} is not a multiple of row dim {ctx.row_dim}",
|
|
247
|
+
)
|
|
248
|
+
assert input_len % ctx.row_dim == 0, (
|
|
249
|
+
f"input_len {input_len} is not a multiple of row dim {ctx.row_dim} "
|
|
250
|
+
"Please check your batch size (power of 2 batch size is recommended)"
|
|
251
|
+
)
|
|
252
|
+
nrows = input_len // ctx.row_dim
|
|
253
|
+
ncols = (ctx.row_dim + 3) // 4 * 4 + 2 * 4
|
|
254
|
+
return nrows * ncols
|
|
255
|
+
elif self._comm_precision == SparseType.MX4:
|
|
256
|
+
if ctx:
|
|
257
|
+
group_size = ctx.mx_group_size
|
|
258
|
+
else:
|
|
259
|
+
group_size = MX_GROUP_SIZE_DEFAULT
|
|
260
|
+
assert (
|
|
261
|
+
input_len % group_size == 0
|
|
262
|
+
), f"input_len {input_len} needs to be multiple of group_size {group_size}"
|
|
263
|
+
# quantized output size = half input size + number of groups (shared exp)
|
|
264
|
+
ctx = none_throws(ctx)
|
|
265
|
+
return (input_len // 2) + (input_len // ctx.mx_group_size)
|
|
266
|
+
else:
|
|
267
|
+
return input_len
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def quantized_dtype(self) -> torch.dtype:
|
|
271
|
+
return self._comm_precision.as_dtype()
|
|
272
|
+
|
|
273
|
+
def create_context(self) -> Optional[QuantizationContext]:
|
|
274
|
+
# fp8 rowwise is activated when row_dim > 0
|
|
275
|
+
if self._comm_precision == SparseType.FP8:
|
|
276
|
+
return QuantizationContext(self._row_dim)
|
|
277
|
+
if self._comm_precision == SparseType.MX4:
|
|
278
|
+
return QuantizationContext(
|
|
279
|
+
row_dim=self._row_dim,
|
|
280
|
+
mx_group_size=self._row_dim,
|
|
281
|
+
rounding_mode=self._rounding_mode,
|
|
282
|
+
)
|
|
283
|
+
# int8 rowwise is default
|
|
284
|
+
return QuantizationContext()
|
|
285
|
+
|
|
286
|
+
def padded_size(
|
|
287
|
+
self,
|
|
288
|
+
input_tensor: torch.Tensor,
|
|
289
|
+
dim_per_rank: list[int],
|
|
290
|
+
my_rank: int,
|
|
291
|
+
qcomm_ctx: QuantizationContext,
|
|
292
|
+
) -> tuple[int, int]:
|
|
293
|
+
if input_tensor.ndim == 1:
|
|
294
|
+
return input_tensor.shape[0], 0
|
|
295
|
+
# return padded size for the feature dimension (dim 1), 0 if no padding needed.
|
|
296
|
+
padded_dim_sum, padding_size = input_tensor.shape[1], 0
|
|
297
|
+
if self._comm_precision == SparseType.MX4:
|
|
298
|
+
group_size = qcomm_ctx.mx_group_size
|
|
299
|
+
padding_size_per_rank = [
|
|
300
|
+
group_size - (t if (t := dim_sum % group_size) > 0 else group_size)
|
|
301
|
+
for dim_sum in dim_per_rank
|
|
302
|
+
]
|
|
303
|
+
padded_dim_sum_per_rank = [
|
|
304
|
+
a + b for a, b in zip(dim_per_rank, padding_size_per_rank)
|
|
305
|
+
]
|
|
306
|
+
dim_sum, padding_size = (
|
|
307
|
+
dim_per_rank[my_rank],
|
|
308
|
+
padding_size_per_rank[my_rank],
|
|
309
|
+
)
|
|
310
|
+
assert input_tensor.ndim == 2 and input_tensor.shape[1] == dim_sum
|
|
311
|
+
qcomm_ctx.padded_dim_sum_per_rank = padded_dim_sum_per_rank
|
|
312
|
+
padded_dim_sum = padding_size + dim_sum
|
|
313
|
+
return padded_dim_sum, padding_size
|
|
314
|
+
|
|
315
|
+
return padded_dim_sum, padding_size
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Optional, Union
|
|
12
|
+
|
|
13
|
+
import torch # isort:skip
|
|
14
|
+
|
|
15
|
+
import fbgemm_gpu
|
|
16
|
+
|
|
17
|
+
from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4, RoundingMode
|
|
18
|
+
from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
|
|
22
|
+
open_source = bool(getattr(fbgemm_gpu, "open_source", False))
|
|
23
|
+
except NotImplementedError:
|
|
24
|
+
open_source = False
|
|
25
|
+
|
|
26
|
+
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
|
|
27
|
+
if not open_source:
|
|
28
|
+
from mtia.kernels.triton.mx4.quantize import (
|
|
29
|
+
triton_dequantize_mx4 as mtia_dequantize_mx4,
|
|
30
|
+
triton_quantize_mx4 as mtia_quantize_mx4,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
logger: logging.Logger = logging.getLogger()
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
# pyre-ignore[21]
|
|
37
|
+
from fbgemm_gpu import open_source # noqa: F401
|
|
38
|
+
except Exception:
|
|
39
|
+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
|
|
40
|
+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
|
|
41
|
+
|
|
42
|
+
TORCH_HALF_MIN: float = torch.finfo(torch.float16).min
|
|
43
|
+
TORCH_HALF_MAX: float = torch.finfo(torch.float16).max
|
|
44
|
+
|
|
45
|
+
TORCH_BFLOAT16_MIN: float = torch.finfo(torch.bfloat16).min
|
|
46
|
+
TORCH_BFLOAT16_MAX: float = torch.finfo(torch.bfloat16).max
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def fp32_to_mx4(
|
|
50
|
+
tensor: torch.Tensor,
|
|
51
|
+
group_size: int = 32,
|
|
52
|
+
ebits: int = 2,
|
|
53
|
+
mbits: int = 1,
|
|
54
|
+
rounding_mode: Optional[Union[RoundingMode, int]] = RoundingMode.even,
|
|
55
|
+
stochastic_casting: bool = False,
|
|
56
|
+
use_triton: bool = True,
|
|
57
|
+
) -> torch.Tensor:
|
|
58
|
+
"""Quantize an FP32 tensor to MX4 with triton or native cuda impl.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
tensor (torch.Tensor): FP32 tensor to quantize with M total elements.
|
|
62
|
+
group_size (int): Compute scale in chunks of group_size.
|
|
63
|
+
ebits (int): Number of exponent bits in target mx4 format.
|
|
64
|
+
mbits (int): Number of mantissa bits in target mx4 format.
|
|
65
|
+
rounding_mode (RoundingMode or int): Which type of rounding to use when computing exponent.
|
|
66
|
+
Only supported with use_triton=True.
|
|
67
|
+
stochastic_casting (bool): Whether to use stochastic casting when downcasting.
|
|
68
|
+
use_triton (bool): If set, use triton quantization, otherwise cuda.
|
|
69
|
+
|
|
70
|
+
Return:
|
|
71
|
+
output: MX4 tensor packed into int8 values with total elements (M / 2 + M / groupsize)
|
|
72
|
+
"""
|
|
73
|
+
# Accelerated MX4 is only available on cuda, if input is on cpu, use python.
|
|
74
|
+
# Operate on flattened input.
|
|
75
|
+
if rounding_mode is None:
|
|
76
|
+
rounding_mode = RoundingMode.even
|
|
77
|
+
|
|
78
|
+
if not tensor.is_cuda and not tensor.is_mtia:
|
|
79
|
+
return py_quantize_mx4(
|
|
80
|
+
tensor,
|
|
81
|
+
group_size,
|
|
82
|
+
ebits=ebits,
|
|
83
|
+
mbits=mbits,
|
|
84
|
+
rounding_mode=rounding_mode,
|
|
85
|
+
stochastic_casting=stochastic_casting,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if use_triton:
|
|
89
|
+
if tensor.is_mtia:
|
|
90
|
+
return mtia_quantize_mx4(
|
|
91
|
+
tensor,
|
|
92
|
+
group_size,
|
|
93
|
+
ebits=ebits,
|
|
94
|
+
mbits=mbits,
|
|
95
|
+
rounding_mode=rounding_mode,
|
|
96
|
+
stochastic_casting=stochastic_casting,
|
|
97
|
+
)
|
|
98
|
+
return quantize_mx4(
|
|
99
|
+
tensor,
|
|
100
|
+
group_size,
|
|
101
|
+
ebits=ebits,
|
|
102
|
+
mbits=mbits,
|
|
103
|
+
rounding_mode=rounding_mode,
|
|
104
|
+
stochastic_casting=stochastic_casting,
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
out = torch.ops.fbgemm.quantize_mx_cuda(
|
|
108
|
+
tensor.flatten(),
|
|
109
|
+
scale_bits=8,
|
|
110
|
+
elem_ebits=2,
|
|
111
|
+
elem_mbits=3,
|
|
112
|
+
elem_max_norm=6.0,
|
|
113
|
+
mx_group_size=group_size,
|
|
114
|
+
)
|
|
115
|
+
# Perserve input dimensions.
|
|
116
|
+
output_shape = list(tensor.shape[:-1]) + [-1]
|
|
117
|
+
return out.view(output_shape)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def mx4_to_fp32(
|
|
121
|
+
tensor: torch.Tensor,
|
|
122
|
+
group_size: int = 32,
|
|
123
|
+
use_triton: bool = True,
|
|
124
|
+
ebits: int = 2,
|
|
125
|
+
mbits: int = 1,
|
|
126
|
+
) -> torch.Tensor:
|
|
127
|
+
"""Dequantize an MX4 tensor to FP32 with triton or native cuda impl.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
tensor (torch.Tensor): MX4 packed tensor with total elements (M / 2 + M / groupsize)
|
|
131
|
+
group_size (int): Compute scale in chunks of group_size.
|
|
132
|
+
use_triton (bool): If set, use triton quantization, otherwise cuda.
|
|
133
|
+
ebits (int): Number of exponent bits in target mx4 format.
|
|
134
|
+
mbits (int): Number of mantissa bits in target mx4 format.
|
|
135
|
+
|
|
136
|
+
Return:
|
|
137
|
+
output: FP32 tensor with total elements (M).
|
|
138
|
+
"""
|
|
139
|
+
# Accelerated MX4 dequantize is only available on cuda, if input is on cpu, use python.
|
|
140
|
+
if not tensor.is_cuda and not tensor.is_mtia:
|
|
141
|
+
return py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
|
|
142
|
+
if use_triton:
|
|
143
|
+
if tensor.is_mtia:
|
|
144
|
+
return mtia_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
|
|
145
|
+
return dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
|
|
146
|
+
else:
|
|
147
|
+
return torch.ops.fbgemm.dequantize_mx_cuda(tensor.flatten(), group_size)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
|
|
151
|
+
return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def fp32_to_bf16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
|
|
155
|
+
return torch.clamp(tensor, TORCH_BFLOAT16_MIN, TORCH_BFLOAT16_MAX).bfloat16()
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def fp32_to_hfp8_with_clamp(
|
|
159
|
+
tensor: torch.Tensor, ebits: int = 4, mbits: int = 3, bias: int = 15
|
|
160
|
+
) -> torch.Tensor:
|
|
161
|
+
max_pos: float = (2 ** ((1 << ebits) - 2 - bias)) * (2 - 2 ** (-mbits))
|
|
162
|
+
return torch.ops.fbgemm.FloatToHFP8Quantized(
|
|
163
|
+
tensor.contiguous(),
|
|
164
|
+
ebits,
|
|
165
|
+
bias,
|
|
166
|
+
max_pos,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def fp16_to_fp32(tensor: torch.Tensor) -> torch.Tensor:
|
|
171
|
+
return tensor.float()
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def bf16_to_fp32(tensor: torch.Tensor) -> torch.Tensor:
|
|
175
|
+
return tensor.view(torch.bfloat16).float()
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def hfp8_to_fp32(tensor: torch.Tensor, ebits: int = 4, bias: int = 15) -> torch.Tensor:
|
|
179
|
+
return torch.ops.fbgemm.HFP8QuantizedToFloat(
|
|
180
|
+
tensor.contiguous().view(torch.uint8),
|
|
181
|
+
ebits,
|
|
182
|
+
bias,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def measure_fp16_quant_error(input_tensor: torch.Tensor) -> None:
|
|
187
|
+
# TODO: log to tensorboard
|
|
188
|
+
|
|
189
|
+
num_nan_fp32_tensor = torch.numel(input_tensor[torch.isnan(input_tensor)])
|
|
190
|
+
logger.info(
|
|
191
|
+
"num NaN in fp32 tensor: {}, ratio: {}.".format(
|
|
192
|
+
num_nan_fp32_tensor, num_nan_fp32_tensor / torch.numel(input_tensor)
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
logger.info(
|
|
197
|
+
"fp32 tensor profile: min: {}, max: {}, min abs:{}, max abs:{}.".format(
|
|
198
|
+
torch.min(input_tensor),
|
|
199
|
+
torch.max(input_tensor),
|
|
200
|
+
torch.min(torch.abs(input_tensor)),
|
|
201
|
+
torch.max(torch.abs(input_tensor)),
|
|
202
|
+
)
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
fp16_tensor = fp32_to_fp16_with_clamp(input_tensor)
|
|
206
|
+
num_nan_fp16_tensor = torch.numel(fp16_tensor[torch.isnan(fp16_tensor)])
|
|
207
|
+
|
|
208
|
+
logger.info(
|
|
209
|
+
"num NaN in fp16 tensor: {}, ratio: {}.".format(
|
|
210
|
+
num_nan_fp16_tensor, num_nan_fp16_tensor / torch.numel(input_tensor)
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
diff = torch.abs(input_tensor - fp16_tensor.float())
|
|
215
|
+
rel_diff = diff / torch.abs(input_tensor)
|
|
216
|
+
logger.info(
|
|
217
|
+
"fp32_to_fp16 abs error: min={}, max={}, avg={}.".format(
|
|
218
|
+
torch.min(diff), torch.max(diff), torch.mean(diff)
|
|
219
|
+
)
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
rel_diff_not_nan = rel_diff[torch.logical_not(torch.isnan(rel_diff))]
|
|
223
|
+
logger.info(
|
|
224
|
+
"fp32_to_fp16 rel error: min={}, max={}, avg={}.".format(
|
|
225
|
+
torch.min(rel_diff_not_nan),
|
|
226
|
+
torch.max(rel_diff_not_nan),
|
|
227
|
+
torch.mean(rel_diff_not_nan),
|
|
228
|
+
)
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
rel_diff_1_idx = torch.where(rel_diff == 1.0)
|
|
232
|
+
fp32_rel_err_1_vals = input_tensor[rel_diff_1_idx]
|
|
233
|
+
if torch.numel(fp32_rel_err_1_vals) > 0:
|
|
234
|
+
fp32_rel_err_1_vals = torch.abs(fp32_rel_err_1_vals)
|
|
235
|
+
logger.info(
|
|
236
|
+
"fp32_to_fp16 rel error == 1: fp32 min:{}, fp32 max:{}, fp32 avg:{}.".format(
|
|
237
|
+
torch.min(fp32_rel_err_1_vals),
|
|
238
|
+
torch.max(fp32_rel_err_1_vals),
|
|
239
|
+
torch.mean(fp32_rel_err_1_vals),
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
subrange_ratio = torch.numel(fp16_tensor[rel_diff_1_idx]) / torch.numel(
|
|
244
|
+
fp16_tensor
|
|
245
|
+
)
|
|
246
|
+
logger.info("sub fp16 range ratio: {}".format(subrange_ratio))
|