mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
|
@@ -1,21 +1,47 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
import re
|
|
3
|
-
import copy
|
|
4
|
-
import sys
|
|
5
|
-
from itertools import zip_longest
|
|
6
18
|
|
|
7
|
-
from
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
19
|
+
from collections import defaultdict
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
import pandas as pd
|
|
23
|
+
|
|
24
|
+
from msprobe.core.common.const import CompareConst, Const
|
|
12
25
|
from msprobe.core.common.exceptions import FileCheckException
|
|
26
|
+
from msprobe.core.common.file_utils import (FileOpen, create_directory, load_json,
|
|
27
|
+
load_npy, load_yaml)
|
|
28
|
+
from msprobe.core.common.log import logger
|
|
29
|
+
from msprobe.core.common.utils import (CompareException, check_compare_param,
|
|
30
|
+
check_configuration_param,
|
|
31
|
+
get_dump_mode, set_dump_path, check_op_str_pattern_valid)
|
|
32
|
+
from msprobe.core.compare.check import dtype_mapping
|
|
13
33
|
from msprobe.core.compare.acc_compare import Comparator
|
|
14
|
-
from msprobe.core.compare.
|
|
15
|
-
|
|
16
|
-
from msprobe.mindspore.compare.layer_mapping import get_layer_mapping
|
|
34
|
+
from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping
|
|
35
|
+
|
|
17
36
|
|
|
18
37
|
class MSComparator(Comparator):
|
|
38
|
+
"""
|
|
39
|
+
用于mindspore动态图同框架/跨框架精度比对,支持md5/summary/all模式。
|
|
40
|
+
cell_mapping: mindspore在cell级别(L0)dump数据和pytorch的module之间的映射关系;
|
|
41
|
+
api_mapping: mindspore在api级别(L1)dump数据和pytorch的api之间的映射关系;
|
|
42
|
+
data_mapping: mindspore的cell或api的入参/出参和pytorch之间的映射关系;
|
|
43
|
+
is_cross_framework: 是否跨框架。
|
|
44
|
+
"""
|
|
19
45
|
def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None, is_cross_framework=False):
|
|
20
46
|
self.frame_name = MSComparator.__name__
|
|
21
47
|
self.cell_mapping = cell_mapping
|
|
@@ -37,10 +63,108 @@ class MSComparator(Comparator):
|
|
|
37
63
|
else:
|
|
38
64
|
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
39
65
|
f"{type(self.data_mapping)}")
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def calc_accuracy(cls, result_df, dump_mode, header):
|
|
69
|
+
condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
|
|
70
|
+
result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
|
|
71
|
+
result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH
|
|
72
|
+
|
|
73
|
+
def calc_summary_diff(data_type: str):
|
|
74
|
+
def type_check(val):
|
|
75
|
+
check_series = pd.Series(False, index=val.index)
|
|
76
|
+
val_str = val.astype(str)
|
|
77
|
+
check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True
|
|
78
|
+
return check_series
|
|
79
|
+
|
|
80
|
+
def get_number(val):
|
|
81
|
+
return pd.to_numeric(val.astype(str), errors='coerce')
|
|
82
|
+
|
|
83
|
+
ms_val = result_df['NPU ' + data_type]
|
|
84
|
+
pt_val = result_df['Bench ' + data_type]
|
|
85
|
+
diff_name = data_type.capitalize() + ' diff'
|
|
86
|
+
rel_err_name = ('norm' if data_type == 'l2norm' else data_type).capitalize() + 'RelativeErr'
|
|
87
|
+
condition_na = ~type_check(ms_val) | ~type_check(pt_val)
|
|
88
|
+
result_df.loc[condition_na, [diff_name, rel_err_name]] = CompareConst.N_A
|
|
89
|
+
result_df.loc[~(condition_no_bench | condition_na), diff_name] = get_number(ms_val) - get_number(pt_val)
|
|
90
|
+
condition_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].isna()
|
|
91
|
+
condition_not_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].notna()
|
|
92
|
+
result_df.loc[condition_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN
|
|
93
|
+
condition_pt_zero = pt_val == 0
|
|
94
|
+
result_df.loc[condition_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.NAN
|
|
95
|
+
condition_ref_err = condition_not_nan_diff & ~condition_pt_zero
|
|
96
|
+
result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, diff_name] /
|
|
97
|
+
pt_val[condition_ref_err] * 100)
|
|
98
|
+
result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, rel_err_name]
|
|
99
|
+
.abs().astype(str) + '%')
|
|
100
|
+
magnitude = get_number(result_df[diff_name]).abs() / (
|
|
101
|
+
pd.Series(np.maximum(get_number(ms_val), get_number(pt_val))).abs() + CompareConst.EPSILON)
|
|
102
|
+
return magnitude > CompareConst.MAGNITUDE
|
|
103
|
+
|
|
104
|
+
if dump_mode == Const.MD5:
|
|
105
|
+
condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5]
|
|
106
|
+
result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS
|
|
107
|
+
result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
|
|
108
|
+
elif dump_mode == Const.SUMMARY:
|
|
109
|
+
warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']]
|
|
110
|
+
warning_flag = pd.DataFrame(warning_list).all()
|
|
111
|
+
result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
|
|
112
|
+
result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
|
|
113
|
+
result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
|
|
114
|
+
else:
|
|
115
|
+
fill_cols = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
|
|
116
|
+
CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
|
|
117
|
+
CompareConst.ERROR_MESSAGE]
|
|
118
|
+
result_df.loc[~condition_no_bench, fill_cols] = ''
|
|
119
|
+
result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES
|
|
120
|
+
return result_df[header]
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def make_result_df(cls, result, stack_mode, dump_mode):
|
|
124
|
+
header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode]
|
|
125
|
+
|
|
126
|
+
if stack_mode:
|
|
127
|
+
header.append(CompareConst.STACK)
|
|
128
|
+
if dump_mode == Const.ALL:
|
|
129
|
+
header.append(CompareConst.DATA_NAME)
|
|
130
|
+
result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
|
|
131
|
+
'op_name_y': CompareConst.BENCH_NAME,
|
|
132
|
+
'dtype_x': CompareConst.NPU_DTYPE,
|
|
133
|
+
'dtype_y': CompareConst.BENCH_DTYPE,
|
|
134
|
+
'shape_x': CompareConst.NPU_SHAPE,
|
|
135
|
+
'shape_y': CompareConst.BENCH_SHAPE,
|
|
136
|
+
'md5_x': CompareConst.NPU_MD5,
|
|
137
|
+
'md5_y': CompareConst.BENCH_MD5,
|
|
138
|
+
'data_name_x': CompareConst.DATA_NAME,
|
|
139
|
+
'stack_info_x': CompareConst.STACK}, inplace=True)
|
|
140
|
+
|
|
141
|
+
npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
|
|
142
|
+
bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
|
|
143
|
+
CompareConst.BENCH_NORM]
|
|
144
|
+
def set_summary(summary):
|
|
145
|
+
if summary == CompareConst.N_A:
|
|
146
|
+
return [CompareConst.N_A] * 4
|
|
147
|
+
summary_list = []
|
|
148
|
+
for i in summary:
|
|
149
|
+
if i is None:
|
|
150
|
+
summary_list.append(CompareConst.N_A)
|
|
151
|
+
elif str(i).lower() == 'nan':
|
|
152
|
+
summary_list.append(CompareConst.NAN)
|
|
153
|
+
else:
|
|
154
|
+
summary_list.append(i)
|
|
155
|
+
return summary_list
|
|
156
|
+
|
|
157
|
+
result[npu_summary] = result['summary_x'].apply(set_summary).tolist()
|
|
158
|
+
result[bench_summary] = result['summary_y'].apply(set_summary).tolist()
|
|
159
|
+
result_df = pd.DataFrame(columns=header)
|
|
160
|
+
for h in header:
|
|
161
|
+
if h in result.columns:
|
|
162
|
+
result_df[h] = result[h]
|
|
163
|
+
return cls.calc_accuracy(result_df, dump_mode, header)
|
|
40
164
|
|
|
41
165
|
def load_internal_api(self):
|
|
42
166
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
43
|
-
yaml_path = os.path.join(cur_path,
|
|
167
|
+
yaml_path = os.path.abspath(os.path.join(cur_path, CompareConst.INTERNAL_API_MAPPING_FILE))
|
|
44
168
|
return load_yaml(yaml_path)
|
|
45
169
|
|
|
46
170
|
def load_mapping_file(self, mapping_file):
|
|
@@ -51,42 +175,20 @@ class MSComparator(Comparator):
|
|
|
51
175
|
return mapping_dict
|
|
52
176
|
|
|
53
177
|
def process_cell_mapping(self, npu_op_name):
|
|
54
|
-
npu_op_name
|
|
178
|
+
if not npu_op_name or not re.match(r'.+(?:for|back)ward\..+', npu_op_name):
|
|
179
|
+
return CompareConst.N_A
|
|
180
|
+
npu_op_name = npu_op_name.replace("Cell", "Module", 1)
|
|
55
181
|
if self.cell_mapping_dict:
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
npu_op_name[index] = op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
|
|
182
|
+
# get cell name & class name from op_name
|
|
183
|
+
# Cell.fc1.Dense.forward.0.input.0
|
|
184
|
+
cell_name = re.split(r'\.(?:for|back)ward\.', npu_op_name.split(Const.SEP, 1)[-1])[0]
|
|
185
|
+
if cell_name in self.cell_mapping_dict:
|
|
186
|
+
npu_op_name = npu_op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
|
|
62
187
|
return npu_op_name
|
|
63
188
|
|
|
64
|
-
def check_op(self, npu_dict, bench_dict, fuzzy_match):
|
|
65
|
-
npu_dict_new, bench_dict_new = copy.deepcopy(npu_dict), copy.deepcopy(bench_dict)
|
|
66
|
-
npu_op_name, bench_op_name = npu_dict_new.get(CompareConst.OP_NAME), bench_dict_new.get(CompareConst.OP_NAME)
|
|
67
|
-
if self.cell_mapping is not None:
|
|
68
|
-
npu_op_name = self.process_cell_mapping(npu_op_name)
|
|
69
|
-
if self.api_mapping is not None:
|
|
70
|
-
npu_op_name = self.process_internal_api_mapping(npu_op_name, bench_op_name)
|
|
71
|
-
if isinstance(self.api_mapping, str):
|
|
72
|
-
npu_dict_new, bench_dict_new, target_dict = self.transform_user_mapping_api(npu_dict_new,
|
|
73
|
-
bench_dict_new)
|
|
74
|
-
if target_dict:
|
|
75
|
-
bench_dict = self.reconstitution_bench_dict(npu_dict, copy.deepcopy(bench_dict_new), target_dict)
|
|
76
|
-
npu_op_name = npu_dict_new.get(CompareConst.OP_NAME)
|
|
77
|
-
bench_op_name = bench_dict_new.get(CompareConst.OP_NAME)
|
|
78
|
-
struct_match = check_struct_match(npu_dict_new, bench_dict_new, cross_frame=self.cross_frame)
|
|
79
|
-
if not fuzzy_match:
|
|
80
|
-
return npu_op_name == bench_op_name and struct_match
|
|
81
|
-
is_match = True
|
|
82
|
-
try:
|
|
83
|
-
is_match = fuzzy_check_op(npu_op_name, bench_op_name)
|
|
84
|
-
except Exception as err:
|
|
85
|
-
logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
|
|
86
|
-
is_match = False
|
|
87
|
-
return is_match and struct_match
|
|
88
|
-
|
|
89
189
|
def read_npy_data(self, dir_path, file_name, load_pt_file=False):
|
|
190
|
+
if not file_name:
|
|
191
|
+
return None
|
|
90
192
|
data_path = os.path.join(dir_path, file_name)
|
|
91
193
|
if load_pt_file:
|
|
92
194
|
import torch
|
|
@@ -97,34 +199,22 @@ class MSComparator(Comparator):
|
|
|
97
199
|
data_value = data_value.numpy()
|
|
98
200
|
else:
|
|
99
201
|
data_value = load_npy(data_path)
|
|
100
|
-
return data_value
|
|
202
|
+
return data_value
|
|
101
203
|
|
|
102
|
-
def
|
|
103
|
-
for idx, _ in enumerate(npu_op_name):
|
|
104
|
-
npu_op_name[idx] = npu_op_name[idx].replace(target, para)
|
|
105
|
-
return npu_op_name
|
|
106
|
-
|
|
107
|
-
def process_internal_api_mapping(self, npu_op_name, bench_op_name):
|
|
204
|
+
def process_internal_api_mapping(self, npu_op_name):
|
|
108
205
|
# get api name & class name from op_name
|
|
109
206
|
# Functional.addcmul.0.forward.input.0
|
|
110
|
-
|
|
111
|
-
ms_api_name = self.get_api_name(npu_op_name[0].split(Const.SEP))
|
|
112
|
-
pt_api_name = self.get_api_name(bench_op_name[0].split(Const.SEP))
|
|
207
|
+
ms_api_name = self.get_api_name(npu_op_name.split(Const.SEP))
|
|
113
208
|
class_name = ms_api_name.split(Const.SEP)[0]
|
|
114
209
|
if class_name == "Mint":
|
|
115
|
-
return
|
|
210
|
+
return npu_op_name.replace("Mint", "Torch")
|
|
116
211
|
elif class_name == "MintFunctional":
|
|
117
|
-
return
|
|
118
|
-
elif self.ms_to_pt_mapping.get(ms_api_name)
|
|
119
|
-
return
|
|
212
|
+
return npu_op_name.replace("MintFunctional", "Functional")
|
|
213
|
+
elif self.ms_to_pt_mapping.get(ms_api_name):
|
|
214
|
+
return npu_op_name.replace(ms_api_name, self.ms_to_pt_mapping.get(ms_api_name))
|
|
120
215
|
else:
|
|
121
216
|
return npu_op_name
|
|
122
217
|
|
|
123
|
-
def remove_element(self, op_name, struct, summary, idx):
|
|
124
|
-
del op_name[idx]
|
|
125
|
-
del struct[idx]
|
|
126
|
-
del summary[idx]
|
|
127
|
-
|
|
128
218
|
def get_api_name(self, api_list):
|
|
129
219
|
try:
|
|
130
220
|
api_name = api_list[0] + Const.SEP + api_list[1]
|
|
@@ -132,184 +222,126 @@ class MSComparator(Comparator):
|
|
|
132
222
|
logger.error(f'Failed to retrieve API name, please check if the dump data is reasonable')
|
|
133
223
|
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
134
224
|
return api_name
|
|
135
|
-
|
|
136
|
-
def transform_user_mapping_api(self, new_npu_dict, new_bench_dict):
|
|
137
|
-
"""
|
|
138
|
-
Transform user mapping API based on new NPU and benchmark dictionaries.
|
|
139
|
-
Parameters:
|
|
140
|
-
new_npu_dict (dict): New NPU operation dictionary.
|
|
141
|
-
new_bench_dict (dict): New benchmark operation dictionary.
|
|
142
|
-
Returns:
|
|
143
|
-
tuple: Updated NPU and benchmark dictionaries, along with the target dictionary.
|
|
144
|
-
"""
|
|
145
|
-
npu_op_name, bench_op_name = new_npu_dict.get(CompareConst.OP_NAME), new_bench_dict.get(CompareConst.OP_NAME)
|
|
146
|
-
npu_struct_in = new_npu_dict.get(CompareConst.INPUT_STRUCT)
|
|
147
|
-
bench_struct_in = new_bench_dict.get(CompareConst.INPUT_STRUCT)
|
|
148
|
-
npu_struct_out = new_npu_dict.get(CompareConst.OUTPUT_STRUCT)
|
|
149
|
-
bench_struct_out = new_bench_dict.get(CompareConst.OUTPUT_STRUCT)
|
|
150
|
-
npu_summary, bench_summary = new_npu_dict.get(CompareConst.SUMMARY), new_bench_dict.get(CompareConst.SUMMARY)
|
|
151
|
-
npu_in_len, bench_in_len = len(npu_struct_in), len(bench_struct_in)
|
|
152
|
-
npu_out_len, bench_out_len = len(npu_struct_out), len(bench_struct_out)
|
|
153
|
-
ms_api_list, pt_api_list = npu_op_name[0].split(Const.SEP), bench_op_name[0].split(Const.SEP)
|
|
154
|
-
ms_api_name = self.get_api_name(ms_api_list)
|
|
155
|
-
pt_api_name = self.get_api_name(pt_api_list)
|
|
156
|
-
target_dict = {}
|
|
157
|
-
for api_dict in self.api_mapping_dict:
|
|
158
|
-
if api_dict.get("pt_api") == pt_api_name and api_dict.get("ms_api") == ms_api_name:
|
|
159
|
-
ms_user_args_len, pt_user_args_len = len(api_dict.get("ms_args")), len(api_dict.get("pt_args"))
|
|
160
|
-
ms_user_output_len, pt_user_output_len = len(api_dict.get("ms_output")), len(api_dict.get("pt_output"))
|
|
161
|
-
if ms_user_args_len != pt_user_args_len or ms_user_output_len != pt_user_output_len:
|
|
162
|
-
logger.warning("The user-defined mapping table is incorrect,\
|
|
163
|
-
make sure that the number of parameters is equal")
|
|
164
|
-
break
|
|
165
|
-
ms_out_list = api_dict.get("ms_output", [])
|
|
166
|
-
for idx in reversed(range(npu_out_len)):
|
|
167
|
-
if idx not in ms_out_list:
|
|
168
|
-
del npu_struct_out[idx]
|
|
169
|
-
if idx + npu_in_len < len(npu_summary) and idx + npu_in_len < len(npu_op_name):
|
|
170
|
-
del npu_summary[idx + npu_in_len]
|
|
171
|
-
del npu_op_name[idx + npu_in_len]
|
|
172
|
-
pt_out_list = api_dict.get("pt_output", [])
|
|
173
|
-
for idx in reversed(range(bench_out_len)):
|
|
174
|
-
if idx not in pt_out_list:
|
|
175
|
-
del bench_struct_out[idx]
|
|
176
|
-
if idx + bench_in_len < len(bench_summary) and idx + bench_in_len < len(bench_op_name):
|
|
177
|
-
del bench_summary[idx + bench_in_len]
|
|
178
|
-
del bench_op_name[idx + bench_in_len]
|
|
179
|
-
ms_para_list = api_dict.get("ms_args", [])
|
|
180
|
-
for idx in reversed(range(npu_in_len)):
|
|
181
|
-
if idx not in ms_para_list:
|
|
182
|
-
self.remove_element(npu_op_name, npu_struct_in, npu_summary, idx)
|
|
183
|
-
pt_para_list = api_dict.get("pt_args", [])
|
|
184
|
-
for idx in reversed(range(bench_in_len)):
|
|
185
|
-
if idx not in pt_para_list:
|
|
186
|
-
self.remove_element(bench_op_name, bench_struct_in, bench_summary, idx)
|
|
187
|
-
npu_op_name = self.api_replace(npu_op_name, ms_api_name, pt_api_name)
|
|
188
|
-
npu_op_name = self.para_sequence_update(npu_op_name, bench_op_name)
|
|
189
|
-
target_dict = api_dict
|
|
190
|
-
break
|
|
191
|
-
if target_dict:
|
|
192
|
-
new_npu_dict.update({CompareConst.OP_NAME: npu_op_name, CompareConst.INPUT_STRUCT: npu_struct_in,
|
|
193
|
-
CompareConst.OUTPUT_STRUCT: npu_struct_out, CompareConst.SUMMARY: npu_summary})
|
|
194
|
-
new_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in,
|
|
195
|
-
CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
|
|
196
|
-
return new_npu_dict, new_bench_dict, target_dict
|
|
197
|
-
|
|
198
|
-
def para_sequence_update(self, npu_op_name, bench_op_name):
|
|
199
|
-
for idx, _ in enumerate(npu_op_name):
|
|
200
|
-
bench_op_name_list = bench_op_name[idx].rsplit(Const.SEP, 1)
|
|
201
|
-
if len(bench_op_name_list) != 0:
|
|
202
|
-
npu_op_name[idx] = npu_op_name[idx][:-1] + bench_op_name_list[-1]
|
|
203
|
-
return npu_op_name
|
|
204
225
|
|
|
205
|
-
def
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
npu_in_len = len(npu_struct_in)
|
|
211
|
-
npu_out_len = len(npu_struct_out)
|
|
212
|
-
if npu_in_len == len(ms_user_args_list) and npu_out_len == len(ms_user_output_list):
|
|
213
|
-
return del_bench_dict
|
|
214
|
-
ms_input_args_list = [i for i in range(npu_in_len)]
|
|
215
|
-
input_sub_list = list(set(ms_input_args_list) - set(ms_user_args_list))
|
|
216
|
-
ms_output_args_list = [i for i in range(npu_out_len)]
|
|
217
|
-
output_sub_list = list(set(ms_output_args_list) - set(ms_user_output_list))
|
|
218
|
-
bench_op_name = del_bench_dict.get(CompareConst.OP_NAME, [])
|
|
219
|
-
bench_struct_in = del_bench_dict.get(CompareConst.INPUT_STRUCT, [])
|
|
220
|
-
bench_struct_out = del_bench_dict.get(CompareConst.OUTPUT_STRUCT, [])
|
|
221
|
-
bench_summary = del_bench_dict.get(CompareConst.SUMMARY, [])
|
|
222
|
-
for idx in input_sub_list: # Fill in the blank value field in the pt dictionary
|
|
223
|
-
bench_op_name.insert(idx, CompareConst.N_A)
|
|
224
|
-
bench_struct_in.insert(idx, CompareConst.N_A)
|
|
225
|
-
bench_summary.insert(idx, CompareConst.N_A)
|
|
226
|
-
for idx in output_sub_list: # Fill in the blank value field in the pt dictionary
|
|
227
|
-
bench_op_name.insert(npu_in_len + idx, CompareConst.N_A)
|
|
228
|
-
bench_struct_out.insert(idx, CompareConst.N_A)
|
|
229
|
-
bench_summary.insert(npu_in_len + idx, CompareConst.N_A)
|
|
230
|
-
del_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in,
|
|
231
|
-
CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
|
|
232
|
-
return del_bench_dict
|
|
233
|
-
|
|
226
|
+
def compare_process(self, file_lists, stack_mode, fuzzy_match, dump_mode):
|
|
227
|
+
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
228
|
+
npu_json_data = load_json(npu_json_path)
|
|
229
|
+
bench_json_data = load_json(bench_json_path)
|
|
230
|
+
stack_json_data = load_json(stack_json_path)
|
|
234
231
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
232
|
+
npu_df = self.gen_data_df(npu_json_data, stack_json_data, dump_mode)
|
|
233
|
+
bench_df = self.gen_data_df(bench_json_data, stack_json_data, dump_mode)
|
|
234
|
+
if self.cell_mapping:
|
|
235
|
+
npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping)
|
|
236
|
+
elif self.api_mapping:
|
|
237
|
+
npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_internal_api_mapping)
|
|
238
|
+
if isinstance(self.api_mapping, str):
|
|
239
|
+
self.modify_compare_data_with_user_mapping(npu_df, bench_df)
|
|
240
|
+
else:
|
|
241
|
+
npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME]
|
|
242
|
+
npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str)
|
|
243
|
+
bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str)
|
|
244
|
+
npu_df[CompareConst.COMPARE_SHAPE] = npu_df[Const.SHAPE]
|
|
245
|
+
bench_df[CompareConst.COMPARE_SHAPE] = bench_df[Const.SHAPE]
|
|
246
|
+
bench_df[CompareConst.COMPARE_KEY] = bench_df[CompareConst.OP_NAME]
|
|
247
|
+
match_result = pd.merge(npu_df, bench_df, on=[CompareConst.COMPARE_KEY, CompareConst.COMPARE_SHAPE],
|
|
248
|
+
how='outer')
|
|
249
|
+
match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A)
|
|
248
250
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
+
def gen_dtype_condition():
|
|
252
|
+
npu_dtype = match_result['dtype_x']
|
|
253
|
+
bench_dtype = match_result['dtype_y']
|
|
254
|
+
if self.cross_frame:
|
|
255
|
+
npu_dtype = npu_dtype.map(dtype_mapping).fillna(npu_dtype)
|
|
256
|
+
return ((npu_dtype == bench_dtype) |
|
|
257
|
+
((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.FLOAT32)) |
|
|
258
|
+
((npu_dtype == Const.FLOAT32) & (bench_dtype == Const.FLOAT16)) |
|
|
259
|
+
((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.BFLOAT16)) |
|
|
260
|
+
((npu_dtype == Const.BFLOAT16) & (bench_dtype == Const.FLOAT16)) |
|
|
261
|
+
((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_FLOAT32)) |
|
|
262
|
+
((npu_dtype == Const.TORCH_FLOAT32) & (bench_dtype == Const.TORCH_FLOAT16)) |
|
|
263
|
+
((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_BFLOAT16)) |
|
|
264
|
+
((npu_dtype == Const.TORCH_BFLOAT16) & (bench_dtype == Const.TORCH_FLOAT16)))
|
|
265
|
+
|
|
266
|
+
match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A
|
|
267
|
+
return MSComparator.make_result_df(match_result, stack_mode, dump_mode)
|
|
251
268
|
|
|
252
|
-
def
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
269
|
+
def modify_compare_data_with_user_mapping(self, npu_df, bench_df):
|
|
270
|
+
def get_api_indices_dict(op_name_df):
|
|
271
|
+
api_indices_dict = defaultdict(list)
|
|
272
|
+
for op_index, name in enumerate(op_name_df[CompareConst.OP_NAME]):
|
|
273
|
+
api = self.get_api_name(name.split(Const.SEP))
|
|
274
|
+
api_indices_dict[api].append(op_index)
|
|
275
|
+
return api_indices_dict
|
|
256
276
|
|
|
257
|
-
|
|
277
|
+
ms_api_indices_dict = get_api_indices_dict(npu_df)
|
|
278
|
+
pt_api_indices_dict = get_api_indices_dict(bench_df)
|
|
258
279
|
|
|
280
|
+
def gen_input_compare_key(pattern, term):
|
|
281
|
+
flag = True
|
|
282
|
+
for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')):
|
|
283
|
+
if op_name.split(pattern)[1].startswith(str(prefix)):
|
|
284
|
+
npu_df.loc[index, CompareConst.COMPARE_KEY] = (
|
|
285
|
+
op_name.replace(pattern + str(prefix),
|
|
286
|
+
pattern + str(mapping_dict.get(f'pt_{term}')[i])))
|
|
287
|
+
flag = False
|
|
288
|
+
return flag
|
|
259
289
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
map_split = map_value.split(Const.SEP)
|
|
266
|
-
map_name = Const.SEP.join(map_split[0:-1])
|
|
267
|
-
map_index = map_split[-1]
|
|
268
|
-
for key, value in data.items():
|
|
269
|
-
if key.find(flag) != -1 and key.find(map_name) != -1:
|
|
270
|
-
if key.split(Const.SEP)[-1] != map_index and key.split(Const.SEP)[-2] != map_index :
|
|
290
|
+
for mapping_dict in self.api_mapping_dict:
|
|
291
|
+
if (len(mapping_dict.get('ms_args')) != len(mapping_dict.get('pt_args')) or
|
|
292
|
+
len(mapping_dict.get('ms_output')) != len(mapping_dict.get('pt_output'))):
|
|
293
|
+
logger.warning('The user-defined mapping table is incorrect,\
|
|
294
|
+
make sure that the number of parameters is equal')
|
|
271
295
|
continue
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
forward_data = []
|
|
290
|
-
mapping_list = sort_by_execution_sequence(npu_data, bench_data, mapping_list, Const.FORWARD)
|
|
291
|
-
for map_value in mapping_list:
|
|
292
|
-
npu_forward_inputs, npu_backward_outputs = generate_kernel_data(map_value[0], npu_data, "forward")
|
|
293
|
-
bench_forward_inputs, bench_backward_outputs = generate_kernel_data(map_value[1], bench_data, "forward")
|
|
294
|
-
inputs_zip = list(zip_longest(npu_forward_inputs, bench_forward_inputs))
|
|
295
|
-
outputs_zip = list(zip_longest(npu_backward_outputs, bench_backward_outputs))
|
|
296
|
-
forward_data.extend(inputs_zip)
|
|
297
|
-
forward_data.extend(outputs_zip)
|
|
298
|
-
|
|
299
|
-
backward_data = []
|
|
300
|
-
mapping_list = sort_by_execution_sequence(npu_data, bench_data, mapping_list, Const.BACKWARD)
|
|
301
|
-
for map_value in mapping_list:
|
|
302
|
-
npu_forward_inputs, npu_backward_outputs = generate_kernel_data(map_value[0], npu_data, "backward")
|
|
303
|
-
bench_forward_inputs, bench_backward_outputs = generate_kernel_data(map_value[1], bench_data, "backward")
|
|
304
|
-
inputs_zip = list(zip_longest(npu_forward_inputs, bench_forward_inputs))
|
|
305
|
-
outputs_zip = list(zip_longest(npu_backward_outputs, bench_backward_outputs))
|
|
306
|
-
backward_data.extend(inputs_zip)
|
|
307
|
-
backward_data.extend(outputs_zip)
|
|
308
|
-
|
|
309
|
-
kernel_data = forward_data + backward_data
|
|
310
|
-
result = {key: value for key, value in kernel_data if key is not None}
|
|
296
|
+
ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api')
|
|
297
|
+
if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict:
|
|
298
|
+
continue
|
|
299
|
+
for index in ms_api_indices_dict.get(ms_api):
|
|
300
|
+
op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1)
|
|
301
|
+
if CompareConst.INPUT_PATTERN in op_name:
|
|
302
|
+
is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args')
|
|
303
|
+
elif CompareConst.KWARGS_PATTERN in op_name:
|
|
304
|
+
is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args')
|
|
305
|
+
elif CompareConst.OUTPUT_PATTERN in op_name:
|
|
306
|
+
is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output')
|
|
307
|
+
else:
|
|
308
|
+
logger.error(f'Excepted op_name: {op_name}')
|
|
309
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
310
|
+
if is_abandoned:
|
|
311
|
+
npu_df.loc[index, CompareConst.COMPARE_KEY] = op_name + 'abandoned'
|
|
311
312
|
|
|
312
|
-
|
|
313
|
+
def gen_data_df(self, data_json, stack_json, dump_mode):
|
|
314
|
+
result = {
|
|
315
|
+
CompareConst.OP_NAME: [],
|
|
316
|
+
Const.DTYPE: [],
|
|
317
|
+
Const.SHAPE: [],
|
|
318
|
+
Const.SUMMARY: [],
|
|
319
|
+
'stack_info': []
|
|
320
|
+
}
|
|
321
|
+
if dump_mode == Const.ALL:
|
|
322
|
+
result['data_name'] = []
|
|
323
|
+
elif dump_mode == Const.MD5:
|
|
324
|
+
result[Const.MD5] = []
|
|
325
|
+
for data_name in data_json['data']:
|
|
326
|
+
check_op_str_pattern_valid(data_name)
|
|
327
|
+
merge_list = self.gen_merge_list(data_json, data_name, stack_json, dump_mode)
|
|
328
|
+
if not merge_list:
|
|
329
|
+
continue
|
|
330
|
+
for op_name in merge_list[CompareConst.OP_NAME]:
|
|
331
|
+
result[CompareConst.OP_NAME].append(op_name)
|
|
332
|
+
if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name):
|
|
333
|
+
struct = merge_list[CompareConst.INPUT_STRUCT].pop(0)
|
|
334
|
+
else:
|
|
335
|
+
struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0)
|
|
336
|
+
result[Const.DTYPE].append(struct[0])
|
|
337
|
+
result[Const.SHAPE].append(struct[1])
|
|
338
|
+
if dump_mode == Const.MD5:
|
|
339
|
+
result[Const.MD5].append(struct[2])
|
|
340
|
+
result[Const.SUMMARY].append(merge_list[Const.SUMMARY].pop(0))
|
|
341
|
+
result['stack_info'].append(merge_list['stack_info'][0])
|
|
342
|
+
if dump_mode == Const.ALL:
|
|
343
|
+
result['data_name'].append(merge_list['data_name'].pop(0))
|
|
344
|
+
return pd.DataFrame(result)
|
|
313
345
|
|
|
314
346
|
|
|
315
347
|
def check_cross_framework(bench_json_path):
|
|
@@ -330,28 +362,19 @@ def ms_compare(input_param, output_path, **kwargs):
|
|
|
330
362
|
api_mapping = kwargs.get('api_mapping', None)
|
|
331
363
|
data_mapping = kwargs.get('data_mapping', None)
|
|
332
364
|
layer_mapping = kwargs.get('layer_mapping', None)
|
|
365
|
+
suffix = kwargs.get('suffix', '')
|
|
333
366
|
|
|
334
|
-
|
|
367
|
+
set_dump_path(input_param)
|
|
368
|
+
dump_mode = get_dump_mode(input_param)
|
|
335
369
|
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
|
|
336
370
|
create_directory(output_path)
|
|
337
|
-
check_compare_param(input_param, output_path,
|
|
371
|
+
check_compare_param(input_param, output_path, dump_mode)
|
|
338
372
|
except (CompareException, FileCheckException) as error:
|
|
339
373
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
340
374
|
raise CompareException(error.code) from error
|
|
341
375
|
if layer_mapping:
|
|
342
|
-
|
|
343
|
-
ms_stack, ms_construct = struct_json_get(input_param, Const.MS_FRAMEWORK)
|
|
344
|
-
mapping = load_yaml(layer_mapping)
|
|
345
|
-
ms_mapping_result = modify_mapping_with_stack(ms_stack, ms_construct)
|
|
346
|
-
pt_mapping_result = modify_mapping_with_stack(pt_stack, pt_construct)
|
|
347
|
-
layer_mapping = get_layer_mapping(ms_mapping_result, pt_mapping_result, mapping)
|
|
348
|
-
data_mapping = generate_file_mapping(input_param.get("npu_json_path"), input_param.get("bench_json_path"), layer_mapping)
|
|
349
|
-
|
|
350
|
-
data_mapping_name = add_time_with_yaml(f"data_mapping")
|
|
351
|
-
data_mapping_path = os.path.join(os.path.realpath(output_path), f"{data_mapping_name}")
|
|
352
|
-
save_yaml(data_mapping_path, data_mapping)
|
|
376
|
+
data_mapping = generate_data_mapping_by_layer_mapping(input_param, layer_mapping, output_path)
|
|
353
377
|
is_cross_framework = check_cross_framework(input_param.get("bench_json_path"))
|
|
354
378
|
ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping, is_cross_framework)
|
|
355
|
-
ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
|
|
356
|
-
auto_analyze=auto_analyze, fuzzy_match=fuzzy_match,
|
|
357
|
-
md5_compare=md5_compare)
|
|
379
|
+
ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode, suffix=suffix,
|
|
380
|
+
auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, dump_mode=dump_mode)
|