mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,581 +1,632 @@
1
- import argparse
2
- import math
3
- import os
4
- import sys
5
- from collections import namedtuple
6
-
7
- import torch
8
- import pandas as pd
9
-
10
- from msprobe.core.common.utils import write_csv
11
- from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
12
- from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
13
- API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
14
- ApiPrecisionCompareColumn, absolute_standard_api, binary_standard_api, ulp_standard_api, thousandth_standard_api, \
15
- BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage, is_inf_or_nan, \
16
- check_inf_or_nan
17
- from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
18
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path
19
- from msprobe.core.common.file_check import FileChecker, change_mode, check_path_before_create, create_directory
20
- from msprobe.pytorch.common.log import logger
21
- from msprobe.core.common.utils import CompareException
22
- from msprobe.core.common.const import CompareConst, FileCheckConst, Const
23
-
24
- CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
25
- BenchmarkInf_Nan_Consistency = namedtuple('BenchmarkInf_Nan_Consistency', ['small_value_inf_nan_consistency',
26
- 'rmse_inf_nan_consistency',
27
- 'max_rel_inf_nan_consistency',
28
- 'mean_rel_inf_nan_consistency',
29
- 'eb_inf_nan_consistency'])
30
- unsupported_message = 'This data type does not support benchmark compare.'
31
-
32
- DEFAULT_THRESHOLD = 1
33
-
34
- benchmark_algorithms_thresholds = {
35
- 'small_value': {
36
- 'error_threshold': 2,
37
- 'warning_threshold': 1
38
- },
39
- 'rmse': {
40
- 'error_threshold': 2,
41
- 'warning_threshold': 1
42
- },
43
- 'max_rel_err': {
44
- 'error_threshold': 10,
45
- 'warning_threshold': 1
46
- },
47
- 'mean_rel_err': {
48
- 'error_threshold': 2,
49
- 'warning_threshold': 1
50
- },
51
- 'eb': {
52
- 'error_threshold': 2,
53
- 'warning_threshold': 1
54
- }
55
- }
56
-
57
- benchmark_message = {
58
- "small_value_err_status": {
59
- CompareConst.ERROR: "ERROR: 小值域错误比值超过阈值\n",
60
- CompareConst.WARNING: "WARNING: 小值域错误比值超过阈值\n"
61
- },
62
- "rmse_status": {
63
- CompareConst.ERROR: "ERROR: 均方根误差比值超过阈值\n",
64
- CompareConst.WARNING: "WARNING: 均方根误差比值超过阈值\n"
65
- },
66
- "max_rel_err_status": {
67
- CompareConst.ERROR: "ERROR: 相对误差最大值比值超过阈值\n",
68
- CompareConst.WARNING: "WARNING: 相对误差最大值比值超过阈值\n"
69
- },
70
- "mean_rel_err_status": {
71
- CompareConst.ERROR: "ERROR: 相对误差平均值比值超过阈值\n",
72
- CompareConst.WARNING: "WARNING: 相对误差平均值比值超过阈值\n"
73
- }
74
- }
75
-
76
-
77
- class Standard:
78
- @staticmethod
79
- def _calc_ratio(column_name, x, y, default_value):
80
- '''
81
- 计算npu侧和gpu侧统计量的比值
82
- 输入:
83
- column_name:统计量名称
84
- x:npu侧统计量
85
- y:gpu侧统计量
86
- default:当x不接近0,y接近0,设置的比值默认值
87
- 输出:
88
- ratio:统计量x和y的比值
89
- inf_nan_consistency:不出现inf或nan时为True,出现inf或nan时必须同时为inf或-inf或nan才为True,否则为False
90
- message:当出现inf或nan时的提示信息
91
- '''
92
- x, y = convert_str_to_float(x), convert_str_to_float(y)
93
-
94
- if is_inf_or_nan(x) or is_inf_or_nan(y):
95
- return check_inf_or_nan(x, y, column_name)
96
-
97
- inf_nan_consistency = True
98
- message = ""
99
- if math.isclose(y, 0.0):
100
- if math.isclose(x, 0.0):
101
- return 1.0, inf_nan_consistency, message
102
- else:
103
- return default_value, inf_nan_consistency, message
104
- else:
105
- return abs(x / y), inf_nan_consistency, message
106
-
107
-
108
- class BenchmarkStandard(Standard):
109
- def __init__(self, api_name, npu_precision, gpu_precision):
110
- self.api_name = api_name
111
- self.npu_precision = npu_precision
112
- self.gpu_precision = gpu_precision
113
- self.small_value_err_ratio = 1
114
- self.rmse_ratio = 1
115
- self.max_rel_err_ratio = 1
116
- self.mean_rel_err_ratio = 1
117
- self.eb_ratio = 1
118
- self.small_value_err_status = CompareConst.PASS
119
- self.rmse_status = CompareConst.PASS
120
- self.max_rel_err_status = CompareConst.PASS
121
- self.mean_rel_err_status = CompareConst.PASS
122
- self.eb_status = CompareConst.PASS
123
- self.check_result_list = []
124
- self.final_result = CompareConst.PASS
125
- self.compare_message = ""
126
-
127
- def __str__(self):
128
- return "%s" % (self.api_name)
129
-
130
- @staticmethod
131
- def _get_status(ratio, algorithm):
132
- if math.isnan(ratio) or math.isinf(ratio):
133
- return CompareConst.PASS
134
- error_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('error_threshold', DEFAULT_THRESHOLD)
135
- warning_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('warning_threshold',
136
- DEFAULT_THRESHOLD)
137
- if ratio > error_threshold:
138
- return CompareConst.ERROR
139
- elif ratio > warning_threshold:
140
- return CompareConst.WARNING
141
- return CompareConst.PASS
142
-
143
- def get_result(self):
144
- inf_nan_consistency = self._compare_ratio()
145
- small_value_inf_nan_consistency = inf_nan_consistency.small_value_inf_nan_consistency
146
- rmse_inf_nan_consistency = inf_nan_consistency.rmse_inf_nan_consistency
147
- max_rel_inf_nan_consistency = inf_nan_consistency.max_rel_inf_nan_consistency
148
- mean_rel_inf_nan_consistency = inf_nan_consistency.mean_rel_inf_nan_consistency
149
- eb_inf_nan_consistency = inf_nan_consistency.eb_inf_nan_consistency
150
- self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') if \
151
- small_value_inf_nan_consistency else CompareConst.ERROR
152
- self.check_result_list.append(self.small_value_err_status)
153
- self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') if rmse_inf_nan_consistency \
154
- else CompareConst.ERROR
155
- self.check_result_list.append(self.rmse_status)
156
- self.max_rel_err_status = self._get_status(self.max_rel_err_ratio, 'max_rel_err') if max_rel_inf_nan_consistency \
157
- else CompareConst.ERROR
158
- self.check_result_list.append(self.max_rel_err_status)
159
- self.mean_rel_err_status = self._get_status(self.mean_rel_err_ratio, 'mean_rel_err') if mean_rel_inf_nan_consistency \
160
- else CompareConst.ERROR
161
- self.check_result_list.append(self.mean_rel_err_status)
162
- self.eb_status = self._get_status(self.eb_ratio, 'eb')
163
- if CompareConst.ERROR in self.check_result_list:
164
- self.final_result = CompareConst.ERROR
165
- elif CompareConst.WARNING in self.check_result_list:
166
- self.final_result = CompareConst.WARNING
167
-
168
- def to_column_value(self):
169
- return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
170
- self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
171
- self.mean_rel_err_status, self.eb_ratio, self.eb_status]
172
-
173
- def _compare_ratio(self):
174
-
175
- self.small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = self._calc_ratio(
176
- ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE,
177
- self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE),
178
- self.gpu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), 10000.0)
179
- self.compare_message += small_value_message
180
- self.rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE,
181
- self.npu_precision.get(ApiPrecisionCompareColumn.RMSE),
182
- self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0)
183
- self.compare_message += rmse_message
184
- self.max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = self._calc_ratio(
185
- ApiPrecisionCompareColumn.MAX_REL_ERR,
186
- self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR),
187
- self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0)
188
- self.compare_message += max_rel_message
189
- self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(ApiPrecisionCompareColumn.MEAN_REL_ERR,
190
- self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR),
191
- self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0)
192
- self.compare_message += mean_rel_message
193
- self.eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB,
194
- self.npu_precision.get(ApiPrecisionCompareColumn.EB),
195
- self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0)
196
- self.compare_message += eb_message
197
-
198
- return BenchmarkInf_Nan_Consistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
199
- max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency, eb_inf_nan_consistency)
200
-
201
-
202
- class ULPStandard(Standard):
203
- def __init__(self, api_name, npu_precision, gpu_precision):
204
- self.api_name = api_name
205
- self.npu_precision = npu_precision
206
- self.gpu_precision = gpu_precision
207
- self.mean_ulp_err = 0
208
- self.ulp_err_proportion = 0
209
- self.ulp_err_proportion_ratio = 1
210
- self.ulp_err_status = CompareConst.PASS
211
- self.compare_message = ""
212
-
213
- def __str__(self):
214
- return f"{self.api_name}"
215
-
216
- def get_result(self):
217
- self.mean_ulp_err = convert_str_to_float(self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
218
- gpu_mean_ulp_err = convert_str_to_float(self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
219
- inf_nan_consistency = True
220
- if is_inf_or_nan(self.mean_ulp_err) or is_inf_or_nan(gpu_mean_ulp_err):
221
- _, inf_nan_consistency, message = check_inf_or_nan(self.mean_ulp_err, gpu_mean_ulp_err,
222
- ApiPrecisionCompareColumn.MEAN_ULP_ERR)
223
- self.compare_message += message
224
- self.ulp_err_proportion = convert_str_to_float(
225
- self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION))
226
- self.ulp_err_proportion_ratio, ulp_inf_nan_consistency, message = self._calc_ratio(
227
- ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
228
- self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION),
229
- self.gpu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION), 10000.0)
230
- inf_nan_consistency = inf_nan_consistency and ulp_inf_nan_consistency
231
- self.compare_message += message
232
- if inf_nan_consistency:
233
- self.ulp_err_status = self._get_ulp_status(self.npu_precision.get(ApiPrecisionCompareColumn.DEVICE_DTYPE))
234
- else:
235
- self.ulp_err_status = CompareConst.ERROR
236
-
237
- def _get_ulp_status(self, dtype):
238
- if dtype == torch.float32:
239
- if self.mean_ulp_err < 64:
240
- return CompareConst.PASS
241
- elif self.ulp_err_proportion < 0.05:
242
- return CompareConst.PASS
243
- elif self.ulp_err_proportion_ratio < 1:
244
- return CompareConst.PASS
245
- else:
246
- self.compare_message += "ERROR: ULP误差不满足标准\n"
247
- return CompareConst.ERROR
248
- else:
249
- if self.ulp_err_proportion < 0.001:
250
- return CompareConst.PASS
251
- elif self.ulp_err_proportion_ratio < 1:
252
- return CompareConst.PASS
253
- else:
254
- self.compare_message += "ERROR: ULP误差不满足标准\n"
255
- return CompareConst.ERROR
256
-
257
-
258
- def write_detail_csv(content, save_path):
259
- rows = []
260
- content = ["{:.{}f}".format(item, msCheckerConfig.precision) \
261
- if isinstance(item, float) else item for item in content]
262
- rows.append(content)
263
- write_csv(rows, save_path)
264
-
265
-
266
- def api_precision_compare(config):
267
- logger.info("Start compare task")
268
- logger.info(f"Compare task result will be saved in {config.result_csv_path}")
269
- logger.info(f"Compare task detail will be saved in {config.details_csv_path}")
270
- try:
271
- npu_data = pd.read_csv(config.npu_csv_path)
272
- except Exception as err:
273
- logger.error(f"Open npu csv Error: %s" % str(err))
274
- check_csv_columns(npu_data.columns, "npu_csv")
275
- try:
276
- gpu_data = pd.read_csv(config.gpu_csv_path)
277
- except Exception as err:
278
- logger.error(f"Open gpu csv Error: %s" % str(err))
279
- check_csv_columns(gpu_data.columns, "gpu_csv")
280
- detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
281
- result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
282
- write_csv(result_csv_title, config.result_csv_path)
283
- write_csv(detail_csv_title, config.details_csv_path)
284
- try:
285
- analyse_csv(npu_data, gpu_data, config)
286
- except Exception as err:
287
- logger.error(f"Analyse csv Error: %s" % str(err))
288
- change_mode(config.result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
289
- change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
290
-
291
-
292
- def online_api_precision_compare(online_config):
293
- rank = online_config.rank
294
- result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace("_rank*.csv", f"_rank{rank}.csv")
295
- details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace("_rank*.csv", f"_rank{rank}.csv")
296
- detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
297
- result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
298
- if not os.path.exists(result_csv_path):
299
- write_csv(result_csv_title, result_csv_path)
300
- if not os.path.exists(details_csv_path):
301
- write_csv(detail_csv_title, details_csv_path)
302
- config = CompareConfig("", "", result_csv_path, details_csv_path)
303
- try:
304
- npu_data, gpu_data = online_config.npu_data, online_config.gpu_data
305
- check_csv_columns(npu_data.columns, "npu_csv")
306
- check_csv_columns(gpu_data.columns, "gpu_csv")
307
- analyse_csv(npu_data, gpu_data, config)
308
- except Exception as err:
309
- logger.error(f"Online api precision compare Error: {str(err)}")
310
- change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
311
- change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
312
-
313
-
314
- def analyse_csv(npu_data, gpu_data, config):
315
- forward_status, backward_status = [], []
316
- last_api_name, last_api_dtype, last_api_full_name = None, None, None
317
- for _, row_npu in npu_data.iterrows():
318
- message = ''
319
- compare_column = ApiPrecisionOutputColumn()
320
- full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
321
- row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
322
- api_type, api_name, api_nums, direction_status, _, _ = full_api_name_with_direction_status.split(Const.SEP)
323
- api_full_name = Const.SEP.join([api_type, api_name, api_nums])
324
- if row_gpu.empty:
325
- logger.warning(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.')
326
- continue
327
- if len(row_gpu) > 1:
328
- msg = f'This API : {full_api_name_with_direction_status} has multiple records in the GPU data.'
329
- raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
330
- row_gpu = row_gpu.iloc[0]
331
- new_status = CompareConst.SPACE
332
- # 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对
333
- if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace():
334
- compare_column.api_name = full_api_name_with_direction_status
335
- compare_column.compare_result = CompareConst.SKIP
336
- compare_column.compare_message = row_npu[ApiPrecisionCompareColumn.MESSAGE]
337
- new_status = CompareConst.SKIP
338
- write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
339
- else:
340
- compare_column.api_name = full_api_name_with_direction_status
341
- if api_name in thousandth_standard_api:
342
- new_status = record_thousandth_threshold_result(compare_column, row_npu)
343
- elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or \
344
- api_name in binary_standard_api:
345
- new_status = record_binary_consistency_result(api_name, compare_column, row_npu)
346
- elif api_name in absolute_standard_api:
347
- new_status = record_absolute_threshold_result(compare_column, row_npu)
348
- elif api_name in ulp_standard_api and \
349
- row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST:
350
- us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu)
351
- new_status = record_ulp_compare_result(compare_column, us)
352
- elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST:
353
- bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu)
354
- new_status = record_benchmark_compare_result(compare_column, bs)
355
- write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
356
-
357
- if last_api_name is not None and api_name != last_api_name:
358
- if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
359
- message = unsupported_message
360
- write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path)
361
- print_test_success(api_full_name, "skip", "skip")
362
- forward_status, backward_status = [], []
363
- message = ''
364
- else:
365
- forward_result = get_api_checker_result(forward_status)
366
- backward_result = get_api_checker_result(backward_status)
367
- message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
368
- write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
369
- print_test_success(api_full_name, forward_result, backward_result)
370
- forward_status, backward_status = [], []
371
- message = ''
372
-
373
- is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST
374
- last_api_name = api_name
375
- last_api_full_name = api_full_name
376
-
377
- last_api_dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
378
- if not is_supported:
379
- continue
380
-
381
- if direction_status == 'forward':
382
- forward_status.append(new_status)
383
- elif direction_status == 'backward':
384
- backward_status.append(new_status)
385
- else:
386
- logger.error(f"Invalid direction status: {direction_status}")
387
-
388
- if last_api_name is not None:
389
- if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
390
- message = unsupported_message
391
- write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path)
392
- print_test_success(last_api_full_name, "skip", "skip")
393
- else:
394
- forward_result = get_api_checker_result(forward_status)
395
- backward_result = get_api_checker_result(backward_status)
396
- message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
397
- write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
398
- print_test_success(last_api_full_name, forward_result, backward_result)
399
-
400
-
401
- def print_test_success(api_full_name, forward_result, backward_result):
402
- is_fwd_success = (forward_result == CompareConst.PASS)
403
- is_bwd_success = (backward_result == CompareConst.PASS or backward_result == CompareConst.SPACE)
404
- logger.info(f"running api_full_name {api_full_name} compare, "
405
- f"is_fwd_success: {is_fwd_success}, "
406
- f"is_bwd_success: {is_bwd_success}")
407
-
408
-
409
- def check_error_rate(npu_error_rate):
410
- return CompareConst.PASS if convert_str_to_float(npu_error_rate) == 0 else CompareConst.ERROR
411
-
412
-
413
- def get_absolute_threshold_result(row_npu):
414
- inf_nan_error_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO])
415
- rel_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.REL_ERR_RATIO])
416
- abs_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.ABS_ERR_RATIO])
417
-
418
- inf_nan_result = CompareConst.PASS if inf_nan_error_ratio == 0 else CompareConst.ERROR
419
- rel_err_result = CompareConst.PASS if rel_err_ratio == 0 else CompareConst.ERROR
420
- abs_err_result = CompareConst.PASS if abs_err_ratio == 0 else CompareConst.ERROR
421
-
422
- if CompareConst.ERROR in [inf_nan_result, rel_err_result, abs_err_result]:
423
- absolute_threshold_result = CompareConst.ERROR
424
- else:
425
- absolute_threshold_result = CompareConst.PASS
426
-
427
- return {
428
- "inf_nan_error_ratio": inf_nan_error_ratio,
429
- "inf_nan_result": inf_nan_result,
430
- "rel_err_ratio": rel_err_ratio,
431
- "rel_err_result": rel_err_result,
432
- "abs_err_ratio": abs_err_ratio,
433
- "abs_err_result": abs_err_result,
434
- "absolute_threshold_result": absolute_threshold_result,
435
- }
436
-
437
-
438
- def get_api_checker_result(status):
439
- if not status:
440
- return CompareConst.SPACE
441
- if all(item == CompareConst.SKIP for item in status):
442
- return CompareConst.SKIP
443
- for const in (CompareConst.ERROR, CompareConst.WARNING):
444
- if const in status:
445
- return const
446
- return CompareConst.PASS
447
-
448
-
449
- def check_csv_columns(columns, csv_type):
450
- required_columns = ApiPrecisionCompareColumn.to_required_columns()
451
- missing_columns = [column for column in required_columns if column not in columns]
452
- if missing_columns:
453
- msg = f"The following columns {','.join(missing_columns)} are missing in{csv_type}"
454
- raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
455
-
456
-
457
- def record_binary_consistency_result(api_name, compare_column, row_npu):
458
- new_status = check_error_rate(row_npu[ApiPrecisionCompareColumn.ERROR_RATE])
459
- compare_column.error_rate = row_npu[ApiPrecisionCompareColumn.ERROR_RATE]
460
- compare_column.error_rate_status = new_status
461
- compare_column.compare_result = new_status
462
- compare_column.compare_algorithm = "二进制一致法"
463
- message = ''
464
- if compare_column.error_rate_status == CompareConst.ERROR:
465
- message += "ERROR: 二进制一致错误率超过阈值\n"
466
- message += CompareMessage.get(api_name, "")
467
- compare_column.compare_message = message
468
- return new_status
469
-
470
-
471
- def record_absolute_threshold_result(compare_column, row_npu):
472
- absolute_threshold_result = get_absolute_threshold_result(row_npu)
473
- compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
474
- compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
475
- compare_column.rel_err_ratio = absolute_threshold_result.get("rel_err_ratio")
476
- compare_column.rel_err_ratio_status = absolute_threshold_result.get("rel_err_result")
477
- compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
478
- compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
479
- compare_column.compare_result = absolute_threshold_result.get("absolute_threshold_result")
480
- compare_column.compare_algorithm = "绝对阈值法"
481
- message = ''
482
- if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
483
- message += "ERROR: inf/nan错误率超过阈值\n"
484
- if compare_column.rel_err_ratio_status == CompareConst.ERROR:
485
- message += "ERROR: 相对误差错误率超过阈值\n"
486
- if compare_column.abs_err_ratio_status == CompareConst.ERROR:
487
- message += "ERROR: 绝对误差错误率超过阈值\n"
488
- compare_column.compare_message = message
489
- return compare_column.compare_result
490
-
491
-
492
- def record_benchmark_compare_result(compare_column, bs):
493
- bs.get_result()
494
- compare_column.small_value_err_ratio = bs.small_value_err_ratio
495
- compare_column.small_value_err_status = bs.small_value_err_status
496
- compare_column.rmse_ratio = bs.rmse_ratio
497
- compare_column.rmse_status = bs.rmse_status
498
- compare_column.max_rel_err_ratio = bs.max_rel_err_ratio
499
- compare_column.max_rel_err_status = bs.max_rel_err_status
500
- compare_column.mean_rel_err_ratio = bs.mean_rel_err_ratio
501
- compare_column.mean_rel_err_status = bs.mean_rel_err_status
502
- compare_column.eb_ratio = bs.eb_ratio
503
- compare_column.eb_status = bs.eb_status
504
- compare_column.compare_result = bs.final_result
505
- compare_column.compare_algorithm = "标杆比对法"
506
- compare_column.compare_message = bs.compare_message
507
- for status_attr, messages in benchmark_message.items():
508
- status_value = getattr(compare_column, status_attr)
509
- if status_value in messages:
510
- compare_column.compare_message += messages[status_value]
511
- return compare_column.compare_result
512
-
513
-
514
- def record_ulp_compare_result(compare_column, us):
515
- us.get_result()
516
- compare_column.mean_ulp_err = us.mean_ulp_err
517
- compare_column.ulp_err_proportion = us.ulp_err_proportion
518
- compare_column.ulp_err_proportion_ratio = us.ulp_err_proportion_ratio
519
- compare_column.ulp_err_status = us.ulp_err_status
520
- compare_column.compare_result = us.ulp_err_status
521
- compare_column.compare_algorithm = "ULP误差比对法"
522
- compare_column.compare_message = us.compare_message
523
- return compare_column.compare_result
524
-
525
-
526
- def check_thousandth_rate(thousandth_rate):
527
- return CompareConst.PASS if convert_str_to_float(thousandth_rate) >= 0.999 else CompareConst.ERROR
528
-
529
-
530
- def record_thousandth_threshold_result(compare_column, row_npu):
531
- new_status = check_thousandth_rate(row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH])
532
- compare_column.rel_err_thousandth = row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
533
- compare_column.rel_err_thousandth_status = new_status
534
- compare_column.compare_result = new_status
535
- compare_column.compare_algorithm = "双千指标法"
536
- message = ''
537
- if compare_column.rel_err_thousandth_status == CompareConst.ERROR:
538
- message += "ERROR: 双千指标不达标\n"
539
- compare_column.compare_message = message
540
- return compare_column.compare_result
541
-
542
-
543
- def _api_precision_compare(parser=None):
544
- if not parser:
545
- parser = argparse.ArgumentParser()
546
- _api_precision_compare_parser(parser)
547
- args = parser.parse_args(sys.argv[1:])
548
- _api_precision_compare_command(args)
549
-
550
-
551
- def _api_precision_compare_command(args):
552
- npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail')
553
- gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail')
554
- out_path = os.path.realpath(args.out_path) if args.out_path else "./"
555
- check_path_before_create(out_path)
556
- create_directory(out_path)
557
- out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
558
- out_path = out_path_checker.common_check()
559
- result_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_RESULT_FILE_NAME)
560
- details_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_DETAILS_FILE_NAME)
561
- compare_config = CompareConfig(npu_csv_path, gpu_csv_path, result_csv_path, details_csv_path)
562
- api_precision_compare(compare_config)
563
-
564
-
565
- def _api_precision_compare_parser(parser):
566
- parser.add_argument("-npu", "--npu_csv_path", dest="npu_csv_path", default="", type=str,
567
- help="<Required> , Accuracy_checking_details.csv generated on the NPU by using the "
568
- "api_accuracy_checker tool.",
569
- required=True)
570
- parser.add_argument("-gpu", "--gpu_csv_path", dest="gpu_csv_path", default="", type=str,
571
- help="<Required> Accuracy_checking_details.csv generated on the GPU by using the "
572
- "api_accuracy_checker tool.",
573
- required=False)
574
- parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
575
- help="<optional> The api precision compare task result out path.",
576
- required=False)
577
-
578
-
579
- if __name__ == '__main__':
580
- _api_precision_compare()
581
- logger.info("Compare task completed.")
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import argparse
19
+ import math
20
+ import os
21
+ import sys
22
+ from collections import namedtuple
23
+
24
+ import torch
25
+ import pandas as pd
26
+
27
+ from msprobe.core.common.file_utils import write_csv, read_csv
28
+ from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
29
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
30
+ API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
31
+ ApiPrecisionCompareColumn, absolute_standard_api, binary_standard_api, ulp_standard_api, thousandth_standard_api, \
32
+ BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage, is_inf_or_nan, \
33
+ check_inf_or_nan
34
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
35
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path
36
+ from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments
37
+ from msprobe.core.common.file_utils import FileChecker, change_mode, check_path_before_create, create_directory
38
+ from msprobe.pytorch.common.log import logger
39
+ from msprobe.core.common.utils import CompareException
40
+ from msprobe.core.common.const import Const, CompareConst, FileCheckConst
41
+
42
+ CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
43
+ BenchmarkInfNanConsistency = namedtuple('BenchmarkInfNanConsistency', ['small_value_inf_nan_consistency',
44
+ 'rmse_inf_nan_consistency',
45
+ 'max_rel_inf_nan_consistency',
46
+ 'mean_rel_inf_nan_consistency',
47
+ 'eb_inf_nan_consistency'])
48
+ UNSUPPORTED_MESSAGE = 'This data type does not support benchmark compare.'
49
+
50
+ DEFAULT_THRESHOLD = 1
51
+
52
+ benchmark_algorithms_thresholds = {
53
+ 'small_value': {
54
+ 'error_threshold': 2,
55
+ 'warning_threshold': 1
56
+ },
57
+ 'rmse': {
58
+ 'error_threshold': 2,
59
+ 'warning_threshold': 1
60
+ },
61
+ 'max_rel_err': {
62
+ 'error_threshold': 10,
63
+ 'warning_threshold': 1
64
+ },
65
+ 'mean_rel_err': {
66
+ 'error_threshold': 2,
67
+ 'warning_threshold': 1
68
+ },
69
+ 'eb': {
70
+ 'error_threshold': 2,
71
+ 'warning_threshold': 1
72
+ }
73
+ }
74
+
75
+ benchmark_message = {
76
+ "small_value_err_status": {
77
+ CompareConst.ERROR: "ERROR: 小值域错误比值超过阈值\n",
78
+ CompareConst.WARNING: "WARNING: 小值域错误比值超过阈值\n"
79
+ },
80
+ "rmse_status": {
81
+ CompareConst.ERROR: "ERROR: 均方根误差比值超过阈值\n",
82
+ CompareConst.WARNING: "WARNING: 均方根误差比值超过阈值\n"
83
+ },
84
+ "max_rel_err_status": {
85
+ CompareConst.ERROR: "ERROR: 相对误差最大值比值超过阈值\n",
86
+ CompareConst.WARNING: "WARNING: 相对误差最大值比值超过阈值\n"
87
+ },
88
+ "mean_rel_err_status": {
89
+ CompareConst.ERROR: "ERROR: 相对误差平均值比值超过阈值\n",
90
+ CompareConst.WARNING: "WARNING: 相对误差平均值比值超过阈值\n"
91
+ }
92
+ }
93
+
94
+
95
+ class Standard:
96
+ @staticmethod
97
+ def _calc_ratio(column_name, x, y, default_value):
98
+ '''
99
+ 计算npu侧和gpu侧统计量的比值
100
+ 输入:
101
+ column_name:统计量名称
102
+ x:npu侧统计量
103
+ y:gpu侧统计量
104
+ default:当x不接近0,y接近0,设置的比值默认值
105
+ 输出:
106
+ ratio:统计量x和y的比值
107
+ inf_nan_consistency:不出现inf或nan时为True,出现inf或nan时必须同时为inf或-inf或nan才为True,否则为False
108
+ message:当出现inf或nan时的提示信息
109
+ '''
110
+ x, y = convert_str_to_float(x), convert_str_to_float(y)
111
+
112
+ if is_inf_or_nan(x) or is_inf_or_nan(y):
113
+ return check_inf_or_nan(x, y, column_name)
114
+
115
+ inf_nan_consistency = True
116
+ message = ""
117
+ if math.isclose(y, 0.0):
118
+ if math.isclose(x, 0.0):
119
+ return 1.0, inf_nan_consistency, message
120
+ else:
121
+ return default_value, inf_nan_consistency, message
122
+ else:
123
+ return abs(x / y), inf_nan_consistency, message
124
+
125
+
126
+ class BenchmarkStandard(Standard):
127
+ def __init__(self, api_name, npu_precision, gpu_precision):
128
+ self.api_name = api_name
129
+ self.npu_precision = npu_precision
130
+ self.gpu_precision = gpu_precision
131
+ self.small_value_err_ratio = 1
132
+ self.rmse_ratio = 1
133
+ self.max_rel_err_ratio = 1
134
+ self.mean_rel_err_ratio = 1
135
+ self.eb_ratio = 1
136
+ self.small_value_err_status = CompareConst.PASS
137
+ self.rmse_status = CompareConst.PASS
138
+ self.max_rel_err_status = CompareConst.PASS
139
+ self.mean_rel_err_status = CompareConst.PASS
140
+ self.eb_status = CompareConst.PASS
141
+ self.check_result_list = []
142
+ self.final_result = CompareConst.PASS
143
+ self.compare_message = ""
144
+
145
+ def __str__(self):
146
+ return "%s" % (self.api_name)
147
+
148
+ @staticmethod
149
+ def _get_status(ratio, algorithm):
150
+ if math.isnan(ratio) or math.isinf(ratio):
151
+ return CompareConst.PASS
152
+ error_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('error_threshold', DEFAULT_THRESHOLD)
153
+ warning_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('warning_threshold',
154
+ DEFAULT_THRESHOLD)
155
+ if ratio > error_threshold:
156
+ return CompareConst.ERROR
157
+ elif ratio > warning_threshold:
158
+ return CompareConst.WARNING
159
+ return CompareConst.PASS
160
+
161
+ def get_result(self):
162
+ inf_nan_consistency = self._compare_ratio()
163
+ small_value_inf_nan_consistency = inf_nan_consistency.small_value_inf_nan_consistency
164
+ rmse_inf_nan_consistency = inf_nan_consistency.rmse_inf_nan_consistency
165
+ max_rel_inf_nan_consistency = inf_nan_consistency.max_rel_inf_nan_consistency
166
+ mean_rel_inf_nan_consistency = inf_nan_consistency.mean_rel_inf_nan_consistency
167
+ eb_inf_nan_consistency = inf_nan_consistency.eb_inf_nan_consistency
168
+ self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') if \
169
+ small_value_inf_nan_consistency else CompareConst.ERROR
170
+ self.check_result_list.append(self.small_value_err_status)
171
+ self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') if rmse_inf_nan_consistency \
172
+ else CompareConst.ERROR
173
+ self.check_result_list.append(self.rmse_status)
174
+ self.max_rel_err_status = self._get_status(
175
+ self.max_rel_err_ratio, 'max_rel_err') if max_rel_inf_nan_consistency else CompareConst.ERROR
176
+ self.check_result_list.append(self.max_rel_err_status)
177
+ self.mean_rel_err_status = self._get_status(
178
+ self.mean_rel_err_ratio, 'mean_rel_err') if mean_rel_inf_nan_consistency else CompareConst.ERROR
179
+ self.check_result_list.append(self.mean_rel_err_status)
180
+ self.eb_status = self._get_status(self.eb_ratio, 'eb')
181
+ if CompareConst.ERROR in self.check_result_list:
182
+ self.final_result = CompareConst.ERROR
183
+ elif CompareConst.WARNING in self.check_result_list:
184
+ self.final_result = CompareConst.WARNING
185
+
186
+ def to_column_value(self):
187
+ return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
188
+ self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
189
+ self.mean_rel_err_status, self.eb_ratio, self.eb_status]
190
+
191
+ def _compare_ratio(self):
192
+
193
+ self.small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = self._calc_ratio(
194
+ ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE,
195
+ self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE),
196
+ self.gpu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), 10000.0)
197
+ self.compare_message += small_value_message
198
+ self.rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE,
199
+ self.npu_precision.get(ApiPrecisionCompareColumn.RMSE),
200
+ self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0)
201
+ self.compare_message += rmse_message
202
+ self.max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = self._calc_ratio(
203
+ ApiPrecisionCompareColumn.MAX_REL_ERR,
204
+ self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR),
205
+ self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0)
206
+ self.compare_message += max_rel_message
207
+ self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(
208
+ ApiPrecisionCompareColumn.MEAN_REL_ERR,
209
+ self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR),
210
+ self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0)
211
+ self.compare_message += mean_rel_message
212
+ self.eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB,
213
+ self.npu_precision.get(ApiPrecisionCompareColumn.EB),
214
+ self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0)
215
+ self.compare_message += eb_message
216
+
217
+ return BenchmarkInfNanConsistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
218
+ max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency,
219
+ eb_inf_nan_consistency)
220
+
221
+
222
+ class ULPStandard(Standard):
223
+ def __init__(self, api_name, npu_precision, gpu_precision):
224
+ self.api_name = api_name
225
+ self.npu_precision = npu_precision
226
+ self.gpu_precision = gpu_precision
227
+ self.mean_ulp_err = 0
228
+ self.ulp_err_proportion = 0
229
+ self.ulp_err_proportion_ratio = 1
230
+ self.ulp_err_status = CompareConst.PASS
231
+ self.compare_message = ""
232
+
233
+ def __str__(self):
234
+ return f"{self.api_name}"
235
+
236
+ def get_result(self):
237
+ self.mean_ulp_err = convert_str_to_float(self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
238
+ gpu_mean_ulp_err = convert_str_to_float(self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
239
+ inf_nan_consistency = True
240
+ if is_inf_or_nan(self.mean_ulp_err) or is_inf_or_nan(gpu_mean_ulp_err):
241
+ _, inf_nan_consistency, message = check_inf_or_nan(self.mean_ulp_err, gpu_mean_ulp_err,
242
+ ApiPrecisionCompareColumn.MEAN_ULP_ERR)
243
+ self.compare_message += message
244
+ self.ulp_err_proportion = convert_str_to_float(
245
+ self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION))
246
+ self.ulp_err_proportion_ratio, ulp_inf_nan_consistency, message = self._calc_ratio(
247
+ ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
248
+ self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION),
249
+ self.gpu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION), 10000.0)
250
+ inf_nan_consistency = inf_nan_consistency and ulp_inf_nan_consistency
251
+ self.compare_message += message
252
+ if inf_nan_consistency:
253
+ self.ulp_err_status = self._get_ulp_status(self.npu_precision.get(ApiPrecisionCompareColumn.DEVICE_DTYPE))
254
+ else:
255
+ self.ulp_err_status = CompareConst.ERROR
256
+
257
+ def _get_ulp_status(self, dtype):
258
+ if dtype == torch.float32:
259
+ if self.mean_ulp_err < 64:
260
+ return CompareConst.PASS
261
+ elif self.ulp_err_proportion < 0.05:
262
+ return CompareConst.PASS
263
+ elif self.ulp_err_proportion_ratio < 1:
264
+ return CompareConst.PASS
265
+ else:
266
+ self.compare_message += "ERROR: ULP误差不满足标准\n"
267
+ return CompareConst.ERROR
268
+ else:
269
+ if self.ulp_err_proportion < 0.001:
270
+ return CompareConst.PASS
271
+ elif self.ulp_err_proportion_ratio < 1:
272
+ return CompareConst.PASS
273
+ else:
274
+ self.compare_message += "ERROR: ULP误差不满足标准\n"
275
+ return CompareConst.ERROR
276
+
277
+
278
+ def write_detail_csv(content, save_path):
279
+ rows = []
280
+ content = ["{:.{}f}".format(item, msCheckerConfig.precision) \
281
+ if isinstance(item, float) else item for item in content]
282
+ rows.append(content)
283
+ write_csv(rows, save_path)
284
+
285
+
286
+ def api_precision_compare(config):
287
+ logger.info("Start compare task")
288
+ logger.info(f"Compare task result will be saved in {config.result_csv_path}")
289
+ logger.info(f"Compare task detail will be saved in {config.details_csv_path}")
290
+ try:
291
+ npu_data = read_csv(config.npu_csv_path)
292
+ except Exception as err:
293
+ logger.error(f"Open npu csv Error: %s" % str(err))
294
+ check_csv_columns(npu_data.columns, "npu_csv")
295
+ try:
296
+ gpu_data = read_csv(config.gpu_csv_path)
297
+ except Exception as err:
298
+ logger.error(f"Open gpu csv Error: %s" % str(err))
299
+ check_csv_columns(gpu_data.columns, "gpu_csv")
300
+ detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
301
+ result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
302
+ write_csv(result_csv_title, config.result_csv_path)
303
+ write_csv(detail_csv_title, config.details_csv_path)
304
+ try:
305
+ analyse_csv(npu_data, gpu_data, config)
306
+ except Exception as err:
307
+ logger.error(f"Analyse csv Error: %s" % str(err))
308
+ change_mode(config.result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
309
+ change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
310
+
311
+
312
+ def online_api_precision_compare(online_config):
313
+ rank = online_config.rank
314
+ result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace(
315
+ "_rank*.csv", f"_rank{rank}.csv")
316
+ details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace(
317
+ "_rank*.csv", f"_rank{rank}.csv")
318
+ detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
319
+ result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
320
+ if not os.path.exists(result_csv_path):
321
+ write_csv(result_csv_title, result_csv_path)
322
+ if not os.path.exists(details_csv_path):
323
+ write_csv(detail_csv_title, details_csv_path)
324
+ config = CompareConfig("", "", result_csv_path, details_csv_path)
325
+ try:
326
+ npu_data, gpu_data = online_config.npu_data, online_config.gpu_data
327
+ check_csv_columns(npu_data.columns, "npu_csv")
328
+ check_csv_columns(gpu_data.columns, "gpu_csv")
329
+ analyse_csv(npu_data, gpu_data, config)
330
+ except Exception as err:
331
+ logger.error(f"Online api precision compare Error: {str(err)}")
332
+ change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
333
+ change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
334
+
335
+
336
+ def analyse_csv(npu_data, gpu_data, config):
337
+ forward_status, backward_status = [], []
338
+ last_api_name, last_api_dtype, last_api_full_name = None, None, None
339
+ last_api_skip_message = ''
340
+ for _, row_npu in npu_data.iterrows():
341
+ message = ''
342
+ compare_column = ApiPrecisionOutputColumn()
343
+ full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
344
+ row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
345
+ api_name, api_full_name, direction_status = extract_detailed_api_segments(full_api_name_with_direction_status)
346
+ if not api_full_name:
347
+ err_message = f"The API name {full_api_name_with_direction_status} is invalid."
348
+ logger.error(err_message)
349
+ compare_column.api_name = full_api_name_with_direction_status
350
+ compare_column.compare_result = CompareConst.SKIP
351
+ compare_column.compare_message = err_message
352
+ write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
353
+ write_csv([[full_api_name_with_direction_status, CompareConst.SKIP, CompareConst.SKIP, err_message]],
354
+ config.result_csv_path)
355
+ continue
356
+ if row_gpu.empty:
357
+ logger.warning(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.')
358
+ continue
359
+ if len(row_gpu) > 1:
360
+ msg = f'This API : {full_api_name_with_direction_status} has multiple records in the GPU data.'
361
+ raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
362
+ row_gpu = row_gpu.iloc[0]
363
+ new_status = CompareConst.SPACE
364
+ try:
365
+ new_status = get_api_status(row_npu, row_gpu, api_name, compare_column)
366
+ except Exception as err:
367
+ logger.error(f"Get api status error: {str(err)}")
368
+ compare_column.api_name = full_api_name_with_direction_status
369
+ compare_column.compare_result = CompareConst.SKIP
370
+ compare_column.compare_message = str(err)
371
+ write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
372
+ write_csv([[full_api_name_with_direction_status, CompareConst.SKIP, CompareConst.SKIP, str(err)]],
373
+ config.result_csv_path)
374
+ continue
375
+
376
+ write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
377
+
378
+ if last_api_name is not None and api_full_name != last_api_name:
379
+ if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
380
+ message = UNSUPPORTED_MESSAGE
381
+ write_csv([[last_api_name, CompareConst.SKIP, CompareConst.SKIP, message]], config.result_csv_path)
382
+ print_test_success(last_api_name, CompareConst.SKIP, CompareConst.SKIP)
383
+ else:
384
+ forward_result = get_api_checker_result(forward_status)
385
+ backward_result = get_api_checker_result(backward_status)
386
+ message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
387
+ message += last_api_skip_message if forward_result == CompareConst.SKIP else ""
388
+ write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
389
+ print_test_success(last_api_name, forward_result, backward_result)
390
+ last_api_skip_message = ''
391
+ forward_status, backward_status = [], []
392
+ message = ''
393
+
394
+ is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST
395
+ last_api_name = api_full_name
396
+
397
+ last_api_dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
398
+ if not is_supported:
399
+ continue
400
+
401
+ if direction_status == 'forward':
402
+ forward_status.append(new_status)
403
+ last_api_skip_message = str(row_npu[ApiPrecisionCompareColumn.MESSAGE]) if new_status == CompareConst.SKIP \
404
+ else ''
405
+ elif direction_status == 'backward':
406
+ backward_status.append(new_status)
407
+ else:
408
+ logger.error(f"Invalid direction status: {direction_status}")
409
+
410
+ if last_api_name is not None:
411
+ if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
412
+ message = UNSUPPORTED_MESSAGE
413
+ write_csv([[last_api_name, CompareConst.SKIP, CompareConst.SKIP, message]], config.result_csv_path)
414
+ print_test_success(last_api_name, CompareConst.SKIP, CompareConst.SKIP)
415
+ else:
416
+ forward_result = get_api_checker_result(forward_status)
417
+ backward_result = get_api_checker_result(backward_status)
418
+ message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
419
+ message += last_api_skip_message if forward_result == CompareConst.SKIP else ""
420
+ write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
421
+ print_test_success(last_api_name, forward_result, backward_result)
422
+ last_api_skip_message = ''
423
+
424
+
425
+ def get_api_status(row_npu, row_gpu, api_name, compare_column):
426
+ full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
427
+ # 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对
428
+ if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace():
429
+ compare_column.api_name = full_api_name_with_direction_status
430
+ compare_column.compare_result = CompareConst.SKIP
431
+ compare_column.compare_message = row_npu[ApiPrecisionCompareColumn.MESSAGE]
432
+ new_status = CompareConst.SKIP
433
+ else:
434
+ compare_column.api_name = full_api_name_with_direction_status
435
+ if api_name in thousandth_standard_api:
436
+ new_status = record_thousandth_threshold_result(compare_column, row_npu)
437
+ elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or \
438
+ api_name in binary_standard_api:
439
+ new_status = record_binary_consistency_result(api_name, compare_column, row_npu)
440
+ elif api_name in absolute_standard_api:
441
+ new_status = record_absolute_threshold_result(compare_column, row_npu)
442
+ elif api_name in ulp_standard_api and \
443
+ row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST:
444
+ us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu)
445
+ new_status = record_ulp_compare_result(compare_column, us)
446
+ elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST:
447
+ bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu)
448
+ new_status = record_benchmark_compare_result(compare_column, bs)
449
+ return new_status
450
+
451
+
452
+ def print_test_success(api_full_name, forward_result, backward_result):
453
+ is_fwd_success = (forward_result == CompareConst.PASS)
454
+ is_bwd_success = (backward_result == CompareConst.PASS or backward_result == CompareConst.SPACE)
455
+ logger.info(f"running api_full_name {api_full_name} compare, "
456
+ f"is_fwd_success: {is_fwd_success}, "
457
+ f"is_bwd_success: {is_bwd_success}")
458
+
459
+
460
+ def check_error_rate(npu_error_rate):
461
+ return CompareConst.PASS if convert_str_to_float(npu_error_rate) == 0 else CompareConst.ERROR
462
+
463
+
464
+ def get_absolute_threshold_result(row_npu):
465
+ inf_nan_error_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO])
466
+ rel_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.REL_ERR_RATIO])
467
+ abs_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.ABS_ERR_RATIO])
468
+
469
+ inf_nan_result = CompareConst.PASS if inf_nan_error_ratio == 0 else CompareConst.ERROR
470
+ rel_err_result = CompareConst.PASS if rel_err_ratio == 0 else CompareConst.ERROR
471
+ abs_err_result = CompareConst.PASS if abs_err_ratio == 0 else CompareConst.ERROR
472
+
473
+ if CompareConst.ERROR in [inf_nan_result, rel_err_result, abs_err_result]:
474
+ absolute_threshold_result = CompareConst.ERROR
475
+ else:
476
+ absolute_threshold_result = CompareConst.PASS
477
+
478
+ return {
479
+ "inf_nan_error_ratio": inf_nan_error_ratio,
480
+ "inf_nan_result": inf_nan_result,
481
+ "rel_err_ratio": rel_err_ratio,
482
+ "rel_err_result": rel_err_result,
483
+ "abs_err_ratio": abs_err_ratio,
484
+ "abs_err_result": abs_err_result,
485
+ "absolute_threshold_result": absolute_threshold_result,
486
+ }
487
+
488
+
489
+ def get_api_checker_result(status):
490
+ if not status:
491
+ return CompareConst.SPACE
492
+ if all(item == CompareConst.SKIP for item in status):
493
+ return CompareConst.SKIP
494
+ for const in (CompareConst.ERROR, CompareConst.WARNING):
495
+ if const in status:
496
+ return const
497
+ return CompareConst.PASS
498
+
499
+
500
+ def check_csv_columns(columns, csv_type):
501
+ required_columns = ApiPrecisionCompareColumn.to_required_columns()
502
+ missing_columns = [column for column in required_columns if column not in columns]
503
+ if missing_columns:
504
+ msg = f"The following columns {','.join(missing_columns)} are missing in{csv_type}"
505
+ raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
506
+
507
+
508
+ def record_binary_consistency_result(api_name, compare_column, row_npu):
509
+ new_status = check_error_rate(row_npu[ApiPrecisionCompareColumn.ERROR_RATE])
510
+ compare_column.error_rate = row_npu[ApiPrecisionCompareColumn.ERROR_RATE]
511
+ compare_column.error_rate_status = new_status
512
+ compare_column.compare_result = new_status
513
+ compare_column.compare_algorithm = "二进制一致法"
514
+ message = ''
515
+ if compare_column.error_rate_status == CompareConst.ERROR:
516
+ message += "ERROR: 二进制一致错误率超过阈值\n"
517
+ message += CompareMessage.get(api_name, "")
518
+ compare_column.compare_message = message
519
+ return new_status
520
+
521
+
522
+ def record_absolute_threshold_result(compare_column, row_npu):
523
+ absolute_threshold_result = get_absolute_threshold_result(row_npu)
524
+ compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
525
+ compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
526
+ compare_column.rel_err_ratio = absolute_threshold_result.get("rel_err_ratio")
527
+ compare_column.rel_err_ratio_status = absolute_threshold_result.get("rel_err_result")
528
+ compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
529
+ compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
530
+ compare_column.compare_result = absolute_threshold_result.get("absolute_threshold_result")
531
+ compare_column.compare_algorithm = "绝对阈值法"
532
+ message = ''
533
+ if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
534
+ message += "ERROR: inf/nan错误率超过阈值\n"
535
+ if compare_column.rel_err_ratio_status == CompareConst.ERROR:
536
+ message += "ERROR: 相对误差错误率超过阈值\n"
537
+ if compare_column.abs_err_ratio_status == CompareConst.ERROR:
538
+ message += "ERROR: 绝对误差错误率超过阈值\n"
539
+ compare_column.compare_message = message
540
+ return compare_column.compare_result
541
+
542
+
543
+ def record_benchmark_compare_result(compare_column, bs):
544
+ bs.get_result()
545
+ compare_column.small_value_err_ratio = bs.small_value_err_ratio
546
+ compare_column.small_value_err_status = bs.small_value_err_status
547
+ compare_column.rmse_ratio = bs.rmse_ratio
548
+ compare_column.rmse_status = bs.rmse_status
549
+ compare_column.max_rel_err_ratio = bs.max_rel_err_ratio
550
+ compare_column.max_rel_err_status = bs.max_rel_err_status
551
+ compare_column.mean_rel_err_ratio = bs.mean_rel_err_ratio
552
+ compare_column.mean_rel_err_status = bs.mean_rel_err_status
553
+ compare_column.eb_ratio = bs.eb_ratio
554
+ compare_column.eb_status = bs.eb_status
555
+ compare_column.compare_result = bs.final_result
556
+ compare_column.compare_algorithm = "标杆比对法"
557
+ compare_column.compare_message = bs.compare_message
558
+ for status_attr, messages in benchmark_message.items():
559
+ status_value = getattr(compare_column, status_attr)
560
+ if status_value in messages:
561
+ compare_column.compare_message += messages[status_value]
562
+ return compare_column.compare_result
563
+
564
+
565
+ def record_ulp_compare_result(compare_column, us):
566
+ us.get_result()
567
+ compare_column.mean_ulp_err = us.mean_ulp_err
568
+ compare_column.ulp_err_proportion = us.ulp_err_proportion
569
+ compare_column.ulp_err_proportion_ratio = us.ulp_err_proportion_ratio
570
+ compare_column.ulp_err_status = us.ulp_err_status
571
+ compare_column.compare_result = us.ulp_err_status
572
+ compare_column.compare_algorithm = "ULP误差比对法"
573
+ compare_column.compare_message = us.compare_message
574
+ return compare_column.compare_result
575
+
576
+
577
+ def check_thousandth_rate(thousandth_rate):
578
+ return CompareConst.PASS if convert_str_to_float(thousandth_rate) >= 0.999 else CompareConst.ERROR
579
+
580
+
581
+ def record_thousandth_threshold_result(compare_column, row_npu):
582
+ new_status = check_thousandth_rate(row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH])
583
+ compare_column.rel_err_thousandth = row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
584
+ compare_column.rel_err_thousandth_status = new_status
585
+ compare_column.compare_result = new_status
586
+ compare_column.compare_algorithm = "双千指标法"
587
+ message = ''
588
+ if compare_column.rel_err_thousandth_status == CompareConst.ERROR:
589
+ message += "ERROR: 双千指标不达标\n"
590
+ compare_column.compare_message = message
591
+ return compare_column.compare_result
592
+
593
+
594
+ def _api_precision_compare(parser=None):
595
+ if not parser:
596
+ parser = argparse.ArgumentParser()
597
+ _api_precision_compare_parser(parser)
598
+ args = parser.parse_args(sys.argv[1:])
599
+ _api_precision_compare_command(args)
600
+
601
+
602
+ def _api_precision_compare_command(args):
603
+ npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail')
604
+ gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail')
605
+ out_path = os.path.realpath(args.out_path) if args.out_path else "./"
606
+ check_path_before_create(out_path)
607
+ create_directory(out_path)
608
+ out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
609
+ out_path = out_path_checker.common_check()
610
+ result_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_RESULT_FILE_NAME)
611
+ details_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_DETAILS_FILE_NAME)
612
+ compare_config = CompareConfig(npu_csv_path, gpu_csv_path, result_csv_path, details_csv_path)
613
+ api_precision_compare(compare_config)
614
+
615
+
616
+ def _api_precision_compare_parser(parser):
617
+ parser.add_argument("-npu", "--npu_csv_path", dest="npu_csv_path", default="", type=str,
618
+ help="<Required> , Accuracy_checking_details.csv generated on the NPU by using the "
619
+ "api_accuracy_checker tool.",
620
+ required=True)
621
+ parser.add_argument("-gpu", "--gpu_csv_path", dest="gpu_csv_path", default="", type=str,
622
+ help="<Required> Accuracy_checking_details.csv generated on the GPU by using the "
623
+ "api_accuracy_checker tool.",
624
+ required=False)
625
+ parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
626
+ help="<optional> The api precision compare task result out path.",
627
+ required=False)
628
+
629
+
630
+ if __name__ == '__main__':
631
+ _api_precision_compare()
632
+ logger.info("Compare task completed.")