mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections import OrderedDict
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
import sys
|
|
19
|
+
import time
|
|
20
|
+
import psutil
|
|
21
|
+
|
|
22
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, load_json
|
|
23
|
+
from msprobe.core.common.const import Const
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class RankPath:
|
|
28
|
+
rank: int
|
|
29
|
+
dump_path: str
|
|
30
|
+
|
|
31
|
+
def __init__(self, rank, dump_path):
|
|
32
|
+
self.rank = rank
|
|
33
|
+
check_file_or_directory_path(dump_path)
|
|
34
|
+
self.dump_path = dump_path
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class FileCache:
|
|
38
|
+
"""
|
|
39
|
+
lazy load file
|
|
40
|
+
"""
|
|
41
|
+
_instance = None
|
|
42
|
+
|
|
43
|
+
def __new__(cls, *args, **kwargs):
|
|
44
|
+
if not cls._instance:
|
|
45
|
+
cls._instance = super().__new__(cls, *args, **kwargs)
|
|
46
|
+
return cls._instance
|
|
47
|
+
|
|
48
|
+
def __init__(self):
|
|
49
|
+
self._max_memory_usage = psutil.virtual_memory().available / 4 # 最大占用当前可用内存空间的1/4
|
|
50
|
+
self._cache = OrderedDict()
|
|
51
|
+
self._access_cnt = {}
|
|
52
|
+
self._access_time = {}
|
|
53
|
+
self._size = {}
|
|
54
|
+
|
|
55
|
+
@staticmethod
|
|
56
|
+
def _sizeof(obj):
|
|
57
|
+
seen = set()
|
|
58
|
+
objs = [obj]
|
|
59
|
+
size = 0
|
|
60
|
+
while objs:
|
|
61
|
+
obj = objs.pop()
|
|
62
|
+
obj_id = id(obj)
|
|
63
|
+
if obj_id in seen:
|
|
64
|
+
continue
|
|
65
|
+
seen.add(obj_id)
|
|
66
|
+
size += sys.getsizeof(obj)
|
|
67
|
+
if isinstance(obj, dict):
|
|
68
|
+
objs.extend(obj.keys())
|
|
69
|
+
objs.extend(obj.values())
|
|
70
|
+
elif isinstance(obj, (list, tuple, set, frozenset)):
|
|
71
|
+
objs.extend(obj)
|
|
72
|
+
return size
|
|
73
|
+
|
|
74
|
+
def load_json(self, json_path):
|
|
75
|
+
if json_path in self._cache:
|
|
76
|
+
self._access_cnt[json_path] += 1
|
|
77
|
+
self._access_time[json_path] = time.monotonic()
|
|
78
|
+
self._cache.move_to_end(json_path)
|
|
79
|
+
return self._cache[json_path]
|
|
80
|
+
self._cleanup()
|
|
81
|
+
return self._load(json_path)
|
|
82
|
+
|
|
83
|
+
def _load(self, json_path):
|
|
84
|
+
data = load_json(json_path)
|
|
85
|
+
self._add_to_cache(json_path, data)
|
|
86
|
+
return data
|
|
87
|
+
|
|
88
|
+
def _add_to_cache(self, key, data):
|
|
89
|
+
if key in self._cache:
|
|
90
|
+
self._cache.move_to_end(key)
|
|
91
|
+
else:
|
|
92
|
+
self._cache[key] = data
|
|
93
|
+
self._access_cnt[key] = 0
|
|
94
|
+
self._access_time[key] = time.monotonic()
|
|
95
|
+
self._size[key] = self._sizeof(data)
|
|
96
|
+
|
|
97
|
+
def _calc_cache_size(self):
|
|
98
|
+
return sys.getsizeof(self._cache) + sum(self._size.values())
|
|
99
|
+
|
|
100
|
+
def _cleanup(self):
|
|
101
|
+
while self._calc_cache_size() > self._max_memory_usage and self._cache:
|
|
102
|
+
least_frequent_key = min(self._access_cnt.keys(), key=lambda k: self._access_cnt[k])
|
|
103
|
+
least_recent_key = min(self._access_time.keys(), key=lambda k: self._access_time[k])
|
|
104
|
+
largest_key = max(self._cache.keys(), key=lambda k: self._size[k])
|
|
105
|
+
key_to_rm = min([least_frequent_key, least_recent_key, largest_key],
|
|
106
|
+
key=lambda k: (self._access_cnt[k], self._access_time[k], -self._size[k]))
|
|
107
|
+
del self._cache[key_to_rm]
|
|
108
|
+
del self._access_cnt[key_to_rm]
|
|
109
|
+
del self._access_time[key_to_rm]
|
|
110
|
+
del self._size[key_to_rm]
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def is_communication_op(op_name):
|
|
114
|
+
# 定义通信算子的关键字,覆盖各种通信操作,如all_reduce, send, broadcast等
|
|
115
|
+
# 从wrap文件中读取,先硬编码在文件中
|
|
116
|
+
return (op_name.startswith((Const.DISTRIBUTED, Const.MINT_DIST_API_TYPE_PREFIX, Const.MS_API_TYPE_COM)) and
|
|
117
|
+
any(keyword in op_name for keyword in DiffAnalyseConst.COMMUNICATION_KEYWORDS))
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def is_ignore_op(op_name):
|
|
121
|
+
ignore_keywords = [
|
|
122
|
+
'Torch.empty',
|
|
123
|
+
'Torch.fill',
|
|
124
|
+
'Tensor.__setitem__'
|
|
125
|
+
]
|
|
126
|
+
return any(keyword in op_name for keyword in ignore_keywords)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class DiffAnalyseConst:
|
|
130
|
+
COMMUNICATION_KEYWORDS = {
|
|
131
|
+
'send', # send 算子
|
|
132
|
+
'recv', # recv 算子
|
|
133
|
+
'broadcast', # broadcast 算子
|
|
134
|
+
'all_reduce', # all_reduce 算子
|
|
135
|
+
'reduce', # reduce 算子
|
|
136
|
+
'all_gather', # all_gather 算子
|
|
137
|
+
'gather', # gather 算子
|
|
138
|
+
'isend', # isend 算子
|
|
139
|
+
'irecv', # irecv 算子
|
|
140
|
+
'scatter', # scatter 算子
|
|
141
|
+
'reduce_scatter', # reduce_scatter 算子
|
|
142
|
+
'_reduce_scatter_base', # _reduce_scatter_base 算子
|
|
143
|
+
'_all_gather_base', # _all_gather_base 算子
|
|
144
|
+
'all_to_all_single', # all_to_all_single 算子
|
|
145
|
+
'all_to_all', # all_to_all 算子
|
|
146
|
+
'all_gather_into_tensor', # all_gather_into_tensor 算子
|
|
147
|
+
'reduce_scatter_tensor', # reduce_scatter_tensor 算子
|
|
148
|
+
'send_object_list', # send_object_list 算子
|
|
149
|
+
'recv_object_list' # recv_object_list 算子
|
|
150
|
+
}
|
|
151
|
+
P2P_API_MAPPING = {'send': 'recv', 'recv': 'send', 'isend': 'irecv', 'irecv': 'isend',
|
|
152
|
+
'send_object_list': 'recv_object_list', 'recv_object_list': 'send_object_list'}
|
|
153
|
+
SRC = 'src'
|
|
154
|
+
DST = 'dst'
|
|
155
|
+
SRC_GROUP = 'group_src'
|
|
156
|
+
DST_GROUP = 'group_dst'
|
|
157
|
+
LINK = 'link'
|
|
158
|
+
DIRECTED_API = {'send': DST, 'recv': SRC, 'isend': DST, 'irecv': SRC, 'broadcast': SRC, 'scatter': SRC,
|
|
159
|
+
'gather': DST, 'send_object_list': DST, 'recv_object_list': SRC}
|
|
160
|
+
OPPOSITE_DIR = {SRC: DST, DST: SRC}
|
|
161
|
+
DUMP_FILE = "dump.json"
|
|
162
|
+
CONSTRUCT_FILE = "construct.json"
|
|
163
|
+
STACK_FILE = "stack.json"
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def analyze_diff_in_group(nodes_group):
|
|
167
|
+
diff_nodes = []
|
|
168
|
+
|
|
169
|
+
def get_compute_ops_from_comm_nodes(comm_nodes):
|
|
170
|
+
for comm_node in comm_nodes:
|
|
171
|
+
for op_node in comm_node.compute_ops:
|
|
172
|
+
op_node.layer = comm_node.layer
|
|
173
|
+
diff_nodes.append(op_node)
|
|
174
|
+
|
|
175
|
+
def get_comm_ops(comm_nodes):
|
|
176
|
+
for node in comm_nodes:
|
|
177
|
+
node.data.layer = node.layer
|
|
178
|
+
diff_nodes.append(node.data)
|
|
179
|
+
|
|
180
|
+
# 先看src或link中input是否有异常
|
|
181
|
+
src_list = list(filter(lambda node: node.type in [DiffAnalyseConst.SRC, DiffAnalyseConst.LINK], nodes_group))
|
|
182
|
+
input_diff_nodes = list(filter(lambda node: node.is_diff, src_list))
|
|
183
|
+
# 如果有异常回溯计算节点找到异常来源
|
|
184
|
+
# 使用cpu模拟节点进行计算,查看结果是否有问题。需要对所有计算节点录入/映射,暂不实现。
|
|
185
|
+
get_compute_ops_from_comm_nodes(nodes_group)
|
|
186
|
+
# 筛选入参没问题但出参有问题的通信节点
|
|
187
|
+
output_diff_nodes = list(filter(lambda node: node.data.is_diff, nodes_group))
|
|
188
|
+
get_comm_ops(output_diff_nodes)
|
|
189
|
+
return diff_nodes
|
|
@@ -16,10 +16,8 @@
|
|
|
16
16
|
import abc
|
|
17
17
|
import math
|
|
18
18
|
import multiprocessing
|
|
19
|
-
import re
|
|
20
19
|
from collections import namedtuple
|
|
21
20
|
|
|
22
|
-
import numpy as np
|
|
23
21
|
import openpyxl
|
|
24
22
|
from openpyxl.styles import PatternFill
|
|
25
23
|
from openpyxl.utils.dataframe import dataframe_to_rows
|
|
@@ -28,8 +26,8 @@ from tqdm import tqdm
|
|
|
28
26
|
from msprobe.core.common.const import CompareConst, Const
|
|
29
27
|
from msprobe.core.common.file_utils import save_workbook
|
|
30
28
|
from msprobe.core.common.log import logger
|
|
31
|
-
from msprobe.core.common.utils import get_header_index,
|
|
32
|
-
from msprobe.core.compare.utils import table_value_is_valid,
|
|
29
|
+
from msprobe.core.common.utils import get_header_index, CompareException
|
|
30
|
+
from msprobe.core.compare.utils import table_value_is_valid, gen_api_batches
|
|
33
31
|
from msprobe.core.compare.config import ModeConfig
|
|
34
32
|
|
|
35
33
|
|
|
@@ -54,10 +52,12 @@ class CheckOrderMagnitude(HighlightCheck):
|
|
|
54
52
|
api_in, api_out, num = info
|
|
55
53
|
max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
|
|
56
54
|
else CompareConst.MAX_ABS_ERR, dump_mode)
|
|
57
|
-
|
|
55
|
+
max_diff_in = abs(api_in[max_diff_index])
|
|
56
|
+
max_diff_out = abs(api_out[max_diff_index])
|
|
57
|
+
if max_diff_in > max_diff_out or (max_diff_in <= 1 or max_diff_out <= 1):
|
|
58
58
|
return
|
|
59
|
-
in_order = 0 if
|
|
60
|
-
out_order = 0 if
|
|
59
|
+
in_order = 0 if max_diff_in < 1 else math.log10(max_diff_in)
|
|
60
|
+
out_order = 0 if max_diff_out < 1 else math.log10(max_diff_out)
|
|
61
61
|
if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW:
|
|
62
62
|
add_highlight_row_info(color_columns.yellow, num,
|
|
63
63
|
"maximum absolute error of both input/parameters and output exceed 1, "
|
|
@@ -102,20 +102,28 @@ class CheckMaxRelativeDiff(HighlightCheck):
|
|
|
102
102
|
"""检查最大相对差异"""
|
|
103
103
|
|
|
104
104
|
def apply(self, info, color_columns, dump_mode):
|
|
105
|
+
def get_number(data):
|
|
106
|
+
"""统计量相对值如果为正常百分数据,str格式并以%结尾"""
|
|
107
|
+
if isinstance(data, str) and data.endswith("%"):
|
|
108
|
+
return float(data[:-1]) / 100
|
|
109
|
+
return data
|
|
110
|
+
|
|
105
111
|
api_in, api_out, num = info
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
output_max_relative_diff =
|
|
111
|
-
|
|
112
|
-
if not isinstance(
|
|
113
|
-
(float, int)):
|
|
112
|
+
max_rel_diff = get_header_index(CompareConst.MAX_RELATIVE_ERR, dump_mode)
|
|
113
|
+
input_max_relative_diff = api_in[max_rel_diff] # 内部数据,长度总是和表头一致,不会越界
|
|
114
|
+
output_max_relative_diff = api_out[max_rel_diff]
|
|
115
|
+
input_max_relative_diff = get_number(input_max_relative_diff)
|
|
116
|
+
output_max_relative_diff = get_number(output_max_relative_diff)
|
|
117
|
+
|
|
118
|
+
if not isinstance(output_max_relative_diff, (float, int)):
|
|
114
119
|
return
|
|
115
120
|
if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
|
|
116
121
|
add_highlight_row_info(color_columns.red, num, "maximum relative error exceeds 0.5")
|
|
117
|
-
|
|
118
|
-
|
|
122
|
+
|
|
123
|
+
if not isinstance(input_max_relative_diff, (float, int)):
|
|
124
|
+
return
|
|
125
|
+
if (output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and
|
|
126
|
+
input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
|
|
119
127
|
add_highlight_row_info(color_columns.yellow, num,
|
|
120
128
|
"The output's maximum relative error exceeds 0.1, "
|
|
121
129
|
"while the input/parameter's is below 0.01")
|
|
@@ -139,12 +147,25 @@ class CheckOverflow(HighlightCheck):
|
|
|
139
147
|
add_highlight_row_info(color_columns.red, num, "maximum absolute error exceeds 1e+10")
|
|
140
148
|
|
|
141
149
|
|
|
150
|
+
class CheckReqGradConsist(HighlightCheck):
|
|
151
|
+
"""检查requires_grad是否一致"""
|
|
152
|
+
|
|
153
|
+
def apply(self, info, color_columns, dump_mode):
|
|
154
|
+
line, num = info
|
|
155
|
+
req_grad_consist_index = get_header_index(CompareConst.REQ_GRAD_CONSIST, dump_mode)
|
|
156
|
+
if not line[req_grad_consist_index]:
|
|
157
|
+
add_highlight_row_info(color_columns.yellow, num, "requires_grad is inconsistent")
|
|
158
|
+
|
|
159
|
+
|
|
142
160
|
class HighlightRules:
|
|
143
161
|
"""高亮规则集合,用于检查API的误差"""
|
|
144
162
|
# 适用于每行的规则
|
|
145
163
|
basic_rules = {
|
|
146
164
|
"check_overflow": CheckOverflow()
|
|
147
165
|
}
|
|
166
|
+
consist_rules = {
|
|
167
|
+
"check_req_grad_consist": CheckReqGradConsist()
|
|
168
|
+
}
|
|
148
169
|
|
|
149
170
|
# 用于比较输入和输出的规则
|
|
150
171
|
# 真实数据检查规则
|
|
@@ -160,64 +181,10 @@ class HighlightRules:
|
|
|
160
181
|
}
|
|
161
182
|
|
|
162
183
|
|
|
163
|
-
class ApiBatch:
|
|
164
|
-
def __init__(self, api_name: str, start: int):
|
|
165
|
-
self.api_name = api_name
|
|
166
|
-
self.start = start
|
|
167
|
-
self.input_len = 1 # input的数量
|
|
168
|
-
self.params_end_index = start + 1 # params的结束index
|
|
169
|
-
self.output_end_index = start + 1 # output的结束index
|
|
170
|
-
self.params_grad_end_index = start + 1 # params_grad的结束index
|
|
171
|
-
# 内部state的标志("input", "output", "parameters", "parameters_grad"),
|
|
172
|
-
# 用于控制计算input_len, output_end_index, params_end_index, self.params_grad_end_index
|
|
173
|
-
self._state = Const.INPUT # api_batch初始化为input
|
|
174
|
-
|
|
175
|
-
def set_state(self, state: str):
|
|
176
|
-
"""设置当前状态"""
|
|
177
|
-
if state in {Const.INPUT, Const.OUTPUT, Const.KWARGS, Const.PARAMS, Const.PARAMS_GRAD}:
|
|
178
|
-
self._state = state
|
|
179
|
-
else:
|
|
180
|
-
raise ValueError(f"Invalid state: {state}")
|
|
181
|
-
|
|
182
|
-
def increment(self, state: str):
|
|
183
|
-
self.set_state(state)
|
|
184
|
-
if self._state == Const.INPUT or self._state == Const.KWARGS:
|
|
185
|
-
self.input_len += 1
|
|
186
|
-
self.params_end_index += 1
|
|
187
|
-
self.output_end_index += 1
|
|
188
|
-
if self._state == Const.PARAMS:
|
|
189
|
-
self.params_end_index += 1
|
|
190
|
-
self.output_end_index += 1
|
|
191
|
-
if self._state == Const.OUTPUT:
|
|
192
|
-
self.output_end_index += 1
|
|
193
|
-
self.params_grad_end_index += 1
|
|
194
|
-
|
|
195
|
-
|
|
196
184
|
class HighLight:
|
|
197
|
-
def __init__(self, mode_config: ModeConfig):
|
|
185
|
+
def __init__(self, mode_config: ModeConfig, rank):
|
|
198
186
|
self.mode_config = mode_config
|
|
199
|
-
|
|
200
|
-
@staticmethod
|
|
201
|
-
def api_batches_update(api_batches, api_name, state, index):
|
|
202
|
-
"""
|
|
203
|
-
当一个api的所有item更新完后,input, output的索引范围:
|
|
204
|
-
input: [start: start+input_len]
|
|
205
|
-
output: [start+input_len: output_end_index]
|
|
206
|
-
params: [output_end_index: params_end_index]
|
|
207
|
-
"""
|
|
208
|
-
if not api_batches:
|
|
209
|
-
api_batches.append(ApiBatch(api_name, index))
|
|
210
|
-
else:
|
|
211
|
-
api_batch = api_batches[-1]
|
|
212
|
-
if api_batch.api_name == api_name or (
|
|
213
|
-
not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name):
|
|
214
|
-
try:
|
|
215
|
-
api_batch.increment(state)
|
|
216
|
-
except ValueError as e:
|
|
217
|
-
logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}")
|
|
218
|
-
raise CompareException(CompareException.INVALID_STATE_ERROR) from e
|
|
219
|
-
else:
|
|
220
|
-
api_batches.append(ApiBatch(api_name, index))
|
|
187
|
+
self.rank = rank
|
|
221
188
|
|
|
222
189
|
@staticmethod
|
|
223
190
|
def check_indices_numeric(api_items, indices: list):
|
|
@@ -232,7 +199,7 @@ class HighLight:
|
|
|
232
199
|
if CompareConst.NPU_MD5 in result_df.columns:
|
|
233
200
|
return
|
|
234
201
|
|
|
235
|
-
err_msg = result_df.get(CompareConst.ERROR_MESSAGE)
|
|
202
|
+
err_msg = result_df.get(CompareConst.ERROR_MESSAGE).copy()
|
|
236
203
|
red_lines_num_set = highlight_dict.get('red_rows')
|
|
237
204
|
|
|
238
205
|
for color in ['red', 'yellow']:
|
|
@@ -273,12 +240,11 @@ class HighLight:
|
|
|
273
240
|
def find_compare_result_error_rows(self, result_df, highlight_dict):
|
|
274
241
|
"""将dataframe根据API分组,并找到有误差的算子用于高亮"""
|
|
275
242
|
result = result_df.values
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar:
|
|
243
|
+
header = result_df.columns.tolist()
|
|
244
|
+
api_batches = gen_api_batches(result, header)
|
|
245
|
+
default_bar_desc = 'API/Module Analyse Progress'
|
|
246
|
+
bar_desc_add_rank = f'[{self.rank}]' + default_bar_desc if self.rank else default_bar_desc
|
|
247
|
+
with tqdm(total=len(api_batches), desc=bar_desc_add_rank, unit="item", ncols=100) as progress_bar:
|
|
282
248
|
for api_batch in api_batches:
|
|
283
249
|
self.find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch,
|
|
284
250
|
highlight_dict)
|
|
@@ -328,6 +294,13 @@ class HighLight:
|
|
|
328
294
|
api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index)
|
|
329
295
|
self.apply_comparison_rules(api_info, color_columns)
|
|
330
296
|
|
|
297
|
+
# 对单行API的输入或输出进行requires_grad是否一致判断
|
|
298
|
+
for i, line in enumerate(result):
|
|
299
|
+
index = api_batch_start + i
|
|
300
|
+
line_info = LineInfo(line_data=line, num_pointer=index)
|
|
301
|
+
for rule in HighlightRules.consist_rules.values():
|
|
302
|
+
rule.apply(line_info, color_columns, self.mode_config.dump_mode)
|
|
303
|
+
|
|
331
304
|
red_lines_num_set = {x[0] for x in red_lines}
|
|
332
305
|
yellow_lines_num_set = {x[0] for x in yellow_lines}
|
|
333
306
|
highlight_dict.get('red_rows', set()).update(red_lines_num_set)
|
|
@@ -349,28 +322,19 @@ class HighLight:
|
|
|
349
322
|
|
|
350
323
|
self.update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg
|
|
351
324
|
|
|
325
|
+
self.df_malicious_value_check(result_df)
|
|
326
|
+
|
|
352
327
|
wb = openpyxl.Workbook()
|
|
353
328
|
ws = wb.active
|
|
354
|
-
|
|
355
|
-
# write header
|
|
356
|
-
logger.info('Initializing Excel file.')
|
|
357
|
-
|
|
358
|
-
self.handle_multi_process_malicious_value_check(self.df_malicious_value_check, result_df)
|
|
359
|
-
|
|
360
329
|
result_df_convert = result_df.applymap(self.compare_result_df_convert)
|
|
361
|
-
|
|
362
330
|
for row in dataframe_to_rows(result_df_convert, index=False, header=True):
|
|
363
331
|
ws.append(row)
|
|
364
332
|
|
|
365
333
|
# 对可疑数据标色
|
|
366
334
|
logger.info('Coloring Excel in progress.')
|
|
335
|
+
red_fill = PatternFill(start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid")
|
|
336
|
+
yellow_fill = PatternFill(start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid")
|
|
367
337
|
col_len = len(result_df.columns)
|
|
368
|
-
red_fill = PatternFill(
|
|
369
|
-
start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
|
|
370
|
-
)
|
|
371
|
-
yellow_fill = PatternFill(
|
|
372
|
-
start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
|
|
373
|
-
)
|
|
374
338
|
for i in highlight_dict.get("red_rows", []):
|
|
375
339
|
for j in range(1, col_len + 1):
|
|
376
340
|
ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始
|
|
@@ -378,7 +342,6 @@ class HighLight:
|
|
|
378
342
|
for j in range(1, col_len + 1):
|
|
379
343
|
ws.cell(row=i + 2, column=j).fill = yellow_fill
|
|
380
344
|
|
|
381
|
-
logger.info('Saving Excel file to disk: %s' % file_path)
|
|
382
345
|
save_workbook(wb, file_path)
|
|
383
346
|
|
|
384
347
|
def handle_multi_process_malicious_value_check(self, func, result_df):
|
|
@@ -396,22 +359,32 @@ class HighLight:
|
|
|
396
359
|
|
|
397
360
|
def err_call(args):
|
|
398
361
|
logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args))
|
|
399
|
-
try:
|
|
400
|
-
pool.close()
|
|
401
|
-
except OSError:
|
|
402
|
-
logger.error("Pool terminate failed")
|
|
403
362
|
|
|
404
363
|
result_df_columns = result_df.columns.tolist()
|
|
405
364
|
for column in result_df_columns:
|
|
406
365
|
self.value_check(column)
|
|
366
|
+
async_results = []
|
|
407
367
|
for df_chunk in chunks:
|
|
408
|
-
pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
|
|
368
|
+
result = pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
|
|
369
|
+
async_results.append(result)
|
|
409
370
|
|
|
410
371
|
pool.close()
|
|
372
|
+
|
|
373
|
+
for ar in async_results:
|
|
374
|
+
try:
|
|
375
|
+
ar.get(timeout=3600)
|
|
376
|
+
except Exception as e:
|
|
377
|
+
logger.error(f"Task failed with exception: {e}")
|
|
378
|
+
pool.terminate()
|
|
379
|
+
raise CompareException(CompareException.MULTIPROCESS_ERROR) from e
|
|
380
|
+
|
|
411
381
|
pool.join()
|
|
412
382
|
|
|
413
|
-
def df_malicious_value_check(self,
|
|
414
|
-
|
|
383
|
+
def df_malicious_value_check(self, result_df):
|
|
384
|
+
result_df_columns = result_df.columns.tolist()
|
|
385
|
+
for column in result_df_columns:
|
|
386
|
+
self.value_check(column)
|
|
387
|
+
for row in result_df.itertuples(index=False):
|
|
415
388
|
api_name = row[0]
|
|
416
389
|
for i, value in enumerate(row):
|
|
417
390
|
self.value_check(value, api_name, i, result_df_columns)
|
|
@@ -18,12 +18,12 @@ from collections import defaultdict
|
|
|
18
18
|
|
|
19
19
|
from msprobe.core.common.const import CompareConst, Const
|
|
20
20
|
from msprobe.core.common.file_utils import load_json, load_yaml, save_yaml
|
|
21
|
-
from msprobe.core.common.utils import
|
|
22
|
-
|
|
23
|
-
get_stack_construct_by_dump_json_path)
|
|
21
|
+
from msprobe.core.common.utils import add_time_with_yaml, detect_framework_by_dump_json, \
|
|
22
|
+
get_stack_construct_by_dump_json_path, CompareException
|
|
24
23
|
from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items
|
|
25
24
|
from msprobe.core.compare.utils import read_op, reorder_op_name_list
|
|
26
25
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
26
|
+
from msprobe.core.common.log import logger
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class LayerTrie:
|
|
@@ -63,7 +63,11 @@ class LayerTrie:
|
|
|
63
63
|
node = node.children[name]
|
|
64
64
|
if index >= len(node.data_items[state]):
|
|
65
65
|
return default_value
|
|
66
|
-
|
|
66
|
+
if node.data_items[state]:
|
|
67
|
+
return node.data_items[state][index]
|
|
68
|
+
else:
|
|
69
|
+
logger.error(f"node.data_items of state:{state} is empty, please check.")
|
|
70
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
67
71
|
|
|
68
72
|
def save_to_yaml(self, output_path):
|
|
69
73
|
result = {f"{self.type_name} @ {self}": self.convert_to_dict(self)}
|
|
@@ -208,7 +212,8 @@ def generate_data_mapping(npu_json_path, bench_json_path, api_mapping, output_pa
|
|
|
208
212
|
def read_full_op_names(data, op_name):
|
|
209
213
|
op_parsed_list = read_op(data.get(op_name, {}), op_name)
|
|
210
214
|
full_op_names = [op_parsed.get('full_op_name') for op_parsed in op_parsed_list]
|
|
211
|
-
|
|
215
|
+
states = [op_parsed.get(Const.STATE) for op_parsed in op_parsed_list]
|
|
216
|
+
return full_op_names, states
|
|
212
217
|
|
|
213
218
|
def generate_op_data_mapping(npu_op_name, npu_full_op_names, bench_op_name, bench_full_op_names):
|
|
214
219
|
suffix_to_full_op_name = {}
|
|
@@ -228,10 +233,10 @@ def generate_data_mapping(npu_json_path, bench_json_path, api_mapping, output_pa
|
|
|
228
233
|
for npu_op_name, bench_op_name in api_mapping.items():
|
|
229
234
|
if not npu_op_name:
|
|
230
235
|
continue
|
|
231
|
-
npu_full_op_names = read_full_op_names(npu_data, npu_op_name)
|
|
232
|
-
bench_full_op_names = read_full_op_names(bench_data, bench_op_name)
|
|
233
|
-
npu_full_op_names_reorder = reorder_op_name_list(npu_full_op_names)
|
|
234
|
-
bench_full_op_names_reorder = reorder_op_name_list(bench_full_op_names)
|
|
236
|
+
npu_full_op_names, npu_states = read_full_op_names(npu_data, npu_op_name)
|
|
237
|
+
bench_full_op_names, bench_states = read_full_op_names(bench_data, bench_op_name)
|
|
238
|
+
npu_full_op_names_reorder, _ = reorder_op_name_list(npu_full_op_names, npu_states)
|
|
239
|
+
bench_full_op_names_reorder, _ = reorder_op_name_list(bench_full_op_names, bench_states)
|
|
235
240
|
mapping = generate_op_data_mapping(npu_op_name, npu_full_op_names_reorder,
|
|
236
241
|
bench_op_name, bench_full_op_names_reorder)
|
|
237
242
|
data_mapping.update(mapping)
|
|
@@ -109,8 +109,8 @@ def check_index_dump_mode_consistent(dump_mode, rank_num):
|
|
|
109
109
|
return []
|
|
110
110
|
|
|
111
111
|
dump_mode_compare_index_map = {
|
|
112
|
-
Const.ALL: CompareConst.ALL_COMPARE_INDEX,
|
|
113
|
-
Const.SUMMARY: CompareConst.SUMMARY_COMPARE_INDEX
|
|
112
|
+
Const.ALL: CompareConst.ALL_COMPARE_INDEX + [CompareConst.REQ_GRAD_CONSIST],
|
|
113
|
+
Const.SUMMARY: CompareConst.SUMMARY_COMPARE_INDEX + [CompareConst.REQ_GRAD_CONSIST]
|
|
114
114
|
}
|
|
115
115
|
valid_compare_index = dump_mode_compare_index_map.get(dump_mode)
|
|
116
116
|
|