mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
msprobe/core/compare/utils.py
CHANGED
|
@@ -18,13 +18,14 @@ import re
|
|
|
18
18
|
import math
|
|
19
19
|
import zlib
|
|
20
20
|
from dataclasses import dataclass
|
|
21
|
+
import multiprocessing
|
|
21
22
|
|
|
22
23
|
import numpy as np
|
|
23
24
|
import pandas as pd
|
|
24
25
|
|
|
25
26
|
from msprobe.core.common.const import Const, CompareConst, FileCheckConst
|
|
26
27
|
from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value
|
|
27
|
-
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
28
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, load_json
|
|
28
29
|
|
|
29
30
|
json_file_mapping = {
|
|
30
31
|
Const.DUMP_JSON_FILE: "dump.json",
|
|
@@ -94,30 +95,39 @@ def check_and_return_dir_contents(dump_dir, prefix):
|
|
|
94
95
|
|
|
95
96
|
|
|
96
97
|
def read_op(op_data, op_name):
|
|
98
|
+
if not isinstance(op_name, str):
|
|
99
|
+
logger.error(f"api name error: {op_name} is not a string, please check.")
|
|
100
|
+
raise CompareException(CompareException.INVALID_API_NAME_ERROR)
|
|
97
101
|
split_name = op_name.split(Const.SEP)
|
|
98
|
-
if
|
|
99
|
-
op_parsed_list = op_item_parse(op_data, op_name)
|
|
102
|
+
if split_name[-1] == Const.DEBUG:
|
|
103
|
+
op_parsed_list = op_item_parse(op_data, op_name, Const.DEBUG)
|
|
104
|
+
elif split_name[-1] == Const.PARAMS_GRAD:
|
|
105
|
+
op_parsed_list = op_item_parse(op_data, op_name, Const.PARAMS_GRAD)
|
|
100
106
|
else:
|
|
101
107
|
op_parsed_list = []
|
|
102
108
|
for name in CompareConst.IO_NAME_MAPPING:
|
|
103
109
|
if name in op_data:
|
|
104
|
-
op_parsed_list.extend(op_item_parse(op_data[name], op_name + CompareConst.IO_NAME_MAPPING[name]))
|
|
110
|
+
op_parsed_list.extend(op_item_parse(op_data[name], op_name + CompareConst.IO_NAME_MAPPING[name], name))
|
|
105
111
|
return op_parsed_list
|
|
106
112
|
|
|
107
113
|
|
|
108
|
-
def op_item_parse(op_data, op_name: str, depth: int = 0) -> list:
|
|
114
|
+
def op_item_parse(op_data, op_name: str, state: str, depth: int = 0) -> list:
|
|
115
|
+
if state == Const.INPUT_ARGS or state == Const.INPUT_KWARGS:
|
|
116
|
+
state = Const.INPUT
|
|
109
117
|
default_item = {
|
|
110
118
|
'full_op_name': op_name,
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
119
|
+
Const.TYPE: None,
|
|
120
|
+
Const.MAX: None,
|
|
121
|
+
Const.MIN: None,
|
|
122
|
+
Const.MEAN: None,
|
|
123
|
+
Const.NORM: None,
|
|
124
|
+
Const.DTYPE: None,
|
|
125
|
+
Const.SHAPE: None,
|
|
126
|
+
Const.MD5: None,
|
|
127
|
+
Const.VALUE: None,
|
|
128
|
+
Const.DATA_NAME: '-1',
|
|
129
|
+
Const.STATE: state,
|
|
130
|
+
Const.REQ_GRAD: None
|
|
121
131
|
}
|
|
122
132
|
|
|
123
133
|
if depth > Const.MAX_DEPTH:
|
|
@@ -133,33 +143,53 @@ def op_item_parse(op_data, op_name: str, depth: int = 0) -> list:
|
|
|
133
143
|
if isinstance(op_data, list):
|
|
134
144
|
for i, data in enumerate(op_data):
|
|
135
145
|
if Const.PARAMS_GRAD not in op_name.split(Const.SEP):
|
|
136
|
-
item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), depth + 1))
|
|
146
|
+
item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), state, depth + 1))
|
|
137
147
|
else:
|
|
138
|
-
item_list.extend(op_item_parse(data, op_name, depth + 1))
|
|
148
|
+
item_list.extend(op_item_parse(data, op_name, state, depth + 1))
|
|
139
149
|
elif isinstance(op_data, dict):
|
|
150
|
+
if is_p2pop_leaf_data(op_data):
|
|
151
|
+
p2pop_item = {}
|
|
152
|
+
for key in ['class_type', 'op', 'peer', 'tag', 'group_id']:
|
|
153
|
+
p2pop_item[key] = op_data.get(key)
|
|
154
|
+
op_data = op_data.get('tensor')
|
|
155
|
+
if isinstance(op_data, dict):
|
|
156
|
+
op_item = gen_op_item(op_data, op_name, state)
|
|
157
|
+
else:
|
|
158
|
+
op_item = default_item
|
|
159
|
+
op_item.update(p2pop_item)
|
|
160
|
+
return [op_item]
|
|
140
161
|
if is_leaf_data(op_data):
|
|
141
|
-
return [gen_op_item(op_data, op_name)]
|
|
162
|
+
return [gen_op_item(op_data, op_name, state)]
|
|
142
163
|
for sub_name, sub_data in op_data.items():
|
|
143
|
-
item_list.extend(op_item_parse(sub_data, op_name + Const.SEP + str(sub_name), depth + 1))
|
|
164
|
+
item_list.extend(op_item_parse(sub_data, op_name + Const.SEP + str(sub_name), state, depth + 1))
|
|
144
165
|
return item_list
|
|
145
166
|
|
|
146
167
|
|
|
168
|
+
def is_p2pop_leaf_data(op_data):
|
|
169
|
+
return op_data.get('class_type') == 'torch.distributed.P2POp'
|
|
170
|
+
|
|
171
|
+
|
|
147
172
|
def is_leaf_data(op_data):
|
|
148
173
|
return 'type' in op_data and isinstance(op_data['type'], str)
|
|
149
174
|
|
|
150
175
|
|
|
151
|
-
def gen_op_item(op_data, op_name):
|
|
176
|
+
def gen_op_item(op_data, op_name, state):
|
|
152
177
|
op_item = {}
|
|
153
|
-
op_item.update(op_data)
|
|
154
|
-
data_name = op_data.get(
|
|
155
|
-
op_item[
|
|
178
|
+
op_item.update({key: str(value) if isinstance(value, bool) else value for key, value in op_data.items()})
|
|
179
|
+
data_name = op_data.get(Const.DATA_NAME) if op_data.get(Const.DATA_NAME) else '-1' # 如果是""也返回-1
|
|
180
|
+
op_item[Const.DATA_NAME] = data_name
|
|
156
181
|
op_item['full_op_name'] = data_name.rsplit(Const.SEP, 1)[0] if data_name != '-1' else op_name
|
|
182
|
+
op_item[Const.STATE] = state
|
|
183
|
+
if Const.REQ_GRAD not in op_item:
|
|
184
|
+
op_item[Const.REQ_GRAD] = None
|
|
157
185
|
|
|
158
|
-
|
|
186
|
+
# 补齐统计量字段
|
|
187
|
+
params = [Const.MAX, Const.MIN, Const.MEAN, Const.NORM]
|
|
159
188
|
for i in params:
|
|
160
189
|
if i not in op_item:
|
|
161
190
|
op_item[i] = None
|
|
162
191
|
|
|
192
|
+
# special cases
|
|
163
193
|
if not op_item.get('dtype'):
|
|
164
194
|
if op_item.get('type') == 'torch.Size':
|
|
165
195
|
op_item['dtype'] = op_data.get('type')
|
|
@@ -172,11 +202,18 @@ def gen_op_item(op_data, op_name):
|
|
|
172
202
|
op_item['shape'] = '[]'
|
|
173
203
|
for i in params:
|
|
174
204
|
op_item[i] = op_data.get('value')
|
|
205
|
+
elif op_name.split(Const.SEP)[-1] in ['src', 'dst', 'group_src', 'group_dst']:
|
|
206
|
+
op_item['dtype'] = op_data.get('type')
|
|
207
|
+
op_item['shape'] = '[]'
|
|
208
|
+
for i in params:
|
|
209
|
+
op_item[i] = str(op_data.get('value'))
|
|
210
|
+
op_item['md5'] = str(op_data.get('value'))
|
|
175
211
|
elif op_item.get('type') == 'torch.ProcessGroup':
|
|
176
212
|
op_item['dtype'] = op_data.get('type')
|
|
177
213
|
op_item['shape'] = '[]'
|
|
178
214
|
for i in params:
|
|
179
215
|
op_item[i] = str(op_data.get('group_ranks'))
|
|
216
|
+
op_item['md5'] = str(op_data.get('group_ranks'))
|
|
180
217
|
else:
|
|
181
218
|
op_item['dtype'] = str(type(op_data.get('value')))
|
|
182
219
|
op_item['shape'] = '[]'
|
|
@@ -205,22 +242,26 @@ def merge_tensor(tensor_list, dump_mode):
|
|
|
205
242
|
CompareConst.PARAMS_GRAD_STRUCT,
|
|
206
243
|
CompareConst.DEBUG_STRUCT,
|
|
207
244
|
Const.SUMMARY,
|
|
208
|
-
Const.STACK_INFO
|
|
245
|
+
Const.STACK_INFO,
|
|
246
|
+
Const.STATE,
|
|
247
|
+
Const.REQ_GRAD
|
|
209
248
|
]
|
|
210
249
|
op_dict = {key: [] for key in keys}
|
|
211
250
|
|
|
212
251
|
if dump_mode == Const.ALL:
|
|
213
|
-
op_dict[
|
|
252
|
+
op_dict[Const.DATA_NAME] = []
|
|
214
253
|
|
|
215
254
|
for tensor in tensor_list:
|
|
216
255
|
# A dict(len=2) with 'full_op_name' and 'full_info' is added to the tensor only if self.stack_mode is True
|
|
217
256
|
if len(tensor) == 2:
|
|
218
|
-
op_dict[Const.STACK_INFO].append(tensor
|
|
257
|
+
op_dict[Const.STACK_INFO].append(tensor.get('full_info'))
|
|
219
258
|
break
|
|
220
259
|
|
|
221
|
-
op_dict[CompareConst.OP_NAME].append(tensor
|
|
260
|
+
op_dict[CompareConst.OP_NAME].append(tensor.get('full_op_name'))
|
|
261
|
+
state = tensor.get(Const.STATE)
|
|
262
|
+
op_dict[Const.STATE].append(state)
|
|
263
|
+
op_dict[Const.REQ_GRAD].append(tensor.get(Const.REQ_GRAD))
|
|
222
264
|
|
|
223
|
-
_, state = get_name_and_state(tensor['full_op_name'])
|
|
224
265
|
struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
|
|
225
266
|
if not struct_key:
|
|
226
267
|
continue
|
|
@@ -228,22 +269,19 @@ def merge_tensor(tensor_list, dump_mode):
|
|
|
228
269
|
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
|
|
229
270
|
else:
|
|
230
271
|
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
|
|
231
|
-
|
|
272
|
+
|
|
273
|
+
# 当统计量为None时,转成字符串None,避免后续操作list放到pd中时None被默认转成NaN
|
|
274
|
+
op_dict[Const.SUMMARY].append(
|
|
275
|
+
[str(tensor[key]) if tensor[key] is None else tensor[key] for key in Const.SUMMARY_METRICS_LIST])
|
|
232
276
|
|
|
233
277
|
if dump_mode == Const.ALL:
|
|
234
|
-
op_dict[
|
|
278
|
+
op_dict[Const.DATA_NAME].append(tensor.get(Const.DATA_NAME))
|
|
235
279
|
|
|
236
280
|
if not op_dict[CompareConst.KWARGS_STRUCT]:
|
|
237
281
|
del op_dict[CompareConst.KWARGS_STRUCT]
|
|
238
282
|
return op_dict if op_dict[CompareConst.OP_NAME] else {}
|
|
239
283
|
|
|
240
284
|
|
|
241
|
-
def check_api_info_len(op_name, info_list, len_require):
|
|
242
|
-
if len(info_list) < len_require:
|
|
243
|
-
logger.error(f'Index out of bounds error, please check info of api: {op_name}.')
|
|
244
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
245
|
-
|
|
246
|
-
|
|
247
285
|
def print_compare_ends_info():
|
|
248
286
|
total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
|
|
249
287
|
logger.info('*' * total_len)
|
|
@@ -263,83 +301,113 @@ def table_value_is_valid(value: str) -> bool:
|
|
|
263
301
|
return True
|
|
264
302
|
|
|
265
303
|
|
|
266
|
-
|
|
304
|
+
class ApiBatch:
|
|
305
|
+
def __init__(self, api_name: str, start: int):
|
|
306
|
+
self.api_name = api_name
|
|
307
|
+
self.start = start
|
|
308
|
+
self.input_len = 1 # input的数量
|
|
309
|
+
self.params_end_index = start + 1 # params的结束index
|
|
310
|
+
self.output_end_index = start + 1 # output的结束index
|
|
311
|
+
self.params_grad_end_index = start + 1 # params_grad的结束index
|
|
312
|
+
# 内部state的标志("input", "output", "parameters", "parameters_grad"),
|
|
313
|
+
# 用于控制计算input_len, output_end_index, params_end_index, self.params_grad_end_index
|
|
314
|
+
self._state = Const.INPUT # api_batch初始化为input
|
|
315
|
+
|
|
316
|
+
def set_state(self, state: str):
|
|
317
|
+
"""设置当前状态"""
|
|
318
|
+
if state in {Const.INPUT, Const.OUTPUT, Const.KWARGS, Const.PARAMS, Const.PARAMS_GRAD}:
|
|
319
|
+
self._state = state
|
|
320
|
+
else:
|
|
321
|
+
raise ValueError(f"Invalid state: {state}")
|
|
322
|
+
|
|
323
|
+
def increment(self, state: str):
|
|
324
|
+
self.set_state(state)
|
|
325
|
+
if self._state == Const.INPUT or self._state == Const.KWARGS:
|
|
326
|
+
self.input_len += 1
|
|
327
|
+
self.params_end_index += 1
|
|
328
|
+
self.output_end_index += 1
|
|
329
|
+
if self._state == Const.PARAMS:
|
|
330
|
+
self.params_end_index += 1
|
|
331
|
+
self.output_end_index += 1
|
|
332
|
+
if self._state == Const.OUTPUT:
|
|
333
|
+
self.output_end_index += 1
|
|
334
|
+
self.params_grad_end_index += 1
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def api_batches_update(api_batches, api_name, state, index):
|
|
267
338
|
"""
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
339
|
+
当一个api的所有item更新完后,input, output的索引范围:
|
|
340
|
+
input: [start: start+input_len]
|
|
341
|
+
output: [start+input_len: output_end_index]
|
|
342
|
+
params: [output_end_index: params_end_index]
|
|
343
|
+
"""
|
|
344
|
+
if not api_batches:
|
|
345
|
+
api_batches.append(ApiBatch(api_name, index))
|
|
346
|
+
else:
|
|
347
|
+
api_batch = api_batches[-1]
|
|
348
|
+
if api_batch.api_name == api_name or (
|
|
349
|
+
not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name):
|
|
350
|
+
try:
|
|
351
|
+
api_batch.increment(state)
|
|
352
|
+
except ValueError as e:
|
|
353
|
+
logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}")
|
|
354
|
+
raise CompareException(CompareException.INVALID_STATE_ERROR) from e
|
|
355
|
+
else:
|
|
356
|
+
api_batches.append(ApiBatch(api_name, index))
|
|
275
357
|
|
|
276
|
-
name = 'x_tensor.0.debug.{index}'
|
|
277
|
-
return: ('x_tensor.0.', 'debug')
|
|
278
358
|
|
|
279
|
-
|
|
359
|
+
def reorder_index(op_parsed_list):
|
|
280
360
|
"""
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
361
|
+
对单个api解析的op_items的index进行重排,将parameter的index放到output前面,返回新的重排后的index列表,op_parsed_list不变
|
|
362
|
+
"""
|
|
363
|
+
index_param = []
|
|
364
|
+
index_output = []
|
|
365
|
+
index_param_grad = []
|
|
366
|
+
index_other = []
|
|
367
|
+
for i, op_item in enumerate(op_parsed_list[:-1]):
|
|
368
|
+
state = op_item.get(Const.STATE)
|
|
369
|
+
if state == Const.PARAMS:
|
|
370
|
+
index_param.append(i)
|
|
371
|
+
elif state == Const.OUTPUT:
|
|
372
|
+
index_output.append(i)
|
|
373
|
+
elif state == Const.PARAMS_GRAD:
|
|
374
|
+
index_param_grad.append(i)
|
|
375
|
+
else:
|
|
376
|
+
index_other.append(i)
|
|
377
|
+
# 合并others, parameters, 和output,确保parameters排在output前面
|
|
378
|
+
reordered_index_list = index_other + index_param + index_output + index_param_grad
|
|
379
|
+
return reordered_index_list
|
|
284
380
|
|
|
285
|
-
if Const.DEBUG in name.split(Const.SEP):
|
|
286
|
-
return name.split(Const.DEBUG)[0], Const.DEBUG
|
|
287
|
-
if Const.PARAMS_GRAD in name.split(Const.SEP):
|
|
288
|
-
return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD
|
|
289
381
|
|
|
290
|
-
|
|
291
|
-
if len(split) < 3:
|
|
292
|
-
logger.error(f'Invalid name string: {name}, can not be split by forward/backward, please check.')
|
|
293
|
-
raise CompareException(CompareException.INVALID_API_NAME_ERROR)
|
|
294
|
-
api = f'{split[0]}.{split[1]}.'
|
|
295
|
-
state_str = split[2]
|
|
296
|
-
match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str)
|
|
297
|
-
if not match:
|
|
298
|
-
raise CompareException(f'Invalid name string: {name}')
|
|
299
|
-
if match.group(1):
|
|
300
|
-
api = f'{api}{match.group(1)}'
|
|
301
|
-
state = match.group(2)
|
|
302
|
-
return api, state
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
def reorder_op_name_list(op_name_list):
|
|
382
|
+
def reorder_op_name_list(op_name_list, state_list):
|
|
306
383
|
if not op_name_list:
|
|
307
|
-
return op_name_list
|
|
384
|
+
return op_name_list, state_list
|
|
308
385
|
|
|
309
386
|
parameters = []
|
|
310
387
|
output = []
|
|
311
388
|
parameters_grad = []
|
|
312
389
|
others = []
|
|
313
|
-
|
|
314
|
-
|
|
390
|
+
parameters_s = []
|
|
391
|
+
output_s = []
|
|
392
|
+
parameters_grad_s = []
|
|
393
|
+
others_s = []
|
|
394
|
+
for op_name, state in zip(op_name_list, state_list):
|
|
315
395
|
if state == Const.PARAMS:
|
|
316
|
-
parameters.append(
|
|
396
|
+
parameters.append(op_name)
|
|
397
|
+
parameters_s.append(state)
|
|
317
398
|
elif state == Const.OUTPUT:
|
|
318
|
-
output.append(
|
|
399
|
+
output.append(op_name)
|
|
400
|
+
output_s.append(state)
|
|
319
401
|
elif state == Const.PARAMS_GRAD:
|
|
320
|
-
parameters_grad.append(
|
|
402
|
+
parameters_grad.append(op_name)
|
|
403
|
+
parameters_grad_s.append(state)
|
|
321
404
|
else:
|
|
322
|
-
others.append(
|
|
405
|
+
others.append(op_name)
|
|
406
|
+
others_s.append(state)
|
|
323
407
|
# 合并others, parameters, 和output,确保parameters排在output前面
|
|
324
408
|
op_name_reorder = others + parameters + output + parameters_grad
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
def reorder_op_x_list(op_name_list, summary_list, data_name_list):
|
|
329
|
-
"""对op_name, summary, data_name重新排序,把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理"""
|
|
330
|
-
if not op_name_list or not summary_list:
|
|
331
|
-
return op_name_list, summary_list, data_name_list
|
|
332
|
-
|
|
333
|
-
index_map = {name: index for index, name in enumerate(op_name_list)}
|
|
334
|
-
|
|
335
|
-
op_name_reorder = reorder_op_name_list(op_name_list)
|
|
336
|
-
summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder]
|
|
337
|
-
if data_name_list:
|
|
338
|
-
data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder]
|
|
339
|
-
else:
|
|
340
|
-
data_name_reorder = data_name_list
|
|
341
|
-
|
|
342
|
-
return op_name_reorder, summary_reorder, data_name_reorder
|
|
409
|
+
state_reorder = others_s + parameters_s + output_s + parameters_grad_s
|
|
410
|
+
return op_name_reorder, state_reorder
|
|
343
411
|
|
|
344
412
|
|
|
345
413
|
def process_summary_data(summary_data):
|
|
@@ -393,17 +461,22 @@ def stack_column_process(result_item, has_stack, index, key, npu_stack_info):
|
|
|
393
461
|
return result_item
|
|
394
462
|
|
|
395
463
|
|
|
396
|
-
def result_item_init(n_info, b_info, dump_mode):
|
|
464
|
+
def result_item_init(n_info, b_info, requires_grad_pair, dump_mode):
|
|
397
465
|
n_len = len(n_info.struct)
|
|
398
466
|
b_len = len(b_info.struct)
|
|
467
|
+
# requires_grad_pair内部创建,固定两个元素
|
|
468
|
+
n_requires_grad = requires_grad_pair[0]
|
|
469
|
+
b_requires_grad = requires_grad_pair[1]
|
|
470
|
+
req_grad_consist = n_requires_grad == b_requires_grad
|
|
399
471
|
struct_long_enough = (n_len > 2 and b_len > 2) if dump_mode == Const.MD5 else (n_len > 1 and b_len > 1)
|
|
400
472
|
if struct_long_enough:
|
|
401
473
|
result_item = [
|
|
402
|
-
n_info.name, b_info.name, n_info.struct[0], b_info.struct[0], n_info.struct[1], b_info.struct[1]
|
|
474
|
+
n_info.name, b_info.name, n_info.struct[0], b_info.struct[0], n_info.struct[1], b_info.struct[1],
|
|
475
|
+
n_requires_grad, b_requires_grad
|
|
403
476
|
]
|
|
404
477
|
if dump_mode == Const.MD5:
|
|
405
478
|
md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF
|
|
406
|
-
result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result])
|
|
479
|
+
result_item.extend([n_info.struct[2], b_info.struct[2], req_grad_consist, md5_compare_result])
|
|
407
480
|
elif dump_mode == Const.SUMMARY:
|
|
408
481
|
result_item.extend([" "] * 8) # 8个统计量数据情况的比对指标
|
|
409
482
|
else:
|
|
@@ -449,11 +522,15 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
449
522
|
b_name = safe_get_value(b_dict, b_start + index, "b_dict", key="op_name")
|
|
450
523
|
n_struct = safe_get_value(n_dict, index, "n_dict", key=key)
|
|
451
524
|
b_struct = safe_get_value(b_dict, index, "b_dict", key=key)
|
|
525
|
+
n_requires_grad = safe_get_value(n_dict, n_start + index, "n_dict", key='requires_grad')
|
|
526
|
+
b_requires_grad = safe_get_value(b_dict, b_start + index, "b_dict", key='requires_grad')
|
|
527
|
+
requires_grad_pair = [n_requires_grad, b_requires_grad]
|
|
528
|
+
req_grad_consist = n_requires_grad == b_requires_grad
|
|
452
529
|
err_msg = ""
|
|
453
530
|
|
|
454
531
|
npu_info = ApiItemInfo(n_name, n_struct, npu_stack_info)
|
|
455
532
|
bench_info = ApiItemInfo(b_name, b_struct, bench_stack_info)
|
|
456
|
-
result_item = result_item_init(npu_info, bench_info, dump_mode)
|
|
533
|
+
result_item = result_item_init(npu_info, bench_info, requires_grad_pair, dump_mode)
|
|
457
534
|
|
|
458
535
|
if dump_mode == Const.MD5:
|
|
459
536
|
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
@@ -469,6 +546,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
469
546
|
result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
|
|
470
547
|
bench_summary_data, err_msg)
|
|
471
548
|
|
|
549
|
+
result_item.append(req_grad_consist)
|
|
550
|
+
err_msg += "Requires_grad inconsistent." if not req_grad_consist else ""
|
|
472
551
|
result_item.append(accuracy_check if dump_mode == Const.SUMMARY else CompareConst.ACCURACY_CHECK_YES)
|
|
473
552
|
result_item.append(err_msg)
|
|
474
553
|
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
@@ -482,23 +561,30 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
482
561
|
if n_len > b_len:
|
|
483
562
|
for index in range(b_len, n_len):
|
|
484
563
|
try:
|
|
485
|
-
n_name = n_dict
|
|
486
|
-
n_struct = n_dict
|
|
564
|
+
n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name")
|
|
565
|
+
n_struct = safe_get_value(n_dict, index, "n_dict", key=key)
|
|
566
|
+
n_requires_grad = safe_get_value(n_dict, n_start + index, "n_dict", key='requires_grad')
|
|
567
|
+
|
|
487
568
|
if dump_mode == Const.MD5:
|
|
488
569
|
result_item = [
|
|
489
570
|
n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
|
|
490
|
-
|
|
571
|
+
n_requires_grad, CompareConst.NAN,
|
|
572
|
+
n_struct[2], CompareConst.NAN,
|
|
573
|
+
False,
|
|
574
|
+
CompareConst.NAN
|
|
491
575
|
]
|
|
492
576
|
result.append(result_item)
|
|
493
577
|
continue
|
|
494
578
|
result_item = [
|
|
495
579
|
n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
|
|
580
|
+
n_requires_grad, CompareConst.NAN,
|
|
496
581
|
" ", " ", " ", " ", " ", " "
|
|
497
582
|
]
|
|
498
583
|
summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
|
|
499
584
|
result_item.extend(summary_data)
|
|
500
585
|
summary_data = [CompareConst.NAN for _ in range(len(n_dict.get(CompareConst.SUMMARY)[0]))]
|
|
501
586
|
result_item.extend(summary_data)
|
|
587
|
+
result_item.append(False)
|
|
502
588
|
except IndexError as e:
|
|
503
589
|
err_msg = "index out of bounds error occurs, please check!\n" \
|
|
504
590
|
f"n_dict is {n_dict}"
|
|
@@ -546,6 +632,23 @@ def make_result_table(result, dump_mode, stack_mode):
|
|
|
546
632
|
return result_df
|
|
547
633
|
|
|
548
634
|
|
|
635
|
+
def gen_api_batches(result: np.ndarray, header: list):
|
|
636
|
+
api_name_index = header.index(Const.API_ORIGIN_NAME)
|
|
637
|
+
state_name_index = header.index(Const.STATE)
|
|
638
|
+
api_batches = []
|
|
639
|
+
for i, res_i in enumerate(result):
|
|
640
|
+
api_name = safe_get_value(res_i, api_name_index, "res_i")
|
|
641
|
+
state = safe_get_value(res_i, state_name_index, "res_i")
|
|
642
|
+
api_batches_update(api_batches, api_name, state, i)
|
|
643
|
+
return api_batches
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
def get_paired_dirs(npu_path, bench_path):
|
|
647
|
+
npu_dirs = set(os.listdir(npu_path))
|
|
648
|
+
bench_dirs = set(os.listdir(bench_path))
|
|
649
|
+
return list(npu_dirs & bench_dirs)
|
|
650
|
+
|
|
651
|
+
|
|
549
652
|
def _compare_parser(parser):
|
|
550
653
|
parser.add_argument("-i", "--input_path", dest="input_path", type=str,
|
|
551
654
|
help="<Required> The compare input path, a dict json.", required=True)
|
|
@@ -558,6 +661,8 @@ def _compare_parser(parser):
|
|
|
558
661
|
help="<optional> Whether to give advisor.", required=False)
|
|
559
662
|
parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true",
|
|
560
663
|
help="<optional> Whether to perform a fuzzy match on the api name.", required=False)
|
|
664
|
+
parser.add_argument("-hl", "--highlight", dest="highlight", action="store_true",
|
|
665
|
+
help="<optional> Whether to set result highlighting.", required=False)
|
|
561
666
|
parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True,
|
|
562
667
|
help="<optional> The cell mapping file path.", required=False)
|
|
563
668
|
parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
|
|
@@ -566,40 +671,143 @@ def _compare_parser(parser):
|
|
|
566
671
|
help="<optional> The data mapping file path.", required=False)
|
|
567
672
|
parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True,
|
|
568
673
|
help="<optional> The layer mapping file path.", required=False)
|
|
674
|
+
parser.add_argument("-da", "--diff_analyze", dest="diff_analyze", action="store_true",
|
|
675
|
+
help="<optional> Whether to perform a diff analyze on the api name.", required=False)
|
|
569
676
|
|
|
570
677
|
|
|
571
|
-
def
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
678
|
+
def get_sorted_ranks(npu_dump_dir, bench_dump_dir):
|
|
679
|
+
"""
|
|
680
|
+
get the ranks and match by order
|
|
681
|
+
"""
|
|
682
|
+
unsorted_npu_ranks = check_and_return_dir_contents(npu_dump_dir, 'rank')
|
|
683
|
+
unsorted_bench_ranks = check_and_return_dir_contents(bench_dump_dir, 'rank')
|
|
684
|
+
# 正则匹配已经校验rank后面必是数字,或者无数字的rank
|
|
685
|
+
npu_ranks = sorted(unsorted_npu_ranks, key=lambda x: int(x[4:]) if len(x) > 4 else -1) # 前四个字符都是rank,后面是卡号
|
|
686
|
+
bench_ranks = sorted(unsorted_bench_ranks, key=lambda x: int(x[4:]) if len(x) > 4 else -1)
|
|
579
687
|
if len(npu_ranks) != len(bench_ranks):
|
|
580
688
|
logger.error('The number of ranks in the two runs are different. '
|
|
581
689
|
'Unable to match the ranks. Please use another folder to compare '
|
|
582
690
|
'or use compare() api and manually match the ranks.')
|
|
583
691
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
584
|
-
|
|
692
|
+
return npu_ranks, bench_ranks
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
def multi_statistics_compare(func, func_args):
|
|
696
|
+
def err_call(args):
|
|
697
|
+
logger.error(f'Multiprocess statistics compare failed! Reason: {args}')
|
|
698
|
+
|
|
699
|
+
compare_func, input_param_nr_list, output_path, kwargs = func_args
|
|
700
|
+
|
|
701
|
+
param_num = len(input_param_nr_list)
|
|
702
|
+
process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
|
|
703
|
+
if param_num <= process_num:
|
|
704
|
+
process_num = param_num
|
|
705
|
+
chunks = [[input_param_nr] for input_param_nr in input_param_nr_list]
|
|
706
|
+
else:
|
|
707
|
+
chunk_size = param_num // process_num
|
|
708
|
+
remainder = param_num % process_num
|
|
709
|
+
chunks = [input_param_nr_list[i:i + chunk_size] for i in range(0, param_num - remainder, chunk_size)]
|
|
710
|
+
for i in range(remainder):
|
|
711
|
+
chunks[i].append(input_param_nr_list[param_num - remainder + i])
|
|
712
|
+
|
|
713
|
+
pool = multiprocessing.Pool(process_num)
|
|
714
|
+
|
|
715
|
+
async_results = []
|
|
716
|
+
for chunk in chunks:
|
|
717
|
+
result = pool.apply_async(func, args=(compare_func, chunk, output_path, kwargs), error_callback=err_call)
|
|
718
|
+
async_results.append(result)
|
|
719
|
+
|
|
720
|
+
pool.close()
|
|
721
|
+
|
|
722
|
+
for ar in async_results:
|
|
723
|
+
try:
|
|
724
|
+
ar.get(timeout=3600)
|
|
725
|
+
except Exception as e:
|
|
726
|
+
logger.error(f"Task failed with exception: {e}")
|
|
727
|
+
pool.terminate()
|
|
728
|
+
raise CompareException(CompareException.MULTIPROCESS_ERROR) from e
|
|
729
|
+
|
|
730
|
+
pool.join()
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
def mp_logger_init(ranks_str):
|
|
734
|
+
"""
|
|
735
|
+
多进程比对需要对logger进行wrap和patch,在日志前加上卡号信息,从而实现不同进程日志的隔离
|
|
736
|
+
"""
|
|
737
|
+
|
|
738
|
+
def wrap_logger(fn):
|
|
739
|
+
def inner(msg, *args, **kwargs):
|
|
740
|
+
return fn(ranks_str + msg, *args, **kwargs)
|
|
741
|
+
return inner
|
|
742
|
+
|
|
743
|
+
logger.info = wrap_logger(logger.info)
|
|
744
|
+
logger.warning = wrap_logger(logger.warning)
|
|
745
|
+
logger.error = wrap_logger(logger.error)
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
def multi_ranks_compare(compare_func, input_param_nr_list, output_path, kwargs):
|
|
749
|
+
"""
|
|
750
|
+
将多卡数据分成多进程后,单进程内可能还有多张卡的数据,因此还需要多次比对
|
|
751
|
+
"""
|
|
752
|
+
rank_list = [input_param_nr[1] for input_param_nr in input_param_nr_list] # input_param_nr内部数据结构,2元素tuple
|
|
753
|
+
ranks_str = f"[{' '.join(rank_list)}]"
|
|
754
|
+
mp_logger_init(ranks_str)
|
|
755
|
+
for input_param_nr in input_param_nr_list:
|
|
756
|
+
input_param, nr = input_param_nr
|
|
757
|
+
compare_entry(compare_func, input_param, output_path, nr, kwargs)
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
def compare_entry(compare_func, input_param, output_path, nr, kwargs):
|
|
761
|
+
try:
|
|
762
|
+
compare_func(input_param=input_param, output_path=output_path, suffix=f'_{nr}', **kwargs)
|
|
763
|
+
except CompareException as e:
|
|
764
|
+
if e.code == CompareException.INVALID_DATA_ERROR:
|
|
765
|
+
logger.error(f"Invalid or missing 'data' in dump.json. Skipping {nr} comparison.")
|
|
766
|
+
if e.code == CompareException.INVALID_TASK_ERROR:
|
|
767
|
+
logger.error(f"Invalid or missing 'task' in dump.json. Skipping {nr} comparison.")
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, compare_func, **kwargs):
|
|
771
|
+
def extract_compare_param(_file_type):
|
|
585
772
|
npu_data_dir = os.path.join(npu_dump_dir, nr)
|
|
586
773
|
bench_data_dir = os.path.join(bench_dump_dir, br)
|
|
774
|
+
npu_path = extract_json(npu_data_dir, _file_type)
|
|
775
|
+
bench_path = extract_json(bench_data_dir, _file_type)
|
|
776
|
+
if npu_path == "" or bench_path == "":
|
|
777
|
+
logger.debug(f'Did not find paired {_file_type} in {nr} and {br}, skip comparing.')
|
|
778
|
+
return {}, True
|
|
779
|
+
_input_param = {
|
|
780
|
+
'npu_json_path': npu_path,
|
|
781
|
+
'bench_json_path': bench_path,
|
|
782
|
+
'is_print_compare_log': kwargs.get('is_print_compare_log', True)
|
|
783
|
+
}
|
|
784
|
+
return _input_param, False
|
|
785
|
+
|
|
786
|
+
if kwargs.get('suffix'):
|
|
787
|
+
logger.error("Argument 'suffix' is not supported for compare_distributed.")
|
|
788
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
789
|
+
|
|
790
|
+
npu_ranks, bench_ranks = get_sorted_ranks(npu_dump_dir, bench_dump_dir)
|
|
791
|
+
|
|
792
|
+
# 统计量、md5比对
|
|
793
|
+
pre_check_dump_path = os.path.join(npu_dump_dir, npu_ranks[0], 'dump.json') if npu_ranks else ''
|
|
794
|
+
if not pre_check_dump_path:
|
|
795
|
+
return
|
|
796
|
+
dump_data = load_json(pre_check_dump_path)
|
|
797
|
+
if dump_data.get('task') == Const.STATISTICS:
|
|
798
|
+
# dump数据为统计量或md5时,多进程加速比对
|
|
799
|
+
input_param_nr_list = []
|
|
800
|
+
for nr, br in zip(npu_ranks, bench_ranks):
|
|
801
|
+
input_param, skip = extract_compare_param(Const.DUMP_JSON_FILE)
|
|
802
|
+
if not skip:
|
|
803
|
+
input_param_nr_list.append((input_param, nr))
|
|
804
|
+
func_args = (compare_func, input_param_nr_list, output_path, kwargs)
|
|
805
|
+
multi_statistics_compare(multi_ranks_compare, func_args)
|
|
806
|
+
return
|
|
807
|
+
|
|
808
|
+
# 真实数据比对
|
|
809
|
+
for nr, br in zip(npu_ranks, bench_ranks):
|
|
587
810
|
for file_type in [Const.DUMP_JSON_FILE, Const.DEBUG_JSON_FILE]:
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
logger.debug(f'Did not find paired {file_type} in {npu_data_dir} and {bench_data_dir},'
|
|
592
|
-
' skip comparing.')
|
|
593
|
-
continue
|
|
594
|
-
dump_result_param = {
|
|
595
|
-
'npu_json_path': npu_path,
|
|
596
|
-
'bench_json_path': bench_path,
|
|
597
|
-
'is_print_compare_log': is_print_compare_log
|
|
598
|
-
}
|
|
599
|
-
try:
|
|
600
|
-
compare_func(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}', **kwargs)
|
|
601
|
-
except CompareException as e:
|
|
602
|
-
if e.code == CompareException.INVALID_DATA_ERROR:
|
|
603
|
-
logger.error(f"Invalid or missing 'data' in dump.json. Skipping {nr} comparison.")
|
|
604
|
-
if e.code == CompareException.INVALID_TASK_ERROR:
|
|
605
|
-
logger.error(f"Invalid or missing 'task' in dump.json. Skipping {nr} comparison.")
|
|
811
|
+
input_param, skip = extract_compare_param(file_type)
|
|
812
|
+
if not skip:
|
|
813
|
+
compare_entry(compare_func, input_param, output_path, nr, kwargs)
|