mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -0,0 +1,986 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ import math
18
+
19
+ from msprobe.core.common.const import Const
20
+ from msprobe.visualization.graph.graph import Graph, BaseNode
21
+ from msprobe.visualization.graph.node_op import NodeOp
22
+ from msprobe.core.common.log import logger
23
+ from msprobe.visualization.utils import GraphConst
24
+ from msprobe.core.common.decorator import recursion_depth_decorator
25
+ from msprobe.core.common.parallel_state import get_tp_pp_default_groups
26
+
27
+ MAX_INFO = 'The Max value merging method for '
28
+ MIN_INFO = 'The Min value merging method for '
29
+ MEAN_INFO = 'The Mean value merging method for '
30
+ NORM_INFO = 'The Norm value merging method for '
31
+
32
+
33
+ class GraphMerger:
34
+ def __init__(self, build_graph_results, parallel_param, is_bench=False):
35
+ self.strategy = self._select_strategy(build_graph_results, parallel_param, is_bench)
36
+
37
+ @staticmethod
38
+ def _select_strategy(results, param, is_bench):
39
+ if param.tp == param.pp == param.rank_size == 1:
40
+ return NoParallelMerger(results, param, is_bench)
41
+ elif param.tp == param.rank_size:
42
+ return TPMerger(results, param, is_bench)
43
+ elif param.pp == param.rank_size:
44
+ return PPMerger(results, param, is_bench) if param.vpp == 1 else VPPMerger(results, param, is_bench)
45
+ elif param.pp == 1:
46
+ return TPMerger(results, param, is_bench)
47
+ elif param.tp == 1:
48
+ return PPMerger(results, param, is_bench) if param.vpp == 1 else VPPMerger(results, param, is_bench)
49
+ elif param.tp * param.pp == param.rank_size:
50
+ return TPPPMerger(results, param, is_bench)
51
+ else:
52
+ return FullMerger(results, param, is_bench)
53
+
54
+ def merge_graph(self):
55
+ return self.strategy.merge_graphs()
56
+
57
+
58
+ class BaseGraphMerger:
59
+ def __init__(self, build_graph_results, parallel_param, is_bench):
60
+ self.unmerged_module = [Const.CLIP_GRAD, Const.OPTIMIZER]
61
+ self.dtype_list = Const.TORCH_INT_DTYPE + Const.TORCH_FLOAT_DTYPE + [Const.FLOAT16, Const.FLOAT32,
62
+ Const.BFLOAT16]
63
+ self.build_graph_results = build_graph_results
64
+ self.parallel_param = parallel_param
65
+ self.is_bench = is_bench
66
+ self.log_prefix = '[Bench]' if self.is_bench else '[NPU]'
67
+ self._add_all_nodes_rank()
68
+
69
+ @staticmethod
70
+ def sort_merged_api_collection(graph):
71
+ def extract_rank(node):
72
+ match = re.search(r'_Rank(\d+)', node.id)
73
+ return int(match.group(1)) if match else None
74
+
75
+ for sub_node in graph.root.subnodes:
76
+ if sub_node.op == NodeOp.api_collection and sub_node.id.startswith(
77
+ GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS):
78
+ sub_node.subnodes = sorted(sub_node.subnodes, key=extract_rank)
79
+
80
+ @staticmethod
81
+ def _update_node_data_key(old_id, new_id, data_dict):
82
+ new_dict = {}
83
+ for key, value in data_dict.items():
84
+ new_key = key.replace(old_id, new_id)
85
+ if 'full_op_name' in value:
86
+ value['full_op_name'] = value.get('full_op_name').replace(old_id, new_id)
87
+ new_dict[new_key] = value
88
+ return new_dict
89
+
90
+ @staticmethod
91
+ def _compare_value_same(main_value, other_value, has_uncertainty=False):
92
+ if not isinstance(main_value, (int, float)) or not isinstance(other_value, (int, float)):
93
+ return True
94
+ # 没开启确定性计算,各rank的mean和norm有细微差异,如果相对误差在阈值内则认为是相同的
95
+ if has_uncertainty:
96
+ diff = abs(main_value - other_value)
97
+ if math.isnan(diff):
98
+ return math.isnan(main_value) and math.isnan(other_value)
99
+ elif math.isinf(diff):
100
+ return math.isinf(main_value) and math.isinf(other_value)
101
+ else:
102
+ return diff < GraphConst.UNCERTAINTY_THRESHOLD if main_value == 0 else \
103
+ abs(diff / main_value) < GraphConst.UNCERTAINTY_THRESHOLD
104
+ else:
105
+ return main_value == other_value
106
+
107
+ def merge_graphs(self):
108
+ raise NotImplementedError("This method should be implemented by subclasses.")
109
+
110
+ def merge_graph_api_collection(self, results: list):
111
+ """
112
+ graph合并时,将各rank的游离api集合合并为一个总的游离api集合
113
+ example:
114
+ rank0: Apis_Between_Modules.0 rank1: Apis_Between_Modules.0
115
+ Module.module.Float16Module.forward.0 Module.module.Float16Module.forward.0
116
+ Apis_Between_Modules.1 Apis_Between_Modules.1
117
+
118
+ merged: Apis_Between_Modules_All_Ranks.0
119
+ |_ Apis_Between_Modules_Rank0.0
120
+ |_ Apis_Between_Modules_Rank1.0
121
+ Module.module.Float16Module.forward.0
122
+ Apis_Between_Modules_All_Ranks.1
123
+ |_ Apis_Between_Modules_Rank0.1
124
+ |_ Apis_Between_Modules_Rank1.1
125
+ """
126
+ main_graph_result = results[0]
127
+ main_root_sub_nodes = main_graph_result.graph.root.subnodes
128
+ new_main_root_sub_nodes = []
129
+ for main_node in main_root_sub_nodes:
130
+ # 如果游离api集合已合并为一个总的游离api集合,总的游离api集合之间还要再合并
131
+ if main_node.id.startswith(GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS):
132
+ new_main_root_sub_nodes.append(main_node)
133
+ for other_graph_result in results[1:]:
134
+ other_node = other_graph_result.graph.get_node(main_node.id)
135
+ if not other_node:
136
+ continue
137
+ for sub_node in other_node.subnodes:
138
+ sub_node.upnode = main_node
139
+ main_graph_result.graph.node_map[sub_node.id] = sub_node
140
+ for sub_sub_node in sub_node.subnodes:
141
+ main_graph_result.graph.node_map[sub_sub_node.id] = sub_sub_node
142
+ main_node.subnodes.extend(other_node.subnodes)
143
+ # 游离api集合合并为一个总的游离api集合
144
+ elif main_node.id.startswith(GraphConst.APIS_BETWEEN_MODULES):
145
+ all_collection_node_id = main_graph_result.graph.add_node(NodeOp.api_collection,
146
+ GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS,
147
+ id_accumulation=True)
148
+ all_collection_node = main_graph_result.graph.get_node(all_collection_node_id)
149
+ new_main_root_sub_nodes.append(all_collection_node)
150
+ # Apis_Between_Modules.0 --> Apis_Between_Modules_Rank0.0
151
+ origin_main_node_id = main_node.id
152
+ main_node.id = GraphConst.APIS_BETWEEN_MODULES + f'_Rank{main_graph_result.rank}.' + \
153
+ main_node.id.split(Const.SEP)[-1]
154
+ all_collection_node.subnodes = [main_node]
155
+ main_node.upnode = all_collection_node
156
+ main_graph_result.graph.node_map[main_node.id] = main_node
157
+ del main_graph_result.graph.node_map[origin_main_node_id]
158
+ for other_graph_result in results[1:]:
159
+ other_node = other_graph_result.graph.get_node(origin_main_node_id)
160
+ if not other_node:
161
+ continue
162
+ # Apis_Between_Modules.0 --> Apis_Between_Modules_Rank1.0
163
+ other_node.id = GraphConst.APIS_BETWEEN_MODULES + f'_Rank{other_graph_result.rank}.' + \
164
+ other_node.id.split(Const.SEP)[-1]
165
+ main_graph_result.graph.node_map[other_node.id] = other_node
166
+ for sub_node in other_node.subnodes:
167
+ # api节点,在api名称上添加rank信息
168
+ old_id = sub_node.id
169
+ parts = sub_node.id.split(Const.SEP)
170
+ parts[1] += f'_rank{other_graph_result.rank}'
171
+ sub_node.id = Const.SEP.join(parts)
172
+ sub_node.input_data = self._update_node_data_key(old_id, sub_node.id, sub_node.input_data)
173
+ sub_node.output_data = self._update_node_data_key(old_id, sub_node.id, sub_node.output_data)
174
+ main_graph_result.graph.node_map[sub_node.id] = sub_node
175
+ all_collection_node.subnodes.append(other_node)
176
+ other_node.upnode = all_collection_node
177
+ else:
178
+ new_main_root_sub_nodes.append(main_node)
179
+ main_graph_result.graph.root.subnodes = new_main_root_sub_nodes
180
+
181
+ def split_graph_results_by_groups(self, groups):
182
+ """
183
+ 基于pp或tp域,划分待合并的graph
184
+ """
185
+ rank_results_mapping = {result.rank: result for result in self.build_graph_results}
186
+ return [[rank_results_mapping.get(rank) for rank in ranks] for ranks in groups]
187
+
188
+ def compare_node_param_data(self, main_node, other_nodes, compare_data=True):
189
+ """
190
+ 当前节点与若干其他节点比较输入输出参数的数据是否一致,如果发现有不一致的参数,将参数暂存于列表中
191
+ :param main_node: 当前节点
192
+ :param other_nodes: 其他节点列表
193
+ :param compare_data: 是否进行数据比对,如果compare_data=False则直接认为数据不一致
194
+ :return: 输入不一致的参数dict,输出不一致的参数dict,两个dict都为空列表代表两个节点的输入输出完全一致
195
+ """
196
+ if not other_nodes:
197
+ return {}, {}
198
+ data_types = {'input_data': {}, 'output_data': {}}
199
+ for data_type, data_dict in data_types.items():
200
+ main_data_dict = getattr(main_node, data_type)
201
+ for key, main_param in main_data_dict.items():
202
+ same_flag = compare_data
203
+ if main_param.get(Const.DTYPE) not in self.dtype_list:
204
+ continue
205
+ tp_need_merge_params = [main_param]
206
+ for other_node in other_nodes:
207
+ param_key = key.replace(main_node.id, other_node.id) if main_node.id != other_node.id else key
208
+ other_param = getattr(other_node, data_type).get(param_key, {})
209
+ if other_param.get(Const.DTYPE) not in self.dtype_list:
210
+ break
211
+ tp_need_merge_params.append(other_param)
212
+ if compare_data and not self.compare_param_same(main_param, other_param, has_uncertainty=True):
213
+ same_flag = False
214
+ if not same_flag:
215
+ data_dict[key.replace(main_node.id + Const.SEP, '')] = tp_need_merge_params
216
+ return data_types.get('input_data'), data_types.get('output_data')
217
+
218
+ def compare_param_same(self, main_param, other_param, has_uncertainty=False):
219
+ if not self._compare_value_same(main_param.get(Const.MAX), other_param.get(Const.MAX)):
220
+ return False
221
+ if not self._compare_value_same(main_param.get(Const.MIN), other_param.get(Const.MIN)):
222
+ return False
223
+ if not self._compare_value_same(main_param.get(Const.MEAN), other_param.get(Const.MEAN), has_uncertainty):
224
+ return False
225
+ if not self._compare_value_same(main_param.get(Const.NORM), other_param.get(Const.NORM), has_uncertainty):
226
+ return False
227
+ return True
228
+
229
+ def get_default_groups(self):
230
+ """
231
+ 根据GPU总数、TP数、PP数初始化并行组
232
+
233
+ return:
234
+ tp_groups: 张量并行组列表,每个元素是一个包含组内rank的列表
235
+ pp_groups: 流水线并行组列表,每个元素是一个包含组内rank的列表
236
+ """
237
+ tp_groups, pp_groups = get_tp_pp_default_groups(self.parallel_param.rank_size, self.parallel_param.tp,
238
+ self.parallel_param.pp, order=self.parallel_param.order)
239
+
240
+ return tp_groups, pp_groups
241
+
242
+ def _add_all_nodes_rank(self):
243
+ for result in self.build_graph_results:
244
+ for node in result.graph.node_map.values():
245
+ node.rank = result.rank
246
+
247
+
248
+ class PPMerger(BaseGraphMerger):
249
+ LAYERS_PATTERN = re.compile(r"(layers\.|layer\.)\d+(\.)")
250
+ MARK_PATTERN = re.compile(r"%(\d+)%(\d+)$")
251
+ MARK = '%'
252
+
253
+ @staticmethod
254
+ def _trace_p2p_mapping(p2p_mapping: dict):
255
+ """
256
+ 将字典分组为独立的链,每个链都从未访问过的键开始,按照字典中的映射关系进行追踪
257
+ p2p_mapping内容为p2p通信的send映射,追踪映射关系建立pp域
258
+ example: p2p_mapping={0: 2, 1: 3, 2: 4, 3: 5, 4: 6, 5: 7, 6: 4, 7: 5}, return=[[0, 2, 4, 6], [1, 3, 5, 7]]
259
+ """
260
+ visited = set()
261
+ result = []
262
+
263
+ def collect_keys(start_key):
264
+ """
265
+ 追踪从某一个键开始的所有“连续”键,直到无法再找到下一个键为止
266
+ """
267
+ current_key = start_key
268
+ chain = []
269
+ while current_key in p2p_mapping and current_key not in visited:
270
+ chain.append(current_key)
271
+ visited.add(current_key)
272
+ current_key = p2p_mapping[current_key]
273
+ return chain
274
+
275
+ for key in p2p_mapping:
276
+ if key not in visited:
277
+ chain_result = collect_keys(key)
278
+ if chain_result:
279
+ result.append(chain_result)
280
+ return result
281
+
282
+ @recursion_depth_decorator("msprobe.visualization.builder.graph_merger.PPMerger._merge_nodes", 1000)
283
+ def _merge_nodes(self, main_graph, main_node, other_graphs):
284
+ """
285
+ 其他rank graph中被pp切分的节点,需要合并到main graph
286
+ """
287
+ other_nodes = []
288
+ for other_graph in other_graphs:
289
+ other_node = other_graph.get_node(main_node.id)
290
+ # 表明此节点只有main graph有
291
+ if not other_node:
292
+ other_nodes.clear()
293
+ return
294
+ other_nodes.append(other_node)
295
+ if other_nodes:
296
+ param_in, param_out = self.compare_node_param_data(main_node, other_nodes)
297
+ # 各个rank都有的模块,且输入输出都不一致,且节点id符合正则,判定为被pp切分的模块,需要合并结构
298
+ pp_merged_condition = param_in and param_out and self.LAYERS_PATTERN.search(main_node.id)
299
+ # backward可能没有output,是否要pp合并从对应的forward节点判断
300
+ if Const.SEP + Const.BACKWARD + Const.SEP in main_node.id:
301
+ f_node = main_graph.node_map.get(
302
+ main_node.id.replace(Const.SEP + Const.BACKWARD + Const.SEP, Const.SEP + Const.FORWARD + Const.SEP))
303
+ if f_node and hasattr(f_node, 'is_pp_merged'):
304
+ pp_merged_condition = True
305
+ if pp_merged_condition:
306
+ main_node.is_pp_merged = True
307
+ main_up_node = main_node.upnode
308
+ for other_node in other_nodes:
309
+ # pp切分中被切分的层在各rank的名称是一样的,这里给其他rank的同名层增加位置和rank标记
310
+ self._mark_node_id_position_rank(other_node, other_node.rank)
311
+ self._add_node_to_main_graph(main_graph, other_node)
312
+ # 其他rank被pp切分的模块节点添加到当前rank的graph
313
+ other_node.upnode = main_up_node
314
+ main_up_node.subnodes.append(other_node)
315
+ # 已找到被pp切分的模块节点,不再递归其内部
316
+ return
317
+ # 各个rank都有的forward模块,且输入一致,输出不一致,判定为模块内部包含被pp切分的模块,此模块的输出要使用最后一个rank的输出
318
+ elif not param_in and param_out and Const.SEP + Const.FORWARD + Const.SEP in main_node.id:
319
+ main_node.output_data = other_nodes[-1].output_data
320
+ # 各个rank都有的backward模块,且输出一致,输入不一致,判定为模块内部包含被pp切分的模块,此模块的输入要使用最后一个rank的输入
321
+ elif param_in and not param_out and Const.SEP + Const.BACKWARD + Const.SEP in main_node.id:
322
+ main_node.input_data = other_nodes[-1].input_data
323
+ self._merge_other_unique_nodes(main_graph, main_node, other_nodes)
324
+ for sub_node in main_node.subnodes:
325
+ if sub_node.op == NodeOp.module:
326
+ self._merge_nodes(main_graph, sub_node, other_graphs)
327
+
328
+ def merge_graphs(self):
329
+ results_groups = self.split_graph_results_by_groups(self.get_groups())
330
+ results = []
331
+ for result_groups in results_groups:
332
+ self.merge_graph_api_collection(result_groups)
333
+ results.extend(self.merge_pp_graphs(result_groups))
334
+ return results
335
+
336
+ def merge_pp_graphs(self, results):
337
+ if not results or len(results) < 2:
338
+ return results
339
+ graphs = [x.graph for x in results]
340
+ main_graph_result = results[0]
341
+ for main_node in main_graph_result.graph.root.subnodes:
342
+ if main_node.op == NodeOp.module and main_node.id not in self.unmerged_module:
343
+ self._merge_nodes(main_graph_result.graph, main_node, graphs[1:])
344
+ self._sort_nodes(main_graph_result.graph, main_node)
345
+ return [main_graph_result]
346
+
347
+ def get_groups(self):
348
+ """
349
+ 在各rank寻找p2p通信节点,建立各rank之间p2p的映射关系
350
+ """
351
+ p2p_mapping = {}
352
+ for result in self.build_graph_results:
353
+ rank = result.rank
354
+ pp_rank = None
355
+ for node in result.graph.node_map.values():
356
+ if not node.id.startswith(Const.DISTRIBUTED + Const.SEP):
357
+ continue
358
+ if '.batch_isend_irecv.' in node.id:
359
+ for p2p_info in node.batch_p2p_info:
360
+ target_rank = p2p_info.get(GraphConst.PEER)
361
+ if target_rank is not None and target_rank != rank and p2p_info.get(GraphConst.OP) == 'isend':
362
+ pp_rank = target_rank
363
+ break
364
+ elif '.send.' in node.id or '.isend.' in node.id:
365
+ # example: Distributed.isend.0.forward --> Distributed.isend.0.forward.input.dst
366
+ dst_kwarg = f'{node.id}{Const.SEP}{Const.INPUT}{Const.SEP}{GraphConst.DST}'
367
+ dst = node.input_data.get(dst_kwarg, {}).get('value')
368
+ if dst is not None:
369
+ pp_rank = dst
370
+ break
371
+ if pp_rank is not None:
372
+ break
373
+ if pp_rank is not None:
374
+ p2p_mapping[rank] = pp_rank
375
+ pp_groups = self._trace_p2p_mapping(p2p_mapping)
376
+ if not pp_groups:
377
+ logger.info('Unable to get pp groups based on Distributed Api (batch_isend_irecv, send, or isend), '
378
+ 'generate pp groups using parallel param "rank_size", "tp" and "pp".')
379
+ _, pp_groups = self.get_default_groups()
380
+ logger.info(f'{self.log_prefix} All pp groups is {pp_groups}.')
381
+ return pp_groups
382
+
383
+ def _merge_other_unique_nodes(self, main_graph, main_node, other_nodes):
384
+ """
385
+ 其他rank graph中other_node的子节点列表如果包含独有的节点,需要合并到main graph
386
+ """
387
+ lists = [main_node.subnodes]
388
+ for other_node in other_nodes:
389
+ lists.append(other_node.subnodes)
390
+ dicts = [{node.id: node for node in lst} for lst in lists]
391
+ unique_node_ids = {}
392
+ # 计算每个集合的独有元素
393
+ for i, current_dict in enumerate(dicts):
394
+ other_ids = set()
395
+ for j, other_dict in enumerate(dicts):
396
+ if i != j:
397
+ # 更新并集,添加当前遍历到的集合的元素
398
+ other_ids.update(other_dict.keys())
399
+ result = set(current_dict.keys()) - other_ids
400
+ if i != 0 and result:
401
+ # 计算当前集合与其他集合并集的差集,即独有元素,保持原始顺序
402
+ unique_node_ids[i] = [node_id for node_id in current_dict if node_id in result]
403
+ unique_nodes = []
404
+ if unique_node_ids:
405
+ for i, items in unique_node_ids.items():
406
+ for item in items:
407
+ unique_nodes.append(dicts[i].get(item))
408
+ if unique_nodes:
409
+ for unique_node in unique_nodes:
410
+ self._mark_node_id_position_rank(unique_node, unique_node.rank)
411
+ self._add_node_to_main_graph(main_graph, unique_node)
412
+ main_node.subnodes.append(unique_node)
413
+ unique_node.upnode = main_node
414
+
415
+ def _sort_nodes(self, main_graph, start_node):
416
+ stack = [start_node]
417
+ while stack:
418
+ node = stack.pop()
419
+ if self.MARK_PATTERN.search(node.id):
420
+ is_forward = (Const.SEP + Const.FORWARD + Const.SEP in node.id or
421
+ Const.SEP + Const.FORWARD + self.MARK in node.id)
422
+ new_sub_nodes1, new_sub_nodes2 = [], []
423
+ for item in node.upnode.subnodes:
424
+ new_sub_nodes2.append(item) if self.MARK_PATTERN.search(item.id) else new_sub_nodes1.append(item)
425
+
426
+ order = True if is_forward else False
427
+ new_sub_nodes2.sort(key=lambda n: self._get_node_sort_rule(n, rank_ascending=order))
428
+ new_sub_nodes = new_sub_nodes1 + new_sub_nodes2 if is_forward else new_sub_nodes2 + new_sub_nodes1
429
+
430
+ index = -1
431
+ node_iter = new_sub_nodes if is_forward else reversed(new_sub_nodes)
432
+ for item in node_iter:
433
+ if self.LAYERS_PATTERN.search(item.id):
434
+ index += 1
435
+ if self.MARK_PATTERN.search(item.id):
436
+ item.pp_index = index
437
+
438
+ for item in new_sub_nodes2:
439
+ self._update_node_id(main_graph, item)
440
+
441
+ node.upnode.subnodes = new_sub_nodes
442
+
443
+ stack.extend(node.subnodes)
444
+
445
+ def _add_node_to_main_graph(self, main_graph: Graph, node: BaseNode):
446
+ if node.id in main_graph.node_map:
447
+ logger.warning(f'{node.id} is exist!')
448
+ else:
449
+ main_graph.node_map[node.id] = node
450
+ for sub_node in node.subnodes:
451
+ self._add_node_to_main_graph(main_graph, sub_node)
452
+
453
+ def _get_node_sort_rule(self, node, rank_ascending=True):
454
+ match = self.MARK_PATTERN.search(node.id)
455
+ if match:
456
+ # position代表当前节点在父节点中的位置序号
457
+ position, rank = int(match.group(1)), int(match.group(2))
458
+ if rank_ascending:
459
+ return rank, position
460
+ else:
461
+ return -rank, position
462
+ return (float('inf'), float('inf')) if rank_ascending else (-float('inf'), -float('inf'))
463
+
464
+ def _mark_node_id_position_rank(self, node: BaseNode, rank):
465
+ position = 0
466
+ for index, item in enumerate(node.upnode.subnodes):
467
+ if item.id == node.id:
468
+ position = index
469
+ break
470
+ # 各rank重复节点添加所处层级位置排序信息position和rank号,用%分隔
471
+ node.id = node.id + f'{self.MARK}{position}' + f'{self.MARK}{rank}'
472
+ for sub_node in node.subnodes:
473
+ self._mark_node_id_position_rank(sub_node, rank)
474
+
475
+ def _update_node_id(self, graph, start_node: BaseNode, pp_index=""):
476
+ stack = [(start_node, pp_index)]
477
+ while stack:
478
+ node, pp_index = stack.pop()
479
+ # 修改节点id之前删除node_map的信息,修改完再添加回去
480
+ if node.id not in graph.node_map:
481
+ logger.warning(f'Update node id {node.id} fail!')
482
+ else:
483
+ del graph.node_map[node.id]
484
+ old_id = self.MARK_PATTERN.sub("", node.id)
485
+ if node.op == NodeOp.module:
486
+ # 被pp切分的模块节点,基于位置和rank信息修改模块名称计数信息
487
+ if self.LAYERS_PATTERN.search(node.id) and self.MARK_PATTERN.search(node.id):
488
+ if hasattr(node, 'pp_index'):
489
+ pp_index = str(node.pp_index)
490
+ node.id = self.LAYERS_PATTERN.sub(r"\g<1>" + pp_index + r"\g<2>", node.id)
491
+ else:
492
+ # api节点,在api名称上添加rank信息
493
+ parts = node.id.split(Const.SEP)
494
+ parts[1] += f'_rank{node.id.split(PPMerger.MARK)[-1]}'
495
+ node.id = Const.SEP.join(parts)
496
+ # 把之前添加的位置和rank信息删掉
497
+ node.id = self.MARK_PATTERN.sub("", node.id)
498
+ # node id更新了,那么data的key中包含node id也要更新
499
+ node.input_data = self._update_node_data_key(old_id, node.id, node.input_data)
500
+ node.output_data = self._update_node_data_key(old_id, node.id, node.output_data)
501
+ graph.node_map[node.id] = node
502
+ # 将子节点加入栈中
503
+ for sub_node in node.subnodes:
504
+ stack.append((sub_node, pp_index))
505
+
506
+
507
+ class TPMerger(BaseGraphMerger):
508
+ RANK_PATTERN = re.compile(r"_rank(\d+)\.")
509
+ OPERATION_TABLE = {
510
+ Const.MAX: {
511
+ 'initial': lambda p: p.get(Const.MAX),
512
+ 'merge': lambda current, other: max(current, other.get(Const.MAX)),
513
+ 'finalize': lambda current, count: current,
514
+ 'formula': lambda key, values: f'{MAX_INFO}{key} is: max({", ".join(map(str, values))})'
515
+ },
516
+ Const.MIN: {
517
+ 'initial': lambda p: p.get(Const.MIN),
518
+ 'merge': lambda current, other: min(current, other.get(Const.MIN)),
519
+ 'finalize': lambda current, count: current,
520
+ 'formula': lambda key, values: f'{MIN_INFO}{key} is: min({", ".join(map(str, values))})'
521
+ },
522
+ Const.MEAN: {
523
+ 'initial': lambda p: p.get(Const.MEAN),
524
+ 'merge': lambda current, other: current + other.get(Const.MEAN),
525
+ 'finalize': lambda current, count: current / count,
526
+ 'formula': lambda key, values: f'{MEAN_INFO}{key} is: ({" + ".join(map(str, values))}) / {len(values)}'
527
+ },
528
+ Const.NORM: {
529
+ 'initial': lambda p: pow(p.get(Const.NORM), 2.0),
530
+ 'merge': lambda current, other: current + pow(other.get(Const.NORM), 2.0),
531
+ 'finalize': lambda current, count: pow(current, 1 / 2.0),
532
+ 'formula': lambda key, values: f'{NORM_INFO}{key} is: ({" + ".join([f"{v} ** 2" for v in values])}) ** 0.5'
533
+ }
534
+ }
535
+ TP_MERGED_INFO = f'This data is the merged data after tensor parallelism(TP), and the data is merged from rank '
536
+
537
+ @staticmethod
538
+ def _merge_params(tp_need_merge_param: dict):
539
+ """
540
+ 合并tp切分的各rank参数统计值
541
+ tp_need_merge_param: {input.0: [{"Max": 0, "Min": 0, ...}, {"Max": 0.1, "Min": 0, ...}, ...]}
542
+ return: 计算详情
543
+ """
544
+ merge_info = []
545
+ for key, param_list in tp_need_merge_param.items():
546
+ if len(param_list) < 2:
547
+ continue
548
+ main_param = param_list[0]
549
+
550
+ for stat, ops in TPMerger.OPERATION_TABLE.items():
551
+ current_value = ops['initial'](main_param)
552
+ value_list = [current_value if stat != Const.NORM else main_param.get(Const.NORM)]
553
+
554
+ for other_param in param_list[1:]:
555
+ current_value = ops['merge'](current_value, other_param)
556
+ value_list.append(other_param.get(stat) if stat != Const.NORM else other_param.get(Const.NORM))
557
+
558
+ final_value = ops['finalize'](current_value, len(param_list))
559
+ main_param[stat] = final_value
560
+ formula_base = f'{ops["formula"](key, value_list)}' + f' = {final_value}'
561
+
562
+ merge_info.append(formula_base)
563
+
564
+ return merge_info
565
+
566
+ @staticmethod
567
+ def _get_need_merge_node(main_node, other_graphs, tp_merge_mapping):
568
+ """
569
+ 获取需要TP合并的节点列表
570
+ 如果是TP+PP的混合并行,此时数据已经被PP合并过,一些node_id被标记上rank信息,此时需要基于rank映射才能获取到需要TP合并的节点列表,例如:
571
+ main_node = Torch.matmul_rank4.32.forward other_node = Torch.matmul_rank5.32.forward
572
+ 需要建立4->5的映射,才能基于Torch.matmul_rank4.32.forward找到Torch.matmul_rank5.32.forward
573
+ """
574
+ other_nodes = []
575
+ match = TPMerger.RANK_PATTERN.search(main_node.id)
576
+ # 节点名称被标记rank信息,且提供了映射
577
+ if match and tp_merge_mapping:
578
+ rank = int(match.group(1))
579
+ tp_mapping_ranks = tp_merge_mapping.get(rank)
580
+ if not tp_mapping_ranks:
581
+ return other_nodes
582
+ if len(tp_mapping_ranks) != len(other_graphs):
583
+ return other_nodes
584
+ for i, graph in enumerate(other_graphs):
585
+ # 基于映射得到目标rank,替换node_id当前rank信息后去目标graph取node
586
+ tp_mapping_id = TPMerger.RANK_PATTERN.sub(f"_rank{tp_mapping_ranks[i]}.", main_node.id)
587
+ other_node = graph.node_map.get(tp_mapping_id)
588
+ if not other_node or main_node.get_ancestors() != other_node.get_ancestors():
589
+ other_nodes.clear()
590
+ break
591
+ other_nodes.append(other_node)
592
+ else:
593
+ for graph in other_graphs:
594
+ other_node = graph.node_map.get(main_node.id)
595
+ if not other_node or main_node.get_ancestors() != other_node.get_ancestors():
596
+ other_nodes.clear()
597
+ break
598
+ other_nodes.append(other_node)
599
+
600
+ return other_nodes
601
+
602
+ @staticmethod
603
+ def _slice_list_at_id(node_list, target_id1, target_id2):
604
+ start_index, end_index = -1, -1
605
+ for index, node in enumerate(node_list):
606
+ if target_id1 in node.id:
607
+ start_index = index
608
+ elif target_id2 in node.id:
609
+ end_index = index
610
+ return [] if start_index == -1 or end_index == -1 else node_list[start_index:end_index + 1]
611
+
612
+ def merge_graphs(self):
613
+ results_groups = self.split_graph_results_by_groups(self.get_groups())
614
+ results = []
615
+ for result_groups in results_groups:
616
+ self.merge_graph_api_collection(result_groups)
617
+ results.extend(self.merge_tp_graphs(result_groups))
618
+ return results
619
+
620
+ def merge_tp_graphs(self, results, tp_merge_mapping=None):
621
+ if not results or len(results) < 2:
622
+ return results
623
+ graphs = [x.graph for x in results]
624
+ main_graph_result = results[0]
625
+ for main_node in main_graph_result.graph.node_map.values():
626
+ should_continue = (
627
+ not main_node.upnode or main_node.upnode.op != NodeOp.module or
628
+ main_node.upnode.id in self.unmerged_module or main_node.id.startswith(Const.DISTRIBUTED) or
629
+ main_node.parallel_merge_info != [])
630
+ if should_continue:
631
+ continue
632
+ self._handle_tp_matmul_reduce(main_node, graphs[1:], tp_merge_mapping)
633
+ other_nodes = self._get_need_merge_node(main_node, graphs[1:], tp_merge_mapping)
634
+ tp_need_merge_param_in, tp_need_merge_param_out = self.compare_node_param_data(main_node, other_nodes)
635
+ if tp_need_merge_param_in or tp_need_merge_param_out:
636
+ ranks = [main_node.rank]
637
+ for other_node in other_nodes:
638
+ ranks.append(other_node.rank)
639
+ main_node.parallel_merge_info.append(f'{self.TP_MERGED_INFO}{ranks}.')
640
+ merge_info_in = self._merge_params(tp_need_merge_param_in)
641
+ merge_info_out = self._merge_params(tp_need_merge_param_out)
642
+ main_node.parallel_merge_info.extend(merge_info_in + merge_info_out)
643
+ for main_node in main_graph_result.graph.node_map.values():
644
+ self._merge_tp_megatron_column_row_parallel(main_node, graphs[1:], tp_merge_mapping)
645
+ return [main_graph_result]
646
+
647
+ def get_groups(self):
648
+ tp_groups = []
649
+ for result in self.build_graph_results:
650
+ for node in result.graph.node_map.values():
651
+ if any(op in node.id for op in GraphConst.REDUCE_OPERATIONS):
652
+ group_ranks = node.input_data.get(f'{node.id}.input.group', {}).get('group_ranks')
653
+ if group_ranks and group_ranks not in tp_groups:
654
+ tp_groups.append(group_ranks)
655
+ break
656
+ if not tp_groups:
657
+ logger.info('Unable to get tp groups based on Distributed Api (reduce_scatter or all_reduce), '
658
+ 'generate tp groups using parallel param "rank_size", "tp" and "pp".')
659
+ tp_groups, _ = self.get_default_groups()
660
+ logger.info(f'{self.log_prefix} All tp groups is {tp_groups}.')
661
+ return tp_groups
662
+
663
+ def _handle_tp_matmul_reduce(self, node, other_graphs, tp_merge_mapping):
664
+ """
665
+ 前向RowParallel和反向ColumnParallel层的matmul输出需要替换成matmul计算完成后all_reduce/reduce_scatter的输出
666
+ """
667
+ if node.op != NodeOp.module:
668
+ return
669
+ splits = node.id.split(Const.SEP)
670
+ if len(splits) < 4:
671
+ return
672
+ is_forward_with_row_parallel = splits[-2] == Const.FORWARD and 'RowParallelLinear' in splits[-3]
673
+ is_backward_with_column_parallel = splits[-2] == Const.BACKWARD and 'ColumnParallelLinear' in splits[-3]
674
+ if not is_forward_with_row_parallel and not is_backward_with_column_parallel:
675
+ return
676
+ matmul_list = []
677
+ reduce_list = []
678
+ for sub_node in node.subnodes:
679
+ if 'matmul' in sub_node.id:
680
+ matmul_list.append(sub_node)
681
+ if ('_reduce_scatter_base' in sub_node.id or 'reduce_scatter_tensor' in sub_node.id or
682
+ 'all_reduce' in sub_node.id):
683
+ reduce_list.append(sub_node)
684
+ if not matmul_list or not reduce_list:
685
+ return
686
+ for matmul_node in matmul_list:
687
+ if not matmul_node.output_data:
688
+ continue
689
+ # matmul的output0,将传递给all_reduce/reduce_scatter,作为all_reduce的input0,或作为reduce_scatter的input1
690
+ matmul_node_output_param = list(matmul_node.output_data.values())[0]
691
+ for reduce_node in reduce_list:
692
+ if not reduce_node.output_data:
693
+ continue
694
+ if 'all_reduce' in reduce_node.id:
695
+ if not reduce_node.input_data:
696
+ continue
697
+ reduce_node_input_param = list(reduce_node.input_data.values())[0]
698
+ else:
699
+ if len(reduce_node.input_data) < 2:
700
+ continue
701
+ reduce_node_input_param = list(reduce_node.input_data.values())[1]
702
+ if not self.compare_param_same(matmul_node_output_param, reduce_node_input_param):
703
+ continue
704
+ # matmul的input统计值与其他rank的数据进行合并
705
+ other_nodes = self._get_need_merge_node(matmul_node, other_graphs, tp_merge_mapping)
706
+ tp_need_merge_param_in, _ = self.compare_node_param_data(matmul_node, other_nodes)
707
+ if tp_need_merge_param_in:
708
+ ranks = [matmul_node.rank]
709
+ for other_node in other_nodes:
710
+ ranks.append(other_node.rank)
711
+ matmul_node.parallel_merge_info.append(f'{self.TP_MERGED_INFO}{ranks}.')
712
+ merge_info_in = self._merge_params(tp_need_merge_param_in)
713
+ matmul_node.parallel_merge_info.extend(merge_info_in)
714
+ # matmul的output0替换为all_reduce/reduce_scatter的output0
715
+ reduce_node_output_param = list(reduce_node.output_data.values())[0]
716
+ keys = [Const.MAX, Const.MIN, Const.MEAN, Const.NORM]
717
+ matmul_node_output_param.update({k: reduce_node_output_param.get(k) for k in keys})
718
+ full_op_name = reduce_node_output_param.get('full_op_name')
719
+ param_name = full_op_name if full_op_name else reduce_node.id
720
+ matmul_node.parallel_merge_info.append(f'The output of this data is merged from {param_name}')
721
+ reduce_list.remove(reduce_node)
722
+ break
723
+
724
+ def _merge_tp_megatron_column_row_parallel(self, node, other_graphs, tp_merge_mapping):
725
+ if node.op != NodeOp.module or node.parallel_merge_info:
726
+ return
727
+ splits = node.id.split(Const.SEP)
728
+ if len(splits) < 4:
729
+ return
730
+ is_forward_with_column_parallel = splits[-2] == Const.FORWARD and 'ColumnParallelLinear' in splits[-3]
731
+ if not is_forward_with_column_parallel:
732
+ return
733
+ if not node.upnode:
734
+ return
735
+ # 获取[ColumnParallelLinear, RowParallelLinear]结构
736
+ nodes = self._slice_list_at_id(node.upnode.subnodes, node.id, 'RowParallelLinear')
737
+ if len(nodes) < 2:
738
+ return
739
+ stack = nodes[:]
740
+ while stack:
741
+ current_node = stack.pop()
742
+ stack.extend(reversed(current_node.subnodes))
743
+
744
+ if current_node.parallel_merge_info or current_node.id.startswith(Const.DISTRIBUTED):
745
+ continue
746
+
747
+ other_nodes = self._get_need_merge_node(current_node, other_graphs, tp_merge_mapping)
748
+ param_in, param_out = self.compare_node_param_data(current_node, other_nodes, False)
749
+
750
+ if param_in or param_out:
751
+ ranks = [current_node.rank]
752
+ for other_node in other_nodes:
753
+ ranks.append(other_node.rank)
754
+ current_node.parallel_merge_info.append(f'{self.TP_MERGED_INFO}{ranks}.')
755
+ # ColumnParallelLinear层的输入、其中的matmul输入不需要合并
756
+ if current_node == nodes[0] or ('matmul' in current_node.id and current_node.upnode == nodes[0]):
757
+ param_in.pop('input.0', None)
758
+ # RowParallelLinear层的输出、其中的matmul输出不需要合并, bias不需要合并
759
+ elif current_node == nodes[-1] or ('matmul' in current_node.id and current_node.upnode == nodes[-1]):
760
+ param_out = {}
761
+ param_in.pop('parameters.bias', None)
762
+
763
+ merge_info_in = self._merge_params(param_in)
764
+ merge_info_out = self._merge_params(param_out)
765
+ current_node.parallel_merge_info.extend(merge_info_in + merge_info_out)
766
+
767
+
768
+ class NoParallelMerger(BaseGraphMerger):
769
+ def merge_graphs(self):
770
+ self.merge_graph_api_collection(self.build_graph_results)
771
+ return self.build_graph_results
772
+
773
+
774
+ class TPPPMerger(BaseGraphMerger):
775
+ def merge_graphs(self):
776
+ tp_merger = TPMerger(self.build_graph_results, self.parallel_param, self.is_bench)
777
+ pp_merger = PPMerger(self.build_graph_results, self.parallel_param, self.is_bench) \
778
+ if self.parallel_param.vpp == 1 else VPPMerger(self.build_graph_results, self.parallel_param, self.is_bench)
779
+ pp_groups = pp_merger.get_groups()
780
+ tp_groups = tp_merger.get_groups()
781
+ # 进入TP+PP混合处理器,PP和TP必然大于1
782
+ tp_merge_mapping = {}
783
+ for tp_group in tp_groups[1:]:
784
+ tp_merge_mapping[tp_group[0]] = tp_group[1:]
785
+ self.merge_graph_api_collection(self.build_graph_results)
786
+ # 先合并pp,需要知道pp域,在各自pp域中合并
787
+ results_groups_pp = self.split_graph_results_by_groups(pp_groups)
788
+ pp_results = []
789
+ for results in results_groups_pp:
790
+ pp_results.extend(pp_merger.merge_pp_graphs(results))
791
+ # pp合并完成后,直接进行tp合并,最终得到一个graph
792
+ tp_result = tp_merger.merge_tp_graphs(pp_results, tp_merge_mapping)
793
+ self.sort_merged_api_collection(tp_result[0].graph)
794
+ return tp_result
795
+
796
+
797
+ class FullMerger(BaseGraphMerger):
798
+ def merge_graphs(self):
799
+ tp_merger = TPMerger(self.build_graph_results, self.parallel_param, self.is_bench)
800
+ pp_merger = PPMerger(self.build_graph_results, self.parallel_param, self.is_bench) \
801
+ if self.parallel_param.vpp == 1 else VPPMerger(self.build_graph_results, self.parallel_param, self.is_bench)
802
+ pp_groups = pp_merger.get_groups()
803
+ tp_groups = tp_merger.get_groups()
804
+ tp_merge_mapping = {}
805
+ if len(tp_groups) < 1:
806
+ raise RuntimeError(f'Graph merged error, and tp_groups is {tp_groups}.')
807
+ for tp_group in tp_groups[1:]:
808
+ if len(tp_group) < 1:
809
+ raise RuntimeError(f'Graph merged error, and tp_group is {tp_group}.')
810
+ tp_merge_mapping[tp_group[0]] = tp_group[1:]
811
+ # 先合并pp,需要知道pp域,在各自pp域中合并
812
+ results_groups_pp = self.split_graph_results_by_groups(pp_groups)
813
+ pp_results = {}
814
+ for pp_result in results_groups_pp:
815
+ self.merge_graph_api_collection(pp_result)
816
+ pp_result = pp_merger.merge_pp_graphs(pp_result)[0]
817
+ pp_results[pp_result.rank] = pp_result
818
+ # pp合并完成后,基于tp域划分pp合并结果
819
+ lists_to_be_tp_merged = []
820
+ for tp_group in tp_groups:
821
+ list_to_be_tp_merged = []
822
+ for rank in tp_group:
823
+ pp_result = pp_results.get(rank)
824
+ if pp_result:
825
+ list_to_be_tp_merged.append(pp_result)
826
+ if list_to_be_tp_merged:
827
+ lists_to_be_tp_merged.append(list_to_be_tp_merged)
828
+ tp_results = []
829
+ for list_to_be_tp_merged in lists_to_be_tp_merged:
830
+ self.merge_graph_api_collection(list_to_be_tp_merged)
831
+ tp_merged_result = tp_merger.merge_tp_graphs(list_to_be_tp_merged, tp_merge_mapping)
832
+ self.sort_merged_api_collection(tp_merged_result[0].graph)
833
+ tp_results.extend(tp_merged_result)
834
+ return tp_results
835
+
836
+
837
+ class VPPMerger(PPMerger):
838
+ LAYERS_NUM_PATTERN = re.compile(r"(layers\.|layer\.)(\d+)(\.)")
839
+ FORWARD_PATTERN = re.compile(r'\.forward\.\d+$')
840
+
841
+ @staticmethod
842
+ def _replace_vpp_id(s, vpp_id):
843
+ parts = s.split(Const.SEP)
844
+ if len(parts) < 2 or not parts[1].isdigit():
845
+ return s
846
+ parts[1] = str(vpp_id)
847
+ return Const.SEP.join(parts)
848
+
849
+ def merge_pp_graphs(self, results):
850
+ if not results or len(results) < 2:
851
+ return results
852
+ graphs = [x.graph for x in results]
853
+ main_graph_result = results[0]
854
+ for main_node in main_graph_result.graph.root.subnodes:
855
+ if main_node.op == NodeOp.module and main_node.id not in self.unmerged_module:
856
+ self._merge_nodes(main_graph_result.graph, main_node, graphs[1:])
857
+ self._sort_nodes(main_graph_result.graph, main_node)
858
+ self._merge_vpp_data(main_graph_result.graph)
859
+ self._merge_vpp_chunks(main_graph_result.graph)
860
+ return [main_graph_result]
861
+
862
+ def _merge_vpp_data(self, graph):
863
+ """
864
+ 所有chunk的数据都合并到chunk0,前向chunk0的输出使用最后一个chunk的输出,反向chunk0的输入使用最后一个chunk的输入
865
+ """
866
+ module_list = []
867
+ for node in reversed(graph.root.subnodes):
868
+ parts = node.id.split(Const.SEP)
869
+ if len(parts) < 2:
870
+ continue
871
+ if parts[1] in [GraphConst.VPP_CHUNK_0, str(self.parallel_param.vpp - 1)]:
872
+ module_list.append(node)
873
+ if not module_list:
874
+ return
875
+ stack = module_list[:]
876
+ while stack:
877
+ current_node = stack.pop()
878
+ if hasattr(current_node, 'is_pp_merged') or hasattr(current_node,
879
+ 'pp_index') or current_node.op != NodeOp.module:
880
+ continue
881
+ is_forward = self.FORWARD_PATTERN.search(current_node.id)
882
+ stack.extend(reversed(current_node.subnodes))
883
+ target_id = self._replace_vpp_id(current_node.id, self.parallel_param.vpp - 1)
884
+ target_node = graph.node_map.get(target_id)
885
+ if not target_node:
886
+ continue
887
+ if is_forward:
888
+ current_node.output_data = self._update_node_data_key(target_node.id, current_node.id,
889
+ target_node.output_data)
890
+ else:
891
+ current_node.input_data = self._update_node_data_key(target_node.id, current_node.id,
892
+ target_node.input_data)
893
+
894
+ def _merge_vpp_chunks(self, graph):
895
+ """
896
+ 所有chunk都合并到chunk0,layers层搬到chunk0并重排序号
897
+ """
898
+ chunk_id_list = [i for i in range(1, self.parallel_param.vpp)]
899
+ chunk_0_list = []
900
+ for node in reversed(graph.root.subnodes):
901
+ parts = node.id.split(Const.SEP)
902
+ if len(parts) < 2:
903
+ continue
904
+ if parts[1] == GraphConst.VPP_CHUNK_0:
905
+ chunk_0_list.append(node)
906
+ if not chunk_0_list:
907
+ return
908
+ stack = chunk_0_list[:]
909
+ layers_need_merge_dict = {}
910
+ while stack:
911
+ current_node = stack.pop()
912
+ if hasattr(current_node, 'is_pp_merged') or hasattr(current_node, 'pp_index') \
913
+ and current_node.upnode.id not in layers_need_merge_dict:
914
+ layers_need_merge_dict[current_node.upnode.id] = current_node.upnode
915
+ continue
916
+ stack.extend(reversed(current_node.subnodes))
917
+ for node in layers_need_merge_dict.values():
918
+ is_forward = self.FORWARD_PATTERN.search(node.id)
919
+ for vpp_id in chunk_id_list:
920
+ target_node = graph.node_map.get(self._replace_vpp_id(node.id, vpp_id))
921
+ if not target_node:
922
+ continue
923
+ # 其他chunk的layers都搬到chunk0,forward追加到后面,backward追加到前面
924
+ if is_forward:
925
+ node.subnodes.extend(target_node.subnodes)
926
+ else:
927
+ node.subnodes = target_node.subnodes + node.subnodes
928
+ for sub_node in target_node.subnodes:
929
+ sub_node.upnode = node
930
+ # 获取其他chunk的层级链路,删除所有父节点,不在前端展示已合并的其他chunk节点
931
+ ancestors = target_node.get_ancestors()
932
+ if len(ancestors) < 2:
933
+ continue
934
+ for module_id in ancestors[1:]:
935
+ graph.node_map.pop(module_id, None)
936
+ graph.root.subnodes = [node for node in graph.root.subnodes if node.id != ancestors[1]]
937
+ # layers层重排序号
938
+ self._sort_layers(node.subnodes, graph, is_forward)
939
+
940
+ def _sort_layers(self, node_list, graph, is_forward):
941
+ if not is_forward:
942
+ node_list = list(reversed(node_list))
943
+ index = -1
944
+ for node in node_list:
945
+ match = self.LAYERS_NUM_PATTERN.search(node.id)
946
+ if match:
947
+ index += 1
948
+ parts = node.id.split(Const.SEP)
949
+ # Module.0.xxx代表第一个chunk,不必重排序
950
+ if len(parts) < 2 or parts[1] == GraphConst.VPP_CHUNK_0:
951
+ continue
952
+ # layers层修改chunk号和layers序号,非layers层修改chunk号
953
+ new_node_id_prefix = ''
954
+ if match:
955
+ prefix, number, dot = match.groups()
956
+ new_string = prefix + str(index) + dot
957
+ start, end = match.span()
958
+ new_node_id_prefix = node.id[:start] + new_string
959
+ new_node_id_prefix = self._replace_vpp_id(new_node_id_prefix, GraphConst.VPP_CHUNK_0)
960
+ new_node_id = new_node_id_prefix + node.id[end:]
961
+ else:
962
+ new_node_id = self._replace_vpp_id(node.id, GraphConst.VPP_CHUNK_0)
963
+ graph.node_map.pop(node.id, None)
964
+ node.input_data = self._update_node_data_key(node.id, new_node_id, node.input_data)
965
+ node.output_data = self._update_node_data_key(node.id, new_node_id, node.output_data)
966
+ node.id = new_node_id
967
+ graph.node_map[new_node_id] = node
968
+ stack = node.subnodes[:]
969
+ while stack:
970
+ current_node = stack.pop()
971
+ if current_node.op != NodeOp.module:
972
+ continue
973
+ stack.extend(reversed(current_node.subnodes))
974
+ match = self.LAYERS_NUM_PATTERN.search(current_node.id)
975
+ if match:
976
+ _, e = match.span()
977
+ new_current_node_id = new_node_id_prefix + current_node.id[e:]
978
+ else:
979
+ new_current_node_id = self._replace_vpp_id(current_node.id, GraphConst.VPP_CHUNK_0)
980
+ current_node.input_data = self._update_node_data_key(current_node.id, new_current_node_id,
981
+ current_node.input_data)
982
+ current_node.output_data = self._update_node_data_key(current_node.id, new_current_node_id,
983
+ current_node.output_data)
984
+ graph.node_map.pop(current_node.id, None)
985
+ current_node.id = new_current_node_id
986
+ graph.node_map[new_current_node_id] = current_node