mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- msprobe/README.md +46 -16
- msprobe/__init__.py +16 -1
- msprobe/config.json +0 -2
- msprobe/core/advisor/advisor.py +8 -8
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +64 -3
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +54 -9
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +21 -11
- msprobe/core/common/utils.py +153 -167
- msprobe/core/common_config.py +18 -25
- msprobe/core/compare/acc_compare.py +209 -36
- msprobe/core/compare/check.py +102 -17
- msprobe/core/compare/compare_cli.py +21 -1
- msprobe/core/compare/highlight.py +41 -5
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +21 -6
- msprobe/core/compare/utils.py +82 -48
- msprobe/core/data_dump/data_collector.py +31 -32
- msprobe/core/data_dump/data_processor/base.py +45 -22
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
- msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +32 -16
- msprobe/core/grad_probe/constant.py +4 -0
- msprobe/core/grad_probe/grad_compare.py +2 -3
- msprobe/core/grad_probe/utils.py +16 -3
- msprobe/docs/01.installation.md +19 -9
- msprobe/docs/02.config_introduction.md +52 -80
- msprobe/docs/03.config_examples.md +3 -13
- msprobe/docs/04.acl_config_examples.md +11 -9
- msprobe/docs/05.data_dump_PyTorch.md +140 -12
- msprobe/docs/06.data_dump_MindSpore.md +47 -5
- msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/docs/17.grad_probe.md +14 -16
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
- msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
- msprobe/mindspore/cell_processor.py +27 -3
- msprobe/mindspore/common/const.py +2 -0
- msprobe/mindspore/common/utils.py +18 -2
- msprobe/mindspore/compare/distributed_compare.py +9 -22
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +173 -35
- msprobe/mindspore/compare/ms_graph_compare.py +27 -11
- msprobe/mindspore/debugger/debugger_config.py +16 -13
- msprobe/mindspore/debugger/precision_debugger.py +37 -13
- msprobe/mindspore/dump/dump_tool_factory.py +16 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +41 -17
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
- msprobe/mindspore/free_benchmark/common/utils.py +19 -5
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
- msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
- msprobe/mindspore/grad_probe/global_context.py +18 -8
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/service.py +42 -123
- msprobe/pytorch/__init__.py +20 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +20 -5
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +26 -11
- msprobe/pytorch/common/utils.py +40 -35
- msprobe/pytorch/compare/distributed_compare.py +11 -11
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +38 -6
- msprobe/pytorch/debugger/debugger_config.py +52 -39
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/enums.py +28 -0
- msprobe/pytorch/free_benchmark/common/params.py +15 -0
- msprobe/pytorch/free_benchmark/common/utils.py +17 -1
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +10 -11
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +17 -2
- msprobe/pytorch/online_dispatch/compare.py +11 -12
- msprobe/pytorch/online_dispatch/single_compare.py +7 -7
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
- msprobe/pytorch/online_dispatch/utils.py +1 -4
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +9 -10
- msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
- msprobe/pytorch/parse_tool/lib/utils.py +28 -24
- msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
- msprobe/pytorch/pt_config.py +167 -38
- msprobe/pytorch/service.py +97 -32
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,14 +1,30 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import multiprocessing
|
|
2
17
|
import os
|
|
3
|
-
import json
|
|
4
18
|
import pandas as pd
|
|
5
|
-
from
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
from msprobe.core.common.file_utils import load_json
|
|
6
21
|
from msprobe.core.common.const import CompareConst, Const
|
|
7
22
|
from msprobe.core.common.exceptions import FileCheckException
|
|
8
23
|
from msprobe.core.common.log import logger
|
|
9
|
-
from msprobe.core.common.utils import add_time_with_xlsx, CompareException
|
|
24
|
+
from msprobe.core.common.utils import add_time_with_xlsx, CompareException, check_op_str_pattern_valid
|
|
10
25
|
from msprobe.core.common.file_utils import remove_path
|
|
11
|
-
from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op
|
|
26
|
+
from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op, check_dump_json_str, \
|
|
27
|
+
check_stack_json_str
|
|
12
28
|
from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
|
|
13
29
|
from msprobe.core.compare.utils import read_op, merge_tensor, get_un_match_accuracy, get_accuracy
|
|
14
30
|
from msprobe.core.compare.multiprocessing_compute import _handle_multi_process, ComparisonResult, _save_cmp_result
|
|
@@ -21,10 +37,53 @@ class Comparator:
|
|
|
21
37
|
|
|
22
38
|
def __init__(self):
|
|
23
39
|
pass
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args):
|
|
43
|
+
result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
|
|
44
|
+
bench_ops_all.get(bench_op_name).get('struct')[0],
|
|
45
|
+
npu_ops_all.get(ms_op_name).get('struct')[1],
|
|
46
|
+
bench_ops_all.get(bench_op_name).get('struct')[1],
|
|
47
|
+
npu_ops_all.get(ms_op_name).get('struct')[2],
|
|
48
|
+
bench_ops_all.get(bench_op_name).get('struct')[2],
|
|
49
|
+
CompareConst.PASS if npu_ops_all.get(ms_op_name).get('struct')[2]
|
|
50
|
+
== bench_ops_all.get(bench_op_name).get('struct')[2]
|
|
51
|
+
else CompareConst.DIFF]
|
|
52
|
+
if args[0]:
|
|
53
|
+
result_item.extend(args[1])
|
|
54
|
+
else:
|
|
55
|
+
result_item.append(CompareConst.NONE)
|
|
56
|
+
return result_item
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def calculate_summary_data(npu_summary_data, bench_summary_data, result_item):
|
|
60
|
+
err_msg = ""
|
|
61
|
+
start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
|
|
62
|
+
warning_flag = False
|
|
63
|
+
for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
|
|
64
|
+
if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
|
|
65
|
+
diff = npu_val - bench_val
|
|
66
|
+
if bench_val != 0:
|
|
67
|
+
relative = str(abs((diff / bench_val) * 100)) + '%'
|
|
68
|
+
else:
|
|
69
|
+
relative = "N/A"
|
|
70
|
+
result_item[start_idx + i] = diff
|
|
71
|
+
result_item[start_idx + i + 4] = relative
|
|
72
|
+
magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
|
|
73
|
+
if magnitude_diff > 0.5:
|
|
74
|
+
warning_flag = True
|
|
75
|
+
else:
|
|
76
|
+
result_item[start_idx + i] = CompareConst.NONE
|
|
77
|
+
accuracy_check = CompareConst.WARNING if warning_flag else ""
|
|
78
|
+
err_msg += "Need double check api accuracy." if warning_flag else ""
|
|
79
|
+
for i in range(start_idx, len(result_item)):
|
|
80
|
+
if str(result_item[i]) in ('inf', '-inf', 'nan'):
|
|
81
|
+
result_item[i] = f'{result_item[i]}\t'
|
|
82
|
+
result_item.append(accuracy_check)
|
|
83
|
+
result_item.append(err_msg)
|
|
24
84
|
|
|
25
85
|
@classmethod
|
|
26
|
-
def make_result_table(cls,result, md5_compare, summary_compare, stack_mode):
|
|
27
|
-
header = []
|
|
86
|
+
def make_result_table(cls, result, md5_compare, summary_compare, stack_mode):
|
|
28
87
|
if md5_compare:
|
|
29
88
|
header = CompareConst.MD5_COMPARE_RESULT_HEADER[:]
|
|
30
89
|
elif summary_compare:
|
|
@@ -47,17 +106,22 @@ class Comparator:
|
|
|
47
106
|
else:
|
|
48
107
|
for row in result:
|
|
49
108
|
del row[-1]
|
|
50
|
-
result_df = pd.DataFrame(result, columns=header)
|
|
109
|
+
result_df = pd.DataFrame(result, columns=header, dtype='object')
|
|
51
110
|
return result_df
|
|
52
111
|
|
|
53
112
|
@classmethod
|
|
54
|
-
def gen_merge_list(
|
|
113
|
+
def gen_merge_list(cls, json_data, op_name, stack_json_data, summary_compare, md5_compare):
|
|
55
114
|
op_data = json_data['data'][op_name]
|
|
115
|
+
check_dump_json_str(op_data, op_name)
|
|
56
116
|
op_parsed_list = read_op(op_data, op_name)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
117
|
+
|
|
118
|
+
stack_info = stack_json_data.get(op_name)
|
|
119
|
+
if stack_info is not None:
|
|
120
|
+
check_stack_json_str(stack_info, op_name)
|
|
121
|
+
op_parsed_list.append({
|
|
122
|
+
'full_op_name': op_name,
|
|
123
|
+
'full_info': stack_info
|
|
124
|
+
})
|
|
61
125
|
|
|
62
126
|
merge_list = merge_tensor(op_parsed_list, summary_compare, md5_compare)
|
|
63
127
|
return merge_list
|
|
@@ -67,7 +131,7 @@ class Comparator:
|
|
|
67
131
|
b_op_name = bench_dict["op_name"]
|
|
68
132
|
graph_mode = check_graph_mode(a_op_name[0], b_op_name[0])
|
|
69
133
|
|
|
70
|
-
frame_name = getattr(self,"frame_name")
|
|
134
|
+
frame_name = getattr(self, "frame_name")
|
|
71
135
|
if frame_name == "PTComparator":
|
|
72
136
|
from msprobe.pytorch.compare.match import graph_mapping
|
|
73
137
|
if graph_mode:
|
|
@@ -94,11 +158,11 @@ class Comparator:
|
|
|
94
158
|
return n_index, len(bench_queue) - 1
|
|
95
159
|
return -1, -1
|
|
96
160
|
|
|
97
|
-
def compare_process(self,
|
|
98
|
-
|
|
99
|
-
npu_json_data =
|
|
100
|
-
bench_json_data =
|
|
101
|
-
stack_json_data =
|
|
161
|
+
def compare_process(self, file_lists, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False):
|
|
162
|
+
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
163
|
+
npu_json_data = load_json(npu_json_path)
|
|
164
|
+
bench_json_data = load_json(bench_json_path)
|
|
165
|
+
stack_json_data = load_json(stack_json_path)
|
|
102
166
|
|
|
103
167
|
if fuzzy_match:
|
|
104
168
|
logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
|
|
@@ -114,14 +178,19 @@ class Comparator:
|
|
|
114
178
|
last_npu_ops_len = 0
|
|
115
179
|
last_bench_ops_len = 0
|
|
116
180
|
|
|
181
|
+
npu_api_nums = len(npu_json_data['data'])
|
|
182
|
+
progress_bar = tqdm(total=npu_api_nums, desc="API/Module Read Progress", unit="item", ncols=100)
|
|
183
|
+
|
|
117
184
|
while True:
|
|
118
185
|
if not read_err_npu and not read_err_bench:
|
|
119
186
|
break
|
|
120
187
|
try:
|
|
121
188
|
last_npu_ops_len = len(npu_ops_queue)
|
|
122
189
|
op_name_npu = next(ops_npu_iter)
|
|
190
|
+
check_op_str_pattern_valid(op_name_npu)
|
|
123
191
|
read_err_npu = True
|
|
124
|
-
npu_merge_list = self.gen_merge_list(npu_json_data,op_name_npu,stack_json_data,
|
|
192
|
+
npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data,
|
|
193
|
+
summary_compare, md5_compare)
|
|
125
194
|
if npu_merge_list:
|
|
126
195
|
npu_ops_queue.append(npu_merge_list)
|
|
127
196
|
except StopIteration:
|
|
@@ -129,12 +198,16 @@ class Comparator:
|
|
|
129
198
|
try:
|
|
130
199
|
last_bench_ops_len = len(bench_ops_queue)
|
|
131
200
|
op_name_bench = next(ops_bench_iter)
|
|
132
|
-
|
|
201
|
+
check_op_str_pattern_valid(op_name_bench)
|
|
202
|
+
bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data,
|
|
203
|
+
summary_compare, md5_compare)
|
|
133
204
|
if bench_merge_list:
|
|
134
205
|
bench_ops_queue.append(bench_merge_list)
|
|
135
206
|
except StopIteration:
|
|
136
207
|
read_err_bench = False
|
|
137
208
|
|
|
209
|
+
progress_bar.update(1)
|
|
210
|
+
|
|
138
211
|
# merge all boolean expressions
|
|
139
212
|
both_empty = not npu_ops_queue and not bench_ops_queue
|
|
140
213
|
no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len)
|
|
@@ -163,7 +236,91 @@ class Comparator:
|
|
|
163
236
|
|
|
164
237
|
result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
|
|
165
238
|
return result_df
|
|
166
|
-
|
|
239
|
+
|
|
240
|
+
def merge_data(self, json_data, stack_json_data, summary_compare, md5_compare):
|
|
241
|
+
ops_all = {}
|
|
242
|
+
for op_name in json_data.get('data', {}):
|
|
243
|
+
merge_list = self.gen_merge_list(json_data, op_name, stack_json_data, summary_compare,
|
|
244
|
+
md5_compare)
|
|
245
|
+
if merge_list:
|
|
246
|
+
input_index, output_index = 0, 0
|
|
247
|
+
for index, input_or_output in enumerate(merge_list['op_name']):
|
|
248
|
+
input_or_output_list = input_or_output.split(Const.SEP)
|
|
249
|
+
data_name = merge_list.get('data_name')
|
|
250
|
+
data_name = data_name[index] if data_name else None
|
|
251
|
+
if Const.INPUT in input_or_output_list or Const.KWARGS in input_or_output_list:
|
|
252
|
+
ops_all[input_or_output] = {'struct': merge_list.get('input_struct')[input_index],
|
|
253
|
+
'summary': merge_list.get('summary')[index],
|
|
254
|
+
'data_name': data_name,
|
|
255
|
+
'stack_info': merge_list.get('stack_info')}
|
|
256
|
+
input_index += 1
|
|
257
|
+
|
|
258
|
+
elif Const.OUTPUT in input_or_output_list:
|
|
259
|
+
ops_all[input_or_output] = {'struct': merge_list.get('output_struct')[output_index],
|
|
260
|
+
'summary': merge_list.get('summary')[index],
|
|
261
|
+
'data_name': data_name,
|
|
262
|
+
'stack_info': merge_list.get('stack_info')}
|
|
263
|
+
output_index += 1
|
|
264
|
+
return ops_all
|
|
265
|
+
|
|
266
|
+
def get_accuracy(self, npu_ops_all, bench_ops_all, summary_compare, md5_compare):
|
|
267
|
+
result = []
|
|
268
|
+
for ms_op_name, bench_op_name in self.data_mapping_dict.items():
|
|
269
|
+
if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
|
|
270
|
+
npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
|
|
271
|
+
bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
|
|
272
|
+
has_stack = npu_stack_info and bench_stack_info
|
|
273
|
+
if md5_compare:
|
|
274
|
+
result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all,
|
|
275
|
+
bench_ops_all, has_stack, npu_stack_info))
|
|
276
|
+
continue
|
|
277
|
+
if summary_compare:
|
|
278
|
+
result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
|
|
279
|
+
bench_ops_all.get(bench_op_name).get('struct')[0],
|
|
280
|
+
npu_ops_all.get(ms_op_name).get('struct')[1],
|
|
281
|
+
bench_ops_all.get(bench_op_name).get('struct')[1],
|
|
282
|
+
" ", " ", " ", " ", " ", " ", " ", " "]
|
|
283
|
+
else:
|
|
284
|
+
result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
|
|
285
|
+
bench_ops_all.get(bench_op_name).get('struct')[0],
|
|
286
|
+
npu_ops_all.get(ms_op_name).get('struct')[1],
|
|
287
|
+
bench_ops_all.get(bench_op_name).get('struct')[1],
|
|
288
|
+
" ", " ", " ", " ", " "]
|
|
289
|
+
npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
|
|
290
|
+
result_item.extend(npu_summary_data)
|
|
291
|
+
bench_summary_data = bench_ops_all.get(bench_op_name).get("summary")
|
|
292
|
+
result_item.extend(bench_summary_data)
|
|
293
|
+
if summary_compare:
|
|
294
|
+
self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item)
|
|
295
|
+
else:
|
|
296
|
+
result_item.append(CompareConst.ACCURACY_CHECK_YES)
|
|
297
|
+
result_item.append("")
|
|
298
|
+
if has_stack:
|
|
299
|
+
result_item.extend(npu_stack_info)
|
|
300
|
+
else:
|
|
301
|
+
result_item.append(CompareConst.NONE)
|
|
302
|
+
if not (summary_compare or md5_compare):
|
|
303
|
+
result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None))
|
|
304
|
+
result.append(result_item)
|
|
305
|
+
elif ms_op_name not in npu_ops_all:
|
|
306
|
+
logger.warning(f'Can not find npu op name : `{ms_op_name}` in npu dump json file.')
|
|
307
|
+
elif bench_op_name not in npu_ops_all:
|
|
308
|
+
logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.')
|
|
309
|
+
return result
|
|
310
|
+
|
|
311
|
+
def compare_process_custom(self, file_lists, stack_mode, summary_compare=False, md5_compare=False):
|
|
312
|
+
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
313
|
+
npu_json_data = load_json(npu_json_path)
|
|
314
|
+
bench_json_data = load_json(bench_json_path)
|
|
315
|
+
stack_json_data = load_json(stack_json_path)
|
|
316
|
+
|
|
317
|
+
npu_ops_all = self.merge_data(npu_json_data, stack_json_data, summary_compare, md5_compare)
|
|
318
|
+
bench_ops_all = self.merge_data(bench_json_data, stack_json_data, summary_compare, md5_compare)
|
|
319
|
+
|
|
320
|
+
result = self.get_accuracy(npu_ops_all, bench_ops_all, summary_compare, md5_compare)
|
|
321
|
+
result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
|
|
322
|
+
return result_df
|
|
323
|
+
|
|
167
324
|
def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
|
|
168
325
|
npu_bench_name_list = op_name_mapping_dict[npu_op_name]
|
|
169
326
|
data_name = npu_bench_name_list[1]
|
|
@@ -178,9 +335,11 @@ class Comparator:
|
|
|
178
335
|
if frame_name == "MSComparator":
|
|
179
336
|
n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX)
|
|
180
337
|
if self.cross_frame:
|
|
181
|
-
b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
|
|
338
|
+
b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
|
|
339
|
+
bench_op_name + Const.PT_SUFFIX, load_pt_file=True)
|
|
182
340
|
else:
|
|
183
|
-
b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
|
|
341
|
+
b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
|
|
342
|
+
bench_op_name + Const.NUMPY_SUFFIX)
|
|
184
343
|
else:
|
|
185
344
|
n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX)
|
|
186
345
|
b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.PT_SUFFIX)
|
|
@@ -237,19 +396,31 @@ class Comparator:
|
|
|
237
396
|
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
238
397
|
remove_path(file_path)
|
|
239
398
|
highlight_dict = {'red_rows': [], 'yellow_rows': []}
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
399
|
+
|
|
400
|
+
npu_json = input_parma.get("npu_json_path")
|
|
401
|
+
bench_json = input_parma.get("bench_json_path")
|
|
402
|
+
stack_json = input_parma.get("stack_json_path")
|
|
403
|
+
if self.data_mapping:
|
|
404
|
+
result_df = self.compare_process_custom([npu_json, bench_json, stack_json], stack_mode,
|
|
405
|
+
summary_compare, md5_compare)
|
|
406
|
+
else:
|
|
244
407
|
result_df = self.compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match,
|
|
245
|
-
|
|
408
|
+
summary_compare, md5_compare)
|
|
409
|
+
|
|
410
|
+
if not result_df.values.tolist():
|
|
411
|
+
logger.warning("Can`t match any op.")
|
|
412
|
+
return
|
|
246
413
|
|
|
247
414
|
if not md5_compare and not summary_compare:
|
|
248
415
|
result_df = self._do_multi_process(input_parma, result_df)
|
|
416
|
+
|
|
417
|
+
logger.info("Highlight suspicious API/Module start.")
|
|
249
418
|
find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare)
|
|
250
419
|
highlight_rows_xlsx(result_df, highlight_dict, file_path)
|
|
420
|
+
logger.info("Highlight suspicious API/Module finish.")
|
|
421
|
+
|
|
251
422
|
if auto_analyze:
|
|
252
|
-
advisor = Advisor(result_df, output_path)
|
|
423
|
+
advisor = Advisor(result_df, output_path, suffix)
|
|
253
424
|
advisor.analysis()
|
|
254
425
|
|
|
255
426
|
def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
|
|
@@ -265,13 +436,14 @@ class Comparator:
|
|
|
265
436
|
bench_op_name = result_df.iloc[i, 1]
|
|
266
437
|
if is_print_compare_log:
|
|
267
438
|
logger.info("start compare: {}".format(npu_op_name))
|
|
268
|
-
cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg =
|
|
269
|
-
npu_op_name, bench_op_name, dump_path_dict, input_param)
|
|
439
|
+
cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = \
|
|
440
|
+
self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
|
|
270
441
|
if is_print_compare_log:
|
|
271
442
|
logger.info(
|
|
272
|
-
"[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {},
|
|
273
|
-
|
|
274
|
-
|
|
443
|
+
"[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
|
|
444
|
+
one_thousand_err_ratio {}, "
|
|
445
|
+
"five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err,
|
|
446
|
+
err_msg, one_thousand_err_ratio, five_thousand_err_ratio))
|
|
275
447
|
cos_result.append(cos_sim)
|
|
276
448
|
max_err_result.append(max_abs_err)
|
|
277
449
|
max_relative_err_result.append(max_relative_err)
|
|
@@ -290,9 +462,10 @@ class Comparator:
|
|
|
290
462
|
|
|
291
463
|
return _save_cmp_result(idx, cr, result_df, lock)
|
|
292
464
|
|
|
293
|
-
def _do_multi_process(self,input_parma, result_df):
|
|
465
|
+
def _do_multi_process(self, input_parma, result_df):
|
|
294
466
|
try:
|
|
295
|
-
result_df = _handle_multi_process(self.compare_ops, input_parma, result_df,
|
|
467
|
+
result_df = _handle_multi_process(self.compare_ops, input_parma, result_df,
|
|
468
|
+
multiprocessing.Manager().RLock())
|
|
296
469
|
return result_df
|
|
297
470
|
except ValueError as e:
|
|
298
471
|
logger.error('result dataframe is not found.')
|
msprobe/core/compare/check.py
CHANGED
|
@@ -1,5 +1,22 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.core.common.log import logger
|
|
2
|
-
from msprobe.core.compare.utils import rename_api
|
|
17
|
+
from msprobe.core.compare.utils import rename_api
|
|
18
|
+
from msprobe.core.common.utils import check_op_str_pattern_valid, CompareException
|
|
19
|
+
from msprobe.core.common.const import Const
|
|
3
20
|
|
|
4
21
|
|
|
5
22
|
dtype_mapping = {
|
|
@@ -34,8 +51,15 @@ def check_struct_match(npu_dict, bench_dict, cross_frame=False):
|
|
|
34
51
|
if not is_match:
|
|
35
52
|
if len(npu_struct_in) == 0 or len(bench_struct_in) == 0 or len(npu_struct_in) != len(bench_struct_in):
|
|
36
53
|
return False
|
|
37
|
-
|
|
38
|
-
|
|
54
|
+
try:
|
|
55
|
+
struct_in_is_match = check_type_shape_match(npu_struct_in, bench_struct_in)
|
|
56
|
+
struct_out_is_match = check_type_shape_match(npu_struct_out, bench_struct_out)
|
|
57
|
+
except CompareException as error:
|
|
58
|
+
err_msg = f'index out of bounds error occurs in npu or bench api, please check!\n' \
|
|
59
|
+
f'npu_dict: {npu_dict}' \
|
|
60
|
+
f'bench_dict: {bench_dict}'
|
|
61
|
+
logger.error(err_msg)
|
|
62
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
39
63
|
is_match = struct_in_is_match and struct_out_is_match
|
|
40
64
|
return is_match
|
|
41
65
|
|
|
@@ -43,17 +67,27 @@ def check_struct_match(npu_dict, bench_dict, cross_frame=False):
|
|
|
43
67
|
def check_type_shape_match(npu_struct, bench_struct):
|
|
44
68
|
shape_type_match = False
|
|
45
69
|
for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
70
|
+
try:
|
|
71
|
+
npu_type = npu_type_shape[0]
|
|
72
|
+
npu_shape = npu_type_shape[1]
|
|
73
|
+
bench_type = bench_type_shape[0]
|
|
74
|
+
bench_shape = bench_type_shape[1]
|
|
75
|
+
except IndexError as error:
|
|
76
|
+
logger.error(f'length of npu_type_shape: {npu_type_shape} and bench_type_shape: {bench_type_shape} '
|
|
77
|
+
f'should both be 2, please check!')
|
|
78
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
50
79
|
shape_match = npu_shape == bench_shape
|
|
51
80
|
type_match = npu_type == bench_type
|
|
52
81
|
if not type_match:
|
|
53
|
-
ms_type=
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
82
|
+
ms_type = [
|
|
83
|
+
[Const.FLOAT16, Const.FLOAT32], [Const.FLOAT32, Const.FLOAT16],
|
|
84
|
+
[Const.FLOAT16, Const.BFLOAT16], [Const.BFLOAT16, Const.FLOAT16]
|
|
85
|
+
]
|
|
86
|
+
torch_type = [
|
|
87
|
+
[Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16],
|
|
88
|
+
[Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16]
|
|
89
|
+
]
|
|
90
|
+
if ([npu_type, bench_type] in ms_type) or ([npu_type, bench_type] in torch_type):
|
|
57
91
|
type_match = True
|
|
58
92
|
else:
|
|
59
93
|
type_match = False
|
|
@@ -64,9 +98,9 @@ def check_type_shape_match(npu_struct, bench_struct):
|
|
|
64
98
|
|
|
65
99
|
|
|
66
100
|
def check_graph_mode(a_op_name, b_op_name):
|
|
67
|
-
if
|
|
101
|
+
if Const.ATEN in a_op_name and Const.ATEN not in b_op_name:
|
|
68
102
|
return True
|
|
69
|
-
if
|
|
103
|
+
if Const.ATEN not in a_op_name and Const.ATEN in b_op_name:
|
|
70
104
|
return True
|
|
71
105
|
return False
|
|
72
106
|
|
|
@@ -83,13 +117,64 @@ def fuzzy_check_op(npu_name_list, bench_name_list):
|
|
|
83
117
|
|
|
84
118
|
|
|
85
119
|
def fuzzy_check_name(npu_name, bench_name):
|
|
86
|
-
if
|
|
87
|
-
is_match = rename_api(npu_name,
|
|
88
|
-
elif
|
|
89
|
-
is_match = rename_api(npu_name,
|
|
120
|
+
if Const.FORWARD in npu_name and Const.FORWARD in bench_name:
|
|
121
|
+
is_match = rename_api(npu_name, Const.FORWARD) == rename_api(bench_name, Const.FORWARD)
|
|
122
|
+
elif Const.BACKWARD in npu_name and Const.BACKWARD in bench_name:
|
|
123
|
+
is_match = rename_api(npu_name, Const.BACKWARD) == rename_api(bench_name, Const.BACKWARD)
|
|
90
124
|
else:
|
|
91
125
|
is_match = npu_name == bench_name
|
|
92
126
|
return is_match
|
|
93
127
|
|
|
94
128
|
|
|
129
|
+
def check_dump_json_str(op_data, op_name):
|
|
130
|
+
input_list = op_data.get(Const.INPUT_ARGS, None) if op_data.get(Const.INPUT_ARGS, None) else op_data.get(
|
|
131
|
+
Const.INPUT, None)
|
|
132
|
+
input_kwargs = op_data.get(Const.INPUT_KWARGS, None)
|
|
133
|
+
output_list = op_data.get(Const.OUTPUT, None)
|
|
134
|
+
|
|
135
|
+
args = [input_list, input_kwargs, output_list]
|
|
136
|
+
for arg in args:
|
|
137
|
+
if not arg:
|
|
138
|
+
continue
|
|
139
|
+
if isinstance(arg, dict):
|
|
140
|
+
check_json_key_value(arg, op_name)
|
|
141
|
+
else:
|
|
142
|
+
for ele in arg:
|
|
143
|
+
if not ele:
|
|
144
|
+
continue
|
|
145
|
+
check_json_key_value(ele, op_name)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def check_json_key_value(input_output, op_name, depth=0):
|
|
149
|
+
if depth > Const.MAX_DEPTH:
|
|
150
|
+
logger.error(f"string check of data info of {op_name} exceeds the recursion limit.")
|
|
151
|
+
return
|
|
152
|
+
if isinstance(input_output, list):
|
|
153
|
+
for item in input_output:
|
|
154
|
+
check_json_key_value(item, op_name, depth+1)
|
|
155
|
+
elif isinstance(input_output, dict):
|
|
156
|
+
for key, value in input_output.items():
|
|
157
|
+
if isinstance(value, dict):
|
|
158
|
+
check_json_key_value(value, op_name, depth+1)
|
|
159
|
+
else:
|
|
160
|
+
valid_key_value(key, value, op_name)
|
|
161
|
+
|
|
95
162
|
|
|
163
|
+
def valid_key_value(key, value, op_name):
|
|
164
|
+
if key == "shape" and not isinstance(value, (list, tuple)):
|
|
165
|
+
logger.error(f"shape of input or output of {op_name} is not list or tuple, please check!")
|
|
166
|
+
raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
|
|
167
|
+
elif key == "requires_grad" and not isinstance(value, bool):
|
|
168
|
+
logger.error(f"requires_grad of input or output of {op_name} is not bool, please check!")
|
|
169
|
+
raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
|
|
170
|
+
else:
|
|
171
|
+
check_op_str_pattern_valid(value, op_name)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def check_stack_json_str(stack_info, op_name):
|
|
175
|
+
if isinstance(stack_info, list):
|
|
176
|
+
for item in stack_info:
|
|
177
|
+
check_op_str_pattern_valid(item, op_name, stack=True)
|
|
178
|
+
else:
|
|
179
|
+
logger.error(f"Expected stack_info to be a list, but got {type(stack_info).__name__} for '{op_name}'")
|
|
180
|
+
raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import json
|
|
2
17
|
from msprobe.core.common.file_utils import FileOpen, check_file_type
|
|
3
18
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
@@ -23,8 +38,11 @@ def compare_cli(args):
|
|
|
23
38
|
input_param["bench_json_path"] = input_param.pop("bench_path")
|
|
24
39
|
input_param["stack_json_path"] = input_param.pop("stack_path")
|
|
25
40
|
if frame_name == Const.PT_FRAMEWORK:
|
|
41
|
+
kwargs = {
|
|
42
|
+
"data_mapping": args.data_mapping
|
|
43
|
+
}
|
|
26
44
|
compare(input_param, args.output_path, stack_mode=args.stack_mode, auto_analyze=auto_analyze,
|
|
27
|
-
fuzzy_match=args.fuzzy_match)
|
|
45
|
+
fuzzy_match=args.fuzzy_match, **kwargs)
|
|
28
46
|
else:
|
|
29
47
|
kwargs = {
|
|
30
48
|
"stack_mode": args.stack_mode,
|
|
@@ -32,6 +50,8 @@ def compare_cli(args):
|
|
|
32
50
|
"fuzzy_match": args.fuzzy_match,
|
|
33
51
|
"cell_mapping": args.cell_mapping,
|
|
34
52
|
"api_mapping": args.api_mapping,
|
|
53
|
+
"data_mapping": args.data_mapping,
|
|
54
|
+
"layer_mapping": args.layer_mapping
|
|
35
55
|
}
|
|
36
56
|
|
|
37
57
|
ms_compare(input_param, args.output_path, **kwargs)
|
|
@@ -1,5 +1,21 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import math
|
|
2
17
|
import abc
|
|
18
|
+
import re
|
|
3
19
|
from collections import namedtuple
|
|
4
20
|
import numpy as np
|
|
5
21
|
import openpyxl
|
|
@@ -7,7 +23,7 @@ from openpyxl.styles import PatternFill
|
|
|
7
23
|
from msprobe.core.common.utils import get_header_index
|
|
8
24
|
from msprobe.core.common.file_utils import save_workbook
|
|
9
25
|
from msprobe.core.common.log import logger
|
|
10
|
-
from msprobe.core.common.const import CompareConst
|
|
26
|
+
from msprobe.core.common.const import CompareConst, FileCheckConst
|
|
11
27
|
|
|
12
28
|
|
|
13
29
|
class HighlightCheck(abc.ABC):
|
|
@@ -34,9 +50,11 @@ class CheckOneThousandErrorRatio(HighlightCheck):
|
|
|
34
50
|
def apply(self, info, color_columns, summary_compare=True):
|
|
35
51
|
api_in, api_out, num = info
|
|
36
52
|
one_thousand_index = get_header_index('One Thousandth Err Ratio', summary_compare)
|
|
37
|
-
if not isinstance(api_in[one_thousand_index], (float, int)) or
|
|
53
|
+
if (not isinstance(api_in[one_thousand_index], (float, int)) or
|
|
54
|
+
not isinstance(api_out[one_thousand_index], (float, int))):
|
|
38
55
|
return
|
|
39
|
-
if api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and
|
|
56
|
+
if (api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and
|
|
57
|
+
api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED):
|
|
40
58
|
color_columns.red.append(num)
|
|
41
59
|
elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
|
|
42
60
|
color_columns.yellow.append(num)
|
|
@@ -66,7 +84,8 @@ class CheckMaxRelativeDiff(HighlightCheck):
|
|
|
66
84
|
return
|
|
67
85
|
if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
|
|
68
86
|
color_columns.red.append(num)
|
|
69
|
-
elif output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and
|
|
87
|
+
elif (output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and
|
|
88
|
+
input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
|
|
70
89
|
color_columns.yellow.append(num)
|
|
71
90
|
|
|
72
91
|
|
|
@@ -193,7 +212,8 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, m
|
|
|
193
212
|
input_num = num
|
|
194
213
|
else:
|
|
195
214
|
output_num = num
|
|
196
|
-
find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
|
|
215
|
+
find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
|
|
216
|
+
summary_compare, md5_compare)
|
|
197
217
|
|
|
198
218
|
|
|
199
219
|
def highlight_rows_xlsx(result_df, highlight_dict, file_path):
|
|
@@ -205,12 +225,16 @@ def highlight_rows_xlsx(result_df, highlight_dict, file_path):
|
|
|
205
225
|
|
|
206
226
|
# write header
|
|
207
227
|
for j, col_name in enumerate(result_df.columns, start=1):
|
|
228
|
+
if not csv_value_is_valid(col_name):
|
|
229
|
+
raise RuntimeError(f"Malicious value [{col_name}] is not allowed to be written into the xlsx: {file_path}.")
|
|
208
230
|
ws.cell(row=1, column=j, value=col_name)
|
|
209
231
|
|
|
210
232
|
for i, row in enumerate(result_df.iterrows(), start=2):
|
|
211
233
|
for j, value in enumerate(row[1], start=1):
|
|
212
234
|
if not isinstance(value, (float, int)):
|
|
213
235
|
value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value)
|
|
236
|
+
if not csv_value_is_valid(value):
|
|
237
|
+
raise RuntimeError(f"Malicious value [{value}] is not allowed to be written into the xlsx: {file_path}.")
|
|
214
238
|
ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value)
|
|
215
239
|
|
|
216
240
|
if (i - 2) in highlight_dict['red_rows']:
|
|
@@ -221,3 +245,15 @@ def highlight_rows_xlsx(result_df, highlight_dict, file_path):
|
|
|
221
245
|
end_color=CompareConst.YELLOW, fill_type="solid")
|
|
222
246
|
|
|
223
247
|
save_workbook(wb, file_path)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def csv_value_is_valid(value: str) -> bool:
|
|
251
|
+
if not isinstance(value, str):
|
|
252
|
+
return True
|
|
253
|
+
try:
|
|
254
|
+
# -1.00 or +1.00 should be consdiered as digit numbers
|
|
255
|
+
float(value)
|
|
256
|
+
except ValueError:
|
|
257
|
+
# otherwise, they will be considered as formular injections
|
|
258
|
+
return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
|
|
259
|
+
return True
|