mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -0,0 +1,209 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from msprobe.core.overflow_check.checker import AnomalyDetector
16
+ from msprobe.visualization.graph.base_node import BaseNode
17
+ from msprobe.visualization.graph.node_op import NodeOp
18
+ from msprobe.visualization.utils import GraphConst
19
+ from msprobe.core.common.log import logger
20
+ from msprobe.core.common.const import Const
21
+
22
+
23
+ MAX_RECUR_LEVEL = 100
24
+
25
+
26
+ class Graph:
27
+ def __init__(self, model_name, data_path='', dump_data=None):
28
+ self.node_map = {}
29
+ self.node_id_map = {}
30
+ self.add_node(NodeOp.module, model_name)
31
+ self.root = self.get_node(model_name)
32
+ self.data_path = data_path
33
+ self.dump_data = dump_data
34
+
35
+ def __str__(self):
36
+ infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map]
37
+ info = "\n".join(infos)
38
+ return info
39
+
40
+ @staticmethod
41
+ def match(graph_n, node_n, graph_b):
42
+ """
43
+ 给定节点n,在另一个graph中匹配它对应的节点。前置条件是它的父节点匹配已经完成
44
+ 目前采用完全匹配的方式,后续可能在这里加入一定的模糊匹配逻辑
45
+ 返回匹配结果,匹配到的节点,以及祖先树。没匹配到则返回None, []
46
+ """
47
+ if not node_n or node_n.id not in graph_b.node_map:
48
+ return None, []
49
+ node_b = graph_b.node_map.get(node_n.id)
50
+ if node_n != node_b:
51
+ return None, []
52
+ ancestors_n = node_n.get_ancestors()
53
+ ancestors_b = node_b.get_ancestors()
54
+ if ancestors_n != ancestors_b:
55
+ return None, []
56
+ return node_b, ancestors_n
57
+
58
+ @staticmethod
59
+ def mapping_match(node_n, graph_b, mapping_dict):
60
+ """
61
+ 根据映射配置对节点进行匹配
62
+ """
63
+ node_b = graph_b.node_map.get(mapping_dict.get(node_n.id, node_n.id))
64
+ if not node_b:
65
+ return None, [], []
66
+ ancestors_n = node_n.get_ancestors()
67
+ ancestors_b = node_b.get_ancestors()
68
+ return node_b, ancestors_n, ancestors_b
69
+
70
+
71
+ @staticmethod
72
+ def fuzzy_match(node_n, node_b):
73
+ if not node_n or not node_b or not node_n.fuzzy_eq(node_b):
74
+ return None, [], []
75
+ ancestors_n = node_n.get_ancestors()
76
+ ancestors_b = node_b.get_ancestors()
77
+ return node_b, ancestors_n, ancestors_b
78
+
79
+ @staticmethod
80
+ def dfs(node, result):
81
+ info = node.to_dict()
82
+ result[node.id] = info
83
+ for subnode in node.subnodes:
84
+ Graph.dfs(subnode, result)
85
+
86
+ @staticmethod
87
+ def split_nodes_by_micro_step(nodes):
88
+ """
89
+ 根据Module名称, 区分一个step中的多个micro steps.
90
+ 一个micro step必须是一次完整的前反向过程
91
+ Example::
92
+ =============== micro step0
93
+ Module.forward
94
+ Module.forward
95
+ ...
96
+ Module.backward
97
+ Module.backward
98
+ =============== micro step1
99
+ Module.forward
100
+ Module.forward
101
+ ...
102
+ Module.backward
103
+ Module.backward
104
+ =============== micro step2
105
+ Module.forward
106
+ Module.forward
107
+ ...
108
+ Module.backward
109
+ Module.backward
110
+
111
+ 如果是非Module节点,分类到前一个Module节点所在的micro step.
112
+ """
113
+ result = {}
114
+ micro_step = 0
115
+ result[micro_step] = []
116
+ backward_flag = False
117
+
118
+ for node in nodes:
119
+ if node.op == NodeOp.module:
120
+ if f'{Const.SEP}{Const.FORWARD}{Const.SEP}' in node.id:
121
+ if backward_flag:
122
+ micro_step += 1
123
+ result[micro_step] = []
124
+ backward_flag = False
125
+ else:
126
+ backward_flag = True
127
+ result[micro_step].append(node)
128
+ return result
129
+
130
+ def add_node(self, node_op, node_id, up_node=None, id_accumulation=False):
131
+ """
132
+ 在graph中进行节点的添加
133
+ Args:
134
+ node_op: 需要添加的节点类型
135
+ node_id: 需要添加的节点id
136
+ up_node:对应节点的父节点
137
+ id_accumulation: 是否对传入的重复node_id进行累加
138
+ """
139
+ if node_id in self.node_map:
140
+ if id_accumulation:
141
+ self.node_id_map[node_id] = 0
142
+ else:
143
+ return node_id
144
+ if id_accumulation:
145
+ if node_id in self.node_id_map:
146
+ self.node_id_map[node_id] += 1
147
+ else:
148
+ self.node_id_map[node_id] = 0
149
+ node_id = f'{node_id}.{self.node_id_map[node_id]}'
150
+ node = BaseNode(node_op, node_id, up_node)
151
+ self.node_map[node_id] = node
152
+ return node_id
153
+
154
+ def get_node(self, node_id):
155
+ """
156
+ 返回节点,不存在返回None
157
+ """
158
+ return self.node_map.get(node_id, None)
159
+
160
+ def to_dict(self):
161
+ """
162
+ 用于数据输出
163
+ """
164
+ result = {}
165
+ result[GraphConst.JSON_ROOT_KEY] = self.root.id if self.root else 'None'
166
+ result[GraphConst.JSON_DATA_KEY] = self.data_path
167
+ result[GraphConst.JSON_NODE_KEY] = {}
168
+ for node_id in self.node_map:
169
+ info = self.node_map.get(node_id).to_dict()
170
+ result[GraphConst.JSON_NODE_KEY][node_id] = info
171
+ return result
172
+
173
+ def paging_by_micro_step(self, graph_other=None):
174
+ """
175
+ 给graph首层节点增加micro step标记,供前端分页展示,有助于在处理大规模图数据时进行优化和管理
176
+ 比对场景中,同步更新另一个图graph_other中相应节点的micro step信息
177
+ Args:
178
+ self: 当前graph
179
+ graph_other: 可选参数,另一个graph
180
+ Returns: 分批的数量
181
+ """
182
+ batches_n = Graph.split_nodes_by_micro_step(self.root.subnodes)
183
+ for batch_number, nodes in batches_n.items():
184
+ for node in nodes:
185
+ node.micro_step_id = batch_number
186
+ # 在graph_other中更新已匹配节点的micro_step_id
187
+ if graph_other and node.matched_node_link:
188
+ node_other = graph_other.get_node(node.matched_node_link[-1])
189
+ if node_other:
190
+ node_other.micro_step_id = batch_number
191
+ # 遍历graph_other根节点下的所有子节点,确保未匹配节点也有micro_step_id
192
+ if graph_other:
193
+ for node in graph_other.root.subnodes:
194
+ if node.micro_step_id is None:
195
+ try:
196
+ micro_step_id = int(node.id.split(Const.SEP)[-1])
197
+ except ValueError:
198
+ micro_step_id = 0
199
+ node.micro_step_id = micro_step_id
200
+ return len(batches_n)
201
+
202
+ def overflow_check(self):
203
+ detector = AnomalyDetector(self.dump_data)
204
+ detector.analyze().filter()
205
+
206
+ for node_id, _node in self.node_map.items():
207
+ if detector.has_overflow(node_id):
208
+ lv = detector.get_overflow_level(node_id)
209
+ _node.set_overflow_level(lv)
@@ -0,0 +1,95 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from enum import Enum
17
+ from msprobe.visualization.utils import GraphConst, ToolTip
18
+
19
+ SUMMARY_DESCRIPTION = "此节点所有输入输出的统计量相对误差, 值越大代表测量值与标杆值的偏差越大, 相对误差计算方式:|(测量值-标杆值)/标杆值|"
20
+ REAL_DATA_DESCRIPTION = (f"此节点所有输入的最小双千分之一和所有输出的最小双千分之一的差值的绝对值, 代表双千指标的变化情况, "
21
+ f"值越大代表测量值与标杆值的偏差越大, 双千分之一指标计算方式:{ToolTip.ONE_THOUSANDTH_ERR_RATIO}")
22
+ MD5_DESCRIPTION_N = "与标杆相比, 此节点任意输入输出的md5值不同"
23
+ MD5_DESCRIPTION_Y = "与标杆相比, 此节点所有输入输出的md5值相同"
24
+ NOT_MATCHED = "比对过程中节点未匹配上"
25
+
26
+
27
+ class NodeColors(Enum):
28
+ # 枚举值后缀数字越小, 颜色越浅
29
+ # value值左闭右开, 两个值相同代表固定值
30
+ YELLOW_1 = ("#FFFCF3", {
31
+ GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0, 0.2], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION},
32
+ GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0, 0.05], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION},
33
+ GraphConst.MD5_COMPARE: {GraphConst.VALUE: [1, 1], GraphConst.DESCRIPTION: MD5_DESCRIPTION_Y},
34
+ })
35
+ YELLOW_2 = ("#FFEDBE", {
36
+ GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.2, 0.4], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION},
37
+ GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.05, 0.1], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION}
38
+ })
39
+ ORANGE_1 = ("#FFDC7F", {
40
+ GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.4, 0.6], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION},
41
+ GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.1, 0.15], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION}
42
+ })
43
+ ORANGE_2 = ("#FFC62E", {
44
+ GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.6, 0.8], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION},
45
+ GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.15, 0.2], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION}
46
+ })
47
+ RED = ("#FF704D", {
48
+ GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.8, 1], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION},
49
+ GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.2, 1], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION},
50
+ GraphConst.MD5_COMPARE: {GraphConst.VALUE: [0, 0], GraphConst.DESCRIPTION: MD5_DESCRIPTION_N},
51
+ })
52
+ GREY = ("#C7C7C7", {
53
+ GraphConst.VALUE: [], GraphConst.DESCRIPTION: NOT_MATCHED
54
+ })
55
+
56
+ def __init__(self, hex_value, mode_info):
57
+ self.hex_value = hex_value
58
+ self.mode_info = mode_info
59
+
60
+ @staticmethod
61
+ def get_node_colors(mode):
62
+ """
63
+ 获取不同比对模式下的颜色说明
64
+ Args:
65
+ mode: 比对模式
66
+ Returns: 颜色说明
67
+ """
68
+ return {
69
+ color.hex_value: color.get_info_by_mode(mode) for color in NodeColors if color.get_info_by_mode(mode)
70
+ }
71
+
72
+ @staticmethod
73
+ def get_node_error_status(mode, value):
74
+ """
75
+ 判断精度数据比对指标是否大于基准值
76
+ Args:
77
+ mode: 比对模式
78
+ value: 精度数据比对指标
79
+ Returns: bool
80
+ """
81
+ info = NodeColors.ORANGE_1.get_info_by_mode(mode)
82
+ if info and GraphConst.VALUE in info:
83
+ value_range = info[GraphConst.VALUE]
84
+ return value > value_range[0]
85
+ return False
86
+
87
+ def get_info_by_mode(self, mode):
88
+ if isinstance(self.mode_info, dict):
89
+ # 检查是否是模式特定的信息
90
+ if isinstance(next(iter(self.mode_info.values())), dict):
91
+ return self.mode_info.get(mode, {})
92
+ else:
93
+ # 所有模式共享相同的信息
94
+ return self.mode_info
95
+ return {}
@@ -0,0 +1,39 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from enum import Enum
17
+ import re
18
+ from msprobe.visualization.builder.msprobe_adapter import op_patterns
19
+
20
+
21
+ class NodeOp(Enum):
22
+ module = 0
23
+ function_api = 1
24
+ api_collection = 9
25
+
26
+
27
+ @staticmethod
28
+ def get_node_op(node_name: str):
29
+ """
30
+ 基于代表节点的字符串,解析节点种类
31
+ """
32
+ for op in NodeOp:
33
+ index = op.value
34
+ if index < 0 or index >= len(op_patterns):
35
+ raise Exception("NodeOp and op_patterns in MsprobeAdapter do not match")
36
+ pattern = op_patterns[index]
37
+ if re.match(pattern, node_name):
38
+ return op
39
+ raise Exception(f"Cannot parse node_name {node_name} into NodeOp")
@@ -0,0 +1,288 @@
1
+ # Copyright (c) 2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import time
18
+ import json
19
+ from msprobe.core.common.file_utils import (FileOpen, check_file_type, create_directory, FileChecker,
20
+ check_file_or_directory_path)
21
+ from msprobe.core.common.const import FileCheckConst, Const
22
+ from msprobe.core.common.utils import CompareException
23
+ from msprobe.core.overflow_check.checker import AnomalyDetector
24
+ from msprobe.visualization.compare.graph_comparator import GraphComparator
25
+ from msprobe.visualization.utils import GraphConst, check_directory_content
26
+ from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig
27
+ from msprobe.core.common.log import logger
28
+ from msprobe.visualization.graph.node_colors import NodeColors
29
+ from msprobe.core.compare.layer_mapping import generate_api_mapping_by_layer_mapping
30
+ from msprobe.core.compare.utils import check_and_return_dir_contents
31
+ from msprobe.visualization.graph.distributed_analyzer import DistributedAnalyzer
32
+
33
+ current_time = time.strftime("%Y%m%d%H%M%S")
34
+
35
+
36
+ def _compare_graph(input_param, args):
37
+ logger.info('Start building model graphs...')
38
+ # 对两个数据进行构图
39
+ dump_path_n = input_param.get('npu_path')
40
+ dump_path_b = input_param.get('bench_path')
41
+ construct_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.CONSTRUCT_FILE),
42
+ FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
43
+ construct_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.CONSTRUCT_FILE),
44
+ FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
45
+ data_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.DUMP_FILE), FileCheckConst.FILE,
46
+ FileCheckConst.READ_ABLE).common_check()
47
+ data_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.DUMP_FILE), FileCheckConst.FILE,
48
+ FileCheckConst.READ_ABLE).common_check()
49
+ stack_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.STACK_FILE), FileCheckConst.FILE,
50
+ FileCheckConst.READ_ABLE).common_check()
51
+ stack_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.STACK_FILE), FileCheckConst.FILE,
52
+ FileCheckConst.READ_ABLE).common_check()
53
+ graph_n = GraphBuilder.build(construct_path_n, data_path_n, stack_path_n, complete_stack=args.complete_stack)
54
+ graph_b = GraphBuilder.build(construct_path_b, data_path_b, stack_path_b, complete_stack=args.complete_stack)
55
+ logger.info('Model graphs built successfully, start Comparing graphs...')
56
+ # 基于graph、stack和data进行比较
57
+ dump_path_param = {
58
+ 'npu_json_path': data_path_n,
59
+ 'bench_json_path': data_path_b,
60
+ 'stack_json_path': stack_path_n,
61
+ 'is_print_compare_log': input_param.get("is_print_compare_log", True)
62
+ }
63
+ mapping_dict = None
64
+ if args.layer_mapping:
65
+ yaml_path = FileChecker(args.layer_mapping, FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
66
+ try:
67
+ mapping_dict = generate_api_mapping_by_layer_mapping(data_path_n, data_path_b, yaml_path)
68
+ except Exception:
69
+ logger.warning('The layer mapping file parsing failed, please check file format, mapping is not effective.')
70
+ graph_comparator = GraphComparator([graph_n, graph_b], dump_path_param, args, mapping_dict=mapping_dict)
71
+ graph_comparator.compare()
72
+ micro_steps = graph_n.paging_by_micro_step(graph_b)
73
+ # 开启溢出检测
74
+ if args.overflow_check:
75
+ graph_n.overflow_check()
76
+ graph_b.overflow_check()
77
+
78
+ return CompareGraphResult(graph_n, graph_b, graph_comparator, micro_steps)
79
+
80
+
81
+ def _export_compare_graph_result(args, graphs, graph_comparator, micro_steps,
82
+ output_file_name=f'compare_{current_time}.vis'):
83
+ create_directory(args.output_path)
84
+ output_path = os.path.join(args.output_path, output_file_name)
85
+ task = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(graph_comparator.ma.compare_mode)
86
+ export_config = GraphExportConfig(graphs[0], graphs[1], graph_comparator.ma.get_tool_tip(),
87
+ NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps, task,
88
+ args.overflow_check)
89
+ GraphBuilder.to_json(output_path, export_config)
90
+ logger.info(f'Model graphs compared successfully, the result file is saved in {output_path}')
91
+
92
+
93
+ def _build_graph(dump_path, args):
94
+ logger.info('Start building model graph...')
95
+ construct_path = FileChecker(os.path.join(dump_path, GraphConst.CONSTRUCT_FILE), FileCheckConst.FILE,
96
+ FileCheckConst.READ_ABLE).common_check()
97
+ data_path = FileChecker(os.path.join(dump_path, GraphConst.DUMP_FILE), FileCheckConst.FILE,
98
+ FileCheckConst.READ_ABLE).common_check()
99
+ stack_path = FileChecker(os.path.join(dump_path, GraphConst.STACK_FILE), FileCheckConst.FILE,
100
+ FileCheckConst.READ_ABLE).common_check()
101
+ graph = GraphBuilder.build(construct_path, data_path, stack_path, complete_stack=args.complete_stack)
102
+ micro_steps = graph.paging_by_micro_step()
103
+ # 开启溢出检测
104
+ if args.overflow_check:
105
+ graph.overflow_check()
106
+ return BuildGraphResult(graph, micro_steps)
107
+
108
+
109
+ def _export_build_graph_result(out_path, graph, micro_steps, overflow_check,
110
+ output_file_name=f'build_{current_time}.vis'):
111
+ create_directory(out_path)
112
+ output_path = os.path.join(out_path, output_file_name)
113
+ GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps, overflow_check=overflow_check))
114
+ logger.info(f'Model graph built successfully, the result file is saved in {output_path}')
115
+
116
+
117
+ def _compare_graph_ranks(input_param, args, step=None):
118
+ dump_rank_n = input_param.get('npu_path')
119
+ dump_rank_b = input_param.get('bench_path')
120
+ npu_ranks = sorted(check_and_return_dir_contents(dump_rank_n, Const.RANK))
121
+ bench_ranks = sorted(check_and_return_dir_contents(dump_rank_b, Const.RANK))
122
+ if npu_ranks != bench_ranks:
123
+ logger.error('The number of ranks in the two runs are different. Unable to match the ranks.')
124
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
125
+ compare_graph_results = []
126
+ for nr, br in zip(npu_ranks, bench_ranks):
127
+ logger.info(f'Start processing data for {nr}...')
128
+ input_param['npu_path'] = os.path.join(dump_rank_n, nr)
129
+ input_param['bench_path'] = os.path.join(dump_rank_b, br)
130
+ output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis'
131
+ result = _compare_graph(input_param, args)
132
+ result.output_file_name = output_file_name
133
+ if nr != Const.RANK:
134
+ try:
135
+ result.rank = int(nr.replace(Const.RANK, ""))
136
+ except Exception as e:
137
+ logger.error('The folder name format is incorrect, expected rank+number.')
138
+ raise CompareException(CompareException.INVALID_PATH_ERROR) from e
139
+ # 暂存所有rank的graph,用于匹配rank间的分布式节点
140
+ compare_graph_results.append(result)
141
+
142
+ # 匹配rank间的分布式节点
143
+ if len(compare_graph_results) > 1:
144
+ DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results},
145
+ args.overflow_check).distributed_match()
146
+ DistributedAnalyzer({obj.rank: obj.graph_b for obj in compare_graph_results},
147
+ args.overflow_check).distributed_match()
148
+
149
+ for result in compare_graph_results:
150
+ _export_compare_graph_result(args, [result.graph_n, result.graph_b], result.graph_comparator,
151
+ result.micro_steps, output_file_name=result.output_file_name)
152
+
153
+
154
+ def _compare_graph_steps(input_param, args):
155
+ dump_step_n = input_param.get('npu_path')
156
+ dump_step_b = input_param.get('bench_path')
157
+
158
+ npu_steps = sorted(check_and_return_dir_contents(dump_step_n, Const.STEP))
159
+ bench_steps = sorted(check_and_return_dir_contents(dump_step_b, Const.STEP))
160
+
161
+ if npu_steps != bench_steps:
162
+ logger.error('The number of steps in the two runs are different. Unable to match the steps.')
163
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
164
+
165
+ for folder_step in npu_steps:
166
+ logger.info(f'Start processing data for {folder_step}...')
167
+ input_param['npu_path'] = os.path.join(dump_step_n, folder_step)
168
+ input_param['bench_path'] = os.path.join(dump_step_b, folder_step)
169
+
170
+ _compare_graph_ranks(input_param, args, step=folder_step)
171
+
172
+
173
+ def _build_graph_ranks(dump_ranks_path, args, step=None):
174
+ ranks = sorted(check_and_return_dir_contents(dump_ranks_path, Const.RANK))
175
+ build_graph_results = []
176
+ for rank in ranks:
177
+ logger.info(f'Start processing data for {rank}...')
178
+ dump_path = os.path.join(dump_ranks_path, rank)
179
+ output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis'
180
+ result = _build_graph(dump_path, args)
181
+ result.output_file_name = output_file_name
182
+ if rank != Const.RANK:
183
+ try:
184
+ result.rank = int(rank.replace(Const.RANK, ""))
185
+ except Exception as e:
186
+ logger.error('The folder name format is incorrect, expected rank+number.')
187
+ raise CompareException(CompareException.INVALID_PATH_ERROR) from e
188
+ build_graph_results.append(result)
189
+
190
+ if len(build_graph_results) > 1:
191
+ DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results},
192
+ args.overflow_check).distributed_match()
193
+
194
+ for result in build_graph_results:
195
+ _export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check,
196
+ result.output_file_name)
197
+
198
+
199
+ def _build_graph_steps(dump_steps_path, args):
200
+ steps = sorted(check_and_return_dir_contents(dump_steps_path, Const.STEP))
201
+ for step in steps:
202
+ logger.info(f'Start processing data for {step}...')
203
+ dump_ranks_path = os.path.join(dump_steps_path, step)
204
+ _build_graph_ranks(dump_ranks_path, args, step)
205
+
206
+
207
+ def _graph_service_parser(parser):
208
+ parser.add_argument("-i", "--input_path", dest="input_path", type=str,
209
+ help="<Required> The compare input path, a dict json.", required=True)
210
+ parser.add_argument("-o", "--output_path", dest="output_path", type=str,
211
+ help="<Required> The compare task result out path.", required=True)
212
+ parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str,
213
+ help="<Optional> The layer mapping file path.", required=False)
214
+ parser.add_argument("-oc", "--overflow_check", dest="overflow_check", action="store_true",
215
+ help="<Optional> whether open overflow_check for graph.", required=False)
216
+ parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true",
217
+ help="<Optional> Whether to perform a fuzzy match on the api name.", required=False)
218
+ parser.add_argument("-cs", "--complete_stack", dest="complete_stack", action="store_true",
219
+ help="<Optional> Whether to use complete stack information.", required=False)
220
+
221
+
222
+ def _graph_service_command(args):
223
+ with FileOpen(args.input_path, "r") as file:
224
+ input_param = json.load(file)
225
+ npu_path = input_param.get("npu_path")
226
+ bench_path = input_param.get("bench_path")
227
+ check_file_or_directory_path(npu_path, isdir=True)
228
+ if bench_path:
229
+ check_file_or_directory_path(bench_path, isdir=True)
230
+ if check_file_type(npu_path) == FileCheckConst.DIR and not bench_path:
231
+ content = check_directory_content(npu_path)
232
+ if content == GraphConst.RANKS:
233
+ _build_graph_ranks(npu_path, args)
234
+ elif content == GraphConst.STEPS:
235
+ _build_graph_steps(npu_path, args)
236
+ else:
237
+ result = _build_graph(npu_path, args)
238
+ _export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check)
239
+ elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
240
+ content_n = check_directory_content(npu_path)
241
+ content_b = check_directory_content(bench_path)
242
+ if content_n != content_b:
243
+ raise ValueError('The directory structures of npu_path and bench_path are inconsistent.')
244
+ if content_n == GraphConst.RANKS:
245
+ _compare_graph_ranks(input_param, args)
246
+ elif content_n == GraphConst.STEPS:
247
+ _compare_graph_steps(input_param, args)
248
+ else:
249
+ result = _compare_graph(input_param, args)
250
+ _export_compare_graph_result(args, [result.graph_n, result.graph_b],
251
+ result.graph_comparator, result.micro_steps)
252
+ else:
253
+ logger.error("The npu_path or bench_path should be a folder.")
254
+ raise CompareException(CompareException.INVALID_COMPARE_MODE)
255
+
256
+
257
+ def _pt_graph_service_parser(parser):
258
+ _graph_service_parser(parser)
259
+
260
+
261
+ def _pt_graph_service_command(args):
262
+ _graph_service_command(args)
263
+
264
+
265
+ def _ms_graph_service_parser(parser):
266
+ _graph_service_parser(parser)
267
+
268
+
269
+ def _ms_graph_service_command(args):
270
+ _graph_service_command(args)
271
+
272
+
273
+ class CompareGraphResult:
274
+ def __init__(self, graph_n, graph_b, graph_comparator, micro_steps, rank=0, output_file_name=''):
275
+ self.graph_n = graph_n
276
+ self.graph_b = graph_b
277
+ self.graph_comparator = graph_comparator
278
+ self.micro_steps = micro_steps
279
+ self.rank = rank
280
+ self.output_file_name = output_file_name
281
+
282
+
283
+ class BuildGraphResult:
284
+ def __init__(self, graph, micro_steps, rank=0, output_file_name=''):
285
+ self.graph = graph
286
+ self.micro_steps = micro_steps
287
+ self.rank = rank
288
+ self.output_file_name = output_file_name