mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from msprobe.pytorch.api_accuracy_checker.compare.algorithm import check_inf_nan_value, check_norm_value, \
|
|
21
|
+
check_small_value
|
|
22
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare
|
|
23
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
|
|
24
|
+
from msprobe.core.common.const import CompareConst
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AbsolutethdCompare(BaseCompare):
|
|
29
|
+
"""
|
|
30
|
+
Absolute threshold compare class.
|
|
31
|
+
|
|
32
|
+
This class is used to compare the absolute threshold of benchmark outputs and device outputs.
|
|
33
|
+
It calculates various metrics such as inf_nan_error_ratio, rel_err_ratio, and abs_err_ratio
|
|
34
|
+
to determine the accuracy of the device output compared to the benchmark output.
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
bench_output (np.ndarray): The output from the benchmark.
|
|
38
|
+
device_output (np.ndarray): The output from the device.
|
|
39
|
+
dtype (torch.dtype): The data type of the outputs.
|
|
40
|
+
abs_bench (np.ndarray): The absolute value of the benchmark output.
|
|
41
|
+
abs_bench_with_eps (np.ndarray): The absolute value of the benchmark output with epsilon.
|
|
42
|
+
both_finite_mask (np.ndarray): A mask indicating where both outputs are finite.
|
|
43
|
+
inf_nan_mask (np.ndarray): A mask indicating where either output is infinite or NaN.
|
|
44
|
+
rtol (float): The relative tolerance for comparison.
|
|
45
|
+
rel_err (np.ndarray): The relative error between the benchmark and device outputs.
|
|
46
|
+
small_value (float): The small value threshold for comparison.
|
|
47
|
+
small_value_atol (float): The absolute tolerance for small values.
|
|
48
|
+
small_value_mask (np.ndarray): A mask indicating where values are small.
|
|
49
|
+
normal_value_mask (np.ndarray): A mask indicating where values are normal.
|
|
50
|
+
|
|
51
|
+
Methods:
|
|
52
|
+
_get_rtol(): Gets the relative tolerance based on the data type.
|
|
53
|
+
_get_rel_err(abs_bench_with_eps): Calculates the relative error.
|
|
54
|
+
_get_normal_value_mask(small_value_mask): Gets the mask for normal values.
|
|
55
|
+
_pre_compare(): Prepares the comparison by calculating various metrics.
|
|
56
|
+
_compute_metrics(): Computes the comparison metrics.
|
|
57
|
+
|
|
58
|
+
Note:
|
|
59
|
+
This class assumes that the input data is a dictionary containing 'bench_output', 'device_output',
|
|
60
|
+
'compare_column' and 'dtype'.
|
|
61
|
+
The 'dtype' should be a PyTorch data type.
|
|
62
|
+
|
|
63
|
+
See Also:
|
|
64
|
+
BaseCompare: The base class for comparison classes.
|
|
65
|
+
StandardConfig: The class containing standard configuration values.
|
|
66
|
+
"""
|
|
67
|
+
def __init__(self, input_data):
|
|
68
|
+
super(AbsolutethdCompare, self).__init__(input_data)
|
|
69
|
+
self.compare_algorithm = CompareConst.ABSOLUTE_THRESHOLD
|
|
70
|
+
|
|
71
|
+
def _get_rtol(self):
|
|
72
|
+
return StandardConfig.get_rtol(self.dtype)
|
|
73
|
+
|
|
74
|
+
def _pre_compare(self):
|
|
75
|
+
"""
|
|
76
|
+
Prepares the comparison by calculating various metrics.
|
|
77
|
+
|
|
78
|
+
This method performs the following steps:
|
|
79
|
+
1. Calculates the absolute benchmark values and their epsilon-adjusted versions.
|
|
80
|
+
2. Determines masks for finite and infinite/NaN values in the outputs.
|
|
81
|
+
3. Computes the absolute error between benchmark and device outputs.
|
|
82
|
+
4. Retrieves the relative tolerance based on the data type.
|
|
83
|
+
5. Calculates the relative error using the absolute error and epsilon-adjusted benchmark values.
|
|
84
|
+
6. Determines the small value threshold and its absolute tolerance.
|
|
85
|
+
7. Creates a mask for small values based on the benchmark values and finite mask.
|
|
86
|
+
8. Creates a mask for normal values by excluding small values from the finite mask.
|
|
87
|
+
"""
|
|
88
|
+
self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps()
|
|
89
|
+
self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask()
|
|
90
|
+
self.abs_err = self.stat_abs_error()
|
|
91
|
+
self.rtol = self._get_rtol()
|
|
92
|
+
self.rel_err = self._get_rel_err(self.abs_err, self.abs_bench_with_eps)
|
|
93
|
+
self.small_value, self.small_value_atol = self.get_small_value_threshold()
|
|
94
|
+
self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value)
|
|
95
|
+
self.normal_value_mask = self._get_normal_value_mask(self.both_finite_mask, self.small_value_mask)
|
|
96
|
+
|
|
97
|
+
def _compute_metrics(self):
|
|
98
|
+
inf_nan_error_ratio = check_inf_nan_value(self.inf_nan_mask, self.bench_output, self.device_output, self.dtype,
|
|
99
|
+
self.rtol)
|
|
100
|
+
rel_err_ratio = check_norm_value(self.normal_value_mask, self.rel_err, self.rtol)
|
|
101
|
+
abs_err_ratio = check_small_value(self.abs_err, self.small_value_mask, self.small_value_atol)
|
|
102
|
+
return {
|
|
103
|
+
"inf_nan_error_ratio": inf_nan_error_ratio,
|
|
104
|
+
"rel_err_ratio": rel_err_ratio,
|
|
105
|
+
"abs_err_ratio": abs_err_ratio
|
|
106
|
+
}
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from msprobe.pytorch.api_accuracy_checker.compare.algorithm import check_inf_nan_value, check_norm_value, \
|
|
21
|
+
check_small_value, get_error_balance
|
|
22
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare
|
|
23
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
|
|
24
|
+
from msprobe.core.common.const import CompareConst
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AccumulativeErrorCompare(BaseCompare):
|
|
28
|
+
"""
|
|
29
|
+
Absolute threshold compare class.
|
|
30
|
+
|
|
31
|
+
This class is used to compare the absolute threshold of benchmark outputs and device outputs.
|
|
32
|
+
It calculates various metrics such as inf_nan_error_ratio, rel_err_ratio, and abs_err_ratio
|
|
33
|
+
to determine the accuracy of the device output compared to the benchmark output.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
bench_output (np.ndarray): The output from the benchmark.
|
|
37
|
+
device_output (np.ndarray): The output from the device.
|
|
38
|
+
dtype (torch.dtype): The data type of the outputs.
|
|
39
|
+
abs_bench (np.ndarray): The absolute value of the benchmark output.
|
|
40
|
+
abs_bench_with_eps (np.ndarray): The absolute value of the benchmark output with epsilon.
|
|
41
|
+
both_finite_mask (np.ndarray): A mask indicating where both outputs are finite.
|
|
42
|
+
inf_nan_mask (np.ndarray): A mask indicating where either output is infinite or NaN.
|
|
43
|
+
bound (float): The tolerance for comparison.
|
|
44
|
+
rel_err (np.ndarray): The relative error between the benchmark and device outputs.
|
|
45
|
+
small_value (float): The small value threshold for comparison.
|
|
46
|
+
small_value_atol (float): The absolute tolerance for small values.
|
|
47
|
+
small_value_mask (np.ndarray): A mask indicating where values are small.
|
|
48
|
+
normal_value_mask (np.ndarray): A mask indicating where values are normal.
|
|
49
|
+
|
|
50
|
+
Methods:
|
|
51
|
+
_get_rtol(): Gets the relative tolerance based on the data type.
|
|
52
|
+
_get_rel_err(abs_bench_with_eps): Calculates the relative error.
|
|
53
|
+
_get_normal_value_mask(small_value_mask): Gets the mask for normal values.
|
|
54
|
+
_pre_compare(): Prepares the comparison by calculating various metrics.
|
|
55
|
+
_compute_metrics(): Computes the comparison metrics.
|
|
56
|
+
|
|
57
|
+
Note:
|
|
58
|
+
This class assumes that the input data is a dictionary containing 'bench_output', 'device_output',
|
|
59
|
+
'compare_column' and 'dtype'.
|
|
60
|
+
The 'dtype' should be a PyTorch data type.
|
|
61
|
+
|
|
62
|
+
See Also:
|
|
63
|
+
BaseCompare: The base class for comparison classes.
|
|
64
|
+
StandardConfig: The class containing standard configuration values.
|
|
65
|
+
"""
|
|
66
|
+
def __init__(self, input_data):
|
|
67
|
+
super(AccumulativeErrorCompare, self).__init__(input_data)
|
|
68
|
+
self.compare_algorithm = CompareConst.ACCUMULATIVE_ERROR_COMPARE
|
|
69
|
+
|
|
70
|
+
def _get_bound(self):
|
|
71
|
+
return StandardConfig.get_accumulative_error_bound(self.dtype)
|
|
72
|
+
|
|
73
|
+
def _pre_compare(self):
|
|
74
|
+
"""
|
|
75
|
+
Prepares the comparison by calculating various metrics.
|
|
76
|
+
|
|
77
|
+
This method performs the following steps:
|
|
78
|
+
1. Calculates the absolute benchmark values and their epsilon-adjusted versions.
|
|
79
|
+
2. Determines masks for finite and infinite/NaN values in the outputs.
|
|
80
|
+
3. Computes the absolute error between benchmark and device outputs.
|
|
81
|
+
4. Retrieves the tolerance based on the data type.
|
|
82
|
+
5. Calculates the relative error using the absolute error and epsilon-adjusted benchmark values.
|
|
83
|
+
6. Determines the small value threshold and its absolute tolerance.
|
|
84
|
+
7. Creates a mask for small values based on the benchmark values and finite mask.
|
|
85
|
+
8. Creates a mask for normal values by excluding small values from the finite mask.
|
|
86
|
+
"""
|
|
87
|
+
self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps()
|
|
88
|
+
self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask()
|
|
89
|
+
self.abs_err = self.stat_abs_error()
|
|
90
|
+
self.bound = self._get_bound()
|
|
91
|
+
self.rel_err = self._get_rel_err(self.abs_err, self.abs_bench_with_eps)
|
|
92
|
+
self.small_value, self.small_value_atol = self.get_small_value_threshold()
|
|
93
|
+
self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value)
|
|
94
|
+
self.normal_value_mask = self._get_normal_value_mask(self.both_finite_mask, self.small_value_mask)
|
|
95
|
+
|
|
96
|
+
def _compute_metrics(self):
|
|
97
|
+
inf_nan_error_ratio = check_inf_nan_value(self.inf_nan_mask, self.bench_output, self.device_output, self.dtype,
|
|
98
|
+
self.bound)
|
|
99
|
+
rel_err_ratio = check_norm_value(self.normal_value_mask, self.rel_err, self.bound)
|
|
100
|
+
abs_err_ratio = check_small_value(self.abs_err, self.small_value_mask, self.bound)
|
|
101
|
+
eb = get_error_balance(self.bench_output, self.device_output)
|
|
102
|
+
return {
|
|
103
|
+
"inf_nan_error_ratio": inf_nan_error_ratio,
|
|
104
|
+
"rel_err_ratio": rel_err_ratio,
|
|
105
|
+
"abs_err_ratio": abs_err_ratio,
|
|
106
|
+
"eb": eb
|
|
107
|
+
}
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
import numpy as np
|
|
20
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import convert_str_to_float
|
|
21
|
+
from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_bench_with_eps, get_abs_err, \
|
|
22
|
+
get_finite_and_infinite_mask, get_small_value_mask
|
|
23
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class BaseCompare(ABC):
|
|
27
|
+
"""
|
|
28
|
+
Base comparison class for benchmarking and device output.
|
|
29
|
+
|
|
30
|
+
This class provides a foundation for comparing benchmark outputs with device outputs.
|
|
31
|
+
It encapsulates the common logic for calculating accuracy metrics and
|
|
32
|
+
provides a framework for subclasses to implement specific comparison logic.
|
|
33
|
+
|
|
34
|
+
Attributes:
|
|
35
|
+
bench_output (np.ndarray): The output from the benchmark.
|
|
36
|
+
device_output (np.ndarray): The output from the device.
|
|
37
|
+
compare_column (object): The column object to store comparison results.
|
|
38
|
+
dtype (torch.dtype): The data type of the outputs.
|
|
39
|
+
|
|
40
|
+
Methods:
|
|
41
|
+
get_small_value_threshold(): Retrieves the small value threshold for the given data type.
|
|
42
|
+
stat_abs_bench_with_eps(): Calculates the absolute benchmark output with epsilon.
|
|
43
|
+
stat_abs_error(): Calculates the absolute error between the benchmark and device outputs.
|
|
44
|
+
stat_finite_and_infinite_mask(): Generates masks for finite and infinite/NaN values.
|
|
45
|
+
stat_small_value_mask(abs_bench, both_finite_mask, small_value): Creates a mask for small values.
|
|
46
|
+
compare(): Performs the comparison and computes metrics.
|
|
47
|
+
_pre_compare(): Pre-comparison hook for subclass-specific initialization.
|
|
48
|
+
_compute_metrics(): Computes the comparison metrics.
|
|
49
|
+
_post_compare(metrics): Post-comparison hook to update comparison results.
|
|
50
|
+
|
|
51
|
+
Note:
|
|
52
|
+
This class assumes that the input data is an instance of InputData containing the benchmark output,
|
|
53
|
+
device output, comparison column, and data type. Subclasses should implement the _pre_compare,
|
|
54
|
+
_compute_metrics, and _post_compare methods to provide specific comparison logic.
|
|
55
|
+
|
|
56
|
+
See Also:
|
|
57
|
+
InputData: The class containing input data for comparison.
|
|
58
|
+
StandardConfig: The class containing standard configuration values.
|
|
59
|
+
"""
|
|
60
|
+
def __init__(self, input_data):
|
|
61
|
+
self.bench_output = input_data.bench_output
|
|
62
|
+
self.device_output = input_data.device_output
|
|
63
|
+
self.compare_column = input_data.compare_column
|
|
64
|
+
self.dtype = input_data.dtype
|
|
65
|
+
self.compare_algorithm = None
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def stat_small_value_mask(abs_bench, both_finite_mask, small_value):
|
|
69
|
+
small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value)
|
|
70
|
+
return small_value_mask
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def _get_rel_err(abs_err, abs_bench_with_eps):
|
|
74
|
+
rel_err = abs_err / abs_bench_with_eps
|
|
75
|
+
return rel_err
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def _get_normal_value_mask(both_finite_mask, small_value_mask):
|
|
79
|
+
return np.logical_and(both_finite_mask, np.logical_not(small_value_mask))
|
|
80
|
+
|
|
81
|
+
@abstractmethod
|
|
82
|
+
def _pre_compare(self):
|
|
83
|
+
raise NotImplementedError
|
|
84
|
+
|
|
85
|
+
def get_small_value_threshold(self):
|
|
86
|
+
small_value = StandardConfig.get_small_value(self.dtype, self.compare_algorithm)
|
|
87
|
+
small_value_atol = StandardConfig.get_small_value_atol(self.dtype, self.compare_algorithm)
|
|
88
|
+
return small_value, small_value_atol
|
|
89
|
+
|
|
90
|
+
def stat_abs_bench_with_eps(self):
|
|
91
|
+
abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(self.bench_output, self.dtype)
|
|
92
|
+
return abs_bench, abs_bench_with_eps
|
|
93
|
+
|
|
94
|
+
def stat_abs_error(self):
|
|
95
|
+
abs_err = get_abs_err(self.bench_output, self.device_output)
|
|
96
|
+
return abs_err
|
|
97
|
+
|
|
98
|
+
def stat_finite_and_infinite_mask(self):
|
|
99
|
+
both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(self.bench_output, self.device_output)
|
|
100
|
+
return both_finite_mask, inf_nan_mask
|
|
101
|
+
|
|
102
|
+
def compare(self):
|
|
103
|
+
self._pre_compare()
|
|
104
|
+
metrics = self._compute_metrics()
|
|
105
|
+
self._post_compare(metrics)
|
|
106
|
+
|
|
107
|
+
def _compute_metrics(self):
|
|
108
|
+
return {}
|
|
109
|
+
|
|
110
|
+
def _post_compare(self, metrics):
|
|
111
|
+
self.compare_column.update(metrics)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class BasePrecisionCompare:
|
|
115
|
+
def __init__(self, input_data):
|
|
116
|
+
self.row_npu = input_data.row_npu
|
|
117
|
+
self.row_gpu = input_data.row_gpu
|
|
118
|
+
self.dtype = input_data.dtype
|
|
119
|
+
self.compare_column = input_data.compare_column
|
|
120
|
+
self.compare_algorithm = None
|
|
121
|
+
|
|
122
|
+
@abstractmethod
|
|
123
|
+
def _get_status(self, metrics, inf_nan_consistency):
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
@abstractmethod
|
|
127
|
+
def _compute_ratio(self):
|
|
128
|
+
pass
|
|
129
|
+
|
|
130
|
+
def compare(self):
|
|
131
|
+
metrics, inf_nan_consistency = self._compute_ratio()
|
|
132
|
+
compare_result = self._post_compare(metrics, inf_nan_consistency)
|
|
133
|
+
return compare_result
|
|
134
|
+
|
|
135
|
+
def _get_and_convert_values(self, column_name):
|
|
136
|
+
npu_value = self.row_npu.get(column_name)
|
|
137
|
+
gpu_value = self.row_gpu.get(column_name)
|
|
138
|
+
if npu_value is None:
|
|
139
|
+
raise ValueError(f"NPU value for column '{column_name}' is None.")
|
|
140
|
+
if gpu_value is None:
|
|
141
|
+
raise ValueError(f"GPU value for column '{column_name}' is None.")
|
|
142
|
+
npu_value = convert_str_to_float(npu_value)
|
|
143
|
+
gpu_value = convert_str_to_float(gpu_value)
|
|
144
|
+
return npu_value, gpu_value
|
|
145
|
+
|
|
146
|
+
def _post_compare(self, metrics, inf_nan_consistency):
|
|
147
|
+
metrics = self._get_status(metrics, inf_nan_consistency)
|
|
148
|
+
metrics.update({'compare_algorithm': self.compare_algorithm})
|
|
149
|
+
self.compare_column.update(metrics)
|
|
150
|
+
compare_result = metrics.get('compare_result')
|
|
151
|
+
return compare_result
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
import math
|
|
19
|
+
from collections import namedtuple
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
|
|
23
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare, BasePrecisionCompare
|
|
24
|
+
from msprobe.pytorch.api_accuracy_checker.compare.algorithm import calc_ratio, get_small_value_err_ratio, get_rel_err, \
|
|
25
|
+
get_rmse, get_error_balance, get_max_rel_err, get_mean_rel_err
|
|
26
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ApiPrecisionCompareColumn, check_inf_or_nan, \
|
|
27
|
+
is_inf_or_nan
|
|
28
|
+
from msprobe.core.common.const import CompareConst
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
BenchmarkInfNanConsistency = namedtuple('BenchmarkInfNanConsistency', ['small_value_inf_nan_consistency',
|
|
32
|
+
'rmse_inf_nan_consistency',
|
|
33
|
+
'max_rel_inf_nan_consistency',
|
|
34
|
+
'mean_rel_inf_nan_consistency',
|
|
35
|
+
'eb_inf_nan_consistency'])
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class BenchmarkCompare(BaseCompare):
|
|
39
|
+
"""
|
|
40
|
+
Benchmark comparison class for calculating accuracy metrics.
|
|
41
|
+
|
|
42
|
+
This class is designed to compare the output of a benchmark test with the output of a device.
|
|
43
|
+
It calculates various metrics such as small value error ratio, RMSE, error balance, max relative error,
|
|
44
|
+
and mean relative error to assess the accuracy of the device output against the benchmark output.
|
|
45
|
+
|
|
46
|
+
Attributes:
|
|
47
|
+
bench_output (np.ndarray): The output from the benchmark.
|
|
48
|
+
device_output (np.ndarray): The output from the device.
|
|
49
|
+
dtype (torch.dtype): The data type of the outputs.
|
|
50
|
+
abs_bench (np.ndarray): The absolute value of the benchmark output.
|
|
51
|
+
abs_bench_with_eps (np.ndarray): The absolute value of the benchmark output with epsilon.
|
|
52
|
+
both_finite_mask (np.ndarray): A mask indicating where both outputs are finite.
|
|
53
|
+
inf_nan_mask (np.ndarray): A mask indicating where either output is infinite or NaN.
|
|
54
|
+
abs_err (np.ndarray): The absolute error between the benchmark and device outputs.
|
|
55
|
+
small_value (float): The small value threshold for comparison.
|
|
56
|
+
small_value_atol (float): The absolute tolerance for small values.
|
|
57
|
+
small_value_mask (np.ndarray): A mask indicating where values are small.
|
|
58
|
+
rel_err (np.ndarray): The relative error between the benchmark and device outputs.
|
|
59
|
+
abs_err_greater_mask (np.ndarray): A mask indicating where absolute error is greater than the small value
|
|
60
|
+
tolerance.
|
|
61
|
+
|
|
62
|
+
Methods:
|
|
63
|
+
_get_abs_err_greater_mask(small_value_atol): Calculates a mask where absolute error is greater than the small
|
|
64
|
+
value tolerance.
|
|
65
|
+
_compute_rel_err(): Computes the relative error between the benchmark and device outputs.
|
|
66
|
+
_pre_compare(): Prepares the comparison by calculating various metrics.
|
|
67
|
+
_compute_metrics(): Computes the accuracy metrics.
|
|
68
|
+
|
|
69
|
+
Note:
|
|
70
|
+
This class assumes that the input data is a dictionary containing 'bench_output', 'device_output',
|
|
71
|
+
'compare_column' and 'dtype'.
|
|
72
|
+
The data type should be a PyTorch data type.
|
|
73
|
+
|
|
74
|
+
See Also:
|
|
75
|
+
BaseCompare: The base class for comparison classes.
|
|
76
|
+
InputData: The class containing input data for comparison.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(self, input_data):
|
|
80
|
+
super(BenchmarkCompare, self).__init__(input_data)
|
|
81
|
+
self.compare_algorithm = CompareConst.BENCHMARK
|
|
82
|
+
|
|
83
|
+
def _get_abs_err_greater_mask(self, small_value_atol):
|
|
84
|
+
abs_err_greater_mask = np.greater(self.abs_err, small_value_atol)
|
|
85
|
+
return abs_err_greater_mask
|
|
86
|
+
|
|
87
|
+
def _compute_rel_err(self):
|
|
88
|
+
rel_err = get_rel_err(self.abs_err, self.abs_bench_with_eps, self.small_value_mask, self.inf_nan_mask)
|
|
89
|
+
return rel_err
|
|
90
|
+
|
|
91
|
+
def _pre_compare(self):
|
|
92
|
+
self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps()
|
|
93
|
+
self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask()
|
|
94
|
+
self.abs_err = self.stat_abs_error()
|
|
95
|
+
self.small_value, self.small_value_atol = self.get_small_value_threshold()
|
|
96
|
+
self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value)
|
|
97
|
+
self.rel_err = self._compute_rel_err()
|
|
98
|
+
self.abs_err_greater_mask = self._get_abs_err_greater_mask(self.small_value_atol)
|
|
99
|
+
|
|
100
|
+
def _compute_metrics(self):
|
|
101
|
+
"""
|
|
102
|
+
Computes a comprehensive set of error metrics for the comparison between benchmark and device outputs.
|
|
103
|
+
|
|
104
|
+
This method calculates five key metrics:
|
|
105
|
+
1. Small Value Error Ratio: The proportion of errors associated with small values.
|
|
106
|
+
2. Root Mean Square Error (RMSE): The square root of the mean of the squared errors.
|
|
107
|
+
3. Error Balance (EB): A measure of the balance between the errors in the benchmark and device outputs.
|
|
108
|
+
4. Maximum Relative Error: The maximum relative error between the benchmark and device outputs.
|
|
109
|
+
5. Mean Relative Error: The mean relative error between the benchmark and device outputs.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
dict: A dictionary containing the computed error metrics.
|
|
113
|
+
The dictionary has the following keys:
|
|
114
|
+
- "small_value_err_ratio": The proportion of errors associated with small values.
|
|
115
|
+
- "max_rel_error": The maximum relative error.
|
|
116
|
+
- "mean_rel_error": The mean relative error.
|
|
117
|
+
- "rmse": The root mean square error.
|
|
118
|
+
- "eb": The error balance.
|
|
119
|
+
"""
|
|
120
|
+
small_value_err_ratio = get_small_value_err_ratio(self.small_value_mask, self.abs_err_greater_mask)
|
|
121
|
+
rmse = get_rmse(self.abs_err, np.logical_or(self.inf_nan_mask, self.small_value_mask))
|
|
122
|
+
eb = get_error_balance(self.bench_output, self.device_output)
|
|
123
|
+
max_rel_error = get_max_rel_err(self.rel_err)
|
|
124
|
+
mean_rel_error = get_mean_rel_err(self.rel_err)
|
|
125
|
+
|
|
126
|
+
return {
|
|
127
|
+
"small_value_err_ratio": small_value_err_ratio,
|
|
128
|
+
"max_rel_error": max_rel_error,
|
|
129
|
+
"mean_rel_error": mean_rel_error,
|
|
130
|
+
"rmse": rmse,
|
|
131
|
+
"eb": eb
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class BenchmarkPrecisionCompare(BasePrecisionCompare):
|
|
136
|
+
def __init__(self, input_data):
|
|
137
|
+
super().__init__(input_data)
|
|
138
|
+
self.compare_algorithm = CompareConst.BENCHMARK_COMPARE_ALGORITHM_NAME
|
|
139
|
+
|
|
140
|
+
@staticmethod
|
|
141
|
+
def get_final_status(status_list):
|
|
142
|
+
compare_result = CompareConst.PASS
|
|
143
|
+
if CompareConst.ERROR in status_list:
|
|
144
|
+
compare_result = CompareConst.ERROR
|
|
145
|
+
elif CompareConst.WARNING in status_list:
|
|
146
|
+
compare_result = CompareConst.WARNING
|
|
147
|
+
return compare_result
|
|
148
|
+
|
|
149
|
+
def _calc_ratio(self, column_name):
|
|
150
|
+
npu_value, gpu_value = self._get_and_convert_values(column_name)
|
|
151
|
+
if is_inf_or_nan(npu_value) or is_inf_or_nan(gpu_value):
|
|
152
|
+
return check_inf_or_nan(npu_value, gpu_value, column_name)
|
|
153
|
+
else:
|
|
154
|
+
return calc_ratio(npu_value, gpu_value, str(self.dtype)), True, ""
|
|
155
|
+
|
|
156
|
+
def _compute_ratio(self):
|
|
157
|
+
compare_message = ""
|
|
158
|
+
small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = \
|
|
159
|
+
self._calc_ratio(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE)
|
|
160
|
+
compare_message += small_value_message
|
|
161
|
+
rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE)
|
|
162
|
+
compare_message += rmse_message
|
|
163
|
+
max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = \
|
|
164
|
+
self._calc_ratio(ApiPrecisionCompareColumn.MAX_REL_ERR)
|
|
165
|
+
compare_message += max_rel_message
|
|
166
|
+
mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = \
|
|
167
|
+
self._calc_ratio(ApiPrecisionCompareColumn.MEAN_REL_ERR)
|
|
168
|
+
compare_message += mean_rel_message
|
|
169
|
+
eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB)
|
|
170
|
+
compare_message += eb_message
|
|
171
|
+
|
|
172
|
+
metrics = {
|
|
173
|
+
CompareConst.SMALL_VALUE_ERR_RATIO: small_value_err_ratio,
|
|
174
|
+
CompareConst.RMSE_RATIO: rmse_ratio,
|
|
175
|
+
CompareConst.MAX_REL_ERR_RATIO: max_rel_err_ratio,
|
|
176
|
+
CompareConst.MEAN_REL_ERR_RATIO: mean_rel_err_ratio,
|
|
177
|
+
CompareConst.EB_RATIO: eb_ratio,
|
|
178
|
+
CompareConst.COMPARE_MESSAGE: compare_message
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
return metrics, \
|
|
182
|
+
BenchmarkInfNanConsistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
|
|
183
|
+
max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency,
|
|
184
|
+
eb_inf_nan_consistency)
|
|
185
|
+
|
|
186
|
+
def _get_threshold(self, metric):
|
|
187
|
+
error_threshold = StandardConfig.get_benchmark_threshold(metric)
|
|
188
|
+
return error_threshold
|
|
189
|
+
|
|
190
|
+
def _get_single_metric_status(self, ratio, metric):
|
|
191
|
+
if is_inf_or_nan(ratio):
|
|
192
|
+
return CompareConst.PASS
|
|
193
|
+
error_threshold = self._get_threshold(metric)
|
|
194
|
+
if ratio > error_threshold:
|
|
195
|
+
return CompareConst.ERROR
|
|
196
|
+
return CompareConst.PASS
|
|
197
|
+
|
|
198
|
+
def _get_status(self, metrics, inf_nan_consistency):
|
|
199
|
+
small_value_err_ratio = metrics.get(CompareConst.SMALL_VALUE_ERR_RATIO)
|
|
200
|
+
rmse_ratio = metrics.get(CompareConst.RMSE_RATIO)
|
|
201
|
+
max_rel_err_ratio = metrics.get(CompareConst.MAX_REL_ERR_RATIO)
|
|
202
|
+
mean_rel_err_ratio = metrics.get(CompareConst.MEAN_REL_ERR_RATIO)
|
|
203
|
+
eb_ratio = metrics.get(CompareConst.EB_RATIO)
|
|
204
|
+
|
|
205
|
+
small_value_err_status = self._get_single_metric_status(small_value_err_ratio, CompareConst.SMALL_VALUE) \
|
|
206
|
+
if inf_nan_consistency.small_value_inf_nan_consistency else CompareConst.ERROR
|
|
207
|
+
rmse_status = self._get_single_metric_status(rmse_ratio, CompareConst.RMSE) \
|
|
208
|
+
if inf_nan_consistency.rmse_inf_nan_consistency else CompareConst.ERROR
|
|
209
|
+
max_rel_err_status = self._get_single_metric_status(max_rel_err_ratio, CompareConst.MAX_REL_ERR) \
|
|
210
|
+
if inf_nan_consistency.max_rel_inf_nan_consistency else CompareConst.ERROR
|
|
211
|
+
mean_rel_err_status = self._get_single_metric_status(mean_rel_err_ratio, CompareConst.MEAN_REL_ERR) \
|
|
212
|
+
if inf_nan_consistency.mean_rel_inf_nan_consistency else CompareConst.ERROR
|
|
213
|
+
eb_status = self._get_single_metric_status(eb_ratio, CompareConst.EB) \
|
|
214
|
+
if inf_nan_consistency.eb_inf_nan_consistency else CompareConst.ERROR
|
|
215
|
+
status_list = [small_value_err_status, rmse_status, max_rel_err_status, mean_rel_err_status]
|
|
216
|
+
compare_result = self.get_final_status(status_list)
|
|
217
|
+
status_dict = {
|
|
218
|
+
CompareConst.SMALL_VALUE_ERR_STATUS: small_value_err_status,
|
|
219
|
+
CompareConst.RMSE_STATUS: rmse_status,
|
|
220
|
+
CompareConst.MAX_REL_ERR_STATUS: max_rel_err_status,
|
|
221
|
+
CompareConst.MEAN_REL_ERR_STATUS: mean_rel_err_status,
|
|
222
|
+
CompareConst.EB_STATUS: eb_status
|
|
223
|
+
}
|
|
224
|
+
metrics.update(status_dict)
|
|
225
|
+
metrics.update({CompareConst.COMPARE_RESULT: compare_result})
|
|
226
|
+
return metrics
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
from msprobe.pytorch.api_accuracy_checker.compare.algorithm import compare_bool_tensor
|
|
19
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BinaryCompare(BaseCompare):
|
|
23
|
+
"""
|
|
24
|
+
Binary comparison class for comparing boolean tensors.
|
|
25
|
+
|
|
26
|
+
This class is designed to compare the output of a binary operation between a benchmark and a device.
|
|
27
|
+
It calculates the error rate of the comparison and provides a simple metric for assessing the accuracy.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
bench_output (np.ndarray): The output from the benchmark.
|
|
31
|
+
device_output (np.ndarray): The output from the device.
|
|
32
|
+
compare_column (object): The column object to store comparison results.
|
|
33
|
+
dtype (torch.dtype): The data type of the outputs.
|
|
34
|
+
|
|
35
|
+
Methods:
|
|
36
|
+
_compute_metrics(): Computes the comparison metrics, specifically the error rate.
|
|
37
|
+
|
|
38
|
+
Note:
|
|
39
|
+
This class assumes that the input data is an instance of InputData containing the benchmark output,
|
|
40
|
+
device output, comparison column, and data type. The outputs are expected to be boolean tensors.
|
|
41
|
+
|
|
42
|
+
See Also:
|
|
43
|
+
BaseCompare: The base class for comparison classes.
|
|
44
|
+
compare_bool_tensor: The function used to compare boolean tensors.
|
|
45
|
+
"""
|
|
46
|
+
def __init__(self, input_data):
|
|
47
|
+
super(BinaryCompare, self).__init__(input_data)
|
|
48
|
+
|
|
49
|
+
def _pre_compare(self):
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
def _compute_metrics(self):
|
|
53
|
+
"""
|
|
54
|
+
Computes the error rate metric for the comparison between benchmark and device outputs.
|
|
55
|
+
|
|
56
|
+
This method calculates the proportion of mismatches between the benchmark output and the device output.
|
|
57
|
+
It uses the `compare_bool_tensor` function to compare the two tensors and extract the error rate.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
dict: A dictionary containing the computed error rate metric.
|
|
61
|
+
The dictionary has the following key:
|
|
62
|
+
- "error_rate": The proportion of mismatches between the benchmark and device outputs.
|
|
63
|
+
"""
|
|
64
|
+
error_rate, _, _ = compare_bool_tensor(self.bench_output, self.device_output)
|
|
65
|
+
|
|
66
|
+
return {
|
|
67
|
+
"error_rate": error_rate
|
|
68
|
+
}
|