mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -0,0 +1,395 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from enum import Enum
16
+ from msprobe.visualization.utils import GraphConst
17
+ from msprobe.core.common.const import Const, CompareConst
18
+ from msprobe.core.common.log import logger
19
+
20
+
21
+ class CommunicationType(Enum):
22
+ """
23
+ 通信类型:发送、接收、发送接收
24
+ """
25
+ SEND = 'send'
26
+ RECEIVE = 'receive'
27
+ SEND_RECEIVE = 'send_receive'
28
+
29
+
30
+ class DistributedType(Enum):
31
+ """
32
+ 分布式类型:点对点通信、集体通信
33
+ """
34
+ P2P = 'p2p'
35
+ COLLECTIVE = 'collective'
36
+
37
+
38
+ CANNOT_MATCH = 'cannot match distributed node in rank'
39
+
40
+
41
+ class DistributedAnalyzer:
42
+
43
+ def __init__(self, graphs: dict, overflow_check: bool):
44
+ self.graphs = graphs
45
+ self.overflow_check = overflow_check
46
+ self.config = {
47
+ # 当前通信api名称: 匹配目标通信api名称, 获取rank信息的位置参数或关键字参数, 通信类型, 分布式类型
48
+ 'send': ['recv', GraphConst.DST, CommunicationType.SEND.value, DistributedType.P2P],
49
+ 'isend': ['irecv', GraphConst.DST, CommunicationType.SEND.value, DistributedType.P2P],
50
+ 'recv': ['send', GraphConst.SRC, CommunicationType.RECEIVE.value, DistributedType.P2P],
51
+ 'irecv': ['isend', GraphConst.SRC, CommunicationType.RECEIVE.value, DistributedType.P2P],
52
+ 'broadcast': ['broadcast', '1', CommunicationType.SEND.value, DistributedType.COLLECTIVE],
53
+ 'scatter': ['scatter', GraphConst.SRC, CommunicationType.SEND.value, DistributedType.COLLECTIVE],
54
+ 'gather': ['gather', GraphConst.DST, CommunicationType.RECEIVE.value, DistributedType.COLLECTIVE],
55
+ 'reduce': ['reduce', '1', CommunicationType.RECEIVE.value, DistributedType.COLLECTIVE]
56
+ }
57
+ self.group_node_mapping = {}
58
+ self._make_group_node_mapping()
59
+
60
+ @staticmethod
61
+ def _get_opposite_communication_type(action):
62
+ if action == CommunicationType.SEND.value:
63
+ return CommunicationType.RECEIVE.value
64
+ elif action == CommunicationType.RECEIVE.value:
65
+ return CommunicationType.SEND.value
66
+ return action
67
+
68
+ @staticmethod
69
+ def _node_output_all_equal(data: dict, target_data: dict):
70
+ keys_to_compare = [Const.DTYPE, Const.SHAPE, Const.MAX, Const.MIN, Const.MEAN, Const.NORM]
71
+ return all(data.get(key) == target_data.get(key) for key in keys_to_compare)
72
+
73
+ @staticmethod
74
+ def _get_target_rank(node, rank, parameter):
75
+ """
76
+ 点对点通信, 从输出数据参数src或dst, 获取通信目标rank
77
+ 一对多通信和多对一通信, 从输出数据参数src或dst或位置参数, 获取发送或接收的rank源头
78
+ :param node: 当前节点
79
+ :param rank: 当前rank
80
+ :param parameter: 输出数据参数
81
+ :return: 目标rank
82
+ """
83
+ target_rank = node.input_data.get(f'{node.id}{GraphConst.INPUT}{parameter}', {}).get('value')
84
+ if target_rank is None:
85
+ logger.warning(f'The parameter {parameter} of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
86
+ return target_rank
87
+
88
+ @staticmethod
89
+ def _get_group_info(node, rank):
90
+ """
91
+ 获取当前通信节点的group参数中的group_ranks和group_id
92
+ :param node: 当前通信节点
93
+ :param rank: 当前rank
94
+ :return: group_ranks和group_id
95
+ """
96
+ group = node.input_data.get(f'{node.id}{GraphConst.INPUT}group', {})
97
+ if not group:
98
+ logger.warning(f'The kwarg group of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
99
+ return None, None
100
+ group_ranks = group.get('group_ranks')
101
+ if not group_ranks:
102
+ logger.warning(f'The group_ranks of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
103
+ return None, None
104
+ group_id = group.get('group_id')
105
+ if not group_id:
106
+ logger.warning(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
107
+ return None, None
108
+ return group_ranks, group_id
109
+
110
+ @staticmethod
111
+ def _get_batch_group_info(node, rank):
112
+ for data in node.input_data.values():
113
+ group_id = data.get('group_id')
114
+ if group_id is not None:
115
+ return group_id
116
+ logger.warning(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
117
+ return None
118
+
119
+ def distributed_match(self):
120
+ for rank, graph in self.graphs.items():
121
+ nodes = graph.node_map
122
+ for node_id, node in nodes.items():
123
+ # 不是通信节点或者已经匹配过了
124
+ if not node_id.startswith(Const.DISTRIBUTED) or node.matched_distributed:
125
+ continue
126
+ api_name, distributed_type = self._get_distributed_name_and_type(node_id)
127
+ if api_name == GraphConst.BATCH_P2P:
128
+ self._batch_p2p_match(node, rank)
129
+ elif distributed_type == DistributedType.P2P:
130
+ self._p2p_match(node, rank, api_name)
131
+ else:
132
+ self._collective_match(node, rank, api_name)
133
+
134
+ def _make_group_node_mapping(self):
135
+ """
136
+ 建立通信节点的全局唯一标识映射
137
+ key: rank号, value: unique_group_id与node_id之间的映射
138
+ {
139
+ "0": {
140
+ "unique_group_id1": "node_id1",
141
+ "unique_group_id2": "node_id2",
142
+ "node_id1": "unique_group_id1",
143
+ "node_id2": "unique_group_id2"
144
+ },
145
+ "1": {},
146
+ "2": {}
147
+ }
148
+ """
149
+ for rank, graph in self.graphs.items():
150
+ group_count = {}
151
+ group_info = {}
152
+ batch_p2p_count = {}
153
+ nodes = graph.node_map
154
+ for node_id, node in nodes.items():
155
+ if not node_id.startswith(Const.DISTRIBUTED):
156
+ continue
157
+ api_name, distributed_type = self._get_distributed_name_and_type(node_id)
158
+ if api_name == GraphConst.BATCH_P2P:
159
+ self._make_batch_p2p_mapping(node, rank, batch_p2p_count)
160
+ continue
161
+ elif distributed_type == DistributedType.P2P:
162
+ config_info = self.config.get(api_name)
163
+ target_rank = self._get_target_rank(node, rank, config_info[1])
164
+ if target_rank is None:
165
+ continue
166
+ # p2p通信节点,api名称+传输目标rank作为group_id
167
+ group_id = api_name + Const.RANK + str(target_rank)
168
+ else:
169
+ # 其他通信节点直接获取group_id, 并拼接api名称
170
+ _, group_id = self._get_group_info(node, rank)
171
+ if not group_id:
172
+ continue
173
+ group_id += api_name
174
+ # 同group_id的调用次数累计
175
+ group_count[group_id] = group_count.get(group_id, 0) + 1
176
+ # group_id+同group_id的调用次数作为唯一的unique_group_id
177
+ unique_group_id = group_id + Const.REPLACEMENT_CHARACTER + str(group_count.get(group_id))
178
+ group_info[unique_group_id] = node_id
179
+ group_info[node_id] = unique_group_id
180
+ if rank not in self.group_node_mapping:
181
+ self.group_node_mapping[rank] = {}
182
+ self.group_node_mapping[rank].update(group_info)
183
+
184
+ def _make_batch_p2p_mapping(self, node, rank, batch_p2p_count):
185
+ """
186
+ 给batch_isend_irecv接口的每个p2p内容赋予唯一标识
187
+ """
188
+ if rank not in self.group_node_mapping:
189
+ self.group_node_mapping[rank] = {}
190
+ params = []
191
+ for info_dict in node.batch_p2p_info:
192
+ op = info_dict.get(GraphConst.OP)
193
+ target_rank = info_dict.get(GraphConst.PEER)
194
+ if op is None or target_rank is None:
195
+ logger.warning('Cannot get param op or peer.')
196
+ continue
197
+ group_id = op + Const.REPLACEMENT_CHARACTER + Const.RANK + str(target_rank) + \
198
+ Const.REPLACEMENT_CHARACTER + info_dict.get(GraphConst.GROUP_ID, '')
199
+ batch_p2p_count[group_id] = batch_p2p_count.get(group_id, 0) + 1
200
+ # 例如: isend_rank0_5a4d31ad765260ba50eb190f1f9fd163_1
201
+ unique_group_id = group_id + Const.REPLACEMENT_CHARACTER + str(batch_p2p_count.get(group_id))
202
+ params.append(unique_group_id)
203
+ self.group_node_mapping.get(rank)[unique_group_id] = node.id
204
+ if params:
205
+ self.group_node_mapping.get(rank)[node.id] = params
206
+
207
+ def _get_distributed_name_and_type(self, node_id):
208
+ if Const.SEP not in node_id:
209
+ raise ValueError(f'Invalid node id {node_id}.')
210
+ api_name = node_id.split(Const.SEP)[1]
211
+ if api_name in self.config:
212
+ return api_name, self.config.get(api_name)[3]
213
+ return api_name, DistributedType.COLLECTIVE
214
+
215
+ def _get_target_node(self, rank, unique_group_id, api_name, target_rank, target_api_name=None):
216
+ """
217
+ 获取名称匹配上的目标节点
218
+ :param rank: 当前rank
219
+ :param unique_group_id: 当前节点唯一group id
220
+ :param api_name: 当前节点的api名称, 例如Distributed.isend.0.forward, api名称为isend
221
+ :param target_rank: 与当前节点产生通信的rank
222
+ :param target_api_name: 与当前节点产生通信的节点api名称, 仅p2p通信需要配置
223
+ :return: 目标节点
224
+ """
225
+ target_graph = self.graphs.get(target_rank)
226
+ if not target_graph:
227
+ logger.warning(f'Graph data does not exist, {CANNOT_MATCH}{target_rank}')
228
+ return None
229
+ target_group_mapping = self.group_node_mapping.get(target_rank)
230
+ # p2p通信,想要获取目标节点,需要替换unique_group_id中的rank和api name,
231
+ # 例如isend发送到rank1,对应的irecv接收自rank0, isend_rank1与irecv_rank0对应
232
+ target_unique_group_id = (unique_group_id
233
+ .replace(Const.RANK + str(target_rank), Const.RANK + str(rank))
234
+ .replace(api_name, target_api_name)) if target_api_name else unique_group_id
235
+ target_node_id = target_group_mapping.get(target_unique_group_id, '')
236
+ target_node = target_graph.node_map.get(target_node_id)
237
+ if not target_node:
238
+ logger.warning(f'Node {target_node_id} does not exist, {CANNOT_MATCH}{target_rank}')
239
+ return None
240
+ return target_node
241
+
242
+ def _add_node_matched_distributed(self, node, target_node, api_name, target_rank, reversal_type=False):
243
+ """
244
+ 给当前节点添加matched_distributed字段信息
245
+ :param node: 当前节点
246
+ :param target_node: 匹配上的目标节点
247
+ :param api_name: 当前节点的api名称
248
+ :param target_rank: 匹配上的目标rank
249
+ :param reversal_type: 是否需要反转通信类型,例如broadcast在rank0通信类型是发送,但在其他rank通信类型是接收
250
+ """
251
+ communications_type = self.config.get(api_name)[2]
252
+ communications_type = self._get_opposite_communication_type(communications_type) if reversal_type \
253
+ else communications_type
254
+ index = target_node.data.get(GraphConst.OVERFLOW_LEVEL, CompareConst.NAN) if self.overflow_check \
255
+ else target_node.data.get(GraphConst.JSON_INDEX_KEY, CompareConst.NAN)
256
+ matched_distributed = {
257
+ 'communications_type': communications_type,
258
+ 'nodes_info': {target_rank: [str(index), target_node.id]}
259
+ }
260
+ node.matched_distributed = matched_distributed
261
+
262
+ def _p2p_match(self, node, rank, api_name):
263
+ """
264
+ 点对点通信匹配
265
+
266
+ 根据当前点对点通信节点的输出数据中的src或dst参数, 确定目标rank, 并从目标rank中找到对应的点对点通信节点, 校验输出数据是否一致,
267
+ 校验通过则在两个匹配节点增加匹配信息
268
+ Args:
269
+ node: 当前点对点通信节点
270
+ rank: 当前节点所属rank
271
+ api_name: 当前节点的api名称
272
+ Returns:
273
+ """
274
+ config_info = self.config.get(api_name)
275
+ target_api_name = config_info[0]
276
+ #
277
+ target_rank = self._get_target_rank(node, rank, config_info[1])
278
+ if target_rank is None:
279
+ return
280
+ unique_group_id = self.group_node_mapping.get(rank, {}).get(node.id, '')
281
+ target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name)
282
+ if not target_node:
283
+ return
284
+ target_config_info = self.config.get(target_api_name)
285
+ source_rank = (target_node.input_data.get(f'{target_node.id}{GraphConst.INPUT}{target_config_info[1]}', {})
286
+ .get('value'))
287
+ if source_rank is None:
288
+ logger.warning(
289
+ f'The kwarg {target_config_info[1]} of node {target_node.id} does not exist, '
290
+ f'{CANNOT_MATCH}{source_rank}')
291
+ return
292
+ if source_rank != rank:
293
+ # 点对点通信,待匹配目标节点包含的rank信息要与当前rank一致
294
+ logger.warning(
295
+ f'{node.id} of rank{rank} is expected to communicate with {target_node.id} of rank{target_rank}, '
296
+ f'but the data shows that {target_node.id} communicates with rank{source_rank}.'
297
+ f'The rank is inconsistent, cannot match distributed node')
298
+ return
299
+
300
+ # 点对点通信,两个匹配节点的输出数据要一致
301
+ if not DistributedAnalyzer._node_output_all_equal(node.output_data.get(node.id + '.output.0'),
302
+ target_node.output_data.get(target_node.id + '.output.0')):
303
+ logger.warning(f'{node.id} output of rank{rank} is different from the {target_node.id} '
304
+ f'output of rank{target_rank}, cannot match distributed node')
305
+ return
306
+
307
+ self._add_node_matched_distributed(node, target_node, api_name, target_rank)
308
+ self._add_node_matched_distributed(target_node, node, target_api_name, rank)
309
+
310
+ def _collective_match(self, node, rank, api_name):
311
+ """
312
+ 集体通信匹配
313
+
314
+ 一对多通信和多对一通信, 需要先获取节点输出数据中的src或dst或位置参数, 确定发送源或接收源, 多对多通信不需要
315
+ :param node: 当前集体通信节点
316
+ :param rank: 当前节点所属rank
317
+ :param api_name: 当前节点的api名称
318
+ :return:
319
+ """
320
+ communications_type = CommunicationType.SEND_RECEIVE.value
321
+ config_info = self.config.get(api_name)
322
+ if config_info:
323
+ # 此时为一对多通信或多对一通信
324
+ source_rank = self._get_target_rank(node, rank, config_info[1])
325
+ if source_rank is None or str(source_rank) != str(rank):
326
+ return
327
+ communications_type = config_info[2]
328
+ group_ranks, group_id = self._get_group_info(node, rank)
329
+ if not group_ranks or not group_id:
330
+ return
331
+ unique_group_id = self.group_node_mapping.get(rank, {}).get(node.id, '')
332
+ matched_distributed = {'communications_type': communications_type}
333
+ nodes_info = {}
334
+ for target_rank in group_ranks:
335
+ if str(target_rank) == str(rank):
336
+ continue
337
+ target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank)
338
+ if not target_node:
339
+ continue
340
+ _, target_group_id = self._get_group_info(target_node, target_rank)
341
+ if not target_group_id:
342
+ continue
343
+ if group_id != target_group_id:
344
+ logger.warning(
345
+ f'{node.id} of rank{rank} is expected to communicate with {target_node.id} of rank{target_rank}'
346
+ f', but the data shows that the group id of the two nodes are different, '
347
+ f'cannot match distributed node')
348
+ continue
349
+ # 给当前通信节点添加matched_distributed字段信息
350
+ index = target_node.data.get(GraphConst.OVERFLOW_LEVEL, CompareConst.NAN) if self.overflow_check \
351
+ else target_node.data.get(GraphConst.JSON_INDEX_KEY, CompareConst.NAN)
352
+ nodes_info[target_rank] = [str(index), target_node.id]
353
+ if config_info:
354
+ # 给匹配上的目标节点也添加matched_distributed字段信息
355
+ self._add_node_matched_distributed(target_node, node, api_name, rank, True)
356
+ if nodes_info:
357
+ matched_distributed['nodes_info'] = nodes_info
358
+ node.matched_distributed = matched_distributed
359
+
360
+ def _batch_p2p_match(self, node, rank):
361
+ """
362
+ 批量点对点匹配
363
+
364
+ 针对torch.distributed.batch_isend_irecv接口,其入参是一个包含点对点通信信息的集合,需要遍历集合对每个点对点通信信息进行匹配
365
+ :param node: 当前集体通信节点
366
+ :param rank: 当前节点所属rank
367
+ :return:
368
+ """
369
+ unique_group_ids = self.group_node_mapping.get(rank, {}).get(node.id)
370
+ if not unique_group_ids:
371
+ return
372
+ matched_distributed = [] if len(unique_group_ids) > 1 else {}
373
+ for unique_group_id in unique_group_ids:
374
+ try:
375
+ id_info = unique_group_id.split(Const.REPLACEMENT_CHARACTER)
376
+ api_name = id_info[0]
377
+ target_api_name = self.config.get(api_name)[0]
378
+ target_rank = int(id_info[1].replace(Const.RANK, ''))
379
+ except Exception as e:
380
+ logger.warning(f'Failed to parsing batch p2p parameter with error info: {e}.')
381
+ continue
382
+ target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name)
383
+ if not target_node:
384
+ continue
385
+ communications_type = self.config.get(api_name)[2]
386
+ index = target_node.data.get(GraphConst.OVERFLOW_LEVEL, CompareConst.NAN) if self.overflow_check \
387
+ else target_node.data.get(GraphConst.JSON_INDEX_KEY, CompareConst.NAN)
388
+ matched_info = {
389
+ 'communications_type': communications_type,
390
+ 'nodes_info': {target_rank: [str(index), target_node.id]}
391
+ }
392
+ matched_distributed.append(matched_info) if isinstance(matched_distributed, list) \
393
+ else matched_distributed.update(matched_info)
394
+ if matched_distributed:
395
+ node.matched_distributed = matched_distributed
@@ -67,6 +67,15 @@ class Graph:
67
67
  ancestors_b = node_b.get_ancestors()
68
68
  return node_b, ancestors_n, ancestors_b
69
69
 
70
+
71
+ @staticmethod
72
+ def fuzzy_match(node_n, node_b):
73
+ if not node_n or not node_b or not node_n.fuzzy_eq(node_b):
74
+ return None, [], []
75
+ ancestors_n = node_n.get_ancestors()
76
+ ancestors_b = node_b.get_ancestors()
77
+ return node_b, ancestors_n, ancestors_b
78
+
70
79
  @staticmethod
71
80
  def dfs(node, result):
72
81
  info = node.to_dict()
@@ -16,6 +16,7 @@
16
16
  from enum import Enum
17
17
  import re
18
18
  from msprobe.visualization.builder.msprobe_adapter import op_patterns
19
+ from msprobe.core.common.log import logger
19
20
 
20
21
 
21
22
  class NodeOp(Enum):
@@ -32,8 +33,9 @@ class NodeOp(Enum):
32
33
  for op in NodeOp:
33
34
  index = op.value
34
35
  if index < 0 or index >= len(op_patterns):
35
- raise Exception("NodeOp and op_patterns in MsprobeAdapter do not match")
36
+ continue
36
37
  pattern = op_patterns[index]
37
38
  if re.match(pattern, node_name):
38
39
  return op
39
- raise Exception(f"Cannot parse node_name {node_name} into NodeOp")
40
+ logger.warning(f"Cannot parsing node_name {node_name} into NodeOp, default parsing as module.")
41
+ return NodeOp.module
@@ -16,8 +16,8 @@
16
16
  import os
17
17
  import time
18
18
  import json
19
- from msprobe.core.common.file_utils import (FileOpen, check_file_type, create_directory, FileChecker,
20
- check_file_or_directory_path)
19
+ from msprobe.core.common.file_utils import (check_file_type, create_directory, FileChecker,
20
+ check_file_or_directory_path, load_json)
21
21
  from msprobe.core.common.const import FileCheckConst, Const
22
22
  from msprobe.core.common.utils import CompareException
23
23
  from msprobe.core.overflow_check.checker import AnomalyDetector
@@ -28,11 +28,12 @@ from msprobe.core.common.log import logger
28
28
  from msprobe.visualization.graph.node_colors import NodeColors
29
29
  from msprobe.core.compare.layer_mapping import generate_api_mapping_by_layer_mapping
30
30
  from msprobe.core.compare.utils import check_and_return_dir_contents
31
+ from msprobe.visualization.graph.distributed_analyzer import DistributedAnalyzer
31
32
 
32
33
  current_time = time.strftime("%Y%m%d%H%M%S")
33
34
 
34
35
 
35
- def _compare_graph(input_param, args, output_file_name=f'compare_{current_time}.vis'):
36
+ def _compare_graph(input_param, args):
36
37
  logger.info('Start building model graphs...')
37
38
  # 对两个数据进行构图
38
39
  dump_path_n = input_param.get('npu_path')
@@ -49,8 +50,8 @@ def _compare_graph(input_param, args, output_file_name=f'compare_{current_time}.
49
50
  FileCheckConst.READ_ABLE).common_check()
50
51
  stack_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.STACK_FILE), FileCheckConst.FILE,
51
52
  FileCheckConst.READ_ABLE).common_check()
52
- graph_n = GraphBuilder.build(construct_path_n, data_path_n, stack_path_n)
53
- graph_b = GraphBuilder.build(construct_path_b, data_path_b, stack_path_b)
53
+ graph_n = GraphBuilder.build(construct_path_n, data_path_n, stack_path_n, complete_stack=args.complete_stack)
54
+ graph_b = GraphBuilder.build(construct_path_b, data_path_b, stack_path_b, complete_stack=args.complete_stack)
54
55
  logger.info('Model graphs built successfully, start Comparing graphs...')
55
56
  # 基于graph、stack和data进行比较
56
57
  dump_path_param = {
@@ -66,8 +67,7 @@ def _compare_graph(input_param, args, output_file_name=f'compare_{current_time}.
66
67
  mapping_dict = generate_api_mapping_by_layer_mapping(data_path_n, data_path_b, yaml_path)
67
68
  except Exception:
68
69
  logger.warning('The layer mapping file parsing failed, please check file format, mapping is not effective.')
69
- graph_comparator = GraphComparator([graph_n, graph_b], dump_path_param, args.output_path, args.framework,
70
- mapping_dict=mapping_dict)
70
+ graph_comparator = GraphComparator([graph_n, graph_b], dump_path_param, args, mapping_dict=mapping_dict)
71
71
  graph_comparator.compare()
72
72
  micro_steps = graph_n.paging_by_micro_step(graph_b)
73
73
  # 开启溢出检测
@@ -75,16 +75,22 @@ def _compare_graph(input_param, args, output_file_name=f'compare_{current_time}.
75
75
  graph_n.overflow_check()
76
76
  graph_b.overflow_check()
77
77
 
78
+ return CompareGraphResult(graph_n, graph_b, graph_comparator, micro_steps)
79
+
80
+
81
+ def _export_compare_graph_result(args, graphs, graph_comparator, micro_steps,
82
+ output_file_name=f'compare_{current_time}.vis'):
78
83
  create_directory(args.output_path)
79
84
  output_path = os.path.join(args.output_path, output_file_name)
80
85
  task = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(graph_comparator.ma.compare_mode)
81
- export_config = GraphExportConfig(graph_n, graph_b, graph_comparator.ma.get_tool_tip(),
82
- NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps, task)
86
+ export_config = GraphExportConfig(graphs[0], graphs[1], graph_comparator.ma.get_tool_tip(),
87
+ NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps, task,
88
+ args.overflow_check)
83
89
  GraphBuilder.to_json(output_path, export_config)
84
90
  logger.info(f'Model graphs compared successfully, the result file is saved in {output_path}')
85
91
 
86
92
 
87
- def _build_graph(dump_path, out_path, overflow_check=False, output_file_name=f'build_{current_time}.vis'):
93
+ def _build_graph(dump_path, args):
88
94
  logger.info('Start building model graph...')
89
95
  construct_path = FileChecker(os.path.join(dump_path, GraphConst.CONSTRUCT_FILE), FileCheckConst.FILE,
90
96
  FileCheckConst.READ_ABLE).common_check()
@@ -92,14 +98,19 @@ def _build_graph(dump_path, out_path, overflow_check=False, output_file_name=f'b
92
98
  FileCheckConst.READ_ABLE).common_check()
93
99
  stack_path = FileChecker(os.path.join(dump_path, GraphConst.STACK_FILE), FileCheckConst.FILE,
94
100
  FileCheckConst.READ_ABLE).common_check()
95
- create_directory(out_path)
96
- output_path = os.path.join(out_path, output_file_name)
97
- graph = GraphBuilder.build(construct_path, data_path, stack_path)
101
+ graph = GraphBuilder.build(construct_path, data_path, stack_path, complete_stack=args.complete_stack)
98
102
  micro_steps = graph.paging_by_micro_step()
99
103
  # 开启溢出检测
100
- if overflow_check:
104
+ if args.overflow_check:
101
105
  graph.overflow_check()
102
- GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps))
106
+ return BuildGraphResult(graph, micro_steps)
107
+
108
+
109
+ def _export_build_graph_result(out_path, graph, micro_steps, overflow_check,
110
+ output_file_name=f'build_{current_time}.vis'):
111
+ create_directory(out_path)
112
+ output_path = os.path.join(out_path, output_file_name)
113
+ GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps, overflow_check=overflow_check))
103
114
  logger.info(f'Model graph built successfully, the result file is saved in {output_path}')
104
115
 
105
116
 
@@ -111,12 +122,33 @@ def _compare_graph_ranks(input_param, args, step=None):
111
122
  if npu_ranks != bench_ranks:
112
123
  logger.error('The number of ranks in the two runs are different. Unable to match the ranks.')
113
124
  raise CompareException(CompareException.INVALID_PATH_ERROR)
125
+ compare_graph_results = []
114
126
  for nr, br in zip(npu_ranks, bench_ranks):
115
127
  logger.info(f'Start processing data for {nr}...')
116
128
  input_param['npu_path'] = os.path.join(dump_rank_n, nr)
117
129
  input_param['bench_path'] = os.path.join(dump_rank_b, br)
118
130
  output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis'
119
- _compare_graph(input_param, args, output_file_name=output_file_name)
131
+ result = _compare_graph(input_param, args)
132
+ result.output_file_name = output_file_name
133
+ if nr != Const.RANK:
134
+ try:
135
+ result.rank = int(nr.replace(Const.RANK, ""))
136
+ except Exception as e:
137
+ logger.error('The folder name format is incorrect, expected rank+number.')
138
+ raise CompareException(CompareException.INVALID_PATH_ERROR) from e
139
+ # 暂存所有rank的graph,用于匹配rank间的分布式节点
140
+ compare_graph_results.append(result)
141
+
142
+ # 匹配rank间的分布式节点
143
+ if len(compare_graph_results) > 1:
144
+ DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results},
145
+ args.overflow_check).distributed_match()
146
+ DistributedAnalyzer({obj.rank: obj.graph_b for obj in compare_graph_results},
147
+ args.overflow_check).distributed_match()
148
+
149
+ for result in compare_graph_results:
150
+ _export_compare_graph_result(args, [result.graph_n, result.graph_b], result.graph_comparator,
151
+ result.micro_steps, output_file_name=result.output_file_name)
120
152
 
121
153
 
122
154
  def _compare_graph_steps(input_param, args):
@@ -138,21 +170,38 @@ def _compare_graph_steps(input_param, args):
138
170
  _compare_graph_ranks(input_param, args, step=folder_step)
139
171
 
140
172
 
141
- def _build_graph_ranks(dump_ranks_path, out_path, overflow_check=False, step=None):
173
+ def _build_graph_ranks(dump_ranks_path, args, step=None):
142
174
  ranks = sorted(check_and_return_dir_contents(dump_ranks_path, Const.RANK))
175
+ build_graph_results = []
143
176
  for rank in ranks:
144
177
  logger.info(f'Start processing data for {rank}...')
145
178
  dump_path = os.path.join(dump_ranks_path, rank)
146
179
  output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis'
147
- _build_graph(dump_path, out_path, overflow_check, output_file_name)
180
+ result = _build_graph(dump_path, args)
181
+ result.output_file_name = output_file_name
182
+ if rank != Const.RANK:
183
+ try:
184
+ result.rank = int(rank.replace(Const.RANK, ""))
185
+ except Exception as e:
186
+ logger.error('The folder name format is incorrect, expected rank+number.')
187
+ raise CompareException(CompareException.INVALID_PATH_ERROR) from e
188
+ build_graph_results.append(result)
189
+
190
+ if len(build_graph_results) > 1:
191
+ DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results},
192
+ args.overflow_check).distributed_match()
193
+
194
+ for result in build_graph_results:
195
+ _export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check,
196
+ result.output_file_name)
148
197
 
149
198
 
150
- def _build_graph_steps(dump_steps_path, out_path, overflow_check=False):
199
+ def _build_graph_steps(dump_steps_path, args):
151
200
  steps = sorted(check_and_return_dir_contents(dump_steps_path, Const.STEP))
152
201
  for step in steps:
153
202
  logger.info(f'Start processing data for {step}...')
154
203
  dump_ranks_path = os.path.join(dump_steps_path, step)
155
- _build_graph_ranks(dump_ranks_path, out_path, overflow_check, step)
204
+ _build_graph_ranks(dump_ranks_path, args, step)
156
205
 
157
206
 
158
207
  def _graph_service_parser(parser):
@@ -161,14 +210,17 @@ def _graph_service_parser(parser):
161
210
  parser.add_argument("-o", "--output_path", dest="output_path", type=str,
162
211
  help="<Required> The compare task result out path.", required=True)
163
212
  parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str,
164
- help="<optional> The layer mapping file path.", required=False)
213
+ help="<Optional> The layer mapping file path.", required=False)
165
214
  parser.add_argument("-oc", "--overflow_check", dest="overflow_check", action="store_true",
166
215
  help="<Optional> whether open overflow_check for graph.", required=False)
216
+ parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true",
217
+ help="<Optional> Whether to perform a fuzzy match on the api name.", required=False)
218
+ parser.add_argument("-cs", "--complete_stack", dest="complete_stack", action="store_true",
219
+ help="<Optional> Whether to use complete stack information.", required=False)
167
220
 
168
221
 
169
222
  def _graph_service_command(args):
170
- with FileOpen(args.input_path, "r") as file:
171
- input_param = json.load(file)
223
+ input_param = load_json(args.input_path)
172
224
  npu_path = input_param.get("npu_path")
173
225
  bench_path = input_param.get("bench_path")
174
226
  check_file_or_directory_path(npu_path, isdir=True)
@@ -177,11 +229,12 @@ def _graph_service_command(args):
177
229
  if check_file_type(npu_path) == FileCheckConst.DIR and not bench_path:
178
230
  content = check_directory_content(npu_path)
179
231
  if content == GraphConst.RANKS:
180
- _build_graph_ranks(npu_path, args.output_path, args.overflow_check)
232
+ _build_graph_ranks(npu_path, args)
181
233
  elif content == GraphConst.STEPS:
182
- _build_graph_steps(npu_path, args.output_path, args.overflow_check)
234
+ _build_graph_steps(npu_path, args)
183
235
  else:
184
- _build_graph(npu_path, args.output_path, args.overflow_check)
236
+ result = _build_graph(npu_path, args)
237
+ _export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check)
185
238
  elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
186
239
  content_n = check_directory_content(npu_path)
187
240
  content_b = check_directory_content(bench_path)
@@ -192,7 +245,9 @@ def _graph_service_command(args):
192
245
  elif content_n == GraphConst.STEPS:
193
246
  _compare_graph_steps(input_param, args)
194
247
  else:
195
- _compare_graph(input_param, args)
248
+ result = _compare_graph(input_param, args)
249
+ _export_compare_graph_result(args, [result.graph_n, result.graph_b],
250
+ result.graph_comparator, result.micro_steps)
196
251
  else:
197
252
  logger.error("The npu_path or bench_path should be a folder.")
198
253
  raise CompareException(CompareException.INVALID_COMPARE_MODE)
@@ -212,3 +267,21 @@ def _ms_graph_service_parser(parser):
212
267
 
213
268
  def _ms_graph_service_command(args):
214
269
  _graph_service_command(args)
270
+
271
+
272
+ class CompareGraphResult:
273
+ def __init__(self, graph_n, graph_b, graph_comparator, micro_steps, rank=0, output_file_name=''):
274
+ self.graph_n = graph_n
275
+ self.graph_b = graph_b
276
+ self.graph_comparator = graph_comparator
277
+ self.micro_steps = micro_steps
278
+ self.rank = rank
279
+ self.output_file_name = output_file_name
280
+
281
+
282
+ class BuildGraphResult:
283
+ def __init__(self, graph, micro_steps, rank=0, output_file_name=''):
284
+ self.graph = graph
285
+ self.micro_steps = micro_steps
286
+ self.rank = rank
287
+ self.output_file_name = output_file_name