mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -28,7 +28,7 @@ op_patterns = [
|
|
|
28
28
|
# NodeOp.module
|
|
29
29
|
r'^(Module.|Cell.|optimizer|clip_grad)',
|
|
30
30
|
# NodeOp.function_api
|
|
31
|
-
r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)'
|
|
31
|
+
r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.|MindSpeed.)'
|
|
32
32
|
]
|
|
33
33
|
|
|
34
34
|
|
|
@@ -54,7 +54,13 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
|
|
|
54
54
|
framework: 框架类型, pytorch或mindspore
|
|
55
55
|
is_cross_frame: 是否进行跨框架比对,仅支持mindspore比pytorch, 其中pytorch为标杆
|
|
56
56
|
"""
|
|
57
|
-
|
|
57
|
+
config_dict = {
|
|
58
|
+
'stack_mode': False,
|
|
59
|
+
'auto_analyze': True,
|
|
60
|
+
'fuzzy_match': False,
|
|
61
|
+
'dump_mode': Const.ALL
|
|
62
|
+
}
|
|
63
|
+
mode_config = ModeConfig(**config_dict)
|
|
58
64
|
|
|
59
65
|
if framework == Const.PT_FRAMEWORK:
|
|
60
66
|
from msprobe.pytorch.compare.pt_compare import read_real_data
|
|
@@ -125,7 +131,7 @@ def format_node_data(data_dict, node_id=None, compare_mode=None):
|
|
|
125
131
|
"""
|
|
126
132
|
删除节点数据中不需要展示的字段
|
|
127
133
|
"""
|
|
128
|
-
del_list = ['
|
|
134
|
+
del_list = ['state', 'full_op_name']
|
|
129
135
|
if GraphConst.MD5_COMPARE != compare_mode:
|
|
130
136
|
del_list.append(Const.MD5)
|
|
131
137
|
if node_id and GraphConst.BATCH_P2P in node_id:
|
|
@@ -140,31 +146,27 @@ def format_node_data(data_dict, node_id=None, compare_mode=None):
|
|
|
140
146
|
return data_dict
|
|
141
147
|
|
|
142
148
|
|
|
143
|
-
def compare_node(
|
|
149
|
+
def compare_node(node_n, node_b, compare_mode):
|
|
144
150
|
"""
|
|
145
151
|
调用acc_compare.py中的get_accuracy获得精度对比指标
|
|
146
152
|
真实数据对比模式无法获得精度对比指标,需要调用多进程比对接口
|
|
147
153
|
Returns: 包含参数信息和对比指标(真实数据对比模式除外)的list
|
|
148
154
|
"""
|
|
149
|
-
merge_n = _parse_node(node_ids[0], data_dicts[0], stack_json_data, compare_mode)
|
|
150
|
-
merge_b = _parse_node(node_ids[1], data_dicts[1], stack_json_data, compare_mode)
|
|
151
|
-
result = []
|
|
152
155
|
dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
|
|
156
|
+
merge_n = _parse_node(node_n, dump_mode)
|
|
157
|
+
merge_b = _parse_node(node_b, dump_mode)
|
|
158
|
+
result = []
|
|
153
159
|
get_accuracy(result, merge_n, merge_b, dump_mode)
|
|
154
160
|
return result
|
|
155
161
|
|
|
156
162
|
|
|
157
|
-
def _parse_node(
|
|
163
|
+
def _parse_node(node, dump_mode):
|
|
158
164
|
"""
|
|
159
165
|
转换节点,使其能够作为acc_compare.py中的get_accuracy的入参
|
|
160
166
|
"""
|
|
161
|
-
|
|
162
|
-
op_parsed_list
|
|
163
|
-
|
|
164
|
-
op_parsed_list.append(
|
|
165
|
-
{'full_op_name': node_id, 'full_info': stack_json_data[node_id]})
|
|
166
|
-
else:
|
|
167
|
-
op_parsed_list.append({'full_op_name': node_id, 'full_info': None})
|
|
167
|
+
op_parsed_list = []
|
|
168
|
+
op_parsed_list.extend(node.input_data.values())
|
|
169
|
+
op_parsed_list.extend(node.output_data.values())
|
|
168
170
|
result = merge_tensor(op_parsed_list, dump_mode)
|
|
169
171
|
if not result:
|
|
170
172
|
result['op_name'] = []
|
|
@@ -35,13 +35,15 @@ class GraphComparator:
|
|
|
35
35
|
self.fuzzy_match = args.fuzzy_match
|
|
36
36
|
self.pattern = re.compile(r'\.\d+\.')
|
|
37
37
|
self.is_cross_framework = is_cross_framework
|
|
38
|
+
self.parallel_merge = args.parallel_merge if hasattr(args, 'parallel_merge') else False
|
|
39
|
+
self.rank_pattern = re.compile(r"_rank\d+")
|
|
38
40
|
|
|
39
41
|
def compare(self):
|
|
40
42
|
"""
|
|
41
43
|
比较函数,初始化结束后单独调用。比较结果写入graph_n
|
|
42
44
|
"""
|
|
43
45
|
if self.fuzzy_match:
|
|
44
|
-
self._compare_nodes_fuzzy(self.graph_n.root)
|
|
46
|
+
self._compare_nodes_fuzzy(self.graph_n.root, False if self.parallel_merge else True)
|
|
45
47
|
else:
|
|
46
48
|
self._compare_nodes(self.graph_n.root)
|
|
47
49
|
self._postcompare()
|
|
@@ -98,11 +100,12 @@ class GraphComparator:
|
|
|
98
100
|
while node_list:
|
|
99
101
|
compare_single_node(node_list.pop(0))
|
|
100
102
|
|
|
101
|
-
def _compare_nodes_fuzzy(self, node_root):
|
|
103
|
+
def _compare_nodes_fuzzy(self, node_root, check_shape=True):
|
|
102
104
|
def compare_single_nodes_fuzzy(node_n):
|
|
103
105
|
if node_n.op != NodeOp.function_api:
|
|
104
106
|
# 模块经过模糊匹配
|
|
105
|
-
node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id)
|
|
107
|
+
node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id),
|
|
108
|
+
check_shape)
|
|
106
109
|
if node_b:
|
|
107
110
|
self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b)
|
|
108
111
|
# 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配
|
|
@@ -113,7 +116,7 @@ class GraphComparator:
|
|
|
113
116
|
if not api_node_n:
|
|
114
117
|
continue
|
|
115
118
|
api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(
|
|
116
|
-
api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)))
|
|
119
|
+
api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)), check_shape)
|
|
117
120
|
if api_node_b:
|
|
118
121
|
self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b)
|
|
119
122
|
node_list.extend(node_n.subnodes)
|
|
@@ -147,21 +150,26 @@ class GraphComparator:
|
|
|
147
150
|
api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标
|
|
148
151
|
md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差
|
|
149
152
|
"""
|
|
153
|
+
def handle_api_collection_index(api_collection_node):
|
|
154
|
+
precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \
|
|
155
|
+
else GraphConst.MIN_INDEX_KEY
|
|
156
|
+
for api in api_collection_node.subnodes:
|
|
157
|
+
precision_index = min(precision_index,
|
|
158
|
+
api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \
|
|
159
|
+
if self.ma.compare_mode == GraphConst.MD5_COMPARE \
|
|
160
|
+
else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
|
|
161
|
+
api_collection_node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
162
|
+
|
|
150
163
|
for node in self.graph_n.root.subnodes:
|
|
151
|
-
if node.op == NodeOp.api_collection:
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
if self.ma.compare_mode == GraphConst.MD5_COMPARE \
|
|
158
|
-
else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
|
|
159
|
-
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
164
|
+
if node.op == NodeOp.api_collection and node.id.startswith(GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS):
|
|
165
|
+
for sub_node in node.subnodes:
|
|
166
|
+
handle_api_collection_index(sub_node)
|
|
167
|
+
handle_api_collection_index(node)
|
|
168
|
+
elif node.op == NodeOp.api_collection:
|
|
169
|
+
handle_api_collection_index(node)
|
|
160
170
|
|
|
161
171
|
def _get_and_add_result(self, node_n, node_b):
|
|
162
|
-
compare_result_list = compare_node(
|
|
163
|
-
[self.data_n_dict, self.data_b_dict],
|
|
164
|
-
self.stack_json_data, self.ma.compare_mode)
|
|
172
|
+
compare_result_list = compare_node(node_n, node_b, self.ma.compare_mode)
|
|
165
173
|
if compare_result_list:
|
|
166
174
|
self.ma.add_csv_data(compare_result_list)
|
|
167
175
|
self.add_compare_result_to_node(node_n, compare_result_list)
|
|
@@ -178,6 +186,8 @@ class GraphComparator:
|
|
|
178
186
|
if sub_node.op == NodeOp.function_api:
|
|
179
187
|
# 忽略dump调用次数
|
|
180
188
|
count_removed_id = self.pattern.sub(Const.SEP, sub_node.id)
|
|
189
|
+
if self.rank_pattern.search(count_removed_id):
|
|
190
|
+
count_removed_id = self.rank_pattern.sub('', count_removed_id)
|
|
181
191
|
node_count[count_removed_id] = node_count.get(count_removed_id, 0) + 1
|
|
182
192
|
# 赋予模块中的调用顺序
|
|
183
193
|
recount_node_id = count_removed_id + str(node_count.get(count_removed_id))
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import sqlite3
|
|
18
|
+
import json
|
|
19
|
+
import re
|
|
20
|
+
from msprobe.core.common.log import logger
|
|
21
|
+
from msprobe.core.common.file_utils import change_mode, check_path_before_create, FileChecker
|
|
22
|
+
from msprobe.core.common.const import FileCheckConst
|
|
23
|
+
from msprobe.visualization.utils import GraphConst
|
|
24
|
+
from msprobe.visualization.builder.msprobe_adapter import format_node_data
|
|
25
|
+
|
|
26
|
+
TEXT_PRIMARY_KEY = 'TEXT PRIMARY KEY'
|
|
27
|
+
TEXT_NOT_NULL = 'TEXT NOT NULL'
|
|
28
|
+
INTEGER_NOT_NULL = 'INTEGER NOT NULL'
|
|
29
|
+
TEXT = 'TEXT'
|
|
30
|
+
INTEGER = 'INTEGER'
|
|
31
|
+
|
|
32
|
+
node_columns = {
|
|
33
|
+
'id': TEXT_PRIMARY_KEY,
|
|
34
|
+
'graph_id': TEXT_NOT_NULL,
|
|
35
|
+
'node_order': INTEGER_NOT_NULL,
|
|
36
|
+
'node_name': TEXT_NOT_NULL,
|
|
37
|
+
'node_type': TEXT_NOT_NULL,
|
|
38
|
+
'up_node': TEXT,
|
|
39
|
+
'sub_nodes': TEXT,
|
|
40
|
+
'precision_index': INTEGER,
|
|
41
|
+
'overflow_level': TEXT,
|
|
42
|
+
'micro_step_id': INTEGER_NOT_NULL,
|
|
43
|
+
'matched_node_link': TEXT,
|
|
44
|
+
'stack_id': TEXT,
|
|
45
|
+
'parallel_merge_info': TEXT,
|
|
46
|
+
'matched_distributed': TEXT,
|
|
47
|
+
'modified': INTEGER_NOT_NULL,
|
|
48
|
+
'input_data': TEXT,
|
|
49
|
+
'output_data': TEXT,
|
|
50
|
+
'data_source': TEXT,
|
|
51
|
+
'dump_data_dir': TEXT,
|
|
52
|
+
'step': INTEGER_NOT_NULL,
|
|
53
|
+
'rank': INTEGER_NOT_NULL
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
config_columns = {
|
|
57
|
+
'id': TEXT_PRIMARY_KEY,
|
|
58
|
+
'graph_type': TEXT_NOT_NULL,
|
|
59
|
+
'task': TEXT,
|
|
60
|
+
'tool_tip': TEXT,
|
|
61
|
+
'micro_steps': INTEGER,
|
|
62
|
+
'overflow_check': INTEGER,
|
|
63
|
+
'node_colors': TEXT_NOT_NULL,
|
|
64
|
+
'rank_list': TEXT_NOT_NULL,
|
|
65
|
+
'step_list': TEXT_NOT_NULL
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
stack_columns = {
|
|
69
|
+
'id': TEXT_PRIMARY_KEY,
|
|
70
|
+
'stack_info': TEXT
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
indexes = {
|
|
74
|
+
"index1": ["step", "rank", "data_source", "up_node", "node_order"],
|
|
75
|
+
"index2": ["step", "rank", "data_source", "node_name"],
|
|
76
|
+
"index3": ["step", "rank", "data_source", "node_order"],
|
|
77
|
+
"index4": ["step", "rank", "node_order"],
|
|
78
|
+
"index5": ["step", "rank", "micro_step_id", "node_order"],
|
|
79
|
+
"index6": ["step", "rank", "modified", "matched_node_link"]
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
SAFE_NAME_PATTERN = re.compile(r'^[a-zA-Z0-9_]+$')
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def is_safe_identifier(name):
|
|
86
|
+
"""验证标识符是否安全(防止SQL注入)"""
|
|
87
|
+
return isinstance(name, str) and SAFE_NAME_PATTERN.match(name) is not None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def create_table_sql_from_dict(table_name, columns_dict):
|
|
91
|
+
"""
|
|
92
|
+
根据提供的表名和列定义字典生成CREATE TABLE SQL语句。
|
|
93
|
+
"""
|
|
94
|
+
if not is_safe_identifier(table_name):
|
|
95
|
+
raise ValueError(f"Invalid table name: {table_name} - potential SQL injection risk!")
|
|
96
|
+
|
|
97
|
+
sql = f"CREATE TABLE IF NOT EXISTS {table_name} (\n"
|
|
98
|
+
|
|
99
|
+
column_definitions = []
|
|
100
|
+
for column_name, column_type in columns_dict.items():
|
|
101
|
+
if not is_safe_identifier(column_name):
|
|
102
|
+
raise ValueError(f"Invalid column name: {column_name} - potential SQL injection risk!")
|
|
103
|
+
|
|
104
|
+
column_definitions.append(f" {column_name} {column_type}")
|
|
105
|
+
|
|
106
|
+
sql += ",\n".join(column_definitions)
|
|
107
|
+
sql += "\n);"
|
|
108
|
+
|
|
109
|
+
return sql
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def create_insert_sql_from_dict(table_name, columns_dict, ignore_insert=False):
|
|
113
|
+
"""
|
|
114
|
+
根据提供的表名和数据字典生成INSERT INTO SQL语句。
|
|
115
|
+
"""
|
|
116
|
+
if not is_safe_identifier(table_name):
|
|
117
|
+
raise ValueError(f"Invalid table name: {table_name} - potential SQL injection risk!")
|
|
118
|
+
|
|
119
|
+
columns = list(columns_dict.keys())
|
|
120
|
+
|
|
121
|
+
for column_name in columns:
|
|
122
|
+
if not is_safe_identifier(column_name):
|
|
123
|
+
raise ValueError(f"Invalid column name: {column_name} - potential SQL injection risk!")
|
|
124
|
+
|
|
125
|
+
placeholders = ["?"] * len(columns)
|
|
126
|
+
|
|
127
|
+
columns_string = ", ".join(columns)
|
|
128
|
+
placeholders_string = ", ".join(placeholders)
|
|
129
|
+
|
|
130
|
+
sql_prefix = "INSERT OR IGNORE INTO" if ignore_insert else "INSERT INTO"
|
|
131
|
+
sql = f"{sql_prefix} {table_name} ({columns_string}) VALUES ({placeholders_string})"
|
|
132
|
+
return sql
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def to_db(db_path, create_table_sql, insert_sql, data, db_insert_size=1000):
|
|
136
|
+
if not os.path.exists(db_path):
|
|
137
|
+
check_path_before_create(db_path)
|
|
138
|
+
else:
|
|
139
|
+
FileChecker(db_path, FileCheckConst.FILE, FileCheckConst.READ_WRITE_ABLE,
|
|
140
|
+
FileCheckConst.DB_SUFFIX).common_check()
|
|
141
|
+
try:
|
|
142
|
+
conn = sqlite3.connect(db_path)
|
|
143
|
+
except sqlite3.Error as e:
|
|
144
|
+
logger.error(f"Unable to create database connection: {e}")
|
|
145
|
+
raise RuntimeError("Unable to create database connection") from e
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
cursor = conn.cursor()
|
|
149
|
+
cursor.execute(create_table_sql)
|
|
150
|
+
if len(data) == 1:
|
|
151
|
+
cursor.execute(insert_sql, data[0])
|
|
152
|
+
conn.commit()
|
|
153
|
+
else:
|
|
154
|
+
for i in range(0, len(data), db_insert_size):
|
|
155
|
+
batch = data[i:i + db_insert_size]
|
|
156
|
+
cursor.executemany(insert_sql, batch)
|
|
157
|
+
conn.commit()
|
|
158
|
+
except sqlite3.Error as e:
|
|
159
|
+
logger.error(f"An sqlite3 error occurred: {e}")
|
|
160
|
+
raise RuntimeError("An sqlite3 error occurred") from e
|
|
161
|
+
finally:
|
|
162
|
+
conn.close()
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def add_table_index(db_path):
|
|
166
|
+
FileChecker(db_path, FileCheckConst.FILE, FileCheckConst.READ_WRITE_ABLE, FileCheckConst.DB_SUFFIX).common_check()
|
|
167
|
+
try:
|
|
168
|
+
conn = sqlite3.connect(db_path)
|
|
169
|
+
except sqlite3.Error as e:
|
|
170
|
+
logger.error(f"Unable to create database connection: {e}")
|
|
171
|
+
raise RuntimeError("Unable to create database connection") from e
|
|
172
|
+
|
|
173
|
+
try:
|
|
174
|
+
cursor = conn.cursor()
|
|
175
|
+
for index_name, columns in indexes.items():
|
|
176
|
+
if not is_safe_identifier(index_name):
|
|
177
|
+
raise ValueError(f"Invalid index name: {index_name} - potential SQL injection risk!")
|
|
178
|
+
|
|
179
|
+
for column in columns:
|
|
180
|
+
if not is_safe_identifier(column):
|
|
181
|
+
raise ValueError(f"Invalid column name in index: {column} - potential SQL injection risk!")
|
|
182
|
+
|
|
183
|
+
columns_str = ', '.join(columns)
|
|
184
|
+
index_sql = f'''
|
|
185
|
+
CREATE INDEX IF NOT EXISTS {index_name} ON tb_nodes ({columns_str});
|
|
186
|
+
'''
|
|
187
|
+
cursor.execute(index_sql)
|
|
188
|
+
conn.commit()
|
|
189
|
+
except sqlite3.Error as e:
|
|
190
|
+
logger.error(f"Failed to add table index: {e}")
|
|
191
|
+
raise RuntimeError("Failed to add table index") from e
|
|
192
|
+
finally:
|
|
193
|
+
conn.close()
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def post_process_db(db_path):
|
|
197
|
+
add_table_index(db_path)
|
|
198
|
+
change_mode(db_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def node_to_db(graph, db_name):
|
|
202
|
+
create_table_sql = create_table_sql_from_dict('tb_nodes', node_columns)
|
|
203
|
+
insert_sql = create_insert_sql_from_dict('tb_nodes', node_columns)
|
|
204
|
+
data = []
|
|
205
|
+
stack_dict = {}
|
|
206
|
+
for i, node in enumerate(graph.get_sorted_nodes()):
|
|
207
|
+
stack_info_text = json.dumps(node.stack_info)
|
|
208
|
+
if stack_info_text not in stack_dict:
|
|
209
|
+
stack_dict[stack_info_text] = get_stack_unique_id(graph, stack_dict)
|
|
210
|
+
data.append((get_node_unique_id(graph, node), get_graph_unique_id(graph), i, node.id, node.op.value,
|
|
211
|
+
node.upnode.id if node.upnode else '',
|
|
212
|
+
json.dumps([node.id for node in node.subnodes]) if node.subnodes else '',
|
|
213
|
+
node.data.get(GraphConst.JSON_INDEX_KEY), node.data.get(GraphConst.OVERFLOW_LEVEL),
|
|
214
|
+
node.micro_step_id if node.micro_step_id is not None else 0, json.dumps(node.matched_node_link),
|
|
215
|
+
stack_dict.get(stack_info_text),
|
|
216
|
+
json.dumps(node.parallel_merge_info) if node.parallel_merge_info else '',
|
|
217
|
+
json.dumps(node.matched_distributed), 0,
|
|
218
|
+
json.dumps(format_node_data(node.input_data, node.id, graph.compare_mode)),
|
|
219
|
+
json.dumps(format_node_data(node.output_data, node.id, graph.compare_mode)),
|
|
220
|
+
graph.data_source, graph.data_path, graph.step, graph.rank))
|
|
221
|
+
to_db(db_name, create_table_sql, insert_sql, data)
|
|
222
|
+
stack_to_db(stack_dict, db_name)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def config_to_db(config, db_name):
|
|
226
|
+
create_table_sql = create_table_sql_from_dict('tb_config', config_columns)
|
|
227
|
+
insert_sql = create_insert_sql_from_dict('tb_config', config_columns, ignore_insert=True)
|
|
228
|
+
data = [("1", "compare" if config.graph_b else "build", config.task, config.tool_tip, config.micro_steps,
|
|
229
|
+
config.overflow_check, json.dumps(config.node_colors), json.dumps(config.rank_list),
|
|
230
|
+
json.dumps(config.step_list))]
|
|
231
|
+
to_db(db_name, create_table_sql, insert_sql, data)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def stack_to_db(stack_dict, db_name):
|
|
235
|
+
create_table_sql = create_table_sql_from_dict('tb_stack', stack_columns)
|
|
236
|
+
insert_sql = create_insert_sql_from_dict('tb_stack', stack_columns)
|
|
237
|
+
data = []
|
|
238
|
+
for stack_info_text, unique_id in stack_dict.items():
|
|
239
|
+
data.append((unique_id, stack_info_text))
|
|
240
|
+
to_db(db_name, create_table_sql, insert_sql, data)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def get_graph_unique_id(graph):
|
|
244
|
+
return f'{graph.data_source}_{graph.step}_{graph.rank}'
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def get_node_unique_id(graph, node):
|
|
248
|
+
return f'{get_graph_unique_id(graph)}_{node.id}'
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def get_stack_unique_id(graph, stack_dict):
|
|
252
|
+
return f'{get_graph_unique_id(graph)}_{len(stack_dict)}'
|
|
@@ -36,6 +36,8 @@ class BaseNode:
|
|
|
36
36
|
self.overflow_level = None
|
|
37
37
|
self.matched_distributed = {}
|
|
38
38
|
self.batch_p2p_info = []
|
|
39
|
+
self.rank = 0
|
|
40
|
+
self.parallel_merge_info = []
|
|
39
41
|
|
|
40
42
|
def __str__(self):
|
|
41
43
|
info = f'id:\t{self.id}'
|
|
@@ -87,28 +89,6 @@ class BaseNode:
|
|
|
87
89
|
self.matched_node_link = ancestors
|
|
88
90
|
node.matched_node_link = ancestors
|
|
89
91
|
|
|
90
|
-
def to_dict(self, compare_mode=None):
|
|
91
|
-
"""
|
|
92
|
-
输出数据
|
|
93
|
-
"""
|
|
94
|
-
result = {
|
|
95
|
-
'id': self.id,
|
|
96
|
-
'node_type': self.op.value,
|
|
97
|
-
'output_data': format_node_data(self.output_data, self.id, compare_mode),
|
|
98
|
-
'input_data': format_node_data(self.input_data, self.id, compare_mode),
|
|
99
|
-
'upnode': self.upnode.id if self.upnode else 'None',
|
|
100
|
-
'subnodes': [node.id for node in self.subnodes],
|
|
101
|
-
'matched_node_link': self.matched_node_link,
|
|
102
|
-
'suggestions': self.suggestions,
|
|
103
|
-
'stack_info': self.stack_info
|
|
104
|
-
}
|
|
105
|
-
if self.micro_step_id is not None:
|
|
106
|
-
result['micro_step_id'] = self.micro_step_id
|
|
107
|
-
result['data'] = self.data
|
|
108
|
-
if self.matched_distributed:
|
|
109
|
-
result[GraphConst.MATCHED_DISTRIBUTED] = self.matched_distributed
|
|
110
|
-
return result
|
|
111
|
-
|
|
112
92
|
def get_ancestors(self):
|
|
113
93
|
"""
|
|
114
94
|
获取节点所有祖先的列表
|
|
@@ -82,7 +82,7 @@ class DistributedAnalyzer:
|
|
|
82
82
|
"""
|
|
83
83
|
target_rank = node.input_data.get(f'{node.id}{GraphConst.INPUT}{parameter}', {}).get('value')
|
|
84
84
|
if target_rank is None:
|
|
85
|
-
logger.
|
|
85
|
+
logger.debug(f'The parameter {parameter} of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
|
|
86
86
|
return target_rank
|
|
87
87
|
|
|
88
88
|
@staticmethod
|
|
@@ -95,15 +95,15 @@ class DistributedAnalyzer:
|
|
|
95
95
|
"""
|
|
96
96
|
group = node.input_data.get(f'{node.id}{GraphConst.INPUT}group', {})
|
|
97
97
|
if not group:
|
|
98
|
-
logger.
|
|
98
|
+
logger.debug(f'The kwarg group of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
|
|
99
99
|
return None, None
|
|
100
100
|
group_ranks = group.get('group_ranks')
|
|
101
101
|
if not group_ranks:
|
|
102
|
-
logger.
|
|
102
|
+
logger.debug(f'The group_ranks of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
|
|
103
103
|
return None, None
|
|
104
104
|
group_id = group.get('group_id')
|
|
105
105
|
if not group_id:
|
|
106
|
-
logger.
|
|
106
|
+
logger.debug(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
|
|
107
107
|
return None, None
|
|
108
108
|
return group_ranks, group_id
|
|
109
109
|
|
|
@@ -183,7 +183,7 @@ class DistributedAnalyzer:
|
|
|
183
183
|
op = info_dict.get(GraphConst.OP)
|
|
184
184
|
target_rank = info_dict.get(GraphConst.PEER)
|
|
185
185
|
if op is None or target_rank is None:
|
|
186
|
-
logger.
|
|
186
|
+
logger.debug('Cannot get param op or peer.')
|
|
187
187
|
continue
|
|
188
188
|
group_id = op + Const.REPLACEMENT_CHARACTER + Const.RANK + str(target_rank) + \
|
|
189
189
|
Const.REPLACEMENT_CHARACTER + info_dict.get(GraphConst.GROUP_ID, '')
|
|
@@ -215,7 +215,7 @@ class DistributedAnalyzer:
|
|
|
215
215
|
"""
|
|
216
216
|
target_graph = self.graphs.get(target_rank)
|
|
217
217
|
if not target_graph:
|
|
218
|
-
logger.
|
|
218
|
+
logger.debug(f'Graph data does not exist, {CANNOT_MATCH}{target_rank}')
|
|
219
219
|
return None
|
|
220
220
|
target_group_mapping = self.group_node_mapping.get(target_rank)
|
|
221
221
|
# p2p通信,想要获取目标节点,需要替换unique_group_id中的rank和api name,
|
|
@@ -226,7 +226,7 @@ class DistributedAnalyzer:
|
|
|
226
226
|
target_node_id = target_group_mapping.get(target_unique_group_id, '')
|
|
227
227
|
target_node = target_graph.node_map.get(target_node_id)
|
|
228
228
|
if not target_node:
|
|
229
|
-
logger.
|
|
229
|
+
logger.debug(f'Node {target_node_id} does not exist, {CANNOT_MATCH}{target_rank}')
|
|
230
230
|
return None
|
|
231
231
|
return target_node
|
|
232
232
|
|
|
@@ -276,13 +276,13 @@ class DistributedAnalyzer:
|
|
|
276
276
|
source_rank = (target_node.input_data.get(f'{target_node.id}{GraphConst.INPUT}{target_config_info[1]}', {})
|
|
277
277
|
.get('value'))
|
|
278
278
|
if source_rank is None:
|
|
279
|
-
logger.
|
|
279
|
+
logger.debug(
|
|
280
280
|
f'The kwarg {target_config_info[1]} of node {target_node.id} does not exist, '
|
|
281
281
|
f'{CANNOT_MATCH}{source_rank}')
|
|
282
282
|
return
|
|
283
283
|
if source_rank != rank:
|
|
284
284
|
# 点对点通信,待匹配目标节点包含的rank信息要与当前rank一致
|
|
285
|
-
logger.
|
|
285
|
+
logger.debug(
|
|
286
286
|
f'{node.id} of rank{rank} is expected to communicate with {target_node.id} of rank{target_rank}, '
|
|
287
287
|
f'but the data shows that {target_node.id} communicates with rank{source_rank}.'
|
|
288
288
|
f'The rank is inconsistent, cannot match distributed node')
|
|
@@ -291,7 +291,7 @@ class DistributedAnalyzer:
|
|
|
291
291
|
# 点对点通信,两个匹配节点的输出数据要一致
|
|
292
292
|
if not DistributedAnalyzer._node_output_all_equal(node.output_data.get(node.id + '.output.0'),
|
|
293
293
|
target_node.output_data.get(target_node.id + '.output.0')):
|
|
294
|
-
logger.
|
|
294
|
+
logger.debug(f'{node.id} output of rank{rank} is different from the {target_node.id} '
|
|
295
295
|
f'output of rank{target_rank}, cannot match distributed node')
|
|
296
296
|
return
|
|
297
297
|
|
|
@@ -332,7 +332,7 @@ class DistributedAnalyzer:
|
|
|
332
332
|
if not target_group_id:
|
|
333
333
|
continue
|
|
334
334
|
if group_id != target_group_id:
|
|
335
|
-
logger.
|
|
335
|
+
logger.debug(
|
|
336
336
|
f'{node.id} of rank{rank} is expected to communicate with {target_node.id} of rank{target_rank}'
|
|
337
337
|
f', but the data shows that the group id of the two nodes are different, '
|
|
338
338
|
f'cannot match distributed node')
|
|
@@ -368,7 +368,7 @@ class DistributedAnalyzer:
|
|
|
368
368
|
target_api_name = self.config.get(api_name)[0]
|
|
369
369
|
target_rank = int(id_info[1].replace(Const.RANK, ''))
|
|
370
370
|
except Exception as e:
|
|
371
|
-
logger.
|
|
371
|
+
logger.debug(f'Failed to parse batch p2p parameter with error info: {e}.')
|
|
372
372
|
continue
|
|
373
373
|
target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name)
|
|
374
374
|
if not target_node:
|
|
@@ -18,16 +18,22 @@ from msprobe.visualization.graph.node_op import NodeOp
|
|
|
18
18
|
from msprobe.visualization.utils import GraphConst
|
|
19
19
|
from msprobe.core.common.log import logger
|
|
20
20
|
from msprobe.core.common.const import Const
|
|
21
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
class Graph:
|
|
24
|
-
def __init__(self, model_name, data_path='', dump_data=None):
|
|
25
|
+
def __init__(self, model_name, data_path='', dump_data=None, micro_step_num=None):
|
|
25
26
|
self.node_map = {}
|
|
26
27
|
self.node_id_map = {}
|
|
27
28
|
self.add_node(NodeOp.module, model_name)
|
|
28
29
|
self.root = self.get_node(model_name)
|
|
29
30
|
self.data_path = data_path
|
|
30
31
|
self.dump_data = dump_data
|
|
32
|
+
self.data_source = GraphConst.JSON_NPU_KEY
|
|
33
|
+
self.step = 0
|
|
34
|
+
self.rank = 0
|
|
35
|
+
self.compare_mode = GraphConst.SUMMARY_COMPARE
|
|
36
|
+
self.micro_step_num = micro_step_num
|
|
31
37
|
|
|
32
38
|
def __str__(self):
|
|
33
39
|
infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map]
|
|
@@ -65,8 +71,10 @@ class Graph:
|
|
|
65
71
|
return node_b, ancestors_n, ancestors_b
|
|
66
72
|
|
|
67
73
|
@staticmethod
|
|
68
|
-
def fuzzy_match(node_n, node_b):
|
|
69
|
-
if not node_n or not node_b
|
|
74
|
+
def fuzzy_match(node_n, node_b, check_shape=True):
|
|
75
|
+
if not node_n or not node_b:
|
|
76
|
+
return None, [], []
|
|
77
|
+
if check_shape and not node_n.fuzzy_eq(node_b):
|
|
70
78
|
return None, [], []
|
|
71
79
|
ancestors_n = node_n.get_ancestors()
|
|
72
80
|
ancestors_b = node_b.get_ancestors()
|
|
@@ -116,6 +124,25 @@ class Graph:
|
|
|
116
124
|
result[micro_step].append(node)
|
|
117
125
|
return result
|
|
118
126
|
|
|
127
|
+
def get_sorted_nodes(self):
|
|
128
|
+
"""
|
|
129
|
+
通过深度优先遍历graph,获得排过序的node列表
|
|
130
|
+
"""
|
|
131
|
+
visited = set()
|
|
132
|
+
order = []
|
|
133
|
+
|
|
134
|
+
@recursion_depth_decorator('msprobe.visualization.graph.graph.Graph.get_nodes_order.visit', max_depth=500)
|
|
135
|
+
def visit(node):
|
|
136
|
+
if node.id in visited:
|
|
137
|
+
return
|
|
138
|
+
visited.add(node.id)
|
|
139
|
+
for sub_node in node.subnodes:
|
|
140
|
+
visit(sub_node)
|
|
141
|
+
order.append(node)
|
|
142
|
+
|
|
143
|
+
visit(self.root)
|
|
144
|
+
return order
|
|
145
|
+
|
|
119
146
|
def add_node(self, node_op, node_id, up_node=None, id_accumulation=False):
|
|
120
147
|
"""
|
|
121
148
|
在graph中进行节点的添加
|
|
@@ -146,19 +173,6 @@ class Graph:
|
|
|
146
173
|
"""
|
|
147
174
|
return self.node_map.get(node_id, None)
|
|
148
175
|
|
|
149
|
-
def to_dict(self, compare_mode=None):
|
|
150
|
-
"""
|
|
151
|
-
用于数据输出
|
|
152
|
-
"""
|
|
153
|
-
result = {}
|
|
154
|
-
result[GraphConst.JSON_ROOT_KEY] = self.root.id if self.root else 'None'
|
|
155
|
-
result[GraphConst.JSON_DATA_KEY] = self.data_path
|
|
156
|
-
result[GraphConst.JSON_NODE_KEY] = {}
|
|
157
|
-
for node_id in self.node_map:
|
|
158
|
-
info = self.node_map.get(node_id).to_dict(compare_mode)
|
|
159
|
-
result[GraphConst.JSON_NODE_KEY][node_id] = info
|
|
160
|
-
return result
|
|
161
|
-
|
|
162
176
|
def paging_by_micro_step(self, graph_other=None):
|
|
163
177
|
"""
|
|
164
178
|
给graph首层节点增加micro step标记,供前端分页展示,有助于在处理大规模图数据时进行优化和管理
|
|
@@ -168,6 +182,18 @@ class Graph:
|
|
|
168
182
|
graph_other: 可选参数,另一个graph
|
|
169
183
|
Returns: 分批的数量
|
|
170
184
|
"""
|
|
185
|
+
|
|
186
|
+
@recursion_depth_decorator(
|
|
187
|
+
'msprobe.visualization.graph.graph.Graph.paging_by_micro_step.propagate_micro_step_id', max_depth=500)
|
|
188
|
+
def propagate_micro_step_id(node):
|
|
189
|
+
if node.upnode is not None and node.micro_step_id is None:
|
|
190
|
+
node.micro_step_id = node.upnode.micro_step_id
|
|
191
|
+
for sub_node in node.subnodes:
|
|
192
|
+
propagate_micro_step_id(sub_node)
|
|
193
|
+
|
|
194
|
+
if self.micro_step_num is not None:
|
|
195
|
+
return self.micro_step_num + 1
|
|
196
|
+
|
|
171
197
|
batches_n = Graph.split_nodes_by_micro_step(self.root.subnodes)
|
|
172
198
|
for batch_number, nodes in batches_n.items():
|
|
173
199
|
for node in nodes:
|
|
@@ -177,6 +203,7 @@ class Graph:
|
|
|
177
203
|
node_other = graph_other.get_node(node.matched_node_link[-1])
|
|
178
204
|
if node_other:
|
|
179
205
|
node_other.micro_step_id = batch_number
|
|
206
|
+
propagate_micro_step_id(self.root)
|
|
180
207
|
# 遍历graph_other根节点下的所有子节点,确保未匹配节点也有micro_step_id
|
|
181
208
|
if graph_other:
|
|
182
209
|
for node in graph_other.root.subnodes:
|
|
@@ -186,6 +213,7 @@ class Graph:
|
|
|
186
213
|
except ValueError:
|
|
187
214
|
micro_step_id = 0
|
|
188
215
|
node.micro_step_id = micro_step_id
|
|
216
|
+
propagate_micro_step_id(graph_other.root)
|
|
189
217
|
return len(batches_n)
|
|
190
218
|
|
|
191
219
|
def overflow_check(self):
|