mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,36 +15,46 @@
15
15
 
16
16
  import os
17
17
  import re
18
+ import math
19
+ import zlib
20
+ from dataclasses import dataclass
21
+
18
22
  import numpy as np
19
- from msprobe.core.common.const import Const, CompareConst
20
- from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger
23
+
24
+ from msprobe.core.common.const import Const, CompareConst, FileCheckConst
25
+ from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value
21
26
  from msprobe.core.common.file_utils import check_file_or_directory_path
22
27
 
23
28
 
24
29
  def extract_json(dirname, stack_json=False):
25
30
  json_path = ''
26
- for fname in os.listdir(dirname):
27
- if fname == "construct.json":
28
- continue
29
- full_path = os.path.join(dirname, fname)
30
- if full_path.endswith('.json'):
31
- json_path = full_path
32
- if not stack_json and 'stack' not in json_path:
33
- break
34
- if stack_json and 'stack' in json_path:
35
- break
31
+ for filename in os.listdir(dirname):
32
+ target_file_name = 'stack.json' if stack_json else 'dump.json'
33
+ if filename == target_file_name:
34
+ json_path = os.path.join(dirname, filename)
35
+ break
36
36
 
37
37
  # Provide robustness on invalid directory inputs
38
38
  if not json_path:
39
- logger.error(f'No file is found in dump dir {dirname}. ')
40
- raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
39
+ if stack_json:
40
+ logger.warning(f'stack.json is not found in dump dir {dirname}.')
41
+ else:
42
+ logger.error(f'dump.json is not found in dump dir {dirname}.')
43
+ raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
41
44
  return json_path
42
45
 
43
46
 
47
+ def set_stack_json_path(input_param):
48
+ npu_data_dir = os.path.dirname(input_param.get("npu_json_path"))
49
+ stack_path = extract_json(npu_data_dir, stack_json=True)
50
+ input_param["stack_json_path"] = stack_path if stack_path else None
51
+ return bool(stack_path)
52
+
53
+
44
54
  def check_and_return_dir_contents(dump_dir, prefix):
45
55
  """
46
56
  check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
47
- pattern: ^{prefix}(?:0|[0-9][1-9]*)?$
57
+ pattern: ^{prefix}(?:0|[1-9][0-9]*)?$
48
58
 
49
59
  Args:
50
60
  dump_dir (str): dump dir
@@ -60,7 +70,7 @@ def check_and_return_dir_contents(dump_dir, prefix):
60
70
  check_regex_prefix_format_valid(prefix)
61
71
  check_file_or_directory_path(dump_dir, True)
62
72
  contents = os.listdir(dump_dir)
63
- pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$')
73
+ pattern = re.compile(rf'^{prefix}(?:0|[1-9][0-9]*)?$')
64
74
  for name in contents:
65
75
  if not pattern.match(name):
66
76
  logger.error(
@@ -72,6 +82,10 @@ def check_and_return_dir_contents(dump_dir, prefix):
72
82
 
73
83
 
74
84
  def rename_api(npu_name, process):
85
+ """
86
+ 原api: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号}
87
+ rename后: {api_type}.{api_name}.{input/output}.{参数序号}
88
+ """
75
89
  npu_split = npu_name.split(process)
76
90
  try:
77
91
  torch_func_index, in_out = npu_split[0], npu_split[1]
@@ -84,122 +98,89 @@ def rename_api(npu_name, process):
84
98
 
85
99
 
86
100
  def read_op(op_data, op_name):
87
- op_parsed_list = []
88
- if Const.FORWARD in op_name:
89
- if Const.INPUT_ARGS in op_data:
90
- input_item = op_data[Const.INPUT_ARGS]
91
- input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
92
- op_parsed_list = input_parsed_list.copy()
93
- input_parsed_list.clear()
94
- if Const.INPUT_KWARGS in op_data:
95
- kwargs_item = op_data[Const.INPUT_KWARGS]
96
- if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list):
97
- kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '.input', None)
98
- op_parsed_list += kwarg_parsed_list
99
- kwarg_parsed_list.clear()
100
- elif kwargs_item:
101
- for kwarg in kwargs_item:
102
- kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '.input.' + kwarg, None)
103
- op_parsed_list += kwarg_parsed_list
104
- kwarg_parsed_list.clear()
105
- if Const.OUTPUT in op_data:
106
- output_item = op_data[Const.OUTPUT]
107
- output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
108
- op_parsed_list += output_parsed_list
109
- output_parsed_list.clear()
110
- if Const.BACKWARD in op_name:
111
- if Const.INPUT in op_data:
112
- input_item = op_data[Const.INPUT]
113
- input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
114
- op_parsed_list = input_parsed_list.copy()
115
- input_parsed_list.clear()
116
- if Const.OUTPUT in op_data:
117
- output_item = op_data[Const.OUTPUT]
118
- output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
119
- op_parsed_list += output_parsed_list
120
- output_parsed_list.clear()
101
+ if Const.PARAMS_GRAD in op_name.split(Const.SEP):
102
+ op_parsed_list = op_item_parse(op_data, op_name)
103
+ else:
104
+ op_parsed_list = []
105
+ for name in CompareConst.IO_NAME_MAPPING:
106
+ if name in op_data:
107
+ op_parsed_list.extend(op_item_parse(op_data[name], op_name + CompareConst.IO_NAME_MAPPING[name]))
121
108
  return op_parsed_list
122
109
 
123
110
 
124
- def op_item_parse(item, op_name, index, item_list=None, top_bool=True, depth=0):
111
+ def op_item_parse(op_data, op_name: str, depth: int = 0) -> list:
112
+ default_item = {
113
+ 'full_op_name': op_name,
114
+ 'type': None,
115
+ 'Max': None,
116
+ 'Min': None,
117
+ 'Mean': None,
118
+ 'Norm': None,
119
+ 'dtype': None,
120
+ 'shape': None,
121
+ 'md5': None,
122
+ 'value': None,
123
+ 'data_name': '-1'
124
+ }
125
+
125
126
  if depth > Const.MAX_DEPTH:
126
- logger.error(f"parse of api/module of {op_name} exceeds the recursion limit.")
127
+ logger.error(f'parse of api/module of {op_name} exceeds the recursion limit.')
127
128
  raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
128
- if item_list is None:
129
- item_list = []
130
- if item is None or (isinstance(item, dict) and not item):
131
- if not top_bool:
132
- tmp = {
133
- 'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None,
134
- 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'
135
- }
136
- else:
137
- tmp = {
138
- 'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None,
139
- 'shape': None, 'md5': None, 'data_name': '-1'
140
- }
141
- item_list.append(tmp)
142
- return item_list
143
- if index is None:
144
- if isinstance(item, dict):
145
- full_op_name = op_name + '.0'
146
- else:
147
- full_op_name = op_name
148
- else:
149
- full_op_name = op_name + Const.SEP + str(index)
150
- if isinstance(item, dict):
151
- if 'type' not in item:
152
- for kwarg in item:
153
- kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None, depth=depth+1)
154
- item_list += kwarg_parsed_list
155
- kwarg_parsed_list.clear()
156
- elif 'dtype' in item:
157
- parsed_item = item
158
- parsed_item['full_op_name'] = full_op_name
159
- item_list.append(parsed_item)
160
- elif 'type' in item:
161
- parsed_item = {}
162
- if item['type'] == 'torch.Size':
163
- parsed_item['full_op_name'] = full_op_name
164
- parsed_item['dtype'] = 'torch.Size'
165
- parsed_item['shape'] = str(item['value'])
166
- parsed_item['md5'] = None
167
- parsed_item['Max'] = None
168
- parsed_item['Min'] = None
169
- parsed_item['Mean'] = None
170
- parsed_item['Norm'] = None
171
- parsed_item['data_name'] = '-1'
172
- item_list.append(parsed_item)
173
- elif item['type'] == 'slice':
174
- parsed_item['full_op_name'] = full_op_name
175
- parsed_item['dtype'] = 'slice'
176
- parsed_item['shape'] = str(np.shape(np.array(item['value'])))
177
- parsed_item['md5'] = None
178
- parsed_item['Max'] = None
179
- parsed_item['Min'] = None
180
- parsed_item['Mean'] = None
181
- parsed_item['Norm'] = None
182
- parsed_item['data_name'] = '-1'
183
- item_list.append(parsed_item)
129
+
130
+ if op_data is None:
131
+ return [default_item]
132
+ elif not op_data:
133
+ return []
134
+
135
+ item_list = []
136
+ if isinstance(op_data, list):
137
+ for i, data in enumerate(op_data):
138
+ if Const.PARAMS_GRAD not in op_name.split(Const.SEP):
139
+ item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), depth + 1))
184
140
  else:
185
- parsed_item['full_op_name'] = full_op_name
186
- parsed_item['dtype'] = str(type(item['value']))
187
- parsed_item['shape'] = '[]'
188
- parsed_item['md5'] = None
189
- parsed_item['Max'] = item['value']
190
- parsed_item['Min'] = item['value']
191
- parsed_item['Mean'] = item['value']
192
- parsed_item['Norm'] = item['value']
193
- parsed_item['data_name'] = '-1'
194
- item_list.append(parsed_item)
195
- else:
196
- resolve_api_special_parameters(item, full_op_name, item_list)
197
- else:
198
- for j, item_spec in enumerate(item):
199
- op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False, depth=depth+1)
141
+ item_list.extend(op_item_parse(data, op_name, depth + 1))
142
+ elif isinstance(op_data, dict):
143
+ if is_leaf_data(op_data):
144
+ return [gen_op_item(op_data, op_name)]
145
+ for sub_name, sub_data in op_data.items():
146
+ item_list.extend(op_item_parse(sub_data, op_name + Const.SEP + str(sub_name), depth + 1))
200
147
  return item_list
201
148
 
202
149
 
150
+ def is_leaf_data(op_data):
151
+ return 'type' in op_data and isinstance(op_data['type'], str)
152
+
153
+
154
+ def gen_op_item(op_data, op_name):
155
+ op_item = {}
156
+ op_item.update(op_data)
157
+ data_name = op_data.get('data_name') if op_data.get('data_name') else '-1' # 如果是""也返回-1
158
+ op_item['data_name'] = data_name
159
+ op_item['full_op_name'] = data_name.rsplit(Const.SEP, 1)[0] if data_name != '-1' else op_name
160
+
161
+ params = ['Max', 'Min', 'Mean', 'Norm']
162
+ for i in params:
163
+ if i not in op_item:
164
+ op_item[i] = None
165
+
166
+ if not op_item.get('dtype'):
167
+ if op_item.get('type') == 'torch.Size':
168
+ op_item['dtype'] = op_data.get('type')
169
+ op_item['shape'] = str(op_data.get('value'))
170
+ elif op_item.get('type') == 'slice':
171
+ op_item['dtype'] = op_data.get('type')
172
+ op_item['shape'] = str(np.shape(np.array(op_data.get('value'))))
173
+ else:
174
+ op_item['dtype'] = str(type(op_data.get('value')))
175
+ op_item['shape'] = '[]'
176
+ for i in params:
177
+ op_item[i] = op_data.get('value')
178
+ if not op_item.get('md5'):
179
+ op_item['md5'] = f"{zlib.crc32(str(op_data.get('value', '')).encode()):08x}"
180
+
181
+ return op_item
182
+
183
+
203
184
  def resolve_api_special_parameters(data_dict, full_op_name, item_list):
204
185
  """
205
186
  Function Description:
@@ -231,223 +212,387 @@ def resolve_api_special_parameters(data_dict, full_op_name, item_list):
231
212
  item_list.append(parsed_item)
232
213
 
233
214
 
234
- def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False):
215
+ def process_summary_data(summary_data):
216
+ """处理summary_data中的nan值,返回处理后的列表"""
217
+ return [CompareConst.NAN if isinstance(x, float) and math.isnan(x) else x for x in summary_data]
218
+
219
+
220
+ def get_rela_diff_summary_mode(result_item, npu_summary_data, bench_summary_data, err_msg):
221
+ start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
222
+ warning_flag = False
223
+ for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
224
+ if all(isinstance(val, (float, int)) and not isinstance(val, bool) for val in [npu_val, bench_val]):
225
+ diff = npu_val - bench_val
226
+ if math.isnan(diff):
227
+ diff = CompareConst.NAN
228
+ relative = CompareConst.NAN
229
+ else:
230
+ if bench_val != 0:
231
+ relative = str(abs((diff / bench_val) * 100)) + '%'
232
+ else:
233
+ relative = CompareConst.N_A
234
+ magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + CompareConst.EPSILON)
235
+ if magnitude_diff > CompareConst.MAGNITUDE:
236
+ warning_flag = True
237
+ result_item[start_idx + i] = diff
238
+ result_item[start_idx + i + CompareConst.STATISTICS_INDICATOR_NUM] = relative
239
+ else:
240
+ result_item[start_idx + i] = CompareConst.N_A
241
+ result_item[start_idx + i + CompareConst.STATISTICS_INDICATOR_NUM] = CompareConst.N_A
242
+
243
+ accuracy_check = CompareConst.WARNING if warning_flag else ""
244
+ err_msg += "Need double check api accuracy." if warning_flag else ""
245
+ for i in range(start_idx, len(result_item)):
246
+ if str(result_item[i]) in ('inf', '-inf', 'nan'):
247
+ result_item[i] = f'{result_item[i]}\t'
248
+ return result_item, accuracy_check, err_msg
249
+
250
+
251
+ @dataclass
252
+ class ApiItemInfo:
253
+ name: str
254
+ struct: tuple
255
+ stack_info: list
256
+
257
+
258
+ def stack_column_process(result_item, has_stack, index, key, npu_stack_info):
259
+ if has_stack and index == 0 and key == CompareConst.INPUT_STRUCT:
260
+ result_item.extend(npu_stack_info)
261
+ else:
262
+ result_item.append(CompareConst.NONE)
263
+ return result_item
264
+
265
+
266
+ def result_item_init(n_info, b_info, dump_mode):
267
+ n_len = len(n_info.struct)
268
+ b_len = len(b_info.struct)
269
+ struct_long_enough = (n_len > 2 and b_len > 2) if dump_mode == Const.MD5 else (n_len > 1 and b_len > 1)
270
+ if struct_long_enough:
271
+ result_item = [
272
+ n_info.name, b_info.name, n_info.struct[0], b_info.struct[0], n_info.struct[1], b_info.struct[1]
273
+ ]
274
+ if dump_mode == Const.MD5:
275
+ md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF
276
+ result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result])
277
+ elif dump_mode == Const.SUMMARY:
278
+ result_item.extend([" "] * 8)
279
+ else:
280
+ result_item.extend([" "] * 5)
281
+ else:
282
+ err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \
283
+ f"npu_info_struct is {n_info.struct}\n" \
284
+ f"bench_info_struct is {b_info.struct}"
285
+ logger.error(err_msg)
286
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
287
+ return result_item
288
+
289
+
290
+ def count_struct(op_dict):
291
+ parts = [
292
+ CompareConst.OP_NAME,
293
+ CompareConst.INPUT_STRUCT,
294
+ CompareConst.OUTPUT_STRUCT,
295
+ CompareConst.PARAMS_STRUCT,
296
+ CompareConst.PARAMS_GRAD_STRUCT
297
+ ]
298
+ lengths = [len(op_dict.get(part, [])) for part in parts]
299
+ num = lengths[0]
300
+ if num != sum(lengths[1:]):
301
+ logger.error(f"Length of names and structs of op_dict not match. Please check! op_dict: {op_dict}")
302
+ raise CompareException(CompareException.NAMES_STRUCTS_MATCH_ERROR)
303
+ return tuple(lengths)
304
+
305
+
306
+ def get_accuracy(result, n_dict, b_dict, dump_mode):
235
307
  def get_accuracy_core(n_start, n_len, b_start, b_len, key):
236
308
  min_len = min(n_len, b_len)
237
309
  npu_stack_info = n_dict.get("stack_info", None)
238
310
  bench_stack_info = b_dict.get("stack_info", None)
239
311
  has_stack = npu_stack_info and bench_stack_info
240
312
 
241
- all_mode_bool = not (summary_compare or md5_compare)
242
- if all_mode_bool:
313
+ if dump_mode == Const.ALL:
243
314
  npu_data_name = n_dict.get("data_name", None)
244
315
  bench_data_name = b_dict.get("data_name", None)
245
316
 
246
317
  for index in range(min_len):
247
-
248
- n_name = n_dict['op_name'][n_start + index]
249
- b_name = b_dict['op_name'][b_start + index]
250
- n_struct = n_dict[key][index]
251
- b_struct = b_dict[key][index]
318
+ n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name")
319
+ b_name = safe_get_value(b_dict, b_start + index, "b_dict", key="op_name")
320
+ n_struct = safe_get_value(n_dict, index, "n_dict", key=key)
321
+ b_struct = safe_get_value(b_dict, index, "b_dict", key=key)
252
322
  err_msg = ""
253
- if md5_compare:
254
- result_item = [
255
- n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], n_struct[2], b_struct[2],
256
- CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF
257
- ]
258
- if has_stack and index == 0 and key == "input_struct":
259
- result_item.extend(npu_stack_info)
260
- else:
261
- result_item.append(CompareConst.NONE)
323
+
324
+ npu_info = ApiItemInfo(n_name, n_struct, npu_stack_info)
325
+ bench_info = ApiItemInfo(b_name, b_struct, bench_stack_info)
326
+ result_item = result_item_init(npu_info, bench_info, dump_mode)
327
+
328
+ if dump_mode == Const.MD5:
329
+ result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
262
330
  result.append(result_item)
263
331
  continue
264
332
 
265
- if summary_compare:
266
- result_item = [
267
- n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
268
- " ", " ", " ", " ", " ", " ", " ", " "
269
- ]
270
- else:
271
- result_item = [
272
- n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
273
- " ", " ", " ", " ", " "
274
- ]
275
-
276
- npu_summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
277
- result_item.extend(npu_summary_data)
278
- bench_summary_data = b_dict.get(CompareConst.SUMMARY)[b_start + index]
279
- result_item.extend(bench_summary_data)
280
-
281
- if summary_compare:
282
- start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
283
- warning_flag = False
284
- for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
285
- if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
286
- diff = npu_val - bench_val
287
- if bench_val != 0:
288
- relative = str(abs((diff / bench_val) * 100)) + '%'
289
- else:
290
- relative = CompareConst.N_A
291
- result_item[start_idx + i] = diff
292
- result_item[start_idx + i + 4] = relative
293
- magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
294
- if magnitude_diff > 0.5:
295
- warning_flag = True
296
- else:
297
- result_item[start_idx + i] = CompareConst.NONE
298
- accuracy_check = CompareConst.WARNING if warning_flag else ""
299
- err_msg += "Need double check api accuracy." if warning_flag else ""
300
- for i in range(start_idx, len(result_item)):
301
- if str(result_item[i]) in ('inf', '-inf', 'nan'):
302
- result_item[i] = f'{result_item[i]}\t'
303
-
304
- result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES)
333
+ npu_summary_data = safe_get_value(n_dict, n_start + index, "n_dict", key=CompareConst.SUMMARY)
334
+ bench_summary_data = safe_get_value(b_dict, b_start + index, "b_dict", key=CompareConst.SUMMARY)
335
+ result_item.extend(process_summary_data(npu_summary_data))
336
+ result_item.extend(process_summary_data(bench_summary_data))
337
+
338
+ if dump_mode == Const.SUMMARY:
339
+ result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
340
+ bench_summary_data, err_msg)
341
+
342
+ result_item.append(accuracy_check if dump_mode == Const.SUMMARY else CompareConst.ACCURACY_CHECK_YES)
305
343
  result_item.append(err_msg)
306
- if has_stack and index == 0 and key == "input_struct":
307
- result_item.extend(npu_stack_info)
308
- else:
309
- result_item.append(CompareConst.NONE)
310
- if all_mode_bool:
311
- result_item.append(npu_data_name[n_start + index])
344
+ result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
345
+ if dump_mode == Const.ALL:
346
+ result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
312
347
 
313
348
  result.append(result_item)
314
349
 
315
350
  if n_len > b_len:
316
351
  for index in range(b_len, n_len):
317
- n_name = n_dict['op_name'][n_start + index]
318
- n_struct = n_dict[key][index]
319
- if md5_compare:
352
+ try:
353
+ n_name = n_dict['op_name'][n_start + index]
354
+ n_struct = n_dict[key][index]
355
+ if dump_mode == Const.MD5:
356
+ result_item = [
357
+ n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
358
+ n_struct[2], CompareConst.NAN, CompareConst.NAN
359
+ ]
360
+ result.append(result_item)
361
+ continue
320
362
  result_item = [
321
363
  n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
322
- n_struct[2], CompareConst.NAN, CompareConst.NAN
364
+ " ", " ", " ", " ", " "
323
365
  ]
324
- result.append(result_item)
325
- continue
326
- result_item = [
327
- n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
328
- " ", " ", " ", " ", " "
329
- ]
330
- summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
331
- result_item.extend(summary_data)
332
- summary_data = [CompareConst.NAN for _ in range(len(n_dict.get(CompareConst.SUMMARY)[0]))]
333
- result_item.extend(summary_data)
366
+ summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
367
+ result_item.extend(summary_data)
368
+ summary_data = [CompareConst.NAN for _ in range(len(n_dict.get(CompareConst.SUMMARY)[0]))]
369
+ result_item.extend(summary_data)
370
+ except IndexError as e:
371
+ err_msg = "index out of bounds error occurs, please check!\n" \
372
+ f"n_dict is {n_dict}"
373
+ logger.error(err_msg)
374
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
334
375
 
335
376
  err_msg = ""
336
377
  result_item.append(CompareConst.ACCURACY_CHECK_YES)
337
378
  result_item.append(err_msg)
338
-
339
- if has_stack and index == 0 and key == "input_struct":
340
- result_item.extend(npu_stack_info)
341
- else:
342
- result_item.append(CompareConst.NONE)
343
- if all_mode_bool:
344
- result_item.append(npu_data_name[n_start + index])
379
+ result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
380
+ if dump_mode == Const.ALL:
381
+ result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
345
382
 
346
383
  result.append(result_item)
347
384
 
348
- n_num = len(n_dict['op_name'])
349
- b_num = len(b_dict['op_name'])
350
- n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
351
- b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
352
- n_num_output = n_num - n_num_input
353
- b_num_output = b_num - b_num_input
354
- get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
355
- get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, 'output_struct')
385
+ n_num, n_num_input, n_num_output, n_num_params, n_num_params_grad = count_struct(n_dict)
386
+ b_num, b_num_input, b_num_output, b_num_params, b_num_params_grad = count_struct(b_dict)
387
+
388
+ get_accuracy_core(0, n_num_input, 0, b_num_input, CompareConst.INPUT_STRUCT)
389
+ get_accuracy_core(n_num_input + n_num_output, n_num_params, b_num_input + b_num_output, b_num_params,
390
+ CompareConst.PARAMS_STRUCT)
391
+ get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, CompareConst.OUTPUT_STRUCT)
392
+ get_accuracy_core(n_num_input + n_num_output + n_num_params, n_num_params_grad,
393
+ b_num_input + b_num_output + b_num_params, b_num_params_grad,
394
+ CompareConst.PARAMS_GRAD_STRUCT)
356
395
 
357
396
 
358
- def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
359
- index_out = 0
397
+ def append_stack_info(result_item, npu_stack_info, index):
398
+ """添加堆栈信息到 result_item"""
399
+ if npu_stack_info and index == 0:
400
+ result_item.extend(npu_stack_info)
401
+ else:
402
+ result_item.append(CompareConst.NONE)
403
+
404
+
405
+ def get_un_match_accuracy(result, n_dict, dump_mode):
360
406
  npu_stack_info = n_dict.get("stack_info", None)
361
407
  bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
362
- err_msg = CompareConst.NO_BENCH
363
- accuracy_check_res = CompareConst.N_A
364
- for index, n_name in enumerate(n_dict["op_name"]):
365
- name_ele_list = n_name.split(Const.SEP)
366
- if "input" in name_ele_list:
367
- n_struct = n_dict["input_struct"][index]
368
- else:
369
- n_struct = n_dict["output_struct"][index_out]
370
- index_out += 1
371
408
 
372
- result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
373
- if md5_compare:
409
+ struct_to_index_mapping = {
410
+ CompareConst.INPUT_STRUCT: 0,
411
+ CompareConst.OUTPUT_STRUCT: 0,
412
+ CompareConst.PARAMS_STRUCT: 0,
413
+ CompareConst.PARAMS_GRAD_STRUCT: 0
414
+ }
415
+
416
+ op_name_list = n_dict.get(CompareConst.OP_NAME)
417
+ summary_list = n_dict.get(Const.SUMMARY)
418
+ data_name_list = n_dict.get('data_name')
419
+ op_name_reorder, summary_reorder, _ = reorder_op_x_list(op_name_list,
420
+ summary_list,
421
+ data_name_list)
422
+ for index, n_name in enumerate(op_name_reorder):
423
+ _, state = get_name_and_state(n_name)
424
+ struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
425
+ if not struct_key:
426
+ continue
427
+ n_struct = safe_get_value(n_dict, struct_to_index_mapping.get(struct_key), "n_dict", key=struct_key)
428
+ struct_to_index_mapping[struct_key] += 1
429
+
430
+ try:
431
+ result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
432
+ except IndexError as e:
433
+ err_msg = "index out of bounds error occurs, please check!\n" \
434
+ f"op_name of n_dict is {n_dict['op_name']}\n" \
435
+ f"input_struct of n_dict is {n_dict[CompareConst.INPUT_STRUCT]}\n" \
436
+ f"output_struct of n_dict is {n_dict[CompareConst.OUTPUT_STRUCT]}"
437
+ logger.error(err_msg)
438
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
439
+
440
+ if dump_mode == Const.MD5:
374
441
  result_item.extend([CompareConst.N_A] * 3)
375
- if npu_stack_info and index == 0:
376
- result_item.extend(npu_stack_info)
377
- else:
378
- result_item.append(CompareConst.NONE)
442
+ append_stack_info(result_item, npu_stack_info, index)
379
443
  result.append(result_item)
380
444
  continue
381
- if summary_compare:
445
+ if dump_mode == Const.SUMMARY:
382
446
  result_item.extend([CompareConst.N_A] * 8)
383
- else:
447
+ if dump_mode == Const.ALL:
384
448
  result_item.extend([CompareConst.N_A] * 5)
385
- npu_summary_data = n_dict.get("summary")[index]
386
- result_item.extend(npu_summary_data)
449
+
450
+ npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder")
387
451
  bench_summary_data = [CompareConst.N_A] * 4
452
+ result_item.extend(npu_summary_data)
388
453
  result_item.extend(bench_summary_data)
454
+ err_msg = CompareConst.NO_BENCH
455
+ accuracy_check_res = CompareConst.N_A
389
456
  result_item.append(accuracy_check_res)
390
457
  result_item.append(err_msg)
391
- if npu_stack_info and index == 0:
392
- result_item.extend(npu_stack_info)
393
- else:
394
- result_item.append(CompareConst.NONE)
395
- if not md5_compare and not summary_compare and result_item[1] == CompareConst.N_A:
458
+ append_stack_info(result_item, npu_stack_info, index)
459
+ if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
396
460
  result_item.extend(["-1"])
397
461
  result.append(result_item)
398
462
 
399
463
 
400
- def merge_tensor(tensor_list, summary_compare, md5_compare):
464
+ def merge_tensor(tensor_list, dump_mode):
401
465
  op_dict = {}
402
466
  op_dict["op_name"] = []
403
- op_dict["input_struct"] = []
404
- op_dict["kwargs_struct"] = []
405
- op_dict["output_struct"] = []
406
- op_dict["summary"] = []
467
+ op_dict[CompareConst.INPUT_STRUCT] = []
468
+ op_dict[CompareConst.KWARGS_STRUCT] = []
469
+ op_dict[CompareConst.OUTPUT_STRUCT] = []
470
+ op_dict[CompareConst.PARAMS_STRUCT] = []
471
+ op_dict[CompareConst.PARAMS_GRAD_STRUCT] = []
472
+ op_dict[Const.SUMMARY] = []
407
473
  op_dict["stack_info"] = []
408
474
 
409
- all_mode_bool = not (summary_compare or md5_compare)
410
- if all_mode_bool:
475
+ if dump_mode == Const.ALL:
411
476
  op_dict["data_name"] = []
412
477
 
413
478
  for tensor in tensor_list:
479
+ # A dict(len=2) with 'full_op_name' and 'full_info' is added to the tensor only if self.stack_mode is True
414
480
  if len(tensor) == 2:
415
481
  op_dict['stack_info'].append(tensor['full_info'])
416
482
  break
483
+
417
484
  op_dict["op_name"].append(tensor['full_op_name'])
418
- name_ele_list = tensor['full_op_name'].split(Const.SEP)
419
- if not md5_compare:
420
- if "input" in name_ele_list:
421
- op_dict["input_struct"].append((tensor['dtype'], tensor['shape']))
422
- elif "kwarg" in name_ele_list:
423
- op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape']))
424
- elif "output" in name_ele_list:
425
- op_dict["output_struct"].append((tensor['dtype'], tensor['shape']))
485
+
486
+ _, state = get_name_and_state(tensor['full_op_name'])
487
+ struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
488
+ if not struct_key:
489
+ continue
490
+ if dump_mode == Const.MD5:
491
+ op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
426
492
  else:
427
- if "input" in name_ele_list:
428
- op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
429
- if "kwarg" in name_ele_list:
430
- op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
431
- elif "output" in name_ele_list:
432
- op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
433
- op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
434
-
435
- if all_mode_bool:
493
+ op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
494
+ op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]])
495
+
496
+ if dump_mode == Const.ALL:
436
497
  op_dict["data_name"].append(tensor['data_name'])
437
- data_name = op_dict["data_name"][-1].rsplit(Const.SEP, 1)[0]
438
- if data_name != "-1":
439
- op_dict["op_name"][-1] = data_name
440
498
 
441
- if not op_dict["kwargs_struct"]:
442
- del op_dict["kwargs_struct"]
499
+ if not op_dict[CompareConst.KWARGS_STRUCT]:
500
+ del op_dict[CompareConst.KWARGS_STRUCT]
443
501
  return op_dict if op_dict["op_name"] else {}
444
502
 
445
503
 
504
+ def print_compare_ends_info():
505
+ total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
506
+ logger.info('*' * total_len)
507
+ logger.info(f"*{CompareConst.COMPARE_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
508
+ logger.info('*' * total_len)
509
+
510
+
511
+ def table_value_is_valid(value: str) -> bool:
512
+ if not isinstance(value, str):
513
+ return True
514
+ try:
515
+ # -1.00 or +1.00 should be consdiered as digit numbers
516
+ float(value)
517
+ except ValueError:
518
+ # otherwise, they will be considered as formular injections
519
+ return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
520
+ return True
521
+
522
+
523
+ def get_name_and_state(name):
524
+ """
525
+ Get api/module name and state
526
+ example:
527
+ name = 'conv2d.forward.1.input.0'
528
+ return: ('conv2d.forward.1.', 'input')
529
+
530
+ name = 'Functional.pad.0.backward.output.0'
531
+ return: ('Functional.pad.0.backward.', 'output')
532
+
533
+ state type: input, output, kwargs, parameters, parameters_grad
534
+ """
535
+ if Const.PARAMS_GRAD in name.split(Const.SEP):
536
+ return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD
537
+
538
+ split = re.split(Const.REGEX_FORWARD_BACKWARD, name)
539
+ api = f'{split[0]}.{split[1]}.'
540
+ state_str = split[2]
541
+ match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str)
542
+ if not match:
543
+ raise CompareException(f'Invalid name string: {name}')
544
+ if match.group(1):
545
+ api = f'{api}{match.group(1)}'
546
+ state = match.group(2)
547
+ return api, state
548
+
549
+
550
+ def reorder_op_name_list(op_name_list):
551
+ if not op_name_list:
552
+ return op_name_list
553
+
554
+ parameters = []
555
+ output = []
556
+ parameters_grad = []
557
+ others = []
558
+ for x in op_name_list:
559
+ state = get_name_and_state(x)[1]
560
+ if state == Const.PARAMS:
561
+ parameters.append(x)
562
+ elif state == Const.OUTPUT:
563
+ output.append(x)
564
+ elif state == Const.PARAMS_GRAD:
565
+ parameters_grad.append(x)
566
+ else:
567
+ others.append(x)
568
+ # 合并others, parameters, 和output,确保parameters排在output前面
569
+ op_name_reorder = others + parameters + output + parameters_grad
570
+ return op_name_reorder
571
+
572
+
573
+ def reorder_op_x_list(op_name_list, summary_list, data_name_list):
574
+ """对op_name, summary, data_name重新排序,把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理"""
575
+ if not op_name_list or not summary_list:
576
+ return op_name_list, summary_list, data_name_list
577
+
578
+ index_map = {name: index for index, name in enumerate(op_name_list)}
579
+
580
+ op_name_reorder = reorder_op_name_list(op_name_list)
581
+ summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder]
582
+ if data_name_list:
583
+ data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder]
584
+ else:
585
+ data_name_reorder = data_name_list
586
+
587
+ return op_name_reorder, summary_reorder, data_name_reorder
588
+
589
+
446
590
  def _compare_parser(parser):
447
591
  parser.add_argument("-i", "--input_path", dest="input_path", type=str,
448
592
  help="<Required> The compare input path, a dict json.", required=True)
449
593
  parser.add_argument("-o", "--output_path", dest="output_path", type=str,
450
- help="<Required> The compare task result out path.", required=True)
594
+ help="<Required> The compare task result out path. Default path: ./output",
595
+ required=False, default="./output", nargs="?", const="./output")
451
596
  parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
452
597
  help="<optional> Whether to save stack info.", required=False)
453
598
  parser.add_argument("-c", "--compare_only", dest="compare_only", action="store_true",
@@ -457,8 +602,8 @@ def _compare_parser(parser):
457
602
  parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True,
458
603
  help="<optional> The cell mapping file path.", required=False)
459
604
  parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
460
- help="<optional> The api mapping file path.", required=False)
605
+ help="<optional> The api mapping file path.", required=False)
461
606
  parser.add_argument("-dm", "--data_mapping", dest="data_mapping", type=str,
462
607
  help="<optional> The data mapping file path.", required=False)
463
- parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str,
608
+ parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True,
464
609
  help="<optional> The layer mapping file path.", required=False)