mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.const import CompareConst
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class StandardConfig:
|
|
24
|
+
"""
|
|
25
|
+
Standard configuration class for managing precision and comparison thresholds.
|
|
26
|
+
|
|
27
|
+
This class provides a centralized way to manage the small value thresholds, absolute tolerances,
|
|
28
|
+
and relative tolerances (rtol) used in precision comparisons. It allows for different thresholds
|
|
29
|
+
based on the data type, with default values provided for common data types.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
_small_value (dict): A dictionary mapping data types to their corresponding small value thresholds.
|
|
33
|
+
_small_value_atol (dict): A dictionary mapping data types to their corresponding absolute tolerances.
|
|
34
|
+
_rtol (dict): A dictionary mapping data types to their corresponding relative tolerances.
|
|
35
|
+
|
|
36
|
+
Methods:
|
|
37
|
+
get_small_value(dtype): Retrieves the small value threshold for the given data type.
|
|
38
|
+
get_small_value_atol(dtype): Retrieves the absolute tolerance for the given data type.
|
|
39
|
+
get_rtol(dtype): Retrieves the relative tolerance for the given data type.
|
|
40
|
+
|
|
41
|
+
Example:
|
|
42
|
+
>>> small_value = StandardConfig.get_small_value(torch.float32)
|
|
43
|
+
>>> atol = StandardConfig.get_small_value_atol(torch.float32)
|
|
44
|
+
>>> rtol = StandardConfig.get_rtol(torch.float32)
|
|
45
|
+
>>> print(small_value, atol, rtol)
|
|
46
|
+
1e-6 1e-9 1e-6
|
|
47
|
+
|
|
48
|
+
Note:
|
|
49
|
+
The data type is expected to be a PyTorch data type. If the data type is not found in the dictionary,
|
|
50
|
+
the default value is returned.
|
|
51
|
+
|
|
52
|
+
See Also:
|
|
53
|
+
torch.dtype: PyTorch data types.
|
|
54
|
+
"""
|
|
55
|
+
_small_value = {
|
|
56
|
+
torch.float16: 2**-10,
|
|
57
|
+
torch.bfloat16: 2**-10,
|
|
58
|
+
torch.float32: 2**-20,
|
|
59
|
+
"default": 2**-20
|
|
60
|
+
}
|
|
61
|
+
_threshold_small_value_atol = {
|
|
62
|
+
torch.float16: 2**-16,
|
|
63
|
+
torch.bfloat16: 1e-16,
|
|
64
|
+
torch.float32: 2**-30,
|
|
65
|
+
"default": 2**-30
|
|
66
|
+
}
|
|
67
|
+
_benchmark_small_value_atol = {
|
|
68
|
+
torch.float16: 1e-16,
|
|
69
|
+
torch.bfloat16: 1e-16,
|
|
70
|
+
torch.float32: 2**-30,
|
|
71
|
+
"default": 2**-30
|
|
72
|
+
}
|
|
73
|
+
_rtol = {
|
|
74
|
+
torch.float16: 2**-10,
|
|
75
|
+
torch.bfloat16: 2**-8,
|
|
76
|
+
torch.float32: 2**-20,
|
|
77
|
+
"default": 2**-20
|
|
78
|
+
}
|
|
79
|
+
_accumulative_error_bound = {
|
|
80
|
+
torch.float16: 2**-8,
|
|
81
|
+
torch.bfloat16: 2**-7,
|
|
82
|
+
torch.float32: 2**-11,
|
|
83
|
+
"default": 2**-11
|
|
84
|
+
}
|
|
85
|
+
_small_value_threshold = {
|
|
86
|
+
'error_threshold': 2,
|
|
87
|
+
'warning_threshold': 1,
|
|
88
|
+
"default": 1
|
|
89
|
+
}
|
|
90
|
+
_rmse_threshold = {
|
|
91
|
+
'error_threshold': 2,
|
|
92
|
+
'warning_threshold': 1,
|
|
93
|
+
"default": 1
|
|
94
|
+
}
|
|
95
|
+
_max_rel_err_threshold = {
|
|
96
|
+
'error_threshold': 10,
|
|
97
|
+
'warning_threshold': 1,
|
|
98
|
+
"default": 1
|
|
99
|
+
}
|
|
100
|
+
_mean_rel_err_threshold = {
|
|
101
|
+
'error_threshold': 2,
|
|
102
|
+
'warning_threshold': 1,
|
|
103
|
+
"default": 1
|
|
104
|
+
}
|
|
105
|
+
_eb_threshold = {
|
|
106
|
+
'error_threshold': 2,
|
|
107
|
+
'warning_threshold': 1,
|
|
108
|
+
"default": 1
|
|
109
|
+
}
|
|
110
|
+
_minmum_err = {
|
|
111
|
+
'torch.float16': 2**-11,
|
|
112
|
+
'torch.bfloat16': 2**-8,
|
|
113
|
+
'torch.float32': 2**-14,
|
|
114
|
+
'default': 2**-14
|
|
115
|
+
}
|
|
116
|
+
_accumulative_error_eb_threshold = {
|
|
117
|
+
'torch.float16': 2**-20,
|
|
118
|
+
'torch.bfloat16': 2**-7,
|
|
119
|
+
'torch.float32': 2**-14,
|
|
120
|
+
'default': 2**-14
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
_fp32_mean_ulp_err_threshold = 64
|
|
124
|
+
ulp_err_proportion_ratio = 1
|
|
125
|
+
_fp32_ulp_err_proportion = 0.05
|
|
126
|
+
_fp16_ulp_err_proportion = 0.001
|
|
127
|
+
_special_samll_value = 1
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def get_small_value(cls, dtype, standard):
|
|
131
|
+
if standard == CompareConst.ACCUMULATIVE_ERROR_COMPARE:
|
|
132
|
+
return cls._special_samll_value
|
|
133
|
+
return cls._small_value.get(dtype, cls._small_value["default"])
|
|
134
|
+
|
|
135
|
+
@classmethod
|
|
136
|
+
def get_small_value_atol(cls, dtype, standard):
|
|
137
|
+
standard_dict = {
|
|
138
|
+
CompareConst.ABSOLUTE_THRESHOLD: cls._threshold_small_value_atol,
|
|
139
|
+
CompareConst.BENCHMARK: cls._benchmark_small_value_atol
|
|
140
|
+
}
|
|
141
|
+
small_value_atol_standard = standard_dict.get(standard, cls._benchmark_small_value_atol)
|
|
142
|
+
return small_value_atol_standard.get(dtype, small_value_atol_standard["default"])
|
|
143
|
+
|
|
144
|
+
@classmethod
|
|
145
|
+
def get_rtol(cls, dtype):
|
|
146
|
+
return cls._rtol.get(dtype, cls._rtol["default"])
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def get_small_value_threshold(cls, threshold_type):
|
|
150
|
+
return cls._small_value_threshold.get(threshold_type, cls._small_value_threshold["default"])
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def get_rmse_threshold(cls, threshold_type):
|
|
154
|
+
return cls._rmse_threshold.get(threshold_type, cls._rmse_threshold["default"])
|
|
155
|
+
|
|
156
|
+
@classmethod
|
|
157
|
+
def get_max_rel_err_threshold(cls, threshold_type):
|
|
158
|
+
return cls._max_rel_err_threshold.get(threshold_type, cls._max_rel_err_threshold["default"])
|
|
159
|
+
|
|
160
|
+
@classmethod
|
|
161
|
+
def get_mean_rel_err_threshold(cls, threshold_type):
|
|
162
|
+
return cls._mean_rel_err_threshold.get(threshold_type, cls._mean_rel_err_threshold["default"])
|
|
163
|
+
|
|
164
|
+
@classmethod
|
|
165
|
+
def get_eb_threshold(cls, threshold_type):
|
|
166
|
+
return cls._eb_threshold.get(threshold_type, cls._eb_threshold["default"])
|
|
167
|
+
|
|
168
|
+
@classmethod
|
|
169
|
+
def get_benchmark_threshold(cls, metric):
|
|
170
|
+
metric_threshold_functions = {
|
|
171
|
+
'small_value': StandardConfig.get_small_value_threshold,
|
|
172
|
+
'rmse': StandardConfig.get_rmse_threshold,
|
|
173
|
+
'max_rel_err': StandardConfig.get_max_rel_err_threshold,
|
|
174
|
+
'mean_rel_err': StandardConfig.get_mean_rel_err_threshold,
|
|
175
|
+
'eb': StandardConfig.get_eb_threshold
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
threshold_func = metric_threshold_functions.get(metric)
|
|
179
|
+
return threshold_func('error_threshold')
|
|
180
|
+
|
|
181
|
+
@classmethod
|
|
182
|
+
def get_fp32_mean_ulp_err_threshold(cls):
|
|
183
|
+
return cls._fp32_mean_ulp_err_threshold
|
|
184
|
+
|
|
185
|
+
@classmethod
|
|
186
|
+
def get_ulp_err_proportion_ratio_threshold(cls):
|
|
187
|
+
return cls.ulp_err_proportion_ratio
|
|
188
|
+
|
|
189
|
+
@classmethod
|
|
190
|
+
def get_fp32_ulp_err_proportion_threshold(cls):
|
|
191
|
+
return cls._fp32_ulp_err_proportion
|
|
192
|
+
|
|
193
|
+
@classmethod
|
|
194
|
+
def get_fp16_ulp_err_proportion_threshold(cls):
|
|
195
|
+
return cls._fp16_ulp_err_proportion
|
|
196
|
+
|
|
197
|
+
@classmethod
|
|
198
|
+
def get_ulp_threshold(cls, dtype):
|
|
199
|
+
ulp_err_proportion_ratio_threshold = StandardConfig.get_ulp_err_proportion_ratio_threshold()
|
|
200
|
+
if dtype == torch.float32:
|
|
201
|
+
mean_ulp_err_threshold = StandardConfig.get_fp32_mean_ulp_err_threshold()
|
|
202
|
+
ulp_err_proportion_threshold = StandardConfig.get_fp32_ulp_err_proportion_threshold()
|
|
203
|
+
return mean_ulp_err_threshold, ulp_err_proportion_threshold, ulp_err_proportion_ratio_threshold
|
|
204
|
+
else:
|
|
205
|
+
ulp_err_proportion_threshold = StandardConfig.get_fp16_ulp_err_proportion_threshold()
|
|
206
|
+
return None, ulp_err_proportion_threshold, ulp_err_proportion_ratio_threshold
|
|
207
|
+
|
|
208
|
+
@classmethod
|
|
209
|
+
def get_minmum_err(cls, dtype):
|
|
210
|
+
return cls._minmum_err.get(dtype, cls._minmum_err["default"])
|
|
211
|
+
|
|
212
|
+
@classmethod
|
|
213
|
+
def get_accumulative_error_bound(cls, dtype):
|
|
214
|
+
return cls._accumulative_error_bound.get(dtype, cls._accumulative_error_bound["default"])
|
|
215
|
+
|
|
216
|
+
@classmethod
|
|
217
|
+
def get_accumulative_error_eb_threshold(cls, dtype):
|
|
218
|
+
return cls._accumulative_error_eb_threshold.get(dtype, cls._accumulative_error_eb_threshold["default"])
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
from typing import Callable
|
|
19
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import absolute_standard_api, binary_standard_api, \
|
|
20
|
+
ulp_standard_api, thousandth_standard_api, accumulative_error_standard_api, BINARY_COMPARE_UNSUPPORT_LIST
|
|
21
|
+
from msprobe.core.common.const import CompareConst
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class StandardRegistry:
|
|
25
|
+
"""
|
|
26
|
+
Registry class for managing comparison standards and functions.
|
|
27
|
+
|
|
28
|
+
This class provides a centralized registry for different comparison standards and their corresponding functions.
|
|
29
|
+
It allows for dynamic registration of comparison functions based on the standard category.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
comparison_functions (dict): A dictionary mapping standard categories to their corresponding comparison
|
|
33
|
+
functions.
|
|
34
|
+
standard_categories (dict): A dictionary mapping standard names to their corresponding API categories.
|
|
35
|
+
|
|
36
|
+
Methods:
|
|
37
|
+
_get_standard_category(api_name, dtype): Determines the standard category for a given API name and data type.
|
|
38
|
+
register(standard, func): Registers a comparison function for a given standard category.
|
|
39
|
+
get_comparison_function(api_name, dtype): Retrieves the comparison function for a given API name and data type.
|
|
40
|
+
|
|
41
|
+
Note:
|
|
42
|
+
The data type is used to determine the standard category if it is not supported by binary comparison.
|
|
43
|
+
If the API name is not found in any standard category, it defaults to the 'benchmark' category.
|
|
44
|
+
|
|
45
|
+
See Also:
|
|
46
|
+
BaseCompare: The base class for comparison classes.
|
|
47
|
+
"""
|
|
48
|
+
def __init__(self):
|
|
49
|
+
self.comparison_functions = {}
|
|
50
|
+
self.api_standard_function_map = {
|
|
51
|
+
CompareConst.ABSOLUTE_THRESHOLD: absolute_standard_api,
|
|
52
|
+
CompareConst.BINARY_CONSISTENCY: binary_standard_api,
|
|
53
|
+
CompareConst.ULP_COMPARE: ulp_standard_api,
|
|
54
|
+
CompareConst.THOUSANDTH_STANDARD: thousandth_standard_api,
|
|
55
|
+
CompareConst.ACCUMULATIVE_ERROR_COMPARE: accumulative_error_standard_api
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
def register(self, standard: str, func: Callable) -> None:
|
|
59
|
+
"""
|
|
60
|
+
Registers a comparison function for a given standard category.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
standard (str): The name of the standard category.
|
|
64
|
+
func (Callable): The comparison function to be registered.
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
ValueError: If the standard category is not supported.
|
|
68
|
+
"""
|
|
69
|
+
if not callable(func):
|
|
70
|
+
raise ValueError("The function to be registered must be callable.")
|
|
71
|
+
self.comparison_functions[standard] = func
|
|
72
|
+
|
|
73
|
+
def get_comparison_function(self, api_name, dtype=None):
|
|
74
|
+
standard = self._get_standard_category(api_name, dtype)
|
|
75
|
+
return self.comparison_functions.get(standard)
|
|
76
|
+
|
|
77
|
+
def _get_standard_category(self, api_name, dtype=None):
|
|
78
|
+
"""
|
|
79
|
+
Determines the standard category for a given API name and data type.
|
|
80
|
+
|
|
81
|
+
This method checks if the provided data type is supported for binary comparison.
|
|
82
|
+
If it is, the method returns 'binary_consistency'. Otherwise, it iterates over the
|
|
83
|
+
api_standard_function_map to find a matching category for the API name.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
api_name (str): The name of the API for which to determine the standard category.
|
|
87
|
+
dtype (type, optional): The data type to check against the BINARY_COMPARE_UNSUPPORT_LIST. Defaults to None.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
str: The name of the standard category that matches the API name and data type, or 'benchmark' if no match
|
|
91
|
+
is found.
|
|
92
|
+
|
|
93
|
+
Note:
|
|
94
|
+
This method assumes that the api_standard_function_map is properly populated with standard categories and
|
|
95
|
+
their corresponding API functions.
|
|
96
|
+
The BINARY_COMPARE_UNSUPPORT_LIST should be defined and contain all data types that are not supported for
|
|
97
|
+
binary comparison.
|
|
98
|
+
"""
|
|
99
|
+
if dtype and dtype not in BINARY_COMPARE_UNSUPPORT_LIST:
|
|
100
|
+
return CompareConst.BINARY_CONSISTENCY
|
|
101
|
+
for name, category in self.api_standard_function_map.items():
|
|
102
|
+
if api_name in category:
|
|
103
|
+
return name
|
|
104
|
+
return CompareConst.BENCHMARK
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rel_err_ratio
|
|
19
|
+
from msprobe.core.common.const import CompareConst
|
|
20
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ThousandthStdCompare(BaseCompare):
|
|
24
|
+
"""
|
|
25
|
+
Thousandth standard comparison class for calculating accuracy metrics.
|
|
26
|
+
|
|
27
|
+
A subclass of BaseCompare, specifically designed to compare the relative error
|
|
28
|
+
between benchmark and device outputs, focusing on errors within a thousandth (0.001) threshold.
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
rel_err_orign (float or array-like): The original relative error values to be compared.
|
|
32
|
+
compare_column (object): An object to store and update comparison metrics.
|
|
33
|
+
|
|
34
|
+
Methods:
|
|
35
|
+
_compute_metrics(): Computes the relative error metrics, specifically the thousandth error ratio.
|
|
36
|
+
"""
|
|
37
|
+
def __init__(self, input_data):
|
|
38
|
+
self.rel_err_orign = input_data.rel_err_orign
|
|
39
|
+
self.compare_column = input_data.compare_column
|
|
40
|
+
|
|
41
|
+
def _pre_compare(self):
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
def _compute_metrics(self):
|
|
45
|
+
"""
|
|
46
|
+
Computes the relative error metrics for the comparison, specifically focusing on errors within a thousandth
|
|
47
|
+
(0.001) threshold.
|
|
48
|
+
|
|
49
|
+
This method calculates the proportion of relative errors that are within the thousandth threshold.
|
|
50
|
+
It uses the `get_rel_err_ratio` function to determine the ratio of relative errors that are less than or
|
|
51
|
+
equal to the
|
|
52
|
+
specified threshold defined in `CompareConst.THOUSAND_RATIO_THRESHOLD`.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
dict: A dictionary containing the computed relative error metric.
|
|
56
|
+
The dictionary has the following key:
|
|
57
|
+
- 'rel_err_thousandth': The proportion of relative errors within the thousandth threshold.
|
|
58
|
+
"""
|
|
59
|
+
rel_err_thousandth, _ = get_rel_err_ratio(self.rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
|
|
60
|
+
|
|
61
|
+
return {
|
|
62
|
+
'rel_err_thousandth': rel_err_thousandth
|
|
63
|
+
}
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
from collections import namedtuple
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
|
|
23
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare, BasePrecisionCompare
|
|
24
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
25
|
+
from msprobe.pytorch.api_accuracy_checker.compare.algorithm import calc_ratio, get_ulp_err
|
|
26
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ApiPrecisionCompareColumn, check_inf_or_nan, \
|
|
27
|
+
is_inf_or_nan
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
UlpInfNanConsistency = namedtuple('UlpInfNanConsistency', ['mean_ulp_err_inf_nan_consistency',
|
|
31
|
+
'ulp_err_proportion_ratio_inf_nan_consistency'])
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class UlpCompare(BaseCompare):
|
|
35
|
+
"""
|
|
36
|
+
Ulp compare comparison class for calculating accuracy metrics.
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
bench_output (array-like): The benchmark output values.
|
|
40
|
+
device_output (array-like): The device output values.
|
|
41
|
+
dtype (torch.dtype): The data type of the outputs (e.g., torch.float32 or torch.float16).
|
|
42
|
+
ulp_err (array-like): The ULP errors calculated from the benchmark and device outputs.
|
|
43
|
+
|
|
44
|
+
Methods:
|
|
45
|
+
_stat_max_ulp_err(ulp_err): Calculates the maximum ULP error.
|
|
46
|
+
_stat_mean_ulp_err(ulp_err): Calculates the mean ULP error.
|
|
47
|
+
_stat_ulp_error_proportion(ulp_err): Calculates the proportion of ULP errors exceeding a threshold.
|
|
48
|
+
_pre_compare(): Prepares for comparison by calculating ULP errors.
|
|
49
|
+
_compute_metrics(): Computes the ULP error metrics.
|
|
50
|
+
"""
|
|
51
|
+
def __init__(self, input_data):
|
|
52
|
+
super(UlpCompare, self).__init__(input_data)
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def _stat_max_ulp_err(ulp_err):
|
|
56
|
+
return np.max(ulp_err)
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def _stat_mean_ulp_err(ulp_err):
|
|
60
|
+
return np.mean(ulp_err)
|
|
61
|
+
|
|
62
|
+
def _stat_ulp_error_proportion(self, ulp_err):
|
|
63
|
+
if self.dtype == torch.float32:
|
|
64
|
+
return np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / self.bench_output.size
|
|
65
|
+
else:
|
|
66
|
+
return np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / self.bench_output.size
|
|
67
|
+
|
|
68
|
+
def _pre_compare(self):
|
|
69
|
+
self.ulp_err = get_ulp_err(self.bench_output, self.device_output, self.dtype)
|
|
70
|
+
|
|
71
|
+
def _compute_metrics(self):
|
|
72
|
+
"""
|
|
73
|
+
Computes the ULP error metrics for the comparison.
|
|
74
|
+
|
|
75
|
+
This method calculates three key metrics:
|
|
76
|
+
1. Maximum ULP error: The maximum difference in ULPs between the benchmark and device outputs.
|
|
77
|
+
2. Mean ULP error: The average difference in ULPs between the benchmark and device outputs.
|
|
78
|
+
3. ULP error proportion: The proportion of ULP errors that exceed a certain threshold.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
None (this method uses instance variables)
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
dict: A dictionary containing the computed ULP error metrics.
|
|
85
|
+
The dictionary has the following keys:
|
|
86
|
+
- "max_ulp_error": The maximum ULP error.
|
|
87
|
+
- "mean_ulp_error": The mean ULP error.
|
|
88
|
+
- "ulp_error_proportion": The proportion of ULP errors exceeding the threshold.
|
|
89
|
+
"""
|
|
90
|
+
max_ulp_error = self._stat_max_ulp_err(self.ulp_err)
|
|
91
|
+
mean_ulp_error = self._stat_mean_ulp_err(self.ulp_err)
|
|
92
|
+
|
|
93
|
+
ulp_error_proportion = self._stat_ulp_error_proportion(self.ulp_err)
|
|
94
|
+
|
|
95
|
+
return {
|
|
96
|
+
"max_ulp_error": max_ulp_error,
|
|
97
|
+
"mean_ulp_error": mean_ulp_error,
|
|
98
|
+
"ulp_error_proportion": ulp_error_proportion
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class UlpPrecisionCompare(BasePrecisionCompare):
|
|
103
|
+
def __init__(self, input_data):
|
|
104
|
+
super().__init__(input_data)
|
|
105
|
+
self.compare_algorithm = CompareConst.ULP_COMPARE_ALGORITHM_NAME
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def _compute_ulp_err_proportion_ratio(npu_value, gpu_value, dtype):
|
|
109
|
+
column_name = ApiPrecisionCompareColumn.ULP_ERR_PROPORTION
|
|
110
|
+
if is_inf_or_nan(npu_value) or is_inf_or_nan(gpu_value):
|
|
111
|
+
return check_inf_or_nan(npu_value, gpu_value, column_name)
|
|
112
|
+
else:
|
|
113
|
+
return calc_ratio(npu_value, gpu_value, dtype), True, ""
|
|
114
|
+
|
|
115
|
+
def _compute_mean_ulp_err(self):
|
|
116
|
+
column_name = ApiPrecisionCompareColumn.MEAN_ULP_ERR
|
|
117
|
+
npu_value, gpu_value = self._get_and_convert_values(column_name)
|
|
118
|
+
if is_inf_or_nan(npu_value) or is_inf_or_nan(gpu_value):
|
|
119
|
+
_, mean_ulp_err_inf_nan_consistency, message = check_inf_or_nan(npu_value, gpu_value, column_name)
|
|
120
|
+
return npu_value, mean_ulp_err_inf_nan_consistency, message
|
|
121
|
+
else:
|
|
122
|
+
return npu_value, True, ""
|
|
123
|
+
|
|
124
|
+
def _compute_ulp_err_proportion(self):
|
|
125
|
+
column_name = ApiPrecisionCompareColumn.ULP_ERR_PROPORTION
|
|
126
|
+
npu_value, gpu_value = self._get_and_convert_values(column_name)
|
|
127
|
+
return npu_value, gpu_value
|
|
128
|
+
|
|
129
|
+
def _get_status(self, metrics, inf_nan_consistency):
|
|
130
|
+
ulp_inf_nan_consistency = inf_nan_consistency.mean_ulp_err_inf_nan_consistency and \
|
|
131
|
+
inf_nan_consistency.ulp_err_proportion_ratio_inf_nan_consistency
|
|
132
|
+
|
|
133
|
+
if not ulp_inf_nan_consistency:
|
|
134
|
+
status_dict = {
|
|
135
|
+
CompareConst.ULP_ERR_STATUS: CompareConst.ERROR
|
|
136
|
+
}
|
|
137
|
+
compare_result = CompareConst.ERROR
|
|
138
|
+
metrics[CompareConst.COMPARE_MESSAGE] = metrics.get(CompareConst.COMPARE_MESSAGE, "") + \
|
|
139
|
+
"ERROR: ULP误差不满足标准\n"
|
|
140
|
+
metrics.update({CompareConst.COMPARE_RESULT: compare_result})
|
|
141
|
+
return metrics
|
|
142
|
+
|
|
143
|
+
dtype = self.row_npu.get(ApiPrecisionCompareColumn.DEVICE_DTYPE)
|
|
144
|
+
mean_ulp_err = metrics.get(CompareConst.MEAN_ULP_ERR)
|
|
145
|
+
ulp_err_proportion = metrics.get(CompareConst.ULP_ERR_PROPORTION)
|
|
146
|
+
ulp_err_proportion_ratio = metrics.get(CompareConst.ULP_ERR_PROPORTION_RATIO)
|
|
147
|
+
if dtype == Const.TORCH_FLOAT32:
|
|
148
|
+
status, final_message = \
|
|
149
|
+
self._get_fp32_ulp_err_status(mean_ulp_err, ulp_err_proportion, ulp_err_proportion_ratio)
|
|
150
|
+
else:
|
|
151
|
+
status, final_message = \
|
|
152
|
+
self._get_fp16_ulp_err_status(ulp_err_proportion, ulp_err_proportion_ratio)
|
|
153
|
+
metrics[CompareConst.COMPARE_MESSAGE] = metrics.get(CompareConst.COMPARE_MESSAGE, "") + final_message
|
|
154
|
+
|
|
155
|
+
status_dict = {
|
|
156
|
+
CompareConst.ULP_ERR_STATUS: status
|
|
157
|
+
}
|
|
158
|
+
compare_result = status
|
|
159
|
+
metrics.update(status_dict)
|
|
160
|
+
metrics.update({CompareConst.COMPARE_RESULT: compare_result})
|
|
161
|
+
return metrics
|
|
162
|
+
|
|
163
|
+
def _get_fp32_ulp_err_status(self, mean_ulp_err, ulp_err_proportion, ulp_err_proportion_ratio):
|
|
164
|
+
mean_ulp_err_threshold, ulp_err_proportion_threshold, ulp_err_proportion_ratio_threshold = \
|
|
165
|
+
StandardConfig.get_ulp_threshold(torch.float32)
|
|
166
|
+
if mean_ulp_err < mean_ulp_err_threshold:
|
|
167
|
+
return CompareConst.PASS, ""
|
|
168
|
+
elif ulp_err_proportion < ulp_err_proportion_threshold:
|
|
169
|
+
return CompareConst.PASS, ""
|
|
170
|
+
elif ulp_err_proportion_ratio < ulp_err_proportion_ratio_threshold:
|
|
171
|
+
return CompareConst.PASS, ""
|
|
172
|
+
compare_message = "ERROR: ULP误差不满足标准\n"
|
|
173
|
+
return CompareConst.ERROR, compare_message
|
|
174
|
+
|
|
175
|
+
def _get_fp16_ulp_err_status(self, ulp_err_proportion, ulp_err_proportion_ratio):
|
|
176
|
+
_, ulp_err_proportion_threshold, ulp_err_proportion_ratio_threshold = \
|
|
177
|
+
StandardConfig.get_ulp_threshold(torch.float16)
|
|
178
|
+
if ulp_err_proportion < ulp_err_proportion_threshold:
|
|
179
|
+
return CompareConst.PASS, ""
|
|
180
|
+
elif ulp_err_proportion_ratio < ulp_err_proportion_ratio_threshold:
|
|
181
|
+
return CompareConst.PASS, ""
|
|
182
|
+
compare_message = "ERROR: ULP误差不满足标准\n"
|
|
183
|
+
return CompareConst.ERROR, compare_message
|
|
184
|
+
|
|
185
|
+
def _compute_ratio(self):
|
|
186
|
+
compare_message = ""
|
|
187
|
+
mean_ulp_err, mean_ulp_err_inf_nan_consistency, mean_ulp_err_message = self._compute_mean_ulp_err()
|
|
188
|
+
compare_message += mean_ulp_err_message
|
|
189
|
+
npu_ulp_err_proportion, gpu_ulp_err_proportion = self._compute_ulp_err_proportion()
|
|
190
|
+
ulp_err_proportion_ratio, ulp_err_proportion_ratio_inf_nan_consistency, ulp_err_proportion_ratio_message = \
|
|
191
|
+
self._compute_ulp_err_proportion_ratio(npu_ulp_err_proportion, gpu_ulp_err_proportion, str(self.dtype))
|
|
192
|
+
compare_message += ulp_err_proportion_ratio_message
|
|
193
|
+
metrics = {
|
|
194
|
+
CompareConst.MEAN_ULP_ERR: mean_ulp_err,
|
|
195
|
+
CompareConst.ULP_ERR_PROPORTION: npu_ulp_err_proportion,
|
|
196
|
+
CompareConst.ULP_ERR_PROPORTION_RATIO: ulp_err_proportion_ratio,
|
|
197
|
+
CompareConst.COMPARE_MESSAGE: compare_message
|
|
198
|
+
}
|
|
199
|
+
return metrics, UlpInfNanConsistency(mean_ulp_err_inf_nan_consistency,
|
|
200
|
+
ulp_err_proportion_ratio_inf_nan_consistency)
|
|
@@ -28,6 +28,7 @@ from msprobe.pytorch.common.log import logger
|
|
|
28
28
|
from msprobe.pytorch.common.utils import load_pt
|
|
29
29
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
30
30
|
|
|
31
|
+
|
|
31
32
|
TORCH_TYPE = ["torch.device", "torch.dtype"]
|
|
32
33
|
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
|
|
33
34
|
FLOAT_TYPE = [
|
|
@@ -310,6 +311,19 @@ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
|
|
|
310
311
|
kwargs_params[key] = gen_list_kwargs(value, api_name, convert_type, real_data_path)
|
|
311
312
|
elif value is None:
|
|
312
313
|
kwargs_params[key] = None
|
|
314
|
+
elif key == 'atten_mask' and api_name == 'npu_fusion_attention':
|
|
315
|
+
sparse_mode = kwargs_params.get('sparse_mode', {})
|
|
316
|
+
if isinstance(sparse_mode, dict):
|
|
317
|
+
sparse_mode_value = sparse_mode.get('value', 0)
|
|
318
|
+
elif isinstance(sparse_mode, int):
|
|
319
|
+
sparse_mode_value = sparse_mode
|
|
320
|
+
else:
|
|
321
|
+
msg = f'The sparse_mode value is not int or dict, but {type(sparse_mode)}'
|
|
322
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR, msg)
|
|
323
|
+
if sparse_mode_value in Const.FA_SPECIAL_SPARSE_MODE:
|
|
324
|
+
kwargs_params[key] = gen_atten_mask(value, convert_type, real_data_path)
|
|
325
|
+
else:
|
|
326
|
+
kwargs_params[key] = gen_data(value, api_name, True, convert_type, real_data_path)
|
|
313
327
|
elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"):
|
|
314
328
|
kwargs_params[key] = gen_data(value, api_name, True, convert_type, real_data_path)
|
|
315
329
|
elif value.get('type') in TORCH_TYPE:
|
|
@@ -319,6 +333,30 @@ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
|
|
|
319
333
|
return kwargs_params
|
|
320
334
|
|
|
321
335
|
|
|
336
|
+
def gen_atten_mask(info, convert_type, real_data_path):
|
|
337
|
+
"""
|
|
338
|
+
Function Description:
|
|
339
|
+
Based on API basic information, generate input parameters: atten_mask, for API forward running
|
|
340
|
+
Parameter:
|
|
341
|
+
info: API basic information. Dict
|
|
342
|
+
convert_type: convert ori_type to dist_type flag.
|
|
343
|
+
real_data_path: the root directory for storing real data.
|
|
344
|
+
"""
|
|
345
|
+
check_object_type(info, dict)
|
|
346
|
+
data_type = info.get('type')
|
|
347
|
+
data_path = info.get('datapath', info.get('data_name'))
|
|
348
|
+
data_path = get_full_data_path(data_path, real_data_path)
|
|
349
|
+
data = None
|
|
350
|
+
if data_type in TENSOR_DATA_LIST:
|
|
351
|
+
if data_path:
|
|
352
|
+
data = gen_real_tensor(data_path, convert_type)
|
|
353
|
+
else:
|
|
354
|
+
# 生成一个2048x2048的三角矩阵,对角线为1,其余为0
|
|
355
|
+
# 这是npu_fusion_attention的sparse_mode为[2, 3, 4]时,atten_mask的shape
|
|
356
|
+
data = torch.triu(torch.ones([2048, 2048]), diagonal=1).to(torch.bool)
|
|
357
|
+
return data
|
|
358
|
+
|
|
359
|
+
|
|
322
360
|
def gen_torch_kwargs(kwargs_params, key, value):
|
|
323
361
|
if value.get('type') != "torch.device":
|
|
324
362
|
module_name, attribute_name = get_module_and_atttribute_name(value.get('value'))
|
|
@@ -346,6 +384,23 @@ def gen_list_kwargs(kwargs_item_value, api_name, convert_type, real_data_path=No
|
|
|
346
384
|
return kwargs_item_result
|
|
347
385
|
|
|
348
386
|
|
|
387
|
+
def get_output_dtype(api_info):
|
|
388
|
+
"""
|
|
389
|
+
Function Description:
|
|
390
|
+
Based on API basic information, get the output data dtype
|
|
391
|
+
Parameter:
|
|
392
|
+
api_info: API basic information. Dict
|
|
393
|
+
"""
|
|
394
|
+
output_dtype = None
|
|
395
|
+
output_info = api_info.get(Const.OUTPUT)
|
|
396
|
+
if output_info and isinstance(output_info[0], dict):
|
|
397
|
+
output_str_dtype = output_info[0].get(Const.DTYPE)
|
|
398
|
+
if output_str_dtype in Const.TORCH_FLOAT_DTYPE:
|
|
399
|
+
module_name, attribute_name = get_module_and_atttribute_name(output_str_dtype)
|
|
400
|
+
output_dtype = get_attribute(module_name, attribute_name)
|
|
401
|
+
return output_dtype
|
|
402
|
+
|
|
403
|
+
|
|
349
404
|
def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
|
|
350
405
|
"""
|
|
351
406
|
Function Description:
|
|
@@ -372,4 +427,5 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d
|
|
|
372
427
|
else:
|
|
373
428
|
logger.warning(f'Warning: No args in {api_info} ')
|
|
374
429
|
args_params = []
|
|
375
|
-
|
|
430
|
+
output_dtype = get_output_dtype(api_info)
|
|
431
|
+
return args_params, kwargs_params, output_dtype
|