mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
msprobe/core/compare/utils.py
CHANGED
|
@@ -15,28 +15,31 @@
|
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
17
|
import re
|
|
18
|
+
import math
|
|
19
|
+
import zlib
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
|
|
18
22
|
import numpy as np
|
|
23
|
+
|
|
19
24
|
from msprobe.core.common.const import Const, CompareConst
|
|
20
|
-
from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger
|
|
25
|
+
from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value
|
|
21
26
|
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
22
27
|
|
|
23
28
|
|
|
24
29
|
def extract_json(dirname, stack_json=False):
|
|
25
30
|
json_path = ''
|
|
26
|
-
for
|
|
27
|
-
if
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
json_path = full_path
|
|
32
|
-
if not stack_json and 'stack' not in json_path:
|
|
33
|
-
break
|
|
34
|
-
if stack_json and 'stack' in json_path:
|
|
35
|
-
break
|
|
31
|
+
for filename in os.listdir(dirname):
|
|
32
|
+
target_file_name = 'stack.json' if stack_json else 'dump.json'
|
|
33
|
+
if filename == target_file_name:
|
|
34
|
+
json_path = os.path.join(dirname, filename)
|
|
35
|
+
break
|
|
36
36
|
|
|
37
37
|
# Provide robustness on invalid directory inputs
|
|
38
38
|
if not json_path:
|
|
39
|
-
|
|
39
|
+
if stack_json:
|
|
40
|
+
logger.error(f'stack.json is not found in dump dir {dirname}.')
|
|
41
|
+
else:
|
|
42
|
+
logger.error(f'dump.json is not found in dump dir {dirname}.')
|
|
40
43
|
raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
|
|
41
44
|
return json_path
|
|
42
45
|
|
|
@@ -44,7 +47,7 @@ def extract_json(dirname, stack_json=False):
|
|
|
44
47
|
def check_and_return_dir_contents(dump_dir, prefix):
|
|
45
48
|
"""
|
|
46
49
|
check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
|
|
47
|
-
pattern: ^{prefix}(?:0|[
|
|
50
|
+
pattern: ^{prefix}(?:0|[1-9][0-9]*)?$
|
|
48
51
|
|
|
49
52
|
Args:
|
|
50
53
|
dump_dir (str): dump dir
|
|
@@ -60,7 +63,7 @@ def check_and_return_dir_contents(dump_dir, prefix):
|
|
|
60
63
|
check_regex_prefix_format_valid(prefix)
|
|
61
64
|
check_file_or_directory_path(dump_dir, True)
|
|
62
65
|
contents = os.listdir(dump_dir)
|
|
63
|
-
pattern = re.compile(rf'^{prefix}(?:0|[
|
|
66
|
+
pattern = re.compile(rf'^{prefix}(?:0|[1-9][0-9]*)?$')
|
|
64
67
|
for name in contents:
|
|
65
68
|
if not pattern.match(name):
|
|
66
69
|
logger.error(
|
|
@@ -84,122 +87,89 @@ def rename_api(npu_name, process):
|
|
|
84
87
|
|
|
85
88
|
|
|
86
89
|
def read_op(op_data, op_name):
|
|
90
|
+
io_name_mapping = {
|
|
91
|
+
Const.INPUT_ARGS: '.input',
|
|
92
|
+
Const.INPUT_KWARGS: '.input',
|
|
93
|
+
Const.INPUT: '.input',
|
|
94
|
+
Const.OUTPUT: '.output'
|
|
95
|
+
}
|
|
96
|
+
|
|
87
97
|
op_parsed_list = []
|
|
88
|
-
|
|
89
|
-
if
|
|
90
|
-
|
|
91
|
-
input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
|
|
92
|
-
op_parsed_list = input_parsed_list.copy()
|
|
93
|
-
input_parsed_list.clear()
|
|
94
|
-
if Const.INPUT_KWARGS in op_data:
|
|
95
|
-
kwargs_item = op_data[Const.INPUT_KWARGS]
|
|
96
|
-
if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list):
|
|
97
|
-
kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '.input', None)
|
|
98
|
-
op_parsed_list += kwarg_parsed_list
|
|
99
|
-
kwarg_parsed_list.clear()
|
|
100
|
-
elif kwargs_item:
|
|
101
|
-
for kwarg in kwargs_item:
|
|
102
|
-
kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '.input.' + kwarg, None)
|
|
103
|
-
op_parsed_list += kwarg_parsed_list
|
|
104
|
-
kwarg_parsed_list.clear()
|
|
105
|
-
if Const.OUTPUT in op_data:
|
|
106
|
-
output_item = op_data[Const.OUTPUT]
|
|
107
|
-
output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
|
|
108
|
-
op_parsed_list += output_parsed_list
|
|
109
|
-
output_parsed_list.clear()
|
|
110
|
-
if Const.BACKWARD in op_name:
|
|
111
|
-
if Const.INPUT in op_data:
|
|
112
|
-
input_item = op_data[Const.INPUT]
|
|
113
|
-
input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
|
|
114
|
-
op_parsed_list = input_parsed_list.copy()
|
|
115
|
-
input_parsed_list.clear()
|
|
116
|
-
if Const.OUTPUT in op_data:
|
|
117
|
-
output_item = op_data[Const.OUTPUT]
|
|
118
|
-
output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
|
|
119
|
-
op_parsed_list += output_parsed_list
|
|
120
|
-
output_parsed_list.clear()
|
|
98
|
+
for name in io_name_mapping:
|
|
99
|
+
if name in op_data:
|
|
100
|
+
op_parsed_list.extend(op_item_parse(op_data[name], op_name + io_name_mapping[name]))
|
|
121
101
|
return op_parsed_list
|
|
122
102
|
|
|
123
103
|
|
|
124
|
-
def op_item_parse(
|
|
104
|
+
def op_item_parse(op_data, op_name: str, depth: int = 0) -> list:
|
|
105
|
+
default_item = {
|
|
106
|
+
'full_op_name': op_name,
|
|
107
|
+
'type': None,
|
|
108
|
+
'Max': None,
|
|
109
|
+
'Min': None,
|
|
110
|
+
'Mean': None,
|
|
111
|
+
'Norm': None,
|
|
112
|
+
'dtype': None,
|
|
113
|
+
'shape': None,
|
|
114
|
+
'md5': None,
|
|
115
|
+
'value': None,
|
|
116
|
+
'data_name': '-1'
|
|
117
|
+
}
|
|
118
|
+
|
|
125
119
|
if depth > Const.MAX_DEPTH:
|
|
126
|
-
logger.error(f
|
|
120
|
+
logger.error(f'parse of api/module of {op_name} exceeds the recursion limit.')
|
|
127
121
|
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
if index is None:
|
|
144
|
-
if isinstance(item, dict):
|
|
145
|
-
full_op_name = op_name + '.0'
|
|
146
|
-
else:
|
|
147
|
-
full_op_name = op_name
|
|
148
|
-
else:
|
|
149
|
-
full_op_name = op_name + Const.SEP + str(index)
|
|
150
|
-
if isinstance(item, dict):
|
|
151
|
-
if 'type' not in item:
|
|
152
|
-
for kwarg in item:
|
|
153
|
-
kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None, depth=depth+1)
|
|
154
|
-
item_list += kwarg_parsed_list
|
|
155
|
-
kwarg_parsed_list.clear()
|
|
156
|
-
elif 'dtype' in item:
|
|
157
|
-
parsed_item = item
|
|
158
|
-
parsed_item['full_op_name'] = full_op_name
|
|
159
|
-
item_list.append(parsed_item)
|
|
160
|
-
elif 'type' in item:
|
|
161
|
-
parsed_item = {}
|
|
162
|
-
if item['type'] == 'torch.Size':
|
|
163
|
-
parsed_item['full_op_name'] = full_op_name
|
|
164
|
-
parsed_item['dtype'] = 'torch.Size'
|
|
165
|
-
parsed_item['shape'] = str(item['value'])
|
|
166
|
-
parsed_item['md5'] = None
|
|
167
|
-
parsed_item['Max'] = None
|
|
168
|
-
parsed_item['Min'] = None
|
|
169
|
-
parsed_item['Mean'] = None
|
|
170
|
-
parsed_item['Norm'] = None
|
|
171
|
-
parsed_item['data_name'] = '-1'
|
|
172
|
-
item_list.append(parsed_item)
|
|
173
|
-
elif item['type'] == 'slice':
|
|
174
|
-
parsed_item['full_op_name'] = full_op_name
|
|
175
|
-
parsed_item['dtype'] = 'slice'
|
|
176
|
-
parsed_item['shape'] = str(np.shape(np.array(item['value'])))
|
|
177
|
-
parsed_item['md5'] = None
|
|
178
|
-
parsed_item['Max'] = None
|
|
179
|
-
parsed_item['Min'] = None
|
|
180
|
-
parsed_item['Mean'] = None
|
|
181
|
-
parsed_item['Norm'] = None
|
|
182
|
-
parsed_item['data_name'] = '-1'
|
|
183
|
-
item_list.append(parsed_item)
|
|
184
|
-
else:
|
|
185
|
-
parsed_item['full_op_name'] = full_op_name
|
|
186
|
-
parsed_item['dtype'] = str(type(item['value']))
|
|
187
|
-
parsed_item['shape'] = '[]'
|
|
188
|
-
parsed_item['md5'] = None
|
|
189
|
-
parsed_item['Max'] = item['value']
|
|
190
|
-
parsed_item['Min'] = item['value']
|
|
191
|
-
parsed_item['Mean'] = item['value']
|
|
192
|
-
parsed_item['Norm'] = item['value']
|
|
193
|
-
parsed_item['data_name'] = '-1'
|
|
194
|
-
item_list.append(parsed_item)
|
|
195
|
-
else:
|
|
196
|
-
resolve_api_special_parameters(item, full_op_name, item_list)
|
|
197
|
-
else:
|
|
198
|
-
for j, item_spec in enumerate(item):
|
|
199
|
-
op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False, depth=depth+1)
|
|
122
|
+
|
|
123
|
+
if op_data is None:
|
|
124
|
+
return [default_item]
|
|
125
|
+
elif not op_data:
|
|
126
|
+
return []
|
|
127
|
+
|
|
128
|
+
item_list = []
|
|
129
|
+
if isinstance(op_data, list):
|
|
130
|
+
for i, data in enumerate(op_data):
|
|
131
|
+
item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), depth + 1))
|
|
132
|
+
elif isinstance(op_data, dict):
|
|
133
|
+
if is_leaf_data(op_data):
|
|
134
|
+
return [gen_op_item(op_data, op_name)]
|
|
135
|
+
for sub_name, sub_data in op_data.items():
|
|
136
|
+
item_list.extend(op_item_parse(sub_data, op_name + Const.SEP + str(sub_name), depth + 1))
|
|
200
137
|
return item_list
|
|
201
138
|
|
|
202
139
|
|
|
140
|
+
def is_leaf_data(op_data):
|
|
141
|
+
return 'type' in op_data and isinstance(op_data['type'], str)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def gen_op_item(op_data, op_name):
|
|
145
|
+
op_item = {}
|
|
146
|
+
op_item.update(op_data)
|
|
147
|
+
op_item['full_op_name'] = op_name
|
|
148
|
+
op_item['data_name'] = op_data.get('data_name', '-1')
|
|
149
|
+
|
|
150
|
+
params = ['Max', 'Min', 'Mean', 'Norm']
|
|
151
|
+
for i in params:
|
|
152
|
+
if i not in op_item:
|
|
153
|
+
op_item[i] = None
|
|
154
|
+
|
|
155
|
+
if not op_item.get('dtype'):
|
|
156
|
+
if op_item.get('type') == 'torch.Size':
|
|
157
|
+
op_item['dtype'] = op_data.get('type')
|
|
158
|
+
op_item['shape'] = str(op_data.get('value'))
|
|
159
|
+
elif op_item.get('type') == 'slice':
|
|
160
|
+
op_item['dtype'] = op_data.get('type')
|
|
161
|
+
op_item['shape'] = str(np.shape(np.array(op_data.get('value'))))
|
|
162
|
+
else:
|
|
163
|
+
op_item['dtype'] = str(type(op_data.get('value')))
|
|
164
|
+
op_item['shape'] = '[]'
|
|
165
|
+
for i in params:
|
|
166
|
+
op_item[i] = op_data.get('value')
|
|
167
|
+
if not op_item.get('md5'):
|
|
168
|
+
op_item['md5'] = f"{zlib.crc32(str(op_data.get('value', '')).encode()):08x}"
|
|
169
|
+
|
|
170
|
+
return op_item
|
|
171
|
+
|
|
172
|
+
|
|
203
173
|
def resolve_api_special_parameters(data_dict, full_op_name, item_list):
|
|
204
174
|
"""
|
|
205
175
|
Function Description:
|
|
@@ -231,131 +201,173 @@ def resolve_api_special_parameters(data_dict, full_op_name, item_list):
|
|
|
231
201
|
item_list.append(parsed_item)
|
|
232
202
|
|
|
233
203
|
|
|
234
|
-
def
|
|
204
|
+
def process_summary_data(summary_data):
|
|
205
|
+
"""处理summary_data中的nan值,返回处理后的列表"""
|
|
206
|
+
return [CompareConst.NAN if isinstance(x, float) and math.isnan(x) else x for x in summary_data]
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def get_rela_diff_summary_mode(result_item, npu_summary_data, bench_summary_data, err_msg):
|
|
210
|
+
start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
|
|
211
|
+
warning_flag = False
|
|
212
|
+
for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
|
|
213
|
+
if all(isinstance(val, (float, int)) and not isinstance(val, bool) for val in [npu_val, bench_val]):
|
|
214
|
+
diff = npu_val - bench_val
|
|
215
|
+
if math.isnan(diff):
|
|
216
|
+
diff = CompareConst.NAN
|
|
217
|
+
relative = CompareConst.NAN
|
|
218
|
+
else:
|
|
219
|
+
if bench_val != 0:
|
|
220
|
+
relative = str(abs((diff / bench_val) * 100)) + '%'
|
|
221
|
+
else:
|
|
222
|
+
relative = CompareConst.N_A
|
|
223
|
+
magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + CompareConst.EPSILON)
|
|
224
|
+
if magnitude_diff > CompareConst.MAGNITUDE:
|
|
225
|
+
warning_flag = True
|
|
226
|
+
result_item[start_idx + i] = diff
|
|
227
|
+
result_item[start_idx + i + CompareConst.STATISTICS_INDICATOR_NUM] = relative
|
|
228
|
+
else:
|
|
229
|
+
result_item[start_idx + i] = CompareConst.N_A
|
|
230
|
+
result_item[start_idx + i + CompareConst.STATISTICS_INDICATOR_NUM] = CompareConst.N_A
|
|
231
|
+
|
|
232
|
+
accuracy_check = CompareConst.WARNING if warning_flag else ""
|
|
233
|
+
err_msg += "Need double check api accuracy." if warning_flag else ""
|
|
234
|
+
for i in range(start_idx, len(result_item)):
|
|
235
|
+
if str(result_item[i]) in ('inf', '-inf', 'nan'):
|
|
236
|
+
result_item[i] = f'{result_item[i]}\t'
|
|
237
|
+
return result_item, accuracy_check, err_msg
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@dataclass
|
|
241
|
+
class ApiItemInfo:
|
|
242
|
+
name: str
|
|
243
|
+
struct: tuple
|
|
244
|
+
stack_info: list
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def stack_column_process(result_item, has_stack, index, key, npu_stack_info):
|
|
248
|
+
if has_stack and index == 0 and key == CompareConst.INPUT_STRUCT:
|
|
249
|
+
result_item.extend(npu_stack_info)
|
|
250
|
+
else:
|
|
251
|
+
result_item.append(CompareConst.NONE)
|
|
252
|
+
return result_item
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def result_item_init(n_info, b_info, dump_mode):
|
|
256
|
+
n_len = len(n_info.struct)
|
|
257
|
+
b_len = len(b_info.struct)
|
|
258
|
+
struct_long_enough = (n_len > 2 and b_len > 2) if dump_mode == Const.MD5 else (n_len > 1 and b_len > 1)
|
|
259
|
+
if struct_long_enough:
|
|
260
|
+
result_item = [
|
|
261
|
+
n_info.name, b_info.name, n_info.struct[0], b_info.struct[0], n_info.struct[1], b_info.struct[1]
|
|
262
|
+
]
|
|
263
|
+
if dump_mode == Const.MD5:
|
|
264
|
+
md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF
|
|
265
|
+
result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result])
|
|
266
|
+
elif dump_mode == Const.SUMMARY:
|
|
267
|
+
result_item.extend([" "] * 8)
|
|
268
|
+
else:
|
|
269
|
+
result_item.extend([" "] * 5)
|
|
270
|
+
else:
|
|
271
|
+
err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \
|
|
272
|
+
f"npu_info_struct is {n_info.struct}\n" \
|
|
273
|
+
f"bench_info_struct is {b_info.struct}"
|
|
274
|
+
logger.error(err_msg)
|
|
275
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
276
|
+
return result_item
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
235
280
|
def get_accuracy_core(n_start, n_len, b_start, b_len, key):
|
|
236
281
|
min_len = min(n_len, b_len)
|
|
237
282
|
npu_stack_info = n_dict.get("stack_info", None)
|
|
238
283
|
bench_stack_info = b_dict.get("stack_info", None)
|
|
239
284
|
has_stack = npu_stack_info and bench_stack_info
|
|
240
285
|
|
|
241
|
-
|
|
242
|
-
if all_mode_bool:
|
|
286
|
+
if dump_mode == Const.ALL:
|
|
243
287
|
npu_data_name = n_dict.get("data_name", None)
|
|
244
288
|
bench_data_name = b_dict.get("data_name", None)
|
|
245
289
|
|
|
246
290
|
for index in range(min_len):
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
b_struct = b_dict[key][index]
|
|
291
|
+
n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name")
|
|
292
|
+
b_name = safe_get_value(b_dict, b_start + index, "b_dict", key="op_name")
|
|
293
|
+
n_struct = safe_get_value(n_dict, index, "n_dict", key=key)
|
|
294
|
+
b_struct = safe_get_value(b_dict, index, "b_dict", key=key)
|
|
252
295
|
err_msg = ""
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
else:
|
|
261
|
-
result_item.append(CompareConst.NONE)
|
|
296
|
+
|
|
297
|
+
npu_info = ApiItemInfo(n_name, n_struct, npu_stack_info)
|
|
298
|
+
bench_info = ApiItemInfo(b_name, b_struct, bench_stack_info)
|
|
299
|
+
result_item = result_item_init(npu_info, bench_info, dump_mode)
|
|
300
|
+
|
|
301
|
+
if dump_mode == Const.MD5:
|
|
302
|
+
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
262
303
|
result.append(result_item)
|
|
263
304
|
continue
|
|
264
305
|
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
result_item =
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
npu_summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
|
|
277
|
-
result_item.extend(npu_summary_data)
|
|
278
|
-
bench_summary_data = b_dict.get(CompareConst.SUMMARY)[b_start + index]
|
|
279
|
-
result_item.extend(bench_summary_data)
|
|
280
|
-
|
|
281
|
-
if summary_compare:
|
|
282
|
-
start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
|
|
283
|
-
warning_flag = False
|
|
284
|
-
for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
|
|
285
|
-
if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
|
|
286
|
-
diff = npu_val - bench_val
|
|
287
|
-
if bench_val != 0:
|
|
288
|
-
relative = str(abs((diff / bench_val) * 100)) + '%'
|
|
289
|
-
else:
|
|
290
|
-
relative = CompareConst.N_A
|
|
291
|
-
result_item[start_idx + i] = diff
|
|
292
|
-
result_item[start_idx + i + 4] = relative
|
|
293
|
-
magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
|
|
294
|
-
if magnitude_diff > 0.5:
|
|
295
|
-
warning_flag = True
|
|
296
|
-
else:
|
|
297
|
-
result_item[start_idx + i] = CompareConst.NONE
|
|
298
|
-
accuracy_check = CompareConst.WARNING if warning_flag else ""
|
|
299
|
-
err_msg += "Need double check api accuracy." if warning_flag else ""
|
|
300
|
-
for i in range(start_idx, len(result_item)):
|
|
301
|
-
if str(result_item[i]) in ('inf', '-inf', 'nan'):
|
|
302
|
-
result_item[i] = f'{result_item[i]}\t'
|
|
303
|
-
|
|
304
|
-
result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES)
|
|
306
|
+
npu_summary_data = safe_get_value(n_dict, n_start + index, "n_dict", key=CompareConst.SUMMARY)
|
|
307
|
+
bench_summary_data = safe_get_value(b_dict, b_start + index, "b_dict", key=CompareConst.SUMMARY)
|
|
308
|
+
result_item.extend(process_summary_data(npu_summary_data))
|
|
309
|
+
result_item.extend(process_summary_data(bench_summary_data))
|
|
310
|
+
|
|
311
|
+
if dump_mode == Const.SUMMARY:
|
|
312
|
+
result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
|
|
313
|
+
bench_summary_data, err_msg)
|
|
314
|
+
|
|
315
|
+
result_item.append(accuracy_check if dump_mode == Const.SUMMARY else CompareConst.ACCURACY_CHECK_YES)
|
|
305
316
|
result_item.append(err_msg)
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
result_item.append(CompareConst.NONE)
|
|
310
|
-
if all_mode_bool:
|
|
311
|
-
result_item.append(npu_data_name[n_start + index])
|
|
317
|
+
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
318
|
+
if dump_mode == Const.ALL:
|
|
319
|
+
result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
|
|
312
320
|
|
|
313
321
|
result.append(result_item)
|
|
314
322
|
|
|
315
323
|
if n_len > b_len:
|
|
316
324
|
for index in range(b_len, n_len):
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
325
|
+
try:
|
|
326
|
+
n_name = n_dict['op_name'][n_start + index]
|
|
327
|
+
n_struct = n_dict[key][index]
|
|
328
|
+
if dump_mode == Const.MD5:
|
|
329
|
+
result_item = [
|
|
330
|
+
n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
|
|
331
|
+
n_struct[2], CompareConst.NAN, CompareConst.NAN
|
|
332
|
+
]
|
|
333
|
+
result.append(result_item)
|
|
334
|
+
continue
|
|
320
335
|
result_item = [
|
|
321
336
|
n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
|
|
322
|
-
|
|
337
|
+
" ", " ", " ", " ", " "
|
|
323
338
|
]
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
result_item.extend(summary_data)
|
|
339
|
+
summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
|
|
340
|
+
result_item.extend(summary_data)
|
|
341
|
+
summary_data = [CompareConst.NAN for _ in range(len(n_dict.get(CompareConst.SUMMARY)[0]))]
|
|
342
|
+
result_item.extend(summary_data)
|
|
343
|
+
except IndexError as e:
|
|
344
|
+
err_msg = "index out of bounds error occurs, please check!\n" \
|
|
345
|
+
f"n_dict is {n_dict}"
|
|
346
|
+
logger.error(err_msg)
|
|
347
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
334
348
|
|
|
335
349
|
err_msg = ""
|
|
336
350
|
result_item.append(CompareConst.ACCURACY_CHECK_YES)
|
|
337
351
|
result_item.append(err_msg)
|
|
338
|
-
|
|
339
|
-
if
|
|
340
|
-
result_item.
|
|
341
|
-
else:
|
|
342
|
-
result_item.append(CompareConst.NONE)
|
|
343
|
-
if all_mode_bool:
|
|
344
|
-
result_item.append(npu_data_name[n_start + index])
|
|
352
|
+
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
353
|
+
if dump_mode == Const.ALL:
|
|
354
|
+
result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
|
|
345
355
|
|
|
346
356
|
result.append(result_item)
|
|
347
357
|
|
|
348
358
|
n_num = len(n_dict['op_name'])
|
|
349
359
|
b_num = len(b_dict['op_name'])
|
|
350
|
-
n_num_input = len([name for name in n_dict['op_name']
|
|
351
|
-
|
|
360
|
+
n_num_input = len([name for name in n_dict['op_name']
|
|
361
|
+
if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
|
|
362
|
+
b_num_input = len([name for name in b_dict['op_name']
|
|
363
|
+
if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
|
|
352
364
|
n_num_output = n_num - n_num_input
|
|
353
365
|
b_num_output = b_num - b_num_input
|
|
354
366
|
get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
|
|
355
367
|
get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, 'output_struct')
|
|
356
368
|
|
|
357
369
|
|
|
358
|
-
def get_un_match_accuracy(result, n_dict,
|
|
370
|
+
def get_un_match_accuracy(result, n_dict, dump_mode):
|
|
359
371
|
index_out = 0
|
|
360
372
|
npu_stack_info = n_dict.get("stack_info", None)
|
|
361
373
|
bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
|
|
@@ -363,14 +375,22 @@ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
|
|
|
363
375
|
accuracy_check_res = CompareConst.N_A
|
|
364
376
|
for index, n_name in enumerate(n_dict["op_name"]):
|
|
365
377
|
name_ele_list = n_name.split(Const.SEP)
|
|
366
|
-
if
|
|
367
|
-
n_struct = n_dict
|
|
368
|
-
|
|
369
|
-
n_struct = n_dict
|
|
378
|
+
if Const.INPUT in name_ele_list or Const.KWARGS in name_ele_list:
|
|
379
|
+
n_struct = safe_get_value(n_dict, index, "n_dict", key=CompareConst.INPUT_STRUCT)
|
|
380
|
+
if Const.OUTPUT in name_ele_list:
|
|
381
|
+
n_struct = safe_get_value(n_dict, index_out, "n_dict", key=CompareConst.OUTPUT_STRUCT)
|
|
370
382
|
index_out += 1
|
|
371
383
|
|
|
372
|
-
|
|
373
|
-
|
|
384
|
+
try:
|
|
385
|
+
result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
|
|
386
|
+
except IndexError as e:
|
|
387
|
+
err_msg = "index out of bounds error occurs, please check!\n" \
|
|
388
|
+
f"op_name of n_dict is {n_dict['op_name']}\n" \
|
|
389
|
+
f"input_struct of n_dict is {n_dict[CompareConst.INPUT_STRUCT]}\n" \
|
|
390
|
+
f"output_struct of n_dict is {n_dict[CompareConst.OUTPUT_STRUCT]}"
|
|
391
|
+
logger.error(err_msg)
|
|
392
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
393
|
+
if dump_mode == Const.MD5:
|
|
374
394
|
result_item.extend([CompareConst.N_A] * 3)
|
|
375
395
|
if npu_stack_info and index == 0:
|
|
376
396
|
result_item.extend(npu_stack_info)
|
|
@@ -378,11 +398,11 @@ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
|
|
|
378
398
|
result_item.append(CompareConst.NONE)
|
|
379
399
|
result.append(result_item)
|
|
380
400
|
continue
|
|
381
|
-
if
|
|
401
|
+
if dump_mode == Const.SUMMARY:
|
|
382
402
|
result_item.extend([CompareConst.N_A] * 8)
|
|
383
403
|
else:
|
|
384
404
|
result_item.extend([CompareConst.N_A] * 5)
|
|
385
|
-
npu_summary_data = n_dict
|
|
405
|
+
npu_summary_data = safe_get_value(n_dict, index, "n_dict", key=CompareConst.SUMMARY)
|
|
386
406
|
result_item.extend(npu_summary_data)
|
|
387
407
|
bench_summary_data = [CompareConst.N_A] * 4
|
|
388
408
|
result_item.extend(bench_summary_data)
|
|
@@ -392,22 +412,21 @@ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
|
|
|
392
412
|
result_item.extend(npu_stack_info)
|
|
393
413
|
else:
|
|
394
414
|
result_item.append(CompareConst.NONE)
|
|
395
|
-
if
|
|
415
|
+
if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
|
|
396
416
|
result_item.extend(["-1"])
|
|
397
417
|
result.append(result_item)
|
|
398
418
|
|
|
399
419
|
|
|
400
|
-
def merge_tensor(tensor_list,
|
|
420
|
+
def merge_tensor(tensor_list, dump_mode):
|
|
401
421
|
op_dict = {}
|
|
402
422
|
op_dict["op_name"] = []
|
|
403
|
-
op_dict[
|
|
404
|
-
op_dict[
|
|
405
|
-
op_dict[
|
|
406
|
-
op_dict[
|
|
423
|
+
op_dict[CompareConst.INPUT_STRUCT] = []
|
|
424
|
+
op_dict[CompareConst.KWARGS_STRUCT] = []
|
|
425
|
+
op_dict[CompareConst.OUTPUT_STRUCT] = []
|
|
426
|
+
op_dict[Const.SUMMARY] = []
|
|
407
427
|
op_dict["stack_info"] = []
|
|
408
428
|
|
|
409
|
-
|
|
410
|
-
if all_mode_bool:
|
|
429
|
+
if dump_mode == Const.ALL:
|
|
411
430
|
op_dict["data_name"] = []
|
|
412
431
|
|
|
413
432
|
for tensor in tensor_list:
|
|
@@ -416,38 +435,44 @@ def merge_tensor(tensor_list, summary_compare, md5_compare):
|
|
|
416
435
|
break
|
|
417
436
|
op_dict["op_name"].append(tensor['full_op_name'])
|
|
418
437
|
name_ele_list = tensor['full_op_name'].split(Const.SEP)
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
if all_mode_bool:
|
|
438
|
+
name_to_struct_mapping = {
|
|
439
|
+
Const.INPUT: CompareConst.INPUT_STRUCT,
|
|
440
|
+
Const.KWARGS: CompareConst.KWARGS_STRUCT,
|
|
441
|
+
Const.OUTPUT: CompareConst.OUTPUT_STRUCT
|
|
442
|
+
}
|
|
443
|
+
for name_key, struct_key in name_to_struct_mapping.items():
|
|
444
|
+
if name_key in name_ele_list:
|
|
445
|
+
if dump_mode == Const.MD5:
|
|
446
|
+
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
|
|
447
|
+
else:
|
|
448
|
+
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
|
|
449
|
+
break
|
|
450
|
+
op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]])
|
|
451
|
+
|
|
452
|
+
if dump_mode == Const.ALL:
|
|
436
453
|
op_dict["data_name"].append(tensor['data_name'])
|
|
437
|
-
data_name = op_dict
|
|
454
|
+
data_name = safe_get_value(op_dict, -1, "op_dict", key="data_name").rsplit(Const.SEP, 1)[0]
|
|
438
455
|
if data_name != "-1":
|
|
439
456
|
op_dict["op_name"][-1] = data_name
|
|
440
457
|
|
|
441
|
-
if not op_dict[
|
|
442
|
-
del op_dict[
|
|
458
|
+
if not op_dict[CompareConst.KWARGS_STRUCT]:
|
|
459
|
+
del op_dict[CompareConst.KWARGS_STRUCT]
|
|
443
460
|
return op_dict if op_dict["op_name"] else {}
|
|
444
461
|
|
|
445
462
|
|
|
463
|
+
def print_compare_ends_info():
|
|
464
|
+
total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
|
|
465
|
+
logger.info('*' * total_len)
|
|
466
|
+
logger.info(f"*{CompareConst.COMPARE_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
|
|
467
|
+
logger.info('*' * total_len)
|
|
468
|
+
|
|
469
|
+
|
|
446
470
|
def _compare_parser(parser):
|
|
447
471
|
parser.add_argument("-i", "--input_path", dest="input_path", type=str,
|
|
448
472
|
help="<Required> The compare input path, a dict json.", required=True)
|
|
449
473
|
parser.add_argument("-o", "--output_path", dest="output_path", type=str,
|
|
450
|
-
help="<Required> The compare task result out path.",
|
|
474
|
+
help="<Required> The compare task result out path. Default path: ./output",
|
|
475
|
+
required=False, default="./output", nargs="?", const="./output")
|
|
451
476
|
parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
|
|
452
477
|
help="<optional> Whether to save stack info.", required=False)
|
|
453
478
|
parser.add_argument("-c", "--compare_only", dest="compare_only", action="store_true",
|
|
@@ -457,8 +482,8 @@ def _compare_parser(parser):
|
|
|
457
482
|
parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True,
|
|
458
483
|
help="<optional> The cell mapping file path.", required=False)
|
|
459
484
|
parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
|
|
460
|
-
help="<optional> The api mapping file path.", required=False)
|
|
485
|
+
help="<optional> The api mapping file path.", required=False)
|
|
461
486
|
parser.add_argument("-dm", "--data_mapping", dest="data_mapping", type=str,
|
|
462
487
|
help="<optional> The data mapping file path.", required=False)
|
|
463
|
-
parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str,
|
|
488
|
+
parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True,
|
|
464
489
|
help="<optional> The layer mapping file path.", required=False)
|