mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,581 +1,606 @@
1
- import argparse
2
- import math
3
- import os
4
- import sys
5
- from collections import namedtuple
6
-
7
- import torch
8
- import pandas as pd
9
-
10
- from msprobe.core.common.utils import write_csv
11
- from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
12
- from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
13
- API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
14
- ApiPrecisionCompareColumn, absolute_standard_api, binary_standard_api, ulp_standard_api, thousandth_standard_api, \
15
- BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage, is_inf_or_nan, \
16
- check_inf_or_nan
17
- from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
18
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path
19
- from msprobe.core.common.file_check import FileChecker, change_mode, check_path_before_create, create_directory
20
- from msprobe.pytorch.common.log import logger
21
- from msprobe.core.common.utils import CompareException
22
- from msprobe.core.common.const import CompareConst, FileCheckConst, Const
23
-
24
- CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
25
- BenchmarkInf_Nan_Consistency = namedtuple('BenchmarkInf_Nan_Consistency', ['small_value_inf_nan_consistency',
26
- 'rmse_inf_nan_consistency',
27
- 'max_rel_inf_nan_consistency',
28
- 'mean_rel_inf_nan_consistency',
29
- 'eb_inf_nan_consistency'])
30
- unsupported_message = 'This data type does not support benchmark compare.'
31
-
32
- DEFAULT_THRESHOLD = 1
33
-
34
- benchmark_algorithms_thresholds = {
35
- 'small_value': {
36
- 'error_threshold': 2,
37
- 'warning_threshold': 1
38
- },
39
- 'rmse': {
40
- 'error_threshold': 2,
41
- 'warning_threshold': 1
42
- },
43
- 'max_rel_err': {
44
- 'error_threshold': 10,
45
- 'warning_threshold': 1
46
- },
47
- 'mean_rel_err': {
48
- 'error_threshold': 2,
49
- 'warning_threshold': 1
50
- },
51
- 'eb': {
52
- 'error_threshold': 2,
53
- 'warning_threshold': 1
54
- }
55
- }
56
-
57
- benchmark_message = {
58
- "small_value_err_status": {
59
- CompareConst.ERROR: "ERROR: 小值域错误比值超过阈值\n",
60
- CompareConst.WARNING: "WARNING: 小值域错误比值超过阈值\n"
61
- },
62
- "rmse_status": {
63
- CompareConst.ERROR: "ERROR: 均方根误差比值超过阈值\n",
64
- CompareConst.WARNING: "WARNING: 均方根误差比值超过阈值\n"
65
- },
66
- "max_rel_err_status": {
67
- CompareConst.ERROR: "ERROR: 相对误差最大值比值超过阈值\n",
68
- CompareConst.WARNING: "WARNING: 相对误差最大值比值超过阈值\n"
69
- },
70
- "mean_rel_err_status": {
71
- CompareConst.ERROR: "ERROR: 相对误差平均值比值超过阈值\n",
72
- CompareConst.WARNING: "WARNING: 相对误差平均值比值超过阈值\n"
73
- }
74
- }
75
-
76
-
77
- class Standard:
78
- @staticmethod
79
- def _calc_ratio(column_name, x, y, default_value):
80
- '''
81
- 计算npu侧和gpu侧统计量的比值
82
- 输入:
83
- column_name:统计量名称
84
- x:npu侧统计量
85
- ygpu侧统计量
86
- default:当x不接近0,y接近0,设置的比值默认值
87
- 输出:
88
- ratio:统计量x和y的比值
89
- inf_nan_consistency:不出现inf或nan时为True,出现inf或nan时必须同时为inf或-inf或nan才为True,否则为False
90
- message:当出现inf或nan时的提示信息
91
- '''
92
- x, y = convert_str_to_float(x), convert_str_to_float(y)
93
-
94
- if is_inf_or_nan(x) or is_inf_or_nan(y):
95
- return check_inf_or_nan(x, y, column_name)
96
-
97
- inf_nan_consistency = True
98
- message = ""
99
- if math.isclose(y, 0.0):
100
- if math.isclose(x, 0.0):
101
- return 1.0, inf_nan_consistency, message
102
- else:
103
- return default_value, inf_nan_consistency, message
104
- else:
105
- return abs(x / y), inf_nan_consistency, message
106
-
107
-
108
- class BenchmarkStandard(Standard):
109
- def __init__(self, api_name, npu_precision, gpu_precision):
110
- self.api_name = api_name
111
- self.npu_precision = npu_precision
112
- self.gpu_precision = gpu_precision
113
- self.small_value_err_ratio = 1
114
- self.rmse_ratio = 1
115
- self.max_rel_err_ratio = 1
116
- self.mean_rel_err_ratio = 1
117
- self.eb_ratio = 1
118
- self.small_value_err_status = CompareConst.PASS
119
- self.rmse_status = CompareConst.PASS
120
- self.max_rel_err_status = CompareConst.PASS
121
- self.mean_rel_err_status = CompareConst.PASS
122
- self.eb_status = CompareConst.PASS
123
- self.check_result_list = []
124
- self.final_result = CompareConst.PASS
125
- self.compare_message = ""
126
-
127
- def __str__(self):
128
- return "%s" % (self.api_name)
129
-
130
- @staticmethod
131
- def _get_status(ratio, algorithm):
132
- if math.isnan(ratio) or math.isinf(ratio):
133
- return CompareConst.PASS
134
- error_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('error_threshold', DEFAULT_THRESHOLD)
135
- warning_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('warning_threshold',
136
- DEFAULT_THRESHOLD)
137
- if ratio > error_threshold:
138
- return CompareConst.ERROR
139
- elif ratio > warning_threshold:
140
- return CompareConst.WARNING
141
- return CompareConst.PASS
142
-
143
- def get_result(self):
144
- inf_nan_consistency = self._compare_ratio()
145
- small_value_inf_nan_consistency = inf_nan_consistency.small_value_inf_nan_consistency
146
- rmse_inf_nan_consistency = inf_nan_consistency.rmse_inf_nan_consistency
147
- max_rel_inf_nan_consistency = inf_nan_consistency.max_rel_inf_nan_consistency
148
- mean_rel_inf_nan_consistency = inf_nan_consistency.mean_rel_inf_nan_consistency
149
- eb_inf_nan_consistency = inf_nan_consistency.eb_inf_nan_consistency
150
- self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') if \
151
- small_value_inf_nan_consistency else CompareConst.ERROR
152
- self.check_result_list.append(self.small_value_err_status)
153
- self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') if rmse_inf_nan_consistency \
154
- else CompareConst.ERROR
155
- self.check_result_list.append(self.rmse_status)
156
- self.max_rel_err_status = self._get_status(self.max_rel_err_ratio, 'max_rel_err') if max_rel_inf_nan_consistency \
157
- else CompareConst.ERROR
158
- self.check_result_list.append(self.max_rel_err_status)
159
- self.mean_rel_err_status = self._get_status(self.mean_rel_err_ratio, 'mean_rel_err') if mean_rel_inf_nan_consistency \
160
- else CompareConst.ERROR
161
- self.check_result_list.append(self.mean_rel_err_status)
162
- self.eb_status = self._get_status(self.eb_ratio, 'eb')
163
- if CompareConst.ERROR in self.check_result_list:
164
- self.final_result = CompareConst.ERROR
165
- elif CompareConst.WARNING in self.check_result_list:
166
- self.final_result = CompareConst.WARNING
167
-
168
- def to_column_value(self):
169
- return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
170
- self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
171
- self.mean_rel_err_status, self.eb_ratio, self.eb_status]
172
-
173
- def _compare_ratio(self):
174
-
175
- self.small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = self._calc_ratio(
176
- ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE,
177
- self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE),
178
- self.gpu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), 10000.0)
179
- self.compare_message += small_value_message
180
- self.rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE,
181
- self.npu_precision.get(ApiPrecisionCompareColumn.RMSE),
182
- self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0)
183
- self.compare_message += rmse_message
184
- self.max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = self._calc_ratio(
185
- ApiPrecisionCompareColumn.MAX_REL_ERR,
186
- self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR),
187
- self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0)
188
- self.compare_message += max_rel_message
189
- self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(ApiPrecisionCompareColumn.MEAN_REL_ERR,
190
- self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR),
191
- self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0)
192
- self.compare_message += mean_rel_message
193
- self.eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB,
194
- self.npu_precision.get(ApiPrecisionCompareColumn.EB),
195
- self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0)
196
- self.compare_message += eb_message
197
-
198
- return BenchmarkInf_Nan_Consistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
199
- max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency, eb_inf_nan_consistency)
200
-
201
-
202
- class ULPStandard(Standard):
203
- def __init__(self, api_name, npu_precision, gpu_precision):
204
- self.api_name = api_name
205
- self.npu_precision = npu_precision
206
- self.gpu_precision = gpu_precision
207
- self.mean_ulp_err = 0
208
- self.ulp_err_proportion = 0
209
- self.ulp_err_proportion_ratio = 1
210
- self.ulp_err_status = CompareConst.PASS
211
- self.compare_message = ""
212
-
213
- def __str__(self):
214
- return f"{self.api_name}"
215
-
216
- def get_result(self):
217
- self.mean_ulp_err = convert_str_to_float(self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
218
- gpu_mean_ulp_err = convert_str_to_float(self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
219
- inf_nan_consistency = True
220
- if is_inf_or_nan(self.mean_ulp_err) or is_inf_or_nan(gpu_mean_ulp_err):
221
- _, inf_nan_consistency, message = check_inf_or_nan(self.mean_ulp_err, gpu_mean_ulp_err,
222
- ApiPrecisionCompareColumn.MEAN_ULP_ERR)
223
- self.compare_message += message
224
- self.ulp_err_proportion = convert_str_to_float(
225
- self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION))
226
- self.ulp_err_proportion_ratio, ulp_inf_nan_consistency, message = self._calc_ratio(
227
- ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
228
- self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION),
229
- self.gpu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION), 10000.0)
230
- inf_nan_consistency = inf_nan_consistency and ulp_inf_nan_consistency
231
- self.compare_message += message
232
- if inf_nan_consistency:
233
- self.ulp_err_status = self._get_ulp_status(self.npu_precision.get(ApiPrecisionCompareColumn.DEVICE_DTYPE))
234
- else:
235
- self.ulp_err_status = CompareConst.ERROR
236
-
237
- def _get_ulp_status(self, dtype):
238
- if dtype == torch.float32:
239
- if self.mean_ulp_err < 64:
240
- return CompareConst.PASS
241
- elif self.ulp_err_proportion < 0.05:
242
- return CompareConst.PASS
243
- elif self.ulp_err_proportion_ratio < 1:
244
- return CompareConst.PASS
245
- else:
246
- self.compare_message += "ERROR: ULP误差不满足标准\n"
247
- return CompareConst.ERROR
248
- else:
249
- if self.ulp_err_proportion < 0.001:
250
- return CompareConst.PASS
251
- elif self.ulp_err_proportion_ratio < 1:
252
- return CompareConst.PASS
253
- else:
254
- self.compare_message += "ERROR: ULP误差不满足标准\n"
255
- return CompareConst.ERROR
256
-
257
-
258
- def write_detail_csv(content, save_path):
259
- rows = []
260
- content = ["{:.{}f}".format(item, msCheckerConfig.precision) \
261
- if isinstance(item, float) else item for item in content]
262
- rows.append(content)
263
- write_csv(rows, save_path)
264
-
265
-
266
- def api_precision_compare(config):
267
- logger.info("Start compare task")
268
- logger.info(f"Compare task result will be saved in {config.result_csv_path}")
269
- logger.info(f"Compare task detail will be saved in {config.details_csv_path}")
270
- try:
271
- npu_data = pd.read_csv(config.npu_csv_path)
272
- except Exception as err:
273
- logger.error(f"Open npu csv Error: %s" % str(err))
274
- check_csv_columns(npu_data.columns, "npu_csv")
275
- try:
276
- gpu_data = pd.read_csv(config.gpu_csv_path)
277
- except Exception as err:
278
- logger.error(f"Open gpu csv Error: %s" % str(err))
279
- check_csv_columns(gpu_data.columns, "gpu_csv")
280
- detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
281
- result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
282
- write_csv(result_csv_title, config.result_csv_path)
283
- write_csv(detail_csv_title, config.details_csv_path)
284
- try:
285
- analyse_csv(npu_data, gpu_data, config)
286
- except Exception as err:
287
- logger.error(f"Analyse csv Error: %s" % str(err))
288
- change_mode(config.result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
289
- change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
290
-
291
-
292
- def online_api_precision_compare(online_config):
293
- rank = online_config.rank
294
- result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace("_rank*.csv", f"_rank{rank}.csv")
295
- details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace("_rank*.csv", f"_rank{rank}.csv")
296
- detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
297
- result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
298
- if not os.path.exists(result_csv_path):
299
- write_csv(result_csv_title, result_csv_path)
300
- if not os.path.exists(details_csv_path):
301
- write_csv(detail_csv_title, details_csv_path)
302
- config = CompareConfig("", "", result_csv_path, details_csv_path)
303
- try:
304
- npu_data, gpu_data = online_config.npu_data, online_config.gpu_data
305
- check_csv_columns(npu_data.columns, "npu_csv")
306
- check_csv_columns(gpu_data.columns, "gpu_csv")
307
- analyse_csv(npu_data, gpu_data, config)
308
- except Exception as err:
309
- logger.error(f"Online api precision compare Error: {str(err)}")
310
- change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
311
- change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
312
-
313
-
314
- def analyse_csv(npu_data, gpu_data, config):
315
- forward_status, backward_status = [], []
316
- last_api_name, last_api_dtype, last_api_full_name = None, None, None
317
- for _, row_npu in npu_data.iterrows():
318
- message = ''
319
- compare_column = ApiPrecisionOutputColumn()
320
- full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
321
- row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
322
- api_type, api_name, api_nums, direction_status, _, _ = full_api_name_with_direction_status.split(Const.SEP)
323
- api_full_name = Const.SEP.join([api_type, api_name, api_nums])
324
- if row_gpu.empty:
325
- logger.warning(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.')
326
- continue
327
- if len(row_gpu) > 1:
328
- msg = f'This API : {full_api_name_with_direction_status} has multiple records in the GPU data.'
329
- raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
330
- row_gpu = row_gpu.iloc[0]
331
- new_status = CompareConst.SPACE
332
- # 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对
333
- if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace():
334
- compare_column.api_name = full_api_name_with_direction_status
335
- compare_column.compare_result = CompareConst.SKIP
336
- compare_column.compare_message = row_npu[ApiPrecisionCompareColumn.MESSAGE]
337
- new_status = CompareConst.SKIP
338
- write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
339
- else:
340
- compare_column.api_name = full_api_name_with_direction_status
341
- if api_name in thousandth_standard_api:
342
- new_status = record_thousandth_threshold_result(compare_column, row_npu)
343
- elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or \
344
- api_name in binary_standard_api:
345
- new_status = record_binary_consistency_result(api_name, compare_column, row_npu)
346
- elif api_name in absolute_standard_api:
347
- new_status = record_absolute_threshold_result(compare_column, row_npu)
348
- elif api_name in ulp_standard_api and \
349
- row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST:
350
- us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu)
351
- new_status = record_ulp_compare_result(compare_column, us)
352
- elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST:
353
- bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu)
354
- new_status = record_benchmark_compare_result(compare_column, bs)
355
- write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
356
-
357
- if last_api_name is not None and api_name != last_api_name:
358
- if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
359
- message = unsupported_message
360
- write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path)
361
- print_test_success(api_full_name, "skip", "skip")
362
- forward_status, backward_status = [], []
363
- message = ''
364
- else:
365
- forward_result = get_api_checker_result(forward_status)
366
- backward_result = get_api_checker_result(backward_status)
367
- message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
368
- write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
369
- print_test_success(api_full_name, forward_result, backward_result)
370
- forward_status, backward_status = [], []
371
- message = ''
372
-
373
- is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST
374
- last_api_name = api_name
375
- last_api_full_name = api_full_name
376
-
377
- last_api_dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
378
- if not is_supported:
379
- continue
380
-
381
- if direction_status == 'forward':
382
- forward_status.append(new_status)
383
- elif direction_status == 'backward':
384
- backward_status.append(new_status)
385
- else:
386
- logger.error(f"Invalid direction status: {direction_status}")
387
-
388
- if last_api_name is not None:
389
- if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
390
- message = unsupported_message
391
- write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path)
392
- print_test_success(last_api_full_name, "skip", "skip")
393
- else:
394
- forward_result = get_api_checker_result(forward_status)
395
- backward_result = get_api_checker_result(backward_status)
396
- message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
397
- write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
398
- print_test_success(last_api_full_name, forward_result, backward_result)
399
-
400
-
401
- def print_test_success(api_full_name, forward_result, backward_result):
402
- is_fwd_success = (forward_result == CompareConst.PASS)
403
- is_bwd_success = (backward_result == CompareConst.PASS or backward_result == CompareConst.SPACE)
404
- logger.info(f"running api_full_name {api_full_name} compare, "
405
- f"is_fwd_success: {is_fwd_success}, "
406
- f"is_bwd_success: {is_bwd_success}")
407
-
408
-
409
- def check_error_rate(npu_error_rate):
410
- return CompareConst.PASS if convert_str_to_float(npu_error_rate) == 0 else CompareConst.ERROR
411
-
412
-
413
- def get_absolute_threshold_result(row_npu):
414
- inf_nan_error_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO])
415
- rel_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.REL_ERR_RATIO])
416
- abs_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.ABS_ERR_RATIO])
417
-
418
- inf_nan_result = CompareConst.PASS if inf_nan_error_ratio == 0 else CompareConst.ERROR
419
- rel_err_result = CompareConst.PASS if rel_err_ratio == 0 else CompareConst.ERROR
420
- abs_err_result = CompareConst.PASS if abs_err_ratio == 0 else CompareConst.ERROR
421
-
422
- if CompareConst.ERROR in [inf_nan_result, rel_err_result, abs_err_result]:
423
- absolute_threshold_result = CompareConst.ERROR
424
- else:
425
- absolute_threshold_result = CompareConst.PASS
426
-
427
- return {
428
- "inf_nan_error_ratio": inf_nan_error_ratio,
429
- "inf_nan_result": inf_nan_result,
430
- "rel_err_ratio": rel_err_ratio,
431
- "rel_err_result": rel_err_result,
432
- "abs_err_ratio": abs_err_ratio,
433
- "abs_err_result": abs_err_result,
434
- "absolute_threshold_result": absolute_threshold_result,
435
- }
436
-
437
-
438
- def get_api_checker_result(status):
439
- if not status:
440
- return CompareConst.SPACE
441
- if all(item == CompareConst.SKIP for item in status):
442
- return CompareConst.SKIP
443
- for const in (CompareConst.ERROR, CompareConst.WARNING):
444
- if const in status:
445
- return const
446
- return CompareConst.PASS
447
-
448
-
449
- def check_csv_columns(columns, csv_type):
450
- required_columns = ApiPrecisionCompareColumn.to_required_columns()
451
- missing_columns = [column for column in required_columns if column not in columns]
452
- if missing_columns:
453
- msg = f"The following columns {','.join(missing_columns)} are missing in{csv_type}"
454
- raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
455
-
456
-
457
- def record_binary_consistency_result(api_name, compare_column, row_npu):
458
- new_status = check_error_rate(row_npu[ApiPrecisionCompareColumn.ERROR_RATE])
459
- compare_column.error_rate = row_npu[ApiPrecisionCompareColumn.ERROR_RATE]
460
- compare_column.error_rate_status = new_status
461
- compare_column.compare_result = new_status
462
- compare_column.compare_algorithm = "二进制一致法"
463
- message = ''
464
- if compare_column.error_rate_status == CompareConst.ERROR:
465
- message += "ERROR: 二进制一致错误率超过阈值\n"
466
- message += CompareMessage.get(api_name, "")
467
- compare_column.compare_message = message
468
- return new_status
469
-
470
-
471
- def record_absolute_threshold_result(compare_column, row_npu):
472
- absolute_threshold_result = get_absolute_threshold_result(row_npu)
473
- compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
474
- compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
475
- compare_column.rel_err_ratio = absolute_threshold_result.get("rel_err_ratio")
476
- compare_column.rel_err_ratio_status = absolute_threshold_result.get("rel_err_result")
477
- compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
478
- compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
479
- compare_column.compare_result = absolute_threshold_result.get("absolute_threshold_result")
480
- compare_column.compare_algorithm = "绝对阈值法"
481
- message = ''
482
- if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
483
- message += "ERROR: inf/nan错误率超过阈值\n"
484
- if compare_column.rel_err_ratio_status == CompareConst.ERROR:
485
- message += "ERROR: 相对误差错误率超过阈值\n"
486
- if compare_column.abs_err_ratio_status == CompareConst.ERROR:
487
- message += "ERROR: 绝对误差错误率超过阈值\n"
488
- compare_column.compare_message = message
489
- return compare_column.compare_result
490
-
491
-
492
- def record_benchmark_compare_result(compare_column, bs):
493
- bs.get_result()
494
- compare_column.small_value_err_ratio = bs.small_value_err_ratio
495
- compare_column.small_value_err_status = bs.small_value_err_status
496
- compare_column.rmse_ratio = bs.rmse_ratio
497
- compare_column.rmse_status = bs.rmse_status
498
- compare_column.max_rel_err_ratio = bs.max_rel_err_ratio
499
- compare_column.max_rel_err_status = bs.max_rel_err_status
500
- compare_column.mean_rel_err_ratio = bs.mean_rel_err_ratio
501
- compare_column.mean_rel_err_status = bs.mean_rel_err_status
502
- compare_column.eb_ratio = bs.eb_ratio
503
- compare_column.eb_status = bs.eb_status
504
- compare_column.compare_result = bs.final_result
505
- compare_column.compare_algorithm = "标杆比对法"
506
- compare_column.compare_message = bs.compare_message
507
- for status_attr, messages in benchmark_message.items():
508
- status_value = getattr(compare_column, status_attr)
509
- if status_value in messages:
510
- compare_column.compare_message += messages[status_value]
511
- return compare_column.compare_result
512
-
513
-
514
- def record_ulp_compare_result(compare_column, us):
515
- us.get_result()
516
- compare_column.mean_ulp_err = us.mean_ulp_err
517
- compare_column.ulp_err_proportion = us.ulp_err_proportion
518
- compare_column.ulp_err_proportion_ratio = us.ulp_err_proportion_ratio
519
- compare_column.ulp_err_status = us.ulp_err_status
520
- compare_column.compare_result = us.ulp_err_status
521
- compare_column.compare_algorithm = "ULP误差比对法"
522
- compare_column.compare_message = us.compare_message
523
- return compare_column.compare_result
524
-
525
-
526
- def check_thousandth_rate(thousandth_rate):
527
- return CompareConst.PASS if convert_str_to_float(thousandth_rate) >= 0.999 else CompareConst.ERROR
528
-
529
-
530
- def record_thousandth_threshold_result(compare_column, row_npu):
531
- new_status = check_thousandth_rate(row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH])
532
- compare_column.rel_err_thousandth = row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
533
- compare_column.rel_err_thousandth_status = new_status
534
- compare_column.compare_result = new_status
535
- compare_column.compare_algorithm = "双千指标法"
536
- message = ''
537
- if compare_column.rel_err_thousandth_status == CompareConst.ERROR:
538
- message += "ERROR: 双千指标不达标\n"
539
- compare_column.compare_message = message
540
- return compare_column.compare_result
541
-
542
-
543
- def _api_precision_compare(parser=None):
544
- if not parser:
545
- parser = argparse.ArgumentParser()
546
- _api_precision_compare_parser(parser)
547
- args = parser.parse_args(sys.argv[1:])
548
- _api_precision_compare_command(args)
549
-
550
-
551
- def _api_precision_compare_command(args):
552
- npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail')
553
- gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail')
554
- out_path = os.path.realpath(args.out_path) if args.out_path else "./"
555
- check_path_before_create(out_path)
556
- create_directory(out_path)
557
- out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
558
- out_path = out_path_checker.common_check()
559
- result_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_RESULT_FILE_NAME)
560
- details_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_DETAILS_FILE_NAME)
561
- compare_config = CompareConfig(npu_csv_path, gpu_csv_path, result_csv_path, details_csv_path)
562
- api_precision_compare(compare_config)
563
-
564
-
565
- def _api_precision_compare_parser(parser):
566
- parser.add_argument("-npu", "--npu_csv_path", dest="npu_csv_path", default="", type=str,
567
- help="<Required> , Accuracy_checking_details.csv generated on the NPU by using the "
568
- "api_accuracy_checker tool.",
569
- required=True)
570
- parser.add_argument("-gpu", "--gpu_csv_path", dest="gpu_csv_path", default="", type=str,
571
- help="<Required> Accuracy_checking_details.csv generated on the GPU by using the "
572
- "api_accuracy_checker tool.",
573
- required=False)
574
- parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
575
- help="<optional> The api precision compare task result out path.",
576
- required=False)
577
-
578
-
579
- if __name__ == '__main__':
580
- _api_precision_compare()
581
- logger.info("Compare task completed.")
1
+ import argparse
2
+ import math
3
+ import os
4
+ import sys
5
+ from collections import namedtuple
6
+
7
+ import torch
8
+ import pandas as pd
9
+
10
+ from msprobe.core.common.file_utils import write_csv
11
+ from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
12
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
13
+ API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
14
+ ApiPrecisionCompareColumn, absolute_standard_api, binary_standard_api, ulp_standard_api, thousandth_standard_api, \
15
+ BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage, is_inf_or_nan, \
16
+ check_inf_or_nan
17
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
18
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path
19
+ from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments
20
+ from msprobe.core.common.file_utils import FileChecker, change_mode, check_path_before_create, create_directory
21
+ from msprobe.pytorch.common.log import logger
22
+ from msprobe.core.common.utils import CompareException
23
+ from msprobe.core.common.const import Const, CompareConst, FileCheckConst
24
+
25
+ CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
26
+ BenchmarkInf_Nan_Consistency = namedtuple('BenchmarkInf_Nan_Consistency', ['small_value_inf_nan_consistency',
27
+ 'rmse_inf_nan_consistency',
28
+ 'max_rel_inf_nan_consistency',
29
+ 'mean_rel_inf_nan_consistency',
30
+ 'eb_inf_nan_consistency'])
31
+ unsupported_message = 'This data type does not support benchmark compare.'
32
+
33
+ DEFAULT_THRESHOLD = 1
34
+
35
+ benchmark_algorithms_thresholds = {
36
+ 'small_value': {
37
+ 'error_threshold': 2,
38
+ 'warning_threshold': 1
39
+ },
40
+ 'rmse': {
41
+ 'error_threshold': 2,
42
+ 'warning_threshold': 1
43
+ },
44
+ 'max_rel_err': {
45
+ 'error_threshold': 10,
46
+ 'warning_threshold': 1
47
+ },
48
+ 'mean_rel_err': {
49
+ 'error_threshold': 2,
50
+ 'warning_threshold': 1
51
+ },
52
+ 'eb': {
53
+ 'error_threshold': 2,
54
+ 'warning_threshold': 1
55
+ }
56
+ }
57
+
58
+ benchmark_message = {
59
+ "small_value_err_status": {
60
+ CompareConst.ERROR: "ERROR: 小值域错误比值超过阈值\n",
61
+ CompareConst.WARNING: "WARNING: 小值域错误比值超过阈值\n"
62
+ },
63
+ "rmse_status": {
64
+ CompareConst.ERROR: "ERROR: 均方根误差比值超过阈值\n",
65
+ CompareConst.WARNING: "WARNING: 均方根误差比值超过阈值\n"
66
+ },
67
+ "max_rel_err_status": {
68
+ CompareConst.ERROR: "ERROR: 相对误差最大值比值超过阈值\n",
69
+ CompareConst.WARNING: "WARNING: 相对误差最大值比值超过阈值\n"
70
+ },
71
+ "mean_rel_err_status": {
72
+ CompareConst.ERROR: "ERROR: 相对误差平均值比值超过阈值\n",
73
+ CompareConst.WARNING: "WARNING: 相对误差平均值比值超过阈值\n"
74
+ }
75
+ }
76
+
77
+
78
+ class Standard:
79
+ @staticmethod
80
+ def _calc_ratio(column_name, x, y, default_value):
81
+ '''
82
+ 计算npu侧和gpu侧统计量的比值
83
+ 输入:
84
+ column_name:统计量名称
85
+ xnpu侧统计量
86
+ y:gpu侧统计量
87
+ default:当x不接近0,y接近0,设置的比值默认值
88
+ 输出:
89
+ ratio:统计量x和y的比值
90
+ inf_nan_consistency:不出现inf或nan时为True,出现inf或nan时必须同时为inf或-inf或nan才为True,否则为False
91
+ message:当出现inf或nan时的提示信息
92
+ '''
93
+ x, y = convert_str_to_float(x), convert_str_to_float(y)
94
+
95
+ if is_inf_or_nan(x) or is_inf_or_nan(y):
96
+ return check_inf_or_nan(x, y, column_name)
97
+
98
+ inf_nan_consistency = True
99
+ message = ""
100
+ if math.isclose(y, 0.0):
101
+ if math.isclose(x, 0.0):
102
+ return 1.0, inf_nan_consistency, message
103
+ else:
104
+ return default_value, inf_nan_consistency, message
105
+ else:
106
+ return abs(x / y), inf_nan_consistency, message
107
+
108
+
109
+ class BenchmarkStandard(Standard):
110
+ def __init__(self, api_name, npu_precision, gpu_precision):
111
+ self.api_name = api_name
112
+ self.npu_precision = npu_precision
113
+ self.gpu_precision = gpu_precision
114
+ self.small_value_err_ratio = 1
115
+ self.rmse_ratio = 1
116
+ self.max_rel_err_ratio = 1
117
+ self.mean_rel_err_ratio = 1
118
+ self.eb_ratio = 1
119
+ self.small_value_err_status = CompareConst.PASS
120
+ self.rmse_status = CompareConst.PASS
121
+ self.max_rel_err_status = CompareConst.PASS
122
+ self.mean_rel_err_status = CompareConst.PASS
123
+ self.eb_status = CompareConst.PASS
124
+ self.check_result_list = []
125
+ self.final_result = CompareConst.PASS
126
+ self.compare_message = ""
127
+
128
+ def __str__(self):
129
+ return "%s" % (self.api_name)
130
+
131
+ @staticmethod
132
+ def _get_status(ratio, algorithm):
133
+ if math.isnan(ratio) or math.isinf(ratio):
134
+ return CompareConst.PASS
135
+ error_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('error_threshold', DEFAULT_THRESHOLD)
136
+ warning_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('warning_threshold',
137
+ DEFAULT_THRESHOLD)
138
+ if ratio > error_threshold:
139
+ return CompareConst.ERROR
140
+ elif ratio > warning_threshold:
141
+ return CompareConst.WARNING
142
+ return CompareConst.PASS
143
+
144
+ def get_result(self):
145
+ inf_nan_consistency = self._compare_ratio()
146
+ small_value_inf_nan_consistency = inf_nan_consistency.small_value_inf_nan_consistency
147
+ rmse_inf_nan_consistency = inf_nan_consistency.rmse_inf_nan_consistency
148
+ max_rel_inf_nan_consistency = inf_nan_consistency.max_rel_inf_nan_consistency
149
+ mean_rel_inf_nan_consistency = inf_nan_consistency.mean_rel_inf_nan_consistency
150
+ eb_inf_nan_consistency = inf_nan_consistency.eb_inf_nan_consistency
151
+ self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') if \
152
+ small_value_inf_nan_consistency else CompareConst.ERROR
153
+ self.check_result_list.append(self.small_value_err_status)
154
+ self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') if rmse_inf_nan_consistency \
155
+ else CompareConst.ERROR
156
+ self.check_result_list.append(self.rmse_status)
157
+ self.max_rel_err_status = self._get_status(self.max_rel_err_ratio, 'max_rel_err') if max_rel_inf_nan_consistency \
158
+ else CompareConst.ERROR
159
+ self.check_result_list.append(self.max_rel_err_status)
160
+ self.mean_rel_err_status = self._get_status(self.mean_rel_err_ratio, 'mean_rel_err') if mean_rel_inf_nan_consistency \
161
+ else CompareConst.ERROR
162
+ self.check_result_list.append(self.mean_rel_err_status)
163
+ self.eb_status = self._get_status(self.eb_ratio, 'eb')
164
+ if CompareConst.ERROR in self.check_result_list:
165
+ self.final_result = CompareConst.ERROR
166
+ elif CompareConst.WARNING in self.check_result_list:
167
+ self.final_result = CompareConst.WARNING
168
+
169
+ def to_column_value(self):
170
+ return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
171
+ self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
172
+ self.mean_rel_err_status, self.eb_ratio, self.eb_status]
173
+
174
+ def _compare_ratio(self):
175
+
176
+ self.small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = self._calc_ratio(
177
+ ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE,
178
+ self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE),
179
+ self.gpu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), 10000.0)
180
+ self.compare_message += small_value_message
181
+ self.rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE,
182
+ self.npu_precision.get(ApiPrecisionCompareColumn.RMSE),
183
+ self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0)
184
+ self.compare_message += rmse_message
185
+ self.max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = self._calc_ratio(
186
+ ApiPrecisionCompareColumn.MAX_REL_ERR,
187
+ self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR),
188
+ self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0)
189
+ self.compare_message += max_rel_message
190
+ self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(ApiPrecisionCompareColumn.MEAN_REL_ERR,
191
+ self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR),
192
+ self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0)
193
+ self.compare_message += mean_rel_message
194
+ self.eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB,
195
+ self.npu_precision.get(ApiPrecisionCompareColumn.EB),
196
+ self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0)
197
+ self.compare_message += eb_message
198
+
199
+ return BenchmarkInf_Nan_Consistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
200
+ max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency, eb_inf_nan_consistency)
201
+
202
+
203
+ class ULPStandard(Standard):
204
+ def __init__(self, api_name, npu_precision, gpu_precision):
205
+ self.api_name = api_name
206
+ self.npu_precision = npu_precision
207
+ self.gpu_precision = gpu_precision
208
+ self.mean_ulp_err = 0
209
+ self.ulp_err_proportion = 0
210
+ self.ulp_err_proportion_ratio = 1
211
+ self.ulp_err_status = CompareConst.PASS
212
+ self.compare_message = ""
213
+
214
+ def __str__(self):
215
+ return f"{self.api_name}"
216
+
217
+ def get_result(self):
218
+ self.mean_ulp_err = convert_str_to_float(self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
219
+ gpu_mean_ulp_err = convert_str_to_float(self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
220
+ inf_nan_consistency = True
221
+ if is_inf_or_nan(self.mean_ulp_err) or is_inf_or_nan(gpu_mean_ulp_err):
222
+ _, inf_nan_consistency, message = check_inf_or_nan(self.mean_ulp_err, gpu_mean_ulp_err,
223
+ ApiPrecisionCompareColumn.MEAN_ULP_ERR)
224
+ self.compare_message += message
225
+ self.ulp_err_proportion = convert_str_to_float(
226
+ self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION))
227
+ self.ulp_err_proportion_ratio, ulp_inf_nan_consistency, message = self._calc_ratio(
228
+ ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
229
+ self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION),
230
+ self.gpu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION), 10000.0)
231
+ inf_nan_consistency = inf_nan_consistency and ulp_inf_nan_consistency
232
+ self.compare_message += message
233
+ if inf_nan_consistency:
234
+ self.ulp_err_status = self._get_ulp_status(self.npu_precision.get(ApiPrecisionCompareColumn.DEVICE_DTYPE))
235
+ else:
236
+ self.ulp_err_status = CompareConst.ERROR
237
+
238
+ def _get_ulp_status(self, dtype):
239
+ if dtype == torch.float32:
240
+ if self.mean_ulp_err < 64:
241
+ return CompareConst.PASS
242
+ elif self.ulp_err_proportion < 0.05:
243
+ return CompareConst.PASS
244
+ elif self.ulp_err_proportion_ratio < 1:
245
+ return CompareConst.PASS
246
+ else:
247
+ self.compare_message += "ERROR: ULP误差不满足标准\n"
248
+ return CompareConst.ERROR
249
+ else:
250
+ if self.ulp_err_proportion < 0.001:
251
+ return CompareConst.PASS
252
+ elif self.ulp_err_proportion_ratio < 1:
253
+ return CompareConst.PASS
254
+ else:
255
+ self.compare_message += "ERROR: ULP误差不满足标准\n"
256
+ return CompareConst.ERROR
257
+
258
+
259
+ def write_detail_csv(content, save_path):
260
+ rows = []
261
+ content = ["{:.{}f}".format(item, msCheckerConfig.precision) \
262
+ if isinstance(item, float) else item for item in content]
263
+ rows.append(content)
264
+ write_csv(rows, save_path)
265
+
266
+
267
+ def api_precision_compare(config):
268
+ logger.info("Start compare task")
269
+ logger.info(f"Compare task result will be saved in {config.result_csv_path}")
270
+ logger.info(f"Compare task detail will be saved in {config.details_csv_path}")
271
+ try:
272
+ npu_data = pd.read_csv(config.npu_csv_path)
273
+ except Exception as err:
274
+ logger.error(f"Open npu csv Error: %s" % str(err))
275
+ check_csv_columns(npu_data.columns, "npu_csv")
276
+ try:
277
+ gpu_data = pd.read_csv(config.gpu_csv_path)
278
+ except Exception as err:
279
+ logger.error(f"Open gpu csv Error: %s" % str(err))
280
+ check_csv_columns(gpu_data.columns, "gpu_csv")
281
+ detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
282
+ result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
283
+ write_csv(result_csv_title, config.result_csv_path)
284
+ write_csv(detail_csv_title, config.details_csv_path)
285
+ try:
286
+ analyse_csv(npu_data, gpu_data, config)
287
+ except Exception as err:
288
+ logger.error(f"Analyse csv Error: %s" % str(err))
289
+ change_mode(config.result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
290
+ change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
291
+
292
+
293
+ def online_api_precision_compare(online_config):
294
+ rank = online_config.rank
295
+ result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace("_rank*.csv", f"_rank{rank}.csv")
296
+ details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace("_rank*.csv", f"_rank{rank}.csv")
297
+ detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
298
+ result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
299
+ if not os.path.exists(result_csv_path):
300
+ write_csv(result_csv_title, result_csv_path)
301
+ if not os.path.exists(details_csv_path):
302
+ write_csv(detail_csv_title, details_csv_path)
303
+ config = CompareConfig("", "", result_csv_path, details_csv_path)
304
+ try:
305
+ npu_data, gpu_data = online_config.npu_data, online_config.gpu_data
306
+ check_csv_columns(npu_data.columns, "npu_csv")
307
+ check_csv_columns(gpu_data.columns, "gpu_csv")
308
+ analyse_csv(npu_data, gpu_data, config)
309
+ except Exception as err:
310
+ logger.error(f"Online api precision compare Error: {str(err)}")
311
+ change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
312
+ change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
313
+
314
+
315
+ def analyse_csv(npu_data, gpu_data, config):
316
+ forward_status, backward_status = [], []
317
+ last_api_name, last_api_dtype, last_api_full_name = None, None, None
318
+ for _, row_npu in npu_data.iterrows():
319
+ message = ''
320
+ compare_column = ApiPrecisionOutputColumn()
321
+ full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
322
+ row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
323
+ api_name, api_full_name, direction_status = extract_detailed_api_segments(full_api_name_with_direction_status)
324
+ if not api_full_name:
325
+ err_message = f"The API name {full_api_name_with_direction_status} is invalid."
326
+ logger.error(err_message)
327
+ compare_column.api_name = full_api_name_with_direction_status
328
+ compare_column.compare_result = CompareConst.SKIP
329
+ compare_column.compare_message = err_message
330
+ write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
331
+ write_csv([[full_api_name_with_direction_status, CompareConst.SKIP, CompareConst.SKIP, err_message]],
332
+ config.result_csv_path)
333
+ continue
334
+ if row_gpu.empty:
335
+ logger.warning(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.')
336
+ continue
337
+ if len(row_gpu) > 1:
338
+ msg = f'This API : {full_api_name_with_direction_status} has multiple records in the GPU data.'
339
+ raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
340
+ row_gpu = row_gpu.iloc[0]
341
+ new_status = CompareConst.SPACE
342
+ try:
343
+ new_status = get_api_status(row_npu, row_gpu, api_name, compare_column)
344
+ except Exception as err:
345
+ logger.error(f"Get api status error: {str(err)}")
346
+ compare_column.api_name = full_api_name_with_direction_status
347
+ compare_column.compare_result = CompareConst.SKIP
348
+ compare_column.compare_message = str(err)
349
+ write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
350
+ write_csv([[full_api_name_with_direction_status, CompareConst.SKIP, CompareConst.SKIP, str(err)]],
351
+ config.result_csv_path)
352
+ continue
353
+
354
+ write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
355
+
356
+ if last_api_name is not None and api_full_name != last_api_name:
357
+ if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
358
+ message = unsupported_message
359
+ write_csv([[last_api_name, CompareConst.SKIP, CompareConst.SKIP, message]], config.result_csv_path)
360
+ print_test_success(last_api_name, CompareConst.SKIP, CompareConst.SKIP)
361
+ forward_status, backward_status = [], []
362
+ message = ''
363
+ else:
364
+ forward_result = get_api_checker_result(forward_status)
365
+ backward_result = get_api_checker_result(backward_status)
366
+ message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
367
+ write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
368
+ print_test_success(last_api_name, forward_result, backward_result)
369
+ forward_status, backward_status = [], []
370
+ message = ''
371
+
372
+ is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST
373
+ last_api_name = api_full_name
374
+
375
+ last_api_dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
376
+ if not is_supported:
377
+ continue
378
+
379
+ if direction_status == 'forward':
380
+ forward_status.append(new_status)
381
+ elif direction_status == 'backward':
382
+ backward_status.append(new_status)
383
+ else:
384
+ logger.error(f"Invalid direction status: {direction_status}")
385
+
386
+ if last_api_name is not None:
387
+ if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
388
+ message = unsupported_message
389
+ write_csv([[last_api_name, CompareConst.SKIP, CompareConst.SKIP, message]], config.result_csv_path)
390
+ print_test_success(last_api_name, CompareConst.SKIP, CompareConst.SKIP)
391
+ else:
392
+ forward_result = get_api_checker_result(forward_status)
393
+ backward_result = get_api_checker_result(backward_status)
394
+ message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
395
+ write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
396
+ print_test_success(last_api_name, forward_result, backward_result)
397
+
398
+
399
+ def get_api_status(row_npu, row_gpu, api_name, compare_column):
400
+ full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
401
+ # 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对
402
+ if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace():
403
+ compare_column.api_name = full_api_name_with_direction_status
404
+ compare_column.compare_result = CompareConst.SKIP
405
+ compare_column.compare_message = row_npu[ApiPrecisionCompareColumn.MESSAGE]
406
+ new_status = CompareConst.SKIP
407
+ else:
408
+ compare_column.api_name = full_api_name_with_direction_status
409
+ if api_name in thousandth_standard_api:
410
+ new_status = record_thousandth_threshold_result(compare_column, row_npu)
411
+ elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or \
412
+ api_name in binary_standard_api:
413
+ new_status = record_binary_consistency_result(api_name, compare_column, row_npu)
414
+ elif api_name in absolute_standard_api:
415
+ new_status = record_absolute_threshold_result(compare_column, row_npu)
416
+ elif api_name in ulp_standard_api and \
417
+ row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST:
418
+ us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu)
419
+ new_status = record_ulp_compare_result(compare_column, us)
420
+ elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST:
421
+ bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu)
422
+ new_status = record_benchmark_compare_result(compare_column, bs)
423
+ return new_status
424
+
425
+
426
+ def print_test_success(api_full_name, forward_result, backward_result):
427
+ is_fwd_success = (forward_result == CompareConst.PASS)
428
+ is_bwd_success = (backward_result == CompareConst.PASS or backward_result == CompareConst.SPACE)
429
+ logger.info(f"running api_full_name {api_full_name} compare, "
430
+ f"is_fwd_success: {is_fwd_success}, "
431
+ f"is_bwd_success: {is_bwd_success}")
432
+
433
+
434
+ def check_error_rate(npu_error_rate):
435
+ return CompareConst.PASS if convert_str_to_float(npu_error_rate) == 0 else CompareConst.ERROR
436
+
437
+
438
+ def get_absolute_threshold_result(row_npu):
439
+ inf_nan_error_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO])
440
+ rel_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.REL_ERR_RATIO])
441
+ abs_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.ABS_ERR_RATIO])
442
+
443
+ inf_nan_result = CompareConst.PASS if inf_nan_error_ratio == 0 else CompareConst.ERROR
444
+ rel_err_result = CompareConst.PASS if rel_err_ratio == 0 else CompareConst.ERROR
445
+ abs_err_result = CompareConst.PASS if abs_err_ratio == 0 else CompareConst.ERROR
446
+
447
+ if CompareConst.ERROR in [inf_nan_result, rel_err_result, abs_err_result]:
448
+ absolute_threshold_result = CompareConst.ERROR
449
+ else:
450
+ absolute_threshold_result = CompareConst.PASS
451
+
452
+ return {
453
+ "inf_nan_error_ratio": inf_nan_error_ratio,
454
+ "inf_nan_result": inf_nan_result,
455
+ "rel_err_ratio": rel_err_ratio,
456
+ "rel_err_result": rel_err_result,
457
+ "abs_err_ratio": abs_err_ratio,
458
+ "abs_err_result": abs_err_result,
459
+ "absolute_threshold_result": absolute_threshold_result,
460
+ }
461
+
462
+
463
+ def get_api_checker_result(status):
464
+ if not status:
465
+ return CompareConst.SPACE
466
+ if all(item == CompareConst.SKIP for item in status):
467
+ return CompareConst.SKIP
468
+ for const in (CompareConst.ERROR, CompareConst.WARNING):
469
+ if const in status:
470
+ return const
471
+ return CompareConst.PASS
472
+
473
+
474
+ def check_csv_columns(columns, csv_type):
475
+ required_columns = ApiPrecisionCompareColumn.to_required_columns()
476
+ missing_columns = [column for column in required_columns if column not in columns]
477
+ if missing_columns:
478
+ msg = f"The following columns {','.join(missing_columns)} are missing in{csv_type}"
479
+ raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
480
+
481
+
482
+ def record_binary_consistency_result(api_name, compare_column, row_npu):
483
+ new_status = check_error_rate(row_npu[ApiPrecisionCompareColumn.ERROR_RATE])
484
+ compare_column.error_rate = row_npu[ApiPrecisionCompareColumn.ERROR_RATE]
485
+ compare_column.error_rate_status = new_status
486
+ compare_column.compare_result = new_status
487
+ compare_column.compare_algorithm = "二进制一致法"
488
+ message = ''
489
+ if compare_column.error_rate_status == CompareConst.ERROR:
490
+ message += "ERROR: 二进制一致错误率超过阈值\n"
491
+ message += CompareMessage.get(api_name, "")
492
+ compare_column.compare_message = message
493
+ return new_status
494
+
495
+
496
+ def record_absolute_threshold_result(compare_column, row_npu):
497
+ absolute_threshold_result = get_absolute_threshold_result(row_npu)
498
+ compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
499
+ compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
500
+ compare_column.rel_err_ratio = absolute_threshold_result.get("rel_err_ratio")
501
+ compare_column.rel_err_ratio_status = absolute_threshold_result.get("rel_err_result")
502
+ compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
503
+ compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
504
+ compare_column.compare_result = absolute_threshold_result.get("absolute_threshold_result")
505
+ compare_column.compare_algorithm = "绝对阈值法"
506
+ message = ''
507
+ if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
508
+ message += "ERROR: inf/nan错误率超过阈值\n"
509
+ if compare_column.rel_err_ratio_status == CompareConst.ERROR:
510
+ message += "ERROR: 相对误差错误率超过阈值\n"
511
+ if compare_column.abs_err_ratio_status == CompareConst.ERROR:
512
+ message += "ERROR: 绝对误差错误率超过阈值\n"
513
+ compare_column.compare_message = message
514
+ return compare_column.compare_result
515
+
516
+
517
+ def record_benchmark_compare_result(compare_column, bs):
518
+ bs.get_result()
519
+ compare_column.small_value_err_ratio = bs.small_value_err_ratio
520
+ compare_column.small_value_err_status = bs.small_value_err_status
521
+ compare_column.rmse_ratio = bs.rmse_ratio
522
+ compare_column.rmse_status = bs.rmse_status
523
+ compare_column.max_rel_err_ratio = bs.max_rel_err_ratio
524
+ compare_column.max_rel_err_status = bs.max_rel_err_status
525
+ compare_column.mean_rel_err_ratio = bs.mean_rel_err_ratio
526
+ compare_column.mean_rel_err_status = bs.mean_rel_err_status
527
+ compare_column.eb_ratio = bs.eb_ratio
528
+ compare_column.eb_status = bs.eb_status
529
+ compare_column.compare_result = bs.final_result
530
+ compare_column.compare_algorithm = "标杆比对法"
531
+ compare_column.compare_message = bs.compare_message
532
+ for status_attr, messages in benchmark_message.items():
533
+ status_value = getattr(compare_column, status_attr)
534
+ if status_value in messages:
535
+ compare_column.compare_message += messages[status_value]
536
+ return compare_column.compare_result
537
+
538
+
539
+ def record_ulp_compare_result(compare_column, us):
540
+ us.get_result()
541
+ compare_column.mean_ulp_err = us.mean_ulp_err
542
+ compare_column.ulp_err_proportion = us.ulp_err_proportion
543
+ compare_column.ulp_err_proportion_ratio = us.ulp_err_proportion_ratio
544
+ compare_column.ulp_err_status = us.ulp_err_status
545
+ compare_column.compare_result = us.ulp_err_status
546
+ compare_column.compare_algorithm = "ULP误差比对法"
547
+ compare_column.compare_message = us.compare_message
548
+ return compare_column.compare_result
549
+
550
+
551
+ def check_thousandth_rate(thousandth_rate):
552
+ return CompareConst.PASS if convert_str_to_float(thousandth_rate) >= 0.999 else CompareConst.ERROR
553
+
554
+
555
+ def record_thousandth_threshold_result(compare_column, row_npu):
556
+ new_status = check_thousandth_rate(row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH])
557
+ compare_column.rel_err_thousandth = row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
558
+ compare_column.rel_err_thousandth_status = new_status
559
+ compare_column.compare_result = new_status
560
+ compare_column.compare_algorithm = "双千指标法"
561
+ message = ''
562
+ if compare_column.rel_err_thousandth_status == CompareConst.ERROR:
563
+ message += "ERROR: 双千指标不达标\n"
564
+ compare_column.compare_message = message
565
+ return compare_column.compare_result
566
+
567
+
568
+ def _api_precision_compare(parser=None):
569
+ if not parser:
570
+ parser = argparse.ArgumentParser()
571
+ _api_precision_compare_parser(parser)
572
+ args = parser.parse_args(sys.argv[1:])
573
+ _api_precision_compare_command(args)
574
+
575
+
576
+ def _api_precision_compare_command(args):
577
+ npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail')
578
+ gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail')
579
+ out_path = os.path.realpath(args.out_path) if args.out_path else "./"
580
+ check_path_before_create(out_path)
581
+ create_directory(out_path)
582
+ out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
583
+ out_path = out_path_checker.common_check()
584
+ result_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_RESULT_FILE_NAME)
585
+ details_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_DETAILS_FILE_NAME)
586
+ compare_config = CompareConfig(npu_csv_path, gpu_csv_path, result_csv_path, details_csv_path)
587
+ api_precision_compare(compare_config)
588
+
589
+
590
+ def _api_precision_compare_parser(parser):
591
+ parser.add_argument("-npu", "--npu_csv_path", dest="npu_csv_path", default="", type=str,
592
+ help="<Required> , Accuracy_checking_details.csv generated on the NPU by using the "
593
+ "api_accuracy_checker tool.",
594
+ required=True)
595
+ parser.add_argument("-gpu", "--gpu_csv_path", dest="gpu_csv_path", default="", type=str,
596
+ help="<Required> Accuracy_checking_details.csv generated on the GPU by using the "
597
+ "api_accuracy_checker tool.",
598
+ required=False)
599
+ parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
600
+ help="<optional> The api precision compare task result out path.",
601
+ required=False)
602
+
603
+
604
+ if __name__ == '__main__':
605
+ _api_precision_compare()
606
+ logger.info("Compare task completed.")