mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -14,26 +14,29 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import re
|
|
17
|
+
import copy
|
|
17
18
|
from dataclasses import dataclass
|
|
18
19
|
|
|
19
20
|
from msprobe.core.common.const import Const
|
|
20
|
-
from msprobe.core.common.file_utils import load_json,
|
|
21
|
+
from msprobe.core.common.file_utils import load_json, load_construct_json
|
|
21
22
|
from msprobe.core.common.utils import load_stack_json
|
|
23
|
+
from msprobe.core.common.log import logger
|
|
22
24
|
from msprobe.visualization.builder.msprobe_adapter import get_input_output
|
|
23
|
-
from msprobe.visualization.builder.msprobe_adapter import op_patterns
|
|
24
25
|
from msprobe.visualization.graph.graph import Graph
|
|
25
26
|
from msprobe.visualization.graph.node_op import NodeOp
|
|
26
27
|
from msprobe.visualization.utils import GraphConst
|
|
28
|
+
from msprobe.visualization.db_utils import node_to_db, config_to_db
|
|
27
29
|
|
|
28
30
|
|
|
29
31
|
class GraphBuilder:
|
|
30
32
|
backward_pattern = re.compile(r"(\.backward\.)(\d+)$")
|
|
31
33
|
forward_pattern = re.compile(r"(\.forward\.)(\d+)$")
|
|
32
34
|
# 匹配以大写字母开头,后接任意字母,并以Template(结尾,或包含api_template(的字符串
|
|
33
|
-
template_pattern = re.compile(r'\b([A-Z][a-zA-Z]*Template|api_template)\(')
|
|
35
|
+
template_pattern = re.compile(r'\b([A-Z][a-zA-Z]*Template|api_template|api_instance)\(')
|
|
36
|
+
micro_step_dict = {}
|
|
34
37
|
|
|
35
38
|
@staticmethod
|
|
36
|
-
def build(construct_path, data_path, stack_path, model_name='DefaultModel'
|
|
39
|
+
def build(construct_path, data_path, stack_path, model_name='DefaultModel'):
|
|
37
40
|
"""
|
|
38
41
|
GraphBuilder的对外提供的构图方法
|
|
39
42
|
Args:
|
|
@@ -41,75 +44,38 @@ class GraphBuilder:
|
|
|
41
44
|
data_path: dump.json路径
|
|
42
45
|
stack_path: stack.json路径
|
|
43
46
|
model_name: 模型名字,依赖外部输入
|
|
44
|
-
complete_stack: 完整的堆栈信息
|
|
45
47
|
Returns: Graph,代表图的数据结构
|
|
46
48
|
"""
|
|
47
|
-
construct_dict =
|
|
49
|
+
construct_dict, micro_step_dict = load_construct_json(construct_path)
|
|
50
|
+
if not construct_dict:
|
|
51
|
+
logger.error("The content of 'construct.json' is empty, failed to build graph. "
|
|
52
|
+
"When dumping data, it is necessary to select level L0 or mix in order to "
|
|
53
|
+
"collect model structure data, that is, the content of 'construct.json' is not empty.")
|
|
54
|
+
raise RuntimeError
|
|
55
|
+
GraphBuilder.micro_step_dict = micro_step_dict
|
|
48
56
|
dump_dict = load_json(data_path)
|
|
49
57
|
stack_dict = load_stack_json(stack_path)
|
|
50
|
-
if not complete_stack:
|
|
51
|
-
GraphBuilder._simplify_stack(stack_dict)
|
|
52
58
|
data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
|
|
53
|
-
graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict
|
|
59
|
+
graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict,
|
|
60
|
+
micro_step_num=micro_step_dict.get(Const.MEGATRON_MICRO_STEP_NUMBER))
|
|
54
61
|
GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
|
|
62
|
+
GraphBuilder._handle_recompute(graph)
|
|
55
63
|
GraphBuilder._collect_apis_between_modules(graph)
|
|
56
64
|
GraphBuilder._add_parameters_grad(graph, data_dict)
|
|
57
65
|
return graph
|
|
58
66
|
|
|
59
67
|
@staticmethod
|
|
60
|
-
def
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
68
|
+
def to_db(filename, config):
|
|
69
|
+
config.graph_n.step = config.step
|
|
70
|
+
config.graph_n.rank = config.rank
|
|
71
|
+
config.graph_n.compare_mode = config.compare_mode
|
|
72
|
+
node_to_db(config.graph_n, filename)
|
|
65
73
|
if config.graph_b:
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
result[GraphConst.JSON_TIP_KEY] = config.tool_tip
|
|
72
|
-
if config.node_colors:
|
|
73
|
-
result[GraphConst.COLORS] = config.node_colors
|
|
74
|
-
if config.micro_steps:
|
|
75
|
-
result[GraphConst.MICRO_STEPS] = config.micro_steps
|
|
76
|
-
if config.task:
|
|
77
|
-
result[GraphConst.JSON_TASK_KEY] = config.task
|
|
78
|
-
result[GraphConst.OVERFLOW_CHECK] = config.overflow_check
|
|
79
|
-
save_json(filename, result, indent=4)
|
|
80
|
-
|
|
81
|
-
@staticmethod
|
|
82
|
-
def _simplify_stack(stack_dict):
|
|
83
|
-
"""
|
|
84
|
-
精简堆栈内容,模块级保留包含"模块名("的堆栈,api级保留"xxxTemplate("的下一行堆栈
|
|
85
|
-
|
|
86
|
-
例如模块 Module.layer3.0.bn2.BatchNorm2d.forward.0,模块名为bn2,匹配"bn2(",
|
|
87
|
-
保留堆栈"File /home/models/resnet.py, line 97, in forward, \n out = self.bn2(out)"
|
|
88
|
-
|
|
89
|
-
例如Api Tensor.__iadd__.4.forward,堆栈为:
|
|
90
|
-
"File /home/wrap_tensor.py, line 61, return TensorOPTemplate(op_name, hook)(*args, **kwargs)",
|
|
91
|
-
"File /home/torchvision/models/resnet.py, line 102, in forward, \n out += identity",
|
|
92
|
-
匹配到第一行的"TensorOPTemplate(",保留下一行堆栈
|
|
93
|
-
"""
|
|
94
|
-
module_pattern = re.compile(op_patterns[0])
|
|
95
|
-
for dump_name, stack_list in stack_dict.items():
|
|
96
|
-
if not isinstance(stack_list, list):
|
|
97
|
-
continue
|
|
98
|
-
if module_pattern.match(dump_name):
|
|
99
|
-
parts = dump_name.split(Const.SEP)
|
|
100
|
-
if len(parts) < abs(Const.LAYER_NAME_INDEX):
|
|
101
|
-
continue
|
|
102
|
-
module_name = parts[Const.LAYER_NAME_INDEX]
|
|
103
|
-
for stack in stack_list:
|
|
104
|
-
if re.search(module_name + r'\(', stack):
|
|
105
|
-
stack_list = [stack]
|
|
106
|
-
break
|
|
107
|
-
else:
|
|
108
|
-
for index, stack in enumerate(stack_list):
|
|
109
|
-
if GraphBuilder.template_pattern.search(stack) and index < len(stack_list) - 1:
|
|
110
|
-
stack_list = [stack_list[index + 1]]
|
|
111
|
-
break
|
|
112
|
-
stack_dict[dump_name] = stack_list
|
|
74
|
+
config.graph_b.data_source = GraphConst.JSON_BENCH_KEY
|
|
75
|
+
config.graph_b.step = config.step
|
|
76
|
+
config.graph_b.rank = config.rank
|
|
77
|
+
node_to_db(config.graph_b, filename)
|
|
78
|
+
config_to_db(config, filename)
|
|
113
79
|
|
|
114
80
|
@staticmethod
|
|
115
81
|
def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id):
|
|
@@ -132,10 +98,47 @@ class GraphBuilder:
|
|
|
132
98
|
return new_upnode_id
|
|
133
99
|
return upnode_id
|
|
134
100
|
|
|
101
|
+
@staticmethod
|
|
102
|
+
def _handle_backward_inplace(construct_dict, sub_node_id, up_node_id):
|
|
103
|
+
"""
|
|
104
|
+
如果当前backward节点的父层级信息不等于其父级节点的层级信息,则尝试从同名的forward节点寻找父级节点
|
|
105
|
+
主要针对的场景:inplace层会无法触发backward hook导致反向层级错误
|
|
106
|
+
|
|
107
|
+
example:
|
|
108
|
+
正确的层级关系:
|
|
109
|
+
父层:Module.layer4.1.BasicBlock.backward.0的层级信息为Module.layer4.1
|
|
110
|
+
子层:Module.layer4.1.conv2.Conv2d.backward.0的父层级信息为Module.layer4.1
|
|
111
|
+
|
|
112
|
+
错误的层级关系:
|
|
113
|
+
父层:Module.layer4.1.relu.ReLU.backward.1的层级信息为Module.layer4.1.relu
|
|
114
|
+
子层:Module.layer4.1.conv2.Conv2d.backward.0的父层级信息为Module.layer4.1
|
|
115
|
+
"""
|
|
116
|
+
if GraphBuilder.backward_pattern.search(sub_node_id) and up_node_id:
|
|
117
|
+
sub_split = sub_node_id.split(Const.SEP)
|
|
118
|
+
if len(sub_split) < 5:
|
|
119
|
+
return up_node_id
|
|
120
|
+
up_split = up_node_id.split(Const.SEP)
|
|
121
|
+
if len(up_split) < 4:
|
|
122
|
+
return up_node_id
|
|
123
|
+
sub_node_prefix = Const.SEP.join(sub_split[:-4])
|
|
124
|
+
up_node_prefix = Const.SEP.join(up_split[:-3])
|
|
125
|
+
if sub_node_prefix != up_node_prefix:
|
|
126
|
+
forward_sub_node_id = GraphBuilder.backward_pattern.sub(r".forward.\2", sub_node_id)
|
|
127
|
+
if forward_sub_node_id in construct_dict:
|
|
128
|
+
forward_up_node_id = construct_dict.get(forward_sub_node_id)
|
|
129
|
+
# forward_up_node_id ---> null
|
|
130
|
+
if not forward_up_node_id:
|
|
131
|
+
return forward_up_node_id
|
|
132
|
+
new_up_node_id = GraphBuilder.forward_pattern.sub(r".backward.\2", forward_up_node_id)
|
|
133
|
+
if new_up_node_id in construct_dict:
|
|
134
|
+
return new_up_node_id
|
|
135
|
+
return up_node_id
|
|
136
|
+
|
|
135
137
|
@staticmethod
|
|
136
138
|
def _init_nodes(graph, construct_dict, data_dict, stack_dict):
|
|
137
139
|
for subnode_id, upnode_id in construct_dict.items():
|
|
138
|
-
upnode_id = GraphBuilder.
|
|
140
|
+
upnode_id = GraphBuilder._handle_backward_inplace(construct_dict, subnode_id, upnode_id) if upnode_id \
|
|
141
|
+
else GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id)
|
|
139
142
|
if upnode_id:
|
|
140
143
|
upnode_op = NodeOp.get_node_op(upnode_id)
|
|
141
144
|
upnode = GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], upnode_op, upnode_id)
|
|
@@ -171,6 +174,8 @@ class GraphBuilder:
|
|
|
171
174
|
node_stack_info = forward_node.stack_info if forward_node \
|
|
172
175
|
else ['This backward node cannot find the forward node and cannot retrieve stack information.']
|
|
173
176
|
node.stack_info = node_stack_info
|
|
177
|
+
if GraphBuilder.micro_step_dict:
|
|
178
|
+
node.micro_step_id = GraphBuilder.micro_step_dict.get(node.id, 0)
|
|
174
179
|
# 添加节点
|
|
175
180
|
node.add_upnode(upnode)
|
|
176
181
|
return node
|
|
@@ -230,6 +235,8 @@ class GraphBuilder:
|
|
|
230
235
|
node.upnode = api_collection_node
|
|
231
236
|
api_collection_node.upnode = graph.root
|
|
232
237
|
output.append(api_collection_node)
|
|
238
|
+
if temp_nodes[0].micro_step_id is not None:
|
|
239
|
+
api_collection_node.micro_step_id = temp_nodes[0].micro_step_id
|
|
233
240
|
else:
|
|
234
241
|
# 如果连续的api节点不足2个,将它们原样添加到输出列表
|
|
235
242
|
output.extend(temp_nodes)
|
|
@@ -276,10 +283,128 @@ class GraphBuilder:
|
|
|
276
283
|
# 更新数据
|
|
277
284
|
graph.get_node(parameters_grad_node_id).set_input_output(input_data, output_data)
|
|
278
285
|
|
|
286
|
+
@staticmethod
|
|
287
|
+
def _handle_recompute(graph):
|
|
288
|
+
"""
|
|
289
|
+
1. 通过_get_recompute_map获得重计算节点映射recompute_map: dict(node_id: node_id_prefix)
|
|
290
|
+
2. 通过_get_no_recompute_map获得非重计算节点映射no_recompute_map: dict(node_id_prefix: list(node_id))
|
|
291
|
+
3. 遍历recompute_map,通过node_id_prefix与no_recompute_map建立连接,通过非重计算节点找到自身的父节点
|
|
292
|
+
"""
|
|
293
|
+
recompute_map, recompute_id_map = GraphBuilder._get_recompute_map(graph.root.subnodes)
|
|
294
|
+
if not recompute_map:
|
|
295
|
+
return
|
|
296
|
+
id_prefixes = set(recompute_map.values())
|
|
297
|
+
no_recompute_map = GraphBuilder._get_no_recompute_map(graph, id_prefixes)
|
|
298
|
+
if not no_recompute_map:
|
|
299
|
+
return
|
|
300
|
+
# 深拷贝非重计算节点字典用于反向模式
|
|
301
|
+
no_recompute_ids_b = copy.deepcopy(no_recompute_map)
|
|
302
|
+
|
|
303
|
+
del_indexes = []
|
|
304
|
+
for node_id, id_prefix in recompute_map.items():
|
|
305
|
+
if id_prefix not in no_recompute_map:
|
|
306
|
+
continue
|
|
307
|
+
node_list = no_recompute_map.get(id_prefix) if GraphBuilder.forward_pattern.search(node_id) else \
|
|
308
|
+
no_recompute_ids_b.get(id_prefix)
|
|
309
|
+
if not node_list:
|
|
310
|
+
continue
|
|
311
|
+
no_recompute_node = node_list.pop()
|
|
312
|
+
recompute_node = graph.node_map.get(node_id)
|
|
313
|
+
if not recompute_node:
|
|
314
|
+
continue
|
|
315
|
+
# 通过非重计算forward节点的父节点,找到对应的backward父节点
|
|
316
|
+
new_up_node = graph.node_map.get(
|
|
317
|
+
GraphBuilder.forward_pattern.sub(r".backward.\2", no_recompute_node.upnode.id))
|
|
318
|
+
if not new_up_node:
|
|
319
|
+
continue
|
|
320
|
+
|
|
321
|
+
# 更新节点连接关系
|
|
322
|
+
recompute_node.upnode = new_up_node
|
|
323
|
+
new_up_node.subnodes.append(recompute_node)
|
|
324
|
+
|
|
325
|
+
del_indexes.append(recompute_id_map.get(node_id))
|
|
326
|
+
|
|
327
|
+
# 从后往前删除graph首层中已更新父节点的重计算节点
|
|
328
|
+
del_indexes.sort(reverse=True)
|
|
329
|
+
for index in del_indexes:
|
|
330
|
+
if 0 <= index <= len(graph.root.subnodes):
|
|
331
|
+
del graph.root.subnodes[index]
|
|
332
|
+
|
|
333
|
+
@staticmethod
|
|
334
|
+
def _get_recompute_map(node_list: list):
|
|
335
|
+
"""
|
|
336
|
+
找到graph首层的重计算层
|
|
337
|
+
|
|
338
|
+
return: dict(node_id: node_id_prefix), dict(node_id: index)
|
|
339
|
+
|
|
340
|
+
example:
|
|
341
|
+
{Module.0.module.decoder.layers.0.TransformerLayer.forward.4: Module.0.module.decoder.layers.0.TransformerLayer}
|
|
342
|
+
"""
|
|
343
|
+
recompute_map = {}
|
|
344
|
+
recompute_id_map = {}
|
|
345
|
+
node_id_set = set([node.id for node in node_list])
|
|
346
|
+
node_id_cache = set()
|
|
347
|
+
for i, node in enumerate(node_list):
|
|
348
|
+
if NodeOp.get_node_op(node.id) != NodeOp.module:
|
|
349
|
+
continue
|
|
350
|
+
id_segments = node.id.split(Const.SEP)
|
|
351
|
+
prefix = Const.SEP.join(id_segments[:-2])
|
|
352
|
+
if node.id in node_id_cache:
|
|
353
|
+
recompute_map[node.id] = prefix
|
|
354
|
+
recompute_id_map[node.id] = i
|
|
355
|
+
continue
|
|
356
|
+
is_recompute = GraphBuilder._is_recompute_node_id(id_segments)
|
|
357
|
+
if not is_recompute:
|
|
358
|
+
continue
|
|
359
|
+
# 重计算层必然是一组对应的前反向节点
|
|
360
|
+
id_segments[-2] = Const.BACKWARD if id_segments[-2] == Const.FORWARD else Const.FORWARD
|
|
361
|
+
relative_node_id = Const.SEP.join(id_segments)
|
|
362
|
+
if relative_node_id in node_id_set:
|
|
363
|
+
recompute_map[node.id] = prefix
|
|
364
|
+
recompute_id_map[node.id] = i
|
|
365
|
+
# 对应节点id放入缓存避免后续重复判断
|
|
366
|
+
node_id_cache.add(relative_node_id)
|
|
367
|
+
return recompute_map, recompute_id_map
|
|
368
|
+
|
|
369
|
+
@staticmethod
|
|
370
|
+
def _is_recompute_node_id(id_segments):
|
|
371
|
+
"""
|
|
372
|
+
非重计算首层节点命名必然是:Module/Cell.{number(可选)}.module_name.{number(可选)}.class_name.forward/backward.number
|
|
373
|
+
如果不符合,则判断为重计算节点
|
|
374
|
+
"""
|
|
375
|
+
if len(id_segments) > 7:
|
|
376
|
+
return True
|
|
377
|
+
if len(id_segments) == 7 and not (id_segments[1].isdigit() and id_segments[3].isdigit()):
|
|
378
|
+
return True
|
|
379
|
+
if len(id_segments) == 6 and not id_segments[1].isdigit():
|
|
380
|
+
return True
|
|
381
|
+
return False
|
|
382
|
+
|
|
383
|
+
@staticmethod
|
|
384
|
+
def _get_no_recompute_map(graph, recompute_id_prefixes):
|
|
385
|
+
"""
|
|
386
|
+
寻找与重计算层id前缀相同的非重计算forward层,按顺序排列,重计算层按照顺序使用非重计算forward层的父节点对应的backward节点
|
|
387
|
+
|
|
388
|
+
return: dict(node_id_prefix: list(node_id))
|
|
389
|
+
"""
|
|
390
|
+
no_recompute_map = {}
|
|
391
|
+
for node_id, node in graph.node_map.items():
|
|
392
|
+
if NodeOp.get_node_op(node_id) == NodeOp.module and GraphBuilder.forward_pattern.search(node_id):
|
|
393
|
+
if not node.upnode or node.upnode.id == graph.root.id:
|
|
394
|
+
continue
|
|
395
|
+
id_prefix = GraphBuilder.forward_pattern.sub('', node_id)
|
|
396
|
+
if id_prefix not in recompute_id_prefixes:
|
|
397
|
+
continue
|
|
398
|
+
no_recompute_map.setdefault(id_prefix, []).append(node)
|
|
399
|
+
for node_list in no_recompute_map.values():
|
|
400
|
+
# 方便按顺序pop弹出
|
|
401
|
+
node_list.reverse()
|
|
402
|
+
return no_recompute_map
|
|
403
|
+
|
|
279
404
|
|
|
280
405
|
class GraphExportConfig:
|
|
281
406
|
def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
|
|
282
|
-
overflow_check=False, compare_mode=None):
|
|
407
|
+
overflow_check=False, compare_mode=None, step=0, rank=0, step_list=None, rank_list=None):
|
|
283
408
|
self.graph_n = graph_n
|
|
284
409
|
self.graph_b = graph_b
|
|
285
410
|
self.tool_tip = tool_tip
|
|
@@ -288,6 +413,10 @@ class GraphExportConfig:
|
|
|
288
413
|
self.task = task
|
|
289
414
|
self.overflow_check = overflow_check
|
|
290
415
|
self.compare_mode = compare_mode
|
|
416
|
+
self.step = step
|
|
417
|
+
self.rank = rank
|
|
418
|
+
self.step_list = step_list
|
|
419
|
+
self.rank_list = rank_list
|
|
291
420
|
|
|
292
421
|
|
|
293
422
|
@dataclass
|