mindstudio-probe 8.2.0__py3-none-any.whl → 8.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/RECORD +90 -79
- msprobe/README.md +7 -5
- msprobe/core/common/const.py +6 -0
- msprobe/core/common/db_manager.py +35 -4
- msprobe/core/common/file_utils.py +105 -27
- msprobe/core/common/framework_adapter.py +7 -6
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/utils.py +14 -3
- msprobe/core/compare/find_first/analyzer.py +8 -7
- msprobe/core/compare/find_first/graph.py +11 -3
- msprobe/core/compare/find_first/utils.py +2 -1
- msprobe/core/compare/highlight.py +13 -6
- msprobe/core/compare/multiprocessing_compute.py +17 -10
- msprobe/core/compare/utils.py +14 -5
- msprobe/core/data_dump/data_collector.py +18 -21
- msprobe/core/data_dump/data_processor/pytorch_processor.py +43 -20
- msprobe/core/data_dump/json_writer.py +18 -8
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +37 -3
- msprobe/core/service.py +18 -5
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +7 -5
- msprobe/docs/02.config_introduction.md +14 -1
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/06.data_dump_MindSpore.md +1 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +295 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +46 -5
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +2 -0
- msprobe/docs/21.visualization_PyTorch.md +15 -80
- msprobe/docs/22.visualization_MindSpore.md +20 -104
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/mindspore/cell_processor.py +33 -5
- msprobe/mindspore/compare/common_dir_compare.py +22 -26
- msprobe/mindspore/compare/utils.py +1 -2
- msprobe/mindspore/debugger/precision_debugger.py +1 -1
- msprobe/mindspore/dump/cell_dump_process.py +73 -62
- msprobe/mindspore/dump/graph_mode_cell_dump.py +21 -10
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +2 -0
- msprobe/msprobe.py +6 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +36 -3
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +24 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +12 -2
- msprobe/pytorch/api_accuracy_checker/config.yaml +6 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +132 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +205 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +378 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +239 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +250 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +198 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/common/utils.py +22 -2
- msprobe/pytorch/compare/utils.py +3 -3
- msprobe/pytorch/debugger/debugger_config.py +10 -0
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +34 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +23 -10
- msprobe/pytorch/hook_module/api_register.py +6 -1
- msprobe/pytorch/monitor/module_hook.py +28 -9
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/pt_config.py +57 -2
- msprobe/pytorch/pytorch_service.py +11 -2
- msprobe/visualization/builder/graph_builder.py +170 -64
- msprobe/visualization/builder/graph_merger.py +0 -1
- msprobe/visualization/builder/msprobe_adapter.py +1 -1
- msprobe/visualization/db_utils.py +25 -2
- msprobe/visualization/graph/base_node.py +0 -24
- msprobe/visualization/graph/graph.py +5 -14
- msprobe/visualization/graph_service.py +29 -53
- msprobe/visualization/utils.py +11 -1
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/top_level.txt +0 -0
|
@@ -14,14 +14,14 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import re
|
|
17
|
+
import copy
|
|
17
18
|
from dataclasses import dataclass
|
|
18
19
|
|
|
19
20
|
from msprobe.core.common.const import Const
|
|
20
|
-
from msprobe.core.common.file_utils import load_json,
|
|
21
|
+
from msprobe.core.common.file_utils import load_json, load_construct_json
|
|
21
22
|
from msprobe.core.common.utils import load_stack_json
|
|
22
23
|
from msprobe.core.common.log import logger
|
|
23
24
|
from msprobe.visualization.builder.msprobe_adapter import get_input_output
|
|
24
|
-
from msprobe.visualization.builder.msprobe_adapter import op_patterns
|
|
25
25
|
from msprobe.visualization.graph.graph import Graph
|
|
26
26
|
from msprobe.visualization.graph.node_op import NodeOp
|
|
27
27
|
from msprobe.visualization.utils import GraphConst
|
|
@@ -33,9 +33,10 @@ class GraphBuilder:
|
|
|
33
33
|
forward_pattern = re.compile(r"(\.forward\.)(\d+)$")
|
|
34
34
|
# 匹配以大写字母开头,后接任意字母,并以Template(结尾,或包含api_template(的字符串
|
|
35
35
|
template_pattern = re.compile(r'\b([A-Z][a-zA-Z]*Template|api_template|api_instance)\(')
|
|
36
|
+
micro_step_dict = {}
|
|
36
37
|
|
|
37
38
|
@staticmethod
|
|
38
|
-
def build(construct_path, data_path, stack_path, model_name='DefaultModel'
|
|
39
|
+
def build(construct_path, data_path, stack_path, model_name='DefaultModel'):
|
|
39
40
|
"""
|
|
40
41
|
GraphBuilder的对外提供的构图方法
|
|
41
42
|
Args:
|
|
@@ -43,48 +44,26 @@ class GraphBuilder:
|
|
|
43
44
|
data_path: dump.json路径
|
|
44
45
|
stack_path: stack.json路径
|
|
45
46
|
model_name: 模型名字,依赖外部输入
|
|
46
|
-
complete_stack: 完整的堆栈信息
|
|
47
47
|
Returns: Graph,代表图的数据结构
|
|
48
48
|
"""
|
|
49
|
-
construct_dict =
|
|
49
|
+
construct_dict, micro_step_dict = load_construct_json(construct_path)
|
|
50
50
|
if not construct_dict:
|
|
51
51
|
logger.error("The content of 'construct.json' is empty, failed to build graph. "
|
|
52
52
|
"When dumping data, it is necessary to select level L0 or mix in order to "
|
|
53
53
|
"collect model structure data, that is, the content of 'construct.json' is not empty.")
|
|
54
54
|
raise RuntimeError
|
|
55
|
+
GraphBuilder.micro_step_dict = micro_step_dict
|
|
55
56
|
dump_dict = load_json(data_path)
|
|
56
57
|
stack_dict = load_stack_json(stack_path)
|
|
57
|
-
if not complete_stack:
|
|
58
|
-
GraphBuilder._simplify_stack(stack_dict)
|
|
59
58
|
data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
|
|
60
|
-
graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict
|
|
59
|
+
graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict,
|
|
60
|
+
micro_step_num=micro_step_dict.get(Const.MEGATRON_MICRO_STEP_NUMBER))
|
|
61
61
|
GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
|
|
62
|
+
GraphBuilder._handle_recompute(graph)
|
|
62
63
|
GraphBuilder._collect_apis_between_modules(graph)
|
|
63
64
|
GraphBuilder._add_parameters_grad(graph, data_dict)
|
|
64
65
|
return graph
|
|
65
66
|
|
|
66
|
-
@staticmethod
|
|
67
|
-
def to_json(filename, config):
|
|
68
|
-
"""
|
|
69
|
-
将graph导出成.vis文件的接口
|
|
70
|
-
"""
|
|
71
|
-
result = {}
|
|
72
|
-
if config.graph_b:
|
|
73
|
-
result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict(config.compare_mode)
|
|
74
|
-
result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict(config.compare_mode)
|
|
75
|
-
else:
|
|
76
|
-
result = config.graph_n.to_dict(config.compare_mode)
|
|
77
|
-
if config.tool_tip:
|
|
78
|
-
result[GraphConst.JSON_TIP_KEY] = config.tool_tip
|
|
79
|
-
if config.node_colors:
|
|
80
|
-
result[GraphConst.COLORS] = config.node_colors
|
|
81
|
-
if config.micro_steps:
|
|
82
|
-
result[GraphConst.MICRO_STEPS] = config.micro_steps
|
|
83
|
-
if config.task:
|
|
84
|
-
result[GraphConst.JSON_TASK_KEY] = config.task
|
|
85
|
-
result[GraphConst.OVERFLOW_CHECK] = config.overflow_check
|
|
86
|
-
save_json(filename, result, indent=4)
|
|
87
|
-
|
|
88
67
|
@staticmethod
|
|
89
68
|
def to_db(filename, config):
|
|
90
69
|
config.graph_n.step = config.step
|
|
@@ -95,42 +74,10 @@ class GraphBuilder:
|
|
|
95
74
|
config.graph_b.data_source = GraphConst.JSON_BENCH_KEY
|
|
96
75
|
config.graph_b.step = config.step
|
|
97
76
|
config.graph_b.rank = config.rank
|
|
77
|
+
config.graph_b.compare_mode = config.compare_mode
|
|
98
78
|
node_to_db(config.graph_b, filename)
|
|
99
79
|
config_to_db(config, filename)
|
|
100
80
|
|
|
101
|
-
@staticmethod
|
|
102
|
-
def _simplify_stack(stack_dict):
|
|
103
|
-
"""
|
|
104
|
-
精简堆栈内容,模块级保留包含"模块名("的堆栈,api级保留"xxxTemplate("的下一行堆栈
|
|
105
|
-
|
|
106
|
-
例如模块 Module.layer3.0.bn2.BatchNorm2d.forward.0,模块名为bn2,匹配"bn2(",
|
|
107
|
-
保留堆栈"File /home/models/resnet.py, line 97, in forward, \n out = self.bn2(out)"
|
|
108
|
-
|
|
109
|
-
例如Api Tensor.__iadd__.4.forward,堆栈为:
|
|
110
|
-
"File /home/wrap_tensor.py, line 61, return TensorOPTemplate(op_name, hook)(*args, **kwargs)",
|
|
111
|
-
"File /home/torchvision/models/resnet.py, line 102, in forward, \n out += identity",
|
|
112
|
-
匹配到第一行的"TensorOPTemplate(",保留下一行堆栈
|
|
113
|
-
"""
|
|
114
|
-
module_pattern = re.compile(op_patterns[0])
|
|
115
|
-
for dump_name, stack_list in stack_dict.items():
|
|
116
|
-
if not isinstance(stack_list, list):
|
|
117
|
-
continue
|
|
118
|
-
if module_pattern.match(dump_name):
|
|
119
|
-
parts = dump_name.split(Const.SEP)
|
|
120
|
-
if len(parts) < abs(Const.LAYER_NAME_INDEX):
|
|
121
|
-
continue
|
|
122
|
-
module_name = parts[Const.LAYER_NAME_INDEX]
|
|
123
|
-
for stack in stack_list:
|
|
124
|
-
if re.search(module_name + r'\(', stack):
|
|
125
|
-
stack_list = [stack]
|
|
126
|
-
break
|
|
127
|
-
else:
|
|
128
|
-
for index, stack in enumerate(stack_list):
|
|
129
|
-
if GraphBuilder.template_pattern.search(stack) and index < len(stack_list) - 1:
|
|
130
|
-
stack_list = [stack_list[index + 1]]
|
|
131
|
-
break
|
|
132
|
-
stack_dict[dump_name] = stack_list
|
|
133
|
-
|
|
134
81
|
@staticmethod
|
|
135
82
|
def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id):
|
|
136
83
|
"""
|
|
@@ -152,10 +99,47 @@ class GraphBuilder:
|
|
|
152
99
|
return new_upnode_id
|
|
153
100
|
return upnode_id
|
|
154
101
|
|
|
102
|
+
@staticmethod
|
|
103
|
+
def _handle_backward_inplace(construct_dict, sub_node_id, up_node_id):
|
|
104
|
+
"""
|
|
105
|
+
如果当前backward节点的父层级信息不等于其父级节点的层级信息,则尝试从同名的forward节点寻找父级节点
|
|
106
|
+
主要针对的场景:inplace层会无法触发backward hook导致反向层级错误
|
|
107
|
+
|
|
108
|
+
example:
|
|
109
|
+
正确的层级关系:
|
|
110
|
+
父层:Module.layer4.1.BasicBlock.backward.0的层级信息为Module.layer4.1
|
|
111
|
+
子层:Module.layer4.1.conv2.Conv2d.backward.0的父层级信息为Module.layer4.1
|
|
112
|
+
|
|
113
|
+
错误的层级关系:
|
|
114
|
+
父层:Module.layer4.1.relu.ReLU.backward.1的层级信息为Module.layer4.1.relu
|
|
115
|
+
子层:Module.layer4.1.conv2.Conv2d.backward.0的父层级信息为Module.layer4.1
|
|
116
|
+
"""
|
|
117
|
+
if GraphBuilder.backward_pattern.search(sub_node_id) and up_node_id:
|
|
118
|
+
sub_split = sub_node_id.split(Const.SEP)
|
|
119
|
+
if len(sub_split) < 5:
|
|
120
|
+
return up_node_id
|
|
121
|
+
up_split = up_node_id.split(Const.SEP)
|
|
122
|
+
if len(up_split) < 4:
|
|
123
|
+
return up_node_id
|
|
124
|
+
sub_node_prefix = Const.SEP.join(sub_split[:-4])
|
|
125
|
+
up_node_prefix = Const.SEP.join(up_split[:-3])
|
|
126
|
+
if sub_node_prefix != up_node_prefix:
|
|
127
|
+
forward_sub_node_id = GraphBuilder.backward_pattern.sub(r".forward.\2", sub_node_id)
|
|
128
|
+
if forward_sub_node_id in construct_dict:
|
|
129
|
+
forward_up_node_id = construct_dict.get(forward_sub_node_id)
|
|
130
|
+
# forward_up_node_id ---> null
|
|
131
|
+
if not forward_up_node_id:
|
|
132
|
+
return forward_up_node_id
|
|
133
|
+
new_up_node_id = GraphBuilder.forward_pattern.sub(r".backward.\2", forward_up_node_id)
|
|
134
|
+
if new_up_node_id in construct_dict:
|
|
135
|
+
return new_up_node_id
|
|
136
|
+
return up_node_id
|
|
137
|
+
|
|
155
138
|
@staticmethod
|
|
156
139
|
def _init_nodes(graph, construct_dict, data_dict, stack_dict):
|
|
157
140
|
for subnode_id, upnode_id in construct_dict.items():
|
|
158
|
-
upnode_id = GraphBuilder.
|
|
141
|
+
upnode_id = GraphBuilder._handle_backward_inplace(construct_dict, subnode_id, upnode_id) if upnode_id \
|
|
142
|
+
else GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id)
|
|
159
143
|
if upnode_id:
|
|
160
144
|
upnode_op = NodeOp.get_node_op(upnode_id)
|
|
161
145
|
upnode = GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], upnode_op, upnode_id)
|
|
@@ -191,6 +175,8 @@ class GraphBuilder:
|
|
|
191
175
|
node_stack_info = forward_node.stack_info if forward_node \
|
|
192
176
|
else ['This backward node cannot find the forward node and cannot retrieve stack information.']
|
|
193
177
|
node.stack_info = node_stack_info
|
|
178
|
+
if GraphBuilder.micro_step_dict:
|
|
179
|
+
node.micro_step_id = GraphBuilder.micro_step_dict.get(node.id, 0)
|
|
194
180
|
# 添加节点
|
|
195
181
|
node.add_upnode(upnode)
|
|
196
182
|
return node
|
|
@@ -250,6 +236,8 @@ class GraphBuilder:
|
|
|
250
236
|
node.upnode = api_collection_node
|
|
251
237
|
api_collection_node.upnode = graph.root
|
|
252
238
|
output.append(api_collection_node)
|
|
239
|
+
if temp_nodes[0].micro_step_id is not None:
|
|
240
|
+
api_collection_node.micro_step_id = temp_nodes[0].micro_step_id
|
|
253
241
|
else:
|
|
254
242
|
# 如果连续的api节点不足2个,将它们原样添加到输出列表
|
|
255
243
|
output.extend(temp_nodes)
|
|
@@ -296,6 +284,124 @@ class GraphBuilder:
|
|
|
296
284
|
# 更新数据
|
|
297
285
|
graph.get_node(parameters_grad_node_id).set_input_output(input_data, output_data)
|
|
298
286
|
|
|
287
|
+
@staticmethod
|
|
288
|
+
def _handle_recompute(graph):
|
|
289
|
+
"""
|
|
290
|
+
1. 通过_get_recompute_map获得重计算节点映射recompute_map: dict(node_id: node_id_prefix)
|
|
291
|
+
2. 通过_get_no_recompute_map获得非重计算节点映射no_recompute_map: dict(node_id_prefix: list(node_id))
|
|
292
|
+
3. 遍历recompute_map,通过node_id_prefix与no_recompute_map建立连接,通过非重计算节点找到自身的父节点
|
|
293
|
+
"""
|
|
294
|
+
recompute_map, recompute_id_map = GraphBuilder._get_recompute_map(graph.root.subnodes)
|
|
295
|
+
if not recompute_map:
|
|
296
|
+
return
|
|
297
|
+
id_prefixes = set(recompute_map.values())
|
|
298
|
+
no_recompute_map = GraphBuilder._get_no_recompute_map(graph, id_prefixes)
|
|
299
|
+
if not no_recompute_map:
|
|
300
|
+
return
|
|
301
|
+
# 深拷贝非重计算节点字典用于反向模式
|
|
302
|
+
no_recompute_ids_b = copy.deepcopy(no_recompute_map)
|
|
303
|
+
|
|
304
|
+
del_indexes = []
|
|
305
|
+
for node_id, id_prefix in recompute_map.items():
|
|
306
|
+
if id_prefix not in no_recompute_map:
|
|
307
|
+
continue
|
|
308
|
+
node_list = no_recompute_map.get(id_prefix) if GraphBuilder.forward_pattern.search(node_id) else \
|
|
309
|
+
no_recompute_ids_b.get(id_prefix)
|
|
310
|
+
if not node_list:
|
|
311
|
+
continue
|
|
312
|
+
no_recompute_node = node_list.pop()
|
|
313
|
+
recompute_node = graph.node_map.get(node_id)
|
|
314
|
+
if not recompute_node:
|
|
315
|
+
continue
|
|
316
|
+
# 通过非重计算forward节点的父节点,找到对应的backward父节点
|
|
317
|
+
new_up_node = graph.node_map.get(
|
|
318
|
+
GraphBuilder.forward_pattern.sub(r".backward.\2", no_recompute_node.upnode.id))
|
|
319
|
+
if not new_up_node:
|
|
320
|
+
continue
|
|
321
|
+
|
|
322
|
+
# 更新节点连接关系
|
|
323
|
+
recompute_node.upnode = new_up_node
|
|
324
|
+
new_up_node.subnodes.append(recompute_node)
|
|
325
|
+
|
|
326
|
+
del_indexes.append(recompute_id_map.get(node_id))
|
|
327
|
+
|
|
328
|
+
# 从后往前删除graph首层中已更新父节点的重计算节点
|
|
329
|
+
del_indexes.sort(reverse=True)
|
|
330
|
+
for index in del_indexes:
|
|
331
|
+
if 0 <= index <= len(graph.root.subnodes):
|
|
332
|
+
del graph.root.subnodes[index]
|
|
333
|
+
|
|
334
|
+
@staticmethod
|
|
335
|
+
def _get_recompute_map(node_list: list):
|
|
336
|
+
"""
|
|
337
|
+
找到graph首层的重计算层
|
|
338
|
+
|
|
339
|
+
return: dict(node_id: node_id_prefix), dict(node_id: index)
|
|
340
|
+
|
|
341
|
+
example:
|
|
342
|
+
{Module.0.module.decoder.layers.0.TransformerLayer.forward.4: Module.0.module.decoder.layers.0.TransformerLayer}
|
|
343
|
+
"""
|
|
344
|
+
recompute_map = {}
|
|
345
|
+
recompute_id_map = {}
|
|
346
|
+
node_id_set = set([node.id for node in node_list])
|
|
347
|
+
node_id_cache = set()
|
|
348
|
+
for i, node in enumerate(node_list):
|
|
349
|
+
if NodeOp.get_node_op(node.id) != NodeOp.module:
|
|
350
|
+
continue
|
|
351
|
+
id_segments = node.id.split(Const.SEP)
|
|
352
|
+
prefix = Const.SEP.join(id_segments[:-2])
|
|
353
|
+
if node.id in node_id_cache:
|
|
354
|
+
recompute_map[node.id] = prefix
|
|
355
|
+
recompute_id_map[node.id] = i
|
|
356
|
+
continue
|
|
357
|
+
is_recompute = GraphBuilder._is_recompute_node_id(id_segments)
|
|
358
|
+
if not is_recompute:
|
|
359
|
+
continue
|
|
360
|
+
# 重计算层必然是一组对应的前反向节点
|
|
361
|
+
id_segments[-2] = Const.BACKWARD if id_segments[-2] == Const.FORWARD else Const.FORWARD
|
|
362
|
+
relative_node_id = Const.SEP.join(id_segments)
|
|
363
|
+
if relative_node_id in node_id_set:
|
|
364
|
+
recompute_map[node.id] = prefix
|
|
365
|
+
recompute_id_map[node.id] = i
|
|
366
|
+
# 对应节点id放入缓存避免后续重复判断
|
|
367
|
+
node_id_cache.add(relative_node_id)
|
|
368
|
+
return recompute_map, recompute_id_map
|
|
369
|
+
|
|
370
|
+
@staticmethod
|
|
371
|
+
def _is_recompute_node_id(id_segments):
|
|
372
|
+
"""
|
|
373
|
+
非重计算首层节点命名必然是:Module/Cell.{number(可选)}.module_name.{number(可选)}.class_name.forward/backward.number
|
|
374
|
+
如果不符合,则判断为重计算节点
|
|
375
|
+
"""
|
|
376
|
+
if len(id_segments) > 7:
|
|
377
|
+
return True
|
|
378
|
+
if len(id_segments) == 7 and not (id_segments[1].isdigit() and id_segments[3].isdigit()):
|
|
379
|
+
return True
|
|
380
|
+
if len(id_segments) == 6 and not id_segments[1].isdigit():
|
|
381
|
+
return True
|
|
382
|
+
return False
|
|
383
|
+
|
|
384
|
+
@staticmethod
|
|
385
|
+
def _get_no_recompute_map(graph, recompute_id_prefixes):
|
|
386
|
+
"""
|
|
387
|
+
寻找与重计算层id前缀相同的非重计算forward层,按顺序排列,重计算层按照顺序使用非重计算forward层的父节点对应的backward节点
|
|
388
|
+
|
|
389
|
+
return: dict(node_id_prefix: list(node_id))
|
|
390
|
+
"""
|
|
391
|
+
no_recompute_map = {}
|
|
392
|
+
for node_id, node in graph.node_map.items():
|
|
393
|
+
if NodeOp.get_node_op(node_id) == NodeOp.module and GraphBuilder.forward_pattern.search(node_id):
|
|
394
|
+
if not node.upnode or node.upnode.id == graph.root.id:
|
|
395
|
+
continue
|
|
396
|
+
id_prefix = GraphBuilder.forward_pattern.sub('', node_id)
|
|
397
|
+
if id_prefix not in recompute_id_prefixes:
|
|
398
|
+
continue
|
|
399
|
+
no_recompute_map.setdefault(id_prefix, []).append(node)
|
|
400
|
+
for node_list in no_recompute_map.values():
|
|
401
|
+
# 方便按顺序pop弹出
|
|
402
|
+
node_list.reverse()
|
|
403
|
+
return no_recompute_map
|
|
404
|
+
|
|
299
405
|
|
|
300
406
|
class GraphExportConfig:
|
|
301
407
|
def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
|
|
@@ -212,7 +212,6 @@ class BaseGraphMerger:
|
|
|
212
212
|
if compare_data and not self.compare_param_same(main_param, other_param, has_uncertainty=True):
|
|
213
213
|
same_flag = False
|
|
214
214
|
if not same_flag:
|
|
215
|
-
# {input.0: [{"Max": 0, "Min": 0, ...}, {"Max": 0.1, "Min": 0, ...}, ...]}
|
|
216
215
|
data_dict[key.replace(main_node.id + Const.SEP, '')] = tp_need_merge_params
|
|
217
216
|
return data_types.get('input_data'), data_types.get('output_data')
|
|
218
217
|
|
|
@@ -28,7 +28,7 @@ op_patterns = [
|
|
|
28
28
|
# NodeOp.module
|
|
29
29
|
r'^(Module.|Cell.|optimizer|clip_grad)',
|
|
30
30
|
# NodeOp.function_api
|
|
31
|
-
r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)'
|
|
31
|
+
r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.|MindSpeed.)'
|
|
32
32
|
]
|
|
33
33
|
|
|
34
34
|
|
|
@@ -41,7 +41,7 @@ node_columns = {
|
|
|
41
41
|
'overflow_level': TEXT,
|
|
42
42
|
'micro_step_id': INTEGER_NOT_NULL,
|
|
43
43
|
'matched_node_link': TEXT,
|
|
44
|
-
'
|
|
44
|
+
'stack_id': TEXT,
|
|
45
45
|
'parallel_merge_info': TEXT,
|
|
46
46
|
'matched_distributed': TEXT,
|
|
47
47
|
'modified': INTEGER_NOT_NULL,
|
|
@@ -65,6 +65,11 @@ config_columns = {
|
|
|
65
65
|
'step_list': TEXT_NOT_NULL
|
|
66
66
|
}
|
|
67
67
|
|
|
68
|
+
stack_columns = {
|
|
69
|
+
'id': TEXT_PRIMARY_KEY,
|
|
70
|
+
'stack_info': TEXT
|
|
71
|
+
}
|
|
72
|
+
|
|
68
73
|
indexes = {
|
|
69
74
|
"index1": ["step", "rank", "data_source", "up_node", "node_order"],
|
|
70
75
|
"index2": ["step", "rank", "data_source", "node_name"],
|
|
@@ -197,19 +202,24 @@ def node_to_db(graph, db_name):
|
|
|
197
202
|
create_table_sql = create_table_sql_from_dict('tb_nodes', node_columns)
|
|
198
203
|
insert_sql = create_insert_sql_from_dict('tb_nodes', node_columns)
|
|
199
204
|
data = []
|
|
205
|
+
stack_dict = {}
|
|
200
206
|
for i, node in enumerate(graph.get_sorted_nodes()):
|
|
207
|
+
stack_info_text = json.dumps(node.stack_info)
|
|
208
|
+
if stack_info_text not in stack_dict:
|
|
209
|
+
stack_dict[stack_info_text] = get_stack_unique_id(graph, stack_dict)
|
|
201
210
|
data.append((get_node_unique_id(graph, node), get_graph_unique_id(graph), i, node.id, node.op.value,
|
|
202
211
|
node.upnode.id if node.upnode else '',
|
|
203
212
|
json.dumps([node.id for node in node.subnodes]) if node.subnodes else '',
|
|
204
213
|
node.data.get(GraphConst.JSON_INDEX_KEY), node.data.get(GraphConst.OVERFLOW_LEVEL),
|
|
205
214
|
node.micro_step_id if node.micro_step_id is not None else 0, json.dumps(node.matched_node_link),
|
|
206
|
-
|
|
215
|
+
stack_dict.get(stack_info_text),
|
|
207
216
|
json.dumps(node.parallel_merge_info) if node.parallel_merge_info else '',
|
|
208
217
|
json.dumps(node.matched_distributed), 0,
|
|
209
218
|
json.dumps(format_node_data(node.input_data, node.id, graph.compare_mode)),
|
|
210
219
|
json.dumps(format_node_data(node.output_data, node.id, graph.compare_mode)),
|
|
211
220
|
graph.data_source, graph.data_path, graph.step, graph.rank))
|
|
212
221
|
to_db(db_name, create_table_sql, insert_sql, data)
|
|
222
|
+
stack_to_db(stack_dict, db_name)
|
|
213
223
|
|
|
214
224
|
|
|
215
225
|
def config_to_db(config, db_name):
|
|
@@ -221,9 +231,22 @@ def config_to_db(config, db_name):
|
|
|
221
231
|
to_db(db_name, create_table_sql, insert_sql, data)
|
|
222
232
|
|
|
223
233
|
|
|
234
|
+
def stack_to_db(stack_dict, db_name):
|
|
235
|
+
create_table_sql = create_table_sql_from_dict('tb_stack', stack_columns)
|
|
236
|
+
insert_sql = create_insert_sql_from_dict('tb_stack', stack_columns)
|
|
237
|
+
data = []
|
|
238
|
+
for stack_info_text, unique_id in stack_dict.items():
|
|
239
|
+
data.append((unique_id, stack_info_text))
|
|
240
|
+
to_db(db_name, create_table_sql, insert_sql, data)
|
|
241
|
+
|
|
242
|
+
|
|
224
243
|
def get_graph_unique_id(graph):
|
|
225
244
|
return f'{graph.data_source}_{graph.step}_{graph.rank}'
|
|
226
245
|
|
|
227
246
|
|
|
228
247
|
def get_node_unique_id(graph, node):
|
|
229
248
|
return f'{get_graph_unique_id(graph)}_{node.id}'
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def get_stack_unique_id(graph, stack_dict):
|
|
252
|
+
return f'{get_graph_unique_id(graph)}_{len(stack_dict)}'
|
|
@@ -89,30 +89,6 @@ class BaseNode:
|
|
|
89
89
|
self.matched_node_link = ancestors
|
|
90
90
|
node.matched_node_link = ancestors
|
|
91
91
|
|
|
92
|
-
def to_dict(self, compare_mode=None):
|
|
93
|
-
"""
|
|
94
|
-
输出数据
|
|
95
|
-
"""
|
|
96
|
-
result = {
|
|
97
|
-
'id': self.id,
|
|
98
|
-
'node_type': self.op.value,
|
|
99
|
-
'output_data': format_node_data(self.output_data, self.id, compare_mode),
|
|
100
|
-
'input_data': format_node_data(self.input_data, self.id, compare_mode),
|
|
101
|
-
'upnode': self.upnode.id if self.upnode else 'None',
|
|
102
|
-
'subnodes': [node.id for node in self.subnodes],
|
|
103
|
-
'matched_node_link': self.matched_node_link,
|
|
104
|
-
'suggestions': self.suggestions,
|
|
105
|
-
'stack_info': self.stack_info
|
|
106
|
-
}
|
|
107
|
-
if self.micro_step_id is not None:
|
|
108
|
-
result['micro_step_id'] = self.micro_step_id
|
|
109
|
-
result['data'] = self.data
|
|
110
|
-
if self.matched_distributed:
|
|
111
|
-
result[GraphConst.MATCHED_DISTRIBUTED] = self.matched_distributed
|
|
112
|
-
if self.parallel_merge_info:
|
|
113
|
-
result['parallel_merge_info'] = self.parallel_merge_info
|
|
114
|
-
return result
|
|
115
|
-
|
|
116
92
|
def get_ancestors(self):
|
|
117
93
|
"""
|
|
118
94
|
获取节点所有祖先的列表
|
|
@@ -22,7 +22,7 @@ from msprobe.core.common.decorator import recursion_depth_decorator
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class Graph:
|
|
25
|
-
def __init__(self, model_name, data_path='', dump_data=None):
|
|
25
|
+
def __init__(self, model_name, data_path='', dump_data=None, micro_step_num=None):
|
|
26
26
|
self.node_map = {}
|
|
27
27
|
self.node_id_map = {}
|
|
28
28
|
self.add_node(NodeOp.module, model_name)
|
|
@@ -33,6 +33,7 @@ class Graph:
|
|
|
33
33
|
self.step = 0
|
|
34
34
|
self.rank = 0
|
|
35
35
|
self.compare_mode = GraphConst.SUMMARY_COMPARE
|
|
36
|
+
self.micro_step_num = micro_step_num
|
|
36
37
|
|
|
37
38
|
def __str__(self):
|
|
38
39
|
infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map]
|
|
@@ -172,19 +173,6 @@ class Graph:
|
|
|
172
173
|
"""
|
|
173
174
|
return self.node_map.get(node_id, None)
|
|
174
175
|
|
|
175
|
-
def to_dict(self, compare_mode=None):
|
|
176
|
-
"""
|
|
177
|
-
用于数据输出
|
|
178
|
-
"""
|
|
179
|
-
result = {}
|
|
180
|
-
result[GraphConst.JSON_ROOT_KEY] = self.root.id if self.root else 'None'
|
|
181
|
-
result[GraphConst.JSON_DATA_KEY] = self.data_path
|
|
182
|
-
result[GraphConst.JSON_NODE_KEY] = {}
|
|
183
|
-
for node_id in self.node_map:
|
|
184
|
-
info = self.node_map.get(node_id).to_dict(compare_mode)
|
|
185
|
-
result[GraphConst.JSON_NODE_KEY][node_id] = info
|
|
186
|
-
return result
|
|
187
|
-
|
|
188
176
|
def paging_by_micro_step(self, graph_other=None):
|
|
189
177
|
"""
|
|
190
178
|
给graph首层节点增加micro step标记,供前端分页展示,有助于在处理大规模图数据时进行优化和管理
|
|
@@ -203,6 +191,9 @@ class Graph:
|
|
|
203
191
|
for sub_node in node.subnodes:
|
|
204
192
|
propagate_micro_step_id(sub_node)
|
|
205
193
|
|
|
194
|
+
if self.micro_step_num is not None:
|
|
195
|
+
return self.micro_step_num + 1
|
|
196
|
+
|
|
206
197
|
batches_n = Graph.split_nodes_by_micro_step(self.root.subnodes)
|
|
207
198
|
for batch_number, nodes in batches_n.items():
|
|
208
199
|
for node in nodes:
|