mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,545 +1,606 @@
1
- import argparse
2
- import math
3
- import os
4
- import sys
5
- from collections import namedtuple
6
-
7
- import torch
8
- import pandas as pd
9
-
10
- from msprobe.pytorch.api_accuracy_checker.common.utils import write_csv
11
- from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
12
- from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
13
- API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
14
- ApiPrecisionCompareColumn, AbsoluteStandardApi, BinaryStandardApi, ULPStandardApi, ThousandthStandardApi, \
15
- BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage, is_inf_or_nan, \
16
- check_inf_or_nan
17
- from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
18
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path
19
- from msprobe.core.common.file_check import FileChecker, change_mode, check_path_before_create, create_directory
20
- from msprobe.pytorch.common.log import logger
21
- from msprobe.core.common.utils import CompareException
22
- from msprobe.core.common.const import CompareConst, FileCheckConst
23
-
24
- CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
25
- BenchmarkInf_Nan_Consistency = namedtuple('BenchmarkInf_Nan_Consistency', ['small_value_inf_nan_consistency',
26
- 'rmse_inf_nan_consistency',
27
- 'max_rel_inf_nan_consistency',
28
- 'mean_rel_inf_nan_consistency',
29
- 'eb_inf_nan_consistency'])
30
- unsupported_message = 'This data type does not support benchmark compare.'
31
-
32
- DEFAULT_THRESHOLD = 1
33
-
34
- benchmark_algorithms_thresholds = {
35
- 'small_value': {
36
- 'error_threshold': 2,
37
- 'warning_threshold': 1
38
- },
39
- 'rmse': {
40
- 'error_threshold': 2,
41
- 'warning_threshold': 1
42
- },
43
- 'max_rel_err': {
44
- 'error_threshold': 10,
45
- 'warning_threshold': 1
46
- },
47
- 'mean_rel_err': {
48
- 'error_threshold': 2,
49
- 'warning_threshold': 1
50
- },
51
- 'eb': {
52
- 'error_threshold': 2,
53
- 'warning_threshold': 1
54
- }
55
- }
56
-
57
- benchmark_message = {
58
- "small_value_err_status": {
59
- CompareConst.ERROR: "ERROR: 小值域错误比值超过阈值\n",
60
- CompareConst.WARNING: "WARNING: 小值域错误比值超过阈值\n"
61
- },
62
- "rmse_status": {
63
- CompareConst.ERROR: "ERROR: 均方根误差比值超过阈值\n",
64
- CompareConst.WARNING: "WARNING: 均方根误差比值超过阈值\n"
65
- },
66
- "max_rel_err_status": {
67
- CompareConst.ERROR: "ERROR: 相对误差最大值比值超过阈值\n",
68
- CompareConst.WARNING: "WARNING: 相对误差最大值比值超过阈值\n"
69
- },
70
- "mean_rel_err_status": {
71
- CompareConst.ERROR: "ERROR: 相对误差平均值比值超过阈值\n",
72
- CompareConst.WARNING: "WARNING: 相对误差平均值比值超过阈值\n"
73
- }
74
- }
75
-
76
-
77
- class Standard:
78
- @staticmethod
79
- def _calc_ratio(column_name, x, y, default_value):
80
- '''
81
- 计算npu侧和gpu侧统计量的比值
82
- 输入:
83
- column_name:统计量名称
84
- x:npu侧统计量
85
- ygpu侧统计量
86
- default:当x不接近0,y接近0,设置的比值默认值
87
- 输出:
88
- ratio:统计量x和y的比值
89
- inf_nan_consistency:不出现inf或nan时为True,出现inf或nan时必须同时为inf或-inf或nan才为True,否则为False
90
- message:当出现inf或nan时的提示信息
91
- '''
92
- x, y = convert_str_to_float(x), convert_str_to_float(y)
93
-
94
- if is_inf_or_nan(x) or is_inf_or_nan(y):
95
- return check_inf_or_nan(x, y, column_name)
96
-
97
- inf_nan_consistency = True
98
- message = ""
99
- if math.isclose(y, 0.0):
100
- if math.isclose(x, 0.0):
101
- return 1.0, inf_nan_consistency, message
102
- else:
103
- return default_value, inf_nan_consistency, message
104
- else:
105
- return abs(x / y), inf_nan_consistency, message
106
-
107
-
108
- class BenchmarkStandard(Standard):
109
- def __init__(self, api_name, npu_precision, gpu_precision):
110
- self.api_name = api_name
111
- self.npu_precision = npu_precision
112
- self.gpu_precision = gpu_precision
113
- self.small_value_err_ratio = 1
114
- self.rmse_ratio = 1
115
- self.max_rel_err_ratio = 1
116
- self.mean_rel_err_ratio = 1
117
- self.eb_ratio = 1
118
- self.small_value_err_status = CompareConst.PASS
119
- self.rmse_status = CompareConst.PASS
120
- self.max_rel_err_status = CompareConst.PASS
121
- self.mean_rel_err_status = CompareConst.PASS
122
- self.eb_status = CompareConst.PASS
123
- self.check_result_list = []
124
- self.final_result = CompareConst.PASS
125
- self.compare_message = ""
126
-
127
- def __str__(self):
128
- return "%s" % (self.api_name)
129
-
130
- @staticmethod
131
- def _get_status(ratio, algorithm):
132
- if math.isnan(ratio) or math.isinf(ratio):
133
- return CompareConst.PASS
134
- error_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('error_threshold', DEFAULT_THRESHOLD)
135
- warning_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('warning_threshold',
136
- DEFAULT_THRESHOLD)
137
- if ratio > error_threshold:
138
- return CompareConst.ERROR
139
- elif ratio > warning_threshold:
140
- return CompareConst.WARNING
141
- return CompareConst.PASS
142
-
143
- def get_result(self):
144
- inf_nan_consistency = self._compare_ratio()
145
- small_value_inf_nan_consistency = inf_nan_consistency.small_value_inf_nan_consistency
146
- rmse_inf_nan_consistency = inf_nan_consistency.rmse_inf_nan_consistency
147
- max_rel_inf_nan_consistency = inf_nan_consistency.max_rel_inf_nan_consistency
148
- mean_rel_inf_nan_consistency = inf_nan_consistency.mean_rel_inf_nan_consistency
149
- eb_inf_nan_consistency = inf_nan_consistency.eb_inf_nan_consistency
150
- self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') if \
151
- small_value_inf_nan_consistency else CompareConst.ERROR
152
- self.check_result_list.append(self.small_value_err_status)
153
- self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') if rmse_inf_nan_consistency \
154
- else CompareConst.ERROR
155
- self.check_result_list.append(self.rmse_status)
156
- self.max_rel_err_status = self._get_status(self.max_rel_err_ratio, 'max_rel_err') if max_rel_inf_nan_consistency \
157
- else CompareConst.ERROR
158
- self.check_result_list.append(self.max_rel_err_status)
159
- self.mean_rel_err_status = self._get_status(self.mean_rel_err_ratio, 'mean_rel_err') if mean_rel_inf_nan_consistency \
160
- else CompareConst.ERROR
161
- self.check_result_list.append(self.mean_rel_err_status)
162
- self.eb_status = self._get_status(self.eb_ratio, 'eb')
163
- if CompareConst.ERROR in self.check_result_list:
164
- self.final_result = CompareConst.ERROR
165
- elif CompareConst.WARNING in self.check_result_list:
166
- self.final_result = CompareConst.WARNING
167
-
168
- def to_column_value(self):
169
- return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
170
- self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
171
- self.mean_rel_err_status, self.eb_ratio, self.eb_status]
172
-
173
- def _compare_ratio(self):
174
-
175
- self.small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = self._calc_ratio(
176
- ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE,
177
- self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE),
178
- self.gpu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), 10000.0)
179
- self.compare_message += small_value_message
180
- self.rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE,
181
- self.npu_precision.get(ApiPrecisionCompareColumn.RMSE),
182
- self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0)
183
- self.compare_message += rmse_message
184
- self.max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = self._calc_ratio(
185
- ApiPrecisionCompareColumn.MAX_REL_ERR,
186
- self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR),
187
- self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0)
188
- self.compare_message += max_rel_message
189
- self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(ApiPrecisionCompareColumn.MEAN_REL_ERR,
190
- self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR),
191
- self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0)
192
- self.compare_message += mean_rel_message
193
- self.eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB,
194
- self.npu_precision.get(ApiPrecisionCompareColumn.EB),
195
- self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0)
196
- self.compare_message += eb_message
197
-
198
- return BenchmarkInf_Nan_Consistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
199
- max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency, eb_inf_nan_consistency)
200
-
201
-
202
- class ULPStandard(Standard):
203
- def __init__(self, api_name, npu_precision, gpu_precision):
204
- self.api_name = api_name
205
- self.npu_precision = npu_precision
206
- self.gpu_precision = gpu_precision
207
- self.mean_ulp_err = 0
208
- self.ulp_err_proportion = 0
209
- self.ulp_err_proportion_ratio = 1
210
- self.ulp_err_status = CompareConst.PASS
211
- self.compare_message = ""
212
-
213
- def __str__(self):
214
- return f"{self.api_name}"
215
-
216
- def get_result(self):
217
- self.mean_ulp_err = convert_str_to_float(self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
218
- gpu_mean_ulp_err = convert_str_to_float(self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
219
- inf_nan_consistency = True
220
- if is_inf_or_nan(self.mean_ulp_err) or is_inf_or_nan(gpu_mean_ulp_err):
221
- _, inf_nan_consistency, message = check_inf_or_nan(self.mean_ulp_err, gpu_mean_ulp_err,
222
- ApiPrecisionCompareColumn.MEAN_ULP_ERR)
223
- self.compare_message += message
224
- self.ulp_err_proportion = convert_str_to_float(
225
- self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION))
226
- self.ulp_err_proportion_ratio, ulp_inf_nan_consistency, message = self._calc_ratio(
227
- ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
228
- self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION),
229
- self.gpu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION), 10000.0)
230
- inf_nan_consistency = inf_nan_consistency and ulp_inf_nan_consistency
231
- self.compare_message += message
232
- if inf_nan_consistency:
233
- self.ulp_err_status = self._get_ulp_status(self.npu_precision.get(ApiPrecisionCompareColumn.DEVICE_DTYPE))
234
- else:
235
- self.ulp_err_status = CompareConst.ERROR
236
-
237
- def _get_ulp_status(self, dtype):
238
- if dtype == torch.float32:
239
- if self.mean_ulp_err < 64:
240
- return CompareConst.PASS
241
- elif self.ulp_err_proportion < 0.05:
242
- return CompareConst.PASS
243
- elif self.ulp_err_proportion_ratio < 1:
244
- return CompareConst.PASS
245
- else:
246
- self.compare_message += "ERROR: ULP误差不满足标准\n"
247
- return CompareConst.ERROR
248
- else:
249
- if self.ulp_err_proportion < 0.001:
250
- return CompareConst.PASS
251
- elif self.ulp_err_proportion_ratio < 1:
252
- return CompareConst.PASS
253
- else:
254
- self.compare_message += "ERROR: ULP误差不满足标准\n"
255
- return CompareConst.ERROR
256
-
257
-
258
- def write_detail_csv(content, save_path):
259
- rows = []
260
- content = ["{:.{}f}".format(item, msCheckerConfig.precision) \
261
- if isinstance(item, float) else item for item in content]
262
- rows.append(content)
263
- write_csv(rows, save_path)
264
-
265
-
266
- def api_precision_compare(config):
267
- logger.info("Start compare task")
268
- logger.info(f"Compare task result will be saved in {config.result_csv_path}")
269
- logger.info(f"Compare task detail will be saved in {config.details_csv_path}")
270
- try:
271
- npu_data = pd.read_csv(config.npu_csv_path)
272
- except Exception as err:
273
- logger.error(f"Open npu csv Error: %s" % str(err))
274
- check_csv_columns(npu_data.columns, "npu_csv")
275
- try:
276
- gpu_data = pd.read_csv(config.gpu_csv_path)
277
- except Exception as err:
278
- logger.error(f"Open gpu csv Error: %s" % str(err))
279
- check_csv_columns(gpu_data.columns, "gpu_csv")
280
- detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
281
- result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
282
- write_csv(result_csv_title, config.result_csv_path)
283
- write_csv(detail_csv_title, config.details_csv_path)
284
- try:
285
- analyse_csv(npu_data, gpu_data, config)
286
- except Exception as err:
287
- logger.error(f"Analyse csv Error: %s" % str(err))
288
- change_mode(config.result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
289
- change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
290
-
291
-
292
- def analyse_csv(npu_data, gpu_data, config):
293
- forward_status, backward_status = [], []
294
- last_api_name, last_api_dtype = None, None
295
- for _, row_npu in npu_data.iterrows():
296
- message = ''
297
- compare_column = ApiPrecisionOutputColumn()
298
- full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
299
- row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
300
- _, api_name, _, direction_status, _, _ = full_api_name_with_direction_status.split(".")
301
- if row_gpu.empty:
302
- logger.warning(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.')
303
- continue
304
- if len(row_gpu) > 1:
305
- msg = f'This API : {full_api_name_with_direction_status} has multiple records in the GPU data.'
306
- raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
307
- row_gpu = row_gpu.iloc[0]
308
- new_status = CompareConst.SPACE
309
- # 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对
310
- if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace():
311
- compare_column.api_name = full_api_name_with_direction_status
312
- compare_column.compare_result = CompareConst.SKIP
313
- compare_column.compare_message = row_npu[ApiPrecisionCompareColumn.MESSAGE]
314
- new_status = CompareConst.SKIP
315
- write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
316
- else:
317
- compare_column.api_name = full_api_name_with_direction_status
318
- if api_name in ThousandthStandardApi:
319
- new_status = record_thousandth_threshold_result(compare_column, row_npu)
320
- elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or \
321
- api_name in BinaryStandardApi:
322
- new_status = record_binary_consistency_result(api_name, compare_column, row_npu)
323
- elif api_name in AbsoluteStandardApi:
324
- new_status = record_absolute_threshold_result(compare_column, row_npu)
325
- elif api_name in ULPStandardApi and \
326
- row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST:
327
- us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu)
328
- new_status = record_ulp_compare_result(compare_column, us)
329
- elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST:
330
- bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu)
331
- new_status = record_benchmark_compare_result(compare_column, bs)
332
- write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
333
-
334
- if last_api_name is not None and api_name != last_api_name:
335
- if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
336
- message = unsupported_message
337
- write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path)
338
- forward_status, backward_status = [], []
339
- message = ''
340
- else:
341
- forward_result = get_api_checker_result(forward_status)
342
- backward_result = get_api_checker_result(backward_status)
343
- message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
344
- write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
345
- forward_status, backward_status = [], []
346
- message = ''
347
-
348
- is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST
349
- last_api_name = api_name
350
-
351
- last_api_dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
352
- if not is_supported:
353
- continue
354
-
355
- if direction_status == 'forward':
356
- forward_status.append(new_status)
357
- elif direction_status == 'backward':
358
- backward_status.append(new_status)
359
- else:
360
- logger.error(f"Invalid direction status: {direction_status}")
361
-
362
- if last_api_name is not None:
363
- if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
364
- message = unsupported_message
365
- write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path)
366
- else:
367
- forward_result = get_api_checker_result(forward_status)
368
- backward_result = get_api_checker_result(backward_status)
369
- message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
370
- write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
371
-
372
-
373
- def check_error_rate(npu_error_rate):
374
- return CompareConst.PASS if convert_str_to_float(npu_error_rate) == 0 else CompareConst.ERROR
375
-
376
-
377
- def get_absolute_threshold_result(row_npu):
378
- inf_nan_error_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO])
379
- rel_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.REL_ERR_RATIO])
380
- abs_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.ABS_ERR_RATIO])
381
-
382
- inf_nan_result = CompareConst.PASS if inf_nan_error_ratio == 0 else CompareConst.ERROR
383
- rel_err_result = CompareConst.PASS if rel_err_ratio == 0 else CompareConst.ERROR
384
- abs_err_result = CompareConst.PASS if abs_err_ratio == 0 else CompareConst.ERROR
385
-
386
- if CompareConst.ERROR in [inf_nan_result, rel_err_result, abs_err_result]:
387
- absolute_threshold_result = CompareConst.ERROR
388
- else:
389
- absolute_threshold_result = CompareConst.PASS
390
-
391
- return {
392
- "inf_nan_error_ratio": inf_nan_error_ratio,
393
- "inf_nan_result": inf_nan_result,
394
- "rel_err_ratio": rel_err_ratio,
395
- "rel_err_result": rel_err_result,
396
- "abs_err_ratio": abs_err_ratio,
397
- "abs_err_result": abs_err_result,
398
- "absolute_threshold_result": absolute_threshold_result,
399
- }
400
-
401
-
402
- def get_api_checker_result(status):
403
- if not status:
404
- return CompareConst.SPACE
405
- if all(item == CompareConst.SKIP for item in status):
406
- return CompareConst.SKIP
407
- for const in (CompareConst.ERROR, CompareConst.WARNING):
408
- if const in status:
409
- return const
410
- return CompareConst.PASS
411
-
412
-
413
- def check_csv_columns(columns, csv_type):
414
- required_columns = ApiPrecisionCompareColumn.to_required_columns()
415
- missing_columns = [column for column in required_columns if column not in columns]
416
- if missing_columns:
417
- msg = f"The following columns {','.join(missing_columns)} are missing in{csv_type}"
418
- raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
419
-
420
-
421
- def record_binary_consistency_result(api_name, compare_column, row_npu):
422
- new_status = check_error_rate(row_npu[ApiPrecisionCompareColumn.ERROR_RATE])
423
- compare_column.error_rate = row_npu[ApiPrecisionCompareColumn.ERROR_RATE]
424
- compare_column.error_rate_status = new_status
425
- compare_column.compare_result = new_status
426
- compare_column.compare_algorithm = "二进制一致法"
427
- message = ''
428
- if compare_column.error_rate_status == CompareConst.ERROR:
429
- message += "ERROR: 二进制一致错误率超过阈值\n"
430
- message += CompareMessage.get(api_name, "")
431
- compare_column.compare_message = message
432
- return new_status
433
-
434
-
435
- def record_absolute_threshold_result(compare_column, row_npu):
436
- absolute_threshold_result = get_absolute_threshold_result(row_npu)
437
- compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
438
- compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
439
- compare_column.rel_err_ratio = absolute_threshold_result.get("rel_err_ratio")
440
- compare_column.rel_err_ratio_status = absolute_threshold_result.get("rel_err_result")
441
- compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
442
- compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
443
- compare_column.compare_result = absolute_threshold_result.get("absolute_threshold_result")
444
- compare_column.compare_algorithm = "绝对阈值法"
445
- message = ''
446
- if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
447
- message += "ERROR: inf/nan错误率超过阈值\n"
448
- if compare_column.rel_err_ratio_status == CompareConst.ERROR:
449
- message += "ERROR: 相对误差错误率超过阈值\n"
450
- if compare_column.abs_err_ratio_status == CompareConst.ERROR:
451
- message += "ERROR: 绝对误差错误率超过阈值\n"
452
- compare_column.compare_message = message
453
- return compare_column.compare_result
454
-
455
-
456
- def record_benchmark_compare_result(compare_column, bs):
457
- bs.get_result()
458
- compare_column.small_value_err_ratio = bs.small_value_err_ratio
459
- compare_column.small_value_err_status = bs.small_value_err_status
460
- compare_column.rmse_ratio = bs.rmse_ratio
461
- compare_column.rmse_status = bs.rmse_status
462
- compare_column.max_rel_err_ratio = bs.max_rel_err_ratio
463
- compare_column.max_rel_err_status = bs.max_rel_err_status
464
- compare_column.mean_rel_err_ratio = bs.mean_rel_err_ratio
465
- compare_column.mean_rel_err_status = bs.mean_rel_err_status
466
- compare_column.eb_ratio = bs.eb_ratio
467
- compare_column.eb_status = bs.eb_status
468
- compare_column.compare_result = bs.final_result
469
- compare_column.compare_algorithm = "标杆比对法"
470
- compare_column.compare_message = bs.compare_message
471
- for status_attr, messages in benchmark_message.items():
472
- status_value = getattr(compare_column, status_attr)
473
- if status_value in messages:
474
- compare_column.compare_message += messages[status_value]
475
- return compare_column.compare_result
476
-
477
-
478
- def record_ulp_compare_result(compare_column, us):
479
- us.get_result()
480
- compare_column.mean_ulp_err = us.mean_ulp_err
481
- compare_column.ulp_err_proportion = us.ulp_err_proportion
482
- compare_column.ulp_err_proportion_ratio = us.ulp_err_proportion_ratio
483
- compare_column.ulp_err_status = us.ulp_err_status
484
- compare_column.compare_result = us.ulp_err_status
485
- compare_column.compare_algorithm = "ULP误差比对法"
486
- compare_column.compare_message = us.compare_message
487
- return compare_column.compare_result
488
-
489
-
490
- def check_thousandth_rate(thousandth_rate):
491
- return CompareConst.PASS if convert_str_to_float(thousandth_rate) >= 0.999 else CompareConst.ERROR
492
-
493
-
494
- def record_thousandth_threshold_result(compare_column, row_npu):
495
- new_status = check_thousandth_rate(row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH])
496
- compare_column.rel_err_thousandth = row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
497
- compare_column.rel_err_thousandth_status = new_status
498
- compare_column.compare_result = new_status
499
- compare_column.compare_algorithm = "双千指标法"
500
- message = ''
501
- if compare_column.rel_err_thousandth_status == CompareConst.ERROR:
502
- message += "ERROR: 双千指标不达标\n"
503
- compare_column.compare_message = message
504
- return compare_column.compare_result
505
-
506
-
507
- def _api_precision_compare(parser=None):
508
- if not parser:
509
- parser = argparse.ArgumentParser()
510
- _api_precision_compare_parser(parser)
511
- args = parser.parse_args(sys.argv[1:])
512
- _api_precision_compare_command(args)
513
-
514
-
515
- def _api_precision_compare_command(args):
516
- npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail')
517
- gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail')
518
- out_path = os.path.realpath(args.out_path) if args.out_path else "./"
519
- check_path_before_create(out_path)
520
- create_directory(out_path)
521
- out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
522
- out_path = out_path_checker.common_check()
523
- result_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_RESULT_FILE_NAME)
524
- details_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_DETAILS_FILE_NAME)
525
- compare_config = CompareConfig(npu_csv_path, gpu_csv_path, result_csv_path, details_csv_path)
526
- api_precision_compare(compare_config)
527
-
528
-
529
- def _api_precision_compare_parser(parser):
530
- parser.add_argument("-npu", "--npu_csv_path", dest="npu_csv_path", default="", type=str,
531
- help="<Required> , Accuracy_checking_details.csv generated on the NPU by using the "
532
- "api_accuracy_checker tool.",
533
- required=True)
534
- parser.add_argument("-gpu", "--gpu_csv_path", dest="gpu_csv_path", default="", type=str,
535
- help="<Required> Accuracy_checking_details.csv generated on the GPU by using the "
536
- "api_accuracy_checker tool.",
537
- required=False)
538
- parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
539
- help="<optional> The api precision compare task result out path.",
540
- required=False)
541
-
542
-
543
- if __name__ == '__main__':
544
- _api_precision_compare()
545
- logger.info("Compare task completed.")
1
+ import argparse
2
+ import math
3
+ import os
4
+ import sys
5
+ from collections import namedtuple
6
+
7
+ import torch
8
+ import pandas as pd
9
+
10
+ from msprobe.core.common.file_utils import write_csv
11
+ from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
12
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
13
+ API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
14
+ ApiPrecisionCompareColumn, absolute_standard_api, binary_standard_api, ulp_standard_api, thousandth_standard_api, \
15
+ BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage, is_inf_or_nan, \
16
+ check_inf_or_nan
17
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
18
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path
19
+ from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments
20
+ from msprobe.core.common.file_utils import FileChecker, change_mode, check_path_before_create, create_directory
21
+ from msprobe.pytorch.common.log import logger
22
+ from msprobe.core.common.utils import CompareException
23
+ from msprobe.core.common.const import Const, CompareConst, FileCheckConst
24
+
25
+ CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
26
+ BenchmarkInf_Nan_Consistency = namedtuple('BenchmarkInf_Nan_Consistency', ['small_value_inf_nan_consistency',
27
+ 'rmse_inf_nan_consistency',
28
+ 'max_rel_inf_nan_consistency',
29
+ 'mean_rel_inf_nan_consistency',
30
+ 'eb_inf_nan_consistency'])
31
+ unsupported_message = 'This data type does not support benchmark compare.'
32
+
33
+ DEFAULT_THRESHOLD = 1
34
+
35
+ benchmark_algorithms_thresholds = {
36
+ 'small_value': {
37
+ 'error_threshold': 2,
38
+ 'warning_threshold': 1
39
+ },
40
+ 'rmse': {
41
+ 'error_threshold': 2,
42
+ 'warning_threshold': 1
43
+ },
44
+ 'max_rel_err': {
45
+ 'error_threshold': 10,
46
+ 'warning_threshold': 1
47
+ },
48
+ 'mean_rel_err': {
49
+ 'error_threshold': 2,
50
+ 'warning_threshold': 1
51
+ },
52
+ 'eb': {
53
+ 'error_threshold': 2,
54
+ 'warning_threshold': 1
55
+ }
56
+ }
57
+
58
+ benchmark_message = {
59
+ "small_value_err_status": {
60
+ CompareConst.ERROR: "ERROR: 小值域错误比值超过阈值\n",
61
+ CompareConst.WARNING: "WARNING: 小值域错误比值超过阈值\n"
62
+ },
63
+ "rmse_status": {
64
+ CompareConst.ERROR: "ERROR: 均方根误差比值超过阈值\n",
65
+ CompareConst.WARNING: "WARNING: 均方根误差比值超过阈值\n"
66
+ },
67
+ "max_rel_err_status": {
68
+ CompareConst.ERROR: "ERROR: 相对误差最大值比值超过阈值\n",
69
+ CompareConst.WARNING: "WARNING: 相对误差最大值比值超过阈值\n"
70
+ },
71
+ "mean_rel_err_status": {
72
+ CompareConst.ERROR: "ERROR: 相对误差平均值比值超过阈值\n",
73
+ CompareConst.WARNING: "WARNING: 相对误差平均值比值超过阈值\n"
74
+ }
75
+ }
76
+
77
+
78
+ class Standard:
79
+ @staticmethod
80
+ def _calc_ratio(column_name, x, y, default_value):
81
+ '''
82
+ 计算npu侧和gpu侧统计量的比值
83
+ 输入:
84
+ column_name:统计量名称
85
+ xnpu侧统计量
86
+ y:gpu侧统计量
87
+ default:当x不接近0,y接近0,设置的比值默认值
88
+ 输出:
89
+ ratio:统计量x和y的比值
90
+ inf_nan_consistency:不出现inf或nan时为True,出现inf或nan时必须同时为inf或-inf或nan才为True,否则为False
91
+ message:当出现inf或nan时的提示信息
92
+ '''
93
+ x, y = convert_str_to_float(x), convert_str_to_float(y)
94
+
95
+ if is_inf_or_nan(x) or is_inf_or_nan(y):
96
+ return check_inf_or_nan(x, y, column_name)
97
+
98
+ inf_nan_consistency = True
99
+ message = ""
100
+ if math.isclose(y, 0.0):
101
+ if math.isclose(x, 0.0):
102
+ return 1.0, inf_nan_consistency, message
103
+ else:
104
+ return default_value, inf_nan_consistency, message
105
+ else:
106
+ return abs(x / y), inf_nan_consistency, message
107
+
108
+
109
+ class BenchmarkStandard(Standard):
110
+ def __init__(self, api_name, npu_precision, gpu_precision):
111
+ self.api_name = api_name
112
+ self.npu_precision = npu_precision
113
+ self.gpu_precision = gpu_precision
114
+ self.small_value_err_ratio = 1
115
+ self.rmse_ratio = 1
116
+ self.max_rel_err_ratio = 1
117
+ self.mean_rel_err_ratio = 1
118
+ self.eb_ratio = 1
119
+ self.small_value_err_status = CompareConst.PASS
120
+ self.rmse_status = CompareConst.PASS
121
+ self.max_rel_err_status = CompareConst.PASS
122
+ self.mean_rel_err_status = CompareConst.PASS
123
+ self.eb_status = CompareConst.PASS
124
+ self.check_result_list = []
125
+ self.final_result = CompareConst.PASS
126
+ self.compare_message = ""
127
+
128
+ def __str__(self):
129
+ return "%s" % (self.api_name)
130
+
131
+ @staticmethod
132
+ def _get_status(ratio, algorithm):
133
+ if math.isnan(ratio) or math.isinf(ratio):
134
+ return CompareConst.PASS
135
+ error_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('error_threshold', DEFAULT_THRESHOLD)
136
+ warning_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('warning_threshold',
137
+ DEFAULT_THRESHOLD)
138
+ if ratio > error_threshold:
139
+ return CompareConst.ERROR
140
+ elif ratio > warning_threshold:
141
+ return CompareConst.WARNING
142
+ return CompareConst.PASS
143
+
144
+ def get_result(self):
145
+ inf_nan_consistency = self._compare_ratio()
146
+ small_value_inf_nan_consistency = inf_nan_consistency.small_value_inf_nan_consistency
147
+ rmse_inf_nan_consistency = inf_nan_consistency.rmse_inf_nan_consistency
148
+ max_rel_inf_nan_consistency = inf_nan_consistency.max_rel_inf_nan_consistency
149
+ mean_rel_inf_nan_consistency = inf_nan_consistency.mean_rel_inf_nan_consistency
150
+ eb_inf_nan_consistency = inf_nan_consistency.eb_inf_nan_consistency
151
+ self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') if \
152
+ small_value_inf_nan_consistency else CompareConst.ERROR
153
+ self.check_result_list.append(self.small_value_err_status)
154
+ self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') if rmse_inf_nan_consistency \
155
+ else CompareConst.ERROR
156
+ self.check_result_list.append(self.rmse_status)
157
+ self.max_rel_err_status = self._get_status(self.max_rel_err_ratio, 'max_rel_err') if max_rel_inf_nan_consistency \
158
+ else CompareConst.ERROR
159
+ self.check_result_list.append(self.max_rel_err_status)
160
+ self.mean_rel_err_status = self._get_status(self.mean_rel_err_ratio, 'mean_rel_err') if mean_rel_inf_nan_consistency \
161
+ else CompareConst.ERROR
162
+ self.check_result_list.append(self.mean_rel_err_status)
163
+ self.eb_status = self._get_status(self.eb_ratio, 'eb')
164
+ if CompareConst.ERROR in self.check_result_list:
165
+ self.final_result = CompareConst.ERROR
166
+ elif CompareConst.WARNING in self.check_result_list:
167
+ self.final_result = CompareConst.WARNING
168
+
169
+ def to_column_value(self):
170
+ return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
171
+ self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
172
+ self.mean_rel_err_status, self.eb_ratio, self.eb_status]
173
+
174
+ def _compare_ratio(self):
175
+
176
+ self.small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = self._calc_ratio(
177
+ ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE,
178
+ self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE),
179
+ self.gpu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), 10000.0)
180
+ self.compare_message += small_value_message
181
+ self.rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE,
182
+ self.npu_precision.get(ApiPrecisionCompareColumn.RMSE),
183
+ self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0)
184
+ self.compare_message += rmse_message
185
+ self.max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = self._calc_ratio(
186
+ ApiPrecisionCompareColumn.MAX_REL_ERR,
187
+ self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR),
188
+ self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0)
189
+ self.compare_message += max_rel_message
190
+ self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(ApiPrecisionCompareColumn.MEAN_REL_ERR,
191
+ self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR),
192
+ self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0)
193
+ self.compare_message += mean_rel_message
194
+ self.eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB,
195
+ self.npu_precision.get(ApiPrecisionCompareColumn.EB),
196
+ self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0)
197
+ self.compare_message += eb_message
198
+
199
+ return BenchmarkInf_Nan_Consistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
200
+ max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency, eb_inf_nan_consistency)
201
+
202
+
203
+ class ULPStandard(Standard):
204
+ def __init__(self, api_name, npu_precision, gpu_precision):
205
+ self.api_name = api_name
206
+ self.npu_precision = npu_precision
207
+ self.gpu_precision = gpu_precision
208
+ self.mean_ulp_err = 0
209
+ self.ulp_err_proportion = 0
210
+ self.ulp_err_proportion_ratio = 1
211
+ self.ulp_err_status = CompareConst.PASS
212
+ self.compare_message = ""
213
+
214
+ def __str__(self):
215
+ return f"{self.api_name}"
216
+
217
+ def get_result(self):
218
+ self.mean_ulp_err = convert_str_to_float(self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
219
+ gpu_mean_ulp_err = convert_str_to_float(self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
220
+ inf_nan_consistency = True
221
+ if is_inf_or_nan(self.mean_ulp_err) or is_inf_or_nan(gpu_mean_ulp_err):
222
+ _, inf_nan_consistency, message = check_inf_or_nan(self.mean_ulp_err, gpu_mean_ulp_err,
223
+ ApiPrecisionCompareColumn.MEAN_ULP_ERR)
224
+ self.compare_message += message
225
+ self.ulp_err_proportion = convert_str_to_float(
226
+ self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION))
227
+ self.ulp_err_proportion_ratio, ulp_inf_nan_consistency, message = self._calc_ratio(
228
+ ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
229
+ self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION),
230
+ self.gpu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION), 10000.0)
231
+ inf_nan_consistency = inf_nan_consistency and ulp_inf_nan_consistency
232
+ self.compare_message += message
233
+ if inf_nan_consistency:
234
+ self.ulp_err_status = self._get_ulp_status(self.npu_precision.get(ApiPrecisionCompareColumn.DEVICE_DTYPE))
235
+ else:
236
+ self.ulp_err_status = CompareConst.ERROR
237
+
238
+ def _get_ulp_status(self, dtype):
239
+ if dtype == torch.float32:
240
+ if self.mean_ulp_err < 64:
241
+ return CompareConst.PASS
242
+ elif self.ulp_err_proportion < 0.05:
243
+ return CompareConst.PASS
244
+ elif self.ulp_err_proportion_ratio < 1:
245
+ return CompareConst.PASS
246
+ else:
247
+ self.compare_message += "ERROR: ULP误差不满足标准\n"
248
+ return CompareConst.ERROR
249
+ else:
250
+ if self.ulp_err_proportion < 0.001:
251
+ return CompareConst.PASS
252
+ elif self.ulp_err_proportion_ratio < 1:
253
+ return CompareConst.PASS
254
+ else:
255
+ self.compare_message += "ERROR: ULP误差不满足标准\n"
256
+ return CompareConst.ERROR
257
+
258
+
259
+ def write_detail_csv(content, save_path):
260
+ rows = []
261
+ content = ["{:.{}f}".format(item, msCheckerConfig.precision) \
262
+ if isinstance(item, float) else item for item in content]
263
+ rows.append(content)
264
+ write_csv(rows, save_path)
265
+
266
+
267
+ def api_precision_compare(config):
268
+ logger.info("Start compare task")
269
+ logger.info(f"Compare task result will be saved in {config.result_csv_path}")
270
+ logger.info(f"Compare task detail will be saved in {config.details_csv_path}")
271
+ try:
272
+ npu_data = pd.read_csv(config.npu_csv_path)
273
+ except Exception as err:
274
+ logger.error(f"Open npu csv Error: %s" % str(err))
275
+ check_csv_columns(npu_data.columns, "npu_csv")
276
+ try:
277
+ gpu_data = pd.read_csv(config.gpu_csv_path)
278
+ except Exception as err:
279
+ logger.error(f"Open gpu csv Error: %s" % str(err))
280
+ check_csv_columns(gpu_data.columns, "gpu_csv")
281
+ detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
282
+ result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
283
+ write_csv(result_csv_title, config.result_csv_path)
284
+ write_csv(detail_csv_title, config.details_csv_path)
285
+ try:
286
+ analyse_csv(npu_data, gpu_data, config)
287
+ except Exception as err:
288
+ logger.error(f"Analyse csv Error: %s" % str(err))
289
+ change_mode(config.result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
290
+ change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
291
+
292
+
293
+ def online_api_precision_compare(online_config):
294
+ rank = online_config.rank
295
+ result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace("_rank*.csv", f"_rank{rank}.csv")
296
+ details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace("_rank*.csv", f"_rank{rank}.csv")
297
+ detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
298
+ result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
299
+ if not os.path.exists(result_csv_path):
300
+ write_csv(result_csv_title, result_csv_path)
301
+ if not os.path.exists(details_csv_path):
302
+ write_csv(detail_csv_title, details_csv_path)
303
+ config = CompareConfig("", "", result_csv_path, details_csv_path)
304
+ try:
305
+ npu_data, gpu_data = online_config.npu_data, online_config.gpu_data
306
+ check_csv_columns(npu_data.columns, "npu_csv")
307
+ check_csv_columns(gpu_data.columns, "gpu_csv")
308
+ analyse_csv(npu_data, gpu_data, config)
309
+ except Exception as err:
310
+ logger.error(f"Online api precision compare Error: {str(err)}")
311
+ change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
312
+ change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
313
+
314
+
315
+ def analyse_csv(npu_data, gpu_data, config):
316
+ forward_status, backward_status = [], []
317
+ last_api_name, last_api_dtype, last_api_full_name = None, None, None
318
+ for _, row_npu in npu_data.iterrows():
319
+ message = ''
320
+ compare_column = ApiPrecisionOutputColumn()
321
+ full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
322
+ row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
323
+ api_name, api_full_name, direction_status = extract_detailed_api_segments(full_api_name_with_direction_status)
324
+ if not api_full_name:
325
+ err_message = f"The API name {full_api_name_with_direction_status} is invalid."
326
+ logger.error(err_message)
327
+ compare_column.api_name = full_api_name_with_direction_status
328
+ compare_column.compare_result = CompareConst.SKIP
329
+ compare_column.compare_message = err_message
330
+ write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
331
+ write_csv([[full_api_name_with_direction_status, CompareConst.SKIP, CompareConst.SKIP, err_message]],
332
+ config.result_csv_path)
333
+ continue
334
+ if row_gpu.empty:
335
+ logger.warning(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.')
336
+ continue
337
+ if len(row_gpu) > 1:
338
+ msg = f'This API : {full_api_name_with_direction_status} has multiple records in the GPU data.'
339
+ raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
340
+ row_gpu = row_gpu.iloc[0]
341
+ new_status = CompareConst.SPACE
342
+ try:
343
+ new_status = get_api_status(row_npu, row_gpu, api_name, compare_column)
344
+ except Exception as err:
345
+ logger.error(f"Get api status error: {str(err)}")
346
+ compare_column.api_name = full_api_name_with_direction_status
347
+ compare_column.compare_result = CompareConst.SKIP
348
+ compare_column.compare_message = str(err)
349
+ write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
350
+ write_csv([[full_api_name_with_direction_status, CompareConst.SKIP, CompareConst.SKIP, str(err)]],
351
+ config.result_csv_path)
352
+ continue
353
+
354
+ write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
355
+
356
+ if last_api_name is not None and api_full_name != last_api_name:
357
+ if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
358
+ message = unsupported_message
359
+ write_csv([[last_api_name, CompareConst.SKIP, CompareConst.SKIP, message]], config.result_csv_path)
360
+ print_test_success(last_api_name, CompareConst.SKIP, CompareConst.SKIP)
361
+ forward_status, backward_status = [], []
362
+ message = ''
363
+ else:
364
+ forward_result = get_api_checker_result(forward_status)
365
+ backward_result = get_api_checker_result(backward_status)
366
+ message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
367
+ write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
368
+ print_test_success(last_api_name, forward_result, backward_result)
369
+ forward_status, backward_status = [], []
370
+ message = ''
371
+
372
+ is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST
373
+ last_api_name = api_full_name
374
+
375
+ last_api_dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
376
+ if not is_supported:
377
+ continue
378
+
379
+ if direction_status == 'forward':
380
+ forward_status.append(new_status)
381
+ elif direction_status == 'backward':
382
+ backward_status.append(new_status)
383
+ else:
384
+ logger.error(f"Invalid direction status: {direction_status}")
385
+
386
+ if last_api_name is not None:
387
+ if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
388
+ message = unsupported_message
389
+ write_csv([[last_api_name, CompareConst.SKIP, CompareConst.SKIP, message]], config.result_csv_path)
390
+ print_test_success(last_api_name, CompareConst.SKIP, CompareConst.SKIP)
391
+ else:
392
+ forward_result = get_api_checker_result(forward_status)
393
+ backward_result = get_api_checker_result(backward_status)
394
+ message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
395
+ write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
396
+ print_test_success(last_api_name, forward_result, backward_result)
397
+
398
+
399
+ def get_api_status(row_npu, row_gpu, api_name, compare_column):
400
+ full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
401
+ # 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对
402
+ if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace():
403
+ compare_column.api_name = full_api_name_with_direction_status
404
+ compare_column.compare_result = CompareConst.SKIP
405
+ compare_column.compare_message = row_npu[ApiPrecisionCompareColumn.MESSAGE]
406
+ new_status = CompareConst.SKIP
407
+ else:
408
+ compare_column.api_name = full_api_name_with_direction_status
409
+ if api_name in thousandth_standard_api:
410
+ new_status = record_thousandth_threshold_result(compare_column, row_npu)
411
+ elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or \
412
+ api_name in binary_standard_api:
413
+ new_status = record_binary_consistency_result(api_name, compare_column, row_npu)
414
+ elif api_name in absolute_standard_api:
415
+ new_status = record_absolute_threshold_result(compare_column, row_npu)
416
+ elif api_name in ulp_standard_api and \
417
+ row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST:
418
+ us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu)
419
+ new_status = record_ulp_compare_result(compare_column, us)
420
+ elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST:
421
+ bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu)
422
+ new_status = record_benchmark_compare_result(compare_column, bs)
423
+ return new_status
424
+
425
+
426
+ def print_test_success(api_full_name, forward_result, backward_result):
427
+ is_fwd_success = (forward_result == CompareConst.PASS)
428
+ is_bwd_success = (backward_result == CompareConst.PASS or backward_result == CompareConst.SPACE)
429
+ logger.info(f"running api_full_name {api_full_name} compare, "
430
+ f"is_fwd_success: {is_fwd_success}, "
431
+ f"is_bwd_success: {is_bwd_success}")
432
+
433
+
434
+ def check_error_rate(npu_error_rate):
435
+ return CompareConst.PASS if convert_str_to_float(npu_error_rate) == 0 else CompareConst.ERROR
436
+
437
+
438
+ def get_absolute_threshold_result(row_npu):
439
+ inf_nan_error_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO])
440
+ rel_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.REL_ERR_RATIO])
441
+ abs_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.ABS_ERR_RATIO])
442
+
443
+ inf_nan_result = CompareConst.PASS if inf_nan_error_ratio == 0 else CompareConst.ERROR
444
+ rel_err_result = CompareConst.PASS if rel_err_ratio == 0 else CompareConst.ERROR
445
+ abs_err_result = CompareConst.PASS if abs_err_ratio == 0 else CompareConst.ERROR
446
+
447
+ if CompareConst.ERROR in [inf_nan_result, rel_err_result, abs_err_result]:
448
+ absolute_threshold_result = CompareConst.ERROR
449
+ else:
450
+ absolute_threshold_result = CompareConst.PASS
451
+
452
+ return {
453
+ "inf_nan_error_ratio": inf_nan_error_ratio,
454
+ "inf_nan_result": inf_nan_result,
455
+ "rel_err_ratio": rel_err_ratio,
456
+ "rel_err_result": rel_err_result,
457
+ "abs_err_ratio": abs_err_ratio,
458
+ "abs_err_result": abs_err_result,
459
+ "absolute_threshold_result": absolute_threshold_result,
460
+ }
461
+
462
+
463
+ def get_api_checker_result(status):
464
+ if not status:
465
+ return CompareConst.SPACE
466
+ if all(item == CompareConst.SKIP for item in status):
467
+ return CompareConst.SKIP
468
+ for const in (CompareConst.ERROR, CompareConst.WARNING):
469
+ if const in status:
470
+ return const
471
+ return CompareConst.PASS
472
+
473
+
474
+ def check_csv_columns(columns, csv_type):
475
+ required_columns = ApiPrecisionCompareColumn.to_required_columns()
476
+ missing_columns = [column for column in required_columns if column not in columns]
477
+ if missing_columns:
478
+ msg = f"The following columns {','.join(missing_columns)} are missing in{csv_type}"
479
+ raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
480
+
481
+
482
+ def record_binary_consistency_result(api_name, compare_column, row_npu):
483
+ new_status = check_error_rate(row_npu[ApiPrecisionCompareColumn.ERROR_RATE])
484
+ compare_column.error_rate = row_npu[ApiPrecisionCompareColumn.ERROR_RATE]
485
+ compare_column.error_rate_status = new_status
486
+ compare_column.compare_result = new_status
487
+ compare_column.compare_algorithm = "二进制一致法"
488
+ message = ''
489
+ if compare_column.error_rate_status == CompareConst.ERROR:
490
+ message += "ERROR: 二进制一致错误率超过阈值\n"
491
+ message += CompareMessage.get(api_name, "")
492
+ compare_column.compare_message = message
493
+ return new_status
494
+
495
+
496
+ def record_absolute_threshold_result(compare_column, row_npu):
497
+ absolute_threshold_result = get_absolute_threshold_result(row_npu)
498
+ compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
499
+ compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
500
+ compare_column.rel_err_ratio = absolute_threshold_result.get("rel_err_ratio")
501
+ compare_column.rel_err_ratio_status = absolute_threshold_result.get("rel_err_result")
502
+ compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
503
+ compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
504
+ compare_column.compare_result = absolute_threshold_result.get("absolute_threshold_result")
505
+ compare_column.compare_algorithm = "绝对阈值法"
506
+ message = ''
507
+ if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
508
+ message += "ERROR: inf/nan错误率超过阈值\n"
509
+ if compare_column.rel_err_ratio_status == CompareConst.ERROR:
510
+ message += "ERROR: 相对误差错误率超过阈值\n"
511
+ if compare_column.abs_err_ratio_status == CompareConst.ERROR:
512
+ message += "ERROR: 绝对误差错误率超过阈值\n"
513
+ compare_column.compare_message = message
514
+ return compare_column.compare_result
515
+
516
+
517
+ def record_benchmark_compare_result(compare_column, bs):
518
+ bs.get_result()
519
+ compare_column.small_value_err_ratio = bs.small_value_err_ratio
520
+ compare_column.small_value_err_status = bs.small_value_err_status
521
+ compare_column.rmse_ratio = bs.rmse_ratio
522
+ compare_column.rmse_status = bs.rmse_status
523
+ compare_column.max_rel_err_ratio = bs.max_rel_err_ratio
524
+ compare_column.max_rel_err_status = bs.max_rel_err_status
525
+ compare_column.mean_rel_err_ratio = bs.mean_rel_err_ratio
526
+ compare_column.mean_rel_err_status = bs.mean_rel_err_status
527
+ compare_column.eb_ratio = bs.eb_ratio
528
+ compare_column.eb_status = bs.eb_status
529
+ compare_column.compare_result = bs.final_result
530
+ compare_column.compare_algorithm = "标杆比对法"
531
+ compare_column.compare_message = bs.compare_message
532
+ for status_attr, messages in benchmark_message.items():
533
+ status_value = getattr(compare_column, status_attr)
534
+ if status_value in messages:
535
+ compare_column.compare_message += messages[status_value]
536
+ return compare_column.compare_result
537
+
538
+
539
+ def record_ulp_compare_result(compare_column, us):
540
+ us.get_result()
541
+ compare_column.mean_ulp_err = us.mean_ulp_err
542
+ compare_column.ulp_err_proportion = us.ulp_err_proportion
543
+ compare_column.ulp_err_proportion_ratio = us.ulp_err_proportion_ratio
544
+ compare_column.ulp_err_status = us.ulp_err_status
545
+ compare_column.compare_result = us.ulp_err_status
546
+ compare_column.compare_algorithm = "ULP误差比对法"
547
+ compare_column.compare_message = us.compare_message
548
+ return compare_column.compare_result
549
+
550
+
551
+ def check_thousandth_rate(thousandth_rate):
552
+ return CompareConst.PASS if convert_str_to_float(thousandth_rate) >= 0.999 else CompareConst.ERROR
553
+
554
+
555
+ def record_thousandth_threshold_result(compare_column, row_npu):
556
+ new_status = check_thousandth_rate(row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH])
557
+ compare_column.rel_err_thousandth = row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
558
+ compare_column.rel_err_thousandth_status = new_status
559
+ compare_column.compare_result = new_status
560
+ compare_column.compare_algorithm = "双千指标法"
561
+ message = ''
562
+ if compare_column.rel_err_thousandth_status == CompareConst.ERROR:
563
+ message += "ERROR: 双千指标不达标\n"
564
+ compare_column.compare_message = message
565
+ return compare_column.compare_result
566
+
567
+
568
+ def _api_precision_compare(parser=None):
569
+ if not parser:
570
+ parser = argparse.ArgumentParser()
571
+ _api_precision_compare_parser(parser)
572
+ args = parser.parse_args(sys.argv[1:])
573
+ _api_precision_compare_command(args)
574
+
575
+
576
+ def _api_precision_compare_command(args):
577
+ npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail')
578
+ gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail')
579
+ out_path = os.path.realpath(args.out_path) if args.out_path else "./"
580
+ check_path_before_create(out_path)
581
+ create_directory(out_path)
582
+ out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
583
+ out_path = out_path_checker.common_check()
584
+ result_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_RESULT_FILE_NAME)
585
+ details_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_DETAILS_FILE_NAME)
586
+ compare_config = CompareConfig(npu_csv_path, gpu_csv_path, result_csv_path, details_csv_path)
587
+ api_precision_compare(compare_config)
588
+
589
+
590
+ def _api_precision_compare_parser(parser):
591
+ parser.add_argument("-npu", "--npu_csv_path", dest="npu_csv_path", default="", type=str,
592
+ help="<Required> , Accuracy_checking_details.csv generated on the NPU by using the "
593
+ "api_accuracy_checker tool.",
594
+ required=True)
595
+ parser.add_argument("-gpu", "--gpu_csv_path", dest="gpu_csv_path", default="", type=str,
596
+ help="<Required> Accuracy_checking_details.csv generated on the GPU by using the "
597
+ "api_accuracy_checker tool.",
598
+ required=False)
599
+ parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
600
+ help="<optional> The api precision compare task result out path.",
601
+ required=False)
602
+
603
+
604
+ if __name__ == '__main__':
605
+ _api_precision_compare()
606
+ logger.info("Compare task completed.")