mindstudio-probe 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
  2. mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
  3. mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
  4. mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
  5. mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
  6. mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
  7. msprobe/README.md +182 -0
  8. msprobe/__init__.py +0 -0
  9. msprobe/config/README.md +397 -0
  10. msprobe/config/config.json +28 -0
  11. msprobe/config/img/free_benchmark.png +0 -0
  12. msprobe/core/common/const.py +241 -0
  13. msprobe/core/common/exceptions.py +88 -0
  14. msprobe/core/common/file_check.py +265 -0
  15. msprobe/core/common/log.py +55 -0
  16. msprobe/core/common/utils.py +516 -0
  17. msprobe/core/common_config.py +58 -0
  18. msprobe/core/data_dump/data_collector.py +140 -0
  19. msprobe/core/data_dump/data_processor/base.py +245 -0
  20. msprobe/core/data_dump/data_processor/factory.py +61 -0
  21. msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
  22. msprobe/core/data_dump/json_writer.py +116 -0
  23. msprobe/core/data_dump/scope.py +178 -0
  24. msprobe/mindspore/__init__.py +1 -0
  25. msprobe/mindspore/debugger/__init__.py +0 -0
  26. msprobe/mindspore/debugger/debugger_config.py +51 -0
  27. msprobe/mindspore/debugger/precision_debugger.py +32 -0
  28. msprobe/mindspore/doc/dump.md +65 -0
  29. msprobe/mindspore/dump/__init__.py +0 -0
  30. msprobe/mindspore/dump/api_kbk_dump.py +55 -0
  31. msprobe/mindspore/dump/dump_tool_factory.py +38 -0
  32. msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
  33. msprobe/mindspore/ms_config.py +78 -0
  34. msprobe/mindspore/overflow_check/__init__.py +0 -0
  35. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
  36. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
  37. msprobe/mindspore/task_handler_factory.py +21 -0
  38. msprobe/msprobe.py +67 -0
  39. msprobe/pytorch/__init__.py +4 -0
  40. msprobe/pytorch/advisor/advisor.py +124 -0
  41. msprobe/pytorch/advisor/advisor_const.py +59 -0
  42. msprobe/pytorch/advisor/advisor_result.py +58 -0
  43. msprobe/pytorch/api_accuracy_checker/.keep +0 -0
  44. msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
  45. msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
  46. msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
  47. msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
  48. msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
  49. msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
  50. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
  51. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
  52. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
  53. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
  54. msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
  55. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
  58. msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
  59. msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
  60. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
  61. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
  62. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
  63. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
  64. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
  65. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
  66. msprobe/pytorch/common/__init__.py +2 -0
  67. msprobe/pytorch/common/compare_script.template +14 -0
  68. msprobe/pytorch/common/log.py +32 -0
  69. msprobe/pytorch/common/parse_json.py +37 -0
  70. msprobe/pytorch/common/utils.py +224 -0
  71. msprobe/pytorch/compare/acc_compare.py +1024 -0
  72. msprobe/pytorch/compare/distributed_compare.py +111 -0
  73. msprobe/pytorch/compare/highlight.py +100 -0
  74. msprobe/pytorch/compare/mapping.yaml +607 -0
  75. msprobe/pytorch/compare/match.py +36 -0
  76. msprobe/pytorch/compare/npy_compare.py +244 -0
  77. msprobe/pytorch/debugger/__init__.py +0 -0
  78. msprobe/pytorch/debugger/debugger_config.py +86 -0
  79. msprobe/pytorch/debugger/precision_debugger.py +95 -0
  80. msprobe/pytorch/doc/FAQ.md +193 -0
  81. msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
  82. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
  83. msprobe/pytorch/doc/dump.md +207 -0
  84. msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
  85. msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
  86. msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
  87. msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
  88. msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
  89. msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
  90. msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
  91. msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
  92. msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
  93. msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
  94. msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
  95. msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
  96. msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
  97. msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
  98. msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
  99. msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
  100. msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
  101. msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
  102. msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
  103. msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
  104. msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
  105. msprobe/pytorch/doc/img/cpu_info.png +0 -0
  106. msprobe/pytorch/doc/img/module_compare.png +0 -0
  107. msprobe/pytorch/doc/parse_tool.md +286 -0
  108. msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
  109. msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
  110. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
  111. msprobe/pytorch/doc/run_overflow_check.md +25 -0
  112. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
  113. msprobe/pytorch/free_benchmark/__init__.py +8 -0
  114. msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
  115. msprobe/pytorch/free_benchmark/common/constant.py +67 -0
  116. msprobe/pytorch/free_benchmark/common/counter.py +72 -0
  117. msprobe/pytorch/free_benchmark/common/enums.py +37 -0
  118. msprobe/pytorch/free_benchmark/common/params.py +129 -0
  119. msprobe/pytorch/free_benchmark/common/utils.py +98 -0
  120. msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
  121. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
  122. msprobe/pytorch/free_benchmark/main.py +102 -0
  123. msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
  124. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
  125. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
  126. msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
  127. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
  132. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
  133. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
  134. msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
  135. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
  136. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
  137. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
  138. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
  139. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
  140. msprobe/pytorch/functional/__init__.py +0 -0
  141. msprobe/pytorch/functional/data_processor.py +0 -0
  142. msprobe/pytorch/functional/dump_module.py +39 -0
  143. msprobe/pytorch/hook_module/__init__.py +1 -0
  144. msprobe/pytorch/hook_module/api_registry.py +161 -0
  145. msprobe/pytorch/hook_module/hook_module.py +109 -0
  146. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
  147. msprobe/pytorch/hook_module/utils.py +29 -0
  148. msprobe/pytorch/hook_module/wrap_aten.py +100 -0
  149. msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
  150. msprobe/pytorch/hook_module/wrap_functional.py +108 -0
  151. msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
  152. msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
  153. msprobe/pytorch/hook_module/wrap_torch.py +88 -0
  154. msprobe/pytorch/hook_module/wrap_vf.py +64 -0
  155. msprobe/pytorch/module_processer.py +98 -0
  156. msprobe/pytorch/online_dispatch/__init__.py +20 -0
  157. msprobe/pytorch/online_dispatch/compare.py +236 -0
  158. msprobe/pytorch/online_dispatch/dispatch.py +274 -0
  159. msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
  160. msprobe/pytorch/online_dispatch/single_compare.py +391 -0
  161. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
  162. msprobe/pytorch/online_dispatch/utils.py +187 -0
  163. msprobe/pytorch/parse.py +4 -0
  164. msprobe/pytorch/parse_tool/__init__.py +0 -0
  165. msprobe/pytorch/parse_tool/cli.py +32 -0
  166. msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
  167. msprobe/pytorch/parse_tool/lib/compare.py +259 -0
  168. msprobe/pytorch/parse_tool/lib/config.py +51 -0
  169. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
  170. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
  171. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
  172. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
  173. msprobe/pytorch/parse_tool/lib/utils.py +367 -0
  174. msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
  175. msprobe/pytorch/pt_config.py +93 -0
  176. msprobe/pytorch/service.py +167 -0
  177. msprobe/test/core_ut/common/test_utils.py +345 -0
  178. msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
  179. msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
  180. msprobe/test/core_ut/data_dump/test_scope.py +151 -0
  181. msprobe/test/core_ut/test_common_config.py +152 -0
  182. msprobe/test/core_ut/test_file_check.py +218 -0
  183. msprobe/test/core_ut/test_log.py +109 -0
  184. msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
  185. msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
  186. msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
  187. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
  188. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
  189. msprobe/test/mindspore_ut/test_ms_config.py +69 -0
  190. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
  191. msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
  192. msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
  193. msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
  194. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
  195. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
  196. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
  197. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
  198. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
  199. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
  200. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
  201. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
  202. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
  203. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
  204. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
  205. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
  206. msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
  207. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
  208. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
  209. msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
  210. msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
  211. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
  212. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
  213. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
  214. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
  215. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
  216. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
  217. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
  218. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
  219. msprobe/test/pytorch_ut/test_pt_config.py +69 -0
  220. msprobe/test/pytorch_ut/test_service.py +59 -0
  221. msprobe/test/resources/advisor.txt +3 -0
  222. msprobe/test/resources/compare_result_20230703104808.csv +9 -0
  223. msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
  224. msprobe/test/resources/config.yaml +3 -0
  225. msprobe/test/resources/npu_test.pkl +8 -0
  226. msprobe/test/run_test.sh +30 -0
  227. msprobe/test/run_ut.py +58 -0
  228. msprobe/test/test_module_processer.py +64 -0
@@ -0,0 +1,545 @@
1
+ import argparse
2
+ import math
3
+ import os
4
+ import sys
5
+ from collections import namedtuple
6
+
7
+ import torch
8
+ import pandas as pd
9
+
10
+ from msprobe.pytorch.api_accuracy_checker.common.utils import write_csv
11
+ from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
12
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
13
+ API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
14
+ ApiPrecisionCompareColumn, AbsoluteStandardApi, BinaryStandardApi, ULPStandardApi, ThousandthStandardApi, \
15
+ BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage, is_inf_or_nan, \
16
+ check_inf_or_nan
17
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
18
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path
19
+ from msprobe.core.common.file_check import FileChecker, change_mode, check_path_before_create, create_directory
20
+ from msprobe.pytorch.common.log import logger
21
+ from msprobe.core.common.utils import CompareException
22
+ from msprobe.core.common.const import CompareConst, FileCheckConst
23
+
24
+ CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
25
+ BenchmarkInf_Nan_Consistency = namedtuple('BenchmarkInf_Nan_Consistency', ['small_value_inf_nan_consistency',
26
+ 'rmse_inf_nan_consistency',
27
+ 'max_rel_inf_nan_consistency',
28
+ 'mean_rel_inf_nan_consistency',
29
+ 'eb_inf_nan_consistency'])
30
+ unsupported_message = 'This data type does not support benchmark compare.'
31
+
32
+ DEFAULT_THRESHOLD = 1
33
+
34
+ benchmark_algorithms_thresholds = {
35
+ 'small_value': {
36
+ 'error_threshold': 2,
37
+ 'warning_threshold': 1
38
+ },
39
+ 'rmse': {
40
+ 'error_threshold': 2,
41
+ 'warning_threshold': 1
42
+ },
43
+ 'max_rel_err': {
44
+ 'error_threshold': 10,
45
+ 'warning_threshold': 1
46
+ },
47
+ 'mean_rel_err': {
48
+ 'error_threshold': 2,
49
+ 'warning_threshold': 1
50
+ },
51
+ 'eb': {
52
+ 'error_threshold': 2,
53
+ 'warning_threshold': 1
54
+ }
55
+ }
56
+
57
+ benchmark_message = {
58
+ "small_value_err_status": {
59
+ CompareConst.ERROR: "ERROR: 小值域错误比值超过阈值\n",
60
+ CompareConst.WARNING: "WARNING: 小值域错误比值超过阈值\n"
61
+ },
62
+ "rmse_status": {
63
+ CompareConst.ERROR: "ERROR: 均方根误差比值超过阈值\n",
64
+ CompareConst.WARNING: "WARNING: 均方根误差比值超过阈值\n"
65
+ },
66
+ "max_rel_err_status": {
67
+ CompareConst.ERROR: "ERROR: 相对误差最大值比值超过阈值\n",
68
+ CompareConst.WARNING: "WARNING: 相对误差最大值比值超过阈值\n"
69
+ },
70
+ "mean_rel_err_status": {
71
+ CompareConst.ERROR: "ERROR: 相对误差平均值比值超过阈值\n",
72
+ CompareConst.WARNING: "WARNING: 相对误差平均值比值超过阈值\n"
73
+ }
74
+ }
75
+
76
+
77
+ class Standard:
78
+ @staticmethod
79
+ def _calc_ratio(column_name, x, y, default_value):
80
+ '''
81
+ 计算npu侧和gpu侧统计量的比值
82
+ 输入:
83
+ column_name:统计量名称
84
+ x:npu侧统计量
85
+ y:gpu侧统计量
86
+ default:当x不接近0,y接近0,设置的比值默认值
87
+ 输出:
88
+ ratio:统计量x和y的比值
89
+ inf_nan_consistency:不出现inf或nan时为True,出现inf或nan时必须同时为inf或-inf或nan才为True,否则为False
90
+ message:当出现inf或nan时的提示信息
91
+ '''
92
+ x, y = convert_str_to_float(x), convert_str_to_float(y)
93
+
94
+ if is_inf_or_nan(x) or is_inf_or_nan(y):
95
+ return check_inf_or_nan(x, y, column_name)
96
+
97
+ inf_nan_consistency = True
98
+ message = ""
99
+ if math.isclose(y, 0.0):
100
+ if math.isclose(x, 0.0):
101
+ return 1.0, inf_nan_consistency, message
102
+ else:
103
+ return default_value, inf_nan_consistency, message
104
+ else:
105
+ return abs(x / y), inf_nan_consistency, message
106
+
107
+
108
+ class BenchmarkStandard(Standard):
109
+ def __init__(self, api_name, npu_precision, gpu_precision):
110
+ self.api_name = api_name
111
+ self.npu_precision = npu_precision
112
+ self.gpu_precision = gpu_precision
113
+ self.small_value_err_ratio = 1
114
+ self.rmse_ratio = 1
115
+ self.max_rel_err_ratio = 1
116
+ self.mean_rel_err_ratio = 1
117
+ self.eb_ratio = 1
118
+ self.small_value_err_status = CompareConst.PASS
119
+ self.rmse_status = CompareConst.PASS
120
+ self.max_rel_err_status = CompareConst.PASS
121
+ self.mean_rel_err_status = CompareConst.PASS
122
+ self.eb_status = CompareConst.PASS
123
+ self.check_result_list = []
124
+ self.final_result = CompareConst.PASS
125
+ self.compare_message = ""
126
+
127
+ def __str__(self):
128
+ return "%s" % (self.api_name)
129
+
130
+ @staticmethod
131
+ def _get_status(ratio, algorithm):
132
+ if math.isnan(ratio) or math.isinf(ratio):
133
+ return CompareConst.PASS
134
+ error_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('error_threshold', DEFAULT_THRESHOLD)
135
+ warning_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('warning_threshold',
136
+ DEFAULT_THRESHOLD)
137
+ if ratio > error_threshold:
138
+ return CompareConst.ERROR
139
+ elif ratio > warning_threshold:
140
+ return CompareConst.WARNING
141
+ return CompareConst.PASS
142
+
143
+ def get_result(self):
144
+ inf_nan_consistency = self._compare_ratio()
145
+ small_value_inf_nan_consistency = inf_nan_consistency.small_value_inf_nan_consistency
146
+ rmse_inf_nan_consistency = inf_nan_consistency.rmse_inf_nan_consistency
147
+ max_rel_inf_nan_consistency = inf_nan_consistency.max_rel_inf_nan_consistency
148
+ mean_rel_inf_nan_consistency = inf_nan_consistency.mean_rel_inf_nan_consistency
149
+ eb_inf_nan_consistency = inf_nan_consistency.eb_inf_nan_consistency
150
+ self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') if \
151
+ small_value_inf_nan_consistency else CompareConst.ERROR
152
+ self.check_result_list.append(self.small_value_err_status)
153
+ self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') if rmse_inf_nan_consistency \
154
+ else CompareConst.ERROR
155
+ self.check_result_list.append(self.rmse_status)
156
+ self.max_rel_err_status = self._get_status(self.max_rel_err_ratio, 'max_rel_err') if max_rel_inf_nan_consistency \
157
+ else CompareConst.ERROR
158
+ self.check_result_list.append(self.max_rel_err_status)
159
+ self.mean_rel_err_status = self._get_status(self.mean_rel_err_ratio, 'mean_rel_err') if mean_rel_inf_nan_consistency \
160
+ else CompareConst.ERROR
161
+ self.check_result_list.append(self.mean_rel_err_status)
162
+ self.eb_status = self._get_status(self.eb_ratio, 'eb')
163
+ if CompareConst.ERROR in self.check_result_list:
164
+ self.final_result = CompareConst.ERROR
165
+ elif CompareConst.WARNING in self.check_result_list:
166
+ self.final_result = CompareConst.WARNING
167
+
168
+ def to_column_value(self):
169
+ return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
170
+ self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
171
+ self.mean_rel_err_status, self.eb_ratio, self.eb_status]
172
+
173
+ def _compare_ratio(self):
174
+
175
+ self.small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = self._calc_ratio(
176
+ ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE,
177
+ self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE),
178
+ self.gpu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), 10000.0)
179
+ self.compare_message += small_value_message
180
+ self.rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE,
181
+ self.npu_precision.get(ApiPrecisionCompareColumn.RMSE),
182
+ self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0)
183
+ self.compare_message += rmse_message
184
+ self.max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = self._calc_ratio(
185
+ ApiPrecisionCompareColumn.MAX_REL_ERR,
186
+ self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR),
187
+ self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0)
188
+ self.compare_message += max_rel_message
189
+ self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(ApiPrecisionCompareColumn.MEAN_REL_ERR,
190
+ self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR),
191
+ self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0)
192
+ self.compare_message += mean_rel_message
193
+ self.eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB,
194
+ self.npu_precision.get(ApiPrecisionCompareColumn.EB),
195
+ self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0)
196
+ self.compare_message += eb_message
197
+
198
+ return BenchmarkInf_Nan_Consistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
199
+ max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency, eb_inf_nan_consistency)
200
+
201
+
202
+ class ULPStandard(Standard):
203
+ def __init__(self, api_name, npu_precision, gpu_precision):
204
+ self.api_name = api_name
205
+ self.npu_precision = npu_precision
206
+ self.gpu_precision = gpu_precision
207
+ self.mean_ulp_err = 0
208
+ self.ulp_err_proportion = 0
209
+ self.ulp_err_proportion_ratio = 1
210
+ self.ulp_err_status = CompareConst.PASS
211
+ self.compare_message = ""
212
+
213
+ def __str__(self):
214
+ return f"{self.api_name}"
215
+
216
+ def get_result(self):
217
+ self.mean_ulp_err = convert_str_to_float(self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
218
+ gpu_mean_ulp_err = convert_str_to_float(self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
219
+ inf_nan_consistency = True
220
+ if is_inf_or_nan(self.mean_ulp_err) or is_inf_or_nan(gpu_mean_ulp_err):
221
+ _, inf_nan_consistency, message = check_inf_or_nan(self.mean_ulp_err, gpu_mean_ulp_err,
222
+ ApiPrecisionCompareColumn.MEAN_ULP_ERR)
223
+ self.compare_message += message
224
+ self.ulp_err_proportion = convert_str_to_float(
225
+ self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION))
226
+ self.ulp_err_proportion_ratio, ulp_inf_nan_consistency, message = self._calc_ratio(
227
+ ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
228
+ self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION),
229
+ self.gpu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION), 10000.0)
230
+ inf_nan_consistency = inf_nan_consistency and ulp_inf_nan_consistency
231
+ self.compare_message += message
232
+ if inf_nan_consistency:
233
+ self.ulp_err_status = self._get_ulp_status(self.npu_precision.get(ApiPrecisionCompareColumn.DEVICE_DTYPE))
234
+ else:
235
+ self.ulp_err_status = CompareConst.ERROR
236
+
237
+ def _get_ulp_status(self, dtype):
238
+ if dtype == torch.float32:
239
+ if self.mean_ulp_err < 64:
240
+ return CompareConst.PASS
241
+ elif self.ulp_err_proportion < 0.05:
242
+ return CompareConst.PASS
243
+ elif self.ulp_err_proportion_ratio < 1:
244
+ return CompareConst.PASS
245
+ else:
246
+ self.compare_message += "ERROR: ULP误差不满足标准\n"
247
+ return CompareConst.ERROR
248
+ else:
249
+ if self.ulp_err_proportion < 0.001:
250
+ return CompareConst.PASS
251
+ elif self.ulp_err_proportion_ratio < 1:
252
+ return CompareConst.PASS
253
+ else:
254
+ self.compare_message += "ERROR: ULP误差不满足标准\n"
255
+ return CompareConst.ERROR
256
+
257
+
258
+ def write_detail_csv(content, save_path):
259
+ rows = []
260
+ content = ["{:.{}f}".format(item, msCheckerConfig.precision) \
261
+ if isinstance(item, float) else item for item in content]
262
+ rows.append(content)
263
+ write_csv(rows, save_path)
264
+
265
+
266
+ def api_precision_compare(config):
267
+ logger.info("Start compare task")
268
+ logger.info(f"Compare task result will be saved in {config.result_csv_path}")
269
+ logger.info(f"Compare task detail will be saved in {config.details_csv_path}")
270
+ try:
271
+ npu_data = pd.read_csv(config.npu_csv_path)
272
+ except Exception as err:
273
+ logger.error(f"Open npu csv Error: %s" % str(err))
274
+ check_csv_columns(npu_data.columns, "npu_csv")
275
+ try:
276
+ gpu_data = pd.read_csv(config.gpu_csv_path)
277
+ except Exception as err:
278
+ logger.error(f"Open gpu csv Error: %s" % str(err))
279
+ check_csv_columns(gpu_data.columns, "gpu_csv")
280
+ detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
281
+ result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
282
+ write_csv(result_csv_title, config.result_csv_path)
283
+ write_csv(detail_csv_title, config.details_csv_path)
284
+ try:
285
+ analyse_csv(npu_data, gpu_data, config)
286
+ except Exception as err:
287
+ logger.error(f"Analyse csv Error: %s" % str(err))
288
+ change_mode(config.result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
289
+ change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
290
+
291
+
292
+ def analyse_csv(npu_data, gpu_data, config):
293
+ forward_status, backward_status = [], []
294
+ last_api_name, last_api_dtype = None, None
295
+ for _, row_npu in npu_data.iterrows():
296
+ message = ''
297
+ compare_column = ApiPrecisionOutputColumn()
298
+ full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
299
+ row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
300
+ _, api_name, _, direction_status, _, _ = full_api_name_with_direction_status.split(".")
301
+ if row_gpu.empty:
302
+ logger.warning(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.')
303
+ continue
304
+ if len(row_gpu) > 1:
305
+ msg = f'This API : {full_api_name_with_direction_status} has multiple records in the GPU data.'
306
+ raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
307
+ row_gpu = row_gpu.iloc[0]
308
+ new_status = CompareConst.SPACE
309
+ # 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对
310
+ if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace():
311
+ compare_column.api_name = full_api_name_with_direction_status
312
+ compare_column.compare_result = CompareConst.SKIP
313
+ compare_column.compare_message = row_npu[ApiPrecisionCompareColumn.MESSAGE]
314
+ new_status = CompareConst.SKIP
315
+ write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
316
+ else:
317
+ compare_column.api_name = full_api_name_with_direction_status
318
+ if api_name in ThousandthStandardApi:
319
+ new_status = record_thousandth_threshold_result(compare_column, row_npu)
320
+ elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or \
321
+ api_name in BinaryStandardApi:
322
+ new_status = record_binary_consistency_result(api_name, compare_column, row_npu)
323
+ elif api_name in AbsoluteStandardApi:
324
+ new_status = record_absolute_threshold_result(compare_column, row_npu)
325
+ elif api_name in ULPStandardApi and \
326
+ row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST:
327
+ us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu)
328
+ new_status = record_ulp_compare_result(compare_column, us)
329
+ elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST:
330
+ bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu)
331
+ new_status = record_benchmark_compare_result(compare_column, bs)
332
+ write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
333
+
334
+ if last_api_name is not None and api_name != last_api_name:
335
+ if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
336
+ message = unsupported_message
337
+ write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path)
338
+ forward_status, backward_status = [], []
339
+ message = ''
340
+ else:
341
+ forward_result = get_api_checker_result(forward_status)
342
+ backward_result = get_api_checker_result(backward_status)
343
+ message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
344
+ write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
345
+ forward_status, backward_status = [], []
346
+ message = ''
347
+
348
+ is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST
349
+ last_api_name = api_name
350
+
351
+ last_api_dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
352
+ if not is_supported:
353
+ continue
354
+
355
+ if direction_status == 'forward':
356
+ forward_status.append(new_status)
357
+ elif direction_status == 'backward':
358
+ backward_status.append(new_status)
359
+ else:
360
+ logger.error(f"Invalid direction status: {direction_status}")
361
+
362
+ if last_api_name is not None:
363
+ if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
364
+ message = unsupported_message
365
+ write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path)
366
+ else:
367
+ forward_result = get_api_checker_result(forward_status)
368
+ backward_result = get_api_checker_result(backward_status)
369
+ message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
370
+ write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
371
+
372
+
373
+ def check_error_rate(npu_error_rate):
374
+ return CompareConst.PASS if convert_str_to_float(npu_error_rate) == 0 else CompareConst.ERROR
375
+
376
+
377
+ def get_absolute_threshold_result(row_npu):
378
+ inf_nan_error_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO])
379
+ rel_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.REL_ERR_RATIO])
380
+ abs_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.ABS_ERR_RATIO])
381
+
382
+ inf_nan_result = CompareConst.PASS if inf_nan_error_ratio == 0 else CompareConst.ERROR
383
+ rel_err_result = CompareConst.PASS if rel_err_ratio == 0 else CompareConst.ERROR
384
+ abs_err_result = CompareConst.PASS if abs_err_ratio == 0 else CompareConst.ERROR
385
+
386
+ if CompareConst.ERROR in [inf_nan_result, rel_err_result, abs_err_result]:
387
+ absolute_threshold_result = CompareConst.ERROR
388
+ else:
389
+ absolute_threshold_result = CompareConst.PASS
390
+
391
+ return {
392
+ "inf_nan_error_ratio": inf_nan_error_ratio,
393
+ "inf_nan_result": inf_nan_result,
394
+ "rel_err_ratio": rel_err_ratio,
395
+ "rel_err_result": rel_err_result,
396
+ "abs_err_ratio": abs_err_ratio,
397
+ "abs_err_result": abs_err_result,
398
+ "absolute_threshold_result": absolute_threshold_result,
399
+ }
400
+
401
+
402
+ def get_api_checker_result(status):
403
+ if not status:
404
+ return CompareConst.SPACE
405
+ if all(item == CompareConst.SKIP for item in status):
406
+ return CompareConst.SKIP
407
+ for const in (CompareConst.ERROR, CompareConst.WARNING):
408
+ if const in status:
409
+ return const
410
+ return CompareConst.PASS
411
+
412
+
413
+ def check_csv_columns(columns, csv_type):
414
+ required_columns = ApiPrecisionCompareColumn.to_required_columns()
415
+ missing_columns = [column for column in required_columns if column not in columns]
416
+ if missing_columns:
417
+ msg = f"The following columns {','.join(missing_columns)} are missing in{csv_type}"
418
+ raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
419
+
420
+
421
+ def record_binary_consistency_result(api_name, compare_column, row_npu):
422
+ new_status = check_error_rate(row_npu[ApiPrecisionCompareColumn.ERROR_RATE])
423
+ compare_column.error_rate = row_npu[ApiPrecisionCompareColumn.ERROR_RATE]
424
+ compare_column.error_rate_status = new_status
425
+ compare_column.compare_result = new_status
426
+ compare_column.compare_algorithm = "二进制一致法"
427
+ message = ''
428
+ if compare_column.error_rate_status == CompareConst.ERROR:
429
+ message += "ERROR: 二进制一致错误率超过阈值\n"
430
+ message += CompareMessage.get(api_name, "")
431
+ compare_column.compare_message = message
432
+ return new_status
433
+
434
+
435
+ def record_absolute_threshold_result(compare_column, row_npu):
436
+ absolute_threshold_result = get_absolute_threshold_result(row_npu)
437
+ compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
438
+ compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
439
+ compare_column.rel_err_ratio = absolute_threshold_result.get("rel_err_ratio")
440
+ compare_column.rel_err_ratio_status = absolute_threshold_result.get("rel_err_result")
441
+ compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
442
+ compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
443
+ compare_column.compare_result = absolute_threshold_result.get("absolute_threshold_result")
444
+ compare_column.compare_algorithm = "绝对阈值法"
445
+ message = ''
446
+ if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
447
+ message += "ERROR: inf/nan错误率超过阈值\n"
448
+ if compare_column.rel_err_ratio_status == CompareConst.ERROR:
449
+ message += "ERROR: 相对误差错误率超过阈值\n"
450
+ if compare_column.abs_err_ratio_status == CompareConst.ERROR:
451
+ message += "ERROR: 绝对误差错误率超过阈值\n"
452
+ compare_column.compare_message = message
453
+ return compare_column.compare_result
454
+
455
+
456
+ def record_benchmark_compare_result(compare_column, bs):
457
+ bs.get_result()
458
+ compare_column.small_value_err_ratio = bs.small_value_err_ratio
459
+ compare_column.small_value_err_status = bs.small_value_err_status
460
+ compare_column.rmse_ratio = bs.rmse_ratio
461
+ compare_column.rmse_status = bs.rmse_status
462
+ compare_column.max_rel_err_ratio = bs.max_rel_err_ratio
463
+ compare_column.max_rel_err_status = bs.max_rel_err_status
464
+ compare_column.mean_rel_err_ratio = bs.mean_rel_err_ratio
465
+ compare_column.mean_rel_err_status = bs.mean_rel_err_status
466
+ compare_column.eb_ratio = bs.eb_ratio
467
+ compare_column.eb_status = bs.eb_status
468
+ compare_column.compare_result = bs.final_result
469
+ compare_column.compare_algorithm = "标杆比对法"
470
+ compare_column.compare_message = bs.compare_message
471
+ for status_attr, messages in benchmark_message.items():
472
+ status_value = getattr(compare_column, status_attr)
473
+ if status_value in messages:
474
+ compare_column.compare_message += messages[status_value]
475
+ return compare_column.compare_result
476
+
477
+
478
+ def record_ulp_compare_result(compare_column, us):
479
+ us.get_result()
480
+ compare_column.mean_ulp_err = us.mean_ulp_err
481
+ compare_column.ulp_err_proportion = us.ulp_err_proportion
482
+ compare_column.ulp_err_proportion_ratio = us.ulp_err_proportion_ratio
483
+ compare_column.ulp_err_status = us.ulp_err_status
484
+ compare_column.compare_result = us.ulp_err_status
485
+ compare_column.compare_algorithm = "ULP误差比对法"
486
+ compare_column.compare_message = us.compare_message
487
+ return compare_column.compare_result
488
+
489
+
490
+ def check_thousandth_rate(thousandth_rate):
491
+ return CompareConst.PASS if convert_str_to_float(thousandth_rate) >= 0.999 else CompareConst.ERROR
492
+
493
+
494
+ def record_thousandth_threshold_result(compare_column, row_npu):
495
+ new_status = check_thousandth_rate(row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH])
496
+ compare_column.rel_err_thousandth = row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
497
+ compare_column.rel_err_thousandth_status = new_status
498
+ compare_column.compare_result = new_status
499
+ compare_column.compare_algorithm = "双千指标法"
500
+ message = ''
501
+ if compare_column.rel_err_thousandth_status == CompareConst.ERROR:
502
+ message += "ERROR: 双千指标不达标\n"
503
+ compare_column.compare_message = message
504
+ return compare_column.compare_result
505
+
506
+
507
+ def _api_precision_compare(parser=None):
508
+ if not parser:
509
+ parser = argparse.ArgumentParser()
510
+ _api_precision_compare_parser(parser)
511
+ args = parser.parse_args(sys.argv[1:])
512
+ _api_precision_compare_command(args)
513
+
514
+
515
+ def _api_precision_compare_command(args):
516
+ npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail')
517
+ gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail')
518
+ out_path = os.path.realpath(args.out_path) if args.out_path else "./"
519
+ check_path_before_create(out_path)
520
+ create_directory(out_path)
521
+ out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
522
+ out_path = out_path_checker.common_check()
523
+ result_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_RESULT_FILE_NAME)
524
+ details_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_DETAILS_FILE_NAME)
525
+ compare_config = CompareConfig(npu_csv_path, gpu_csv_path, result_csv_path, details_csv_path)
526
+ api_precision_compare(compare_config)
527
+
528
+
529
+ def _api_precision_compare_parser(parser):
530
+ parser.add_argument("-npu", "--npu_csv_path", dest="npu_csv_path", default="", type=str,
531
+ help="<Required> , Accuracy_checking_details.csv generated on the NPU by using the "
532
+ "api_accuracy_checker tool.",
533
+ required=True)
534
+ parser.add_argument("-gpu", "--gpu_csv_path", dest="gpu_csv_path", default="", type=str,
535
+ help="<Required> Accuracy_checking_details.csv generated on the GPU by using the "
536
+ "api_accuracy_checker tool.",
537
+ required=False)
538
+ parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
539
+ help="<optional> The api precision compare task result out path.",
540
+ required=False)
541
+
542
+
543
+ if __name__ == '__main__':
544
+ _api_precision_compare()
545
+ logger.info("Compare task completed.")
@@ -0,0 +1,133 @@
1
+ # Copyright (c) 2024 Huawei Technologies Co., Ltd
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the BSD 3-Clause License (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://opensource.org/licenses/BSD-3-Clause
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ AbsoluteThreshStandard:
17
+ - mul
18
+ - mul_
19
+ - __mul__
20
+ - __imul__
21
+ - __rmul__
22
+ - add
23
+ - add_
24
+ - __add__
25
+ - __iadd__
26
+ - __radd__
27
+ - div
28
+ - div_
29
+ - __div__
30
+ - __idiv__
31
+ - divide
32
+ - divide_
33
+ - leaky_relu
34
+ - leaky_relu_
35
+ - prelu
36
+ - reciprocal
37
+ - reciprocal_
38
+ - rsqrt
39
+ - rsqrt_
40
+ - square
41
+ - square_
42
+ - sub
43
+ - sub_
44
+ - rsub
45
+ - __isub__
46
+ - __sub__
47
+
48
+ BinaryCompareStandard:
49
+ - abs
50
+ - abs_
51
+ - absolute
52
+ - absolute_
53
+ - argmin
54
+ - bitwise_and
55
+ - bitwise_and_
56
+ - broadcast_to
57
+ - ceil
58
+ - ceil_
59
+ - equal
60
+ - fill_
61
+ - flatten
62
+ - floor
63
+ - floor_
64
+ - gather
65
+ - greater
66
+ - greater_
67
+ - greater_equal
68
+ - greater_equal_
69
+ - isfinite
70
+ - isnan
71
+ - less
72
+ - less_
73
+ - less_equal
74
+ - less_equal_
75
+ - logical_and
76
+ - logical_and_
77
+ - logical_not
78
+ - logical_not_
79
+ - logical_or
80
+ - logical_or_
81
+ - masked_fill
82
+ - masked_fill_
83
+ - max_pool3d
84
+ - maximum
85
+ - minimum
86
+ - neg
87
+ - neg_
88
+ - nonzero
89
+ - not_equal
90
+ - not_equal_
91
+ - one_hot
92
+ - pad
93
+ - relu
94
+ - reshape
95
+ - round
96
+ - round_
97
+ - select
98
+ - sign
99
+ - sign_
100
+ - sort
101
+ - tile
102
+ - topk
103
+ - transpose
104
+ - transpose_
105
+ - tril
106
+ - tril_
107
+ - triu
108
+ - triu_
109
+ - type_as
110
+
111
+ ULPStandard:
112
+ - __matmul__
113
+ - addbmm
114
+ - addbmm_
115
+ - addmm
116
+ - addmm_
117
+ - baddbmm
118
+ - baddbmm_
119
+ - bilinear
120
+ - bmm
121
+ - chain_matmul
122
+ - hspmm
123
+ - linear
124
+ - matmul
125
+ - mm
126
+ - mv
127
+ - smm
128
+ - sspaddmm
129
+
130
+ ThousandthStandard:
131
+ - conv1d
132
+ - conv2d
133
+