mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -25,16 +25,18 @@ from tqdm import tqdm
|
|
|
25
25
|
from msprobe.core.advisor.advisor import Advisor
|
|
26
26
|
from msprobe.core.common.const import CompareConst, Const
|
|
27
27
|
from msprobe.core.common.exceptions import FileCheckException
|
|
28
|
-
from msprobe.core.common.file_utils import load_json, remove_path, create_directory
|
|
28
|
+
from msprobe.core.common.file_utils import load_json, remove_path, create_directory, save_excel, save_json
|
|
29
29
|
from msprobe.core.common.log import logger
|
|
30
30
|
from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, \
|
|
31
|
-
set_dump_path, get_dump_mode, check_compare_param,
|
|
32
|
-
from msprobe.core.compare.check import check_dump_json_str, check_stack_json_str, cross_dtype_mapping
|
|
33
|
-
|
|
34
|
-
|
|
31
|
+
set_dump_path, get_dump_mode, check_compare_param, load_stack_json, get_file_type, add_time_with_json
|
|
32
|
+
from msprobe.core.compare.check import check_dump_json_str, check_stack_json_str, cross_dtype_mapping, \
|
|
33
|
+
check_configuration_param
|
|
34
|
+
from msprobe.core.compare.utils import merge_tensor, print_compare_ends_info, read_op, set_stack_json_path, \
|
|
35
|
+
reorder_index
|
|
35
36
|
from msprobe.core.compare.config import ModeConfig, MappingConfig, MappingDict
|
|
36
37
|
from msprobe.core.compare.multiprocessing_compute import CompareRealData
|
|
37
38
|
from msprobe.core.compare.highlight import HighLight
|
|
39
|
+
from msprobe.core.compare.diff_analyze.first_diff_analyze import FirstDiffAnalyze
|
|
38
40
|
|
|
39
41
|
|
|
40
42
|
@dataclass
|
|
@@ -43,12 +45,15 @@ class ComparisonConfig:
|
|
|
43
45
|
stack_mode: bool
|
|
44
46
|
auto_analyze: bool
|
|
45
47
|
fuzzy_match: bool
|
|
48
|
+
highlight: bool
|
|
46
49
|
data_mapping: dict
|
|
47
50
|
suffix: str
|
|
48
51
|
cell_mapping: dict
|
|
49
52
|
api_mapping: dict
|
|
50
53
|
layer_mapping: dict
|
|
51
54
|
compared_file_type: str
|
|
55
|
+
first_diff_analyze: bool
|
|
56
|
+
is_print_compare_log: bool
|
|
52
57
|
|
|
53
58
|
|
|
54
59
|
class Comparator:
|
|
@@ -57,17 +62,18 @@ class Comparator:
|
|
|
57
62
|
self.mode_config = mode_config
|
|
58
63
|
self.mapping_config = mapping_config
|
|
59
64
|
self.cross_frame = is_cross_framework
|
|
60
|
-
|
|
61
65
|
self.mapping_dict = MappingDict(mapping_config)
|
|
62
66
|
|
|
63
|
-
|
|
64
|
-
def process_output_file(output_path, suffix, compared_file_type):
|
|
67
|
+
def process_output_file(self, output_path, suffix, compared_file_type):
|
|
65
68
|
file_name_prefix_mapping = {
|
|
66
69
|
Const.DUMP_JSON_FILE: "compare_result",
|
|
67
70
|
Const.DEBUG_JSON_FILE: "debug_compare_result"
|
|
68
71
|
}
|
|
69
72
|
file_name_prefix = file_name_prefix_mapping.get(compared_file_type, "compare_result")
|
|
70
|
-
|
|
73
|
+
if self.mode_config.first_diff_analyze:
|
|
74
|
+
file_name = add_time_with_json("compare_result" + suffix)
|
|
75
|
+
else:
|
|
76
|
+
file_name = add_time_with_xlsx(file_name_prefix + suffix)
|
|
71
77
|
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
72
78
|
if os.path.exists(file_path):
|
|
73
79
|
logger.warning(f"{file_path} will be deleted.")
|
|
@@ -95,6 +101,7 @@ class Comparator:
|
|
|
95
101
|
|
|
96
102
|
# get kwargs or set default value
|
|
97
103
|
suffix = kwargs.get('suffix', '')
|
|
104
|
+
rank = suffix[1:]
|
|
98
105
|
|
|
99
106
|
# process output file
|
|
100
107
|
file_path = self.process_output_file(output_path, suffix, self.mode_config.compared_file_type)
|
|
@@ -103,22 +110,45 @@ class Comparator:
|
|
|
103
110
|
npu_json = input_param.get("npu_json_path")
|
|
104
111
|
bench_json = input_param.get("bench_json_path")
|
|
105
112
|
stack_json = input_param.get("stack_json_path")
|
|
106
|
-
|
|
113
|
+
parse_data = ParseData(self.mode_config, rank) # load and parse json data
|
|
114
|
+
npu_df, bench_df = parse_data.parse([npu_json, bench_json, stack_json])
|
|
115
|
+
result_df = self.compare_statistics(npu_df, bench_df)
|
|
107
116
|
if not result_df.values.tolist():
|
|
108
117
|
logger.warning("Can`t match any op. No compare result file generated.")
|
|
109
118
|
return
|
|
110
119
|
|
|
120
|
+
if self.mode_config.first_diff_analyze:
|
|
121
|
+
# add P2POp additional info from npu_df and bench_df to result_df
|
|
122
|
+
result_df['NPU P2POp op'] = npu_df['op']
|
|
123
|
+
result_df['Bench P2POp op'] = bench_df['op']
|
|
124
|
+
result_df['NPU P2POp peer'] = npu_df['peer']
|
|
125
|
+
result_df['Bench P2POp peer'] = bench_df['peer']
|
|
126
|
+
|
|
127
|
+
first_diff_analyze = FirstDiffAnalyze(self.mode_config, rank)
|
|
128
|
+
check_result = first_diff_analyze.check(result_df)
|
|
129
|
+
save_json(file_path, check_result, indent=4)
|
|
130
|
+
logger.info(f"Saving json file to disk: {file_path}")
|
|
131
|
+
return
|
|
132
|
+
|
|
111
133
|
# compare real data
|
|
112
134
|
if self.mode_config.dump_mode == Const.ALL:
|
|
113
135
|
compare_real_data = CompareRealData(self.file_reader, self.mode_config, self.cross_frame)
|
|
114
136
|
result_df = compare_real_data.do_multi_process(input_param, result_df)
|
|
115
137
|
|
|
116
|
-
#
|
|
117
|
-
|
|
118
|
-
highlight
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
138
|
+
# save result excel file
|
|
139
|
+
logger.info(f'Saving result excel file in progress. The file path is: {file_path}.')
|
|
140
|
+
if self.mode_config.highlight and len(result_df) <= CompareConst.MAX_EXCEL_LENGTH:
|
|
141
|
+
# highlight if not too long
|
|
142
|
+
highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
|
|
143
|
+
highlight = HighLight(self.mode_config, rank)
|
|
144
|
+
if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE:
|
|
145
|
+
highlight.find_compare_result_error_rows(result_df, highlight_dict)
|
|
146
|
+
result_df.drop(columns=['state', 'api_origin_name'], inplace=True) # 删除中间数据,两列不落盘
|
|
147
|
+
highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path)
|
|
148
|
+
else:
|
|
149
|
+
# fallback to simple save without highlight
|
|
150
|
+
result_df.drop(columns=['state', 'api_origin_name'], inplace=True) # 删除中间数据,两列不落盘
|
|
151
|
+
save_excel(file_path, result_df)
|
|
122
152
|
|
|
123
153
|
# output compare analysis suggestions
|
|
124
154
|
if self.mode_config.auto_analyze:
|
|
@@ -127,11 +157,7 @@ class Comparator:
|
|
|
127
157
|
|
|
128
158
|
print_compare_ends_info()
|
|
129
159
|
|
|
130
|
-
def compare_statistics(self,
|
|
131
|
-
# load and parse json data
|
|
132
|
-
parse_data = ParseData(self.mode_config)
|
|
133
|
-
npu_df, bench_df = parse_data.parse(file_list)
|
|
134
|
-
|
|
160
|
+
def compare_statistics(self, npu_df, bench_df):
|
|
135
161
|
npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str)
|
|
136
162
|
bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str)
|
|
137
163
|
|
|
@@ -149,6 +175,8 @@ class Comparator:
|
|
|
149
175
|
match_result.loc[~match.gen_dtype_condition(match_result), bench_columns] = CompareConst.N_A
|
|
150
176
|
|
|
151
177
|
# organize compare result table by renaming columns
|
|
178
|
+
if self.mode_config.dump_mode == Const.ALL and self.mode_config.first_diff_analyze:
|
|
179
|
+
self.mode_config.dump_mode = Const.SUMMARY
|
|
152
180
|
create_table = CreateTable(self.mode_config)
|
|
153
181
|
result_df, header = create_table.make_result_df(match_result)
|
|
154
182
|
|
|
@@ -158,8 +186,9 @@ class Comparator:
|
|
|
158
186
|
|
|
159
187
|
|
|
160
188
|
class ParseData:
|
|
161
|
-
def __init__(self, mode_config: ModeConfig):
|
|
189
|
+
def __init__(self, mode_config: ModeConfig, rank):
|
|
162
190
|
self.mode_config = mode_config
|
|
191
|
+
self.rank = rank
|
|
163
192
|
|
|
164
193
|
def parse(self, file_list):
|
|
165
194
|
npu_json_path, bench_json_path, stack_json_path = file_list
|
|
@@ -168,21 +197,24 @@ class ParseData:
|
|
|
168
197
|
stack_json_data = load_stack_json(stack_json_path) if self.mode_config.stack_mode else None
|
|
169
198
|
|
|
170
199
|
# parse json data and generate df
|
|
171
|
-
npu_df = self.gen_data_df(npu_json_data, stack_json_data)
|
|
172
|
-
bench_df = self.gen_data_df(bench_json_data, stack_json_data)
|
|
200
|
+
npu_df = self.gen_data_df(npu_json_data, stack_json_data, 'NPU')
|
|
201
|
+
bench_df = self.gen_data_df(bench_json_data, stack_json_data, 'Bench')
|
|
173
202
|
|
|
174
203
|
return npu_df, bench_df
|
|
175
204
|
|
|
176
|
-
def gen_data_df(self, data_json, stack_json_data):
|
|
205
|
+
def gen_data_df(self, data_json, stack_json_data, device: str):
|
|
177
206
|
result = {
|
|
178
207
|
CompareConst.OP_NAME: [],
|
|
179
208
|
Const.DTYPE: [],
|
|
180
209
|
Const.SHAPE: [],
|
|
181
210
|
Const.SUMMARY: [],
|
|
182
|
-
Const.STACK_INFO: []
|
|
211
|
+
Const.STACK_INFO: [],
|
|
212
|
+
Const.STATE: [],
|
|
213
|
+
Const.API_ORIGIN_NAME: [],
|
|
214
|
+
Const.REQ_GRAD: []
|
|
183
215
|
}
|
|
184
216
|
if self.mode_config.dump_mode == Const.ALL:
|
|
185
|
-
result[
|
|
217
|
+
result[Const.DATA_NAME] = []
|
|
186
218
|
elif self.mode_config.dump_mode == Const.MD5:
|
|
187
219
|
result[Const.MD5] = []
|
|
188
220
|
|
|
@@ -192,56 +224,50 @@ class ParseData:
|
|
|
192
224
|
return pd.DataFrame(result)
|
|
193
225
|
|
|
194
226
|
api_nums = len(apis_data)
|
|
195
|
-
|
|
227
|
+
default_bar_desc = f'{device} API/Module Read Progress'
|
|
228
|
+
bar_desc_add_rank = f'[{self.rank}]' + default_bar_desc if self.rank else default_bar_desc
|
|
229
|
+
progress_bar = tqdm(total=api_nums, desc=bar_desc_add_rank, unit="api/module", ncols=100)
|
|
196
230
|
|
|
197
231
|
# 从json中循环解析API数据,遍历所有API
|
|
198
232
|
for data_name in apis_data:
|
|
199
233
|
check_op_str_pattern_valid(data_name)
|
|
200
|
-
|
|
201
|
-
if not
|
|
234
|
+
op_parsed_list = self.gen_merge_list(data_json, data_name, stack_json_data)
|
|
235
|
+
if not op_parsed_list:
|
|
202
236
|
continue
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
result[
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
else:
|
|
222
|
-
info_list = merge_list[CompareConst.DEBUG_STRUCT]
|
|
223
|
-
check_api_info_len(op_name, info_list, 1)
|
|
224
|
-
struct = info_list.pop(0)
|
|
225
|
-
|
|
226
|
-
check_api_info_len(op_name, struct, 2)
|
|
227
|
-
result[Const.DTYPE].append(struct[0])
|
|
228
|
-
result[Const.SHAPE].append(struct[1])
|
|
237
|
+
reordered_index_list = reorder_index(op_parsed_list)
|
|
238
|
+
for i, index in enumerate(reordered_index_list):
|
|
239
|
+
op_item = op_parsed_list[index]
|
|
240
|
+
|
|
241
|
+
# common key
|
|
242
|
+
result[CompareConst.OP_NAME].append(op_item.get('full_op_name'))
|
|
243
|
+
result[Const.DTYPE].append(op_item.get(Const.DTYPE))
|
|
244
|
+
result[Const.SHAPE].append(op_item.get(Const.SHAPE))
|
|
245
|
+
result[Const.STATE].append(op_item.get(Const.STATE))
|
|
246
|
+
result[Const.REQ_GRAD].append(op_item.get(Const.REQ_GRAD))
|
|
247
|
+
result[Const.API_ORIGIN_NAME].append(data_name)
|
|
248
|
+
summary_data = [
|
|
249
|
+
str(op_item.get(key)) if op_item.get(key) is None else op_item.get(key)
|
|
250
|
+
for key in Const.SUMMARY_METRICS_LIST
|
|
251
|
+
]
|
|
252
|
+
result[Const.SUMMARY].append(summary_data)
|
|
253
|
+
|
|
254
|
+
# dump_mode differ key
|
|
229
255
|
if self.mode_config.dump_mode == Const.MD5:
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
check_api_info_len(op_name, summary_reorder, 1)
|
|
234
|
-
result[Const.SUMMARY].append(summary_reorder.pop(0))
|
|
256
|
+
result[Const.MD5].append(op_parsed_list[index].get(Const.MD5))
|
|
257
|
+
if self.mode_config.dump_mode == Const.ALL:
|
|
258
|
+
result[Const.DATA_NAME].append(op_item.get(Const.DATA_NAME))
|
|
235
259
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
result[Const.STACK_INFO].append(
|
|
260
|
+
# mode_config stack_mode addition key
|
|
261
|
+
if i == 0 and self.mode_config.stack_mode:
|
|
262
|
+
result[Const.STACK_INFO].append(op_parsed_list[-1].get('full_info'))
|
|
239
263
|
else:
|
|
240
264
|
result[Const.STACK_INFO].append(None)
|
|
241
265
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
result
|
|
266
|
+
# mode_config first_diff_analyze addition key
|
|
267
|
+
if self.mode_config.first_diff_analyze:
|
|
268
|
+
result.setdefault('op', []).append(op_item.get('op', str(None)))
|
|
269
|
+
result.setdefault('peer', []).append(op_item.get('peer', str(None)))
|
|
270
|
+
|
|
245
271
|
progress_bar.update(1)
|
|
246
272
|
progress_bar.close()
|
|
247
273
|
return pd.DataFrame(result)
|
|
@@ -256,14 +282,14 @@ class ParseData:
|
|
|
256
282
|
stack_info = stack_json_data.get(op_name)
|
|
257
283
|
if stack_info is not None:
|
|
258
284
|
check_stack_json_str(stack_info, op_name)
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
return
|
|
285
|
+
else:
|
|
286
|
+
stack_info = None
|
|
287
|
+
# always add stack_info whether stack_mode is True
|
|
288
|
+
op_parsed_list.append({
|
|
289
|
+
'full_op_name': op_name,
|
|
290
|
+
'full_info': stack_info
|
|
291
|
+
})
|
|
292
|
+
return op_parsed_list
|
|
267
293
|
|
|
268
294
|
|
|
269
295
|
class ProcessDf:
|
|
@@ -327,13 +353,17 @@ class ProcessDf:
|
|
|
327
353
|
return npu_op_name
|
|
328
354
|
|
|
329
355
|
def modify_compare_data_with_user_mapping(self, npu_df, bench_df):
|
|
356
|
+
def remove_prefix(string, prefix):
|
|
357
|
+
if string.startswith(prefix):
|
|
358
|
+
return string[len(prefix):]
|
|
359
|
+
return string
|
|
360
|
+
|
|
330
361
|
def gen_input_compare_key(pattern, term):
|
|
331
362
|
is_unmatched = True
|
|
332
363
|
for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')):
|
|
333
|
-
if op_name
|
|
364
|
+
if remove_prefix(op_name, api_origin_name + pattern) == str(prefix):
|
|
334
365
|
npu_df.loc[index, CompareConst.CMP_KEY] = (
|
|
335
|
-
op_name.replace(pattern + str(prefix),
|
|
336
|
-
pattern + str(mapping_dict.get(f'pt_{term}')[i])))
|
|
366
|
+
op_name.replace(pattern + str(prefix), pattern + str(mapping_dict.get(f'pt_{term}')[i])))
|
|
337
367
|
is_unmatched = False
|
|
338
368
|
return is_unmatched
|
|
339
369
|
|
|
@@ -355,15 +385,17 @@ class ProcessDf:
|
|
|
355
385
|
continue
|
|
356
386
|
for index in ms_api_indices_dict.get(ms_api):
|
|
357
387
|
op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1)
|
|
358
|
-
|
|
388
|
+
state = npu_df.loc[index, Const.STATE]
|
|
389
|
+
api_origin_name = npu_df.loc[index, Const.API_ORIGIN_NAME].replace(ms_api, pt_api, 1)
|
|
390
|
+
if state == Const.INPUT:
|
|
359
391
|
is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args')
|
|
360
|
-
elif
|
|
392
|
+
elif state == Const.KWARGS:
|
|
361
393
|
is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args')
|
|
362
|
-
elif
|
|
394
|
+
elif state == Const.OUTPUT:
|
|
363
395
|
is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output')
|
|
364
|
-
elif
|
|
396
|
+
elif state == Const.PARAMS:
|
|
365
397
|
is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters')
|
|
366
|
-
elif
|
|
398
|
+
elif state == Const.PARAMS_GRAD:
|
|
367
399
|
is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad')
|
|
368
400
|
else:
|
|
369
401
|
logger.error(f'Excepted op_name: {op_name}')
|
|
@@ -413,8 +445,8 @@ class Match:
|
|
|
413
445
|
@staticmethod
|
|
414
446
|
def put_unmatched_in_table(match_result, npu_op_item):
|
|
415
447
|
npu_columns = npu_op_item.index.tolist()[:-2]
|
|
416
|
-
|
|
417
|
-
na_series = pd.Series([CompareConst.N_A] * len(
|
|
448
|
+
bench_columns = [name + '_y' for name in npu_columns]
|
|
449
|
+
na_series = pd.Series([CompareConst.N_A] * len(bench_columns), index=bench_columns)
|
|
418
450
|
new_result_item = pd.concat([npu_op_item, na_series]).to_frame().T
|
|
419
451
|
new_result_item.columns = CompareConst.MATCH_RESULT_COLUMNS
|
|
420
452
|
match_result = pd.concat([match_result, new_result_item])
|
|
@@ -610,12 +642,21 @@ class CreateTable:
|
|
|
610
642
|
'md5_x': CompareConst.NPU_MD5,
|
|
611
643
|
'md5_y': CompareConst.BENCH_MD5,
|
|
612
644
|
'data_name_x': CompareConst.DATA_NAME,
|
|
613
|
-
'stack_info_x': CompareConst.STACK
|
|
645
|
+
'stack_info_x': CompareConst.STACK,
|
|
646
|
+
'state_x': Const.STATE,
|
|
647
|
+
'api_origin_name_x': Const.API_ORIGIN_NAME,
|
|
648
|
+
'requires_grad_x': CompareConst.NPU_REQ_GRAD,
|
|
649
|
+
'requires_grad_y': CompareConst.BENCH_REQ_GRAD
|
|
650
|
+
},
|
|
651
|
+
inplace=True)
|
|
614
652
|
|
|
615
653
|
# process summary data
|
|
616
654
|
npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
|
|
617
655
|
bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
|
|
618
656
|
CompareConst.BENCH_NORM]
|
|
657
|
+
# process requires_grad
|
|
658
|
+
result[CompareConst.REQ_GRAD_CONSIST] = result[CompareConst.NPU_REQ_GRAD] == result[CompareConst.BENCH_REQ_GRAD]
|
|
659
|
+
|
|
619
660
|
if result.empty:
|
|
620
661
|
result[npu_summary] = pd.DataFrame(columns=npu_summary)
|
|
621
662
|
result[bench_summary] = pd.DataFrame(columns=bench_summary)
|
|
@@ -623,6 +664,7 @@ class CreateTable:
|
|
|
623
664
|
result[npu_summary] = result['summary_x'].apply(self.set_summary).tolist()
|
|
624
665
|
result[bench_summary] = result['summary_y'].apply(self.set_summary).tolist()
|
|
625
666
|
|
|
667
|
+
header.extend([Const.STATE, Const.API_ORIGIN_NAME])
|
|
626
668
|
result_df = pd.DataFrame(columns=header)
|
|
627
669
|
for h in header:
|
|
628
670
|
if h in result.columns:
|
|
@@ -667,13 +709,13 @@ class CalcStatsDiff:
|
|
|
667
709
|
result_df.loc[cond_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN
|
|
668
710
|
|
|
669
711
|
cond_not_nan_diff = cond_valid_stat & ~cond_diff_nan
|
|
670
|
-
condition_pt_zero = bench_val == 0
|
|
712
|
+
condition_pt_zero = self.get_number(bench_val) == 0
|
|
671
713
|
result_df.loc[cond_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.N_A
|
|
672
714
|
|
|
673
715
|
# 相对误差转成百分比字符串
|
|
674
716
|
cond_ref_err = cond_not_nan_diff & ~condition_pt_zero
|
|
675
717
|
result_df.loc[cond_ref_err, rel_err_name] = (
|
|
676
|
-
result_df.loc[cond_ref_err, diff_name] / bench_val[cond_ref_err] * 100)
|
|
718
|
+
result_df.loc[cond_ref_err, diff_name] / bench_val[cond_ref_err].astype(float) * 100)
|
|
677
719
|
result_df.loc[cond_ref_err, rel_err_name] = (result_df.loc[cond_ref_err, rel_err_name].abs().astype(str) + '%')
|
|
678
720
|
|
|
679
721
|
magnitude = self.get_number(result_df[diff_name]).abs() / (pd.Series(
|
|
@@ -685,12 +727,13 @@ class CalcStatsDiff:
|
|
|
685
727
|
condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
|
|
686
728
|
result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
|
|
687
729
|
result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH
|
|
730
|
+
condition_req_grad_consist = result_df[CompareConst.NPU_REQ_GRAD] == result_df[CompareConst.BENCH_REQ_GRAD]
|
|
688
731
|
|
|
689
732
|
if self.mode_config.dump_mode == Const.MD5:
|
|
690
733
|
condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5]
|
|
691
734
|
result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS
|
|
692
735
|
result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
|
|
693
|
-
elif self.mode_config.dump_mode == Const.SUMMARY:
|
|
736
|
+
elif self.mode_config.first_diff_analyze or self.mode_config.dump_mode == Const.SUMMARY:
|
|
694
737
|
warning_list = [
|
|
695
738
|
self.calc_summary_diff(result_df, condition_no_bench, stats_index)
|
|
696
739
|
for stats_index in ['max', 'min', 'mean', 'l2norm']
|
|
@@ -698,14 +741,16 @@ class CalcStatsDiff:
|
|
|
698
741
|
warning_flag = pd.DataFrame(warning_list).any()
|
|
699
742
|
result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
|
|
700
743
|
result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
|
|
701
|
-
result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
|
|
744
|
+
result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy. '
|
|
745
|
+
result_df.loc[~condition_req_grad_consist, CompareConst.ERROR_MESSAGE] += 'Requires_grad inconsistent. '
|
|
702
746
|
else:
|
|
703
747
|
fill_cols = [CompareConst.COSINE, CompareConst.EUC_DIST,
|
|
704
748
|
CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
|
|
705
749
|
CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
|
|
706
750
|
CompareConst.ERROR_MESSAGE]
|
|
707
|
-
result_df.loc[~condition_no_bench, fill_cols] = ''
|
|
751
|
+
result_df.loc[~condition_no_bench, fill_cols] = '' # 默认填充'', df默认省缺值为nan,不便后续处理,容易出现意外情况
|
|
708
752
|
result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES
|
|
753
|
+
result_df.loc[~condition_req_grad_consist, CompareConst.ERROR_MESSAGE] = 'Requires_grad inconsistent. '
|
|
709
754
|
|
|
710
755
|
return result_df[header]
|
|
711
756
|
|
|
@@ -718,12 +763,15 @@ def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig:
|
|
|
718
763
|
stack_mode=False,
|
|
719
764
|
auto_analyze=kwargs.get('auto_analyze', True),
|
|
720
765
|
fuzzy_match=kwargs.get('fuzzy_match', False),
|
|
766
|
+
highlight=kwargs.get('highlight', False),
|
|
721
767
|
data_mapping=kwargs.get('data_mapping', {}),
|
|
722
768
|
suffix=kwargs.get('suffix', ''),
|
|
723
769
|
cell_mapping=kwargs.get('cell_mapping', {}),
|
|
724
770
|
api_mapping=kwargs.get('api_mapping', {}),
|
|
725
771
|
layer_mapping=kwargs.get('layer_mapping', {}),
|
|
772
|
+
first_diff_analyze=kwargs.get('first_diff_analyze', False),
|
|
726
773
|
compared_file_type='',
|
|
774
|
+
is_print_compare_log=input_param.get('is_print_compare_log', True)
|
|
727
775
|
)
|
|
728
776
|
|
|
729
777
|
set_dump_path(input_param)
|
|
@@ -736,8 +784,7 @@ def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig:
|
|
|
736
784
|
else:
|
|
737
785
|
config.stack_mode = set_stack_json_path(input_param)
|
|
738
786
|
|
|
739
|
-
check_configuration_param(config
|
|
740
|
-
input_param.get('is_print_compare_log', True))
|
|
787
|
+
check_configuration_param(config)
|
|
741
788
|
create_directory(output_path)
|
|
742
789
|
check_compare_param(input_param, output_path, config.dump_mode, config.stack_mode)
|
|
743
790
|
|
msprobe/core/compare/check.py
CHANGED
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import os
|
|
17
|
+
|
|
16
18
|
from msprobe.core.common.log import logger
|
|
17
19
|
from msprobe.core.common.utils import check_op_str_pattern_valid, CompareException
|
|
18
20
|
from msprobe.core.common.const import Const
|
|
@@ -106,3 +108,14 @@ def check_stack_json_str(stack_info, op_name):
|
|
|
106
108
|
else:
|
|
107
109
|
logger.error(f"Expected stack_info to be a list, but got {type(stack_info).__name__} for '{op_name}'")
|
|
108
110
|
raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def check_configuration_param(config):
|
|
114
|
+
arg_list = [config.stack_mode, config.auto_analyze, config.fuzzy_match,
|
|
115
|
+
config.highlight, config.first_diff_analyze, config.is_print_compare_log]
|
|
116
|
+
arg_names = ['stack_mode', 'auto_analyze', 'fuzzy_match',
|
|
117
|
+
'highlight', 'first_diff_analyze', 'is_print_compare_log']
|
|
118
|
+
for arg, name in zip(arg_list, arg_names):
|
|
119
|
+
if not isinstance(arg, bool):
|
|
120
|
+
logger.error(f"Invalid input parameter, {name} which should be only bool type.")
|
|
121
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,28 +13,40 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import
|
|
16
|
+
import os
|
|
17
|
+
|
|
17
18
|
from msprobe.core.common.file_utils import check_file_type, load_json, check_file_or_directory_path
|
|
18
19
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
19
20
|
from msprobe.core.common.utils import CompareException
|
|
20
21
|
from msprobe.core.common.log import logger
|
|
22
|
+
from msprobe.core.compare.utils import get_paired_dirs
|
|
23
|
+
|
|
21
24
|
|
|
25
|
+
def compare_cli(args, depth=1):
|
|
26
|
+
if depth > 2:
|
|
27
|
+
logger.error("Recursive compare error, depth exceeds 2.")
|
|
28
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
22
29
|
|
|
23
|
-
|
|
24
|
-
|
|
30
|
+
if isinstance(args.input_path, dict): # special for dyn-graph mix compare
|
|
31
|
+
input_param = args.input_path
|
|
32
|
+
else:
|
|
33
|
+
input_param = load_json(args.input_path)
|
|
25
34
|
if not isinstance(input_param, dict):
|
|
26
35
|
logger.error("input_param should be dict, please check!")
|
|
27
36
|
raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
|
|
37
|
+
|
|
28
38
|
npu_path = input_param.get("npu_path", None)
|
|
29
39
|
bench_path = input_param.get("bench_path", None)
|
|
30
40
|
if not npu_path:
|
|
31
|
-
logger.error(f"Missing npu_path in configuration file
|
|
41
|
+
logger.error(f"Missing npu_path in input configuration file, please check!")
|
|
32
42
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
33
43
|
if not bench_path:
|
|
34
|
-
logger.error(f"Missing bench_path in configuration file
|
|
44
|
+
logger.error(f"Missing bench_path in input configuration file, please check!")
|
|
35
45
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
46
|
+
|
|
36
47
|
frame_name = args.framework
|
|
37
48
|
auto_analyze = not args.compare_only
|
|
49
|
+
|
|
38
50
|
if frame_name == Const.PT_FRAMEWORK:
|
|
39
51
|
from msprobe.pytorch.compare.pt_compare import compare
|
|
40
52
|
from msprobe.pytorch.compare.distributed_compare import compare_distributed
|
|
@@ -46,7 +58,9 @@ def compare_cli(args):
|
|
|
46
58
|
common_kwargs = {
|
|
47
59
|
"auto_analyze": auto_analyze,
|
|
48
60
|
"fuzzy_match": args.fuzzy_match,
|
|
61
|
+
"highlight": args.highlight,
|
|
49
62
|
"data_mapping": args.data_mapping,
|
|
63
|
+
"diff_analyze": args.diff_analyze
|
|
50
64
|
}
|
|
51
65
|
|
|
52
66
|
if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE:
|
|
@@ -75,6 +89,12 @@ def compare_cli(args):
|
|
|
75
89
|
elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
|
|
76
90
|
check_file_or_directory_path(npu_path, isdir=True)
|
|
77
91
|
check_file_or_directory_path(bench_path, isdir=True)
|
|
92
|
+
|
|
93
|
+
if depth == 1:
|
|
94
|
+
mix_compare_success = mix_compare(args, input_param, depth)
|
|
95
|
+
if mix_compare_success:
|
|
96
|
+
return
|
|
97
|
+
|
|
78
98
|
kwargs = {
|
|
79
99
|
**common_kwargs,
|
|
80
100
|
"stack_mode": args.stack_mode,
|
|
@@ -90,6 +110,13 @@ def compare_cli(args):
|
|
|
90
110
|
if isinstance(common, bool) and common:
|
|
91
111
|
common_dir_compare(input_param, args.output_path)
|
|
92
112
|
return
|
|
113
|
+
|
|
114
|
+
if common_kwargs.get('diff_analyze', False):
|
|
115
|
+
logger.info("Start finding first diff node......")
|
|
116
|
+
from msprobe.core.compare.find_first.analyzer import DiffAnalyzer
|
|
117
|
+
DiffAnalyzer(npu_path, bench_path, args.output_path, frame_name).analyze()
|
|
118
|
+
return
|
|
119
|
+
|
|
93
120
|
if frame_name == Const.PT_FRAMEWORK:
|
|
94
121
|
compare_distributed(npu_path, bench_path, args.output_path, **kwargs)
|
|
95
122
|
else:
|
|
@@ -97,3 +124,34 @@ def compare_cli(args):
|
|
|
97
124
|
else:
|
|
98
125
|
logger.error("The npu_path and bench_path need to be of the same type.")
|
|
99
126
|
raise CompareException(CompareException.INVALID_COMPARE_MODE)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def mix_compare(args, input_param, depth):
|
|
130
|
+
npu_path = input_param.get("npu_path", None)
|
|
131
|
+
bench_path = input_param.get("bench_path", None)
|
|
132
|
+
|
|
133
|
+
npu_bench_same_dirs_set = set(get_paired_dirs(npu_path, bench_path))
|
|
134
|
+
compare_cross_set = npu_bench_same_dirs_set & Const.MIX_DUMP_NAMES
|
|
135
|
+
|
|
136
|
+
if compare_cross_set:
|
|
137
|
+
logger.info("Start mix compare.")
|
|
138
|
+
origin_output = args.output_path
|
|
139
|
+
|
|
140
|
+
for folder_name in list(compare_cross_set):
|
|
141
|
+
new_npu_path = os.path.join(npu_path, folder_name)
|
|
142
|
+
new_bench_path = os.path.join(bench_path, folder_name)
|
|
143
|
+
paired_steps = get_paired_dirs(new_npu_path, new_bench_path)
|
|
144
|
+
|
|
145
|
+
for step_name in paired_steps:
|
|
146
|
+
logger.info(f"[mix compare] Start comparing {folder_name}/{step_name}")
|
|
147
|
+
npu_dir = os.path.join(new_npu_path, step_name)
|
|
148
|
+
bench_dir = os.path.join(new_bench_path, step_name)
|
|
149
|
+
args.input_path = {
|
|
150
|
+
"npu_path": npu_dir,
|
|
151
|
+
"bench_path": bench_dir,
|
|
152
|
+
"is_print_compare_log": input_param.get("is_print_compare_log", True)
|
|
153
|
+
}
|
|
154
|
+
args.output_path = os.path.join(origin_output, folder_name, step_name)
|
|
155
|
+
compare_cli(args, depth + 1)
|
|
156
|
+
return True
|
|
157
|
+
return False
|
msprobe/core/compare/config.py
CHANGED
|
@@ -20,13 +20,15 @@ from msprobe.core.common.file_utils import load_yaml
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class ModeConfig:
|
|
23
|
-
def __init__(self,
|
|
24
|
-
|
|
25
|
-
self.
|
|
26
|
-
self.
|
|
27
|
-
self.
|
|
28
|
-
self.dump_mode = dump_mode
|
|
29
|
-
self.
|
|
23
|
+
def __init__(self, **kwargs):
|
|
24
|
+
self.stack_mode = kwargs.get('stack_mode', False)
|
|
25
|
+
self.auto_analyze = kwargs.get('auto_analyze', True)
|
|
26
|
+
self.fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
27
|
+
self.highlight = kwargs.get('highlight', False)
|
|
28
|
+
self.dump_mode = kwargs.get('dump_mode', Const.SUMMARY)
|
|
29
|
+
self.first_diff_analyze = kwargs.get('first_diff_analyze', False)
|
|
30
|
+
self.diff_analyze = kwargs.get('diff_analyze', False)
|
|
31
|
+
self.compared_file_type = kwargs.get('compared_file_type', Const.DUMP_JSON_FILE)
|
|
30
32
|
|
|
31
33
|
|
|
32
34
|
class MappingConfig:
|
|
@@ -69,4 +71,4 @@ class MappingDict:
|
|
|
69
71
|
else:
|
|
70
72
|
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
71
73
|
f"{type(data_mapping)}")
|
|
72
|
-
return data_mapping_dict
|
|
74
|
+
return data_mapping_dict
|