mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -13,111 +13,234 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import multiprocessing
|
|
17
16
|
import os
|
|
18
17
|
import re
|
|
19
|
-
from
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from collections import defaultdict
|
|
20
20
|
|
|
21
|
+
import numpy as np
|
|
21
22
|
import pandas as pd
|
|
22
23
|
from tqdm import tqdm
|
|
23
24
|
|
|
24
25
|
from msprobe.core.advisor.advisor import Advisor
|
|
25
26
|
from msprobe.core.common.const import CompareConst, Const
|
|
26
27
|
from msprobe.core.common.exceptions import FileCheckException
|
|
27
|
-
from msprobe.core.common.file_utils import load_json, remove_path
|
|
28
|
+
from msprobe.core.common.file_utils import load_json, remove_path, create_directory
|
|
28
29
|
from msprobe.core.common.log import logger
|
|
29
|
-
from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid,
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
from msprobe.core.compare.
|
|
33
|
-
|
|
34
|
-
from msprobe.core.compare.
|
|
35
|
-
from msprobe.core.compare.
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
30
|
+
from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, \
|
|
31
|
+
set_dump_path, get_dump_mode, check_compare_param, check_configuration_param, load_stack_json, get_file_type
|
|
32
|
+
from msprobe.core.compare.check import check_dump_json_str, check_stack_json_str, cross_dtype_mapping
|
|
33
|
+
from msprobe.core.compare.utils import merge_tensor, print_compare_ends_info, read_op, \
|
|
34
|
+
reorder_op_x_list, set_stack_json_path
|
|
35
|
+
from msprobe.core.compare.config import ModeConfig, MappingConfig, MappingDict
|
|
36
|
+
from msprobe.core.compare.multiprocessing_compute import CompareRealData
|
|
37
|
+
from msprobe.core.compare.highlight import HighLight
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class ComparisonConfig:
|
|
42
|
+
dump_mode: str
|
|
43
|
+
stack_mode: bool
|
|
44
|
+
auto_analyze: bool
|
|
45
|
+
fuzzy_match: bool
|
|
46
|
+
data_mapping: dict
|
|
47
|
+
suffix: str
|
|
48
|
+
cell_mapping: dict
|
|
49
|
+
api_mapping: dict
|
|
50
|
+
layer_mapping: dict
|
|
51
|
+
compared_file_type: str
|
|
45
52
|
|
|
46
53
|
|
|
47
54
|
class Comparator:
|
|
48
|
-
def __init__(self, mode_config: ModeConfig):
|
|
49
|
-
self.
|
|
50
|
-
self.
|
|
51
|
-
self.
|
|
52
|
-
self.
|
|
55
|
+
def __init__(self, file_reader, mode_config: ModeConfig, mapping_config: MappingConfig, is_cross_framework=False):
|
|
56
|
+
self.file_reader = file_reader
|
|
57
|
+
self.mode_config = mode_config
|
|
58
|
+
self.mapping_config = mapping_config
|
|
59
|
+
self.cross_frame = is_cross_framework
|
|
60
|
+
|
|
61
|
+
self.mapping_dict = MappingDict(mapping_config)
|
|
53
62
|
|
|
54
63
|
@staticmethod
|
|
55
|
-
def
|
|
56
|
-
|
|
57
|
-
|
|
64
|
+
def process_output_file(output_path, suffix, compared_file_type):
|
|
65
|
+
file_name_prefix_mapping = {
|
|
66
|
+
Const.DUMP_JSON_FILE: "compare_result",
|
|
67
|
+
Const.DEBUG_JSON_FILE: "debug_compare_result"
|
|
68
|
+
}
|
|
69
|
+
file_name_prefix = file_name_prefix_mapping.get(compared_file_type, "compare_result")
|
|
70
|
+
file_name = add_time_with_xlsx(file_name_prefix + suffix)
|
|
71
|
+
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
72
|
+
if os.path.exists(file_path):
|
|
73
|
+
logger.warning(f"{file_path} will be deleted.")
|
|
74
|
+
remove_path(file_path)
|
|
75
|
+
return file_path
|
|
58
76
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
77
|
+
def compare_core(self, input_param, output_path, **kwargs):
|
|
78
|
+
"""
|
|
79
|
+
Compares data from multiple JSON files and generates a comparison report.
|
|
63
80
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
81
|
+
Args:
|
|
82
|
+
input_param (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
|
|
83
|
+
"stack_path").
|
|
84
|
+
output_path (str): The path where the output Excel report will be saved.
|
|
85
|
+
**kwargs: Additional keyword arguments including:
|
|
86
|
+
- stack_mode (bool, optional): Enables stack mode comparison. Defaults to False.
|
|
87
|
+
- auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
|
|
88
|
+
- suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
|
|
89
|
+
- fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
|
|
90
|
+
- dump_mode (str): ALL, SUMMARY, MD5.
|
|
67
91
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
result_item.append(CompareConst.NONE)
|
|
72
|
-
return result_item
|
|
92
|
+
Returns:
|
|
93
|
+
"""
|
|
94
|
+
logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
|
|
73
95
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
err_msg = ""
|
|
77
|
-
result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
|
|
78
|
-
bench_summary_data, err_msg)
|
|
79
|
-
result_item.append(accuracy_check)
|
|
80
|
-
result_item.append(err_msg)
|
|
96
|
+
# get kwargs or set default value
|
|
97
|
+
suffix = kwargs.get('suffix', '')
|
|
81
98
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
else:
|
|
94
|
-
value[k] = CompareConst.N_A
|
|
95
|
-
return value
|
|
99
|
+
# process output file
|
|
100
|
+
file_path = self.process_output_file(output_path, suffix, self.mode_config.compared_file_type)
|
|
101
|
+
|
|
102
|
+
# initialize the compare result table and compare general data(name, dtype, shape, statistics/md5, etc.)
|
|
103
|
+
npu_json = input_param.get("npu_json_path")
|
|
104
|
+
bench_json = input_param.get("bench_json_path")
|
|
105
|
+
stack_json = input_param.get("stack_json_path")
|
|
106
|
+
result_df = self.compare_statistics([npu_json, bench_json, stack_json])
|
|
107
|
+
if not result_df.values.tolist():
|
|
108
|
+
logger.warning("Can`t match any op. No compare result file generated.")
|
|
109
|
+
return
|
|
96
110
|
|
|
97
|
-
|
|
98
|
-
|
|
111
|
+
# compare real data
|
|
112
|
+
if self.mode_config.dump_mode == Const.ALL:
|
|
113
|
+
compare_real_data = CompareRealData(self.file_reader, self.mode_config, self.cross_frame)
|
|
114
|
+
result_df = compare_real_data.do_multi_process(input_param, result_df)
|
|
99
115
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
116
|
+
# highlight suspicious API
|
|
117
|
+
highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
|
|
118
|
+
highlight = HighLight(self.mode_config)
|
|
119
|
+
if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE:
|
|
120
|
+
highlight.find_compare_result_error_rows(result_df, highlight_dict)
|
|
121
|
+
highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path)
|
|
122
|
+
|
|
123
|
+
# output compare analysis suggestions
|
|
124
|
+
if self.mode_config.auto_analyze:
|
|
125
|
+
advisor = Advisor(result_df, output_path, suffix)
|
|
126
|
+
advisor.analysis()
|
|
127
|
+
|
|
128
|
+
print_compare_ends_info()
|
|
129
|
+
|
|
130
|
+
def compare_statistics(self, file_list):
|
|
131
|
+
# load and parse json data
|
|
132
|
+
parse_data = ParseData(self.mode_config)
|
|
133
|
+
npu_df, bench_df = parse_data.parse(file_list)
|
|
134
|
+
|
|
135
|
+
npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str)
|
|
136
|
+
bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str)
|
|
137
|
+
|
|
138
|
+
# create new columns for compare op_name and shape
|
|
139
|
+
# process npu_df's COMPARE_KEY whether same or different framework
|
|
140
|
+
process_df = ProcessDf(self.mode_config, self.mapping_config, self.mapping_dict)
|
|
141
|
+
npu_df, bench_df = process_df.process_compare_key_and_shape(npu_df, bench_df)
|
|
142
|
+
|
|
143
|
+
# match npu and bench, match_result contains both npu_info and bench_info
|
|
144
|
+
match = Match(self.mode_config, self.mapping_config, self.cross_frame)
|
|
145
|
+
match_result = match.match_api_infos(npu_df, bench_df)
|
|
146
|
+
# 筛选出npu_name存在的行并填充筛选出行中的缺失值为N/A
|
|
147
|
+
match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A)
|
|
148
|
+
bench_columns = [i + '_y' for i in bench_df.columns]
|
|
149
|
+
match_result.loc[~match.gen_dtype_condition(match_result), bench_columns] = CompareConst.N_A
|
|
150
|
+
|
|
151
|
+
# organize compare result table by renaming columns
|
|
152
|
+
create_table = CreateTable(self.mode_config)
|
|
153
|
+
result_df, header = create_table.make_result_df(match_result)
|
|
154
|
+
|
|
155
|
+
# calculate statistics diff
|
|
156
|
+
calc_stats_diff = CalcStatsDiff(self.mode_config)
|
|
157
|
+
return calc_stats_diff.calc_accuracy(result_df, header)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class ParseData:
|
|
161
|
+
def __init__(self, mode_config: ModeConfig):
|
|
162
|
+
self.mode_config = mode_config
|
|
163
|
+
|
|
164
|
+
def parse(self, file_list):
|
|
165
|
+
npu_json_path, bench_json_path, stack_json_path = file_list
|
|
166
|
+
npu_json_data = load_json(npu_json_path)
|
|
167
|
+
bench_json_data = load_json(bench_json_path)
|
|
168
|
+
stack_json_data = load_stack_json(stack_json_path) if self.mode_config.stack_mode else None
|
|
169
|
+
|
|
170
|
+
# parse json data and generate df
|
|
171
|
+
npu_df = self.gen_data_df(npu_json_data, stack_json_data)
|
|
172
|
+
bench_df = self.gen_data_df(bench_json_data, stack_json_data)
|
|
173
|
+
|
|
174
|
+
return npu_df, bench_df
|
|
175
|
+
|
|
176
|
+
def gen_data_df(self, data_json, stack_json_data):
|
|
177
|
+
result = {
|
|
178
|
+
CompareConst.OP_NAME: [],
|
|
179
|
+
Const.DTYPE: [],
|
|
180
|
+
Const.SHAPE: [],
|
|
181
|
+
Const.SUMMARY: [],
|
|
182
|
+
Const.STACK_INFO: []
|
|
183
|
+
}
|
|
184
|
+
if self.mode_config.dump_mode == Const.ALL:
|
|
185
|
+
result['data_name'] = []
|
|
186
|
+
elif self.mode_config.dump_mode == Const.MD5:
|
|
187
|
+
result[Const.MD5] = []
|
|
188
|
+
|
|
189
|
+
apis_data = data_json.get('data', None)
|
|
190
|
+
if not apis_data:
|
|
191
|
+
logger.warning('No APIs found in dump.json.')
|
|
192
|
+
return pd.DataFrame(result)
|
|
193
|
+
|
|
194
|
+
api_nums = len(apis_data)
|
|
195
|
+
progress_bar = tqdm(total=api_nums, desc="API/Module Read Progress", unit="api/module", ncols=100)
|
|
196
|
+
|
|
197
|
+
# 从json中循环解析API数据,遍历所有API
|
|
198
|
+
for data_name in apis_data:
|
|
199
|
+
check_op_str_pattern_valid(data_name)
|
|
200
|
+
merge_list = self.gen_merge_list(data_json, data_name, stack_json_data)
|
|
201
|
+
if not merge_list:
|
|
202
|
+
continue
|
|
203
|
+
|
|
204
|
+
op_name_list = merge_list.get(CompareConst.OP_NAME)
|
|
205
|
+
summary_list = merge_list.get(Const.SUMMARY)
|
|
206
|
+
data_name_list = merge_list.get('data_name')
|
|
207
|
+
op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
|
|
208
|
+
summary_list,
|
|
209
|
+
data_name_list)
|
|
210
|
+
# 遍历单个API的所有item
|
|
211
|
+
for index, op_name in enumerate(op_name_reorder):
|
|
212
|
+
result[CompareConst.OP_NAME].append(op_name)
|
|
213
|
+
if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name):
|
|
214
|
+
struct = merge_list[CompareConst.INPUT_STRUCT].pop(0)
|
|
215
|
+
elif CompareConst.OUTPUT_PATTERN in op_name:
|
|
216
|
+
struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0)
|
|
217
|
+
elif CompareConst.PARAMS_PATTERN in op_name:
|
|
218
|
+
struct = merge_list[CompareConst.PARAMS_STRUCT].pop(0)
|
|
219
|
+
elif CompareConst.PARAMS_GRAD_PATTERN in op_name:
|
|
220
|
+
struct = merge_list[CompareConst.PARAMS_GRAD_STRUCT].pop(0)
|
|
221
|
+
else:
|
|
222
|
+
struct = merge_list[CompareConst.DEBUG_STRUCT].pop(0)
|
|
223
|
+
result[Const.DTYPE].append(struct[0])
|
|
224
|
+
result[Const.SHAPE].append(struct[1])
|
|
225
|
+
if self.mode_config.dump_mode == Const.MD5:
|
|
226
|
+
result[Const.MD5].append(struct[2])
|
|
227
|
+
result[Const.SUMMARY].append(summary_reorder.pop(0))
|
|
228
|
+
result[Const.STACK_INFO].append(
|
|
229
|
+
merge_list[Const.STACK_INFO][0] if index == 0 and self.mode_config.stack_mode else None)
|
|
230
|
+
if self.mode_config.dump_mode == Const.ALL:
|
|
231
|
+
result['data_name'].append(data_name_reorder.pop(0))
|
|
232
|
+
|
|
233
|
+
progress_bar.update(1)
|
|
234
|
+
progress_bar.close()
|
|
235
|
+
return pd.DataFrame(result)
|
|
114
236
|
|
|
115
237
|
def gen_merge_list(self, json_data, op_name, stack_json_data):
|
|
116
238
|
op_data = json_data['data'][op_name]
|
|
117
|
-
|
|
239
|
+
if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE:
|
|
240
|
+
check_dump_json_str(op_data, op_name)
|
|
118
241
|
op_parsed_list = read_op(op_data, op_name)
|
|
119
242
|
|
|
120
|
-
if self.stack_mode:
|
|
243
|
+
if self.mode_config.stack_mode:
|
|
121
244
|
stack_info = stack_json_data.get(op_name)
|
|
122
245
|
if stack_info is not None:
|
|
123
246
|
check_stack_json_str(stack_info, op_name)
|
|
@@ -127,392 +250,487 @@ class Comparator:
|
|
|
127
250
|
'full_info': stack_info
|
|
128
251
|
})
|
|
129
252
|
|
|
130
|
-
merge_list = merge_tensor(op_parsed_list, self.dump_mode)
|
|
253
|
+
merge_list = merge_tensor(op_parsed_list, self.mode_config.dump_mode)
|
|
131
254
|
return merge_list
|
|
132
255
|
|
|
133
|
-
def check_op(self, npu_dict, bench_dict):
|
|
134
|
-
npu_op_name = npu_dict[CompareConst.OP_NAME]
|
|
135
|
-
bench_op_name = bench_dict[CompareConst.OP_NAME]
|
|
136
|
-
graph_mode = check_graph_mode(safe_get_value(npu_op_name, 0, "npu_op_name"),
|
|
137
|
-
safe_get_value(bench_op_name, 0, "bench_op_name"))
|
|
138
|
-
|
|
139
|
-
frame_name = getattr(self, "frame_name")
|
|
140
|
-
if frame_name == "PTComparator":
|
|
141
|
-
from msprobe.pytorch.compare.match import graph_mapping
|
|
142
|
-
if graph_mode:
|
|
143
|
-
return graph_mapping.match(npu_op_name[0], bench_op_name[0])
|
|
144
|
-
struct_match = check_struct_match(npu_dict, bench_dict)
|
|
145
|
-
if not self.fuzzy_match:
|
|
146
|
-
name_match = npu_op_name == bench_op_name
|
|
147
|
-
return name_match and struct_match
|
|
148
|
-
try:
|
|
149
|
-
name_match = fuzzy_check_op(npu_op_name, bench_op_name)
|
|
150
|
-
except Exception as err:
|
|
151
|
-
logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
|
|
152
|
-
name_match = False
|
|
153
|
-
return name_match and struct_match
|
|
154
256
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
return len(npu_queue) - 1, len(bench_queue) - 1
|
|
161
|
-
for n_index, n_op in enumerate(npu_queue[0: -1]):
|
|
162
|
-
if self.check_op(n_op, bench_queue[-1]):
|
|
163
|
-
return n_index, len(bench_queue) - 1
|
|
164
|
-
return -1, -1
|
|
257
|
+
class ProcessDf:
|
|
258
|
+
def __init__(self, mode_config: ModeConfig, mapping_config: MappingConfig, mapping_dict: MappingDict):
|
|
259
|
+
self.mode_config = mode_config
|
|
260
|
+
self.mapping_config = mapping_config
|
|
261
|
+
self.mapping_dict = mapping_dict
|
|
165
262
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
263
|
+
@staticmethod
|
|
264
|
+
def get_api_name(api_list):
|
|
265
|
+
try:
|
|
266
|
+
api_name = api_list[0] + Const.SEP + api_list[1]
|
|
267
|
+
except IndexError as error:
|
|
268
|
+
logger.error('Failed to retrieve API name, please check if the dump data is reasonable')
|
|
269
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
270
|
+
return api_name
|
|
271
|
+
|
|
272
|
+
def process_compare_key_and_shape(self, npu_df, bench_df):
|
|
273
|
+
npu_df = self.assign_npu_df_compare_key(npu_df, bench_df)
|
|
274
|
+
npu_df[CompareConst.CMP_SHAPE] = npu_df[Const.SHAPE]
|
|
275
|
+
bench_df[CompareConst.CMP_KEY] = bench_df[CompareConst.OP_NAME]
|
|
276
|
+
bench_df[CompareConst.CMP_SHAPE] = bench_df[Const.SHAPE]
|
|
277
|
+
return npu_df, bench_df
|
|
278
|
+
|
|
279
|
+
def assign_npu_df_compare_key(self, npu_df, bench_df):
|
|
280
|
+
"""
|
|
281
|
+
处理 npu_df 的 COMPARE_KEY 赋值逻辑
|
|
171
282
|
|
|
172
|
-
|
|
173
|
-
|
|
283
|
+
:param npu_df: DataFrame,NPU 对比数据
|
|
284
|
+
:param bench_df: DataFrame,Bench 对比数据
|
|
285
|
+
:return: compare_key(name)处理后的 npu_df
|
|
286
|
+
"""
|
|
287
|
+
# 处理api_mapping映射
|
|
288
|
+
if self.mapping_config.api_mapping:
|
|
289
|
+
# 如果用户不传api_mapping.yaml,先使用内置api_mapping.yaml替换npu_op_name
|
|
290
|
+
npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_internal_api_mapping)
|
|
291
|
+
# 如果用户传入api_mapping.yaml,再使用传入api_mapping.yaml进一步替换npu_op_name
|
|
292
|
+
if isinstance(self.mapping_config.api_mapping, str):
|
|
293
|
+
self.modify_compare_data_with_user_mapping(npu_df, bench_df)
|
|
294
|
+
# 处理cell_mapping映射
|
|
295
|
+
elif self.mapping_config.cell_mapping:
|
|
296
|
+
npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping)
|
|
297
|
+
# 处理data_mapping映射
|
|
298
|
+
elif self.mapping_config.data_mapping:
|
|
299
|
+
npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_data_mapping)
|
|
300
|
+
else:
|
|
301
|
+
npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME]
|
|
302
|
+
return npu_df
|
|
303
|
+
|
|
304
|
+
def process_internal_api_mapping(self, npu_op_name):
|
|
305
|
+
# get api name & class name from op_name
|
|
306
|
+
ms_api_name = self.get_api_name(npu_op_name.split(Const.SEP))
|
|
307
|
+
class_name = ms_api_name.split(Const.SEP)[0]
|
|
308
|
+
if class_name == "Mint":
|
|
309
|
+
return npu_op_name.replace("Mint", "Torch")
|
|
310
|
+
elif class_name == "MintFunctional":
|
|
311
|
+
return npu_op_name.replace("MintFunctional", "Functional")
|
|
312
|
+
elif self.mapping_dict.ms_to_pt_mapping.get(ms_api_name):
|
|
313
|
+
return npu_op_name.replace(ms_api_name, self.mapping_dict.ms_to_pt_mapping.get(ms_api_name))
|
|
314
|
+
else:
|
|
315
|
+
return npu_op_name
|
|
316
|
+
|
|
317
|
+
def modify_compare_data_with_user_mapping(self, npu_df, bench_df):
|
|
318
|
+
def gen_input_compare_key(pattern, term):
|
|
319
|
+
is_unmatched = True
|
|
320
|
+
for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')):
|
|
321
|
+
if op_name.split(pattern)[1].startswith(str(prefix)):
|
|
322
|
+
npu_df.loc[index, CompareConst.CMP_KEY] = (
|
|
323
|
+
op_name.replace(pattern + str(prefix),
|
|
324
|
+
pattern + str(mapping_dict.get(f'pt_{term}')[i])))
|
|
325
|
+
is_unmatched = False
|
|
326
|
+
return is_unmatched
|
|
327
|
+
|
|
328
|
+
ms_api_indices_dict = self.get_api_indices_dict(npu_df)
|
|
329
|
+
pt_api_indices_dict = self.get_api_indices_dict(bench_df)
|
|
330
|
+
|
|
331
|
+
for mapping_dict in self.mapping_dict.api_mapping_dict:
|
|
332
|
+
all_length_equal = True
|
|
333
|
+
for k1, k2 in CompareConst.API_MAPPING_KEYS_TO_COMPARE:
|
|
334
|
+
if len(mapping_dict.get(k1, [])) != len(mapping_dict.get(k2, [])):
|
|
335
|
+
all_length_equal = False
|
|
336
|
+
if not all_length_equal:
|
|
337
|
+
logger.warning('The user-defined mapping table is incorrect,\
|
|
338
|
+
make sure that the number of parameters is equal')
|
|
339
|
+
continue
|
|
174
340
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
341
|
+
ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api')
|
|
342
|
+
if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict:
|
|
343
|
+
continue
|
|
344
|
+
for index in ms_api_indices_dict.get(ms_api):
|
|
345
|
+
op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1)
|
|
346
|
+
if CompareConst.INPUT_PATTERN in op_name:
|
|
347
|
+
is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args')
|
|
348
|
+
elif CompareConst.KWARGS_PATTERN in op_name:
|
|
349
|
+
is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args')
|
|
350
|
+
elif CompareConst.OUTPUT_PATTERN in op_name:
|
|
351
|
+
is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output')
|
|
352
|
+
elif CompareConst.PARAMS_PATTERN in op_name:
|
|
353
|
+
is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters')
|
|
354
|
+
elif CompareConst.PARAMS_GRAD_PATTERN in op_name:
|
|
355
|
+
is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad')
|
|
356
|
+
else:
|
|
357
|
+
logger.error(f'Excepted op_name: {op_name}')
|
|
358
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
359
|
+
if is_abandoned:
|
|
360
|
+
npu_df.loc[index, CompareConst.CMP_KEY] = op_name + 'abandoned'
|
|
178
361
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
362
|
+
def get_api_indices_dict(self, op_name_df):
|
|
363
|
+
"""
|
|
364
|
+
生成多个api对应的各自的所有的input、output等的index的键值对字典
|
|
365
|
+
示例:
|
|
366
|
+
{'Functional.conv2d': [0, 1, 2, 3],
|
|
367
|
+
'Functional.batch_norm': [4, 5, 6, 7, 8]
|
|
368
|
+
}
|
|
369
|
+
"""
|
|
370
|
+
api_indices_dict = defaultdict(list)
|
|
371
|
+
for op_index, name in enumerate(op_name_df[CompareConst.OP_NAME]):
|
|
372
|
+
api_name = self.get_api_name(name.split(Const.SEP))
|
|
373
|
+
api_indices_dict[api_name].append(op_index)
|
|
374
|
+
return api_indices_dict
|
|
375
|
+
|
|
376
|
+
def process_cell_mapping(self, npu_op_name):
|
|
377
|
+
if not npu_op_name:
|
|
378
|
+
return CompareConst.N_A
|
|
379
|
+
param_grad_flag = Const.PARAMS_GRAD in npu_op_name.split(Const.SEP)
|
|
380
|
+
if not param_grad_flag and not re.search(Const.REGEX_FORWARD_BACKWARD, npu_op_name):
|
|
381
|
+
return CompareConst.N_A
|
|
382
|
+
npu_op_name = npu_op_name.replace("Cell", "Module", 1)
|
|
383
|
+
if self.mapping_dict.cell_mapping_dict:
|
|
384
|
+
# get cell name & class name from op_name
|
|
385
|
+
# Cell.fc1.Dense.forward.0.input.0
|
|
386
|
+
cell_name = re.split(r'\.(?:forward|backward|parameters_grad)\.', npu_op_name.split(Const.SEP, 1)[-1])[0]
|
|
387
|
+
if cell_name in self.mapping_dict.cell_mapping_dict:
|
|
388
|
+
npu_op_name = npu_op_name.replace(cell_name, self.mapping_dict.cell_mapping_dict[cell_name], 1)
|
|
389
|
+
return npu_op_name
|
|
390
|
+
|
|
391
|
+
def process_data_mapping(self, npu_op_name):
|
|
392
|
+
return self.mapping_dict.data_mapping_dict.get(npu_op_name, npu_op_name)
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
class Match:
|
|
396
|
+
def __init__(self, mode_config: ModeConfig, mapping_config: MappingConfig, cross_frame):
|
|
397
|
+
self.mode_config = mode_config
|
|
398
|
+
self.mapping_config = mapping_config
|
|
399
|
+
self.cross_frame = cross_frame
|
|
185
400
|
|
|
186
|
-
|
|
187
|
-
|
|
401
|
+
@staticmethod
|
|
402
|
+
def put_unmatched_in_table(match_result, npu_op_item):
|
|
403
|
+
npu_columns = npu_op_item.index.tolist()[:-2]
|
|
404
|
+
new_columns = [name[:-1] + 'y' for name in npu_columns]
|
|
405
|
+
na_series = pd.Series([CompareConst.N_A] * len(new_columns), index=new_columns)
|
|
406
|
+
new_result_item = pd.concat([npu_op_item, na_series]).to_frame().T
|
|
407
|
+
new_result_item.columns = CompareConst.MATCH_RESULT_COLUMNS
|
|
408
|
+
match_result = pd.concat([match_result, new_result_item])
|
|
409
|
+
return match_result
|
|
188
410
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data)
|
|
197
|
-
if npu_merge_list:
|
|
198
|
-
npu_ops_queue.append(npu_merge_list)
|
|
199
|
-
except StopIteration:
|
|
200
|
-
read_err_npu = False
|
|
201
|
-
try:
|
|
202
|
-
last_bench_ops_len = len(bench_ops_queue)
|
|
203
|
-
op_name_bench = next(ops_bench_iter)
|
|
204
|
-
check_op_str_pattern_valid(op_name_bench)
|
|
205
|
-
bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data)
|
|
206
|
-
if bench_merge_list:
|
|
207
|
-
bench_ops_queue.append(bench_merge_list)
|
|
208
|
-
except StopIteration:
|
|
209
|
-
read_err_bench = False
|
|
411
|
+
@staticmethod
|
|
412
|
+
def put_matched_in_table(match_result, npu_op_item, bench_op_item):
|
|
413
|
+
head_len = len(CompareConst.MATCH_RESULT_COLUMNS)
|
|
414
|
+
new_result_item = pd.concat([npu_op_item, bench_op_item]).head(head_len).to_frame().T
|
|
415
|
+
new_result_item.columns = CompareConst.MATCH_RESULT_COLUMNS
|
|
416
|
+
match_result = pd.concat([match_result, new_result_item])
|
|
417
|
+
return match_result
|
|
210
418
|
|
|
211
|
-
|
|
419
|
+
@staticmethod
|
|
420
|
+
def rename_api(op_name):
|
|
421
|
+
"""
|
|
422
|
+
原api: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号}
|
|
423
|
+
rename后: {api_type}.{api_name}.{前向反向}.{input/output}.{参数序号}
|
|
424
|
+
"""
|
|
425
|
+
if Const.FORWARD not in op_name and Const.BACKWARD not in op_name:
|
|
426
|
+
return op_name
|
|
427
|
+
process = Const.FORWARD if Const.FORWARD in op_name else Const.BACKWARD
|
|
428
|
+
name_split = op_name.split(process)
|
|
429
|
+
try:
|
|
430
|
+
torch_func_index, in_out = name_split[0], name_split[1]
|
|
431
|
+
except IndexError as error:
|
|
432
|
+
logger.error(f'{op_name} can not be split with {process}, please check!')
|
|
433
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
434
|
+
torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
|
|
435
|
+
torch_func = str(torch_func_split[0]) + Const.SEP + process + str(in_out)
|
|
436
|
+
return torch_func
|
|
437
|
+
|
|
438
|
+
def check_op_item(self, npu_op_item, bench_op_item):
|
|
439
|
+
name_match = self.rename_api(npu_op_item[CompareConst.CMP_KEY]) == self.rename_api(
|
|
440
|
+
bench_op_item[CompareConst.CMP_KEY])
|
|
441
|
+
shape_match = npu_op_item[CompareConst.CMP_SHAPE] == bench_op_item[CompareConst.CMP_SHAPE]
|
|
442
|
+
if name_match and shape_match:
|
|
443
|
+
return True
|
|
444
|
+
else:
|
|
445
|
+
npu_op_name = npu_op_item[CompareConst.OP_NAME]
|
|
446
|
+
bench_op_name = bench_op_item[CompareConst.OP_NAME]
|
|
447
|
+
check_op_str_pattern_valid(npu_op_name)
|
|
448
|
+
check_op_str_pattern_valid(bench_op_name)
|
|
449
|
+
logger.warning(f"{npu_op_name} and {bench_op_name} can not fuzzy match")
|
|
450
|
+
return False
|
|
212
451
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
452
|
+
def match_api_infos(self, npu_df, bench_df):
|
|
453
|
+
"""
|
|
454
|
+
正常匹配和模糊匹配
|
|
455
|
+
"""
|
|
456
|
+
if self.mapping_config.data_mapping:
|
|
457
|
+
match_result = pd.merge(npu_df, bench_df, on=[CompareConst.CMP_KEY], how='left')
|
|
458
|
+
|
|
459
|
+
# reorder match_result by op_name of npu
|
|
460
|
+
op_name_order = npu_df[CompareConst.OP_NAME].tolist()
|
|
461
|
+
match_result[CompareConst.OP_NAME_X] = pd.Categorical(match_result[CompareConst.OP_NAME_X],
|
|
462
|
+
categories=op_name_order, ordered=True)
|
|
463
|
+
match_result = match_result.sort_values(CompareConst.OP_NAME_X).reset_index(drop=True)
|
|
464
|
+
match_result[CompareConst.OP_NAME_X] = match_result[CompareConst.OP_NAME_X].astype('object')
|
|
465
|
+
elif not self.mode_config.fuzzy_match:
|
|
466
|
+
match_result = pd.merge(npu_df, bench_df, on=[CompareConst.CMP_KEY, CompareConst.CMP_SHAPE],
|
|
467
|
+
how='outer')
|
|
468
|
+
else:
|
|
469
|
+
match_result = self.process_fuzzy_match(npu_df, bench_df)
|
|
470
|
+
return match_result
|
|
218
471
|
|
|
219
|
-
|
|
472
|
+
def process_fuzzy_match(self, npu_df, bench_df):
|
|
473
|
+
"""
|
|
474
|
+
模糊匹配通过循环方式匹配api
|
|
475
|
+
"""
|
|
476
|
+
npu_ops_queue = []
|
|
477
|
+
bench_ops_queue = []
|
|
478
|
+
match_result = pd.DataFrame(columns=CompareConst.MATCH_RESULT_COLUMNS)
|
|
479
|
+
|
|
480
|
+
max_len = max(len(npu_df), len(bench_df))
|
|
481
|
+
min_len = min(len(npu_df), len(bench_df))
|
|
482
|
+
for i in range(max_len):
|
|
483
|
+
if i < min_len:
|
|
484
|
+
npu_ops_queue.append(npu_df.iloc[i])
|
|
485
|
+
bench_ops_queue.append(bench_df.iloc[i])
|
|
486
|
+
else:
|
|
487
|
+
try:
|
|
488
|
+
npu_ops_queue.append(npu_df.iloc[i])
|
|
489
|
+
except IndexError:
|
|
490
|
+
pass
|
|
491
|
+
try:
|
|
492
|
+
bench_ops_queue.append(bench_df.iloc[i])
|
|
493
|
+
except IndexError:
|
|
494
|
+
pass
|
|
495
|
+
|
|
496
|
+
# 如果append之后queue状态不一致,则判断结束
|
|
220
497
|
if bool(npu_ops_queue) ^ bool(bench_ops_queue):
|
|
221
|
-
logger.info("Please check whether the number and calls of APIs in NPU and Bench models are consistent.")
|
|
222
498
|
break
|
|
223
499
|
|
|
224
|
-
|
|
500
|
+
npu_match_point, bench_match_point = self.match_op(npu_ops_queue, bench_ops_queue)
|
|
225
501
|
|
|
226
|
-
#
|
|
227
|
-
if
|
|
502
|
+
# 如果没有匹配到,数据放到队列中,跳过。直到后面匹配到,把匹配之前的api放到不匹配中
|
|
503
|
+
if npu_match_point == -1 and bench_match_point == -1:
|
|
228
504
|
continue
|
|
229
505
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
for
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
del npu_ops_queue[0:
|
|
237
|
-
del bench_ops_queue[0:
|
|
238
|
-
|
|
506
|
+
npu_op_item = npu_ops_queue[npu_match_point]
|
|
507
|
+
bench_op_item = bench_ops_queue[bench_match_point]
|
|
508
|
+
unmatched_data = npu_ops_queue[0: npu_match_point]
|
|
509
|
+
for op_item in unmatched_data:
|
|
510
|
+
match_result = self.put_unmatched_in_table(match_result, op_item)
|
|
511
|
+
match_result = self.put_matched_in_table(match_result, npu_op_item, bench_op_item)
|
|
512
|
+
del npu_ops_queue[0: npu_match_point + 1]
|
|
513
|
+
del bench_ops_queue[0: bench_match_point + 1]
|
|
514
|
+
|
|
239
515
|
if npu_ops_queue:
|
|
240
|
-
for
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
result_df = self.make_result_table(result)
|
|
244
|
-
return result_df
|
|
245
|
-
|
|
246
|
-
def merge_data(self, json_data, stack_json_data):
|
|
247
|
-
ops_all = {}
|
|
248
|
-
for op_name in json_data.get('data', {}):
|
|
249
|
-
merge_list = self.gen_merge_list(json_data, op_name, stack_json_data)
|
|
250
|
-
if merge_list:
|
|
251
|
-
struct_to_index_mapping = {
|
|
252
|
-
CompareConst.INPUT_STRUCT: 0,
|
|
253
|
-
CompareConst.OUTPUT_STRUCT: 0,
|
|
254
|
-
CompareConst.PARAMS_STRUCT: 0,
|
|
255
|
-
CompareConst.PARAMS_GRAD_STRUCT: 0
|
|
256
|
-
}
|
|
257
|
-
|
|
258
|
-
op_name_list = merge_list.get(CompareConst.OP_NAME)
|
|
259
|
-
summary_list = merge_list.get(Const.SUMMARY)
|
|
260
|
-
data_name_list = merge_list.get('data_name')
|
|
261
|
-
op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
|
|
262
|
-
summary_list,
|
|
263
|
-
data_name_list)
|
|
264
|
-
for index, op_full_name in enumerate(op_name_reorder):
|
|
265
|
-
data_name = data_name_reorder[index] if data_name_reorder else None
|
|
266
|
-
|
|
267
|
-
_, state = get_name_and_state(op_full_name)
|
|
268
|
-
struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
|
|
269
|
-
if not struct_key:
|
|
270
|
-
continue
|
|
271
|
-
ops_all[op_full_name] = {
|
|
272
|
-
CompareConst.STRUCT: safe_get_value(merge_list, struct_to_index_mapping.get(struct_key),
|
|
273
|
-
"merge_list", key=struct_key),
|
|
274
|
-
CompareConst.SUMMARY: safe_get_value(summary_reorder, index, "summary_reorder"),
|
|
275
|
-
'data_name': data_name,
|
|
276
|
-
'stack_info': merge_list.get('stack_info')
|
|
277
|
-
}
|
|
278
|
-
struct_to_index_mapping[struct_key] += 1
|
|
279
|
-
return ops_all
|
|
280
|
-
|
|
281
|
-
def get_accuracy(self, npu_ops_all, bench_ops_all):
|
|
282
|
-
result = []
|
|
283
|
-
bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all)
|
|
284
|
-
for ms_op_name, bench_op_name in self.data_mapping_dict.items():
|
|
285
|
-
check_op_str_pattern_valid(ms_op_name)
|
|
286
|
-
check_op_str_pattern_valid(bench_op_name)
|
|
287
|
-
if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
|
|
288
|
-
npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
|
|
289
|
-
bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
|
|
290
|
-
has_stack = npu_stack_info and bench_stack_info
|
|
291
|
-
if self.dump_mode == Const.MD5:
|
|
292
|
-
result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all,
|
|
293
|
-
bench_ops_all, has_stack, npu_stack_info))
|
|
294
|
-
continue
|
|
295
|
-
|
|
296
|
-
npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
|
|
297
|
-
bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
|
|
298
|
-
|
|
299
|
-
if len(npu_struct) < 2 or len(bench_struct) < 2:
|
|
300
|
-
logger.error(
|
|
301
|
-
f"The length of npu_struct and bench_struct must be >= 2, "
|
|
302
|
-
f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. "
|
|
303
|
-
f"Please check!"
|
|
304
|
-
)
|
|
305
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
306
|
-
|
|
307
|
-
base_result_item = [
|
|
308
|
-
ms_op_name, bench_op_name,
|
|
309
|
-
npu_struct[0],
|
|
310
|
-
bench_struct[0],
|
|
311
|
-
npu_struct[1],
|
|
312
|
-
bench_struct[1]
|
|
313
|
-
]
|
|
314
|
-
|
|
315
|
-
if self.dump_mode == Const.SUMMARY:
|
|
316
|
-
result_item = base_result_item + [" "] * 8 # 8个统计量数据情况的比对指标
|
|
317
|
-
else:
|
|
318
|
-
result_item = base_result_item + [" "] * 6 # 6个真实数据情况的比对指标
|
|
319
|
-
|
|
320
|
-
npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
|
|
321
|
-
result_item.extend(npu_summary_data)
|
|
322
|
-
bench_summary_data = bench_ops_all.get(bench_op_name).get("summary")
|
|
323
|
-
result_item.extend(bench_summary_data)
|
|
324
|
-
if self.dump_mode == Const.SUMMARY:
|
|
325
|
-
self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item)
|
|
326
|
-
else:
|
|
327
|
-
result_item.append(CompareConst.ACCURACY_CHECK_YES)
|
|
328
|
-
result_item.append("")
|
|
329
|
-
if has_stack:
|
|
330
|
-
result_item.extend(npu_stack_info)
|
|
331
|
-
else:
|
|
332
|
-
result_item.append(CompareConst.NONE)
|
|
333
|
-
if self.dump_mode == Const.ALL:
|
|
334
|
-
ms_data_name = npu_ops_all.get(ms_op_name).get("data_name", None)
|
|
335
|
-
pt_data_name = bench_ops_all.get(bench_op_name).get("data_name", None)
|
|
336
|
-
result_item.append([ms_data_name, pt_data_name])
|
|
337
|
-
result.append(result_item)
|
|
338
|
-
logger.info(f"{ms_op_name}, {bench_op_name} compared.")
|
|
339
|
-
elif ms_op_name not in npu_ops_all:
|
|
340
|
-
logger.warning(f'Can not find npu op name : `{ms_op_name}` in npu dump json file.')
|
|
341
|
-
elif bench_op_name not in npu_ops_all:
|
|
342
|
-
logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.')
|
|
343
|
-
return result
|
|
516
|
+
for op_item in npu_ops_queue:
|
|
517
|
+
match_result = self.put_unmatched_in_table(match_result, op_item)
|
|
344
518
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
npu_json_data = load_json(npu_json_path)
|
|
348
|
-
bench_json_data = load_json(bench_json_path)
|
|
349
|
-
stack_json_data = load_json(stack_json_path) if self.stack_mode else None
|
|
350
|
-
npu_ops_all = self.merge_data(npu_json_data, stack_json_data)
|
|
351
|
-
bench_ops_all = self.merge_data(bench_json_data, stack_json_data)
|
|
519
|
+
match_result.reset_index(drop=True, inplace=True)
|
|
520
|
+
return match_result
|
|
352
521
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
522
|
+
def match_op(self, npu_queue, bench_queue):
|
|
523
|
+
for b_index, b_op in enumerate(bench_queue[0: -1]):
|
|
524
|
+
if self.check_op_item(npu_queue[-1], b_op):
|
|
525
|
+
return len(npu_queue) - 1, b_index
|
|
526
|
+
if self.check_op_item(npu_queue[-1], bench_queue[-1]):
|
|
527
|
+
return len(npu_queue) - 1, len(bench_queue) - 1
|
|
528
|
+
for n_index, n_op in enumerate(npu_queue[0: -1]):
|
|
529
|
+
if self.check_op_item(n_op, bench_queue[-1]):
|
|
530
|
+
return n_index, len(bench_queue) - 1
|
|
531
|
+
return -1, -1
|
|
356
532
|
|
|
357
|
-
def
|
|
533
|
+
def gen_dtype_condition(self, match_result):
|
|
358
534
|
"""
|
|
359
|
-
|
|
360
|
-
:param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
|
|
361
|
-
:param op_name_mapping_dict: op_name和npy或pt文件的映射关系
|
|
362
|
-
:param input_param: npu_json_path/bench_json_path/stack_json_path等参数
|
|
363
|
-
:return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
|
|
364
|
-
用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、欧式距离
|
|
365
|
-
最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
|
|
535
|
+
dtype匹配条件为npu、bench的dtype一致或属于规定的映射关系
|
|
366
536
|
"""
|
|
367
|
-
|
|
537
|
+
# 如果使用了data_mapping,不校验dtype,返回全True的DataFrame
|
|
538
|
+
if self.mapping_config.data_mapping:
|
|
539
|
+
return pd.Series(True, index=match_result.index)
|
|
540
|
+
|
|
541
|
+
npu_dtype = match_result['dtype_x']
|
|
542
|
+
bench_dtype = match_result['dtype_y']
|
|
543
|
+
npu_dtype = self.process_cross_frame_dtype(npu_dtype)
|
|
544
|
+
bench_dtype = self.process_cross_frame_dtype(bench_dtype)
|
|
545
|
+
|
|
546
|
+
equal_condition = npu_dtype == bench_dtype
|
|
547
|
+
match_condition = (
|
|
548
|
+
(npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[0]) & bench_dtype.isin(
|
|
549
|
+
CompareConst.DTYPE_MATCH_GROUPS[0])) |
|
|
550
|
+
(npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[1]) & bench_dtype.isin(
|
|
551
|
+
CompareConst.DTYPE_MATCH_GROUPS[1]))
|
|
552
|
+
)
|
|
553
|
+
return equal_condition | match_condition
|
|
368
554
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
555
|
+
def process_cross_frame_dtype(self, dtype):
|
|
556
|
+
if self.cross_frame:
|
|
557
|
+
dtype = dtype.map(cross_dtype_mapping).fillna(dtype)
|
|
558
|
+
return dtype
|
|
372
559
|
|
|
373
|
-
if str(npu_data_name) == '-1': # 没有npu真实数据
|
|
374
|
-
n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
375
|
-
elif str(bench_data_name) == '-1': # 没有bench真实数据
|
|
376
|
-
n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
377
|
-
error_file = 'no_bench_data'
|
|
378
|
-
else:
|
|
379
|
-
npu_dir = input_param.get("npu_dump_data_dir")
|
|
380
|
-
bench_dir = input_param.get("bench_dump_data_dir")
|
|
381
|
-
try:
|
|
382
|
-
frame_name = getattr(self, "frame_name")
|
|
383
|
-
read_npy_data = getattr(self, "read_npy_data")
|
|
384
|
-
if frame_name == "MSComparator":
|
|
385
|
-
n_value = read_npy_data(npu_dir, npu_data_name)
|
|
386
|
-
if self.cross_frame:
|
|
387
|
-
b_value = read_npy_data(bench_dir, bench_data_name, load_pt_file=True)
|
|
388
|
-
else:
|
|
389
|
-
b_value = read_npy_data(bench_dir, bench_data_name)
|
|
390
|
-
else:
|
|
391
|
-
n_value = read_npy_data(npu_dir, npu_data_name)
|
|
392
|
-
b_value = read_npy_data(bench_dir, bench_data_name)
|
|
393
|
-
except IOError as error:
|
|
394
|
-
error_file = error.filename
|
|
395
|
-
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
396
|
-
error_flag = True
|
|
397
|
-
except (FileCheckException, CompareException):
|
|
398
|
-
error_file = npu_data_name
|
|
399
|
-
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
400
|
-
error_flag = True
|
|
401
|
-
|
|
402
|
-
# 通过n_value, b_value同时得到错误标志和错误信息
|
|
403
|
-
n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value,
|
|
404
|
-
error_flag=error_flag, error_file=error_file)
|
|
405
|
-
|
|
406
|
-
result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg)
|
|
407
|
-
|
|
408
|
-
if self.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
|
|
409
|
-
err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
|
|
410
|
-
result_list.append(err_msg)
|
|
411
|
-
return result_list
|
|
412
560
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
Args:
|
|
418
|
-
input_param (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
|
|
419
|
-
"stack_path").
|
|
420
|
-
output_path (str): The path where the output Excel report will be saved.
|
|
421
|
-
**kwargs: Additional keyword arguments including:
|
|
422
|
-
- stack_mode (bool, optional): Enables stack mode comparison. Defaults to False.
|
|
423
|
-
- auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
|
|
424
|
-
- suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
|
|
425
|
-
- fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
|
|
426
|
-
- dump_mode (str): ALL, SUMMARY, MD5.
|
|
561
|
+
class CreateTable:
|
|
562
|
+
def __init__(self, mode_config: ModeConfig):
|
|
563
|
+
self.mode_config = mode_config
|
|
427
564
|
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
565
|
+
@staticmethod
|
|
566
|
+
def process_data_name(result):
|
|
567
|
+
result['data_name_x'] = result.apply(lambda row: [row['data_name_x'], row['data_name_y']], axis=1)
|
|
568
|
+
return result
|
|
432
569
|
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
570
|
+
@staticmethod
|
|
571
|
+
def set_summary(summary):
|
|
572
|
+
if summary == CompareConst.N_A:
|
|
573
|
+
return [CompareConst.N_A] * 4 # 4为统计值个数
|
|
574
|
+
summary_list = []
|
|
575
|
+
for i in summary:
|
|
576
|
+
if str(i).lower() == 'nan':
|
|
577
|
+
summary_list.append(CompareConst.NAN)
|
|
578
|
+
else:
|
|
579
|
+
summary_list.append(i)
|
|
580
|
+
return summary_list
|
|
440
581
|
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
if self.
|
|
445
|
-
|
|
582
|
+
def make_result_df(self, result):
|
|
583
|
+
# get header
|
|
584
|
+
header = CompareConst.HEAD_OF_COMPARE_MODE[self.mode_config.dump_mode][:]
|
|
585
|
+
if self.mode_config.stack_mode:
|
|
586
|
+
header.append(CompareConst.STACK)
|
|
587
|
+
if self.mode_config.dump_mode == Const.ALL:
|
|
588
|
+
header.append(CompareConst.DATA_NAME)
|
|
589
|
+
result = self.process_data_name(result)
|
|
590
|
+
|
|
591
|
+
# rename match_result columns
|
|
592
|
+
result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
|
|
593
|
+
'op_name_y': CompareConst.BENCH_NAME,
|
|
594
|
+
'dtype_x': CompareConst.NPU_DTYPE,
|
|
595
|
+
'dtype_y': CompareConst.BENCH_DTYPE,
|
|
596
|
+
'shape_x': CompareConst.NPU_SHAPE,
|
|
597
|
+
'shape_y': CompareConst.BENCH_SHAPE,
|
|
598
|
+
'md5_x': CompareConst.NPU_MD5,
|
|
599
|
+
'md5_y': CompareConst.BENCH_MD5,
|
|
600
|
+
'data_name_x': CompareConst.DATA_NAME,
|
|
601
|
+
'stack_info_x': CompareConst.STACK}, inplace=True)
|
|
602
|
+
|
|
603
|
+
# process summary data
|
|
604
|
+
npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
|
|
605
|
+
bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
|
|
606
|
+
CompareConst.BENCH_NORM]
|
|
607
|
+
if result.empty:
|
|
608
|
+
result[npu_summary] = pd.DataFrame(columns=npu_summary)
|
|
609
|
+
result[bench_summary] = pd.DataFrame(columns=bench_summary)
|
|
446
610
|
else:
|
|
447
|
-
|
|
611
|
+
result[npu_summary] = result['summary_x'].apply(self.set_summary).tolist()
|
|
612
|
+
result[bench_summary] = result['summary_y'].apply(self.set_summary).tolist()
|
|
448
613
|
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
614
|
+
result_df = pd.DataFrame(columns=header)
|
|
615
|
+
for h in header:
|
|
616
|
+
if h in result.columns:
|
|
617
|
+
result_df[h] = result[h]
|
|
618
|
+
return result_df, header
|
|
452
619
|
|
|
453
|
-
if self.dump_mode == Const.ALL:
|
|
454
|
-
result_df = self.do_multi_process(input_param, result_df)
|
|
455
620
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
if self.auto_analyze:
|
|
460
|
-
advisor = Advisor(result_df, output_path, suffix)
|
|
461
|
-
advisor.analysis()
|
|
621
|
+
class CalcStatsDiff:
|
|
622
|
+
def __init__(self, mode_config: ModeConfig):
|
|
623
|
+
self.mode_config = mode_config
|
|
462
624
|
|
|
463
|
-
|
|
625
|
+
@staticmethod
|
|
626
|
+
def type_check(val):
|
|
627
|
+
"""
|
|
628
|
+
检查是否为数值或字符串形式的nan, 如果是返回True
|
|
629
|
+
"""
|
|
630
|
+
check_series = pd.Series(False, index=val.index)
|
|
631
|
+
val_str = val.astype(str)
|
|
632
|
+
check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True
|
|
633
|
+
return check_series
|
|
464
634
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
635
|
+
@staticmethod
|
|
636
|
+
def get_number(val):
|
|
637
|
+
return pd.to_numeric(val.astype(str), errors='coerce')
|
|
638
|
+
|
|
639
|
+
def calc_summary_diff(self, result_df, cond_no_bench, stats_index: str):
|
|
640
|
+
npu_val = result_df['NPU ' + stats_index]
|
|
641
|
+
bench_val = result_df['Bench ' + stats_index]
|
|
642
|
+
diff_name = stats_index.capitalize() + ' diff'
|
|
643
|
+
rel_err_name = ('norm' if stats_index == 'l2norm' else stats_index).capitalize() + 'RelativeErr'
|
|
644
|
+
|
|
645
|
+
# npu、bench中统计量均为数字或nan
|
|
646
|
+
cond_num_nan = self.type_check(npu_val) & self.type_check(bench_val)
|
|
647
|
+
|
|
648
|
+
# 如果统计量不是数字或nan,就赋值统计量差异为N/A
|
|
649
|
+
result_df.loc[~cond_num_nan, [diff_name, rel_err_name]] = CompareConst.N_A
|
|
650
|
+
cond_valid_stat = ~cond_no_bench & cond_num_nan # 有效统计条件:bench_name不是N/A,并且NPU和bench的统计量都是数字或nan
|
|
651
|
+
result_df.loc[cond_valid_stat, diff_name] = self.get_number(npu_val) - self.get_number(bench_val)
|
|
652
|
+
|
|
653
|
+
cond_diff_nan = result_df[diff_name].isna() # 统计量差异是nan
|
|
654
|
+
cond_nan_diff = cond_valid_stat & cond_diff_nan
|
|
655
|
+
result_df.loc[cond_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN
|
|
656
|
+
|
|
657
|
+
cond_not_nan_diff = cond_valid_stat & ~cond_diff_nan
|
|
658
|
+
condition_pt_zero = bench_val == 0
|
|
659
|
+
result_df.loc[cond_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.N_A
|
|
660
|
+
|
|
661
|
+
# 相对误差转成百分比字符串
|
|
662
|
+
cond_ref_err = cond_not_nan_diff & ~condition_pt_zero
|
|
663
|
+
result_df.loc[cond_ref_err, rel_err_name] = (
|
|
664
|
+
result_df.loc[cond_ref_err, diff_name] / bench_val[cond_ref_err] * 100)
|
|
665
|
+
result_df.loc[cond_ref_err, rel_err_name] = (result_df.loc[cond_ref_err, rel_err_name].abs().astype(str) + '%')
|
|
666
|
+
|
|
667
|
+
magnitude = self.get_number(result_df[diff_name]).abs() / (pd.Series(
|
|
668
|
+
np.maximum(self.get_number(npu_val), self.get_number(bench_val))).abs() + CompareConst.EPSILON)
|
|
669
|
+
return magnitude > CompareConst.MAGNITUDE
|
|
670
|
+
|
|
671
|
+
def calc_accuracy(self, result_df, header):
|
|
672
|
+
# bench name N/A represents no bench data, err_msg adds "No bench data matched."
|
|
673
|
+
condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
|
|
674
|
+
result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
|
|
675
|
+
result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH
|
|
676
|
+
|
|
677
|
+
if self.mode_config.dump_mode == Const.MD5:
|
|
678
|
+
condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5]
|
|
679
|
+
result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS
|
|
680
|
+
result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
|
|
681
|
+
elif self.mode_config.dump_mode == Const.SUMMARY:
|
|
682
|
+
warning_list = [
|
|
683
|
+
self.calc_summary_diff(result_df, condition_no_bench, stats_index)
|
|
684
|
+
for stats_index in ['max', 'min', 'mean', 'l2norm']
|
|
685
|
+
]
|
|
686
|
+
warning_flag = pd.DataFrame(warning_list).any()
|
|
687
|
+
result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
|
|
688
|
+
result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
|
|
689
|
+
result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
|
|
690
|
+
else:
|
|
691
|
+
fill_cols = [CompareConst.COSINE, CompareConst.EUC_DIST,
|
|
692
|
+
CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
|
|
693
|
+
CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
|
|
694
|
+
CompareConst.ERROR_MESSAGE]
|
|
695
|
+
result_df.loc[~condition_no_bench, fill_cols] = ''
|
|
696
|
+
result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES
|
|
697
|
+
|
|
698
|
+
return result_df[header]
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig:
|
|
702
|
+
"""公共的前置处理逻辑,返回封装后的 ComparisonConfig 对象"""
|
|
703
|
+
try:
|
|
704
|
+
config = ComparisonConfig(
|
|
705
|
+
dump_mode='',
|
|
706
|
+
stack_mode=False,
|
|
707
|
+
auto_analyze=kwargs.get('auto_analyze', True),
|
|
708
|
+
fuzzy_match=kwargs.get('fuzzy_match', False),
|
|
709
|
+
data_mapping=kwargs.get('data_mapping', {}),
|
|
710
|
+
suffix=kwargs.get('suffix', ''),
|
|
711
|
+
cell_mapping=kwargs.get('cell_mapping', {}),
|
|
712
|
+
api_mapping=kwargs.get('api_mapping', {}),
|
|
713
|
+
layer_mapping=kwargs.get('layer_mapping', {}),
|
|
714
|
+
compared_file_type='',
|
|
507
715
|
)
|
|
508
716
|
|
|
509
|
-
|
|
717
|
+
set_dump_path(input_param)
|
|
718
|
+
config.dump_mode = get_dump_mode(input_param)
|
|
719
|
+
config.compared_file_type = get_file_type(input_param.get("npu_json_path", None))
|
|
510
720
|
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
721
|
+
# set stack_mode and set "stack_json_path" in input_param
|
|
722
|
+
if 'stack_json_path' in input_param:
|
|
723
|
+
config.stack_mode = kwargs.get('stack_mode', False)
|
|
724
|
+
else:
|
|
725
|
+
config.stack_mode = set_stack_json_path(input_param)
|
|
726
|
+
|
|
727
|
+
check_configuration_param(config.stack_mode, config.auto_analyze, config.fuzzy_match,
|
|
728
|
+
input_param.get('is_print_compare_log', True))
|
|
729
|
+
create_directory(output_path)
|
|
730
|
+
check_compare_param(input_param, output_path, config.dump_mode, config.stack_mode)
|
|
731
|
+
|
|
732
|
+
return config
|
|
733
|
+
|
|
734
|
+
except (CompareException, FileCheckException) as error:
|
|
735
|
+
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
736
|
+
raise CompareException(error.code) from error
|