mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -0,0 +1,986 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
import math
|
|
18
|
+
|
|
19
|
+
from msprobe.core.common.const import Const
|
|
20
|
+
from msprobe.visualization.graph.graph import Graph, BaseNode
|
|
21
|
+
from msprobe.visualization.graph.node_op import NodeOp
|
|
22
|
+
from msprobe.core.common.log import logger
|
|
23
|
+
from msprobe.visualization.utils import GraphConst
|
|
24
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
25
|
+
from msprobe.core.common.parallel_state import get_tp_pp_default_groups
|
|
26
|
+
|
|
27
|
+
MAX_INFO = 'The Max value merging method for '
|
|
28
|
+
MIN_INFO = 'The Min value merging method for '
|
|
29
|
+
MEAN_INFO = 'The Mean value merging method for '
|
|
30
|
+
NORM_INFO = 'The Norm value merging method for '
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class GraphMerger:
|
|
34
|
+
def __init__(self, build_graph_results, parallel_param, is_bench=False):
|
|
35
|
+
self.strategy = self._select_strategy(build_graph_results, parallel_param, is_bench)
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def _select_strategy(results, param, is_bench):
|
|
39
|
+
if param.tp == param.pp == param.rank_size == 1:
|
|
40
|
+
return NoParallelMerger(results, param, is_bench)
|
|
41
|
+
elif param.tp == param.rank_size:
|
|
42
|
+
return TPMerger(results, param, is_bench)
|
|
43
|
+
elif param.pp == param.rank_size:
|
|
44
|
+
return PPMerger(results, param, is_bench) if param.vpp == 1 else VPPMerger(results, param, is_bench)
|
|
45
|
+
elif param.pp == 1:
|
|
46
|
+
return TPMerger(results, param, is_bench)
|
|
47
|
+
elif param.tp == 1:
|
|
48
|
+
return PPMerger(results, param, is_bench) if param.vpp == 1 else VPPMerger(results, param, is_bench)
|
|
49
|
+
elif param.tp * param.pp == param.rank_size:
|
|
50
|
+
return TPPPMerger(results, param, is_bench)
|
|
51
|
+
else:
|
|
52
|
+
return FullMerger(results, param, is_bench)
|
|
53
|
+
|
|
54
|
+
def merge_graph(self):
|
|
55
|
+
return self.strategy.merge_graphs()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class BaseGraphMerger:
|
|
59
|
+
def __init__(self, build_graph_results, parallel_param, is_bench):
|
|
60
|
+
self.unmerged_module = [Const.CLIP_GRAD, Const.OPTIMIZER]
|
|
61
|
+
self.dtype_list = Const.TORCH_INT_DTYPE + Const.TORCH_FLOAT_DTYPE + [Const.FLOAT16, Const.FLOAT32,
|
|
62
|
+
Const.BFLOAT16]
|
|
63
|
+
self.build_graph_results = build_graph_results
|
|
64
|
+
self.parallel_param = parallel_param
|
|
65
|
+
self.is_bench = is_bench
|
|
66
|
+
self.log_prefix = '[Bench]' if self.is_bench else '[NPU]'
|
|
67
|
+
self._add_all_nodes_rank()
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def sort_merged_api_collection(graph):
|
|
71
|
+
def extract_rank(node):
|
|
72
|
+
match = re.search(r'_Rank(\d+)', node.id)
|
|
73
|
+
return int(match.group(1)) if match else None
|
|
74
|
+
|
|
75
|
+
for sub_node in graph.root.subnodes:
|
|
76
|
+
if sub_node.op == NodeOp.api_collection and sub_node.id.startswith(
|
|
77
|
+
GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS):
|
|
78
|
+
sub_node.subnodes = sorted(sub_node.subnodes, key=extract_rank)
|
|
79
|
+
|
|
80
|
+
@staticmethod
|
|
81
|
+
def _update_node_data_key(old_id, new_id, data_dict):
|
|
82
|
+
new_dict = {}
|
|
83
|
+
for key, value in data_dict.items():
|
|
84
|
+
new_key = key.replace(old_id, new_id)
|
|
85
|
+
if 'full_op_name' in value:
|
|
86
|
+
value['full_op_name'] = value.get('full_op_name').replace(old_id, new_id)
|
|
87
|
+
new_dict[new_key] = value
|
|
88
|
+
return new_dict
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def _compare_value_same(main_value, other_value, has_uncertainty=False):
|
|
92
|
+
if not isinstance(main_value, (int, float)) or not isinstance(other_value, (int, float)):
|
|
93
|
+
return True
|
|
94
|
+
# 没开启确定性计算,各rank的mean和norm有细微差异,如果相对误差在阈值内则认为是相同的
|
|
95
|
+
if has_uncertainty:
|
|
96
|
+
diff = abs(main_value - other_value)
|
|
97
|
+
if math.isnan(diff):
|
|
98
|
+
return math.isnan(main_value) and math.isnan(other_value)
|
|
99
|
+
elif math.isinf(diff):
|
|
100
|
+
return math.isinf(main_value) and math.isinf(other_value)
|
|
101
|
+
else:
|
|
102
|
+
return diff < GraphConst.UNCERTAINTY_THRESHOLD if main_value == 0 else \
|
|
103
|
+
abs(diff / main_value) < GraphConst.UNCERTAINTY_THRESHOLD
|
|
104
|
+
else:
|
|
105
|
+
return main_value == other_value
|
|
106
|
+
|
|
107
|
+
def merge_graphs(self):
|
|
108
|
+
raise NotImplementedError("This method should be implemented by subclasses.")
|
|
109
|
+
|
|
110
|
+
def merge_graph_api_collection(self, results: list):
|
|
111
|
+
"""
|
|
112
|
+
graph合并时,将各rank的游离api集合合并为一个总的游离api集合
|
|
113
|
+
example:
|
|
114
|
+
rank0: Apis_Between_Modules.0 rank1: Apis_Between_Modules.0
|
|
115
|
+
Module.module.Float16Module.forward.0 Module.module.Float16Module.forward.0
|
|
116
|
+
Apis_Between_Modules.1 Apis_Between_Modules.1
|
|
117
|
+
|
|
118
|
+
merged: Apis_Between_Modules_All_Ranks.0
|
|
119
|
+
|_ Apis_Between_Modules_Rank0.0
|
|
120
|
+
|_ Apis_Between_Modules_Rank1.0
|
|
121
|
+
Module.module.Float16Module.forward.0
|
|
122
|
+
Apis_Between_Modules_All_Ranks.1
|
|
123
|
+
|_ Apis_Between_Modules_Rank0.1
|
|
124
|
+
|_ Apis_Between_Modules_Rank1.1
|
|
125
|
+
"""
|
|
126
|
+
main_graph_result = results[0]
|
|
127
|
+
main_root_sub_nodes = main_graph_result.graph.root.subnodes
|
|
128
|
+
new_main_root_sub_nodes = []
|
|
129
|
+
for main_node in main_root_sub_nodes:
|
|
130
|
+
# 如果游离api集合已合并为一个总的游离api集合,总的游离api集合之间还要再合并
|
|
131
|
+
if main_node.id.startswith(GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS):
|
|
132
|
+
new_main_root_sub_nodes.append(main_node)
|
|
133
|
+
for other_graph_result in results[1:]:
|
|
134
|
+
other_node = other_graph_result.graph.get_node(main_node.id)
|
|
135
|
+
if not other_node:
|
|
136
|
+
continue
|
|
137
|
+
for sub_node in other_node.subnodes:
|
|
138
|
+
sub_node.upnode = main_node
|
|
139
|
+
main_graph_result.graph.node_map[sub_node.id] = sub_node
|
|
140
|
+
for sub_sub_node in sub_node.subnodes:
|
|
141
|
+
main_graph_result.graph.node_map[sub_sub_node.id] = sub_sub_node
|
|
142
|
+
main_node.subnodes.extend(other_node.subnodes)
|
|
143
|
+
# 游离api集合合并为一个总的游离api集合
|
|
144
|
+
elif main_node.id.startswith(GraphConst.APIS_BETWEEN_MODULES):
|
|
145
|
+
all_collection_node_id = main_graph_result.graph.add_node(NodeOp.api_collection,
|
|
146
|
+
GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS,
|
|
147
|
+
id_accumulation=True)
|
|
148
|
+
all_collection_node = main_graph_result.graph.get_node(all_collection_node_id)
|
|
149
|
+
new_main_root_sub_nodes.append(all_collection_node)
|
|
150
|
+
# Apis_Between_Modules.0 --> Apis_Between_Modules_Rank0.0
|
|
151
|
+
origin_main_node_id = main_node.id
|
|
152
|
+
main_node.id = GraphConst.APIS_BETWEEN_MODULES + f'_Rank{main_graph_result.rank}.' + \
|
|
153
|
+
main_node.id.split(Const.SEP)[-1]
|
|
154
|
+
all_collection_node.subnodes = [main_node]
|
|
155
|
+
main_node.upnode = all_collection_node
|
|
156
|
+
main_graph_result.graph.node_map[main_node.id] = main_node
|
|
157
|
+
del main_graph_result.graph.node_map[origin_main_node_id]
|
|
158
|
+
for other_graph_result in results[1:]:
|
|
159
|
+
other_node = other_graph_result.graph.get_node(origin_main_node_id)
|
|
160
|
+
if not other_node:
|
|
161
|
+
continue
|
|
162
|
+
# Apis_Between_Modules.0 --> Apis_Between_Modules_Rank1.0
|
|
163
|
+
other_node.id = GraphConst.APIS_BETWEEN_MODULES + f'_Rank{other_graph_result.rank}.' + \
|
|
164
|
+
other_node.id.split(Const.SEP)[-1]
|
|
165
|
+
main_graph_result.graph.node_map[other_node.id] = other_node
|
|
166
|
+
for sub_node in other_node.subnodes:
|
|
167
|
+
# api节点,在api名称上添加rank信息
|
|
168
|
+
old_id = sub_node.id
|
|
169
|
+
parts = sub_node.id.split(Const.SEP)
|
|
170
|
+
parts[1] += f'_rank{other_graph_result.rank}'
|
|
171
|
+
sub_node.id = Const.SEP.join(parts)
|
|
172
|
+
sub_node.input_data = self._update_node_data_key(old_id, sub_node.id, sub_node.input_data)
|
|
173
|
+
sub_node.output_data = self._update_node_data_key(old_id, sub_node.id, sub_node.output_data)
|
|
174
|
+
main_graph_result.graph.node_map[sub_node.id] = sub_node
|
|
175
|
+
all_collection_node.subnodes.append(other_node)
|
|
176
|
+
other_node.upnode = all_collection_node
|
|
177
|
+
else:
|
|
178
|
+
new_main_root_sub_nodes.append(main_node)
|
|
179
|
+
main_graph_result.graph.root.subnodes = new_main_root_sub_nodes
|
|
180
|
+
|
|
181
|
+
def split_graph_results_by_groups(self, groups):
|
|
182
|
+
"""
|
|
183
|
+
基于pp或tp域,划分待合并的graph
|
|
184
|
+
"""
|
|
185
|
+
rank_results_mapping = {result.rank: result for result in self.build_graph_results}
|
|
186
|
+
return [[rank_results_mapping.get(rank) for rank in ranks] for ranks in groups]
|
|
187
|
+
|
|
188
|
+
def compare_node_param_data(self, main_node, other_nodes, compare_data=True):
|
|
189
|
+
"""
|
|
190
|
+
当前节点与若干其他节点比较输入输出参数的数据是否一致,如果发现有不一致的参数,将参数暂存于列表中
|
|
191
|
+
:param main_node: 当前节点
|
|
192
|
+
:param other_nodes: 其他节点列表
|
|
193
|
+
:param compare_data: 是否进行数据比对,如果compare_data=False则直接认为数据不一致
|
|
194
|
+
:return: 输入不一致的参数dict,输出不一致的参数dict,两个dict都为空列表代表两个节点的输入输出完全一致
|
|
195
|
+
"""
|
|
196
|
+
if not other_nodes:
|
|
197
|
+
return {}, {}
|
|
198
|
+
data_types = {'input_data': {}, 'output_data': {}}
|
|
199
|
+
for data_type, data_dict in data_types.items():
|
|
200
|
+
main_data_dict = getattr(main_node, data_type)
|
|
201
|
+
for key, main_param in main_data_dict.items():
|
|
202
|
+
same_flag = compare_data
|
|
203
|
+
if main_param.get(Const.DTYPE) not in self.dtype_list:
|
|
204
|
+
continue
|
|
205
|
+
tp_need_merge_params = [main_param]
|
|
206
|
+
for other_node in other_nodes:
|
|
207
|
+
param_key = key.replace(main_node.id, other_node.id) if main_node.id != other_node.id else key
|
|
208
|
+
other_param = getattr(other_node, data_type).get(param_key, {})
|
|
209
|
+
if other_param.get(Const.DTYPE) not in self.dtype_list:
|
|
210
|
+
break
|
|
211
|
+
tp_need_merge_params.append(other_param)
|
|
212
|
+
if compare_data and not self.compare_param_same(main_param, other_param, has_uncertainty=True):
|
|
213
|
+
same_flag = False
|
|
214
|
+
if not same_flag:
|
|
215
|
+
data_dict[key.replace(main_node.id + Const.SEP, '')] = tp_need_merge_params
|
|
216
|
+
return data_types.get('input_data'), data_types.get('output_data')
|
|
217
|
+
|
|
218
|
+
def compare_param_same(self, main_param, other_param, has_uncertainty=False):
|
|
219
|
+
if not self._compare_value_same(main_param.get(Const.MAX), other_param.get(Const.MAX)):
|
|
220
|
+
return False
|
|
221
|
+
if not self._compare_value_same(main_param.get(Const.MIN), other_param.get(Const.MIN)):
|
|
222
|
+
return False
|
|
223
|
+
if not self._compare_value_same(main_param.get(Const.MEAN), other_param.get(Const.MEAN), has_uncertainty):
|
|
224
|
+
return False
|
|
225
|
+
if not self._compare_value_same(main_param.get(Const.NORM), other_param.get(Const.NORM), has_uncertainty):
|
|
226
|
+
return False
|
|
227
|
+
return True
|
|
228
|
+
|
|
229
|
+
def get_default_groups(self):
|
|
230
|
+
"""
|
|
231
|
+
根据GPU总数、TP数、PP数初始化并行组
|
|
232
|
+
|
|
233
|
+
return:
|
|
234
|
+
tp_groups: 张量并行组列表,每个元素是一个包含组内rank的列表
|
|
235
|
+
pp_groups: 流水线并行组列表,每个元素是一个包含组内rank的列表
|
|
236
|
+
"""
|
|
237
|
+
tp_groups, pp_groups = get_tp_pp_default_groups(self.parallel_param.rank_size, self.parallel_param.tp,
|
|
238
|
+
self.parallel_param.pp, order=self.parallel_param.order)
|
|
239
|
+
|
|
240
|
+
return tp_groups, pp_groups
|
|
241
|
+
|
|
242
|
+
def _add_all_nodes_rank(self):
|
|
243
|
+
for result in self.build_graph_results:
|
|
244
|
+
for node in result.graph.node_map.values():
|
|
245
|
+
node.rank = result.rank
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class PPMerger(BaseGraphMerger):
|
|
249
|
+
LAYERS_PATTERN = re.compile(r"(layers\.|layer\.)\d+(\.)")
|
|
250
|
+
MARK_PATTERN = re.compile(r"%(\d+)%(\d+)$")
|
|
251
|
+
MARK = '%'
|
|
252
|
+
|
|
253
|
+
@staticmethod
|
|
254
|
+
def _trace_p2p_mapping(p2p_mapping: dict):
|
|
255
|
+
"""
|
|
256
|
+
将字典分组为独立的链,每个链都从未访问过的键开始,按照字典中的映射关系进行追踪
|
|
257
|
+
p2p_mapping内容为p2p通信的send映射,追踪映射关系建立pp域
|
|
258
|
+
example: p2p_mapping={0: 2, 1: 3, 2: 4, 3: 5, 4: 6, 5: 7, 6: 4, 7: 5}, return=[[0, 2, 4, 6], [1, 3, 5, 7]]
|
|
259
|
+
"""
|
|
260
|
+
visited = set()
|
|
261
|
+
result = []
|
|
262
|
+
|
|
263
|
+
def collect_keys(start_key):
|
|
264
|
+
"""
|
|
265
|
+
追踪从某一个键开始的所有“连续”键,直到无法再找到下一个键为止
|
|
266
|
+
"""
|
|
267
|
+
current_key = start_key
|
|
268
|
+
chain = []
|
|
269
|
+
while current_key in p2p_mapping and current_key not in visited:
|
|
270
|
+
chain.append(current_key)
|
|
271
|
+
visited.add(current_key)
|
|
272
|
+
current_key = p2p_mapping[current_key]
|
|
273
|
+
return chain
|
|
274
|
+
|
|
275
|
+
for key in p2p_mapping:
|
|
276
|
+
if key not in visited:
|
|
277
|
+
chain_result = collect_keys(key)
|
|
278
|
+
if chain_result:
|
|
279
|
+
result.append(chain_result)
|
|
280
|
+
return result
|
|
281
|
+
|
|
282
|
+
@recursion_depth_decorator("msprobe.visualization.builder.graph_merger.PPMerger._merge_nodes", 1000)
|
|
283
|
+
def _merge_nodes(self, main_graph, main_node, other_graphs):
|
|
284
|
+
"""
|
|
285
|
+
其他rank graph中被pp切分的节点,需要合并到main graph
|
|
286
|
+
"""
|
|
287
|
+
other_nodes = []
|
|
288
|
+
for other_graph in other_graphs:
|
|
289
|
+
other_node = other_graph.get_node(main_node.id)
|
|
290
|
+
# 表明此节点只有main graph有
|
|
291
|
+
if not other_node:
|
|
292
|
+
other_nodes.clear()
|
|
293
|
+
return
|
|
294
|
+
other_nodes.append(other_node)
|
|
295
|
+
if other_nodes:
|
|
296
|
+
param_in, param_out = self.compare_node_param_data(main_node, other_nodes)
|
|
297
|
+
# 各个rank都有的模块,且输入输出都不一致,且节点id符合正则,判定为被pp切分的模块,需要合并结构
|
|
298
|
+
pp_merged_condition = param_in and param_out and self.LAYERS_PATTERN.search(main_node.id)
|
|
299
|
+
# backward可能没有output,是否要pp合并从对应的forward节点判断
|
|
300
|
+
if Const.SEP + Const.BACKWARD + Const.SEP in main_node.id:
|
|
301
|
+
f_node = main_graph.node_map.get(
|
|
302
|
+
main_node.id.replace(Const.SEP + Const.BACKWARD + Const.SEP, Const.SEP + Const.FORWARD + Const.SEP))
|
|
303
|
+
if f_node and hasattr(f_node, 'is_pp_merged'):
|
|
304
|
+
pp_merged_condition = True
|
|
305
|
+
if pp_merged_condition:
|
|
306
|
+
main_node.is_pp_merged = True
|
|
307
|
+
main_up_node = main_node.upnode
|
|
308
|
+
for other_node in other_nodes:
|
|
309
|
+
# pp切分中被切分的层在各rank的名称是一样的,这里给其他rank的同名层增加位置和rank标记
|
|
310
|
+
self._mark_node_id_position_rank(other_node, other_node.rank)
|
|
311
|
+
self._add_node_to_main_graph(main_graph, other_node)
|
|
312
|
+
# 其他rank被pp切分的模块节点添加到当前rank的graph
|
|
313
|
+
other_node.upnode = main_up_node
|
|
314
|
+
main_up_node.subnodes.append(other_node)
|
|
315
|
+
# 已找到被pp切分的模块节点,不再递归其内部
|
|
316
|
+
return
|
|
317
|
+
# 各个rank都有的forward模块,且输入一致,输出不一致,判定为模块内部包含被pp切分的模块,此模块的输出要使用最后一个rank的输出
|
|
318
|
+
elif not param_in and param_out and Const.SEP + Const.FORWARD + Const.SEP in main_node.id:
|
|
319
|
+
main_node.output_data = other_nodes[-1].output_data
|
|
320
|
+
# 各个rank都有的backward模块,且输出一致,输入不一致,判定为模块内部包含被pp切分的模块,此模块的输入要使用最后一个rank的输入
|
|
321
|
+
elif param_in and not param_out and Const.SEP + Const.BACKWARD + Const.SEP in main_node.id:
|
|
322
|
+
main_node.input_data = other_nodes[-1].input_data
|
|
323
|
+
self._merge_other_unique_nodes(main_graph, main_node, other_nodes)
|
|
324
|
+
for sub_node in main_node.subnodes:
|
|
325
|
+
if sub_node.op == NodeOp.module:
|
|
326
|
+
self._merge_nodes(main_graph, sub_node, other_graphs)
|
|
327
|
+
|
|
328
|
+
def merge_graphs(self):
|
|
329
|
+
results_groups = self.split_graph_results_by_groups(self.get_groups())
|
|
330
|
+
results = []
|
|
331
|
+
for result_groups in results_groups:
|
|
332
|
+
self.merge_graph_api_collection(result_groups)
|
|
333
|
+
results.extend(self.merge_pp_graphs(result_groups))
|
|
334
|
+
return results
|
|
335
|
+
|
|
336
|
+
def merge_pp_graphs(self, results):
|
|
337
|
+
if not results or len(results) < 2:
|
|
338
|
+
return results
|
|
339
|
+
graphs = [x.graph for x in results]
|
|
340
|
+
main_graph_result = results[0]
|
|
341
|
+
for main_node in main_graph_result.graph.root.subnodes:
|
|
342
|
+
if main_node.op == NodeOp.module and main_node.id not in self.unmerged_module:
|
|
343
|
+
self._merge_nodes(main_graph_result.graph, main_node, graphs[1:])
|
|
344
|
+
self._sort_nodes(main_graph_result.graph, main_node)
|
|
345
|
+
return [main_graph_result]
|
|
346
|
+
|
|
347
|
+
def get_groups(self):
|
|
348
|
+
"""
|
|
349
|
+
在各rank寻找p2p通信节点,建立各rank之间p2p的映射关系
|
|
350
|
+
"""
|
|
351
|
+
p2p_mapping = {}
|
|
352
|
+
for result in self.build_graph_results:
|
|
353
|
+
rank = result.rank
|
|
354
|
+
pp_rank = None
|
|
355
|
+
for node in result.graph.node_map.values():
|
|
356
|
+
if not node.id.startswith(Const.DISTRIBUTED + Const.SEP):
|
|
357
|
+
continue
|
|
358
|
+
if '.batch_isend_irecv.' in node.id:
|
|
359
|
+
for p2p_info in node.batch_p2p_info:
|
|
360
|
+
target_rank = p2p_info.get(GraphConst.PEER)
|
|
361
|
+
if target_rank is not None and target_rank != rank and p2p_info.get(GraphConst.OP) == 'isend':
|
|
362
|
+
pp_rank = target_rank
|
|
363
|
+
break
|
|
364
|
+
elif '.send.' in node.id or '.isend.' in node.id:
|
|
365
|
+
# example: Distributed.isend.0.forward --> Distributed.isend.0.forward.input.dst
|
|
366
|
+
dst_kwarg = f'{node.id}{Const.SEP}{Const.INPUT}{Const.SEP}{GraphConst.DST}'
|
|
367
|
+
dst = node.input_data.get(dst_kwarg, {}).get('value')
|
|
368
|
+
if dst is not None:
|
|
369
|
+
pp_rank = dst
|
|
370
|
+
break
|
|
371
|
+
if pp_rank is not None:
|
|
372
|
+
break
|
|
373
|
+
if pp_rank is not None:
|
|
374
|
+
p2p_mapping[rank] = pp_rank
|
|
375
|
+
pp_groups = self._trace_p2p_mapping(p2p_mapping)
|
|
376
|
+
if not pp_groups:
|
|
377
|
+
logger.info('Unable to get pp groups based on Distributed Api (batch_isend_irecv, send, or isend), '
|
|
378
|
+
'generate pp groups using parallel param "rank_size", "tp" and "pp".')
|
|
379
|
+
_, pp_groups = self.get_default_groups()
|
|
380
|
+
logger.info(f'{self.log_prefix} All pp groups is {pp_groups}.')
|
|
381
|
+
return pp_groups
|
|
382
|
+
|
|
383
|
+
def _merge_other_unique_nodes(self, main_graph, main_node, other_nodes):
|
|
384
|
+
"""
|
|
385
|
+
其他rank graph中other_node的子节点列表如果包含独有的节点,需要合并到main graph
|
|
386
|
+
"""
|
|
387
|
+
lists = [main_node.subnodes]
|
|
388
|
+
for other_node in other_nodes:
|
|
389
|
+
lists.append(other_node.subnodes)
|
|
390
|
+
dicts = [{node.id: node for node in lst} for lst in lists]
|
|
391
|
+
unique_node_ids = {}
|
|
392
|
+
# 计算每个集合的独有元素
|
|
393
|
+
for i, current_dict in enumerate(dicts):
|
|
394
|
+
other_ids = set()
|
|
395
|
+
for j, other_dict in enumerate(dicts):
|
|
396
|
+
if i != j:
|
|
397
|
+
# 更新并集,添加当前遍历到的集合的元素
|
|
398
|
+
other_ids.update(other_dict.keys())
|
|
399
|
+
result = set(current_dict.keys()) - other_ids
|
|
400
|
+
if i != 0 and result:
|
|
401
|
+
# 计算当前集合与其他集合并集的差集,即独有元素,保持原始顺序
|
|
402
|
+
unique_node_ids[i] = [node_id for node_id in current_dict if node_id in result]
|
|
403
|
+
unique_nodes = []
|
|
404
|
+
if unique_node_ids:
|
|
405
|
+
for i, items in unique_node_ids.items():
|
|
406
|
+
for item in items:
|
|
407
|
+
unique_nodes.append(dicts[i].get(item))
|
|
408
|
+
if unique_nodes:
|
|
409
|
+
for unique_node in unique_nodes:
|
|
410
|
+
self._mark_node_id_position_rank(unique_node, unique_node.rank)
|
|
411
|
+
self._add_node_to_main_graph(main_graph, unique_node)
|
|
412
|
+
main_node.subnodes.append(unique_node)
|
|
413
|
+
unique_node.upnode = main_node
|
|
414
|
+
|
|
415
|
+
def _sort_nodes(self, main_graph, start_node):
|
|
416
|
+
stack = [start_node]
|
|
417
|
+
while stack:
|
|
418
|
+
node = stack.pop()
|
|
419
|
+
if self.MARK_PATTERN.search(node.id):
|
|
420
|
+
is_forward = (Const.SEP + Const.FORWARD + Const.SEP in node.id or
|
|
421
|
+
Const.SEP + Const.FORWARD + self.MARK in node.id)
|
|
422
|
+
new_sub_nodes1, new_sub_nodes2 = [], []
|
|
423
|
+
for item in node.upnode.subnodes:
|
|
424
|
+
new_sub_nodes2.append(item) if self.MARK_PATTERN.search(item.id) else new_sub_nodes1.append(item)
|
|
425
|
+
|
|
426
|
+
order = True if is_forward else False
|
|
427
|
+
new_sub_nodes2.sort(key=lambda n: self._get_node_sort_rule(n, rank_ascending=order))
|
|
428
|
+
new_sub_nodes = new_sub_nodes1 + new_sub_nodes2 if is_forward else new_sub_nodes2 + new_sub_nodes1
|
|
429
|
+
|
|
430
|
+
index = -1
|
|
431
|
+
node_iter = new_sub_nodes if is_forward else reversed(new_sub_nodes)
|
|
432
|
+
for item in node_iter:
|
|
433
|
+
if self.LAYERS_PATTERN.search(item.id):
|
|
434
|
+
index += 1
|
|
435
|
+
if self.MARK_PATTERN.search(item.id):
|
|
436
|
+
item.pp_index = index
|
|
437
|
+
|
|
438
|
+
for item in new_sub_nodes2:
|
|
439
|
+
self._update_node_id(main_graph, item)
|
|
440
|
+
|
|
441
|
+
node.upnode.subnodes = new_sub_nodes
|
|
442
|
+
|
|
443
|
+
stack.extend(node.subnodes)
|
|
444
|
+
|
|
445
|
+
def _add_node_to_main_graph(self, main_graph: Graph, node: BaseNode):
|
|
446
|
+
if node.id in main_graph.node_map:
|
|
447
|
+
logger.warning(f'{node.id} is exist!')
|
|
448
|
+
else:
|
|
449
|
+
main_graph.node_map[node.id] = node
|
|
450
|
+
for sub_node in node.subnodes:
|
|
451
|
+
self._add_node_to_main_graph(main_graph, sub_node)
|
|
452
|
+
|
|
453
|
+
def _get_node_sort_rule(self, node, rank_ascending=True):
|
|
454
|
+
match = self.MARK_PATTERN.search(node.id)
|
|
455
|
+
if match:
|
|
456
|
+
# position代表当前节点在父节点中的位置序号
|
|
457
|
+
position, rank = int(match.group(1)), int(match.group(2))
|
|
458
|
+
if rank_ascending:
|
|
459
|
+
return rank, position
|
|
460
|
+
else:
|
|
461
|
+
return -rank, position
|
|
462
|
+
return (float('inf'), float('inf')) if rank_ascending else (-float('inf'), -float('inf'))
|
|
463
|
+
|
|
464
|
+
def _mark_node_id_position_rank(self, node: BaseNode, rank):
|
|
465
|
+
position = 0
|
|
466
|
+
for index, item in enumerate(node.upnode.subnodes):
|
|
467
|
+
if item.id == node.id:
|
|
468
|
+
position = index
|
|
469
|
+
break
|
|
470
|
+
# 各rank重复节点添加所处层级位置排序信息position和rank号,用%分隔
|
|
471
|
+
node.id = node.id + f'{self.MARK}{position}' + f'{self.MARK}{rank}'
|
|
472
|
+
for sub_node in node.subnodes:
|
|
473
|
+
self._mark_node_id_position_rank(sub_node, rank)
|
|
474
|
+
|
|
475
|
+
def _update_node_id(self, graph, start_node: BaseNode, pp_index=""):
|
|
476
|
+
stack = [(start_node, pp_index)]
|
|
477
|
+
while stack:
|
|
478
|
+
node, pp_index = stack.pop()
|
|
479
|
+
# 修改节点id之前删除node_map的信息,修改完再添加回去
|
|
480
|
+
if node.id not in graph.node_map:
|
|
481
|
+
logger.warning(f'Update node id {node.id} fail!')
|
|
482
|
+
else:
|
|
483
|
+
del graph.node_map[node.id]
|
|
484
|
+
old_id = self.MARK_PATTERN.sub("", node.id)
|
|
485
|
+
if node.op == NodeOp.module:
|
|
486
|
+
# 被pp切分的模块节点,基于位置和rank信息修改模块名称计数信息
|
|
487
|
+
if self.LAYERS_PATTERN.search(node.id) and self.MARK_PATTERN.search(node.id):
|
|
488
|
+
if hasattr(node, 'pp_index'):
|
|
489
|
+
pp_index = str(node.pp_index)
|
|
490
|
+
node.id = self.LAYERS_PATTERN.sub(r"\g<1>" + pp_index + r"\g<2>", node.id)
|
|
491
|
+
else:
|
|
492
|
+
# api节点,在api名称上添加rank信息
|
|
493
|
+
parts = node.id.split(Const.SEP)
|
|
494
|
+
parts[1] += f'_rank{node.id.split(PPMerger.MARK)[-1]}'
|
|
495
|
+
node.id = Const.SEP.join(parts)
|
|
496
|
+
# 把之前添加的位置和rank信息删掉
|
|
497
|
+
node.id = self.MARK_PATTERN.sub("", node.id)
|
|
498
|
+
# node id更新了,那么data的key中包含node id也要更新
|
|
499
|
+
node.input_data = self._update_node_data_key(old_id, node.id, node.input_data)
|
|
500
|
+
node.output_data = self._update_node_data_key(old_id, node.id, node.output_data)
|
|
501
|
+
graph.node_map[node.id] = node
|
|
502
|
+
# 将子节点加入栈中
|
|
503
|
+
for sub_node in node.subnodes:
|
|
504
|
+
stack.append((sub_node, pp_index))
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
class TPMerger(BaseGraphMerger):
|
|
508
|
+
RANK_PATTERN = re.compile(r"_rank(\d+)\.")
|
|
509
|
+
OPERATION_TABLE = {
|
|
510
|
+
Const.MAX: {
|
|
511
|
+
'initial': lambda p: p.get(Const.MAX),
|
|
512
|
+
'merge': lambda current, other: max(current, other.get(Const.MAX)),
|
|
513
|
+
'finalize': lambda current, count: current,
|
|
514
|
+
'formula': lambda key, values: f'{MAX_INFO}{key} is: max({", ".join(map(str, values))})'
|
|
515
|
+
},
|
|
516
|
+
Const.MIN: {
|
|
517
|
+
'initial': lambda p: p.get(Const.MIN),
|
|
518
|
+
'merge': lambda current, other: min(current, other.get(Const.MIN)),
|
|
519
|
+
'finalize': lambda current, count: current,
|
|
520
|
+
'formula': lambda key, values: f'{MIN_INFO}{key} is: min({", ".join(map(str, values))})'
|
|
521
|
+
},
|
|
522
|
+
Const.MEAN: {
|
|
523
|
+
'initial': lambda p: p.get(Const.MEAN),
|
|
524
|
+
'merge': lambda current, other: current + other.get(Const.MEAN),
|
|
525
|
+
'finalize': lambda current, count: current / count,
|
|
526
|
+
'formula': lambda key, values: f'{MEAN_INFO}{key} is: ({" + ".join(map(str, values))}) / {len(values)}'
|
|
527
|
+
},
|
|
528
|
+
Const.NORM: {
|
|
529
|
+
'initial': lambda p: pow(p.get(Const.NORM), 2.0),
|
|
530
|
+
'merge': lambda current, other: current + pow(other.get(Const.NORM), 2.0),
|
|
531
|
+
'finalize': lambda current, count: pow(current, 1 / 2.0),
|
|
532
|
+
'formula': lambda key, values: f'{NORM_INFO}{key} is: ({" + ".join([f"{v} ** 2" for v in values])}) ** 0.5'
|
|
533
|
+
}
|
|
534
|
+
}
|
|
535
|
+
TP_MERGED_INFO = f'This data is the merged data after tensor parallelism(TP), and the data is merged from rank '
|
|
536
|
+
|
|
537
|
+
@staticmethod
|
|
538
|
+
def _merge_params(tp_need_merge_param: dict):
|
|
539
|
+
"""
|
|
540
|
+
合并tp切分的各rank参数统计值
|
|
541
|
+
tp_need_merge_param: {input.0: [{"Max": 0, "Min": 0, ...}, {"Max": 0.1, "Min": 0, ...}, ...]}
|
|
542
|
+
return: 计算详情
|
|
543
|
+
"""
|
|
544
|
+
merge_info = []
|
|
545
|
+
for key, param_list in tp_need_merge_param.items():
|
|
546
|
+
if len(param_list) < 2:
|
|
547
|
+
continue
|
|
548
|
+
main_param = param_list[0]
|
|
549
|
+
|
|
550
|
+
for stat, ops in TPMerger.OPERATION_TABLE.items():
|
|
551
|
+
current_value = ops['initial'](main_param)
|
|
552
|
+
value_list = [current_value if stat != Const.NORM else main_param.get(Const.NORM)]
|
|
553
|
+
|
|
554
|
+
for other_param in param_list[1:]:
|
|
555
|
+
current_value = ops['merge'](current_value, other_param)
|
|
556
|
+
value_list.append(other_param.get(stat) if stat != Const.NORM else other_param.get(Const.NORM))
|
|
557
|
+
|
|
558
|
+
final_value = ops['finalize'](current_value, len(param_list))
|
|
559
|
+
main_param[stat] = final_value
|
|
560
|
+
formula_base = f'{ops["formula"](key, value_list)}' + f' = {final_value}'
|
|
561
|
+
|
|
562
|
+
merge_info.append(formula_base)
|
|
563
|
+
|
|
564
|
+
return merge_info
|
|
565
|
+
|
|
566
|
+
@staticmethod
|
|
567
|
+
def _get_need_merge_node(main_node, other_graphs, tp_merge_mapping):
|
|
568
|
+
"""
|
|
569
|
+
获取需要TP合并的节点列表
|
|
570
|
+
如果是TP+PP的混合并行,此时数据已经被PP合并过,一些node_id被标记上rank信息,此时需要基于rank映射才能获取到需要TP合并的节点列表,例如:
|
|
571
|
+
main_node = Torch.matmul_rank4.32.forward other_node = Torch.matmul_rank5.32.forward
|
|
572
|
+
需要建立4->5的映射,才能基于Torch.matmul_rank4.32.forward找到Torch.matmul_rank5.32.forward
|
|
573
|
+
"""
|
|
574
|
+
other_nodes = []
|
|
575
|
+
match = TPMerger.RANK_PATTERN.search(main_node.id)
|
|
576
|
+
# 节点名称被标记rank信息,且提供了映射
|
|
577
|
+
if match and tp_merge_mapping:
|
|
578
|
+
rank = int(match.group(1))
|
|
579
|
+
tp_mapping_ranks = tp_merge_mapping.get(rank)
|
|
580
|
+
if not tp_mapping_ranks:
|
|
581
|
+
return other_nodes
|
|
582
|
+
if len(tp_mapping_ranks) != len(other_graphs):
|
|
583
|
+
return other_nodes
|
|
584
|
+
for i, graph in enumerate(other_graphs):
|
|
585
|
+
# 基于映射得到目标rank,替换node_id当前rank信息后去目标graph取node
|
|
586
|
+
tp_mapping_id = TPMerger.RANK_PATTERN.sub(f"_rank{tp_mapping_ranks[i]}.", main_node.id)
|
|
587
|
+
other_node = graph.node_map.get(tp_mapping_id)
|
|
588
|
+
if not other_node or main_node.get_ancestors() != other_node.get_ancestors():
|
|
589
|
+
other_nodes.clear()
|
|
590
|
+
break
|
|
591
|
+
other_nodes.append(other_node)
|
|
592
|
+
else:
|
|
593
|
+
for graph in other_graphs:
|
|
594
|
+
other_node = graph.node_map.get(main_node.id)
|
|
595
|
+
if not other_node or main_node.get_ancestors() != other_node.get_ancestors():
|
|
596
|
+
other_nodes.clear()
|
|
597
|
+
break
|
|
598
|
+
other_nodes.append(other_node)
|
|
599
|
+
|
|
600
|
+
return other_nodes
|
|
601
|
+
|
|
602
|
+
@staticmethod
|
|
603
|
+
def _slice_list_at_id(node_list, target_id1, target_id2):
|
|
604
|
+
start_index, end_index = -1, -1
|
|
605
|
+
for index, node in enumerate(node_list):
|
|
606
|
+
if target_id1 in node.id:
|
|
607
|
+
start_index = index
|
|
608
|
+
elif target_id2 in node.id:
|
|
609
|
+
end_index = index
|
|
610
|
+
return [] if start_index == -1 or end_index == -1 else node_list[start_index:end_index + 1]
|
|
611
|
+
|
|
612
|
+
def merge_graphs(self):
|
|
613
|
+
results_groups = self.split_graph_results_by_groups(self.get_groups())
|
|
614
|
+
results = []
|
|
615
|
+
for result_groups in results_groups:
|
|
616
|
+
self.merge_graph_api_collection(result_groups)
|
|
617
|
+
results.extend(self.merge_tp_graphs(result_groups))
|
|
618
|
+
return results
|
|
619
|
+
|
|
620
|
+
def merge_tp_graphs(self, results, tp_merge_mapping=None):
|
|
621
|
+
if not results or len(results) < 2:
|
|
622
|
+
return results
|
|
623
|
+
graphs = [x.graph for x in results]
|
|
624
|
+
main_graph_result = results[0]
|
|
625
|
+
for main_node in main_graph_result.graph.node_map.values():
|
|
626
|
+
should_continue = (
|
|
627
|
+
not main_node.upnode or main_node.upnode.op != NodeOp.module or
|
|
628
|
+
main_node.upnode.id in self.unmerged_module or main_node.id.startswith(Const.DISTRIBUTED) or
|
|
629
|
+
main_node.parallel_merge_info != [])
|
|
630
|
+
if should_continue:
|
|
631
|
+
continue
|
|
632
|
+
self._handle_tp_matmul_reduce(main_node, graphs[1:], tp_merge_mapping)
|
|
633
|
+
other_nodes = self._get_need_merge_node(main_node, graphs[1:], tp_merge_mapping)
|
|
634
|
+
tp_need_merge_param_in, tp_need_merge_param_out = self.compare_node_param_data(main_node, other_nodes)
|
|
635
|
+
if tp_need_merge_param_in or tp_need_merge_param_out:
|
|
636
|
+
ranks = [main_node.rank]
|
|
637
|
+
for other_node in other_nodes:
|
|
638
|
+
ranks.append(other_node.rank)
|
|
639
|
+
main_node.parallel_merge_info.append(f'{self.TP_MERGED_INFO}{ranks}.')
|
|
640
|
+
merge_info_in = self._merge_params(tp_need_merge_param_in)
|
|
641
|
+
merge_info_out = self._merge_params(tp_need_merge_param_out)
|
|
642
|
+
main_node.parallel_merge_info.extend(merge_info_in + merge_info_out)
|
|
643
|
+
for main_node in main_graph_result.graph.node_map.values():
|
|
644
|
+
self._merge_tp_megatron_column_row_parallel(main_node, graphs[1:], tp_merge_mapping)
|
|
645
|
+
return [main_graph_result]
|
|
646
|
+
|
|
647
|
+
def get_groups(self):
|
|
648
|
+
tp_groups = []
|
|
649
|
+
for result in self.build_graph_results:
|
|
650
|
+
for node in result.graph.node_map.values():
|
|
651
|
+
if any(op in node.id for op in GraphConst.REDUCE_OPERATIONS):
|
|
652
|
+
group_ranks = node.input_data.get(f'{node.id}.input.group', {}).get('group_ranks')
|
|
653
|
+
if group_ranks and group_ranks not in tp_groups:
|
|
654
|
+
tp_groups.append(group_ranks)
|
|
655
|
+
break
|
|
656
|
+
if not tp_groups:
|
|
657
|
+
logger.info('Unable to get tp groups based on Distributed Api (reduce_scatter or all_reduce), '
|
|
658
|
+
'generate tp groups using parallel param "rank_size", "tp" and "pp".')
|
|
659
|
+
tp_groups, _ = self.get_default_groups()
|
|
660
|
+
logger.info(f'{self.log_prefix} All tp groups is {tp_groups}.')
|
|
661
|
+
return tp_groups
|
|
662
|
+
|
|
663
|
+
def _handle_tp_matmul_reduce(self, node, other_graphs, tp_merge_mapping):
|
|
664
|
+
"""
|
|
665
|
+
前向RowParallel和反向ColumnParallel层的matmul输出需要替换成matmul计算完成后all_reduce/reduce_scatter的输出
|
|
666
|
+
"""
|
|
667
|
+
if node.op != NodeOp.module:
|
|
668
|
+
return
|
|
669
|
+
splits = node.id.split(Const.SEP)
|
|
670
|
+
if len(splits) < 4:
|
|
671
|
+
return
|
|
672
|
+
is_forward_with_row_parallel = splits[-2] == Const.FORWARD and 'RowParallelLinear' in splits[-3]
|
|
673
|
+
is_backward_with_column_parallel = splits[-2] == Const.BACKWARD and 'ColumnParallelLinear' in splits[-3]
|
|
674
|
+
if not is_forward_with_row_parallel and not is_backward_with_column_parallel:
|
|
675
|
+
return
|
|
676
|
+
matmul_list = []
|
|
677
|
+
reduce_list = []
|
|
678
|
+
for sub_node in node.subnodes:
|
|
679
|
+
if 'matmul' in sub_node.id:
|
|
680
|
+
matmul_list.append(sub_node)
|
|
681
|
+
if ('_reduce_scatter_base' in sub_node.id or 'reduce_scatter_tensor' in sub_node.id or
|
|
682
|
+
'all_reduce' in sub_node.id):
|
|
683
|
+
reduce_list.append(sub_node)
|
|
684
|
+
if not matmul_list or not reduce_list:
|
|
685
|
+
return
|
|
686
|
+
for matmul_node in matmul_list:
|
|
687
|
+
if not matmul_node.output_data:
|
|
688
|
+
continue
|
|
689
|
+
# matmul的output0,将传递给all_reduce/reduce_scatter,作为all_reduce的input0,或作为reduce_scatter的input1
|
|
690
|
+
matmul_node_output_param = list(matmul_node.output_data.values())[0]
|
|
691
|
+
for reduce_node in reduce_list:
|
|
692
|
+
if not reduce_node.output_data:
|
|
693
|
+
continue
|
|
694
|
+
if 'all_reduce' in reduce_node.id:
|
|
695
|
+
if not reduce_node.input_data:
|
|
696
|
+
continue
|
|
697
|
+
reduce_node_input_param = list(reduce_node.input_data.values())[0]
|
|
698
|
+
else:
|
|
699
|
+
if len(reduce_node.input_data) < 2:
|
|
700
|
+
continue
|
|
701
|
+
reduce_node_input_param = list(reduce_node.input_data.values())[1]
|
|
702
|
+
if not self.compare_param_same(matmul_node_output_param, reduce_node_input_param):
|
|
703
|
+
continue
|
|
704
|
+
# matmul的input统计值与其他rank的数据进行合并
|
|
705
|
+
other_nodes = self._get_need_merge_node(matmul_node, other_graphs, tp_merge_mapping)
|
|
706
|
+
tp_need_merge_param_in, _ = self.compare_node_param_data(matmul_node, other_nodes)
|
|
707
|
+
if tp_need_merge_param_in:
|
|
708
|
+
ranks = [matmul_node.rank]
|
|
709
|
+
for other_node in other_nodes:
|
|
710
|
+
ranks.append(other_node.rank)
|
|
711
|
+
matmul_node.parallel_merge_info.append(f'{self.TP_MERGED_INFO}{ranks}.')
|
|
712
|
+
merge_info_in = self._merge_params(tp_need_merge_param_in)
|
|
713
|
+
matmul_node.parallel_merge_info.extend(merge_info_in)
|
|
714
|
+
# matmul的output0替换为all_reduce/reduce_scatter的output0
|
|
715
|
+
reduce_node_output_param = list(reduce_node.output_data.values())[0]
|
|
716
|
+
keys = [Const.MAX, Const.MIN, Const.MEAN, Const.NORM]
|
|
717
|
+
matmul_node_output_param.update({k: reduce_node_output_param.get(k) for k in keys})
|
|
718
|
+
full_op_name = reduce_node_output_param.get('full_op_name')
|
|
719
|
+
param_name = full_op_name if full_op_name else reduce_node.id
|
|
720
|
+
matmul_node.parallel_merge_info.append(f'The output of this data is merged from {param_name}')
|
|
721
|
+
reduce_list.remove(reduce_node)
|
|
722
|
+
break
|
|
723
|
+
|
|
724
|
+
def _merge_tp_megatron_column_row_parallel(self, node, other_graphs, tp_merge_mapping):
|
|
725
|
+
if node.op != NodeOp.module or node.parallel_merge_info:
|
|
726
|
+
return
|
|
727
|
+
splits = node.id.split(Const.SEP)
|
|
728
|
+
if len(splits) < 4:
|
|
729
|
+
return
|
|
730
|
+
is_forward_with_column_parallel = splits[-2] == Const.FORWARD and 'ColumnParallelLinear' in splits[-3]
|
|
731
|
+
if not is_forward_with_column_parallel:
|
|
732
|
+
return
|
|
733
|
+
if not node.upnode:
|
|
734
|
+
return
|
|
735
|
+
# 获取[ColumnParallelLinear, RowParallelLinear]结构
|
|
736
|
+
nodes = self._slice_list_at_id(node.upnode.subnodes, node.id, 'RowParallelLinear')
|
|
737
|
+
if len(nodes) < 2:
|
|
738
|
+
return
|
|
739
|
+
stack = nodes[:]
|
|
740
|
+
while stack:
|
|
741
|
+
current_node = stack.pop()
|
|
742
|
+
stack.extend(reversed(current_node.subnodes))
|
|
743
|
+
|
|
744
|
+
if current_node.parallel_merge_info or current_node.id.startswith(Const.DISTRIBUTED):
|
|
745
|
+
continue
|
|
746
|
+
|
|
747
|
+
other_nodes = self._get_need_merge_node(current_node, other_graphs, tp_merge_mapping)
|
|
748
|
+
param_in, param_out = self.compare_node_param_data(current_node, other_nodes, False)
|
|
749
|
+
|
|
750
|
+
if param_in or param_out:
|
|
751
|
+
ranks = [current_node.rank]
|
|
752
|
+
for other_node in other_nodes:
|
|
753
|
+
ranks.append(other_node.rank)
|
|
754
|
+
current_node.parallel_merge_info.append(f'{self.TP_MERGED_INFO}{ranks}.')
|
|
755
|
+
# ColumnParallelLinear层的输入、其中的matmul输入不需要合并
|
|
756
|
+
if current_node == nodes[0] or ('matmul' in current_node.id and current_node.upnode == nodes[0]):
|
|
757
|
+
param_in.pop('input.0', None)
|
|
758
|
+
# RowParallelLinear层的输出、其中的matmul输出不需要合并, bias不需要合并
|
|
759
|
+
elif current_node == nodes[-1] or ('matmul' in current_node.id and current_node.upnode == nodes[-1]):
|
|
760
|
+
param_out = {}
|
|
761
|
+
param_in.pop('parameters.bias', None)
|
|
762
|
+
|
|
763
|
+
merge_info_in = self._merge_params(param_in)
|
|
764
|
+
merge_info_out = self._merge_params(param_out)
|
|
765
|
+
current_node.parallel_merge_info.extend(merge_info_in + merge_info_out)
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
class NoParallelMerger(BaseGraphMerger):
|
|
769
|
+
def merge_graphs(self):
|
|
770
|
+
self.merge_graph_api_collection(self.build_graph_results)
|
|
771
|
+
return self.build_graph_results
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
class TPPPMerger(BaseGraphMerger):
|
|
775
|
+
def merge_graphs(self):
|
|
776
|
+
tp_merger = TPMerger(self.build_graph_results, self.parallel_param, self.is_bench)
|
|
777
|
+
pp_merger = PPMerger(self.build_graph_results, self.parallel_param, self.is_bench) \
|
|
778
|
+
if self.parallel_param.vpp == 1 else VPPMerger(self.build_graph_results, self.parallel_param, self.is_bench)
|
|
779
|
+
pp_groups = pp_merger.get_groups()
|
|
780
|
+
tp_groups = tp_merger.get_groups()
|
|
781
|
+
# 进入TP+PP混合处理器,PP和TP必然大于1
|
|
782
|
+
tp_merge_mapping = {}
|
|
783
|
+
for tp_group in tp_groups[1:]:
|
|
784
|
+
tp_merge_mapping[tp_group[0]] = tp_group[1:]
|
|
785
|
+
self.merge_graph_api_collection(self.build_graph_results)
|
|
786
|
+
# 先合并pp,需要知道pp域,在各自pp域中合并
|
|
787
|
+
results_groups_pp = self.split_graph_results_by_groups(pp_groups)
|
|
788
|
+
pp_results = []
|
|
789
|
+
for results in results_groups_pp:
|
|
790
|
+
pp_results.extend(pp_merger.merge_pp_graphs(results))
|
|
791
|
+
# pp合并完成后,直接进行tp合并,最终得到一个graph
|
|
792
|
+
tp_result = tp_merger.merge_tp_graphs(pp_results, tp_merge_mapping)
|
|
793
|
+
self.sort_merged_api_collection(tp_result[0].graph)
|
|
794
|
+
return tp_result
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
class FullMerger(BaseGraphMerger):
|
|
798
|
+
def merge_graphs(self):
|
|
799
|
+
tp_merger = TPMerger(self.build_graph_results, self.parallel_param, self.is_bench)
|
|
800
|
+
pp_merger = PPMerger(self.build_graph_results, self.parallel_param, self.is_bench) \
|
|
801
|
+
if self.parallel_param.vpp == 1 else VPPMerger(self.build_graph_results, self.parallel_param, self.is_bench)
|
|
802
|
+
pp_groups = pp_merger.get_groups()
|
|
803
|
+
tp_groups = tp_merger.get_groups()
|
|
804
|
+
tp_merge_mapping = {}
|
|
805
|
+
if len(tp_groups) < 1:
|
|
806
|
+
raise RuntimeError(f'Graph merged error, and tp_groups is {tp_groups}.')
|
|
807
|
+
for tp_group in tp_groups[1:]:
|
|
808
|
+
if len(tp_group) < 1:
|
|
809
|
+
raise RuntimeError(f'Graph merged error, and tp_group is {tp_group}.')
|
|
810
|
+
tp_merge_mapping[tp_group[0]] = tp_group[1:]
|
|
811
|
+
# 先合并pp,需要知道pp域,在各自pp域中合并
|
|
812
|
+
results_groups_pp = self.split_graph_results_by_groups(pp_groups)
|
|
813
|
+
pp_results = {}
|
|
814
|
+
for pp_result in results_groups_pp:
|
|
815
|
+
self.merge_graph_api_collection(pp_result)
|
|
816
|
+
pp_result = pp_merger.merge_pp_graphs(pp_result)[0]
|
|
817
|
+
pp_results[pp_result.rank] = pp_result
|
|
818
|
+
# pp合并完成后,基于tp域划分pp合并结果
|
|
819
|
+
lists_to_be_tp_merged = []
|
|
820
|
+
for tp_group in tp_groups:
|
|
821
|
+
list_to_be_tp_merged = []
|
|
822
|
+
for rank in tp_group:
|
|
823
|
+
pp_result = pp_results.get(rank)
|
|
824
|
+
if pp_result:
|
|
825
|
+
list_to_be_tp_merged.append(pp_result)
|
|
826
|
+
if list_to_be_tp_merged:
|
|
827
|
+
lists_to_be_tp_merged.append(list_to_be_tp_merged)
|
|
828
|
+
tp_results = []
|
|
829
|
+
for list_to_be_tp_merged in lists_to_be_tp_merged:
|
|
830
|
+
self.merge_graph_api_collection(list_to_be_tp_merged)
|
|
831
|
+
tp_merged_result = tp_merger.merge_tp_graphs(list_to_be_tp_merged, tp_merge_mapping)
|
|
832
|
+
self.sort_merged_api_collection(tp_merged_result[0].graph)
|
|
833
|
+
tp_results.extend(tp_merged_result)
|
|
834
|
+
return tp_results
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
class VPPMerger(PPMerger):
|
|
838
|
+
LAYERS_NUM_PATTERN = re.compile(r"(layers\.|layer\.)(\d+)(\.)")
|
|
839
|
+
FORWARD_PATTERN = re.compile(r'\.forward\.\d+$')
|
|
840
|
+
|
|
841
|
+
@staticmethod
|
|
842
|
+
def _replace_vpp_id(s, vpp_id):
|
|
843
|
+
parts = s.split(Const.SEP)
|
|
844
|
+
if len(parts) < 2 or not parts[1].isdigit():
|
|
845
|
+
return s
|
|
846
|
+
parts[1] = str(vpp_id)
|
|
847
|
+
return Const.SEP.join(parts)
|
|
848
|
+
|
|
849
|
+
def merge_pp_graphs(self, results):
|
|
850
|
+
if not results or len(results) < 2:
|
|
851
|
+
return results
|
|
852
|
+
graphs = [x.graph for x in results]
|
|
853
|
+
main_graph_result = results[0]
|
|
854
|
+
for main_node in main_graph_result.graph.root.subnodes:
|
|
855
|
+
if main_node.op == NodeOp.module and main_node.id not in self.unmerged_module:
|
|
856
|
+
self._merge_nodes(main_graph_result.graph, main_node, graphs[1:])
|
|
857
|
+
self._sort_nodes(main_graph_result.graph, main_node)
|
|
858
|
+
self._merge_vpp_data(main_graph_result.graph)
|
|
859
|
+
self._merge_vpp_chunks(main_graph_result.graph)
|
|
860
|
+
return [main_graph_result]
|
|
861
|
+
|
|
862
|
+
def _merge_vpp_data(self, graph):
|
|
863
|
+
"""
|
|
864
|
+
所有chunk的数据都合并到chunk0,前向chunk0的输出使用最后一个chunk的输出,反向chunk0的输入使用最后一个chunk的输入
|
|
865
|
+
"""
|
|
866
|
+
module_list = []
|
|
867
|
+
for node in reversed(graph.root.subnodes):
|
|
868
|
+
parts = node.id.split(Const.SEP)
|
|
869
|
+
if len(parts) < 2:
|
|
870
|
+
continue
|
|
871
|
+
if parts[1] in [GraphConst.VPP_CHUNK_0, str(self.parallel_param.vpp - 1)]:
|
|
872
|
+
module_list.append(node)
|
|
873
|
+
if not module_list:
|
|
874
|
+
return
|
|
875
|
+
stack = module_list[:]
|
|
876
|
+
while stack:
|
|
877
|
+
current_node = stack.pop()
|
|
878
|
+
if hasattr(current_node, 'is_pp_merged') or hasattr(current_node,
|
|
879
|
+
'pp_index') or current_node.op != NodeOp.module:
|
|
880
|
+
continue
|
|
881
|
+
is_forward = self.FORWARD_PATTERN.search(current_node.id)
|
|
882
|
+
stack.extend(reversed(current_node.subnodes))
|
|
883
|
+
target_id = self._replace_vpp_id(current_node.id, self.parallel_param.vpp - 1)
|
|
884
|
+
target_node = graph.node_map.get(target_id)
|
|
885
|
+
if not target_node:
|
|
886
|
+
continue
|
|
887
|
+
if is_forward:
|
|
888
|
+
current_node.output_data = self._update_node_data_key(target_node.id, current_node.id,
|
|
889
|
+
target_node.output_data)
|
|
890
|
+
else:
|
|
891
|
+
current_node.input_data = self._update_node_data_key(target_node.id, current_node.id,
|
|
892
|
+
target_node.input_data)
|
|
893
|
+
|
|
894
|
+
def _merge_vpp_chunks(self, graph):
|
|
895
|
+
"""
|
|
896
|
+
所有chunk都合并到chunk0,layers层搬到chunk0并重排序号
|
|
897
|
+
"""
|
|
898
|
+
chunk_id_list = [i for i in range(1, self.parallel_param.vpp)]
|
|
899
|
+
chunk_0_list = []
|
|
900
|
+
for node in reversed(graph.root.subnodes):
|
|
901
|
+
parts = node.id.split(Const.SEP)
|
|
902
|
+
if len(parts) < 2:
|
|
903
|
+
continue
|
|
904
|
+
if parts[1] == GraphConst.VPP_CHUNK_0:
|
|
905
|
+
chunk_0_list.append(node)
|
|
906
|
+
if not chunk_0_list:
|
|
907
|
+
return
|
|
908
|
+
stack = chunk_0_list[:]
|
|
909
|
+
layers_need_merge_dict = {}
|
|
910
|
+
while stack:
|
|
911
|
+
current_node = stack.pop()
|
|
912
|
+
if hasattr(current_node, 'is_pp_merged') or hasattr(current_node, 'pp_index') \
|
|
913
|
+
and current_node.upnode.id not in layers_need_merge_dict:
|
|
914
|
+
layers_need_merge_dict[current_node.upnode.id] = current_node.upnode
|
|
915
|
+
continue
|
|
916
|
+
stack.extend(reversed(current_node.subnodes))
|
|
917
|
+
for node in layers_need_merge_dict.values():
|
|
918
|
+
is_forward = self.FORWARD_PATTERN.search(node.id)
|
|
919
|
+
for vpp_id in chunk_id_list:
|
|
920
|
+
target_node = graph.node_map.get(self._replace_vpp_id(node.id, vpp_id))
|
|
921
|
+
if not target_node:
|
|
922
|
+
continue
|
|
923
|
+
# 其他chunk的layers都搬到chunk0,forward追加到后面,backward追加到前面
|
|
924
|
+
if is_forward:
|
|
925
|
+
node.subnodes.extend(target_node.subnodes)
|
|
926
|
+
else:
|
|
927
|
+
node.subnodes = target_node.subnodes + node.subnodes
|
|
928
|
+
for sub_node in target_node.subnodes:
|
|
929
|
+
sub_node.upnode = node
|
|
930
|
+
# 获取其他chunk的层级链路,删除所有父节点,不在前端展示已合并的其他chunk节点
|
|
931
|
+
ancestors = target_node.get_ancestors()
|
|
932
|
+
if len(ancestors) < 2:
|
|
933
|
+
continue
|
|
934
|
+
for module_id in ancestors[1:]:
|
|
935
|
+
graph.node_map.pop(module_id, None)
|
|
936
|
+
graph.root.subnodes = [node for node in graph.root.subnodes if node.id != ancestors[1]]
|
|
937
|
+
# layers层重排序号
|
|
938
|
+
self._sort_layers(node.subnodes, graph, is_forward)
|
|
939
|
+
|
|
940
|
+
def _sort_layers(self, node_list, graph, is_forward):
|
|
941
|
+
if not is_forward:
|
|
942
|
+
node_list = list(reversed(node_list))
|
|
943
|
+
index = -1
|
|
944
|
+
for node in node_list:
|
|
945
|
+
match = self.LAYERS_NUM_PATTERN.search(node.id)
|
|
946
|
+
if match:
|
|
947
|
+
index += 1
|
|
948
|
+
parts = node.id.split(Const.SEP)
|
|
949
|
+
# Module.0.xxx代表第一个chunk,不必重排序
|
|
950
|
+
if len(parts) < 2 or parts[1] == GraphConst.VPP_CHUNK_0:
|
|
951
|
+
continue
|
|
952
|
+
# layers层修改chunk号和layers序号,非layers层修改chunk号
|
|
953
|
+
new_node_id_prefix = ''
|
|
954
|
+
if match:
|
|
955
|
+
prefix, number, dot = match.groups()
|
|
956
|
+
new_string = prefix + str(index) + dot
|
|
957
|
+
start, end = match.span()
|
|
958
|
+
new_node_id_prefix = node.id[:start] + new_string
|
|
959
|
+
new_node_id_prefix = self._replace_vpp_id(new_node_id_prefix, GraphConst.VPP_CHUNK_0)
|
|
960
|
+
new_node_id = new_node_id_prefix + node.id[end:]
|
|
961
|
+
else:
|
|
962
|
+
new_node_id = self._replace_vpp_id(node.id, GraphConst.VPP_CHUNK_0)
|
|
963
|
+
graph.node_map.pop(node.id, None)
|
|
964
|
+
node.input_data = self._update_node_data_key(node.id, new_node_id, node.input_data)
|
|
965
|
+
node.output_data = self._update_node_data_key(node.id, new_node_id, node.output_data)
|
|
966
|
+
node.id = new_node_id
|
|
967
|
+
graph.node_map[new_node_id] = node
|
|
968
|
+
stack = node.subnodes[:]
|
|
969
|
+
while stack:
|
|
970
|
+
current_node = stack.pop()
|
|
971
|
+
if current_node.op != NodeOp.module:
|
|
972
|
+
continue
|
|
973
|
+
stack.extend(reversed(current_node.subnodes))
|
|
974
|
+
match = self.LAYERS_NUM_PATTERN.search(current_node.id)
|
|
975
|
+
if match:
|
|
976
|
+
_, e = match.span()
|
|
977
|
+
new_current_node_id = new_node_id_prefix + current_node.id[e:]
|
|
978
|
+
else:
|
|
979
|
+
new_current_node_id = self._replace_vpp_id(current_node.id, GraphConst.VPP_CHUNK_0)
|
|
980
|
+
current_node.input_data = self._update_node_data_key(current_node.id, new_current_node_id,
|
|
981
|
+
current_node.input_data)
|
|
982
|
+
current_node.output_data = self._update_node_data_key(current_node.id, new_current_node_id,
|
|
983
|
+
current_node.output_data)
|
|
984
|
+
graph.node_map.pop(current_node.id, None)
|
|
985
|
+
current_node.id = new_current_node_id
|
|
986
|
+
graph.node_map[new_current_node_id] = current_node
|