mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
from enum import Enum
|
|
16
|
+
from msprobe.visualization.utils import GraphConst
|
|
17
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
18
|
+
from msprobe.core.common.log import logger
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CommunicationType(Enum):
|
|
22
|
+
"""
|
|
23
|
+
通信类型:发送、接收、发送接收
|
|
24
|
+
"""
|
|
25
|
+
SEND = 'send'
|
|
26
|
+
RECEIVE = 'receive'
|
|
27
|
+
SEND_RECEIVE = 'send_receive'
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DistributedType(Enum):
|
|
31
|
+
"""
|
|
32
|
+
分布式类型:点对点通信、集体通信
|
|
33
|
+
"""
|
|
34
|
+
P2P = 'p2p'
|
|
35
|
+
COLLECTIVE = 'collective'
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
CANNOT_MATCH = 'cannot match distributed node in rank'
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DistributedAnalyzer:
|
|
42
|
+
|
|
43
|
+
def __init__(self, graphs: dict, overflow_check: bool):
|
|
44
|
+
self.graphs = graphs
|
|
45
|
+
self.overflow_check = overflow_check
|
|
46
|
+
self.config = {
|
|
47
|
+
# 当前通信api名称: 匹配目标通信api名称, 获取rank信息的位置参数或关键字参数, 通信类型, 分布式类型
|
|
48
|
+
'send': ['recv', GraphConst.DST, CommunicationType.SEND.value, DistributedType.P2P],
|
|
49
|
+
'isend': ['irecv', GraphConst.DST, CommunicationType.SEND.value, DistributedType.P2P],
|
|
50
|
+
'recv': ['send', GraphConst.SRC, CommunicationType.RECEIVE.value, DistributedType.P2P],
|
|
51
|
+
'irecv': ['isend', GraphConst.SRC, CommunicationType.RECEIVE.value, DistributedType.P2P],
|
|
52
|
+
'broadcast': ['broadcast', '1', CommunicationType.SEND.value, DistributedType.COLLECTIVE],
|
|
53
|
+
'scatter': ['scatter', GraphConst.SRC, CommunicationType.SEND.value, DistributedType.COLLECTIVE],
|
|
54
|
+
'gather': ['gather', GraphConst.DST, CommunicationType.RECEIVE.value, DistributedType.COLLECTIVE],
|
|
55
|
+
'reduce': ['reduce', '1', CommunicationType.RECEIVE.value, DistributedType.COLLECTIVE]
|
|
56
|
+
}
|
|
57
|
+
self.group_node_mapping = {}
|
|
58
|
+
self._make_group_node_mapping()
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def _get_opposite_communication_type(action):
|
|
62
|
+
if action == CommunicationType.SEND.value:
|
|
63
|
+
return CommunicationType.RECEIVE.value
|
|
64
|
+
elif action == CommunicationType.RECEIVE.value:
|
|
65
|
+
return CommunicationType.SEND.value
|
|
66
|
+
return action
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def _node_output_all_equal(data: dict, target_data: dict):
|
|
70
|
+
keys_to_compare = [Const.DTYPE, Const.SHAPE, Const.MAX, Const.MIN, Const.MEAN, Const.NORM]
|
|
71
|
+
return all(data.get(key) == target_data.get(key) for key in keys_to_compare)
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def _get_target_rank(node, rank, parameter):
|
|
75
|
+
"""
|
|
76
|
+
点对点通信, 从输出数据参数src或dst, 获取通信目标rank
|
|
77
|
+
一对多通信和多对一通信, 从输出数据参数src或dst或位置参数, 获取发送或接收的rank源头
|
|
78
|
+
:param node: 当前节点
|
|
79
|
+
:param rank: 当前rank
|
|
80
|
+
:param parameter: 输出数据参数
|
|
81
|
+
:return: 目标rank
|
|
82
|
+
"""
|
|
83
|
+
target_rank = node.input_data.get(f'{node.id}{GraphConst.INPUT}{parameter}', {}).get('value')
|
|
84
|
+
if target_rank is None:
|
|
85
|
+
logger.warning(f'The parameter {parameter} of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
|
|
86
|
+
return target_rank
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def _get_group_info(node, rank):
|
|
90
|
+
"""
|
|
91
|
+
获取当前通信节点的group参数中的group_ranks和group_id
|
|
92
|
+
:param node: 当前通信节点
|
|
93
|
+
:param rank: 当前rank
|
|
94
|
+
:return: group_ranks和group_id
|
|
95
|
+
"""
|
|
96
|
+
group = node.input_data.get(f'{node.id}{GraphConst.INPUT}group', {})
|
|
97
|
+
if not group:
|
|
98
|
+
logger.warning(f'The kwarg group of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
|
|
99
|
+
return None, None
|
|
100
|
+
group_ranks = group.get('group_ranks')
|
|
101
|
+
if not group_ranks:
|
|
102
|
+
logger.warning(f'The group_ranks of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
|
|
103
|
+
return None, None
|
|
104
|
+
group_id = group.get('group_id')
|
|
105
|
+
if not group_id:
|
|
106
|
+
logger.warning(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
|
|
107
|
+
return None, None
|
|
108
|
+
return group_ranks, group_id
|
|
109
|
+
|
|
110
|
+
def distributed_match(self):
|
|
111
|
+
for rank, graph in self.graphs.items():
|
|
112
|
+
nodes = graph.node_map
|
|
113
|
+
for node_id, node in nodes.items():
|
|
114
|
+
# 不是通信节点或者已经匹配过了
|
|
115
|
+
if not node_id.startswith(Const.DISTRIBUTED) or node.matched_distributed:
|
|
116
|
+
continue
|
|
117
|
+
api_name, distributed_type = self._get_distributed_name_and_type(node_id)
|
|
118
|
+
if distributed_type == DistributedType.P2P:
|
|
119
|
+
self._p2p_match(node, rank, api_name)
|
|
120
|
+
else:
|
|
121
|
+
self._collective_match(node, rank, api_name)
|
|
122
|
+
|
|
123
|
+
def _make_group_node_mapping(self):
|
|
124
|
+
"""
|
|
125
|
+
建立通信节点的全局唯一标识映射
|
|
126
|
+
key: rank号, value: unique_group_id与node_id之间的映射
|
|
127
|
+
{
|
|
128
|
+
"0": {
|
|
129
|
+
"unique_group_id1": "node_id1",
|
|
130
|
+
"unique_group_id2": "node_id2",
|
|
131
|
+
"node_id1": "unique_group_id1",
|
|
132
|
+
"node_id2": "unique_group_id2"
|
|
133
|
+
},
|
|
134
|
+
"1": {},
|
|
135
|
+
"2": {}
|
|
136
|
+
}
|
|
137
|
+
"""
|
|
138
|
+
for rank, graph in self.graphs.items():
|
|
139
|
+
group_count = {}
|
|
140
|
+
group_info = {}
|
|
141
|
+
nodes = graph.node_map
|
|
142
|
+
for node_id, node in nodes.items():
|
|
143
|
+
if not node_id.startswith(Const.DISTRIBUTED):
|
|
144
|
+
continue
|
|
145
|
+
api_name, distributed_type = self._get_distributed_name_and_type(node_id)
|
|
146
|
+
if distributed_type == DistributedType.P2P:
|
|
147
|
+
config_info = self.config.get(api_name)
|
|
148
|
+
target_rank = self._get_target_rank(node, rank, config_info[1])
|
|
149
|
+
if target_rank is None:
|
|
150
|
+
continue
|
|
151
|
+
# p2p通信节点,api名称+传输目标rank作为group_id
|
|
152
|
+
group_id = api_name + Const.RANK + str(target_rank)
|
|
153
|
+
else:
|
|
154
|
+
# 其他通信节点直接获取group_id, 并拼接api名称
|
|
155
|
+
_, group_id = self._get_group_info(node, rank)
|
|
156
|
+
if not group_id:
|
|
157
|
+
continue
|
|
158
|
+
group_id += api_name
|
|
159
|
+
# 同group_id的调用次数累计
|
|
160
|
+
group_count[group_id] = group_count.get(group_id, 0) + 1
|
|
161
|
+
# group_id+同group_id的调用次数作为唯一的unique_group_id
|
|
162
|
+
unique_group_id = group_id + Const.REPLACEMENT_CHARACTER + str(group_count.get(group_id))
|
|
163
|
+
group_info[unique_group_id] = node_id
|
|
164
|
+
group_info[node_id] = unique_group_id
|
|
165
|
+
self.group_node_mapping[rank] = group_info
|
|
166
|
+
|
|
167
|
+
def _get_distributed_name_and_type(self, node_id):
|
|
168
|
+
if Const.SEP not in node_id:
|
|
169
|
+
raise ValueError(f'Invalid node id {node_id}.')
|
|
170
|
+
api_name = node_id.split(Const.SEP)[1]
|
|
171
|
+
if api_name in self.config:
|
|
172
|
+
return api_name, self.config.get(api_name)[3]
|
|
173
|
+
return api_name, DistributedType.COLLECTIVE
|
|
174
|
+
|
|
175
|
+
def _get_target_node(self, rank, unique_group_id, api_name, target_rank, target_api_name=None):
|
|
176
|
+
"""
|
|
177
|
+
获取名称匹配上的目标节点
|
|
178
|
+
:param rank: 当前rank
|
|
179
|
+
:param unique_group_id: 当前节点唯一group id
|
|
180
|
+
:param api_name: 当前节点的api名称, 例如Distributed.isend.0.forward, api名称为isend
|
|
181
|
+
:param target_rank: 与当前节点产生通信的rank
|
|
182
|
+
:param target_api_name: 与当前节点产生通信的节点api名称, 仅p2p通信需要配置
|
|
183
|
+
:return: 目标节点
|
|
184
|
+
"""
|
|
185
|
+
target_graph = self.graphs.get(target_rank)
|
|
186
|
+
if not target_graph:
|
|
187
|
+
logger.warning(f'Graph data does not exist, {CANNOT_MATCH}{target_rank}')
|
|
188
|
+
return None
|
|
189
|
+
target_group_mapping = self.group_node_mapping.get(target_rank)
|
|
190
|
+
# p2p通信,想要获取目标节点,需要替换unique_group_id中的rank和api name,
|
|
191
|
+
# 例如isend发送到rank1,对应的irecv接收自rank0, isend_rank1与irecv_rank0对应
|
|
192
|
+
target_unique_group_id = (unique_group_id
|
|
193
|
+
.replace(Const.RANK + str(target_rank), Const.RANK + str(rank))
|
|
194
|
+
.replace(api_name, target_api_name)) if target_api_name else unique_group_id
|
|
195
|
+
target_node_id = target_group_mapping.get(target_unique_group_id, '')
|
|
196
|
+
target_node = target_graph.node_map.get(target_node_id)
|
|
197
|
+
if not target_node:
|
|
198
|
+
logger.warning(f'Node {target_node_id} does not exist, {CANNOT_MATCH}{target_rank}')
|
|
199
|
+
return None
|
|
200
|
+
return target_node
|
|
201
|
+
|
|
202
|
+
def _add_node_matched_distributed(self, node, target_node, api_name, target_rank, reversal_type=False):
|
|
203
|
+
"""
|
|
204
|
+
给当前节点添加matched_distributed字段信息
|
|
205
|
+
:param node: 当前节点
|
|
206
|
+
:param target_node: 匹配上的目标节点
|
|
207
|
+
:param api_name: 当前节点的api名称
|
|
208
|
+
:param target_rank: 匹配上的目标rank
|
|
209
|
+
:param reversal_type: 是否需要反转通信类型,例如broadcast在rank0通信类型是发送,但在其他rank通信类型是接收
|
|
210
|
+
"""
|
|
211
|
+
communications_type = self.config.get(api_name)[2]
|
|
212
|
+
communications_type = self._get_opposite_communication_type(communications_type) if reversal_type \
|
|
213
|
+
else communications_type
|
|
214
|
+
index = target_node.data.get(GraphConst.OVERFLOW_LEVEL, CompareConst.NAN) if self.overflow_check \
|
|
215
|
+
else target_node.data.get(GraphConst.JSON_INDEX_KEY, CompareConst.NAN)
|
|
216
|
+
matched_distributed = {
|
|
217
|
+
'communications_type': communications_type,
|
|
218
|
+
'nodes_info': {target_rank: [str(index), target_node.id]}
|
|
219
|
+
}
|
|
220
|
+
node.matched_distributed = matched_distributed
|
|
221
|
+
|
|
222
|
+
def _p2p_match(self, node, rank, api_name):
|
|
223
|
+
"""
|
|
224
|
+
点对点通信匹配
|
|
225
|
+
|
|
226
|
+
根据当前点对点通信节点的输出数据中的src或dst参数, 确定目标rank, 并从目标rank中找到对应的点对点通信节点, 校验输出数据是否一致,
|
|
227
|
+
校验通过则在两个匹配节点增加匹配信息
|
|
228
|
+
Args:
|
|
229
|
+
node: 当前点对点通信节点
|
|
230
|
+
rank: 当前节点所属rank
|
|
231
|
+
api_name: 当前节点的api名称
|
|
232
|
+
Returns:
|
|
233
|
+
"""
|
|
234
|
+
config_info = self.config.get(api_name)
|
|
235
|
+
target_api_name = config_info[0]
|
|
236
|
+
#
|
|
237
|
+
target_rank = self._get_target_rank(node, rank, config_info[1])
|
|
238
|
+
if target_rank is None:
|
|
239
|
+
return
|
|
240
|
+
unique_group_id = self.group_node_mapping.get(rank, {}).get(node.id, '')
|
|
241
|
+
target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name)
|
|
242
|
+
if not target_node:
|
|
243
|
+
return
|
|
244
|
+
target_config_info = self.config.get(target_api_name)
|
|
245
|
+
source_rank = (target_node.input_data.get(f'{target_node.id}{GraphConst.INPUT}{target_config_info[1]}', {})
|
|
246
|
+
.get('value'))
|
|
247
|
+
if source_rank is None:
|
|
248
|
+
logger.warning(
|
|
249
|
+
f'The kwarg {target_config_info[1]} of node {target_node.id} does not exist, '
|
|
250
|
+
f'{CANNOT_MATCH}{source_rank}')
|
|
251
|
+
return
|
|
252
|
+
if source_rank != rank:
|
|
253
|
+
# 点对点通信,待匹配目标节点包含的rank信息要与当前rank一致
|
|
254
|
+
logger.warning(
|
|
255
|
+
f'{node.id} of rank{rank} is expected to communicate with {target_node.id} of rank{target_rank}, '
|
|
256
|
+
f'but the data shows that {target_node.id} communicates with rank{source_rank}.'
|
|
257
|
+
f'The rank is inconsistent, cannot match distributed node')
|
|
258
|
+
return
|
|
259
|
+
|
|
260
|
+
# 点对点通信,两个匹配节点的输出数据要一致
|
|
261
|
+
if not DistributedAnalyzer._node_output_all_equal(node.output_data.get(node.id + '.output.0'),
|
|
262
|
+
target_node.output_data.get(target_node.id + '.output.0')):
|
|
263
|
+
logger.warning(f'{node.id} output of rank{rank} is different from the {target_node.id} '
|
|
264
|
+
f'output of rank{target_rank}, cannot match distributed node')
|
|
265
|
+
return
|
|
266
|
+
|
|
267
|
+
self._add_node_matched_distributed(node, target_node, api_name, target_rank)
|
|
268
|
+
self._add_node_matched_distributed(target_node, node, target_api_name, rank)
|
|
269
|
+
|
|
270
|
+
def _collective_match(self, node, rank, api_name):
|
|
271
|
+
"""
|
|
272
|
+
集体通信匹配
|
|
273
|
+
|
|
274
|
+
一对多通信和多对一通信, 需要先获取节点输出数据中的src或dst或位置参数, 确定发送源或接收源, 多对多通信不需要
|
|
275
|
+
:param node: 当前集体通信节点
|
|
276
|
+
:param rank: 当前节点所属rank
|
|
277
|
+
:param api_name: 当前节点的api名称
|
|
278
|
+
:return:
|
|
279
|
+
"""
|
|
280
|
+
communications_type = CommunicationType.SEND_RECEIVE.value
|
|
281
|
+
config_info = self.config.get(api_name)
|
|
282
|
+
if config_info:
|
|
283
|
+
# 此时为一对多通信或多对一通信
|
|
284
|
+
source_rank = self._get_target_rank(node, rank, config_info[1])
|
|
285
|
+
if source_rank is None or str(source_rank) != str(rank):
|
|
286
|
+
return
|
|
287
|
+
communications_type = config_info[2]
|
|
288
|
+
group_ranks, group_id = self._get_group_info(node, rank)
|
|
289
|
+
if not group_ranks or not group_id:
|
|
290
|
+
return
|
|
291
|
+
unique_group_id = self.group_node_mapping.get(rank, {}).get(node.id, '')
|
|
292
|
+
matched_distributed = {'communications_type': communications_type}
|
|
293
|
+
nodes_info = {}
|
|
294
|
+
for target_rank in group_ranks:
|
|
295
|
+
if str(target_rank) == str(rank):
|
|
296
|
+
continue
|
|
297
|
+
target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank)
|
|
298
|
+
if not target_node:
|
|
299
|
+
continue
|
|
300
|
+
_, target_group_id = self._get_group_info(target_node, target_rank)
|
|
301
|
+
if not target_group_id:
|
|
302
|
+
continue
|
|
303
|
+
if group_id != target_group_id:
|
|
304
|
+
logger.warning(
|
|
305
|
+
f'{node.id} of rank{rank} is expected to communicate with {target_node.id} of rank{target_rank}'
|
|
306
|
+
f', but the data shows that the group id of the two nodes are different, '
|
|
307
|
+
f'cannot match distributed node')
|
|
308
|
+
continue
|
|
309
|
+
# 给当前通信节点添加matched_distributed字段信息
|
|
310
|
+
index = target_node.data.get(GraphConst.OVERFLOW_LEVEL, CompareConst.NAN) if self.overflow_check \
|
|
311
|
+
else target_node.data.get(GraphConst.JSON_INDEX_KEY, CompareConst.NAN)
|
|
312
|
+
nodes_info[target_rank] = [str(index), target_node.id]
|
|
313
|
+
if config_info:
|
|
314
|
+
# 给匹配上的目标节点也添加matched_distributed字段信息
|
|
315
|
+
self._add_node_matched_distributed(target_node, node, api_name, rank, True)
|
|
316
|
+
if nodes_info:
|
|
317
|
+
matched_distributed['nodes_info'] = nodes_info
|
|
318
|
+
node.matched_distributed = matched_distributed
|
|
@@ -67,6 +67,15 @@ class Graph:
|
|
|
67
67
|
ancestors_b = node_b.get_ancestors()
|
|
68
68
|
return node_b, ancestors_n, ancestors_b
|
|
69
69
|
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def fuzzy_match(node_n, node_b):
|
|
73
|
+
if not node_n or not node_b or not node_n.fuzzy_eq(node_b):
|
|
74
|
+
return None, [], []
|
|
75
|
+
ancestors_n = node_n.get_ancestors()
|
|
76
|
+
ancestors_b = node_b.get_ancestors()
|
|
77
|
+
return node_b, ancestors_n, ancestors_b
|
|
78
|
+
|
|
70
79
|
@staticmethod
|
|
71
80
|
def dfs(node, result):
|
|
72
81
|
info = node.to_dict()
|
|
@@ -28,11 +28,12 @@ from msprobe.core.common.log import logger
|
|
|
28
28
|
from msprobe.visualization.graph.node_colors import NodeColors
|
|
29
29
|
from msprobe.core.compare.layer_mapping import generate_api_mapping_by_layer_mapping
|
|
30
30
|
from msprobe.core.compare.utils import check_and_return_dir_contents
|
|
31
|
+
from msprobe.visualization.graph.distributed_analyzer import DistributedAnalyzer
|
|
31
32
|
|
|
32
33
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
33
34
|
|
|
34
35
|
|
|
35
|
-
def _compare_graph(input_param, args
|
|
36
|
+
def _compare_graph(input_param, args):
|
|
36
37
|
logger.info('Start building model graphs...')
|
|
37
38
|
# 对两个数据进行构图
|
|
38
39
|
dump_path_n = input_param.get('npu_path')
|
|
@@ -49,8 +50,8 @@ def _compare_graph(input_param, args, output_file_name=f'compare_{current_time}.
|
|
|
49
50
|
FileCheckConst.READ_ABLE).common_check()
|
|
50
51
|
stack_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.STACK_FILE), FileCheckConst.FILE,
|
|
51
52
|
FileCheckConst.READ_ABLE).common_check()
|
|
52
|
-
graph_n = GraphBuilder.build(construct_path_n, data_path_n, stack_path_n)
|
|
53
|
-
graph_b = GraphBuilder.build(construct_path_b, data_path_b, stack_path_b)
|
|
53
|
+
graph_n = GraphBuilder.build(construct_path_n, data_path_n, stack_path_n, complete_stack=args.complete_stack)
|
|
54
|
+
graph_b = GraphBuilder.build(construct_path_b, data_path_b, stack_path_b, complete_stack=args.complete_stack)
|
|
54
55
|
logger.info('Model graphs built successfully, start Comparing graphs...')
|
|
55
56
|
# 基于graph、stack和data进行比较
|
|
56
57
|
dump_path_param = {
|
|
@@ -66,8 +67,7 @@ def _compare_graph(input_param, args, output_file_name=f'compare_{current_time}.
|
|
|
66
67
|
mapping_dict = generate_api_mapping_by_layer_mapping(data_path_n, data_path_b, yaml_path)
|
|
67
68
|
except Exception:
|
|
68
69
|
logger.warning('The layer mapping file parsing failed, please check file format, mapping is not effective.')
|
|
69
|
-
graph_comparator = GraphComparator([graph_n, graph_b], dump_path_param, args
|
|
70
|
-
mapping_dict=mapping_dict)
|
|
70
|
+
graph_comparator = GraphComparator([graph_n, graph_b], dump_path_param, args, mapping_dict=mapping_dict)
|
|
71
71
|
graph_comparator.compare()
|
|
72
72
|
micro_steps = graph_n.paging_by_micro_step(graph_b)
|
|
73
73
|
# 开启溢出检测
|
|
@@ -75,16 +75,22 @@ def _compare_graph(input_param, args, output_file_name=f'compare_{current_time}.
|
|
|
75
75
|
graph_n.overflow_check()
|
|
76
76
|
graph_b.overflow_check()
|
|
77
77
|
|
|
78
|
+
return CompareGraphResult(graph_n, graph_b, graph_comparator, micro_steps)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _export_compare_graph_result(args, graphs, graph_comparator, micro_steps,
|
|
82
|
+
output_file_name=f'compare_{current_time}.vis'):
|
|
78
83
|
create_directory(args.output_path)
|
|
79
84
|
output_path = os.path.join(args.output_path, output_file_name)
|
|
80
85
|
task = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(graph_comparator.ma.compare_mode)
|
|
81
|
-
export_config = GraphExportConfig(
|
|
82
|
-
NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps, task
|
|
86
|
+
export_config = GraphExportConfig(graphs[0], graphs[1], graph_comparator.ma.get_tool_tip(),
|
|
87
|
+
NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps, task,
|
|
88
|
+
args.overflow_check)
|
|
83
89
|
GraphBuilder.to_json(output_path, export_config)
|
|
84
90
|
logger.info(f'Model graphs compared successfully, the result file is saved in {output_path}')
|
|
85
91
|
|
|
86
92
|
|
|
87
|
-
def _build_graph(dump_path,
|
|
93
|
+
def _build_graph(dump_path, args):
|
|
88
94
|
logger.info('Start building model graph...')
|
|
89
95
|
construct_path = FileChecker(os.path.join(dump_path, GraphConst.CONSTRUCT_FILE), FileCheckConst.FILE,
|
|
90
96
|
FileCheckConst.READ_ABLE).common_check()
|
|
@@ -92,14 +98,19 @@ def _build_graph(dump_path, out_path, overflow_check=False, output_file_name=f'b
|
|
|
92
98
|
FileCheckConst.READ_ABLE).common_check()
|
|
93
99
|
stack_path = FileChecker(os.path.join(dump_path, GraphConst.STACK_FILE), FileCheckConst.FILE,
|
|
94
100
|
FileCheckConst.READ_ABLE).common_check()
|
|
95
|
-
|
|
96
|
-
output_path = os.path.join(out_path, output_file_name)
|
|
97
|
-
graph = GraphBuilder.build(construct_path, data_path, stack_path)
|
|
101
|
+
graph = GraphBuilder.build(construct_path, data_path, stack_path, complete_stack=args.complete_stack)
|
|
98
102
|
micro_steps = graph.paging_by_micro_step()
|
|
99
103
|
# 开启溢出检测
|
|
100
|
-
if overflow_check:
|
|
104
|
+
if args.overflow_check:
|
|
101
105
|
graph.overflow_check()
|
|
102
|
-
|
|
106
|
+
return BuildGraphResult(graph, micro_steps)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _export_build_graph_result(out_path, graph, micro_steps, overflow_check,
|
|
110
|
+
output_file_name=f'build_{current_time}.vis'):
|
|
111
|
+
create_directory(out_path)
|
|
112
|
+
output_path = os.path.join(out_path, output_file_name)
|
|
113
|
+
GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps, overflow_check=overflow_check))
|
|
103
114
|
logger.info(f'Model graph built successfully, the result file is saved in {output_path}')
|
|
104
115
|
|
|
105
116
|
|
|
@@ -111,12 +122,33 @@ def _compare_graph_ranks(input_param, args, step=None):
|
|
|
111
122
|
if npu_ranks != bench_ranks:
|
|
112
123
|
logger.error('The number of ranks in the two runs are different. Unable to match the ranks.')
|
|
113
124
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
125
|
+
compare_graph_results = []
|
|
114
126
|
for nr, br in zip(npu_ranks, bench_ranks):
|
|
115
127
|
logger.info(f'Start processing data for {nr}...')
|
|
116
128
|
input_param['npu_path'] = os.path.join(dump_rank_n, nr)
|
|
117
129
|
input_param['bench_path'] = os.path.join(dump_rank_b, br)
|
|
118
130
|
output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis'
|
|
119
|
-
_compare_graph(input_param, args
|
|
131
|
+
result = _compare_graph(input_param, args)
|
|
132
|
+
result.output_file_name = output_file_name
|
|
133
|
+
if nr != Const.RANK:
|
|
134
|
+
try:
|
|
135
|
+
result.rank = int(nr.replace(Const.RANK, ""))
|
|
136
|
+
except Exception as e:
|
|
137
|
+
logger.error('The folder name format is incorrect, expected rank+number.')
|
|
138
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR) from e
|
|
139
|
+
# 暂存所有rank的graph,用于匹配rank间的分布式节点
|
|
140
|
+
compare_graph_results.append(result)
|
|
141
|
+
|
|
142
|
+
# 匹配rank间的分布式节点
|
|
143
|
+
if len(compare_graph_results) > 1:
|
|
144
|
+
DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results},
|
|
145
|
+
args.overflow_check).distributed_match()
|
|
146
|
+
DistributedAnalyzer({obj.rank: obj.graph_b for obj in compare_graph_results},
|
|
147
|
+
args.overflow_check).distributed_match()
|
|
148
|
+
|
|
149
|
+
for result in compare_graph_results:
|
|
150
|
+
_export_compare_graph_result(args, [result.graph_n, result.graph_b], result.graph_comparator,
|
|
151
|
+
result.micro_steps, output_file_name=result.output_file_name)
|
|
120
152
|
|
|
121
153
|
|
|
122
154
|
def _compare_graph_steps(input_param, args):
|
|
@@ -138,21 +170,38 @@ def _compare_graph_steps(input_param, args):
|
|
|
138
170
|
_compare_graph_ranks(input_param, args, step=folder_step)
|
|
139
171
|
|
|
140
172
|
|
|
141
|
-
def _build_graph_ranks(dump_ranks_path,
|
|
173
|
+
def _build_graph_ranks(dump_ranks_path, args, step=None):
|
|
142
174
|
ranks = sorted(check_and_return_dir_contents(dump_ranks_path, Const.RANK))
|
|
175
|
+
build_graph_results = []
|
|
143
176
|
for rank in ranks:
|
|
144
177
|
logger.info(f'Start processing data for {rank}...')
|
|
145
178
|
dump_path = os.path.join(dump_ranks_path, rank)
|
|
146
179
|
output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis'
|
|
147
|
-
_build_graph(dump_path,
|
|
180
|
+
result = _build_graph(dump_path, args)
|
|
181
|
+
result.output_file_name = output_file_name
|
|
182
|
+
if rank != Const.RANK:
|
|
183
|
+
try:
|
|
184
|
+
result.rank = int(rank.replace(Const.RANK, ""))
|
|
185
|
+
except Exception as e:
|
|
186
|
+
logger.error('The folder name format is incorrect, expected rank+number.')
|
|
187
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR) from e
|
|
188
|
+
build_graph_results.append(result)
|
|
189
|
+
|
|
190
|
+
if len(build_graph_results) > 1:
|
|
191
|
+
DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results},
|
|
192
|
+
args.overflow_check).distributed_match()
|
|
193
|
+
|
|
194
|
+
for result in build_graph_results:
|
|
195
|
+
_export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check,
|
|
196
|
+
result.output_file_name)
|
|
148
197
|
|
|
149
198
|
|
|
150
|
-
def _build_graph_steps(dump_steps_path,
|
|
199
|
+
def _build_graph_steps(dump_steps_path, args):
|
|
151
200
|
steps = sorted(check_and_return_dir_contents(dump_steps_path, Const.STEP))
|
|
152
201
|
for step in steps:
|
|
153
202
|
logger.info(f'Start processing data for {step}...')
|
|
154
203
|
dump_ranks_path = os.path.join(dump_steps_path, step)
|
|
155
|
-
_build_graph_ranks(dump_ranks_path,
|
|
204
|
+
_build_graph_ranks(dump_ranks_path, args, step)
|
|
156
205
|
|
|
157
206
|
|
|
158
207
|
def _graph_service_parser(parser):
|
|
@@ -161,9 +210,13 @@ def _graph_service_parser(parser):
|
|
|
161
210
|
parser.add_argument("-o", "--output_path", dest="output_path", type=str,
|
|
162
211
|
help="<Required> The compare task result out path.", required=True)
|
|
163
212
|
parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str,
|
|
164
|
-
help="<
|
|
213
|
+
help="<Optional> The layer mapping file path.", required=False)
|
|
165
214
|
parser.add_argument("-oc", "--overflow_check", dest="overflow_check", action="store_true",
|
|
166
215
|
help="<Optional> whether open overflow_check for graph.", required=False)
|
|
216
|
+
parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true",
|
|
217
|
+
help="<Optional> Whether to perform a fuzzy match on the api name.", required=False)
|
|
218
|
+
parser.add_argument("-cs", "--complete_stack", dest="complete_stack", action="store_true",
|
|
219
|
+
help="<Optional> Whether to use complete stack information.", required=False)
|
|
167
220
|
|
|
168
221
|
|
|
169
222
|
def _graph_service_command(args):
|
|
@@ -177,11 +230,12 @@ def _graph_service_command(args):
|
|
|
177
230
|
if check_file_type(npu_path) == FileCheckConst.DIR and not bench_path:
|
|
178
231
|
content = check_directory_content(npu_path)
|
|
179
232
|
if content == GraphConst.RANKS:
|
|
180
|
-
_build_graph_ranks(npu_path, args
|
|
233
|
+
_build_graph_ranks(npu_path, args)
|
|
181
234
|
elif content == GraphConst.STEPS:
|
|
182
|
-
_build_graph_steps(npu_path, args
|
|
235
|
+
_build_graph_steps(npu_path, args)
|
|
183
236
|
else:
|
|
184
|
-
_build_graph(npu_path, args
|
|
237
|
+
result = _build_graph(npu_path, args)
|
|
238
|
+
_export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check)
|
|
185
239
|
elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
|
|
186
240
|
content_n = check_directory_content(npu_path)
|
|
187
241
|
content_b = check_directory_content(bench_path)
|
|
@@ -192,7 +246,9 @@ def _graph_service_command(args):
|
|
|
192
246
|
elif content_n == GraphConst.STEPS:
|
|
193
247
|
_compare_graph_steps(input_param, args)
|
|
194
248
|
else:
|
|
195
|
-
_compare_graph(input_param, args)
|
|
249
|
+
result = _compare_graph(input_param, args)
|
|
250
|
+
_export_compare_graph_result(args, [result.graph_n, result.graph_b],
|
|
251
|
+
result.graph_comparator, result.micro_steps)
|
|
196
252
|
else:
|
|
197
253
|
logger.error("The npu_path or bench_path should be a folder.")
|
|
198
254
|
raise CompareException(CompareException.INVALID_COMPARE_MODE)
|
|
@@ -212,3 +268,21 @@ def _ms_graph_service_parser(parser):
|
|
|
212
268
|
|
|
213
269
|
def _ms_graph_service_command(args):
|
|
214
270
|
_graph_service_command(args)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class CompareGraphResult:
|
|
274
|
+
def __init__(self, graph_n, graph_b, graph_comparator, micro_steps, rank=0, output_file_name=''):
|
|
275
|
+
self.graph_n = graph_n
|
|
276
|
+
self.graph_b = graph_b
|
|
277
|
+
self.graph_comparator = graph_comparator
|
|
278
|
+
self.micro_steps = micro_steps
|
|
279
|
+
self.rank = rank
|
|
280
|
+
self.output_file_name = output_file_name
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class BuildGraphResult:
|
|
284
|
+
def __init__(self, graph, micro_steps, rank=0, output_file_name=''):
|
|
285
|
+
self.graph = graph
|
|
286
|
+
self.micro_steps = micro_steps
|
|
287
|
+
self.rank = rank
|
|
288
|
+
self.output_file_name = output_file_name
|
msprobe/visualization/utils.py
CHANGED
|
@@ -18,7 +18,7 @@ import re
|
|
|
18
18
|
import json
|
|
19
19
|
from msprobe.core.common.file_utils import FileOpen
|
|
20
20
|
from msprobe.core.common.const import CompareConst, Const
|
|
21
|
-
from msprobe.core.compare.acc_compare import Comparator
|
|
21
|
+
from msprobe.core.compare.acc_compare import Comparator, ModeConfig
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
def load_json_file(file_path):
|
|
@@ -50,12 +50,13 @@ def save_json_file(file_path, data):
|
|
|
50
50
|
f.write(json.dumps(data, indent=4))
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
def get_csv_df(
|
|
53
|
+
def get_csv_df(stack_mode, csv_data, compare_mode):
|
|
54
54
|
"""
|
|
55
55
|
调用acc接口写入csv
|
|
56
56
|
"""
|
|
57
57
|
dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
|
|
58
|
-
|
|
58
|
+
mode_config = ModeConfig(stack_mode=stack_mode, dump_mode=dump_mode)
|
|
59
|
+
return Comparator(mode_config).make_result_table(csv_data)
|
|
59
60
|
|
|
60
61
|
|
|
61
62
|
def str2float(percentage_str):
|
|
@@ -96,6 +97,10 @@ def check_directory_content(input_path):
|
|
|
96
97
|
if all(os.path.isfile(os.path.join(input_path, item)) for item in contents):
|
|
97
98
|
return GraphConst.FILES
|
|
98
99
|
|
|
100
|
+
# 单卡只有一个rank文件夹
|
|
101
|
+
if contents == [Const.RANK]:
|
|
102
|
+
return GraphConst.RANKS
|
|
103
|
+
|
|
99
104
|
rank_pattern = re.compile(r'^rank\d+$')
|
|
100
105
|
step_pattern = re.compile(r'^step\d+$')
|
|
101
106
|
|
|
@@ -141,16 +146,6 @@ class ToolTip:
|
|
|
141
146
|
SMALL_VALUE_TIP = '{}, 由于{}小于{}, 建议不参考此相对误差,请参考绝对误差'
|
|
142
147
|
|
|
143
148
|
|
|
144
|
-
class Suggestions:
|
|
145
|
-
Module = '此模块精度比对结果疑似异常,请使用msprobe工具的数据采集功能对模块中的api进行dump比对'
|
|
146
|
-
API = '此api精度比对结果疑似异常,请使用msprobe工具的预检功能对api进行精度检测'
|
|
147
|
-
DUMP = 'msprobe工具的数据采集功能'
|
|
148
|
-
DUMP_URL = 'https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/pytorch/doc/dump.md'
|
|
149
|
-
API_ACCURACY_CHECKER = 'msprobe工具的预检功能'
|
|
150
|
-
API_ACCURACY_CHECKER_URL = \
|
|
151
|
-
'https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/pytorch/doc/api_accuracy_checker.md'
|
|
152
|
-
|
|
153
|
-
|
|
154
149
|
class GraphConst:
|
|
155
150
|
CONSTRUCT_FILE = 'construct.json'
|
|
156
151
|
DUMP_FILE = 'dump.json'
|
|
@@ -172,6 +167,8 @@ class GraphConst:
|
|
|
172
167
|
MAX_RELATIVE_ERR_TH = 0.5
|
|
173
168
|
ROUND_TH = 6
|
|
174
169
|
JSON_INDEX_KEY = 'precision_index'
|
|
170
|
+
MATCHED_DISTRIBUTED = 'matched_distributed'
|
|
171
|
+
OVERFLOW_LEVEL = 'overflow_level'
|
|
175
172
|
MAX_INDEX_KEY = 1
|
|
176
173
|
MIN_INDEX_KEY = 0
|
|
177
174
|
SUGGEST_KEY = 'text'
|
|
@@ -198,6 +195,7 @@ class GraphConst:
|
|
|
198
195
|
DESCRIPTION = 'description'
|
|
199
196
|
COLORS = 'Colors'
|
|
200
197
|
MICRO_STEPS = 'MicroSteps'
|
|
198
|
+
OVERFLOW_CHECK = 'OverflowCheck'
|
|
201
199
|
|
|
202
200
|
DUMP_MODE_TO_GRAPHCOMPARE_MODE_MAPPING = {
|
|
203
201
|
Const.ALL: REAL_DATA_COMPARE,
|
|
@@ -210,23 +208,10 @@ class GraphConst:
|
|
|
210
208
|
SUMMARY_COMPARE: Const.SUMMARY,
|
|
211
209
|
MD5_COMPARE: Const.MD5
|
|
212
210
|
}
|
|
213
|
-
SMALL_VALUES = {
|
|
214
|
-
Const.TORCH_FLOAT32: 1e-6,
|
|
215
|
-
Const.TORCH_FLOAT16: 1e-3,
|
|
216
|
-
Const.TORCH_BFLOAT16: 1e-3,
|
|
217
|
-
Const.FLOAT32: 1e-6,
|
|
218
|
-
Const.FLOAT16: 1e-3,
|
|
219
|
-
Const.BFLOAT16: 1e-3
|
|
220
|
-
}
|
|
221
|
-
SMALL_VALUES_ABS_ERROR = {
|
|
222
|
-
Const.TORCH_FLOAT32: 1e-6,
|
|
223
|
-
Const.TORCH_FLOAT16: 1e-3,
|
|
224
|
-
Const.TORCH_BFLOAT16: 1e-3,
|
|
225
|
-
Const.FLOAT32: 1e-6,
|
|
226
|
-
Const.FLOAT16: 1e-3,
|
|
227
|
-
Const.BFLOAT16: 1e-3
|
|
228
|
-
}
|
|
229
211
|
|
|
230
212
|
RANKS = 'ranks'
|
|
231
213
|
STEPS = 'steps'
|
|
232
214
|
FILES = 'files'
|
|
215
|
+
|
|
216
|
+
SRC = 'src'
|
|
217
|
+
DST = 'dst'
|