mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
21
|
+
from msprobe.core.common.utils import logger, CompareException
|
|
22
|
+
from msprobe.core.common.file_utils import load_yaml
|
|
23
|
+
from msprobe.core.compare.config import ModeConfig
|
|
24
|
+
from msprobe.core.compare.utils import gen_api_batches
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
cur_dir = os.path.dirname(os.path.realpath(__file__))
|
|
28
|
+
diff_threshold_yaml_path = os.path.join(cur_dir, 'diff_analyze_threshold.yaml')
|
|
29
|
+
ignore_op_list_yaml_path = os.path.join(cur_dir, 'ignore_op_list.yaml')
|
|
30
|
+
ignore_list = load_yaml(ignore_op_list_yaml_path)
|
|
31
|
+
thresholds = load_yaml(diff_threshold_yaml_path)
|
|
32
|
+
cmp_metrics = thresholds.get('compare_metrics')
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class FirstDiffAnalyze:
|
|
36
|
+
def __init__(self, mode_config: ModeConfig, rank):
|
|
37
|
+
self.mode_config = mode_config
|
|
38
|
+
self.rank = rank
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def single_metric_diff_check(cmp_metric, metric_value):
|
|
42
|
+
threshold = thresholds.get(cmp_metric, None)
|
|
43
|
+
if threshold is None:
|
|
44
|
+
logger.error(f"Check diff or {cmp_metric} need to configure the threshold. "
|
|
45
|
+
f"Please configure it in 'diff_analyze_threshold.yaml'.")
|
|
46
|
+
raise CompareException(CompareException.MISSING_THRESHOLD_ERROR)
|
|
47
|
+
if not isinstance(threshold, list) or len(threshold) != 1:
|
|
48
|
+
logger.error(f"{cmp_metric} threshold configure wrong. Please check.")
|
|
49
|
+
raise CompareException(CompareException.WRONG_THRESHOLD_ERROR)
|
|
50
|
+
if isinstance(metric_value, str) and metric_value.endswith('%'):
|
|
51
|
+
metric_value_float = float(metric_value[:-1]) / 100
|
|
52
|
+
if metric_value_float > threshold[0]:
|
|
53
|
+
return True
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
def single_api_check(self, result_slice, header, api_name=None):
|
|
57
|
+
"""
|
|
58
|
+
单个api差异检查
|
|
59
|
+
|
|
60
|
+
:param result_slice: 数据切片
|
|
61
|
+
:param header: 列名列表
|
|
62
|
+
:return: {'is_same': bool, 'op_items': list[dict]}
|
|
63
|
+
"""
|
|
64
|
+
single_check_result = {
|
|
65
|
+
'is_same': True,
|
|
66
|
+
'op_items': []
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
column_indices = {name: idx for idx, name in enumerate(header)}
|
|
70
|
+
output_idx = -1
|
|
71
|
+
for line in result_slice:
|
|
72
|
+
op_item = {
|
|
73
|
+
column_name: line[column_indices[column_name]]
|
|
74
|
+
for column_name in header
|
|
75
|
+
}
|
|
76
|
+
single_check_result['op_items'].append(op_item)
|
|
77
|
+
if op_item['state'] != 'output':
|
|
78
|
+
continue
|
|
79
|
+
output_idx += 1
|
|
80
|
+
if output_idx in ignore_list.get(api_name, []):
|
|
81
|
+
continue
|
|
82
|
+
# set is_same
|
|
83
|
+
if self.mode_config.dump_mode == Const.MD5:
|
|
84
|
+
if line[column_indices[CompareConst.RESULT]] == CompareConst.DIFF:
|
|
85
|
+
single_check_result['is_same'] = False
|
|
86
|
+
else:
|
|
87
|
+
for cmp_metric in cmp_metrics:
|
|
88
|
+
metric_value = line[column_indices[cmp_metric]]
|
|
89
|
+
if self.single_metric_diff_check(cmp_metric, metric_value):
|
|
90
|
+
single_check_result['is_same'] = False
|
|
91
|
+
break
|
|
92
|
+
return single_check_result
|
|
93
|
+
|
|
94
|
+
def check(self, result_df):
|
|
95
|
+
"""
|
|
96
|
+
比对后循环遍历api检查差异
|
|
97
|
+
example:
|
|
98
|
+
{
|
|
99
|
+
'Functional.conv2d.0.forward': {
|
|
100
|
+
'is_same': true,
|
|
101
|
+
'op_items': [
|
|
102
|
+
{
|
|
103
|
+
'NPU name': 'Functional.conv2d.0.forward.input.0',
|
|
104
|
+
'Bench name': 'Functional.conv2d.0.forward.input.0',
|
|
105
|
+
'xxx': 1,
|
|
106
|
+
'NormRelativeErr': 2,
|
|
107
|
+
'yyy': 3,
|
|
108
|
+
...
|
|
109
|
+
}
|
|
110
|
+
]
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
"""
|
|
114
|
+
result = result_df.values
|
|
115
|
+
header = result_df.columns.tolist()
|
|
116
|
+
|
|
117
|
+
api_batches = gen_api_batches(result, header)
|
|
118
|
+
|
|
119
|
+
check_result = {}
|
|
120
|
+
|
|
121
|
+
default_bar_desc = 'API/Module diff check Progress'
|
|
122
|
+
bar_desc_add_rank = f'[{self.rank}]' + default_bar_desc if self.rank else default_bar_desc
|
|
123
|
+
with tqdm(total=len(api_batches), desc=bar_desc_add_rank, unit="api/module", ncols=100) as progress_bar:
|
|
124
|
+
for api_batch in api_batches:
|
|
125
|
+
result_slice = result[api_batch.start: api_batch.params_grad_end_index]
|
|
126
|
+
api_compo = api_batch.api_name.split('.')
|
|
127
|
+
# suppose name is Tensor.MatMul.0.forward
|
|
128
|
+
if len(api_compo) < 4:
|
|
129
|
+
continue
|
|
130
|
+
# get MatMul as api_name
|
|
131
|
+
api_name = api_compo[-3]
|
|
132
|
+
check_result[api_batch.api_name] = self.single_api_check(result_slice, header, api_name)
|
|
133
|
+
progress_bar.update(1)
|
|
134
|
+
|
|
135
|
+
return check_result
|
|
File without changes
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import time
|
|
17
|
+
from collections import defaultdict
|
|
18
|
+
import os
|
|
19
|
+
from itertools import dropwhile, chain
|
|
20
|
+
|
|
21
|
+
from msprobe.core.common import const
|
|
22
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, save_json, make_dir
|
|
23
|
+
from msprobe.core.common.log import logger
|
|
24
|
+
from msprobe.core.common.const import Const
|
|
25
|
+
from msprobe.core.compare.find_first.data_processor import DataProcessor
|
|
26
|
+
from msprobe.core.compare.find_first.utils import (RankPath, FileCache, is_communication_op, is_ignore_op,
|
|
27
|
+
DiffAnalyseConst, analyze_diff_in_group)
|
|
28
|
+
from msprobe.core.compare.find_first.graph import DataNode, CommunicationNode
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class DiffAnalyzer:
|
|
32
|
+
def __init__(self, npu_path, bench_path, output_path, data_frame=Const.PT_FRAMEWORK):
|
|
33
|
+
self._bench_path = bench_path
|
|
34
|
+
self._npu_path = npu_path
|
|
35
|
+
self._output_path = output_path
|
|
36
|
+
self.pre_processor = DataProcessor(data_frame)
|
|
37
|
+
self._paths = {}
|
|
38
|
+
self._diff_nodes = [] # 记录所有异常节点
|
|
39
|
+
self._cache = FileCache()
|
|
40
|
+
self._first_comm_nodes = {} # 记录各rank下首个通信节点的node_id
|
|
41
|
+
self._after_comm_diffs = {} # 记录各rank下发生在通信节点之后的异常计算节点
|
|
42
|
+
self._rank_comm_nodes_dict = {} # 记录各rank的通信节点
|
|
43
|
+
|
|
44
|
+
def analyze(self):
|
|
45
|
+
self._pre_process()
|
|
46
|
+
for analyze_func in [self._pre_analyze, self._analyze, self._post_analyze]:
|
|
47
|
+
analyze_func()
|
|
48
|
+
if self._diff_nodes:
|
|
49
|
+
self._gen_analyze_info()
|
|
50
|
+
return
|
|
51
|
+
logger.info('Cannot find any diff node, no need to generate analyze file.')
|
|
52
|
+
|
|
53
|
+
def _pre_process(self):
|
|
54
|
+
self.pre_processor.process(self._npu_path, self._bench_path, self._output_path)
|
|
55
|
+
self._resolve_input_path(self._output_path)
|
|
56
|
+
logger.info("Pre Process completed.")
|
|
57
|
+
|
|
58
|
+
"""
|
|
59
|
+
这里需要生成stack,但是直接用dict中自带就行,在op_items.NPU_Stack_Info中
|
|
60
|
+
"""
|
|
61
|
+
def _resolve_input_path(self, result_input_path):
|
|
62
|
+
contents = os.listdir(result_input_path)
|
|
63
|
+
rank_paths = {}
|
|
64
|
+
|
|
65
|
+
for path in contents:
|
|
66
|
+
# 检查文件名是否符合compare_result_rank{rank_id}_{timestamp}.json格式
|
|
67
|
+
if not path.startswith('compare_result_rank'):
|
|
68
|
+
continue
|
|
69
|
+
if not path.endswith('.json'):
|
|
70
|
+
continue
|
|
71
|
+
|
|
72
|
+
# 从文件名中提取rank_id
|
|
73
|
+
try:
|
|
74
|
+
path_ele_list = path.split('_')
|
|
75
|
+
if len(path_ele_list) <= 2:
|
|
76
|
+
continue
|
|
77
|
+
rank_part = path_ele_list[2]
|
|
78
|
+
if not rank_part.startswith('rank'):
|
|
79
|
+
continue
|
|
80
|
+
rank_str = rank_part.strip('rank') # 去掉'rank'前缀
|
|
81
|
+
rank = int(rank_str) if rank_str else 0
|
|
82
|
+
except (IndexError, ValueError):
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
# 构建完整的json文件路径
|
|
86
|
+
dump_path = os.path.join(result_input_path, path)
|
|
87
|
+
rank_paths[rank] = RankPath(rank, dump_path)
|
|
88
|
+
|
|
89
|
+
# 按照rank id排序后添加到self._paths中
|
|
90
|
+
for rank in sorted(rank_paths.keys()):
|
|
91
|
+
self._paths[rank] = rank_paths[rank]
|
|
92
|
+
|
|
93
|
+
def _pre_analyze(self):
|
|
94
|
+
logger.info('Start searching diff node before communication.')
|
|
95
|
+
for path in self._paths.values():
|
|
96
|
+
dump_data = self._cache.load_json(path.dump_path)
|
|
97
|
+
if not dump_data:
|
|
98
|
+
logger.warning(f'Rank {path.rank} has no dump data!')
|
|
99
|
+
continue
|
|
100
|
+
for op_name, op_data in dump_data.items():
|
|
101
|
+
if is_ignore_op(op_name):
|
|
102
|
+
continue
|
|
103
|
+
if is_communication_op(op_name):
|
|
104
|
+
self._first_comm_nodes[path.rank] = op_name
|
|
105
|
+
break
|
|
106
|
+
data_node = DataNode(op_name, path.rank, op_data)
|
|
107
|
+
if data_node.is_diff:
|
|
108
|
+
self._diff_nodes.append(data_node)
|
|
109
|
+
break
|
|
110
|
+
|
|
111
|
+
def _analyze(self):
|
|
112
|
+
logger.info('Start searching diff node during communication.')
|
|
113
|
+
self._rank_comm_nodes_dict = {rank: self._analyze_comm_nodes(rank) for rank in self._paths}
|
|
114
|
+
self._connect_comm_nodes()
|
|
115
|
+
self._pruning()
|
|
116
|
+
self._search_first_diff()
|
|
117
|
+
|
|
118
|
+
def _post_analyze(self):
|
|
119
|
+
logger.info('Start searching diff node after communication.')
|
|
120
|
+
for nodes in self._after_comm_diffs.values():
|
|
121
|
+
if nodes:
|
|
122
|
+
self._diff_nodes.append(nodes[0])
|
|
123
|
+
|
|
124
|
+
def _connect_comm_nodes(self):
|
|
125
|
+
searched_ranks = set()
|
|
126
|
+
for rank, nodes in list(self._rank_comm_nodes_dict.items())[:-1]:
|
|
127
|
+
searched_ranks.add(rank)
|
|
128
|
+
seen_nodes = set()
|
|
129
|
+
last_node = None
|
|
130
|
+
for cur_node in nodes.values():
|
|
131
|
+
is_overflow = last_node and hasattr(last_node, 'layer') and hasattr(cur_node, 'layer') and \
|
|
132
|
+
last_node.layer >= cur_node.layer
|
|
133
|
+
if is_overflow:
|
|
134
|
+
cur_node.layer = last_node.layer + 1
|
|
135
|
+
conn_info = cur_node.find_connected_nodes()
|
|
136
|
+
if not conn_info.get('ranks'):
|
|
137
|
+
conn_info['ranks'] = self._rank_comm_nodes_dict.keys()
|
|
138
|
+
last_node = cur_node
|
|
139
|
+
if not self._find_connection(conn_info, cur_node, searched_ranks, seen_nodes):
|
|
140
|
+
logger.debug(f'Cannot find connected communication node for "{cur_node.node_id}".')
|
|
141
|
+
|
|
142
|
+
def _find_connection(self, conn_info, cur_node, searched_ranks, seen_nodes):
|
|
143
|
+
def connect(search_node):
|
|
144
|
+
seen_nodes.add(search_node.node_id)
|
|
145
|
+
if search_node.type == DiffAnalyseConst.DST:
|
|
146
|
+
cur_node.add_dst(search_node)
|
|
147
|
+
elif search_node.type == DiffAnalyseConst.SRC:
|
|
148
|
+
search_node.layer = cur_node.layer
|
|
149
|
+
search_node.add_dst(cur_node)
|
|
150
|
+
else:
|
|
151
|
+
cur_node.add_link(search_node)
|
|
152
|
+
|
|
153
|
+
found = cur_node.connected
|
|
154
|
+
for connected_rank in conn_info['ranks']:
|
|
155
|
+
if connected_rank in searched_ranks:
|
|
156
|
+
continue
|
|
157
|
+
tar_id_prefix = f'{connected_rank}.{conn_info["api"]}'
|
|
158
|
+
for search_id, search_node in self._rank_comm_nodes_dict[connected_rank].items():
|
|
159
|
+
if search_id in seen_nodes:
|
|
160
|
+
continue
|
|
161
|
+
if not (search_id.startswith(tar_id_prefix) and search_node.type == conn_info.get('type')):
|
|
162
|
+
continue
|
|
163
|
+
search_conn_ranks = search_node.find_connected_nodes().get('ranks')
|
|
164
|
+
if ((not search_conn_ranks and search_node.api not in DiffAnalyseConst.DIRECTED_API) or
|
|
165
|
+
cur_node.rank in search_conn_ranks): # 有些无向通信算子没有填ProcessGroup,默认连接所有rank
|
|
166
|
+
connect(search_node)
|
|
167
|
+
found = True
|
|
168
|
+
break
|
|
169
|
+
return found
|
|
170
|
+
|
|
171
|
+
def _analyze_comm_nodes(self, rank):
|
|
172
|
+
path = self._paths[rank]
|
|
173
|
+
data = self._cache.load_json(path.dump_path)
|
|
174
|
+
communication_nodes = {}
|
|
175
|
+
if rank not in self._first_comm_nodes: # 此rank没有通信节点
|
|
176
|
+
return communication_nodes
|
|
177
|
+
last_node_id = None # 记录上一个通信节点的node_id
|
|
178
|
+
compute_ops = [] # 记录两个通信节点之间的计算节点
|
|
179
|
+
sub_layer = 0 # 记录两个通信算子之间异常计算节点的调用序数
|
|
180
|
+
for op_name in dropwhile(lambda k: k != self._first_comm_nodes[rank], data):
|
|
181
|
+
node_id = f'{rank}.{op_name}'
|
|
182
|
+
op_data = data[op_name]
|
|
183
|
+
if is_communication_op(op_name):
|
|
184
|
+
comm_node = CommunicationNode(node_id, rank, DataNode(op_name, rank, op_data, sub_layer=sub_layer),
|
|
185
|
+
compute_ops=compute_ops)
|
|
186
|
+
if last_node_id:
|
|
187
|
+
communication_nodes.get(last_node_id).add_next(comm_node)
|
|
188
|
+
communication_nodes[node_id] = comm_node
|
|
189
|
+
last_node_id = node_id
|
|
190
|
+
compute_ops = []
|
|
191
|
+
sub_layer = 0
|
|
192
|
+
elif not is_ignore_op(op_name):
|
|
193
|
+
data_node = DataNode(op_name, rank, op_data, sub_layer=sub_layer)
|
|
194
|
+
if data_node.is_diff:
|
|
195
|
+
compute_ops.append(data_node)
|
|
196
|
+
sub_layer += 1
|
|
197
|
+
if compute_ops:
|
|
198
|
+
self._after_comm_diffs[rank] = compute_ops
|
|
199
|
+
return communication_nodes
|
|
200
|
+
|
|
201
|
+
def _pruning(self):
|
|
202
|
+
deleted_node_id = []
|
|
203
|
+
for nodes in self._rank_comm_nodes_dict.values():
|
|
204
|
+
for node_id in list(nodes.keys()):
|
|
205
|
+
node = nodes[node_id]
|
|
206
|
+
if node.is_diff or node.compute_ops:
|
|
207
|
+
continue
|
|
208
|
+
deleted_node_id.append(node_id)
|
|
209
|
+
node.delete()
|
|
210
|
+
del nodes[node_id]
|
|
211
|
+
logger.debug(f'After pruning, following nodes are removed: [{", ".join(deleted_node_id)}]')
|
|
212
|
+
|
|
213
|
+
def _search_first_diff(self):
|
|
214
|
+
nodes_queues = []
|
|
215
|
+
for comm_nodes in self._rank_comm_nodes_dict.values():
|
|
216
|
+
nodes_queues.append(sorted(list(comm_nodes.values()), key=lambda x: x.layer))
|
|
217
|
+
seen_nodes = set()
|
|
218
|
+
|
|
219
|
+
def get_next_node(node_list):
|
|
220
|
+
while node_list:
|
|
221
|
+
next_node = node_list.pop(0)
|
|
222
|
+
if next_node.node_id not in seen_nodes:
|
|
223
|
+
return next_node
|
|
224
|
+
return None
|
|
225
|
+
|
|
226
|
+
def find_all_members(ori_node):
|
|
227
|
+
ids = get_relative_ids(ori_node)
|
|
228
|
+
id_queue = list(chain(*[get_relative_ids(self._get_node_by_id(n_id)).difference(ids) for n_id in ids]))
|
|
229
|
+
while id_queue:
|
|
230
|
+
new_id = id_queue.pop(0)
|
|
231
|
+
ids.add(new_id)
|
|
232
|
+
id_queue.extend(get_relative_ids(self._get_node_by_id(new_id)).difference(ids))
|
|
233
|
+
return ids
|
|
234
|
+
|
|
235
|
+
def get_relative_ids(ori_node):
|
|
236
|
+
if not ori_node:
|
|
237
|
+
return set()
|
|
238
|
+
return ({ori_node.node_id} | ori_node.link_nodes.keys() | ori_node.src_nodes.keys() |
|
|
239
|
+
ori_node.dst_nodes.keys())
|
|
240
|
+
|
|
241
|
+
while any(nodes_queues):
|
|
242
|
+
groups = []
|
|
243
|
+
all_ids_in_groups = set()
|
|
244
|
+
for nodes in nodes_queues:
|
|
245
|
+
node = get_next_node(nodes)
|
|
246
|
+
if not node:
|
|
247
|
+
continue
|
|
248
|
+
if not groups or node.node_id not in all_ids_in_groups:
|
|
249
|
+
new_group = find_all_members(node)
|
|
250
|
+
groups.append(new_group)
|
|
251
|
+
all_ids_in_groups.update(new_group)
|
|
252
|
+
for group in groups:
|
|
253
|
+
seen_nodes.update(group)
|
|
254
|
+
self._diff_nodes.extend(analyze_diff_in_group([self._get_node_by_id(n_id) for n_id in group]))
|
|
255
|
+
if self._diff_nodes:
|
|
256
|
+
# 找出所有layer和sub_layer最小的节点
|
|
257
|
+
min_layer_sublayer = min((x.layer, x.sub_layer) for x in self._diff_nodes)
|
|
258
|
+
self._diff_nodes = [
|
|
259
|
+
node
|
|
260
|
+
for node in self._diff_nodes
|
|
261
|
+
if (node.layer, node.sub_layer) == min_layer_sublayer
|
|
262
|
+
]
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
def _get_node_by_id(self, node_id):
|
|
266
|
+
splits = node_id.split(Const.SEP, 1)
|
|
267
|
+
if len(splits) < 2 or not splits[0].isdigit():
|
|
268
|
+
logger.error(f'invalid node_id {node_id}')
|
|
269
|
+
raise RuntimeError(f'invalid node_id {node_id}')
|
|
270
|
+
rank = int(splits[0])
|
|
271
|
+
return self._rank_comm_nodes_dict.get(rank, {}).get(node_id)
|
|
272
|
+
|
|
273
|
+
def _gen_analyze_info(self):
|
|
274
|
+
if not os.path.exists(self._output_path):
|
|
275
|
+
make_dir(self._output_path)
|
|
276
|
+
file_name = f'diff_analyze_{time.time_ns()}.json'
|
|
277
|
+
result_file = os.path.join(self._output_path, file_name)
|
|
278
|
+
result_content = defaultdict(list)
|
|
279
|
+
for node in self._diff_nodes:
|
|
280
|
+
result_content[f'rank_{node.rank}'].append(node.gen_node_info(self._paths[node.rank]))
|
|
281
|
+
save_json(result_file, result_content, 2)
|
|
282
|
+
logger.info(f"The analyze result is saved in: {result_file}")
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.core.common.log import logger
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DataProcessor:
|
|
22
|
+
def __init__(self, data_frame):
|
|
23
|
+
self.data_frame = data_frame
|
|
24
|
+
if self.data_frame == Const.PT_FRAMEWORK:
|
|
25
|
+
from msprobe.pytorch.compare.distributed_compare import compare_distributed
|
|
26
|
+
self.process_func = compare_distributed
|
|
27
|
+
elif self.data_frame == Const.MS_FRAMEWORK:
|
|
28
|
+
from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed
|
|
29
|
+
self.process_func = ms_compare_distributed
|
|
30
|
+
else:
|
|
31
|
+
raise ValueError(f"Unsupported data_frame: {self.data_frame}")
|
|
32
|
+
|
|
33
|
+
def process(self, npu_path, bench_path, output_path):
|
|
34
|
+
logger.info("Start comparing data ......")
|
|
35
|
+
return self.process_func(npu_path, bench_path, output_path, first_diff_analyze=True)
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.core.common.log import logger
|
|
19
|
+
from msprobe.core.common.const import CompareConst
|
|
20
|
+
from msprobe.core.compare.find_first.utils import RankPath, DiffAnalyseConst
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class DataNode:
|
|
25
|
+
op_name: str
|
|
26
|
+
rank: int
|
|
27
|
+
inputs: dict
|
|
28
|
+
outputs: dict
|
|
29
|
+
op_data: list
|
|
30
|
+
layer: int = 0 # 和communication_node的layer保持一致
|
|
31
|
+
sub_layer: int = 0 # 调用顺序,越小表示越先调用
|
|
32
|
+
|
|
33
|
+
def __init__(self, op_name, rank, op_data, **kwargs):
|
|
34
|
+
self.op_name = op_name
|
|
35
|
+
self.rank = rank
|
|
36
|
+
self.stack = None
|
|
37
|
+
self.inputs = {}
|
|
38
|
+
self.outputs = {}
|
|
39
|
+
self.is_diff = False
|
|
40
|
+
self.parse_data(op_data)
|
|
41
|
+
self.sub_layer = kwargs.get('sub_layer', 0)
|
|
42
|
+
|
|
43
|
+
def find_stack(self):
|
|
44
|
+
for item in self.stack:
|
|
45
|
+
if len(item) >= 2 and self.op_name in item[0]:
|
|
46
|
+
return item[1]
|
|
47
|
+
return {}
|
|
48
|
+
|
|
49
|
+
def parse_data(self, op_data):
|
|
50
|
+
self.is_diff = not op_data.get("is_same", True)
|
|
51
|
+
self.op_data = op_data.get("op_items") # 这里拿到的是比对column,是一个list,有若干行
|
|
52
|
+
metrics = {}
|
|
53
|
+
for cmp_data in self.op_data:
|
|
54
|
+
name = cmp_data.get(CompareConst.NPU_NAME)
|
|
55
|
+
# 构建度量指标字典
|
|
56
|
+
metrics = {}
|
|
57
|
+
|
|
58
|
+
if CompareConst.NPU_MAX in cmp_data:
|
|
59
|
+
metrics = {CompareConst.NPU_MAX: cmp_data.get(CompareConst.NPU_MAX),
|
|
60
|
+
CompareConst.NPU_MIN: cmp_data.get(CompareConst.NPU_MIN),
|
|
61
|
+
CompareConst.NPU_MEAN: cmp_data.get(CompareConst.NPU_MEAN),
|
|
62
|
+
CompareConst.NPU_NORM: cmp_data.get(CompareConst.NPU_NORM)}
|
|
63
|
+
elif CompareConst.NPU_MD5 in cmp_data:
|
|
64
|
+
metrics[CompareConst.NPU_MD5] = cmp_data.get(CompareConst.NPU_MD5)
|
|
65
|
+
|
|
66
|
+
if CompareConst.NPU_P2POP_PEER in cmp_data:
|
|
67
|
+
metrics[CompareConst.NPU_P2POP_PEER] = cmp_data.get(CompareConst.NPU_P2POP_PEER)
|
|
68
|
+
|
|
69
|
+
if cmp_data.get(CompareConst.STACK) != CompareConst.N_A and not self.stack:
|
|
70
|
+
self.stack = cmp_data.get(CompareConst.STACK)
|
|
71
|
+
if cmp_data.get('state') == "input":
|
|
72
|
+
self.inputs[name] = metrics
|
|
73
|
+
elif cmp_data.get('state') == "output":
|
|
74
|
+
self.outputs[name] = metrics
|
|
75
|
+
|
|
76
|
+
def gen_node_info(self, path: RankPath):
|
|
77
|
+
data_info_list = {Const.INPUT: self.inputs, Const.OUTPUT: self.outputs}
|
|
78
|
+
return {'op_name': self.op_name,
|
|
79
|
+
'data_info': data_info_list,
|
|
80
|
+
'stack_info': self.stack}
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class CommunicationNode:
|
|
84
|
+
def __init__(self, node_id, rank, data: DataNode, layer=0, **kwargs):
|
|
85
|
+
self.node_id = node_id
|
|
86
|
+
self.rank = rank
|
|
87
|
+
self.data = data
|
|
88
|
+
self.is_diff = data.is_diff
|
|
89
|
+
self.layer = layer
|
|
90
|
+
op_name_split = self.data.op_name.split(Const.SEP)
|
|
91
|
+
if len(op_name_split) < 4:
|
|
92
|
+
logger.error(f'invalid op_name: {self.data.op_name}')
|
|
93
|
+
raise RuntimeError(f'invalid op_name: {self.data.op_name}')
|
|
94
|
+
self.api = op_name_split[1]
|
|
95
|
+
self.call_cnt = op_name_split[2]
|
|
96
|
+
self.pre_node = kwargs.get('pre_node')
|
|
97
|
+
self.link_nodes = kwargs.get('link_nodes', {})
|
|
98
|
+
self.dst_nodes = kwargs.get('dst_nodes', {})
|
|
99
|
+
self.src_nodes = kwargs.get('src_nodes', {})
|
|
100
|
+
self.next_nodes = kwargs.get('next_nodes', {})
|
|
101
|
+
self.compute_ops = kwargs.get('compute_ops', [])
|
|
102
|
+
self.type = self._resolve_type()
|
|
103
|
+
self.connected = False
|
|
104
|
+
|
|
105
|
+
def add_next(self, node):
|
|
106
|
+
self.next_nodes[node.node_id] = node
|
|
107
|
+
node.pre_node = self
|
|
108
|
+
node.layer = self.layer + 1
|
|
109
|
+
node.data.layer = node.layer
|
|
110
|
+
|
|
111
|
+
def add_link(self, node):
|
|
112
|
+
self.link_nodes[node.node_id] = node
|
|
113
|
+
node.link_nodes[self.node_id] = self
|
|
114
|
+
node.layer = self.layer
|
|
115
|
+
node.data.layer = node.layer
|
|
116
|
+
self.connected = True
|
|
117
|
+
node.connected = True
|
|
118
|
+
|
|
119
|
+
def add_dst(self, node):
|
|
120
|
+
self.dst_nodes[node.node_id] = node
|
|
121
|
+
node.src_nodes[self.node_id] = self
|
|
122
|
+
node.layer = self.layer
|
|
123
|
+
node.data.layer = node.layer
|
|
124
|
+
self.connected = True
|
|
125
|
+
node.connected = True
|
|
126
|
+
|
|
127
|
+
def delete(self):
|
|
128
|
+
for node in self.next_nodes.values():
|
|
129
|
+
node.pre_node = None
|
|
130
|
+
for node in self.dst_nodes.values():
|
|
131
|
+
if node.src_nodes:
|
|
132
|
+
node.src_nodes.pop(self.node_id)
|
|
133
|
+
for node in self.src_nodes.values():
|
|
134
|
+
if node.dst_nodes:
|
|
135
|
+
node.dst_nodes.pop(self.node_id)
|
|
136
|
+
for node in self.link_nodes.values():
|
|
137
|
+
if node.link_nodes:
|
|
138
|
+
node.link_nodes.pop(self.node_id)
|
|
139
|
+
if self.pre_node:
|
|
140
|
+
if self.pre_node.next_nodes:
|
|
141
|
+
self.pre_node.next_nodes.pop(self.node_id)
|
|
142
|
+
|
|
143
|
+
def find_connected_nodes(self):
|
|
144
|
+
"""
|
|
145
|
+
根据 api/类型/入参/调用次数 确定相连接的node的op_name
|
|
146
|
+
"""
|
|
147
|
+
tar_api = DiffAnalyseConst.P2P_API_MAPPING.get(self.api, self.api)
|
|
148
|
+
ranks = set()
|
|
149
|
+
# 遍历DST和SRC相关的input,获取对应的rank值
|
|
150
|
+
# 遍历inputs获取所有rank值
|
|
151
|
+
for input_name, v in self.data.inputs.items():
|
|
152
|
+
# 检查key是否包含DST/SRC相关标识
|
|
153
|
+
target_types = [DiffAnalyseConst.DST, DiffAnalyseConst.DST_GROUP,
|
|
154
|
+
DiffAnalyseConst.SRC, DiffAnalyseConst.SRC_GROUP]
|
|
155
|
+
if any(keyword in input_name for keyword in target_types):
|
|
156
|
+
# 优先使用MD5值,如果没有则使用NPU_MAX值
|
|
157
|
+
rank_val = 0
|
|
158
|
+
if CompareConst.NPU_MD5 in v:
|
|
159
|
+
rank_val = int(v.get(CompareConst.NPU_MD5, 0))
|
|
160
|
+
else:
|
|
161
|
+
rank_val = int(v.get(CompareConst.NPU_MAX, 0))
|
|
162
|
+
if rank_val:
|
|
163
|
+
ranks.add(rank_val)
|
|
164
|
+
elif input_name.endswith('.group'):
|
|
165
|
+
# 优先使用MD5值,如果没有则使用NPU_MAX值
|
|
166
|
+
val = v.get(CompareConst.NPU_MD5) if CompareConst.NPU_MD5 in v else v.get(CompareConst.NPU_MAX)
|
|
167
|
+
if val and val.startswith('[') and val.endswith(']'):
|
|
168
|
+
val = [int(part) for part in val.strip('[]').split(',')]
|
|
169
|
+
ranks.update(val)
|
|
170
|
+
elif v.get(CompareConst.NPU_P2POP_PEER) != "None":
|
|
171
|
+
ranks.add(v.get(CompareConst.NPU_P2POP_PEER))
|
|
172
|
+
|
|
173
|
+
return {'ranks': ranks, 'api': f'Distributed.{tar_api}',
|
|
174
|
+
'type': DiffAnalyseConst.OPPOSITE_DIR.get(self.type, DiffAnalyseConst.LINK)}
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _resolve_type(self):
|
|
178
|
+
# 遍历SRC和DST相关的输入,根据rank值判断节点类型
|
|
179
|
+
for prefix, node_type in [(DiffAnalyseConst.SRC, DiffAnalyseConst.SRC),
|
|
180
|
+
(DiffAnalyseConst.DST, DiffAnalyseConst.DST)]:
|
|
181
|
+
for k, v in self.data.inputs.items():
|
|
182
|
+
if prefix in k or f"{prefix}_GROUP" in k:
|
|
183
|
+
# 优先使用MD5值,如果没有则使用NPU_MAX值
|
|
184
|
+
compare_val = v.get(CompareConst.NPU_MD5) if CompareConst.NPU_MD5 in v \
|
|
185
|
+
else v.get(CompareConst.NPU_MAX)
|
|
186
|
+
return node_type if compare_val == self.rank \
|
|
187
|
+
else DiffAnalyseConst.OPPOSITE_DIR[node_type]
|
|
188
|
+
return DiffAnalyseConst.LINK
|