mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -72,38 +72,53 @@ def check_need_convert(api_name):
|
|
|
72
72
|
return convert_type
|
|
73
73
|
|
|
74
74
|
|
|
75
|
-
def
|
|
75
|
+
def cross_entropy_process(api_info_dict):
|
|
76
76
|
"""
|
|
77
77
|
Function Description:
|
|
78
|
-
Preprocesses the API information.
|
|
78
|
+
Preprocesses the cross_entropy API information.
|
|
79
79
|
Parameter:
|
|
80
|
-
api_name: Name of the API.
|
|
81
80
|
api_info_dict: argument of the API.
|
|
82
81
|
Return api_info_dict:
|
|
83
|
-
convert_type: Type of conversion.
|
|
84
82
|
api_info_dict: Processed argument of the API.
|
|
85
83
|
"""
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
api_info_dict
|
|
89
|
-
|
|
84
|
+
if 'input_args' in api_info_dict and len(api_info_dict['input_args']) > 1 \
|
|
85
|
+
and 'Min' in api_info_dict['input_args'][1]:
|
|
86
|
+
if api_info_dict['input_args'][1]['Min'] <= 0:
|
|
87
|
+
# The second argument in cross_entropy should be -100 or not less than 0
|
|
88
|
+
api_info_dict['input_args'][1]['Min'] = 0
|
|
89
|
+
return api_info_dict
|
|
90
90
|
|
|
91
91
|
|
|
92
|
-
def
|
|
92
|
+
def histc_process(api_info_dict):
|
|
93
|
+
input_args = api_info_dict['input_args']
|
|
94
|
+
if input_args and input_args[0].get('dtype'):
|
|
95
|
+
dtype = input_args[0]['dtype']
|
|
96
|
+
if dtype in Const.TORCH_INT_DTYPE:
|
|
97
|
+
api_info_dict['input_args'][0]['dtype'] = Const.TORCH_FLOAT32
|
|
98
|
+
return api_info_dict
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
API_PROCESS_MAP = {
|
|
102
|
+
'cross_entropy': cross_entropy_process,
|
|
103
|
+
'histc': histc_process
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def api_info_preprocess(api_name, api_info_dict):
|
|
93
108
|
"""
|
|
94
109
|
Function Description:
|
|
95
|
-
Preprocesses the
|
|
110
|
+
Preprocesses the API information.
|
|
96
111
|
Parameter:
|
|
112
|
+
api_name: Name of the API.
|
|
97
113
|
api_info_dict: argument of the API.
|
|
98
114
|
Return api_info_dict:
|
|
115
|
+
convert_type: Type of conversion.
|
|
99
116
|
api_info_dict: Processed argument of the API.
|
|
100
117
|
"""
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
api_info_dict['input_args'][1]['Min'] = 0
|
|
106
|
-
return api_info_dict
|
|
118
|
+
convert_type = check_need_convert(api_name)
|
|
119
|
+
if api_name in API_PROCESS_MAP:
|
|
120
|
+
api_info_dict = API_PROCESS_MAP[api_name](api_info_dict)
|
|
121
|
+
return convert_type, api_info_dict
|
|
107
122
|
|
|
108
123
|
|
|
109
124
|
def initialize_save_path(save_path, dir_name):
|
|
@@ -16,10 +16,12 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
# 定义比对算法及比对标准
|
|
19
|
+
import math
|
|
19
20
|
import torch
|
|
20
21
|
import numpy as np
|
|
21
22
|
|
|
22
23
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAMETERS
|
|
24
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
|
|
23
25
|
from msprobe.core.common.const import CompareConst
|
|
24
26
|
|
|
25
27
|
|
|
@@ -179,13 +181,13 @@ def check_inf_nan_value(inf_nan_mask, bench_output, device_output, dtype, rtol):
|
|
|
179
181
|
|
|
180
182
|
def check_small_value(abs_err, small_value_mask, small_value_atol):
|
|
181
183
|
'''
|
|
182
|
-
|
|
184
|
+
新精度标准的绝对阈值法中,检查npu和golden正常值输出的绝对误差是否满足阈值
|
|
183
185
|
输入:
|
|
184
|
-
|
|
186
|
+
abs_err:npu输出和golden输出的绝对误差
|
|
185
187
|
normal_value_mask:npu输出和golden输出的正常值mask
|
|
186
|
-
|
|
188
|
+
atol:绝对误差的阈值
|
|
187
189
|
输出:
|
|
188
|
-
|
|
190
|
+
abs_err_ratio:npu输出和golden输出的绝对误差不满足阈值的比例
|
|
189
191
|
'''
|
|
190
192
|
greater_mask = np.greater(abs_err, small_value_atol)
|
|
191
193
|
err_mask = np.logical_and(greater_mask, small_value_mask)
|
|
@@ -195,13 +197,13 @@ def check_small_value(abs_err, small_value_mask, small_value_atol):
|
|
|
195
197
|
|
|
196
198
|
def check_norm_value(normal_value_mask, rel_err, rtol):
|
|
197
199
|
'''
|
|
198
|
-
|
|
200
|
+
新精度标准的相对阈值法中,检查npu和golden小值域输出的相对误差是否满足阈值
|
|
199
201
|
输入:
|
|
200
|
-
|
|
202
|
+
rel_err:npu输出和golden输出的相对误差
|
|
201
203
|
normal_value_mask:npu输出和golden输出的正常值mask
|
|
202
|
-
|
|
204
|
+
rtol:相对误差的阈值
|
|
203
205
|
输出:
|
|
204
|
-
|
|
206
|
+
rel_err_ratio:npu输出和golden输出的相对误差不满足阈值的比例
|
|
205
207
|
'''
|
|
206
208
|
err_mask = np.greater(rel_err, rtol)
|
|
207
209
|
err_mask = np.logical_and(err_mask, normal_value_mask)
|
|
@@ -228,3 +230,34 @@ def get_ulp_err(bench_output, device_output, dtype):
|
|
|
228
230
|
def calc_ulp_err(bench_output, device_output, eb, exponent_num, data_type):
|
|
229
231
|
return (device_output.astype(data_type) - bench_output).astype(data_type) * \
|
|
230
232
|
np.exp2(-eb + exponent_num).astype(data_type)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def calc_ratio(x, y, dtype):
|
|
236
|
+
"""
|
|
237
|
+
Calculate the ratio between NPU and GPU statistical values.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
x (float): Statistical value from the NPU side
|
|
241
|
+
y (float): Statistical value from the GPU side
|
|
242
|
+
dtype: Data type used to determine the minimum error value
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
float: The ratio of NPU to GPU statistical values
|
|
246
|
+
|
|
247
|
+
Notes:
|
|
248
|
+
- Takes absolute values of both x and y for calculation
|
|
249
|
+
- Uses StandardConfig.get_minmum_err(dtype) to get minimum error for the specified dtype
|
|
250
|
+
- Prevents division by zero by ensuring denominator is not less than minimum error
|
|
251
|
+
- Returns |x| / max(|y|, minimum_error)
|
|
252
|
+
"""
|
|
253
|
+
x, y = abs(x), abs(y)
|
|
254
|
+
minmum_err = StandardConfig.get_minmum_err(dtype)
|
|
255
|
+
err_y = max(y, minmum_err)
|
|
256
|
+
return x / err_y
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def compare_bool_tensor(bench_output, device_output):
|
|
260
|
+
error_nums = (bench_output != device_output).sum()
|
|
261
|
+
error_rate = float(error_nums / bench_output.size)
|
|
262
|
+
result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
|
|
263
|
+
return error_rate, result, ""
|
|
@@ -29,11 +29,15 @@ from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
|
|
|
29
29
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
|
|
30
30
|
API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
|
|
31
31
|
ApiPrecisionCompareColumn, absolute_standard_api, binary_standard_api, ulp_standard_api, thousandth_standard_api, \
|
|
32
|
-
BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage
|
|
33
|
-
|
|
32
|
+
BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage
|
|
33
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_input import PrecisionCompareInput
|
|
34
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_register import StandardRegistry
|
|
35
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.ulp_compare import UlpPrecisionCompare
|
|
36
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.benchmark_compare import BenchmarkPrecisionCompare
|
|
37
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
|
|
34
38
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
|
|
35
39
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path
|
|
36
|
-
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments
|
|
40
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments, extract_basic_api_segments
|
|
37
41
|
from msprobe.core.common.file_utils import FileChecker, change_mode, create_directory
|
|
38
42
|
from msprobe.pytorch.common.log import logger
|
|
39
43
|
from msprobe.core.common.utils import CompareException
|
|
@@ -47,30 +51,6 @@ BenchmarkInfNanConsistency = namedtuple('BenchmarkInfNanConsistency', ['small_va
|
|
|
47
51
|
'eb_inf_nan_consistency'])
|
|
48
52
|
UNSUPPORTED_MESSAGE = 'This data type does not support benchmark compare.'
|
|
49
53
|
|
|
50
|
-
DEFAULT_THRESHOLD = 1
|
|
51
|
-
|
|
52
|
-
benchmark_algorithms_thresholds = {
|
|
53
|
-
'small_value': {
|
|
54
|
-
'error_threshold': 2,
|
|
55
|
-
'warning_threshold': 1
|
|
56
|
-
},
|
|
57
|
-
'rmse': {
|
|
58
|
-
'error_threshold': 2,
|
|
59
|
-
'warning_threshold': 1
|
|
60
|
-
},
|
|
61
|
-
'max_rel_err': {
|
|
62
|
-
'error_threshold': 10,
|
|
63
|
-
'warning_threshold': 1
|
|
64
|
-
},
|
|
65
|
-
'mean_rel_err': {
|
|
66
|
-
'error_threshold': 2,
|
|
67
|
-
'warning_threshold': 1
|
|
68
|
-
},
|
|
69
|
-
'eb': {
|
|
70
|
-
'error_threshold': 2,
|
|
71
|
-
'warning_threshold': 1
|
|
72
|
-
}
|
|
73
|
-
}
|
|
74
54
|
|
|
75
55
|
benchmark_message = {
|
|
76
56
|
"small_value_err_status": {
|
|
@@ -92,189 +72,6 @@ benchmark_message = {
|
|
|
92
72
|
}
|
|
93
73
|
|
|
94
74
|
|
|
95
|
-
class Standard:
|
|
96
|
-
@staticmethod
|
|
97
|
-
def _calc_ratio(column_name, x, y, default_value):
|
|
98
|
-
'''
|
|
99
|
-
计算npu侧和gpu侧统计量的比值
|
|
100
|
-
输入:
|
|
101
|
-
column_name:统计量名称
|
|
102
|
-
x:npu侧统计量
|
|
103
|
-
y:gpu侧统计量
|
|
104
|
-
default:当x不接近0,y接近0,设置的比值默认值
|
|
105
|
-
输出:
|
|
106
|
-
ratio:统计量x和y的比值
|
|
107
|
-
inf_nan_consistency:不出现inf或nan时为True,出现inf或nan时必须同时为inf或-inf或nan才为True,否则为False
|
|
108
|
-
message:当出现inf或nan时的提示信息
|
|
109
|
-
'''
|
|
110
|
-
x, y = convert_str_to_float(x), convert_str_to_float(y)
|
|
111
|
-
|
|
112
|
-
if is_inf_or_nan(x) or is_inf_or_nan(y):
|
|
113
|
-
return check_inf_or_nan(x, y, column_name)
|
|
114
|
-
|
|
115
|
-
inf_nan_consistency = True
|
|
116
|
-
message = ""
|
|
117
|
-
if math.isclose(y, 0.0):
|
|
118
|
-
if math.isclose(x, 0.0):
|
|
119
|
-
return 1.0, inf_nan_consistency, message
|
|
120
|
-
else:
|
|
121
|
-
return default_value, inf_nan_consistency, message
|
|
122
|
-
else:
|
|
123
|
-
return abs(x / y), inf_nan_consistency, message
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
class BenchmarkStandard(Standard):
|
|
127
|
-
def __init__(self, api_name, npu_precision, gpu_precision):
|
|
128
|
-
self.api_name = api_name
|
|
129
|
-
self.npu_precision = npu_precision
|
|
130
|
-
self.gpu_precision = gpu_precision
|
|
131
|
-
self.small_value_err_ratio = 1
|
|
132
|
-
self.rmse_ratio = 1
|
|
133
|
-
self.max_rel_err_ratio = 1
|
|
134
|
-
self.mean_rel_err_ratio = 1
|
|
135
|
-
self.eb_ratio = 1
|
|
136
|
-
self.small_value_err_status = CompareConst.PASS
|
|
137
|
-
self.rmse_status = CompareConst.PASS
|
|
138
|
-
self.max_rel_err_status = CompareConst.PASS
|
|
139
|
-
self.mean_rel_err_status = CompareConst.PASS
|
|
140
|
-
self.eb_status = CompareConst.PASS
|
|
141
|
-
self.check_result_list = []
|
|
142
|
-
self.final_result = CompareConst.PASS
|
|
143
|
-
self.compare_message = ""
|
|
144
|
-
|
|
145
|
-
def __str__(self):
|
|
146
|
-
return "%s" % (self.api_name)
|
|
147
|
-
|
|
148
|
-
@staticmethod
|
|
149
|
-
def _get_status(ratio, algorithm):
|
|
150
|
-
if math.isnan(ratio) or math.isinf(ratio):
|
|
151
|
-
return CompareConst.PASS
|
|
152
|
-
error_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('error_threshold', DEFAULT_THRESHOLD)
|
|
153
|
-
warning_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('warning_threshold',
|
|
154
|
-
DEFAULT_THRESHOLD)
|
|
155
|
-
if ratio > error_threshold:
|
|
156
|
-
return CompareConst.ERROR
|
|
157
|
-
elif ratio > warning_threshold:
|
|
158
|
-
return CompareConst.WARNING
|
|
159
|
-
return CompareConst.PASS
|
|
160
|
-
|
|
161
|
-
def get_result(self):
|
|
162
|
-
inf_nan_consistency = self._compare_ratio()
|
|
163
|
-
small_value_inf_nan_consistency = inf_nan_consistency.small_value_inf_nan_consistency
|
|
164
|
-
rmse_inf_nan_consistency = inf_nan_consistency.rmse_inf_nan_consistency
|
|
165
|
-
max_rel_inf_nan_consistency = inf_nan_consistency.max_rel_inf_nan_consistency
|
|
166
|
-
mean_rel_inf_nan_consistency = inf_nan_consistency.mean_rel_inf_nan_consistency
|
|
167
|
-
eb_inf_nan_consistency = inf_nan_consistency.eb_inf_nan_consistency
|
|
168
|
-
self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') if \
|
|
169
|
-
small_value_inf_nan_consistency else CompareConst.ERROR
|
|
170
|
-
self.check_result_list.append(self.small_value_err_status)
|
|
171
|
-
self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') if rmse_inf_nan_consistency \
|
|
172
|
-
else CompareConst.ERROR
|
|
173
|
-
self.check_result_list.append(self.rmse_status)
|
|
174
|
-
self.max_rel_err_status = self._get_status(
|
|
175
|
-
self.max_rel_err_ratio, 'max_rel_err') if max_rel_inf_nan_consistency else CompareConst.ERROR
|
|
176
|
-
self.check_result_list.append(self.max_rel_err_status)
|
|
177
|
-
self.mean_rel_err_status = self._get_status(
|
|
178
|
-
self.mean_rel_err_ratio, 'mean_rel_err') if mean_rel_inf_nan_consistency else CompareConst.ERROR
|
|
179
|
-
self.check_result_list.append(self.mean_rel_err_status)
|
|
180
|
-
self.eb_status = self._get_status(self.eb_ratio, 'eb')
|
|
181
|
-
if CompareConst.ERROR in self.check_result_list:
|
|
182
|
-
self.final_result = CompareConst.ERROR
|
|
183
|
-
elif CompareConst.WARNING in self.check_result_list:
|
|
184
|
-
self.final_result = CompareConst.WARNING
|
|
185
|
-
|
|
186
|
-
def to_column_value(self):
|
|
187
|
-
return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
|
|
188
|
-
self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
|
|
189
|
-
self.mean_rel_err_status, self.eb_ratio, self.eb_status]
|
|
190
|
-
|
|
191
|
-
def _compare_ratio(self):
|
|
192
|
-
|
|
193
|
-
self.small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = self._calc_ratio(
|
|
194
|
-
ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE,
|
|
195
|
-
self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE),
|
|
196
|
-
self.gpu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), 10000.0)
|
|
197
|
-
self.compare_message += small_value_message
|
|
198
|
-
self.rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE,
|
|
199
|
-
self.npu_precision.get(ApiPrecisionCompareColumn.RMSE),
|
|
200
|
-
self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0)
|
|
201
|
-
self.compare_message += rmse_message
|
|
202
|
-
self.max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = self._calc_ratio(
|
|
203
|
-
ApiPrecisionCompareColumn.MAX_REL_ERR,
|
|
204
|
-
self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR),
|
|
205
|
-
self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0)
|
|
206
|
-
self.compare_message += max_rel_message
|
|
207
|
-
self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(
|
|
208
|
-
ApiPrecisionCompareColumn.MEAN_REL_ERR,
|
|
209
|
-
self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR),
|
|
210
|
-
self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0)
|
|
211
|
-
self.compare_message += mean_rel_message
|
|
212
|
-
self.eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB,
|
|
213
|
-
self.npu_precision.get(ApiPrecisionCompareColumn.EB),
|
|
214
|
-
self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0)
|
|
215
|
-
self.compare_message += eb_message
|
|
216
|
-
|
|
217
|
-
return BenchmarkInfNanConsistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
|
|
218
|
-
max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency,
|
|
219
|
-
eb_inf_nan_consistency)
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
class ULPStandard(Standard):
|
|
223
|
-
def __init__(self, api_name, npu_precision, gpu_precision):
|
|
224
|
-
self.api_name = api_name
|
|
225
|
-
self.npu_precision = npu_precision
|
|
226
|
-
self.gpu_precision = gpu_precision
|
|
227
|
-
self.mean_ulp_err = 0
|
|
228
|
-
self.ulp_err_proportion = 0
|
|
229
|
-
self.ulp_err_proportion_ratio = 1
|
|
230
|
-
self.ulp_err_status = CompareConst.PASS
|
|
231
|
-
self.compare_message = ""
|
|
232
|
-
|
|
233
|
-
def __str__(self):
|
|
234
|
-
return f"{self.api_name}"
|
|
235
|
-
|
|
236
|
-
def get_result(self):
|
|
237
|
-
self.mean_ulp_err = convert_str_to_float(self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
|
|
238
|
-
gpu_mean_ulp_err = convert_str_to_float(self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
|
|
239
|
-
inf_nan_consistency = True
|
|
240
|
-
if is_inf_or_nan(self.mean_ulp_err) or is_inf_or_nan(gpu_mean_ulp_err):
|
|
241
|
-
_, inf_nan_consistency, message = check_inf_or_nan(self.mean_ulp_err, gpu_mean_ulp_err,
|
|
242
|
-
ApiPrecisionCompareColumn.MEAN_ULP_ERR)
|
|
243
|
-
self.compare_message += message
|
|
244
|
-
self.ulp_err_proportion = convert_str_to_float(
|
|
245
|
-
self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION))
|
|
246
|
-
self.ulp_err_proportion_ratio, ulp_inf_nan_consistency, message = self._calc_ratio(
|
|
247
|
-
ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
|
|
248
|
-
self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION),
|
|
249
|
-
self.gpu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION), 10000.0)
|
|
250
|
-
inf_nan_consistency = inf_nan_consistency and ulp_inf_nan_consistency
|
|
251
|
-
self.compare_message += message
|
|
252
|
-
if inf_nan_consistency:
|
|
253
|
-
self.ulp_err_status = self._get_ulp_status(self.npu_precision.get(ApiPrecisionCompareColumn.DEVICE_DTYPE))
|
|
254
|
-
else:
|
|
255
|
-
self.ulp_err_status = CompareConst.ERROR
|
|
256
|
-
|
|
257
|
-
def _get_ulp_status(self, dtype):
|
|
258
|
-
if dtype == torch.float32:
|
|
259
|
-
if self.mean_ulp_err < 64:
|
|
260
|
-
return CompareConst.PASS
|
|
261
|
-
elif self.ulp_err_proportion < 0.05:
|
|
262
|
-
return CompareConst.PASS
|
|
263
|
-
elif self.ulp_err_proportion_ratio < 1:
|
|
264
|
-
return CompareConst.PASS
|
|
265
|
-
else:
|
|
266
|
-
self.compare_message += "ERROR: ULP误差不满足标准\n"
|
|
267
|
-
return CompareConst.ERROR
|
|
268
|
-
else:
|
|
269
|
-
if self.ulp_err_proportion < 0.001:
|
|
270
|
-
return CompareConst.PASS
|
|
271
|
-
elif self.ulp_err_proportion_ratio < 1:
|
|
272
|
-
return CompareConst.PASS
|
|
273
|
-
else:
|
|
274
|
-
self.compare_message += "ERROR: ULP误差不满足标准\n"
|
|
275
|
-
return CompareConst.ERROR
|
|
276
|
-
|
|
277
|
-
|
|
278
75
|
def write_detail_csv(content, save_path):
|
|
279
76
|
rows = []
|
|
280
77
|
content = ["{:.{}f}".format(item, msCheckerConfig.precision) \
|
|
@@ -283,6 +80,17 @@ def write_detail_csv(content, save_path):
|
|
|
283
80
|
write_csv(rows, save_path)
|
|
284
81
|
|
|
285
82
|
|
|
83
|
+
def register_compare_func():
|
|
84
|
+
registry = StandardRegistry()
|
|
85
|
+
registry.register(CompareConst.ABSOLUTE_THRESHOLD, record_absolute_threshold_result)
|
|
86
|
+
registry.register(CompareConst.BINARY_CONSISTENCY, record_binary_consistency_result)
|
|
87
|
+
registry.register(CompareConst.ULP_COMPARE, record_ulp_compare_result)
|
|
88
|
+
registry.register(CompareConst.THOUSANDTH_STANDARD, record_thousandth_threshold_result)
|
|
89
|
+
registry.register(CompareConst.BENCHMARK, record_benchmark_compare_result)
|
|
90
|
+
registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, record_accumulative_error_compare_result)
|
|
91
|
+
return registry
|
|
92
|
+
|
|
93
|
+
|
|
286
94
|
def api_precision_compare(config):
|
|
287
95
|
logger.info("Start compare task")
|
|
288
96
|
logger.info(f"Compare task result will be saved in {config.result_csv_path}")
|
|
@@ -337,6 +145,8 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
337
145
|
forward_status, backward_status = [], []
|
|
338
146
|
last_api_name, last_api_dtype, last_api_full_name = None, None, None
|
|
339
147
|
last_api_skip_message = ''
|
|
148
|
+
registry = register_compare_func()
|
|
149
|
+
|
|
340
150
|
for _, row_npu in npu_data.iterrows():
|
|
341
151
|
message = ''
|
|
342
152
|
compare_column = ApiPrecisionOutputColumn()
|
|
@@ -362,7 +172,7 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
362
172
|
row_gpu = row_gpu.iloc[0]
|
|
363
173
|
new_status = CompareConst.SPACE
|
|
364
174
|
try:
|
|
365
|
-
new_status = get_api_status(row_npu, row_gpu, api_name, compare_column)
|
|
175
|
+
new_status = get_api_status(row_npu, row_gpu, api_name, compare_column, registry)
|
|
366
176
|
except Exception as err:
|
|
367
177
|
logger.error(f"Get api status error: {str(err)}")
|
|
368
178
|
compare_column.api_name = full_api_name_with_direction_status
|
|
@@ -383,7 +193,8 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
383
193
|
else:
|
|
384
194
|
forward_result = get_api_checker_result(forward_status)
|
|
385
195
|
backward_result = get_api_checker_result(backward_status)
|
|
386
|
-
|
|
196
|
+
_, base_api_name = extract_basic_api_segments(last_api_name)
|
|
197
|
+
message += CompareMessage.get(base_api_name, "") if forward_result == CompareConst.ERROR else ""
|
|
387
198
|
message += last_api_skip_message if forward_result == CompareConst.SKIP else ""
|
|
388
199
|
write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
|
|
389
200
|
print_test_success(last_api_name, forward_result, backward_result)
|
|
@@ -415,37 +226,30 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
415
226
|
else:
|
|
416
227
|
forward_result = get_api_checker_result(forward_status)
|
|
417
228
|
backward_result = get_api_checker_result(backward_status)
|
|
418
|
-
|
|
229
|
+
_, base_api_name = extract_basic_api_segments(last_api_name)
|
|
230
|
+
message += CompareMessage.get(base_api_name, "") if forward_result == CompareConst.ERROR else ""
|
|
419
231
|
message += last_api_skip_message if forward_result == CompareConst.SKIP else ""
|
|
420
232
|
write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
|
|
421
233
|
print_test_success(last_api_name, forward_result, backward_result)
|
|
422
234
|
last_api_skip_message = ''
|
|
423
235
|
|
|
424
236
|
|
|
425
|
-
def get_api_status(row_npu, row_gpu, api_name, compare_column):
|
|
237
|
+
def get_api_status(row_npu, row_gpu, api_name, compare_column, registry):
|
|
426
238
|
full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
|
|
427
239
|
# 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对
|
|
428
|
-
if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace()
|
|
240
|
+
if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace() or \
|
|
241
|
+
row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in API_PRECISION_COMPARE_UNSUPPORT_LIST or \
|
|
242
|
+
row_npu[ApiPrecisionCompareColumn.SHAPE] == CompareConst.ZERO_SHAPE:
|
|
429
243
|
compare_column.api_name = full_api_name_with_direction_status
|
|
430
244
|
compare_column.compare_result = CompareConst.SKIP
|
|
431
245
|
compare_column.compare_message = row_npu[ApiPrecisionCompareColumn.MESSAGE]
|
|
432
246
|
new_status = CompareConst.SKIP
|
|
433
247
|
else:
|
|
434
248
|
compare_column.api_name = full_api_name_with_direction_status
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
new_status = record_binary_consistency_result(api_name, compare_column, row_npu)
|
|
440
|
-
elif api_name in absolute_standard_api:
|
|
441
|
-
new_status = record_absolute_threshold_result(compare_column, row_npu)
|
|
442
|
-
elif api_name in ulp_standard_api and \
|
|
443
|
-
row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST:
|
|
444
|
-
us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu)
|
|
445
|
-
new_status = record_ulp_compare_result(compare_column, us)
|
|
446
|
-
elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST:
|
|
447
|
-
bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu)
|
|
448
|
-
new_status = record_benchmark_compare_result(compare_column, bs)
|
|
249
|
+
dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
|
|
250
|
+
input_data = PrecisionCompareInput(row_npu, row_gpu, dtype, compare_column)
|
|
251
|
+
comparison_func = registry.get_comparison_function(api_name, dtype)
|
|
252
|
+
new_status = comparison_func(input_data)
|
|
449
253
|
return new_status
|
|
450
254
|
|
|
451
255
|
|
|
@@ -505,21 +309,24 @@ def check_csv_columns(columns, csv_type):
|
|
|
505
309
|
raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
|
|
506
310
|
|
|
507
311
|
|
|
508
|
-
def record_binary_consistency_result(
|
|
312
|
+
def record_binary_consistency_result(input_data):
|
|
313
|
+
row_npu = input_data.row_npu
|
|
314
|
+
compare_column = input_data.compare_column
|
|
509
315
|
new_status = check_error_rate(row_npu[ApiPrecisionCompareColumn.ERROR_RATE])
|
|
510
316
|
compare_column.error_rate = row_npu[ApiPrecisionCompareColumn.ERROR_RATE]
|
|
511
317
|
compare_column.error_rate_status = new_status
|
|
512
318
|
compare_column.compare_result = new_status
|
|
513
|
-
compare_column.compare_algorithm =
|
|
319
|
+
compare_column.compare_algorithm = CompareConst.BINARY_CONSISTENCY_ALGORITHM_NAME
|
|
514
320
|
message = ''
|
|
515
321
|
if compare_column.error_rate_status == CompareConst.ERROR:
|
|
516
322
|
message += "ERROR: 二进制一致错误率超过阈值\n"
|
|
517
|
-
message += CompareMessage.get(api_name, "")
|
|
518
323
|
compare_column.compare_message = message
|
|
519
324
|
return new_status
|
|
520
325
|
|
|
521
326
|
|
|
522
|
-
def record_absolute_threshold_result(
|
|
327
|
+
def record_absolute_threshold_result(input_data):
|
|
328
|
+
row_npu = input_data.row_npu
|
|
329
|
+
compare_column = input_data.compare_column
|
|
523
330
|
absolute_threshold_result = get_absolute_threshold_result(row_npu)
|
|
524
331
|
compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
|
|
525
332
|
compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
|
|
@@ -528,62 +335,88 @@ def record_absolute_threshold_result(compare_column, row_npu):
|
|
|
528
335
|
compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
|
|
529
336
|
compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
|
|
530
337
|
compare_column.compare_result = absolute_threshold_result.get("absolute_threshold_result")
|
|
531
|
-
compare_column.compare_algorithm =
|
|
338
|
+
compare_column.compare_algorithm = CompareConst.ABSOLUTE_THRESHOLD_ALGORITHM_NAME
|
|
532
339
|
message = ''
|
|
533
340
|
if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
|
|
534
|
-
message += "ERROR: inf/nan
|
|
341
|
+
message += "ERROR: inf/nan错误率超过阈值"
|
|
535
342
|
if compare_column.rel_err_ratio_status == CompareConst.ERROR:
|
|
536
|
-
message += "ERROR:
|
|
343
|
+
message += "ERROR: 相对误差错误率超过阈值"
|
|
537
344
|
if compare_column.abs_err_ratio_status == CompareConst.ERROR:
|
|
538
|
-
message += "ERROR:
|
|
345
|
+
message += "ERROR: 绝对误差错误率超过阈值"
|
|
539
346
|
compare_column.compare_message = message
|
|
540
347
|
return compare_column.compare_result
|
|
541
348
|
|
|
542
349
|
|
|
543
|
-
def record_benchmark_compare_result(
|
|
544
|
-
bs
|
|
545
|
-
|
|
546
|
-
compare_column.small_value_err_status = bs.small_value_err_status
|
|
547
|
-
compare_column.rmse_ratio = bs.rmse_ratio
|
|
548
|
-
compare_column.rmse_status = bs.rmse_status
|
|
549
|
-
compare_column.max_rel_err_ratio = bs.max_rel_err_ratio
|
|
550
|
-
compare_column.max_rel_err_status = bs.max_rel_err_status
|
|
551
|
-
compare_column.mean_rel_err_ratio = bs.mean_rel_err_ratio
|
|
552
|
-
compare_column.mean_rel_err_status = bs.mean_rel_err_status
|
|
553
|
-
compare_column.eb_ratio = bs.eb_ratio
|
|
554
|
-
compare_column.eb_status = bs.eb_status
|
|
555
|
-
compare_column.compare_result = bs.final_result
|
|
556
|
-
compare_column.compare_algorithm = "标杆比对法"
|
|
557
|
-
compare_column.compare_message = bs.compare_message
|
|
350
|
+
def record_benchmark_compare_result(input_data):
|
|
351
|
+
bs = BenchmarkPrecisionCompare(input_data)
|
|
352
|
+
compare_result = bs.compare()
|
|
558
353
|
for status_attr, messages in benchmark_message.items():
|
|
559
|
-
status_value = getattr(compare_column, status_attr)
|
|
354
|
+
status_value = getattr(input_data.compare_column, status_attr)
|
|
560
355
|
if status_value in messages:
|
|
561
|
-
compare_column.compare_message += messages[status_value]
|
|
562
|
-
return
|
|
356
|
+
input_data.compare_column.compare_message += messages[status_value]
|
|
357
|
+
return compare_result
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def record_ulp_compare_result(input_data):
|
|
361
|
+
us = UlpPrecisionCompare(input_data)
|
|
362
|
+
compare_result = us.compare()
|
|
363
|
+
return compare_result
|
|
563
364
|
|
|
564
365
|
|
|
565
|
-
def
|
|
566
|
-
|
|
567
|
-
compare_column
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
366
|
+
def record_accumulative_error_compare_result(input_data):
|
|
367
|
+
row_npu = input_data.row_npu
|
|
368
|
+
compare_column = input_data.compare_column
|
|
369
|
+
absolute_threshold_result = get_absolute_threshold_result(row_npu)
|
|
370
|
+
threshold_result = absolute_threshold_result.get("absolute_threshold_result")
|
|
371
|
+
eb, eb_result = check_eb(row_npu)
|
|
372
|
+
accumulative_error_compare_result = CompareConst.PASS
|
|
373
|
+
if CompareConst.ERROR in [threshold_result, eb_result]:
|
|
374
|
+
accumulative_error_compare_result = CompareConst.ERROR
|
|
375
|
+
|
|
376
|
+
compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
|
|
377
|
+
compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
|
|
378
|
+
compare_column.rel_err_ratio = absolute_threshold_result.get("rel_err_ratio")
|
|
379
|
+
compare_column.rel_err_ratio_status = absolute_threshold_result.get("rel_err_result")
|
|
380
|
+
compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
|
|
381
|
+
compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
|
|
382
|
+
compare_column.eb_ratio = eb
|
|
383
|
+
compare_column.eb_status = eb_result
|
|
384
|
+
compare_column.compare_result = accumulative_error_compare_result
|
|
385
|
+
compare_column.compare_algorithm = CompareConst.ACCUMULATIVE_ERROR_COMPARE_ALGORITHM_NAME
|
|
386
|
+
message = []
|
|
387
|
+
if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
|
|
388
|
+
message.append("ERROR: inf/nan错误率超过阈值\n")
|
|
389
|
+
if compare_column.rel_err_ratio_status == CompareConst.ERROR:
|
|
390
|
+
message.append("ERROR: 相对误差错误率超过阈值\n")
|
|
391
|
+
if compare_column.abs_err_ratio_status == CompareConst.ERROR:
|
|
392
|
+
message.append("ERROR: 绝对误差错误率超过阈值\n")
|
|
393
|
+
if compare_column.eb_status == CompareConst.ERROR:
|
|
394
|
+
message.append("ERROR: 误差均衡性超过阈值\n")
|
|
395
|
+
compare_column.compare_message = "\n".join(message)
|
|
574
396
|
return compare_column.compare_result
|
|
575
397
|
|
|
576
398
|
|
|
399
|
+
def check_eb(row_npu):
|
|
400
|
+
eb = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.EB])
|
|
401
|
+
dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
|
|
402
|
+
eb_threshold = StandardConfig.get_accumulative_error_eb_threshold(dtype)
|
|
403
|
+
eb_result = CompareConst.PASS if eb <= eb_threshold else CompareConst.ERROR
|
|
404
|
+
return eb, eb_result
|
|
405
|
+
|
|
406
|
+
|
|
577
407
|
def check_thousandth_rate(thousandth_rate):
|
|
578
|
-
return CompareConst.PASS if convert_str_to_float(thousandth_rate) >=
|
|
408
|
+
return CompareConst.PASS if convert_str_to_float(thousandth_rate) >= CompareConst.THOUSANDTH_PASS_VALUE \
|
|
409
|
+
else CompareConst.ERROR
|
|
579
410
|
|
|
580
411
|
|
|
581
|
-
def record_thousandth_threshold_result(
|
|
412
|
+
def record_thousandth_threshold_result(input_data):
|
|
413
|
+
row_npu = input_data.row_npu
|
|
414
|
+
compare_column = input_data.compare_column
|
|
582
415
|
new_status = check_thousandth_rate(row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH])
|
|
583
416
|
compare_column.rel_err_thousandth = row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
|
|
584
417
|
compare_column.rel_err_thousandth_status = new_status
|
|
585
418
|
compare_column.compare_result = new_status
|
|
586
|
-
compare_column.compare_algorithm =
|
|
419
|
+
compare_column.compare_algorithm = CompareConst.THOUSANDTH_STANDARD_ALGORITHM_NAME
|
|
587
420
|
message = ''
|
|
588
421
|
if compare_column.rel_err_thousandth_status == CompareConst.ERROR:
|
|
589
422
|
message += "ERROR: 双千指标不达标\n"
|
|
@@ -66,6 +66,7 @@ BinaryCompareStandard:
|
|
|
66
66
|
- greater_
|
|
67
67
|
- greater_equal
|
|
68
68
|
- greater_equal_
|
|
69
|
+
- histc
|
|
69
70
|
- isfinite
|
|
70
71
|
- isnan
|
|
71
72
|
- less
|
|
@@ -130,4 +131,6 @@ ULPStandard:
|
|
|
130
131
|
ThousandthStandard:
|
|
131
132
|
- conv1d
|
|
132
133
|
- conv2d
|
|
133
|
-
|
|
134
|
+
|
|
135
|
+
AccumulativeErrorStandard:
|
|
136
|
+
- test_api
|