mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -19,28 +19,37 @@ import re
|
|
|
19
19
|
from copy import deepcopy
|
|
20
20
|
|
|
21
21
|
import pandas as pd
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
|
|
22
24
|
from msprobe.core.advisor.advisor import Advisor
|
|
23
25
|
from msprobe.core.common.const import CompareConst, Const
|
|
24
26
|
from msprobe.core.common.exceptions import FileCheckException
|
|
25
|
-
from msprobe.core.common.file_utils import load_json
|
|
26
|
-
from msprobe.core.common.file_utils import remove_path
|
|
27
|
+
from msprobe.core.common.file_utils import load_json, remove_path
|
|
27
28
|
from msprobe.core.common.log import logger
|
|
28
|
-
from msprobe.core.common.utils import
|
|
29
|
-
from msprobe.core.compare.check import
|
|
30
|
-
|
|
29
|
+
from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, safe_get_value
|
|
30
|
+
from msprobe.core.compare.check import check_dump_json_str, check_graph_mode, check_stack_json_str, \
|
|
31
|
+
check_struct_match, fuzzy_check_op
|
|
31
32
|
from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
|
|
32
|
-
from msprobe.core.compare.multiprocessing_compute import
|
|
33
|
-
from msprobe.core.compare.npy_compare import compare_ops_apply,
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
get_rela_diff_summary_mode, print_compare_ends_info
|
|
37
|
-
from tqdm import tqdm
|
|
33
|
+
from msprobe.core.compare.multiprocessing_compute import ComparisonResult, _handle_multi_process, _save_cmp_result
|
|
34
|
+
from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_flag_and_msg
|
|
35
|
+
from msprobe.core.compare.utils import get_accuracy, get_rela_diff_summary_mode, get_un_match_accuracy, merge_tensor, \
|
|
36
|
+
print_compare_ends_info, read_op, get_name_and_state, reorder_op_x_list
|
|
38
37
|
|
|
39
38
|
|
|
40
|
-
class
|
|
39
|
+
class ModeConfig:
|
|
40
|
+
def __init__(self, stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=None):
|
|
41
|
+
self.stack_mode = stack_mode
|
|
42
|
+
self.auto_analyze = auto_analyze
|
|
43
|
+
self.fuzzy_match = fuzzy_match
|
|
44
|
+
self.dump_mode = dump_mode
|
|
41
45
|
|
|
42
|
-
|
|
43
|
-
|
|
46
|
+
|
|
47
|
+
class Comparator:
|
|
48
|
+
def __init__(self, mode_config: ModeConfig):
|
|
49
|
+
self.stack_mode = mode_config.stack_mode
|
|
50
|
+
self.auto_analyze = mode_config.auto_analyze
|
|
51
|
+
self.fuzzy_match = mode_config.fuzzy_match
|
|
52
|
+
self.dump_mode = mode_config.dump_mode
|
|
44
53
|
|
|
45
54
|
@staticmethod
|
|
46
55
|
def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args):
|
|
@@ -85,16 +94,15 @@ class Comparator:
|
|
|
85
94
|
value[k] = CompareConst.N_A
|
|
86
95
|
return value
|
|
87
96
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode][:]
|
|
97
|
+
def make_result_table(self, result):
|
|
98
|
+
header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:]
|
|
91
99
|
|
|
92
|
-
if stack_mode:
|
|
100
|
+
if self.stack_mode:
|
|
93
101
|
header.append(CompareConst.STACK)
|
|
94
|
-
if dump_mode == Const.ALL:
|
|
102
|
+
if self.dump_mode == Const.ALL:
|
|
95
103
|
header.append(CompareConst.DATA_NAME)
|
|
96
104
|
else:
|
|
97
|
-
if dump_mode == Const.ALL:
|
|
105
|
+
if self.dump_mode == Const.ALL:
|
|
98
106
|
for row in result:
|
|
99
107
|
del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列
|
|
100
108
|
header.append(CompareConst.DATA_NAME)
|
|
@@ -104,24 +112,25 @@ class Comparator:
|
|
|
104
112
|
result_df = pd.DataFrame(result, columns=header, dtype='object')
|
|
105
113
|
return result_df
|
|
106
114
|
|
|
107
|
-
|
|
108
|
-
def gen_merge_list(cls, json_data, op_name, stack_json_data, dump_mode):
|
|
115
|
+
def gen_merge_list(self, json_data, op_name, stack_json_data):
|
|
109
116
|
op_data = json_data['data'][op_name]
|
|
110
117
|
check_dump_json_str(op_data, op_name)
|
|
111
118
|
op_parsed_list = read_op(op_data, op_name)
|
|
112
119
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
120
|
+
if self.stack_mode:
|
|
121
|
+
stack_info = stack_json_data.get(op_name)
|
|
122
|
+
if stack_info is not None:
|
|
123
|
+
check_stack_json_str(stack_info, op_name)
|
|
124
|
+
# append only when stack_mode is True,
|
|
125
|
+
op_parsed_list.append({
|
|
126
|
+
'full_op_name': op_name,
|
|
127
|
+
'full_info': stack_info
|
|
128
|
+
})
|
|
129
|
+
|
|
130
|
+
merge_list = merge_tensor(op_parsed_list, self.dump_mode)
|
|
122
131
|
return merge_list
|
|
123
132
|
|
|
124
|
-
def check_op(self, npu_dict, bench_dict
|
|
133
|
+
def check_op(self, npu_dict, bench_dict):
|
|
125
134
|
npu_op_name = npu_dict[CompareConst.OP_NAME]
|
|
126
135
|
bench_op_name = bench_dict[CompareConst.OP_NAME]
|
|
127
136
|
graph_mode = check_graph_mode(safe_get_value(npu_op_name, 0, "npu_op_name"),
|
|
@@ -133,34 +142,34 @@ class Comparator:
|
|
|
133
142
|
if graph_mode:
|
|
134
143
|
return graph_mapping.match(npu_op_name[0], bench_op_name[0])
|
|
135
144
|
struct_match = check_struct_match(npu_dict, bench_dict)
|
|
136
|
-
if not fuzzy_match:
|
|
137
|
-
|
|
138
|
-
|
|
145
|
+
if not self.fuzzy_match:
|
|
146
|
+
name_match = npu_op_name == bench_op_name
|
|
147
|
+
return name_match and struct_match
|
|
139
148
|
try:
|
|
140
|
-
|
|
149
|
+
name_match = fuzzy_check_op(npu_op_name, bench_op_name)
|
|
141
150
|
except Exception as err:
|
|
142
151
|
logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
|
|
143
|
-
|
|
144
|
-
return
|
|
152
|
+
name_match = False
|
|
153
|
+
return name_match and struct_match
|
|
145
154
|
|
|
146
|
-
def match_op(self, npu_queue, bench_queue
|
|
155
|
+
def match_op(self, npu_queue, bench_queue):
|
|
147
156
|
for b_index, b_op in enumerate(bench_queue[0: -1]):
|
|
148
|
-
if self.check_op(npu_queue[-1], b_op
|
|
157
|
+
if self.check_op(npu_queue[-1], b_op):
|
|
149
158
|
return len(npu_queue) - 1, b_index
|
|
150
|
-
if self.check_op(npu_queue[-1], bench_queue[-1]
|
|
159
|
+
if self.check_op(npu_queue[-1], bench_queue[-1]):
|
|
151
160
|
return len(npu_queue) - 1, len(bench_queue) - 1
|
|
152
161
|
for n_index, n_op in enumerate(npu_queue[0: -1]):
|
|
153
|
-
if self.check_op(n_op, bench_queue[-1]
|
|
162
|
+
if self.check_op(n_op, bench_queue[-1]):
|
|
154
163
|
return n_index, len(bench_queue) - 1
|
|
155
164
|
return -1, -1
|
|
156
165
|
|
|
157
|
-
def compare_process(self, file_lists
|
|
166
|
+
def compare_process(self, file_lists):
|
|
158
167
|
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
159
168
|
npu_json_data = load_json(npu_json_path)
|
|
160
169
|
bench_json_data = load_json(bench_json_path)
|
|
161
|
-
stack_json_data = load_json(stack_json_path)
|
|
170
|
+
stack_json_data = load_json(stack_json_path) if self.stack_mode else None
|
|
162
171
|
|
|
163
|
-
if fuzzy_match:
|
|
172
|
+
if self.fuzzy_match:
|
|
164
173
|
logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
|
|
165
174
|
|
|
166
175
|
npu_ops_queue = []
|
|
@@ -184,8 +193,7 @@ class Comparator:
|
|
|
184
193
|
last_npu_ops_len = len(npu_ops_queue)
|
|
185
194
|
op_name_npu = next(ops_npu_iter)
|
|
186
195
|
check_op_str_pattern_valid(op_name_npu)
|
|
187
|
-
|
|
188
|
-
npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data, dump_mode)
|
|
196
|
+
npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data)
|
|
189
197
|
if npu_merge_list:
|
|
190
198
|
npu_ops_queue.append(npu_merge_list)
|
|
191
199
|
except StopIteration:
|
|
@@ -194,7 +202,7 @@ class Comparator:
|
|
|
194
202
|
last_bench_ops_len = len(bench_ops_queue)
|
|
195
203
|
op_name_bench = next(ops_bench_iter)
|
|
196
204
|
check_op_str_pattern_valid(op_name_bench)
|
|
197
|
-
bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data
|
|
205
|
+
bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data)
|
|
198
206
|
if bench_merge_list:
|
|
199
207
|
bench_ops_queue.append(bench_merge_list)
|
|
200
208
|
except StopIteration:
|
|
@@ -213,59 +221,64 @@ class Comparator:
|
|
|
213
221
|
logger.info("Please check whether the number and calls of APIs in NPU and Bench models are consistent.")
|
|
214
222
|
break
|
|
215
223
|
|
|
216
|
-
n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue
|
|
224
|
+
n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue)
|
|
225
|
+
|
|
226
|
+
# 如果没有匹配到,数据放到队列中,跳过,直到后面匹配到,把匹配之前的api放到不匹配中
|
|
217
227
|
if n_match_point == -1 and b_match_point == -1:
|
|
218
228
|
continue
|
|
229
|
+
|
|
219
230
|
n_match_data = npu_ops_queue[n_match_point]
|
|
220
231
|
b_match_data = bench_ops_queue[b_match_point]
|
|
221
232
|
un_match_data = npu_ops_queue[0: n_match_point]
|
|
222
233
|
for npu_data in un_match_data:
|
|
223
|
-
get_un_match_accuracy(result, npu_data, dump_mode)
|
|
224
|
-
get_accuracy(result, n_match_data, b_match_data, dump_mode)
|
|
234
|
+
get_un_match_accuracy(result, npu_data, self.dump_mode)
|
|
235
|
+
get_accuracy(result, n_match_data, b_match_data, self.dump_mode)
|
|
225
236
|
del npu_ops_queue[0: n_match_point + 1]
|
|
226
237
|
del bench_ops_queue[0: b_match_point + 1]
|
|
227
238
|
progress_bar.close()
|
|
228
239
|
if npu_ops_queue:
|
|
229
240
|
for npu_data in npu_ops_queue:
|
|
230
|
-
get_un_match_accuracy(result, npu_data, dump_mode)
|
|
241
|
+
get_un_match_accuracy(result, npu_data, self.dump_mode)
|
|
231
242
|
|
|
232
|
-
result_df = self.make_result_table(result
|
|
243
|
+
result_df = self.make_result_table(result)
|
|
233
244
|
return result_df
|
|
234
245
|
|
|
235
|
-
def merge_data(self, json_data, stack_json_data
|
|
246
|
+
def merge_data(self, json_data, stack_json_data):
|
|
236
247
|
ops_all = {}
|
|
237
248
|
for op_name in json_data.get('data', {}):
|
|
238
|
-
merge_list = self.gen_merge_list(json_data, op_name, stack_json_data
|
|
249
|
+
merge_list = self.gen_merge_list(json_data, op_name, stack_json_data)
|
|
239
250
|
if merge_list:
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
251
|
+
struct_to_index_mapping = {
|
|
252
|
+
CompareConst.INPUT_STRUCT: 0,
|
|
253
|
+
CompareConst.OUTPUT_STRUCT: 0,
|
|
254
|
+
CompareConst.PARAMS_STRUCT: 0,
|
|
255
|
+
CompareConst.PARAMS_GRAD_STRUCT: 0
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
op_name_list = merge_list.get(CompareConst.OP_NAME)
|
|
259
|
+
summary_list = merge_list.get(Const.SUMMARY)
|
|
260
|
+
data_name_list = merge_list.get('data_name')
|
|
261
|
+
op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
|
|
262
|
+
summary_list,
|
|
263
|
+
data_name_list)
|
|
264
|
+
for index, op_full_name in enumerate(op_name_reorder):
|
|
265
|
+
data_name = data_name_reorder[index] if data_name_reorder else None
|
|
266
|
+
|
|
267
|
+
_, state = get_name_and_state(op_full_name)
|
|
268
|
+
struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
|
|
269
|
+
if not struct_key:
|
|
270
|
+
continue
|
|
271
|
+
ops_all[op_full_name] = {
|
|
272
|
+
CompareConst.STRUCT: safe_get_value(merge_list, struct_to_index_mapping.get(struct_key),
|
|
273
|
+
"merge_list", key=struct_key),
|
|
274
|
+
CompareConst.SUMMARY: safe_get_value(summary_reorder, index, "summary_reorder"),
|
|
275
|
+
'data_name': data_name,
|
|
276
|
+
'stack_info': merge_list.get('stack_info')
|
|
277
|
+
}
|
|
278
|
+
struct_to_index_mapping[struct_key] += 1
|
|
266
279
|
return ops_all
|
|
267
280
|
|
|
268
|
-
def get_accuracy(self, npu_ops_all, bench_ops_all
|
|
281
|
+
def get_accuracy(self, npu_ops_all, bench_ops_all):
|
|
269
282
|
result = []
|
|
270
283
|
bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all)
|
|
271
284
|
for ms_op_name, bench_op_name in self.data_mapping_dict.items():
|
|
@@ -273,7 +286,7 @@ class Comparator:
|
|
|
273
286
|
npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
|
|
274
287
|
bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
|
|
275
288
|
has_stack = npu_stack_info and bench_stack_info
|
|
276
|
-
if dump_mode == Const.MD5:
|
|
289
|
+
if self.dump_mode == Const.MD5:
|
|
277
290
|
result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all,
|
|
278
291
|
bench_ops_all, has_stack, npu_stack_info))
|
|
279
292
|
continue
|
|
@@ -297,7 +310,7 @@ class Comparator:
|
|
|
297
310
|
bench_struct[1]
|
|
298
311
|
]
|
|
299
312
|
|
|
300
|
-
if dump_mode == Const.SUMMARY:
|
|
313
|
+
if self.dump_mode == Const.SUMMARY:
|
|
301
314
|
result_item = base_result_item + [" "] * 8
|
|
302
315
|
else:
|
|
303
316
|
result_item = base_result_item + [" "] * 5
|
|
@@ -306,7 +319,7 @@ class Comparator:
|
|
|
306
319
|
result_item.extend(npu_summary_data)
|
|
307
320
|
bench_summary_data = bench_ops_all.get(bench_op_name).get("summary")
|
|
308
321
|
result_item.extend(bench_summary_data)
|
|
309
|
-
if dump_mode == Const.SUMMARY:
|
|
322
|
+
if self.dump_mode == Const.SUMMARY:
|
|
310
323
|
self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item)
|
|
311
324
|
else:
|
|
312
325
|
result_item.append(CompareConst.ACCURACY_CHECK_YES)
|
|
@@ -315,7 +328,7 @@ class Comparator:
|
|
|
315
328
|
result_item.extend(npu_stack_info)
|
|
316
329
|
else:
|
|
317
330
|
result_item.append(CompareConst.NONE)
|
|
318
|
-
if dump_mode == Const.ALL:
|
|
331
|
+
if self.dump_mode == Const.ALL:
|
|
319
332
|
result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None))
|
|
320
333
|
result.append(result_item)
|
|
321
334
|
elif ms_op_name not in npu_ops_all:
|
|
@@ -324,17 +337,16 @@ class Comparator:
|
|
|
324
337
|
logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.')
|
|
325
338
|
return result
|
|
326
339
|
|
|
327
|
-
def compare_process_custom(self, file_lists
|
|
340
|
+
def compare_process_custom(self, file_lists):
|
|
328
341
|
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
329
342
|
npu_json_data = load_json(npu_json_path)
|
|
330
343
|
bench_json_data = load_json(bench_json_path)
|
|
331
|
-
stack_json_data = load_json(stack_json_path)
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
bench_ops_all = self.merge_data(bench_json_data, stack_json_data, dump_mode)
|
|
344
|
+
stack_json_data = load_json(stack_json_path) if self.stack_mode else None
|
|
345
|
+
npu_ops_all = self.merge_data(npu_json_data, stack_json_data)
|
|
346
|
+
bench_ops_all = self.merge_data(bench_json_data, stack_json_data)
|
|
335
347
|
|
|
336
|
-
result = self.get_accuracy(npu_ops_all, bench_ops_all
|
|
337
|
-
result_df = self.make_result_table(result
|
|
348
|
+
result = self.get_accuracy(npu_ops_all, bench_ops_all)
|
|
349
|
+
result_df = self.make_result_table(result)
|
|
338
350
|
return result_df
|
|
339
351
|
|
|
340
352
|
def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param, bench_data):
|
|
@@ -381,25 +393,23 @@ class Comparator:
|
|
|
381
393
|
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
382
394
|
error_flag = True
|
|
383
395
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
n_value, b_value = reshape_value(n_value, b_value)
|
|
396
|
+
# 通过n_value, b_value同时得到错误标志和错误信息
|
|
397
|
+
n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value,
|
|
398
|
+
error_flag=error_flag, error_file=error_file)
|
|
388
399
|
|
|
389
|
-
err_msg =
|
|
390
|
-
result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=relative_err)
|
|
400
|
+
result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg)
|
|
391
401
|
|
|
392
|
-
if npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
|
|
402
|
+
if self.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
|
|
393
403
|
err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
|
|
394
404
|
result_list.append(err_msg)
|
|
395
405
|
return result_list
|
|
396
406
|
|
|
397
|
-
def compare_core(self,
|
|
407
|
+
def compare_core(self, input_param, output_path, **kwargs):
|
|
398
408
|
"""
|
|
399
409
|
Compares data from multiple JSON files and generates a comparison report.
|
|
400
410
|
|
|
401
411
|
Args:
|
|
402
|
-
|
|
412
|
+
input_param (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
|
|
403
413
|
"stack_path").
|
|
404
414
|
output_path (str): The path where the output Excel report will be saved.
|
|
405
415
|
**kwargs: Additional keyword arguments including:
|
|
@@ -412,11 +422,7 @@ class Comparator:
|
|
|
412
422
|
Returns:
|
|
413
423
|
"""
|
|
414
424
|
# get kwargs or set default value
|
|
415
|
-
stack_mode = kwargs.get('stack_mode', False)
|
|
416
|
-
auto_analyze = kwargs.get('auto_analyze', True)
|
|
417
425
|
suffix = kwargs.get('suffix', '')
|
|
418
|
-
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
419
|
-
dump_mode = kwargs.get('dump_mode', None)
|
|
420
426
|
|
|
421
427
|
logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
|
|
422
428
|
file_name = add_time_with_xlsx("compare_result" + suffix)
|
|
@@ -424,30 +430,25 @@ class Comparator:
|
|
|
424
430
|
remove_path(file_path)
|
|
425
431
|
highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
|
|
426
432
|
|
|
427
|
-
npu_json =
|
|
428
|
-
bench_json =
|
|
429
|
-
stack_json =
|
|
433
|
+
npu_json = input_param.get("npu_json_path")
|
|
434
|
+
bench_json = input_param.get("bench_json_path")
|
|
435
|
+
stack_json = input_param.get("stack_json_path")
|
|
430
436
|
if self.data_mapping:
|
|
431
|
-
result_df = self.compare_process_custom([npu_json, bench_json, stack_json]
|
|
437
|
+
result_df = self.compare_process_custom([npu_json, bench_json, stack_json])
|
|
432
438
|
else:
|
|
433
|
-
result_df = self.compare_process(
|
|
434
|
-
[npu_json, bench_json, stack_json],
|
|
435
|
-
stack_mode,
|
|
436
|
-
fuzzy_match,
|
|
437
|
-
dump_mode
|
|
438
|
-
)
|
|
439
|
+
result_df = self.compare_process([npu_json, bench_json, stack_json])
|
|
439
440
|
|
|
440
441
|
if not result_df.values.tolist():
|
|
441
442
|
logger.warning("Can`t match any op.")
|
|
442
443
|
return
|
|
443
444
|
|
|
444
|
-
if dump_mode == Const.ALL:
|
|
445
|
-
result_df = self.do_multi_process(
|
|
445
|
+
if self.dump_mode == Const.ALL:
|
|
446
|
+
result_df = self.do_multi_process(input_param, result_df)
|
|
446
447
|
|
|
447
|
-
find_compare_result_error_rows(result_df, highlight_dict, dump_mode)
|
|
448
|
+
find_compare_result_error_rows(result_df, highlight_dict, self.dump_mode)
|
|
448
449
|
highlight_rows_xlsx(result_df, highlight_dict, file_path)
|
|
449
450
|
|
|
450
|
-
if auto_analyze:
|
|
451
|
+
if self.auto_analyze:
|
|
451
452
|
advisor = Advisor(result_df, output_path, suffix)
|
|
452
453
|
advisor.analysis()
|
|
453
454
|
|
|
@@ -504,14 +505,18 @@ class Comparator:
|
|
|
504
505
|
logger.error('result dataframe is not found.')
|
|
505
506
|
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
506
507
|
|
|
508
|
+
|
|
507
509
|
def get_bench_data_name(bench_op_name, bench_data):
|
|
508
|
-
bench_name_list = re.split(r'\.(input|output|kwargs)\.', bench_op_name)
|
|
509
|
-
|
|
510
|
+
bench_name_list = re.split(r'\.(input|output|kwargs|parameters|parameters_grad)\.', bench_op_name)
|
|
511
|
+
if len(bench_name_list) > 1 and bench_name_list[1] == Const.PARAMS_GRAD:
|
|
512
|
+
bench_data_bundle = bench_data.get(bench_name_list[0] + Const.SEP + bench_name_list[1], {})
|
|
513
|
+
else:
|
|
514
|
+
bench_data_bundle = bench_data.get(bench_name_list[0], {})
|
|
510
515
|
if not bench_data_bundle or len(bench_name_list) < 3:
|
|
511
516
|
return None
|
|
512
517
|
layers = bench_name_list[2].split(Const.SEP)
|
|
513
518
|
|
|
514
|
-
def
|
|
519
|
+
def _get(key, container):
|
|
515
520
|
if isinstance(container, dict):
|
|
516
521
|
return container.get(key)
|
|
517
522
|
if isinstance(container, list):
|
|
@@ -521,11 +526,14 @@ def get_bench_data_name(bench_op_name, bench_data):
|
|
|
521
526
|
return None
|
|
522
527
|
return None
|
|
523
528
|
|
|
524
|
-
def get_by_layer(container):
|
|
529
|
+
def get_by_layer(container, params_grad=False):
|
|
525
530
|
data = container
|
|
531
|
+
# dump.json中parameters_grad的结构为key:[{}], 如果存在key,有且只有一个列表元素,而op_name中只命名到了key,因此加'0'
|
|
532
|
+
if params_grad:
|
|
533
|
+
layers.append('0')
|
|
526
534
|
for layer in layers:
|
|
527
|
-
data =
|
|
528
|
-
return
|
|
535
|
+
data = _get(layer, data)
|
|
536
|
+
return _get(CompareConst.DATA_NAME.lower(), data)
|
|
529
537
|
|
|
530
538
|
if Const.INPUT == bench_name_list[1]:
|
|
531
539
|
return get_by_layer(bench_data_bundle.get(Const.INPUT, bench_data_bundle.get(Const.INPUT_ARGS)))
|
|
@@ -533,6 +541,9 @@ def get_bench_data_name(bench_op_name, bench_data):
|
|
|
533
541
|
return get_by_layer(bench_data_bundle.get(Const.INPUT_KWARGS))
|
|
534
542
|
elif Const.OUTPUT == bench_name_list[1]:
|
|
535
543
|
return get_by_layer(bench_data_bundle.get(Const.OUTPUT))
|
|
544
|
+
elif Const.PARAMS == bench_name_list[1]:
|
|
545
|
+
return get_by_layer(bench_data_bundle.get(Const.PARAMS))
|
|
546
|
+
elif Const.PARAMS_GRAD == bench_name_list[1]:
|
|
547
|
+
return get_by_layer(bench_data_bundle, params_grad=True)
|
|
536
548
|
else:
|
|
537
549
|
return None
|
|
538
|
-
|
msprobe/core/compare/check.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -16,8 +16,7 @@
|
|
|
16
16
|
from msprobe.core.common.log import logger
|
|
17
17
|
from msprobe.core.compare.utils import rename_api
|
|
18
18
|
from msprobe.core.common.utils import check_op_str_pattern_valid, CompareException
|
|
19
|
-
from msprobe.core.common.const import Const
|
|
20
|
-
|
|
19
|
+
from msprobe.core.common.const import CompareConst, Const
|
|
21
20
|
|
|
22
21
|
dtype_mapping = {
|
|
23
22
|
"Int8": "torch.int8",
|
|
@@ -38,31 +37,40 @@ dtype_mapping = {
|
|
|
38
37
|
}
|
|
39
38
|
|
|
40
39
|
|
|
41
|
-
def
|
|
42
|
-
|
|
43
|
-
bench_struct_in = bench_dict.get("input_struct")
|
|
44
|
-
npu_struct_out = npu_dict.get("output_struct")
|
|
45
|
-
bench_struct_out = bench_dict.get("output_struct")
|
|
40
|
+
def compare_op_dict_struct(npu_dict, bench_dict):
|
|
41
|
+
return all(npu_dict.get(key) == bench_dict.get(key) for key in CompareConst.STRUCT_COMPARE_KEY)
|
|
46
42
|
|
|
47
|
-
|
|
43
|
+
|
|
44
|
+
def check_struct_match(npu_dict, bench_dict):
|
|
45
|
+
is_match = compare_op_dict_struct(npu_dict, bench_dict)
|
|
48
46
|
if not is_match:
|
|
49
|
-
|
|
50
|
-
return False
|
|
47
|
+
struct_match_list = []
|
|
51
48
|
try:
|
|
52
|
-
|
|
53
|
-
|
|
49
|
+
for i, key in enumerate(CompareConst.STRUCT_COMPARE_KEY):
|
|
50
|
+
# 首先额外检查input_struct是否空,input_struct不可能为空
|
|
51
|
+
if i == 0 and (not npu_dict.get(key, []) or not bench_dict.get(key, [])):
|
|
52
|
+
return False
|
|
53
|
+
struct_match_list.append(check_type_shape_match(npu_dict.get(key, []), bench_dict.get(key, [])))
|
|
54
54
|
except CompareException as error:
|
|
55
55
|
err_msg = f'index out of bounds error occurs in npu or bench api, please check!\n' \
|
|
56
56
|
f'npu_dict: {npu_dict}' \
|
|
57
57
|
f'bench_dict: {bench_dict}'
|
|
58
58
|
logger.error(err_msg)
|
|
59
59
|
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
60
|
-
is_match =
|
|
60
|
+
is_match = all(struct_match_list)
|
|
61
61
|
return is_match
|
|
62
62
|
|
|
63
63
|
|
|
64
64
|
def check_type_shape_match(npu_struct, bench_struct):
|
|
65
|
-
|
|
65
|
+
"""
|
|
66
|
+
further check dtypes with a dtype mapping list when dtypes are not entirely consistent.
|
|
67
|
+
"""
|
|
68
|
+
if len(npu_struct) != len(bench_struct):
|
|
69
|
+
return False
|
|
70
|
+
if not npu_struct and not bench_struct:
|
|
71
|
+
return True
|
|
72
|
+
|
|
73
|
+
struct_match = False
|
|
66
74
|
for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
|
|
67
75
|
try:
|
|
68
76
|
npu_type = npu_type_shape[0]
|
|
@@ -76,22 +84,14 @@ def check_type_shape_match(npu_struct, bench_struct):
|
|
|
76
84
|
shape_match = npu_shape == bench_shape
|
|
77
85
|
type_match = npu_type == bench_type
|
|
78
86
|
if not type_match:
|
|
79
|
-
|
|
80
|
-
[Const.FLOAT16, Const.FLOAT32], [Const.FLOAT32, Const.FLOAT16],
|
|
81
|
-
[Const.FLOAT16, Const.BFLOAT16], [Const.BFLOAT16, Const.FLOAT16]
|
|
82
|
-
]
|
|
83
|
-
torch_type = [
|
|
84
|
-
[Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16],
|
|
85
|
-
[Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16]
|
|
86
|
-
]
|
|
87
|
-
if ([npu_type, bench_type] in ms_type) or ([npu_type, bench_type] in torch_type):
|
|
87
|
+
if ([npu_type, bench_type] in CompareConst.MS_TYPE) or ([npu_type, bench_type] in CompareConst.TORCH_TYPE):
|
|
88
88
|
type_match = True
|
|
89
89
|
else:
|
|
90
90
|
type_match = False
|
|
91
|
-
|
|
92
|
-
if not
|
|
91
|
+
struct_match = shape_match and type_match
|
|
92
|
+
if not struct_match:
|
|
93
93
|
return False
|
|
94
|
-
return
|
|
94
|
+
return struct_match
|
|
95
95
|
|
|
96
96
|
|
|
97
97
|
def check_graph_mode(a_op_name, b_op_name):
|
|
@@ -103,6 +103,8 @@ def check_graph_mode(a_op_name, b_op_name):
|
|
|
103
103
|
|
|
104
104
|
|
|
105
105
|
def fuzzy_check_op(npu_name_list, bench_name_list):
|
|
106
|
+
# 先检查api里的item长度是否相等,如果不是parameters_grad, 必然有input或者output,长度不可能为0
|
|
107
|
+
# 如果是parameters_grad, "parameters_grad"字段的字典不会是空字典,因此len>=1
|
|
106
108
|
if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list):
|
|
107
109
|
return False
|
|
108
110
|
is_match = True
|
|
@@ -148,11 +150,11 @@ def check_json_key_value(input_output, op_name, depth=0):
|
|
|
148
150
|
return
|
|
149
151
|
if isinstance(input_output, list):
|
|
150
152
|
for item in input_output:
|
|
151
|
-
check_json_key_value(item, op_name, depth+1)
|
|
153
|
+
check_json_key_value(item, op_name, depth + 1)
|
|
152
154
|
elif isinstance(input_output, dict):
|
|
153
155
|
for key, value in input_output.items():
|
|
154
156
|
if isinstance(value, dict):
|
|
155
|
-
check_json_key_value(value, op_name, depth+1)
|
|
157
|
+
check_json_key_value(value, op_name, depth + 1)
|
|
156
158
|
else:
|
|
157
159
|
valid_key_value(key, value, op_name)
|
|
158
160
|
|
|
@@ -38,40 +38,41 @@ def compare_cli(args):
|
|
|
38
38
|
else:
|
|
39
39
|
from msprobe.mindspore.compare.ms_compare import ms_compare
|
|
40
40
|
from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare
|
|
41
|
+
|
|
42
|
+
common_kwargs = {
|
|
43
|
+
"auto_analyze": auto_analyze,
|
|
44
|
+
"fuzzy_match": args.fuzzy_match,
|
|
45
|
+
"data_mapping": args.data_mapping,
|
|
46
|
+
}
|
|
47
|
+
|
|
41
48
|
if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE:
|
|
42
|
-
if "stack_path" not in input_param:
|
|
43
|
-
logger.error(f"Missing stack_path in configuration file {args.input_path}, please check!")
|
|
44
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
45
49
|
input_param["npu_json_path"] = input_param.pop("npu_path")
|
|
46
50
|
input_param["bench_json_path"] = input_param.pop("bench_path")
|
|
47
|
-
|
|
51
|
+
if "stack_path" not in input_param:
|
|
52
|
+
logger.warning(f"Missing stack_path in the configuration file. "
|
|
53
|
+
f"Automatically detecting stack.json to determine whether to display NPU_Stack_Info.")
|
|
54
|
+
else:
|
|
55
|
+
input_param["stack_json_path"] = input_param.pop("stack_path")
|
|
56
|
+
|
|
48
57
|
if frame_name == Const.PT_FRAMEWORK:
|
|
49
|
-
kwargs = {
|
|
50
|
-
|
|
51
|
-
}
|
|
52
|
-
compare(input_param, args.output_path, stack_mode=args.stack_mode, auto_analyze=auto_analyze,
|
|
53
|
-
fuzzy_match=args.fuzzy_match, **kwargs)
|
|
58
|
+
kwargs = {**common_kwargs, "stack_mode": args.stack_mode}
|
|
59
|
+
compare(input_param, args.output_path, **kwargs)
|
|
54
60
|
else:
|
|
55
61
|
kwargs = {
|
|
62
|
+
**common_kwargs,
|
|
56
63
|
"stack_mode": args.stack_mode,
|
|
57
|
-
"auto_analyze": auto_analyze,
|
|
58
|
-
"fuzzy_match": args.fuzzy_match,
|
|
59
64
|
"cell_mapping": args.cell_mapping,
|
|
60
65
|
"api_mapping": args.api_mapping,
|
|
61
|
-
"data_mapping": args.data_mapping,
|
|
62
66
|
"layer_mapping": args.layer_mapping
|
|
63
67
|
}
|
|
64
|
-
|
|
65
68
|
ms_compare(input_param, args.output_path, **kwargs)
|
|
66
69
|
elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
|
|
67
70
|
kwargs = {
|
|
71
|
+
**common_kwargs,
|
|
68
72
|
"stack_mode": args.stack_mode,
|
|
69
|
-
"auto_analyze": auto_analyze,
|
|
70
|
-
"fuzzy_match": args.fuzzy_match,
|
|
71
73
|
"is_print_compare_log": input_param.get("is_print_compare_log", True),
|
|
72
74
|
"cell_mapping": args.cell_mapping,
|
|
73
75
|
"api_mapping": args.api_mapping,
|
|
74
|
-
"data_mapping": args.data_mapping,
|
|
75
76
|
"layer_mapping": args.layer_mapping
|
|
76
77
|
}
|
|
77
78
|
if input_param.get("rank_id") is not None:
|