mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -29,12 +29,16 @@ from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
29
29
  from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
30
30
  API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
31
31
  ApiPrecisionCompareColumn, absolute_standard_api, binary_standard_api, ulp_standard_api, thousandth_standard_api, \
32
- BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage, is_inf_or_nan, \
33
- check_inf_or_nan
32
+ BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage
33
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_input import PrecisionCompareInput
34
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_register import StandardRegistry
35
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.ulp_compare import UlpPrecisionCompare
36
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.benchmark_compare import BenchmarkPrecisionCompare
37
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
34
38
  from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
35
39
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path
36
- from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments
37
- from msprobe.core.common.file_utils import FileChecker, change_mode, check_path_before_create, create_directory
40
+ from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments, extract_basic_api_segments
41
+ from msprobe.core.common.file_utils import FileChecker, change_mode, create_directory
38
42
  from msprobe.pytorch.common.log import logger
39
43
  from msprobe.core.common.utils import CompareException
40
44
  from msprobe.core.common.const import Const, CompareConst, FileCheckConst
@@ -47,30 +51,6 @@ BenchmarkInfNanConsistency = namedtuple('BenchmarkInfNanConsistency', ['small_va
47
51
  'eb_inf_nan_consistency'])
48
52
  UNSUPPORTED_MESSAGE = 'This data type does not support benchmark compare.'
49
53
 
50
- DEFAULT_THRESHOLD = 1
51
-
52
- benchmark_algorithms_thresholds = {
53
- 'small_value': {
54
- 'error_threshold': 2,
55
- 'warning_threshold': 1
56
- },
57
- 'rmse': {
58
- 'error_threshold': 2,
59
- 'warning_threshold': 1
60
- },
61
- 'max_rel_err': {
62
- 'error_threshold': 10,
63
- 'warning_threshold': 1
64
- },
65
- 'mean_rel_err': {
66
- 'error_threshold': 2,
67
- 'warning_threshold': 1
68
- },
69
- 'eb': {
70
- 'error_threshold': 2,
71
- 'warning_threshold': 1
72
- }
73
- }
74
54
 
75
55
  benchmark_message = {
76
56
  "small_value_err_status": {
@@ -92,189 +72,6 @@ benchmark_message = {
92
72
  }
93
73
 
94
74
 
95
- class Standard:
96
- @staticmethod
97
- def _calc_ratio(column_name, x, y, default_value):
98
- '''
99
- 计算npu侧和gpu侧统计量的比值
100
- 输入:
101
- column_name:统计量名称
102
- x:npu侧统计量
103
- y:gpu侧统计量
104
- default:当x不接近0,y接近0,设置的比值默认值
105
- 输出:
106
- ratio:统计量x和y的比值
107
- inf_nan_consistency:不出现inf或nan时为True,出现inf或nan时必须同时为inf或-inf或nan才为True,否则为False
108
- message:当出现inf或nan时的提示信息
109
- '''
110
- x, y = convert_str_to_float(x), convert_str_to_float(y)
111
-
112
- if is_inf_or_nan(x) or is_inf_or_nan(y):
113
- return check_inf_or_nan(x, y, column_name)
114
-
115
- inf_nan_consistency = True
116
- message = ""
117
- if math.isclose(y, 0.0):
118
- if math.isclose(x, 0.0):
119
- return 1.0, inf_nan_consistency, message
120
- else:
121
- return default_value, inf_nan_consistency, message
122
- else:
123
- return abs(x / y), inf_nan_consistency, message
124
-
125
-
126
- class BenchmarkStandard(Standard):
127
- def __init__(self, api_name, npu_precision, gpu_precision):
128
- self.api_name = api_name
129
- self.npu_precision = npu_precision
130
- self.gpu_precision = gpu_precision
131
- self.small_value_err_ratio = 1
132
- self.rmse_ratio = 1
133
- self.max_rel_err_ratio = 1
134
- self.mean_rel_err_ratio = 1
135
- self.eb_ratio = 1
136
- self.small_value_err_status = CompareConst.PASS
137
- self.rmse_status = CompareConst.PASS
138
- self.max_rel_err_status = CompareConst.PASS
139
- self.mean_rel_err_status = CompareConst.PASS
140
- self.eb_status = CompareConst.PASS
141
- self.check_result_list = []
142
- self.final_result = CompareConst.PASS
143
- self.compare_message = ""
144
-
145
- def __str__(self):
146
- return "%s" % (self.api_name)
147
-
148
- @staticmethod
149
- def _get_status(ratio, algorithm):
150
- if math.isnan(ratio) or math.isinf(ratio):
151
- return CompareConst.PASS
152
- error_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('error_threshold', DEFAULT_THRESHOLD)
153
- warning_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('warning_threshold',
154
- DEFAULT_THRESHOLD)
155
- if ratio > error_threshold:
156
- return CompareConst.ERROR
157
- elif ratio > warning_threshold:
158
- return CompareConst.WARNING
159
- return CompareConst.PASS
160
-
161
- def get_result(self):
162
- inf_nan_consistency = self._compare_ratio()
163
- small_value_inf_nan_consistency = inf_nan_consistency.small_value_inf_nan_consistency
164
- rmse_inf_nan_consistency = inf_nan_consistency.rmse_inf_nan_consistency
165
- max_rel_inf_nan_consistency = inf_nan_consistency.max_rel_inf_nan_consistency
166
- mean_rel_inf_nan_consistency = inf_nan_consistency.mean_rel_inf_nan_consistency
167
- eb_inf_nan_consistency = inf_nan_consistency.eb_inf_nan_consistency
168
- self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') if \
169
- small_value_inf_nan_consistency else CompareConst.ERROR
170
- self.check_result_list.append(self.small_value_err_status)
171
- self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') if rmse_inf_nan_consistency \
172
- else CompareConst.ERROR
173
- self.check_result_list.append(self.rmse_status)
174
- self.max_rel_err_status = self._get_status(
175
- self.max_rel_err_ratio, 'max_rel_err') if max_rel_inf_nan_consistency else CompareConst.ERROR
176
- self.check_result_list.append(self.max_rel_err_status)
177
- self.mean_rel_err_status = self._get_status(
178
- self.mean_rel_err_ratio, 'mean_rel_err') if mean_rel_inf_nan_consistency else CompareConst.ERROR
179
- self.check_result_list.append(self.mean_rel_err_status)
180
- self.eb_status = self._get_status(self.eb_ratio, 'eb')
181
- if CompareConst.ERROR in self.check_result_list:
182
- self.final_result = CompareConst.ERROR
183
- elif CompareConst.WARNING in self.check_result_list:
184
- self.final_result = CompareConst.WARNING
185
-
186
- def to_column_value(self):
187
- return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
188
- self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
189
- self.mean_rel_err_status, self.eb_ratio, self.eb_status]
190
-
191
- def _compare_ratio(self):
192
-
193
- self.small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = self._calc_ratio(
194
- ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE,
195
- self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE),
196
- self.gpu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), 10000.0)
197
- self.compare_message += small_value_message
198
- self.rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE,
199
- self.npu_precision.get(ApiPrecisionCompareColumn.RMSE),
200
- self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0)
201
- self.compare_message += rmse_message
202
- self.max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = self._calc_ratio(
203
- ApiPrecisionCompareColumn.MAX_REL_ERR,
204
- self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR),
205
- self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0)
206
- self.compare_message += max_rel_message
207
- self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(
208
- ApiPrecisionCompareColumn.MEAN_REL_ERR,
209
- self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR),
210
- self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0)
211
- self.compare_message += mean_rel_message
212
- self.eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB,
213
- self.npu_precision.get(ApiPrecisionCompareColumn.EB),
214
- self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0)
215
- self.compare_message += eb_message
216
-
217
- return BenchmarkInfNanConsistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
218
- max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency,
219
- eb_inf_nan_consistency)
220
-
221
-
222
- class ULPStandard(Standard):
223
- def __init__(self, api_name, npu_precision, gpu_precision):
224
- self.api_name = api_name
225
- self.npu_precision = npu_precision
226
- self.gpu_precision = gpu_precision
227
- self.mean_ulp_err = 0
228
- self.ulp_err_proportion = 0
229
- self.ulp_err_proportion_ratio = 1
230
- self.ulp_err_status = CompareConst.PASS
231
- self.compare_message = ""
232
-
233
- def __str__(self):
234
- return f"{self.api_name}"
235
-
236
- def get_result(self):
237
- self.mean_ulp_err = convert_str_to_float(self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
238
- gpu_mean_ulp_err = convert_str_to_float(self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
239
- inf_nan_consistency = True
240
- if is_inf_or_nan(self.mean_ulp_err) or is_inf_or_nan(gpu_mean_ulp_err):
241
- _, inf_nan_consistency, message = check_inf_or_nan(self.mean_ulp_err, gpu_mean_ulp_err,
242
- ApiPrecisionCompareColumn.MEAN_ULP_ERR)
243
- self.compare_message += message
244
- self.ulp_err_proportion = convert_str_to_float(
245
- self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION))
246
- self.ulp_err_proportion_ratio, ulp_inf_nan_consistency, message = self._calc_ratio(
247
- ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
248
- self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION),
249
- self.gpu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION), 10000.0)
250
- inf_nan_consistency = inf_nan_consistency and ulp_inf_nan_consistency
251
- self.compare_message += message
252
- if inf_nan_consistency:
253
- self.ulp_err_status = self._get_ulp_status(self.npu_precision.get(ApiPrecisionCompareColumn.DEVICE_DTYPE))
254
- else:
255
- self.ulp_err_status = CompareConst.ERROR
256
-
257
- def _get_ulp_status(self, dtype):
258
- if dtype == torch.float32:
259
- if self.mean_ulp_err < 64:
260
- return CompareConst.PASS
261
- elif self.ulp_err_proportion < 0.05:
262
- return CompareConst.PASS
263
- elif self.ulp_err_proportion_ratio < 1:
264
- return CompareConst.PASS
265
- else:
266
- self.compare_message += "ERROR: ULP误差不满足标准\n"
267
- return CompareConst.ERROR
268
- else:
269
- if self.ulp_err_proportion < 0.001:
270
- return CompareConst.PASS
271
- elif self.ulp_err_proportion_ratio < 1:
272
- return CompareConst.PASS
273
- else:
274
- self.compare_message += "ERROR: ULP误差不满足标准\n"
275
- return CompareConst.ERROR
276
-
277
-
278
75
  def write_detail_csv(content, save_path):
279
76
  rows = []
280
77
  content = ["{:.{}f}".format(item, msCheckerConfig.precision) \
@@ -283,6 +80,17 @@ def write_detail_csv(content, save_path):
283
80
  write_csv(rows, save_path)
284
81
 
285
82
 
83
+ def register_compare_func():
84
+ registry = StandardRegistry()
85
+ registry.register(CompareConst.ABSOLUTE_THRESHOLD, record_absolute_threshold_result)
86
+ registry.register(CompareConst.BINARY_CONSISTENCY, record_binary_consistency_result)
87
+ registry.register(CompareConst.ULP_COMPARE, record_ulp_compare_result)
88
+ registry.register(CompareConst.THOUSANDTH_STANDARD, record_thousandth_threshold_result)
89
+ registry.register(CompareConst.BENCHMARK, record_benchmark_compare_result)
90
+ registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, record_accumulative_error_compare_result)
91
+ return registry
92
+
93
+
286
94
  def api_precision_compare(config):
287
95
  logger.info("Start compare task")
288
96
  logger.info(f"Compare task result will be saved in {config.result_csv_path}")
@@ -337,6 +145,8 @@ def analyse_csv(npu_data, gpu_data, config):
337
145
  forward_status, backward_status = [], []
338
146
  last_api_name, last_api_dtype, last_api_full_name = None, None, None
339
147
  last_api_skip_message = ''
148
+ registry = register_compare_func()
149
+
340
150
  for _, row_npu in npu_data.iterrows():
341
151
  message = ''
342
152
  compare_column = ApiPrecisionOutputColumn()
@@ -362,7 +172,7 @@ def analyse_csv(npu_data, gpu_data, config):
362
172
  row_gpu = row_gpu.iloc[0]
363
173
  new_status = CompareConst.SPACE
364
174
  try:
365
- new_status = get_api_status(row_npu, row_gpu, api_name, compare_column)
175
+ new_status = get_api_status(row_npu, row_gpu, api_name, compare_column, registry)
366
176
  except Exception as err:
367
177
  logger.error(f"Get api status error: {str(err)}")
368
178
  compare_column.api_name = full_api_name_with_direction_status
@@ -383,7 +193,8 @@ def analyse_csv(npu_data, gpu_data, config):
383
193
  else:
384
194
  forward_result = get_api_checker_result(forward_status)
385
195
  backward_result = get_api_checker_result(backward_status)
386
- message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
196
+ _, base_api_name = extract_basic_api_segments(last_api_name)
197
+ message += CompareMessage.get(base_api_name, "") if forward_result == CompareConst.ERROR else ""
387
198
  message += last_api_skip_message if forward_result == CompareConst.SKIP else ""
388
199
  write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
389
200
  print_test_success(last_api_name, forward_result, backward_result)
@@ -415,37 +226,30 @@ def analyse_csv(npu_data, gpu_data, config):
415
226
  else:
416
227
  forward_result = get_api_checker_result(forward_status)
417
228
  backward_result = get_api_checker_result(backward_status)
418
- message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
229
+ _, base_api_name = extract_basic_api_segments(last_api_name)
230
+ message += CompareMessage.get(base_api_name, "") if forward_result == CompareConst.ERROR else ""
419
231
  message += last_api_skip_message if forward_result == CompareConst.SKIP else ""
420
232
  write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
421
233
  print_test_success(last_api_name, forward_result, backward_result)
422
234
  last_api_skip_message = ''
423
235
 
424
236
 
425
- def get_api_status(row_npu, row_gpu, api_name, compare_column):
237
+ def get_api_status(row_npu, row_gpu, api_name, compare_column, registry):
426
238
  full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
427
239
  # 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对
428
- if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace():
240
+ if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace() or \
241
+ row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in API_PRECISION_COMPARE_UNSUPPORT_LIST or \
242
+ row_npu[ApiPrecisionCompareColumn.SHAPE] == CompareConst.ZERO_SHAPE:
429
243
  compare_column.api_name = full_api_name_with_direction_status
430
244
  compare_column.compare_result = CompareConst.SKIP
431
245
  compare_column.compare_message = row_npu[ApiPrecisionCompareColumn.MESSAGE]
432
246
  new_status = CompareConst.SKIP
433
247
  else:
434
248
  compare_column.api_name = full_api_name_with_direction_status
435
- if api_name in thousandth_standard_api:
436
- new_status = record_thousandth_threshold_result(compare_column, row_npu)
437
- elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or \
438
- api_name in binary_standard_api:
439
- new_status = record_binary_consistency_result(api_name, compare_column, row_npu)
440
- elif api_name in absolute_standard_api:
441
- new_status = record_absolute_threshold_result(compare_column, row_npu)
442
- elif api_name in ulp_standard_api and \
443
- row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST:
444
- us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu)
445
- new_status = record_ulp_compare_result(compare_column, us)
446
- elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST:
447
- bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu)
448
- new_status = record_benchmark_compare_result(compare_column, bs)
249
+ dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
250
+ input_data = PrecisionCompareInput(row_npu, row_gpu, dtype, compare_column)
251
+ comparison_func = registry.get_comparison_function(api_name, dtype)
252
+ new_status = comparison_func(input_data)
449
253
  return new_status
450
254
 
451
255
 
@@ -505,21 +309,24 @@ def check_csv_columns(columns, csv_type):
505
309
  raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
506
310
 
507
311
 
508
- def record_binary_consistency_result(api_name, compare_column, row_npu):
312
+ def record_binary_consistency_result(input_data):
313
+ row_npu = input_data.row_npu
314
+ compare_column = input_data.compare_column
509
315
  new_status = check_error_rate(row_npu[ApiPrecisionCompareColumn.ERROR_RATE])
510
316
  compare_column.error_rate = row_npu[ApiPrecisionCompareColumn.ERROR_RATE]
511
317
  compare_column.error_rate_status = new_status
512
318
  compare_column.compare_result = new_status
513
- compare_column.compare_algorithm = "二进制一致法"
319
+ compare_column.compare_algorithm = CompareConst.BINARY_CONSISTENCY_ALGORITHM_NAME
514
320
  message = ''
515
321
  if compare_column.error_rate_status == CompareConst.ERROR:
516
322
  message += "ERROR: 二进制一致错误率超过阈值\n"
517
- message += CompareMessage.get(api_name, "")
518
323
  compare_column.compare_message = message
519
324
  return new_status
520
325
 
521
326
 
522
- def record_absolute_threshold_result(compare_column, row_npu):
327
+ def record_absolute_threshold_result(input_data):
328
+ row_npu = input_data.row_npu
329
+ compare_column = input_data.compare_column
523
330
  absolute_threshold_result = get_absolute_threshold_result(row_npu)
524
331
  compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
525
332
  compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
@@ -528,62 +335,88 @@ def record_absolute_threshold_result(compare_column, row_npu):
528
335
  compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
529
336
  compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
530
337
  compare_column.compare_result = absolute_threshold_result.get("absolute_threshold_result")
531
- compare_column.compare_algorithm = "绝对阈值法"
338
+ compare_column.compare_algorithm = CompareConst.ABSOLUTE_THRESHOLD_ALGORITHM_NAME
532
339
  message = ''
533
340
  if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
534
- message += "ERROR: inf/nan错误率超过阈值\n"
341
+ message += "ERROR: inf/nan错误率超过阈值"
535
342
  if compare_column.rel_err_ratio_status == CompareConst.ERROR:
536
- message += "ERROR: 相对误差错误率超过阈值\n"
343
+ message += "ERROR: 相对误差错误率超过阈值"
537
344
  if compare_column.abs_err_ratio_status == CompareConst.ERROR:
538
- message += "ERROR: 绝对误差错误率超过阈值\n"
345
+ message += "ERROR: 绝对误差错误率超过阈值"
539
346
  compare_column.compare_message = message
540
347
  return compare_column.compare_result
541
348
 
542
349
 
543
- def record_benchmark_compare_result(compare_column, bs):
544
- bs.get_result()
545
- compare_column.small_value_err_ratio = bs.small_value_err_ratio
546
- compare_column.small_value_err_status = bs.small_value_err_status
547
- compare_column.rmse_ratio = bs.rmse_ratio
548
- compare_column.rmse_status = bs.rmse_status
549
- compare_column.max_rel_err_ratio = bs.max_rel_err_ratio
550
- compare_column.max_rel_err_status = bs.max_rel_err_status
551
- compare_column.mean_rel_err_ratio = bs.mean_rel_err_ratio
552
- compare_column.mean_rel_err_status = bs.mean_rel_err_status
553
- compare_column.eb_ratio = bs.eb_ratio
554
- compare_column.eb_status = bs.eb_status
555
- compare_column.compare_result = bs.final_result
556
- compare_column.compare_algorithm = "标杆比对法"
557
- compare_column.compare_message = bs.compare_message
350
+ def record_benchmark_compare_result(input_data):
351
+ bs = BenchmarkPrecisionCompare(input_data)
352
+ compare_result = bs.compare()
558
353
  for status_attr, messages in benchmark_message.items():
559
- status_value = getattr(compare_column, status_attr)
354
+ status_value = getattr(input_data.compare_column, status_attr)
560
355
  if status_value in messages:
561
- compare_column.compare_message += messages[status_value]
562
- return compare_column.compare_result
356
+ input_data.compare_column.compare_message += messages[status_value]
357
+ return compare_result
358
+
563
359
 
360
+ def record_ulp_compare_result(input_data):
361
+ us = UlpPrecisionCompare(input_data)
362
+ compare_result = us.compare()
363
+ return compare_result
564
364
 
565
- def record_ulp_compare_result(compare_column, us):
566
- us.get_result()
567
- compare_column.mean_ulp_err = us.mean_ulp_err
568
- compare_column.ulp_err_proportion = us.ulp_err_proportion
569
- compare_column.ulp_err_proportion_ratio = us.ulp_err_proportion_ratio
570
- compare_column.ulp_err_status = us.ulp_err_status
571
- compare_column.compare_result = us.ulp_err_status
572
- compare_column.compare_algorithm = "ULP误差比对法"
573
- compare_column.compare_message = us.compare_message
365
+
366
+ def record_accumulative_error_compare_result(input_data):
367
+ row_npu = input_data.row_npu
368
+ compare_column = input_data.compare_column
369
+ absolute_threshold_result = get_absolute_threshold_result(row_npu)
370
+ threshold_result = absolute_threshold_result.get("absolute_threshold_result")
371
+ eb, eb_result = check_eb(row_npu)
372
+ accumulative_error_compare_result = CompareConst.PASS
373
+ if CompareConst.ERROR in [threshold_result, eb_result]:
374
+ accumulative_error_compare_result = CompareConst.ERROR
375
+
376
+ compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
377
+ compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
378
+ compare_column.rel_err_ratio = absolute_threshold_result.get("rel_err_ratio")
379
+ compare_column.rel_err_ratio_status = absolute_threshold_result.get("rel_err_result")
380
+ compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
381
+ compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
382
+ compare_column.eb_ratio = eb
383
+ compare_column.eb_status = eb_result
384
+ compare_column.compare_result = accumulative_error_compare_result
385
+ compare_column.compare_algorithm = CompareConst.ACCUMULATIVE_ERROR_COMPARE_ALGORITHM_NAME
386
+ message = []
387
+ if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
388
+ message.append("ERROR: inf/nan错误率超过阈值\n")
389
+ if compare_column.rel_err_ratio_status == CompareConst.ERROR:
390
+ message.append("ERROR: 相对误差错误率超过阈值\n")
391
+ if compare_column.abs_err_ratio_status == CompareConst.ERROR:
392
+ message.append("ERROR: 绝对误差错误率超过阈值\n")
393
+ if compare_column.eb_status == CompareConst.ERROR:
394
+ message.append("ERROR: 误差均衡性超过阈值\n")
395
+ compare_column.compare_message = "\n".join(message)
574
396
  return compare_column.compare_result
575
397
 
576
398
 
399
+ def check_eb(row_npu):
400
+ eb = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.EB])
401
+ dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
402
+ eb_threshold = StandardConfig.get_accumulative_error_eb_threshold(dtype)
403
+ eb_result = CompareConst.PASS if eb <= eb_threshold else CompareConst.ERROR
404
+ return eb, eb_result
405
+
406
+
577
407
  def check_thousandth_rate(thousandth_rate):
578
- return CompareConst.PASS if convert_str_to_float(thousandth_rate) >= 0.999 else CompareConst.ERROR
408
+ return CompareConst.PASS if convert_str_to_float(thousandth_rate) >= CompareConst.THOUSANDTH_PASS_VALUE \
409
+ else CompareConst.ERROR
579
410
 
580
411
 
581
- def record_thousandth_threshold_result(compare_column, row_npu):
412
+ def record_thousandth_threshold_result(input_data):
413
+ row_npu = input_data.row_npu
414
+ compare_column = input_data.compare_column
582
415
  new_status = check_thousandth_rate(row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH])
583
416
  compare_column.rel_err_thousandth = row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
584
417
  compare_column.rel_err_thousandth_status = new_status
585
418
  compare_column.compare_result = new_status
586
- compare_column.compare_algorithm = "双千指标法"
419
+ compare_column.compare_algorithm = CompareConst.THOUSANDTH_STANDARD_ALGORITHM_NAME
587
420
  message = ''
588
421
  if compare_column.rel_err_thousandth_status == CompareConst.ERROR:
589
422
  message += "ERROR: 双千指标不达标\n"
@@ -602,8 +435,7 @@ def _api_precision_compare(parser=None):
602
435
  def _api_precision_compare_command(args):
603
436
  npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail')
604
437
  gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail')
605
- out_path = os.path.realpath(args.out_path) if args.out_path else "./"
606
- check_path_before_create(out_path)
438
+ out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
607
439
  create_directory(out_path)
608
440
  out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
609
441
  out_path = out_path_checker.common_check()
@@ -621,7 +453,7 @@ def _api_precision_compare_parser(parser):
621
453
  parser.add_argument("-gpu", "--gpu_csv_path", dest="gpu_csv_path", default="", type=str,
622
454
  help="<Required> Accuracy_checking_details.csv generated on the GPU by using the "
623
455
  "api_accuracy_checker tool.",
624
- required=False)
456
+ required=True)
625
457
  parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
626
458
  help="<optional> The api precision compare task result out path.",
627
459
  required=False)
@@ -66,6 +66,7 @@ BinaryCompareStandard:
66
66
  - greater_
67
67
  - greater_equal
68
68
  - greater_equal_
69
+ - histc
69
70
  - isfinite
70
71
  - isnan
71
72
  - less
@@ -130,4 +131,6 @@ ULPStandard:
130
131
  ThousandthStandard:
131
132
  - conv1d
132
133
  - conv2d
133
-
134
+
135
+ AccumulativeErrorStandard:
136
+ - test_api