mindstudio-probe 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
- mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
- mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
- mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
- mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
- mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
- msprobe/README.md +182 -0
- msprobe/__init__.py +0 -0
- msprobe/config/README.md +397 -0
- msprobe/config/config.json +28 -0
- msprobe/config/img/free_benchmark.png +0 -0
- msprobe/core/common/const.py +241 -0
- msprobe/core/common/exceptions.py +88 -0
- msprobe/core/common/file_check.py +265 -0
- msprobe/core/common/log.py +55 -0
- msprobe/core/common/utils.py +516 -0
- msprobe/core/common_config.py +58 -0
- msprobe/core/data_dump/data_collector.py +140 -0
- msprobe/core/data_dump/data_processor/base.py +245 -0
- msprobe/core/data_dump/data_processor/factory.py +61 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
- msprobe/core/data_dump/json_writer.py +116 -0
- msprobe/core/data_dump/scope.py +178 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/debugger/__init__.py +0 -0
- msprobe/mindspore/debugger/debugger_config.py +51 -0
- msprobe/mindspore/debugger/precision_debugger.py +32 -0
- msprobe/mindspore/doc/dump.md +65 -0
- msprobe/mindspore/dump/__init__.py +0 -0
- msprobe/mindspore/dump/api_kbk_dump.py +55 -0
- msprobe/mindspore/dump/dump_tool_factory.py +38 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
- msprobe/mindspore/ms_config.py +78 -0
- msprobe/mindspore/overflow_check/__init__.py +0 -0
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
- msprobe/mindspore/task_handler_factory.py +21 -0
- msprobe/msprobe.py +67 -0
- msprobe/pytorch/__init__.py +4 -0
- msprobe/pytorch/advisor/advisor.py +124 -0
- msprobe/pytorch/advisor/advisor_const.py +59 -0
- msprobe/pytorch/advisor/advisor_result.py +58 -0
- msprobe/pytorch/api_accuracy_checker/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
- msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
- msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
- msprobe/pytorch/common/__init__.py +2 -0
- msprobe/pytorch/common/compare_script.template +14 -0
- msprobe/pytorch/common/log.py +32 -0
- msprobe/pytorch/common/parse_json.py +37 -0
- msprobe/pytorch/common/utils.py +224 -0
- msprobe/pytorch/compare/acc_compare.py +1024 -0
- msprobe/pytorch/compare/distributed_compare.py +111 -0
- msprobe/pytorch/compare/highlight.py +100 -0
- msprobe/pytorch/compare/mapping.yaml +607 -0
- msprobe/pytorch/compare/match.py +36 -0
- msprobe/pytorch/compare/npy_compare.py +244 -0
- msprobe/pytorch/debugger/__init__.py +0 -0
- msprobe/pytorch/debugger/debugger_config.py +86 -0
- msprobe/pytorch/debugger/precision_debugger.py +95 -0
- msprobe/pytorch/doc/FAQ.md +193 -0
- msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
- msprobe/pytorch/doc/dump.md +207 -0
- msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
- msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
- msprobe/pytorch/doc/img/cpu_info.png +0 -0
- msprobe/pytorch/doc/img/module_compare.png +0 -0
- msprobe/pytorch/doc/parse_tool.md +286 -0
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
- msprobe/pytorch/doc/run_overflow_check.md +25 -0
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
- msprobe/pytorch/free_benchmark/__init__.py +8 -0
- msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/common/constant.py +67 -0
- msprobe/pytorch/free_benchmark/common/counter.py +72 -0
- msprobe/pytorch/free_benchmark/common/enums.py +37 -0
- msprobe/pytorch/free_benchmark/common/params.py +129 -0
- msprobe/pytorch/free_benchmark/common/utils.py +98 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
- msprobe/pytorch/free_benchmark/main.py +102 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
- msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
- msprobe/pytorch/functional/__init__.py +0 -0
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +39 -0
- msprobe/pytorch/hook_module/__init__.py +1 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -0
- msprobe/pytorch/hook_module/hook_module.py +109 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
- msprobe/pytorch/hook_module/utils.py +29 -0
- msprobe/pytorch/hook_module/wrap_aten.py +100 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
- msprobe/pytorch/hook_module/wrap_functional.py +108 -0
- msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
- msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
- msprobe/pytorch/hook_module/wrap_torch.py +88 -0
- msprobe/pytorch/hook_module/wrap_vf.py +64 -0
- msprobe/pytorch/module_processer.py +98 -0
- msprobe/pytorch/online_dispatch/__init__.py +20 -0
- msprobe/pytorch/online_dispatch/compare.py +236 -0
- msprobe/pytorch/online_dispatch/dispatch.py +274 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
- msprobe/pytorch/online_dispatch/single_compare.py +391 -0
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
- msprobe/pytorch/online_dispatch/utils.py +187 -0
- msprobe/pytorch/parse.py +4 -0
- msprobe/pytorch/parse_tool/__init__.py +0 -0
- msprobe/pytorch/parse_tool/cli.py +32 -0
- msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
- msprobe/pytorch/parse_tool/lib/compare.py +259 -0
- msprobe/pytorch/parse_tool/lib/config.py +51 -0
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
- msprobe/pytorch/parse_tool/lib/utils.py +367 -0
- msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
- msprobe/pytorch/pt_config.py +93 -0
- msprobe/pytorch/service.py +167 -0
- msprobe/test/core_ut/common/test_utils.py +345 -0
- msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
- msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
- msprobe/test/core_ut/data_dump/test_scope.py +151 -0
- msprobe/test/core_ut/test_common_config.py +152 -0
- msprobe/test/core_ut/test_file_check.py +218 -0
- msprobe/test/core_ut/test_log.py +109 -0
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
- msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
- msprobe/test/mindspore_ut/test_ms_config.py +69 -0
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
- msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
- msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
- msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
- msprobe/test/pytorch_ut/test_pt_config.py +69 -0
- msprobe/test/pytorch_ut/test_service.py +59 -0
- msprobe/test/resources/advisor.txt +3 -0
- msprobe/test/resources/compare_result_20230703104808.csv +9 -0
- msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
- msprobe/test/resources/config.yaml +3 -0
- msprobe/test/resources/npu_test.pkl +8 -0
- msprobe/test/run_test.sh +30 -0
- msprobe/test/run_ut.py +58 -0
- msprobe/test/test_module_processer.py +64 -0
|
@@ -0,0 +1,545 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import math
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
from collections import namedtuple
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import write_csv
|
|
11
|
+
from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
|
|
12
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
|
|
13
|
+
API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
|
|
14
|
+
ApiPrecisionCompareColumn, AbsoluteStandardApi, BinaryStandardApi, ULPStandardApi, ThousandthStandardApi, \
|
|
15
|
+
BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage, is_inf_or_nan, \
|
|
16
|
+
check_inf_or_nan
|
|
17
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
|
|
18
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path
|
|
19
|
+
from msprobe.core.common.file_check import FileChecker, change_mode, check_path_before_create, create_directory
|
|
20
|
+
from msprobe.pytorch.common.log import logger
|
|
21
|
+
from msprobe.core.common.utils import CompareException
|
|
22
|
+
from msprobe.core.common.const import CompareConst, FileCheckConst
|
|
23
|
+
|
|
24
|
+
CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
|
|
25
|
+
BenchmarkInf_Nan_Consistency = namedtuple('BenchmarkInf_Nan_Consistency', ['small_value_inf_nan_consistency',
|
|
26
|
+
'rmse_inf_nan_consistency',
|
|
27
|
+
'max_rel_inf_nan_consistency',
|
|
28
|
+
'mean_rel_inf_nan_consistency',
|
|
29
|
+
'eb_inf_nan_consistency'])
|
|
30
|
+
unsupported_message = 'This data type does not support benchmark compare.'
|
|
31
|
+
|
|
32
|
+
DEFAULT_THRESHOLD = 1
|
|
33
|
+
|
|
34
|
+
benchmark_algorithms_thresholds = {
|
|
35
|
+
'small_value': {
|
|
36
|
+
'error_threshold': 2,
|
|
37
|
+
'warning_threshold': 1
|
|
38
|
+
},
|
|
39
|
+
'rmse': {
|
|
40
|
+
'error_threshold': 2,
|
|
41
|
+
'warning_threshold': 1
|
|
42
|
+
},
|
|
43
|
+
'max_rel_err': {
|
|
44
|
+
'error_threshold': 10,
|
|
45
|
+
'warning_threshold': 1
|
|
46
|
+
},
|
|
47
|
+
'mean_rel_err': {
|
|
48
|
+
'error_threshold': 2,
|
|
49
|
+
'warning_threshold': 1
|
|
50
|
+
},
|
|
51
|
+
'eb': {
|
|
52
|
+
'error_threshold': 2,
|
|
53
|
+
'warning_threshold': 1
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
benchmark_message = {
|
|
58
|
+
"small_value_err_status": {
|
|
59
|
+
CompareConst.ERROR: "ERROR: 小值域错误比值超过阈值\n",
|
|
60
|
+
CompareConst.WARNING: "WARNING: 小值域错误比值超过阈值\n"
|
|
61
|
+
},
|
|
62
|
+
"rmse_status": {
|
|
63
|
+
CompareConst.ERROR: "ERROR: 均方根误差比值超过阈值\n",
|
|
64
|
+
CompareConst.WARNING: "WARNING: 均方根误差比值超过阈值\n"
|
|
65
|
+
},
|
|
66
|
+
"max_rel_err_status": {
|
|
67
|
+
CompareConst.ERROR: "ERROR: 相对误差最大值比值超过阈值\n",
|
|
68
|
+
CompareConst.WARNING: "WARNING: 相对误差最大值比值超过阈值\n"
|
|
69
|
+
},
|
|
70
|
+
"mean_rel_err_status": {
|
|
71
|
+
CompareConst.ERROR: "ERROR: 相对误差平均值比值超过阈值\n",
|
|
72
|
+
CompareConst.WARNING: "WARNING: 相对误差平均值比值超过阈值\n"
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class Standard:
|
|
78
|
+
@staticmethod
|
|
79
|
+
def _calc_ratio(column_name, x, y, default_value):
|
|
80
|
+
'''
|
|
81
|
+
计算npu侧和gpu侧统计量的比值
|
|
82
|
+
输入:
|
|
83
|
+
column_name:统计量名称
|
|
84
|
+
x:npu侧统计量
|
|
85
|
+
y:gpu侧统计量
|
|
86
|
+
default:当x不接近0,y接近0,设置的比值默认值
|
|
87
|
+
输出:
|
|
88
|
+
ratio:统计量x和y的比值
|
|
89
|
+
inf_nan_consistency:不出现inf或nan时为True,出现inf或nan时必须同时为inf或-inf或nan才为True,否则为False
|
|
90
|
+
message:当出现inf或nan时的提示信息
|
|
91
|
+
'''
|
|
92
|
+
x, y = convert_str_to_float(x), convert_str_to_float(y)
|
|
93
|
+
|
|
94
|
+
if is_inf_or_nan(x) or is_inf_or_nan(y):
|
|
95
|
+
return check_inf_or_nan(x, y, column_name)
|
|
96
|
+
|
|
97
|
+
inf_nan_consistency = True
|
|
98
|
+
message = ""
|
|
99
|
+
if math.isclose(y, 0.0):
|
|
100
|
+
if math.isclose(x, 0.0):
|
|
101
|
+
return 1.0, inf_nan_consistency, message
|
|
102
|
+
else:
|
|
103
|
+
return default_value, inf_nan_consistency, message
|
|
104
|
+
else:
|
|
105
|
+
return abs(x / y), inf_nan_consistency, message
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class BenchmarkStandard(Standard):
|
|
109
|
+
def __init__(self, api_name, npu_precision, gpu_precision):
|
|
110
|
+
self.api_name = api_name
|
|
111
|
+
self.npu_precision = npu_precision
|
|
112
|
+
self.gpu_precision = gpu_precision
|
|
113
|
+
self.small_value_err_ratio = 1
|
|
114
|
+
self.rmse_ratio = 1
|
|
115
|
+
self.max_rel_err_ratio = 1
|
|
116
|
+
self.mean_rel_err_ratio = 1
|
|
117
|
+
self.eb_ratio = 1
|
|
118
|
+
self.small_value_err_status = CompareConst.PASS
|
|
119
|
+
self.rmse_status = CompareConst.PASS
|
|
120
|
+
self.max_rel_err_status = CompareConst.PASS
|
|
121
|
+
self.mean_rel_err_status = CompareConst.PASS
|
|
122
|
+
self.eb_status = CompareConst.PASS
|
|
123
|
+
self.check_result_list = []
|
|
124
|
+
self.final_result = CompareConst.PASS
|
|
125
|
+
self.compare_message = ""
|
|
126
|
+
|
|
127
|
+
def __str__(self):
|
|
128
|
+
return "%s" % (self.api_name)
|
|
129
|
+
|
|
130
|
+
@staticmethod
|
|
131
|
+
def _get_status(ratio, algorithm):
|
|
132
|
+
if math.isnan(ratio) or math.isinf(ratio):
|
|
133
|
+
return CompareConst.PASS
|
|
134
|
+
error_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('error_threshold', DEFAULT_THRESHOLD)
|
|
135
|
+
warning_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('warning_threshold',
|
|
136
|
+
DEFAULT_THRESHOLD)
|
|
137
|
+
if ratio > error_threshold:
|
|
138
|
+
return CompareConst.ERROR
|
|
139
|
+
elif ratio > warning_threshold:
|
|
140
|
+
return CompareConst.WARNING
|
|
141
|
+
return CompareConst.PASS
|
|
142
|
+
|
|
143
|
+
def get_result(self):
|
|
144
|
+
inf_nan_consistency = self._compare_ratio()
|
|
145
|
+
small_value_inf_nan_consistency = inf_nan_consistency.small_value_inf_nan_consistency
|
|
146
|
+
rmse_inf_nan_consistency = inf_nan_consistency.rmse_inf_nan_consistency
|
|
147
|
+
max_rel_inf_nan_consistency = inf_nan_consistency.max_rel_inf_nan_consistency
|
|
148
|
+
mean_rel_inf_nan_consistency = inf_nan_consistency.mean_rel_inf_nan_consistency
|
|
149
|
+
eb_inf_nan_consistency = inf_nan_consistency.eb_inf_nan_consistency
|
|
150
|
+
self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') if \
|
|
151
|
+
small_value_inf_nan_consistency else CompareConst.ERROR
|
|
152
|
+
self.check_result_list.append(self.small_value_err_status)
|
|
153
|
+
self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') if rmse_inf_nan_consistency \
|
|
154
|
+
else CompareConst.ERROR
|
|
155
|
+
self.check_result_list.append(self.rmse_status)
|
|
156
|
+
self.max_rel_err_status = self._get_status(self.max_rel_err_ratio, 'max_rel_err') if max_rel_inf_nan_consistency \
|
|
157
|
+
else CompareConst.ERROR
|
|
158
|
+
self.check_result_list.append(self.max_rel_err_status)
|
|
159
|
+
self.mean_rel_err_status = self._get_status(self.mean_rel_err_ratio, 'mean_rel_err') if mean_rel_inf_nan_consistency \
|
|
160
|
+
else CompareConst.ERROR
|
|
161
|
+
self.check_result_list.append(self.mean_rel_err_status)
|
|
162
|
+
self.eb_status = self._get_status(self.eb_ratio, 'eb')
|
|
163
|
+
if CompareConst.ERROR in self.check_result_list:
|
|
164
|
+
self.final_result = CompareConst.ERROR
|
|
165
|
+
elif CompareConst.WARNING in self.check_result_list:
|
|
166
|
+
self.final_result = CompareConst.WARNING
|
|
167
|
+
|
|
168
|
+
def to_column_value(self):
|
|
169
|
+
return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
|
|
170
|
+
self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
|
|
171
|
+
self.mean_rel_err_status, self.eb_ratio, self.eb_status]
|
|
172
|
+
|
|
173
|
+
def _compare_ratio(self):
|
|
174
|
+
|
|
175
|
+
self.small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = self._calc_ratio(
|
|
176
|
+
ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE,
|
|
177
|
+
self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE),
|
|
178
|
+
self.gpu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), 10000.0)
|
|
179
|
+
self.compare_message += small_value_message
|
|
180
|
+
self.rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE,
|
|
181
|
+
self.npu_precision.get(ApiPrecisionCompareColumn.RMSE),
|
|
182
|
+
self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0)
|
|
183
|
+
self.compare_message += rmse_message
|
|
184
|
+
self.max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = self._calc_ratio(
|
|
185
|
+
ApiPrecisionCompareColumn.MAX_REL_ERR,
|
|
186
|
+
self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR),
|
|
187
|
+
self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0)
|
|
188
|
+
self.compare_message += max_rel_message
|
|
189
|
+
self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(ApiPrecisionCompareColumn.MEAN_REL_ERR,
|
|
190
|
+
self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR),
|
|
191
|
+
self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0)
|
|
192
|
+
self.compare_message += mean_rel_message
|
|
193
|
+
self.eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB,
|
|
194
|
+
self.npu_precision.get(ApiPrecisionCompareColumn.EB),
|
|
195
|
+
self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0)
|
|
196
|
+
self.compare_message += eb_message
|
|
197
|
+
|
|
198
|
+
return BenchmarkInf_Nan_Consistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
|
|
199
|
+
max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency, eb_inf_nan_consistency)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class ULPStandard(Standard):
|
|
203
|
+
def __init__(self, api_name, npu_precision, gpu_precision):
|
|
204
|
+
self.api_name = api_name
|
|
205
|
+
self.npu_precision = npu_precision
|
|
206
|
+
self.gpu_precision = gpu_precision
|
|
207
|
+
self.mean_ulp_err = 0
|
|
208
|
+
self.ulp_err_proportion = 0
|
|
209
|
+
self.ulp_err_proportion_ratio = 1
|
|
210
|
+
self.ulp_err_status = CompareConst.PASS
|
|
211
|
+
self.compare_message = ""
|
|
212
|
+
|
|
213
|
+
def __str__(self):
|
|
214
|
+
return f"{self.api_name}"
|
|
215
|
+
|
|
216
|
+
def get_result(self):
|
|
217
|
+
self.mean_ulp_err = convert_str_to_float(self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
|
|
218
|
+
gpu_mean_ulp_err = convert_str_to_float(self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
|
|
219
|
+
inf_nan_consistency = True
|
|
220
|
+
if is_inf_or_nan(self.mean_ulp_err) or is_inf_or_nan(gpu_mean_ulp_err):
|
|
221
|
+
_, inf_nan_consistency, message = check_inf_or_nan(self.mean_ulp_err, gpu_mean_ulp_err,
|
|
222
|
+
ApiPrecisionCompareColumn.MEAN_ULP_ERR)
|
|
223
|
+
self.compare_message += message
|
|
224
|
+
self.ulp_err_proportion = convert_str_to_float(
|
|
225
|
+
self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION))
|
|
226
|
+
self.ulp_err_proportion_ratio, ulp_inf_nan_consistency, message = self._calc_ratio(
|
|
227
|
+
ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
|
|
228
|
+
self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION),
|
|
229
|
+
self.gpu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION), 10000.0)
|
|
230
|
+
inf_nan_consistency = inf_nan_consistency and ulp_inf_nan_consistency
|
|
231
|
+
self.compare_message += message
|
|
232
|
+
if inf_nan_consistency:
|
|
233
|
+
self.ulp_err_status = self._get_ulp_status(self.npu_precision.get(ApiPrecisionCompareColumn.DEVICE_DTYPE))
|
|
234
|
+
else:
|
|
235
|
+
self.ulp_err_status = CompareConst.ERROR
|
|
236
|
+
|
|
237
|
+
def _get_ulp_status(self, dtype):
|
|
238
|
+
if dtype == torch.float32:
|
|
239
|
+
if self.mean_ulp_err < 64:
|
|
240
|
+
return CompareConst.PASS
|
|
241
|
+
elif self.ulp_err_proportion < 0.05:
|
|
242
|
+
return CompareConst.PASS
|
|
243
|
+
elif self.ulp_err_proportion_ratio < 1:
|
|
244
|
+
return CompareConst.PASS
|
|
245
|
+
else:
|
|
246
|
+
self.compare_message += "ERROR: ULP误差不满足标准\n"
|
|
247
|
+
return CompareConst.ERROR
|
|
248
|
+
else:
|
|
249
|
+
if self.ulp_err_proportion < 0.001:
|
|
250
|
+
return CompareConst.PASS
|
|
251
|
+
elif self.ulp_err_proportion_ratio < 1:
|
|
252
|
+
return CompareConst.PASS
|
|
253
|
+
else:
|
|
254
|
+
self.compare_message += "ERROR: ULP误差不满足标准\n"
|
|
255
|
+
return CompareConst.ERROR
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def write_detail_csv(content, save_path):
|
|
259
|
+
rows = []
|
|
260
|
+
content = ["{:.{}f}".format(item, msCheckerConfig.precision) \
|
|
261
|
+
if isinstance(item, float) else item for item in content]
|
|
262
|
+
rows.append(content)
|
|
263
|
+
write_csv(rows, save_path)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def api_precision_compare(config):
|
|
267
|
+
logger.info("Start compare task")
|
|
268
|
+
logger.info(f"Compare task result will be saved in {config.result_csv_path}")
|
|
269
|
+
logger.info(f"Compare task detail will be saved in {config.details_csv_path}")
|
|
270
|
+
try:
|
|
271
|
+
npu_data = pd.read_csv(config.npu_csv_path)
|
|
272
|
+
except Exception as err:
|
|
273
|
+
logger.error(f"Open npu csv Error: %s" % str(err))
|
|
274
|
+
check_csv_columns(npu_data.columns, "npu_csv")
|
|
275
|
+
try:
|
|
276
|
+
gpu_data = pd.read_csv(config.gpu_csv_path)
|
|
277
|
+
except Exception as err:
|
|
278
|
+
logger.error(f"Open gpu csv Error: %s" % str(err))
|
|
279
|
+
check_csv_columns(gpu_data.columns, "gpu_csv")
|
|
280
|
+
detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
|
|
281
|
+
result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
|
|
282
|
+
write_csv(result_csv_title, config.result_csv_path)
|
|
283
|
+
write_csv(detail_csv_title, config.details_csv_path)
|
|
284
|
+
try:
|
|
285
|
+
analyse_csv(npu_data, gpu_data, config)
|
|
286
|
+
except Exception as err:
|
|
287
|
+
logger.error(f"Analyse csv Error: %s" % str(err))
|
|
288
|
+
change_mode(config.result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
289
|
+
change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def analyse_csv(npu_data, gpu_data, config):
|
|
293
|
+
forward_status, backward_status = [], []
|
|
294
|
+
last_api_name, last_api_dtype = None, None
|
|
295
|
+
for _, row_npu in npu_data.iterrows():
|
|
296
|
+
message = ''
|
|
297
|
+
compare_column = ApiPrecisionOutputColumn()
|
|
298
|
+
full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
|
|
299
|
+
row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
|
|
300
|
+
_, api_name, _, direction_status, _, _ = full_api_name_with_direction_status.split(".")
|
|
301
|
+
if row_gpu.empty:
|
|
302
|
+
logger.warning(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.')
|
|
303
|
+
continue
|
|
304
|
+
if len(row_gpu) > 1:
|
|
305
|
+
msg = f'This API : {full_api_name_with_direction_status} has multiple records in the GPU data.'
|
|
306
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
|
|
307
|
+
row_gpu = row_gpu.iloc[0]
|
|
308
|
+
new_status = CompareConst.SPACE
|
|
309
|
+
# 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对
|
|
310
|
+
if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace():
|
|
311
|
+
compare_column.api_name = full_api_name_with_direction_status
|
|
312
|
+
compare_column.compare_result = CompareConst.SKIP
|
|
313
|
+
compare_column.compare_message = row_npu[ApiPrecisionCompareColumn.MESSAGE]
|
|
314
|
+
new_status = CompareConst.SKIP
|
|
315
|
+
write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
|
|
316
|
+
else:
|
|
317
|
+
compare_column.api_name = full_api_name_with_direction_status
|
|
318
|
+
if api_name in ThousandthStandardApi:
|
|
319
|
+
new_status = record_thousandth_threshold_result(compare_column, row_npu)
|
|
320
|
+
elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or \
|
|
321
|
+
api_name in BinaryStandardApi:
|
|
322
|
+
new_status = record_binary_consistency_result(api_name, compare_column, row_npu)
|
|
323
|
+
elif api_name in AbsoluteStandardApi:
|
|
324
|
+
new_status = record_absolute_threshold_result(compare_column, row_npu)
|
|
325
|
+
elif api_name in ULPStandardApi and \
|
|
326
|
+
row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST:
|
|
327
|
+
us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu)
|
|
328
|
+
new_status = record_ulp_compare_result(compare_column, us)
|
|
329
|
+
elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST:
|
|
330
|
+
bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu)
|
|
331
|
+
new_status = record_benchmark_compare_result(compare_column, bs)
|
|
332
|
+
write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
|
|
333
|
+
|
|
334
|
+
if last_api_name is not None and api_name != last_api_name:
|
|
335
|
+
if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
|
|
336
|
+
message = unsupported_message
|
|
337
|
+
write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path)
|
|
338
|
+
forward_status, backward_status = [], []
|
|
339
|
+
message = ''
|
|
340
|
+
else:
|
|
341
|
+
forward_result = get_api_checker_result(forward_status)
|
|
342
|
+
backward_result = get_api_checker_result(backward_status)
|
|
343
|
+
message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
|
|
344
|
+
write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
|
|
345
|
+
forward_status, backward_status = [], []
|
|
346
|
+
message = ''
|
|
347
|
+
|
|
348
|
+
is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST
|
|
349
|
+
last_api_name = api_name
|
|
350
|
+
|
|
351
|
+
last_api_dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
|
|
352
|
+
if not is_supported:
|
|
353
|
+
continue
|
|
354
|
+
|
|
355
|
+
if direction_status == 'forward':
|
|
356
|
+
forward_status.append(new_status)
|
|
357
|
+
elif direction_status == 'backward':
|
|
358
|
+
backward_status.append(new_status)
|
|
359
|
+
else:
|
|
360
|
+
logger.error(f"Invalid direction status: {direction_status}")
|
|
361
|
+
|
|
362
|
+
if last_api_name is not None:
|
|
363
|
+
if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
|
|
364
|
+
message = unsupported_message
|
|
365
|
+
write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path)
|
|
366
|
+
else:
|
|
367
|
+
forward_result = get_api_checker_result(forward_status)
|
|
368
|
+
backward_result = get_api_checker_result(backward_status)
|
|
369
|
+
message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
|
|
370
|
+
write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def check_error_rate(npu_error_rate):
|
|
374
|
+
return CompareConst.PASS if convert_str_to_float(npu_error_rate) == 0 else CompareConst.ERROR
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def get_absolute_threshold_result(row_npu):
|
|
378
|
+
inf_nan_error_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO])
|
|
379
|
+
rel_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.REL_ERR_RATIO])
|
|
380
|
+
abs_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.ABS_ERR_RATIO])
|
|
381
|
+
|
|
382
|
+
inf_nan_result = CompareConst.PASS if inf_nan_error_ratio == 0 else CompareConst.ERROR
|
|
383
|
+
rel_err_result = CompareConst.PASS if rel_err_ratio == 0 else CompareConst.ERROR
|
|
384
|
+
abs_err_result = CompareConst.PASS if abs_err_ratio == 0 else CompareConst.ERROR
|
|
385
|
+
|
|
386
|
+
if CompareConst.ERROR in [inf_nan_result, rel_err_result, abs_err_result]:
|
|
387
|
+
absolute_threshold_result = CompareConst.ERROR
|
|
388
|
+
else:
|
|
389
|
+
absolute_threshold_result = CompareConst.PASS
|
|
390
|
+
|
|
391
|
+
return {
|
|
392
|
+
"inf_nan_error_ratio": inf_nan_error_ratio,
|
|
393
|
+
"inf_nan_result": inf_nan_result,
|
|
394
|
+
"rel_err_ratio": rel_err_ratio,
|
|
395
|
+
"rel_err_result": rel_err_result,
|
|
396
|
+
"abs_err_ratio": abs_err_ratio,
|
|
397
|
+
"abs_err_result": abs_err_result,
|
|
398
|
+
"absolute_threshold_result": absolute_threshold_result,
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def get_api_checker_result(status):
|
|
403
|
+
if not status:
|
|
404
|
+
return CompareConst.SPACE
|
|
405
|
+
if all(item == CompareConst.SKIP for item in status):
|
|
406
|
+
return CompareConst.SKIP
|
|
407
|
+
for const in (CompareConst.ERROR, CompareConst.WARNING):
|
|
408
|
+
if const in status:
|
|
409
|
+
return const
|
|
410
|
+
return CompareConst.PASS
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def check_csv_columns(columns, csv_type):
|
|
414
|
+
required_columns = ApiPrecisionCompareColumn.to_required_columns()
|
|
415
|
+
missing_columns = [column for column in required_columns if column not in columns]
|
|
416
|
+
if missing_columns:
|
|
417
|
+
msg = f"The following columns {','.join(missing_columns)} are missing in{csv_type}"
|
|
418
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def record_binary_consistency_result(api_name, compare_column, row_npu):
|
|
422
|
+
new_status = check_error_rate(row_npu[ApiPrecisionCompareColumn.ERROR_RATE])
|
|
423
|
+
compare_column.error_rate = row_npu[ApiPrecisionCompareColumn.ERROR_RATE]
|
|
424
|
+
compare_column.error_rate_status = new_status
|
|
425
|
+
compare_column.compare_result = new_status
|
|
426
|
+
compare_column.compare_algorithm = "二进制一致法"
|
|
427
|
+
message = ''
|
|
428
|
+
if compare_column.error_rate_status == CompareConst.ERROR:
|
|
429
|
+
message += "ERROR: 二进制一致错误率超过阈值\n"
|
|
430
|
+
message += CompareMessage.get(api_name, "")
|
|
431
|
+
compare_column.compare_message = message
|
|
432
|
+
return new_status
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def record_absolute_threshold_result(compare_column, row_npu):
|
|
436
|
+
absolute_threshold_result = get_absolute_threshold_result(row_npu)
|
|
437
|
+
compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
|
|
438
|
+
compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
|
|
439
|
+
compare_column.rel_err_ratio = absolute_threshold_result.get("rel_err_ratio")
|
|
440
|
+
compare_column.rel_err_ratio_status = absolute_threshold_result.get("rel_err_result")
|
|
441
|
+
compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
|
|
442
|
+
compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
|
|
443
|
+
compare_column.compare_result = absolute_threshold_result.get("absolute_threshold_result")
|
|
444
|
+
compare_column.compare_algorithm = "绝对阈值法"
|
|
445
|
+
message = ''
|
|
446
|
+
if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
|
|
447
|
+
message += "ERROR: inf/nan错误率超过阈值\n"
|
|
448
|
+
if compare_column.rel_err_ratio_status == CompareConst.ERROR:
|
|
449
|
+
message += "ERROR: 相对误差错误率超过阈值\n"
|
|
450
|
+
if compare_column.abs_err_ratio_status == CompareConst.ERROR:
|
|
451
|
+
message += "ERROR: 绝对误差错误率超过阈值\n"
|
|
452
|
+
compare_column.compare_message = message
|
|
453
|
+
return compare_column.compare_result
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def record_benchmark_compare_result(compare_column, bs):
|
|
457
|
+
bs.get_result()
|
|
458
|
+
compare_column.small_value_err_ratio = bs.small_value_err_ratio
|
|
459
|
+
compare_column.small_value_err_status = bs.small_value_err_status
|
|
460
|
+
compare_column.rmse_ratio = bs.rmse_ratio
|
|
461
|
+
compare_column.rmse_status = bs.rmse_status
|
|
462
|
+
compare_column.max_rel_err_ratio = bs.max_rel_err_ratio
|
|
463
|
+
compare_column.max_rel_err_status = bs.max_rel_err_status
|
|
464
|
+
compare_column.mean_rel_err_ratio = bs.mean_rel_err_ratio
|
|
465
|
+
compare_column.mean_rel_err_status = bs.mean_rel_err_status
|
|
466
|
+
compare_column.eb_ratio = bs.eb_ratio
|
|
467
|
+
compare_column.eb_status = bs.eb_status
|
|
468
|
+
compare_column.compare_result = bs.final_result
|
|
469
|
+
compare_column.compare_algorithm = "标杆比对法"
|
|
470
|
+
compare_column.compare_message = bs.compare_message
|
|
471
|
+
for status_attr, messages in benchmark_message.items():
|
|
472
|
+
status_value = getattr(compare_column, status_attr)
|
|
473
|
+
if status_value in messages:
|
|
474
|
+
compare_column.compare_message += messages[status_value]
|
|
475
|
+
return compare_column.compare_result
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
def record_ulp_compare_result(compare_column, us):
|
|
479
|
+
us.get_result()
|
|
480
|
+
compare_column.mean_ulp_err = us.mean_ulp_err
|
|
481
|
+
compare_column.ulp_err_proportion = us.ulp_err_proportion
|
|
482
|
+
compare_column.ulp_err_proportion_ratio = us.ulp_err_proportion_ratio
|
|
483
|
+
compare_column.ulp_err_status = us.ulp_err_status
|
|
484
|
+
compare_column.compare_result = us.ulp_err_status
|
|
485
|
+
compare_column.compare_algorithm = "ULP误差比对法"
|
|
486
|
+
compare_column.compare_message = us.compare_message
|
|
487
|
+
return compare_column.compare_result
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def check_thousandth_rate(thousandth_rate):
|
|
491
|
+
return CompareConst.PASS if convert_str_to_float(thousandth_rate) >= 0.999 else CompareConst.ERROR
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def record_thousandth_threshold_result(compare_column, row_npu):
|
|
495
|
+
new_status = check_thousandth_rate(row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH])
|
|
496
|
+
compare_column.rel_err_thousandth = row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
|
|
497
|
+
compare_column.rel_err_thousandth_status = new_status
|
|
498
|
+
compare_column.compare_result = new_status
|
|
499
|
+
compare_column.compare_algorithm = "双千指标法"
|
|
500
|
+
message = ''
|
|
501
|
+
if compare_column.rel_err_thousandth_status == CompareConst.ERROR:
|
|
502
|
+
message += "ERROR: 双千指标不达标\n"
|
|
503
|
+
compare_column.compare_message = message
|
|
504
|
+
return compare_column.compare_result
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def _api_precision_compare(parser=None):
|
|
508
|
+
if not parser:
|
|
509
|
+
parser = argparse.ArgumentParser()
|
|
510
|
+
_api_precision_compare_parser(parser)
|
|
511
|
+
args = parser.parse_args(sys.argv[1:])
|
|
512
|
+
_api_precision_compare_command(args)
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def _api_precision_compare_command(args):
|
|
516
|
+
npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail')
|
|
517
|
+
gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail')
|
|
518
|
+
out_path = os.path.realpath(args.out_path) if args.out_path else "./"
|
|
519
|
+
check_path_before_create(out_path)
|
|
520
|
+
create_directory(out_path)
|
|
521
|
+
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
|
|
522
|
+
out_path = out_path_checker.common_check()
|
|
523
|
+
result_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_RESULT_FILE_NAME)
|
|
524
|
+
details_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_DETAILS_FILE_NAME)
|
|
525
|
+
compare_config = CompareConfig(npu_csv_path, gpu_csv_path, result_csv_path, details_csv_path)
|
|
526
|
+
api_precision_compare(compare_config)
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def _api_precision_compare_parser(parser):
|
|
530
|
+
parser.add_argument("-npu", "--npu_csv_path", dest="npu_csv_path", default="", type=str,
|
|
531
|
+
help="<Required> , Accuracy_checking_details.csv generated on the NPU by using the "
|
|
532
|
+
"api_accuracy_checker tool.",
|
|
533
|
+
required=True)
|
|
534
|
+
parser.add_argument("-gpu", "--gpu_csv_path", dest="gpu_csv_path", default="", type=str,
|
|
535
|
+
help="<Required> Accuracy_checking_details.csv generated on the GPU by using the "
|
|
536
|
+
"api_accuracy_checker tool.",
|
|
537
|
+
required=False)
|
|
538
|
+
parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
|
|
539
|
+
help="<optional> The api precision compare task result out path.",
|
|
540
|
+
required=False)
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
if __name__ == '__main__':
|
|
544
|
+
_api_precision_compare()
|
|
545
|
+
logger.info("Compare task completed.")
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# Copyright (c) 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the BSD 3-Clause License (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# https://opensource.org/licenses/BSD-3-Clause
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
AbsoluteThreshStandard:
|
|
17
|
+
- mul
|
|
18
|
+
- mul_
|
|
19
|
+
- __mul__
|
|
20
|
+
- __imul__
|
|
21
|
+
- __rmul__
|
|
22
|
+
- add
|
|
23
|
+
- add_
|
|
24
|
+
- __add__
|
|
25
|
+
- __iadd__
|
|
26
|
+
- __radd__
|
|
27
|
+
- div
|
|
28
|
+
- div_
|
|
29
|
+
- __div__
|
|
30
|
+
- __idiv__
|
|
31
|
+
- divide
|
|
32
|
+
- divide_
|
|
33
|
+
- leaky_relu
|
|
34
|
+
- leaky_relu_
|
|
35
|
+
- prelu
|
|
36
|
+
- reciprocal
|
|
37
|
+
- reciprocal_
|
|
38
|
+
- rsqrt
|
|
39
|
+
- rsqrt_
|
|
40
|
+
- square
|
|
41
|
+
- square_
|
|
42
|
+
- sub
|
|
43
|
+
- sub_
|
|
44
|
+
- rsub
|
|
45
|
+
- __isub__
|
|
46
|
+
- __sub__
|
|
47
|
+
|
|
48
|
+
BinaryCompareStandard:
|
|
49
|
+
- abs
|
|
50
|
+
- abs_
|
|
51
|
+
- absolute
|
|
52
|
+
- absolute_
|
|
53
|
+
- argmin
|
|
54
|
+
- bitwise_and
|
|
55
|
+
- bitwise_and_
|
|
56
|
+
- broadcast_to
|
|
57
|
+
- ceil
|
|
58
|
+
- ceil_
|
|
59
|
+
- equal
|
|
60
|
+
- fill_
|
|
61
|
+
- flatten
|
|
62
|
+
- floor
|
|
63
|
+
- floor_
|
|
64
|
+
- gather
|
|
65
|
+
- greater
|
|
66
|
+
- greater_
|
|
67
|
+
- greater_equal
|
|
68
|
+
- greater_equal_
|
|
69
|
+
- isfinite
|
|
70
|
+
- isnan
|
|
71
|
+
- less
|
|
72
|
+
- less_
|
|
73
|
+
- less_equal
|
|
74
|
+
- less_equal_
|
|
75
|
+
- logical_and
|
|
76
|
+
- logical_and_
|
|
77
|
+
- logical_not
|
|
78
|
+
- logical_not_
|
|
79
|
+
- logical_or
|
|
80
|
+
- logical_or_
|
|
81
|
+
- masked_fill
|
|
82
|
+
- masked_fill_
|
|
83
|
+
- max_pool3d
|
|
84
|
+
- maximum
|
|
85
|
+
- minimum
|
|
86
|
+
- neg
|
|
87
|
+
- neg_
|
|
88
|
+
- nonzero
|
|
89
|
+
- not_equal
|
|
90
|
+
- not_equal_
|
|
91
|
+
- one_hot
|
|
92
|
+
- pad
|
|
93
|
+
- relu
|
|
94
|
+
- reshape
|
|
95
|
+
- round
|
|
96
|
+
- round_
|
|
97
|
+
- select
|
|
98
|
+
- sign
|
|
99
|
+
- sign_
|
|
100
|
+
- sort
|
|
101
|
+
- tile
|
|
102
|
+
- topk
|
|
103
|
+
- transpose
|
|
104
|
+
- transpose_
|
|
105
|
+
- tril
|
|
106
|
+
- tril_
|
|
107
|
+
- triu
|
|
108
|
+
- triu_
|
|
109
|
+
- type_as
|
|
110
|
+
|
|
111
|
+
ULPStandard:
|
|
112
|
+
- __matmul__
|
|
113
|
+
- addbmm
|
|
114
|
+
- addbmm_
|
|
115
|
+
- addmm
|
|
116
|
+
- addmm_
|
|
117
|
+
- baddbmm
|
|
118
|
+
- baddbmm_
|
|
119
|
+
- bilinear
|
|
120
|
+
- bmm
|
|
121
|
+
- chain_matmul
|
|
122
|
+
- hspmm
|
|
123
|
+
- linear
|
|
124
|
+
- matmul
|
|
125
|
+
- mm
|
|
126
|
+
- mv
|
|
127
|
+
- smm
|
|
128
|
+
- sspaddmm
|
|
129
|
+
|
|
130
|
+
ThousandthStandard:
|
|
131
|
+
- conv1d
|
|
132
|
+
- conv2d
|
|
133
|
+
|