mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -0,0 +1,222 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+
18
+ from msprobe.core.common.const import Const
19
+ from msprobe.core.common.file_utils import load_json
20
+ from msprobe.visualization.builder.msprobe_adapter import get_input_output
21
+ from msprobe.visualization.builder.msprobe_adapter import op_patterns
22
+ from msprobe.visualization.graph.graph import Graph
23
+ from msprobe.visualization.graph.node_op import NodeOp
24
+ from msprobe.visualization.utils import save_json_file, GraphConst
25
+
26
+
27
+ class GraphBuilder:
28
+ backward_pattern = re.compile(r"(\.backward\.)(\d+)$")
29
+ # 匹配以大写字母开头,后接任意字母,并以Template(结尾
30
+ template_pattern = re.compile(r'\b[A-Z][a-zA-Z]*Template\(')
31
+
32
+ @staticmethod
33
+ def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False):
34
+ """
35
+ GraphBuilder的对外提供的构图方法
36
+ Args:
37
+ construct_path: construct.json路径
38
+ data_path: dump.json路径
39
+ stack_path: stack.json路径
40
+ model_name: 模型名字,依赖外部输入
41
+ complete_stack: 完整的堆栈信息
42
+ Returns: Graph,代表图的数据结构
43
+ """
44
+ construct_dict = load_json(construct_path)
45
+ dump_dict = load_json(data_path)
46
+ stack_dict = load_json(stack_path)
47
+ if not complete_stack:
48
+ GraphBuilder._simplify_stack(stack_dict)
49
+ data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
50
+ graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
51
+ GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
52
+ GraphBuilder._collect_apis_between_modules(graph)
53
+ return graph
54
+
55
+ @staticmethod
56
+ def to_json(filename, config):
57
+ """
58
+ 将graph导出成.vis文件的接口
59
+ """
60
+ result = {}
61
+ if config.graph_b:
62
+ result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict()
63
+ result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict()
64
+ else:
65
+ result = config.graph_n.to_dict()
66
+ if config.tool_tip:
67
+ result[GraphConst.JSON_TIP_KEY] = config.tool_tip
68
+ if config.node_colors:
69
+ result[GraphConst.COLORS] = config.node_colors
70
+ if config.micro_steps:
71
+ result[GraphConst.MICRO_STEPS] = config.micro_steps
72
+ if config.task:
73
+ result[GraphConst.JSON_TASK_KEY] = config.task
74
+ result[GraphConst.OVERFLOW_CHECK] = config.overflow_check
75
+ save_json_file(filename, result)
76
+
77
+ @staticmethod
78
+ def _simplify_stack(stack_dict):
79
+ """
80
+ 精简堆栈内容,模块级保留包含"模块名("的堆栈,api级保留"xxxTemplate("的下一行堆栈
81
+
82
+ 例如模块 Module.layer3.0.bn2.BatchNorm2d.forward.0,模块名为bn2,匹配"bn2(",
83
+ 保留堆栈"File /home/models/resnet.py, line 97, in forward, \n out = self.bn2(out)"
84
+
85
+ 例如Api Tensor.__iadd__.4.forward,堆栈为:
86
+ "File /home/wrap_tensor.py, line 61, return TensorOPTemplate(op_name, hook)(*args, **kwargs)",
87
+ "File /home/torchvision/models/resnet.py, line 102, in forward, \n out += identity",
88
+ 匹配到第一行的"TensorOPTemplate(",保留下一行堆栈
89
+ """
90
+ module_pattern = re.compile(op_patterns[0])
91
+ for dump_name, stack_list in stack_dict.items():
92
+ if not isinstance(stack_list, list):
93
+ continue
94
+ if module_pattern.match(dump_name):
95
+ parts = dump_name.split(Const.SEP)
96
+ if len(parts) < abs(Const.LAYER_NAME_INDEX):
97
+ continue
98
+ module_name = parts[Const.LAYER_NAME_INDEX]
99
+ for stack in stack_list:
100
+ if re.search(module_name + r'\(', stack):
101
+ stack_list = [stack]
102
+ break
103
+ else:
104
+ for index, stack in enumerate(stack_list):
105
+ if GraphBuilder.template_pattern.search(stack) and index < len(stack_list) - 1:
106
+ stack_list = [stack_list[index + 1]]
107
+ break
108
+ stack_dict[dump_name] = stack_list
109
+
110
+ @staticmethod
111
+ def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id):
112
+ """
113
+ 如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点
114
+ """
115
+ # 匹配以.backward.后跟一个或多个数字结尾的模式
116
+ backward_pattern = r"(\.backward\.)(\d+)$"
117
+ forward_pattern = r"(\.forward\.)(\d+)$"
118
+ if re.search(backward_pattern, subnode_id) and not upnode_id:
119
+ forward_upnode_id = construct_dict.get(re.sub(backward_pattern, r".forward.\2", subnode_id))
120
+ if forward_upnode_id:
121
+ new_upnode_id = re.sub(forward_pattern, r".backward.\2", forward_upnode_id)
122
+ if new_upnode_id in construct_dict:
123
+ return new_upnode_id
124
+ return upnode_id
125
+
126
+ @staticmethod
127
+ def _init_nodes(graph, construct_dict, data_dict, stack_dict):
128
+ for subnode_id, upnode_id in construct_dict.items():
129
+ upnode_id = GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id)
130
+ if upnode_id:
131
+ upnode_op = NodeOp.get_node_op(upnode_id)
132
+ upnode = GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], upnode_op, upnode_id)
133
+ else:
134
+ upnode = graph.root
135
+ node_op = NodeOp.get_node_op(subnode_id)
136
+ GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], node_op, subnode_id, upnode)
137
+
138
+ @staticmethod
139
+ def _create_or_get_node(graph, data_stack_list, op, name, upnode=None):
140
+ if name in graph.node_map:
141
+ node = graph.get_node(name)
142
+ else:
143
+ graph.add_node(op, name, upnode)
144
+ node = graph.get_node(name)
145
+ node_data = data_stack_list[0].get(name, {})
146
+ node_stack_info = data_stack_list[1].get(name, [])
147
+ # 添加输入输出数据
148
+ input_data, output_data = get_input_output(node_data, node.id)
149
+ # 更新数据
150
+ node.set_input_output(input_data, output_data)
151
+ # 反向节点使用对应前向节点的堆栈信息
152
+ # 模块命名举例:Module.module.module.GPTModel.backward.0; API命名举例:Tensor.permute.1.backward
153
+ if (not node_stack_info and
154
+ (GraphBuilder.backward_pattern.search(name) or name.endswith(f'{Const.SEP}{Const.BACKWARD}'))):
155
+ forward_node = graph.get_node(
156
+ # 同名模块全局唯一,无论调用几次堆栈信息都一致,直接使用编号0的同名模块堆栈信息,避免遗漏
157
+ GraphBuilder.backward_pattern.sub(f'{Const.SEP}{Const.FORWARD}{Const.SEP}0', name)) \
158
+ if GraphBuilder.backward_pattern.search(name) \
159
+ else graph.get_node(name.replace(Const.BACKWARD, Const.FORWARD))
160
+ node_stack_info = forward_node.stack_info if forward_node \
161
+ else ['This backward node cannot find the forward node and cannot retrieve stack information.']
162
+ node.stack_info = node_stack_info
163
+ # 添加节点
164
+ node.add_upnode(upnode)
165
+ return node
166
+
167
+ @staticmethod
168
+ def _collect_apis_between_modules(graph):
169
+ """
170
+ 图首次展开,这些首层节点包含许多module和api,api数量很多导致图被拉得很长严重影响查阅,因此将module之间的apis收集起来成为节点
171
+ Args:
172
+ graph: 模型结构
173
+
174
+ Returns: None
175
+ """
176
+ i = 0
177
+ output = []
178
+ node_list = graph.root.subnodes
179
+ while i < len(node_list):
180
+ current_node = node_list[i]
181
+
182
+ # 当前节点为api,检查后续是否还有api
183
+ if current_node.op == NodeOp.function_api:
184
+ temp_nodes = [current_node]
185
+ i += 1
186
+ while i < len(node_list) and node_list[i].op == NodeOp.function_api:
187
+ temp_nodes.append(node_list[i])
188
+ i += 1
189
+
190
+ # 检查api节点是否大于等于2个
191
+ if len(temp_nodes) >= 2:
192
+ # 创建新节点,将这些api节点放入新节点的subnodes属性
193
+ node_id = graph.add_node(NodeOp.api_collection, GraphConst.APIS_BETWEEN_MODULES,
194
+ id_accumulation=True)
195
+ api_collection_node = graph.get_node(node_id)
196
+ api_collection_node.subnodes = temp_nodes
197
+ # 重新确立父子关系
198
+ for node in temp_nodes:
199
+ node.upnode = api_collection_node
200
+ api_collection_node.upnode = graph.root
201
+ output.append(api_collection_node)
202
+ else:
203
+ # 如果连续的api节点不足2个,将它们原样添加到输出列表
204
+ output.extend(temp_nodes)
205
+ else:
206
+ # 如果当前节点为module,直接添加到输出列表
207
+ output.append(current_node)
208
+ i += 1
209
+
210
+ graph.root.subnodes = output
211
+
212
+
213
+ class GraphExportConfig:
214
+ def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
215
+ overflow_check=False):
216
+ self.graph_n = graph_n
217
+ self.graph_b = graph_b
218
+ self.tool_tip = tool_tip
219
+ self.node_colors = node_colors
220
+ self.micro_steps = micro_steps
221
+ self.task = task
222
+ self.overflow_check = overflow_check
@@ -0,0 +1,227 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import re
16
+ import math
17
+ from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy
18
+ from msprobe.core.common.utils import set_dump_path, get_dump_mode
19
+ from msprobe.visualization.utils import GraphConst
20
+ from msprobe.core.common.const import Const
21
+ from msprobe.core.compare.acc_compare import ModeConfig
22
+
23
+ # 用于将节点名字解析成对应的NodeOp的规则
24
+ op_patterns = [
25
+ # NodeOp.module
26
+ r'^(Module.|Cell.)',
27
+ # NodeOp.function_api
28
+ r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)'
29
+ ]
30
+
31
+
32
+ def get_compare_mode(dump_path_param):
33
+ """
34
+ 获得比较模式,包括summary、MD5和真实数据三种模式
35
+ Args:
36
+ dump_path_param: 调用acc_compare接口所依赖的参数
37
+ Returns: 0 summary mode, 1 md5 mode, 2 true data mode
38
+ """
39
+ set_dump_path(dump_path_param)
40
+ dump_mode = get_dump_mode(dump_path_param)
41
+ compare_mode = GraphConst.DUMP_MODE_TO_GRAPHCOMPARE_MODE_MAPPING.get(dump_mode)
42
+ return compare_mode
43
+
44
+
45
+ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
46
+ """
47
+ 多进程运行生成真实数据
48
+ Args:
49
+ dump_path_param: 调用acc_compare接口所依赖的参数
50
+ csv_path: 生成文件路径
51
+ framework: 框架类型, pytorch或mindspore
52
+ is_cross_frame: 是否进行跨框架比对,仅支持mindspore比pytorch, 其中pytorch为标杆
53
+ """
54
+ mode_config = ModeConfig(stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.ALL)
55
+
56
+ if framework == Const.PT_FRAMEWORK:
57
+ from msprobe.pytorch.compare.pt_compare import PTComparator
58
+ return PTComparator(mode_config).do_multi_process(dump_path_param, csv_path)
59
+ else:
60
+ from msprobe.mindspore.compare.ms_compare import MSComparator
61
+ ms_comparator = MSComparator(mode_config)
62
+ ms_comparator.cross_frame = is_cross_frame
63
+ return ms_comparator.do_multi_process(dump_path_param, csv_path)
64
+
65
+
66
+ def get_input_output(node_data, node_id):
67
+ """
68
+ 将dump的原始数据进行拆解,分解为output和input两个数据
69
+ Args:
70
+ node_data: 属于单个节点的dump数据
71
+ node_id: 节点名字
72
+ """
73
+ input_data = {}
74
+ output_data = {}
75
+ op_parsed_list = read_op(node_data, node_id)
76
+ for item in op_parsed_list:
77
+ full_op_name = item.get('full_op_name', '')
78
+ if not full_op_name:
79
+ continue
80
+ if GraphConst.OUTPUT in full_op_name and GraphConst.INPUT not in full_op_name:
81
+ output_data[full_op_name] = item
82
+ else:
83
+ name = item.get('data_name')
84
+ # 节点参数名称尽量使用落盘数据的名称
85
+ if isinstance(name, str) and name != '-1':
86
+ input_data[name.rsplit(Const.SEP, 1)[0]] = item
87
+ else:
88
+ input_data[full_op_name] = item
89
+ return input_data, output_data
90
+
91
+
92
+ def compare_data(data_dict_list1, data_dict_list2):
93
+ """
94
+ 比较get_input_output中输出的结果是否结构一致,比较一致返回True
95
+ """
96
+ if len(data_dict_list1) != len(data_dict_list2):
97
+ return False
98
+ # 用于比较两个节点是否相等的关键字段
99
+ tag_keys = ['type', 'shape']
100
+ for key1, key2 in zip(data_dict_list1, data_dict_list2):
101
+ dict1 = data_dict_list1[key1]
102
+ dict2 = data_dict_list2[key2]
103
+ for tag_key in tag_keys:
104
+ tag_value1 = dict1.get(tag_key, None)
105
+ tag_value2 = dict2.get(tag_key, None)
106
+ if tag_value1 != tag_value2:
107
+ return False
108
+ return True
109
+
110
+
111
+ def compare_data_fuzzy(data_dict_list1, data_dict_list2):
112
+ """
113
+ 模糊匹配,仅校验参数shape是否一致
114
+ """
115
+ for x, y in zip(data_dict_list1.values(), data_dict_list2.values()):
116
+ x_shape = x.get(Const.SHAPE)
117
+ y_shape = y.get(Const.SHAPE)
118
+ if x_shape != y_shape:
119
+ return False
120
+ return True
121
+
122
+
123
+ def format_node_data(data_dict):
124
+ """
125
+ 批量进行节点数据的输出
126
+ """
127
+ del_list = ['requires_grad', 'full_op_name']
128
+ for _, value in data_dict.items():
129
+ if not isinstance(value, dict):
130
+ continue
131
+ for item in del_list:
132
+ if item in value:
133
+ del value[item]
134
+ _format_data(value)
135
+ return data_dict
136
+
137
+
138
+ def compare_node(node_ids, data_dicts, stack_json_data, compare_mode):
139
+ """
140
+ 调用acc_compare.py中的get_accuracy获得精度对比指标
141
+ 真实数据对比模式无法获得精度对比指标,需要调用多进程比对接口
142
+ Returns: 包含参数信息和对比指标(真实数据对比模式除外)的list
143
+ """
144
+ merge_n = _parse_node(node_ids[0], data_dicts[0], stack_json_data, compare_mode)
145
+ merge_b = _parse_node(node_ids[1], data_dicts[1], stack_json_data, compare_mode)
146
+ result = []
147
+ dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
148
+ get_accuracy(result, merge_n, merge_b, dump_mode)
149
+ return result
150
+
151
+
152
+ def _parse_node(node_id, data_dict, stack_json_data, compare_mode):
153
+ """
154
+ 转换节点,使其能够作为acc_compare.py中的get_accuracy的入参
155
+ """
156
+ dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
157
+ op_parsed_list = read_op(data_dict.get(node_id, {}), node_id)
158
+ if node_id in stack_json_data:
159
+ op_parsed_list.append(
160
+ {'full_op_name': node_id, 'full_info': stack_json_data[node_id]})
161
+ else:
162
+ op_parsed_list.append({'full_op_name': node_id, 'full_info': None})
163
+ result = merge_tensor(op_parsed_list, dump_mode)
164
+ if not result:
165
+ result['op_name'] = []
166
+ return result
167
+
168
+
169
+ def _format_decimal_string(s):
170
+ """
171
+ 使用正则表达式匹配包含数字、小数点和可选的百分号的字符串
172
+ """
173
+ pattern = re.compile(r'\d{1,20}\.\d{1,20}%?')
174
+ matches = pattern.findall(s)
175
+ for match in matches:
176
+ is_percent = match.endswith('%')
177
+ number_str = match.rstrip('%')
178
+ decimal_part = number_str.split('.')[1]
179
+ # 如果小数位数大于6,进行处理
180
+ if len(decimal_part) > GraphConst.ROUND_TH:
181
+ number_float = float(number_str)
182
+ formatted_number = f"{number_float:.{GraphConst.ROUND_TH}f}"
183
+ # 如果原来是百分数,加回百分号
184
+ if is_percent:
185
+ formatted_number += '%'
186
+ # 替换原字符串中的数值部分
187
+ s = s.replace(match, formatted_number)
188
+ return s
189
+
190
+
191
+ def _format_data(data_dict):
192
+ """
193
+ 格式化数据,小数保留6位,处理一些异常值
194
+ """
195
+ pattern = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)$'
196
+ all_null = False
197
+
198
+ keys_to_keep = ['type', 'group_ranks', 'group_id', 'data_name']
199
+ if data_dict.get('type') == 'torch.ProcessGroup':
200
+ keys_to_remove = [key for key in data_dict if key not in keys_to_keep]
201
+ for key in keys_to_remove:
202
+ del data_dict[key]
203
+
204
+ for key, value in data_dict.items():
205
+ if isinstance(value, str):
206
+ # 将单引号删掉,None换成null避免前端解析错误
207
+ value = value.replace("'", "").replace(GraphConst.NONE, GraphConst.NULL)
208
+ value = _format_decimal_string(value)
209
+ elif value is None or value == ' ':
210
+ value = GraphConst.NULL
211
+ # 科学计数法1.123123123123e-11,格式化为1.123123e-11
212
+ elif isinstance(value, float) and len(str(value)) < GraphConst.STR_MAX_LEN and re.match(pattern, str(value)):
213
+ value = "{:.6e}".format(value)
214
+ elif isinstance(value, float):
215
+ value = round(value, GraphConst.ROUND_TH)
216
+ # Inf会走入这里,确保转成Inf。另外给其他不符合预期的类型做兜底方案
217
+ if key != GraphConst.ERROR_KEY:
218
+ # 除了error_key不转str,其他都转str, 避免前端解析错误
219
+ value = str(value)
220
+ # max为null, 意味着这个参数值为null
221
+ if key == Const.MAX and value == GraphConst.NULL:
222
+ all_null = True
223
+ data_dict[key] = value
224
+ # 字典里的value全null,只保留一个null
225
+ if all_null:
226
+ data_dict.clear()
227
+ data_dict[GraphConst.VALUE] = GraphConst.NULL
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,180 @@
1
+ # Copyright (c) 2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data
18
+ from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df
19
+ from msprobe.visualization.graph.graph import Graph, NodeOp
20
+ from msprobe.visualization.graph.node_colors import NodeColors
21
+ from msprobe.visualization.compare.mode_adapter import ModeAdapter
22
+ from msprobe.core.common.const import Const
23
+
24
+
25
+ class GraphComparator:
26
+ def __init__(self, graphs, dump_path_param, args, mapping_dict=None):
27
+ self.graph_n = graphs[0]
28
+ self.graph_b = graphs[1]
29
+ self._parse_param(dump_path_param, args.output_path)
30
+ self.framework = args.framework
31
+ self.mapping_dict = mapping_dict
32
+ self.fuzzy_match = args.fuzzy_match
33
+ self.pattern = re.compile(r'\.\d+\.')
34
+
35
+ def compare(self):
36
+ """
37
+ 比较函数,初始化结束后单独调用。比较结果写入graph_n
38
+ """
39
+ if self.fuzzy_match:
40
+ self._compare_nodes_fuzzy(self.graph_n.root)
41
+ else:
42
+ self._compare_nodes(self.graph_n.root)
43
+ self._postcompare()
44
+
45
+ def add_compare_result_to_node(self, node, compare_result_list):
46
+ """
47
+ 将比对结果添加到节点的输入输出数据中
48
+ Args:
49
+ node: 节点
50
+ compare_result_list: 包含参数信息和对比指标(真实数据对比模式除外)的list
51
+ """
52
+ # 真实数据比对,先暂存节点,在多进程对比得到精度指标后,再将指标添加到节点中
53
+ if self.ma.prepare_real_data(node):
54
+ return
55
+ compare_in_dict = {}
56
+ compare_out_dict = {}
57
+ # input和output对比数据分开
58
+ for item in compare_result_list:
59
+ if not isinstance(item, (list, tuple)) or not item:
60
+ continue
61
+ if '.output.' in item[0]:
62
+ compare_out_dict[item[0]] = item
63
+ else:
64
+ compare_in_dict[item[0]] = item
65
+ precision_index, other_dict = (
66
+ self.ma.parse_result(node, [compare_in_dict, compare_out_dict]))
67
+ node.data[GraphConst.JSON_INDEX_KEY] = precision_index
68
+ node.data.update(other_dict)
69
+
70
+ def _parse_param(self, dump_path_param, output_path):
71
+ self.dump_path_param = dump_path_param
72
+ self.output_path = output_path
73
+ compare_mode = get_compare_mode(self.dump_path_param)
74
+ self.ma = ModeAdapter(compare_mode)
75
+ self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path'))
76
+ self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path'))
77
+ self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path'))
78
+
79
+ def _postcompare(self):
80
+ self._handle_api_collection_index()
81
+ if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE:
82
+ return
83
+ df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode)
84
+ df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False)
85
+ compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
86
+ for node in self.ma.compare_nodes:
87
+ precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
88
+ node.data[GraphConst.JSON_INDEX_KEY] = precision_index
89
+
90
+ def _handle_api_collection_index(self):
91
+ """
92
+ api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标
93
+ md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差
94
+ """
95
+ for node in self.graph_n.root.subnodes:
96
+ if node.op == NodeOp.api_collection:
97
+ precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \
98
+ else GraphConst.MIN_INDEX_KEY
99
+ for api in node.subnodes:
100
+ precision_index = min(precision_index,
101
+ api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \
102
+ if self.ma.compare_mode == GraphConst.MD5_COMPARE \
103
+ else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
104
+ node.data[GraphConst.JSON_INDEX_KEY] = precision_index
105
+
106
+ def _compare_nodes(self, node_n):
107
+ """
108
+ 递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比
109
+ 这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息
110
+ """
111
+ if self.mapping_dict:
112
+ node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_dict)
113
+ if node_b:
114
+ ancestors_n.append(node_n.id)
115
+ ancestors_b.append(node_b.id)
116
+ node_n.matched_node_link = ancestors_b
117
+ node_b.matched_node_link = ancestors_n
118
+ else:
119
+ node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b)
120
+ if node_b:
121
+ ancestors.append(node_b.id)
122
+ node_n.add_link(node_b, ancestors)
123
+ if node_b:
124
+ # 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口
125
+ self._get_and_add_result(node_n, node_b)
126
+ for subnode in node_n.subnodes:
127
+ self._compare_nodes(subnode)
128
+
129
+ def _compare_nodes_fuzzy(self, node_n):
130
+ if node_n.op != NodeOp.function_api:
131
+ # 模块经过模糊匹配
132
+ node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id))
133
+ if node_b:
134
+ self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b)
135
+ # 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配
136
+ recount_result_n = self._recount_api_node(node_n)
137
+ recount_result_b = self._recount_api_node(node_b)
138
+ for recount_node_id, node_id_n in recount_result_n.items():
139
+ api_node_n = self.graph_n.node_map.get(node_id_n)
140
+ if not api_node_n:
141
+ continue
142
+ api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(
143
+ api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)))
144
+ if api_node_b:
145
+ self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b)
146
+ for sub_node in node_n.subnodes:
147
+ self._compare_nodes_fuzzy(sub_node)
148
+
149
+ def _get_and_add_result(self, node_n, node_b):
150
+ compare_result_list = compare_node([node_n.id, node_b.id],
151
+ [self.data_n_dict, self.data_b_dict],
152
+ self.stack_json_data, self.ma.compare_mode)
153
+ if compare_result_list:
154
+ self.ma.add_csv_data(compare_result_list)
155
+ self.add_compare_result_to_node(node_n, compare_result_list)
156
+
157
+ def _recount_api_node(self, node):
158
+ """
159
+ 两个匹配上的模块, 忽略各自模块下所有api的dump调用次数, 并赋予模块中的调用顺序
160
+ Return:
161
+ {赋予模块中的调用顺序的node_id: 原始node_id}
162
+ """
163
+ recount_result = {}
164
+ node_count = {}
165
+ for sub_node in node.subnodes:
166
+ if sub_node.op == NodeOp.function_api:
167
+ # 忽略dump调用次数
168
+ count_removed_id = self.pattern.sub(Const.SEP, sub_node.id)
169
+ node_count[count_removed_id] = node_count.get(count_removed_id, 0) + 1
170
+ # 赋予模块中的调用顺序
171
+ recount_node_id = count_removed_id + str(node_count.get(count_removed_id))
172
+ recount_result[recount_node_id] = sub_node.id
173
+ return recount_result
174
+
175
+ def _process_matched_nodes(self, node_n, node_b, ancestors_n, ancestors_b):
176
+ ancestors_n.append(node_n.id)
177
+ ancestors_b.append(node_b.id)
178
+ node_n.matched_node_link = ancestors_b
179
+ node_b.matched_node_link = ancestors_n
180
+ self._get_and_add_result(node_n, node_b)