mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -13,418 +13,30 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import
|
|
17
|
-
import re
|
|
18
|
-
from collections import defaultdict
|
|
19
|
-
|
|
20
|
-
import numpy as np
|
|
21
|
-
import pandas as pd
|
|
22
|
-
|
|
23
|
-
from msprobe.core.common.const import CompareConst, Const
|
|
24
|
-
from msprobe.core.common.exceptions import FileCheckException
|
|
25
|
-
from msprobe.core.common.file_utils import create_directory, load_json, load_npy, load_yaml
|
|
26
|
-
from msprobe.core.common.log import logger
|
|
27
|
-
from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, \
|
|
28
|
-
check_op_str_pattern_valid, get_dump_mode, set_dump_path, detect_framework_by_dump_json
|
|
29
|
-
from msprobe.core.compare.acc_compare import Comparator, ModeConfig
|
|
30
|
-
from msprobe.core.compare.check import dtype_mapping
|
|
16
|
+
from msprobe.core.compare.acc_compare import Comparator, ModeConfig, MappingConfig, setup_comparison
|
|
31
17
|
from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping
|
|
32
|
-
from msprobe.
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class MappingConfig:
|
|
36
|
-
def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None):
|
|
37
|
-
self.cell_mapping = cell_mapping
|
|
38
|
-
self.api_mapping = api_mapping
|
|
39
|
-
self.data_mapping = data_mapping
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
class MSComparator(Comparator):
|
|
43
|
-
"""
|
|
44
|
-
用于mindspore动态图同框架/跨框架精度比对,支持md5/summary/all模式。
|
|
45
|
-
cell_mapping: mindspore在cell级别(L0)dump数据和pytorch的module之间的映射关系;
|
|
46
|
-
api_mapping: mindspore在api级别(L1)dump数据和pytorch的api之间的映射关系;
|
|
47
|
-
data_mapping: mindspore的cell或api的入参/出参和pytorch之间的映射关系;
|
|
48
|
-
is_cross_framework: 是否跨框架。
|
|
49
|
-
"""
|
|
50
|
-
def __init__(self, mode_config, mapping_config=None, is_cross_framework=False):
|
|
51
|
-
super().__init__(mode_config)
|
|
52
|
-
self.frame_name = MSComparator.__name__
|
|
53
|
-
|
|
54
|
-
self.stack_mode = mode_config.stack_mode
|
|
55
|
-
self.auto_analyze = mode_config.auto_analyze
|
|
56
|
-
self.fuzzy_match = mode_config.fuzzy_match
|
|
57
|
-
self.dump_mode = mode_config.dump_mode
|
|
58
|
-
|
|
59
|
-
if mapping_config:
|
|
60
|
-
self.cell_mapping = mapping_config.cell_mapping
|
|
61
|
-
self.api_mapping = mapping_config.api_mapping
|
|
62
|
-
self.data_mapping = mapping_config.data_mapping
|
|
63
|
-
|
|
64
|
-
if self.data_mapping:
|
|
65
|
-
self.cross_frame = is_cross_framework
|
|
66
|
-
else:
|
|
67
|
-
self.cross_frame = self.cell_mapping is not None or self.api_mapping is not None
|
|
68
|
-
self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
|
|
69
|
-
self.api_mapping_dict = self.load_mapping_file(self.api_mapping)
|
|
70
|
-
if self.api_mapping is not None:
|
|
71
|
-
self.ms_to_pt_mapping = self.load_internal_api()
|
|
72
|
-
|
|
73
|
-
if isinstance(self.data_mapping, str) or self.data_mapping is None:
|
|
74
|
-
self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
|
|
75
|
-
elif isinstance(self.data_mapping, dict):
|
|
76
|
-
self.data_mapping_dict = self.data_mapping
|
|
77
|
-
else:
|
|
78
|
-
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
79
|
-
f"{type(self.data_mapping)}")
|
|
80
|
-
|
|
81
|
-
@staticmethod
|
|
82
|
-
def process_data_name(result):
|
|
83
|
-
result['data_name_x'] = result.apply(lambda row: [row['data_name_x'], row['data_name_y']], axis=1)
|
|
84
|
-
return result
|
|
85
|
-
|
|
86
|
-
def calc_accuracy(self, result_df, header):
|
|
87
|
-
condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
|
|
88
|
-
result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
|
|
89
|
-
result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH
|
|
90
|
-
|
|
91
|
-
def calc_summary_diff(data_type: str):
|
|
92
|
-
def type_check(val):
|
|
93
|
-
check_series = pd.Series(False, index=val.index)
|
|
94
|
-
val_str = val.astype(str)
|
|
95
|
-
check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True
|
|
96
|
-
return check_series
|
|
97
|
-
|
|
98
|
-
def get_number(val):
|
|
99
|
-
return pd.to_numeric(val.astype(str), errors='coerce')
|
|
100
|
-
|
|
101
|
-
ms_val = result_df['NPU ' + data_type]
|
|
102
|
-
pt_val = result_df['Bench ' + data_type]
|
|
103
|
-
diff_name = data_type.capitalize() + ' diff'
|
|
104
|
-
rel_err_name = ('norm' if data_type == 'l2norm' else data_type).capitalize() + 'RelativeErr'
|
|
105
|
-
condition_na = ~type_check(ms_val) | ~type_check(pt_val)
|
|
106
|
-
result_df.loc[condition_na, [diff_name, rel_err_name]] = CompareConst.N_A
|
|
107
|
-
result_df.loc[~(condition_no_bench | condition_na), diff_name] = get_number(ms_val) - get_number(pt_val)
|
|
108
|
-
condition_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].isna()
|
|
109
|
-
condition_not_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].notna()
|
|
110
|
-
result_df.loc[condition_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN
|
|
111
|
-
condition_pt_zero = pt_val == 0
|
|
112
|
-
result_df.loc[condition_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.NAN
|
|
113
|
-
condition_ref_err = condition_not_nan_diff & ~condition_pt_zero
|
|
114
|
-
result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, diff_name] /
|
|
115
|
-
pt_val[condition_ref_err] * 100)
|
|
116
|
-
result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, rel_err_name]
|
|
117
|
-
.abs().astype(str) + '%')
|
|
118
|
-
magnitude = get_number(result_df[diff_name]).abs() / (
|
|
119
|
-
pd.Series(np.maximum(get_number(ms_val), get_number(pt_val))).abs() + CompareConst.EPSILON)
|
|
120
|
-
return magnitude > CompareConst.MAGNITUDE
|
|
121
|
-
|
|
122
|
-
if self.dump_mode == Const.MD5:
|
|
123
|
-
condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5]
|
|
124
|
-
result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS
|
|
125
|
-
result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
|
|
126
|
-
elif self.dump_mode == Const.SUMMARY:
|
|
127
|
-
warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']]
|
|
128
|
-
warning_flag = pd.DataFrame(warning_list).any()
|
|
129
|
-
result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
|
|
130
|
-
result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
|
|
131
|
-
result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
|
|
132
|
-
else:
|
|
133
|
-
fill_cols = [CompareConst.COSINE, CompareConst.EUC_DIST,
|
|
134
|
-
CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
|
|
135
|
-
CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
|
|
136
|
-
CompareConst.ERROR_MESSAGE]
|
|
137
|
-
result_df.loc[~condition_no_bench, fill_cols] = ''
|
|
138
|
-
result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES
|
|
139
|
-
return result_df[header]
|
|
140
|
-
|
|
141
|
-
def make_result_df(self, result):
|
|
142
|
-
header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:]
|
|
143
|
-
|
|
144
|
-
if self.stack_mode:
|
|
145
|
-
header.append(CompareConst.STACK)
|
|
146
|
-
if self.dump_mode == Const.ALL:
|
|
147
|
-
header.append(CompareConst.DATA_NAME)
|
|
148
|
-
result = self.process_data_name(result)
|
|
149
|
-
|
|
150
|
-
result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
|
|
151
|
-
'op_name_y': CompareConst.BENCH_NAME,
|
|
152
|
-
'dtype_x': CompareConst.NPU_DTYPE,
|
|
153
|
-
'dtype_y': CompareConst.BENCH_DTYPE,
|
|
154
|
-
'shape_x': CompareConst.NPU_SHAPE,
|
|
155
|
-
'shape_y': CompareConst.BENCH_SHAPE,
|
|
156
|
-
'md5_x': CompareConst.NPU_MD5,
|
|
157
|
-
'md5_y': CompareConst.BENCH_MD5,
|
|
158
|
-
'data_name_x': CompareConst.DATA_NAME,
|
|
159
|
-
'stack_info_x': CompareConst.STACK}, inplace=True)
|
|
160
|
-
|
|
161
|
-
npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
|
|
162
|
-
bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
|
|
163
|
-
CompareConst.BENCH_NORM]
|
|
164
|
-
|
|
165
|
-
def set_summary(summary):
|
|
166
|
-
if summary == CompareConst.N_A:
|
|
167
|
-
return [CompareConst.N_A] * 4
|
|
168
|
-
summary_list = []
|
|
169
|
-
for i in summary:
|
|
170
|
-
if i is None:
|
|
171
|
-
summary_list.append(CompareConst.N_A)
|
|
172
|
-
elif str(i).lower() == 'nan':
|
|
173
|
-
summary_list.append(CompareConst.NAN)
|
|
174
|
-
else:
|
|
175
|
-
summary_list.append(i)
|
|
176
|
-
return summary_list
|
|
177
|
-
|
|
178
|
-
result[npu_summary] = result['summary_x'].apply(set_summary).tolist()
|
|
179
|
-
result[bench_summary] = result['summary_y'].apply(set_summary).tolist()
|
|
180
|
-
|
|
181
|
-
result_df = pd.DataFrame(columns=header)
|
|
182
|
-
for h in header:
|
|
183
|
-
if h in result.columns:
|
|
184
|
-
result_df[h] = result[h]
|
|
185
|
-
return self.calc_accuracy(result_df, header)
|
|
186
|
-
|
|
187
|
-
def load_internal_api(self):
|
|
188
|
-
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
189
|
-
yaml_path = os.path.abspath(os.path.join(cur_path, CompareConst.INTERNAL_API_MAPPING_FILE))
|
|
190
|
-
return load_yaml(yaml_path)
|
|
191
|
-
|
|
192
|
-
def load_mapping_file(self, mapping_file):
|
|
193
|
-
if isinstance(mapping_file, str):
|
|
194
|
-
mapping_dict = load_yaml(mapping_file)
|
|
195
|
-
else:
|
|
196
|
-
mapping_dict = {}
|
|
197
|
-
return mapping_dict
|
|
198
|
-
|
|
199
|
-
def process_cell_mapping(self, npu_op_name):
|
|
200
|
-
if not npu_op_name:
|
|
201
|
-
return CompareConst.N_A
|
|
202
|
-
param_grad_flag = Const.PARAMS_GRAD in npu_op_name.split(Const.SEP)
|
|
203
|
-
if not param_grad_flag and not re.search(Const.REGEX_FORWARD_BACKWARD, npu_op_name):
|
|
204
|
-
return CompareConst.N_A
|
|
205
|
-
npu_op_name = npu_op_name.replace("Cell", "Module", 1)
|
|
206
|
-
if self.cell_mapping_dict:
|
|
207
|
-
# get cell name & class name from op_name
|
|
208
|
-
# Cell.fc1.Dense.forward.0.input.0
|
|
209
|
-
cell_name = re.split(r'\.(?:forward|backward|parameters_grad)\.', npu_op_name.split(Const.SEP, 1)[-1])[0]
|
|
210
|
-
if cell_name in self.cell_mapping_dict:
|
|
211
|
-
npu_op_name = npu_op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
|
|
212
|
-
return npu_op_name
|
|
213
|
-
|
|
214
|
-
def read_npy_data(self, dir_path, file_name, load_pt_file=False):
|
|
215
|
-
if not file_name:
|
|
216
|
-
return None
|
|
217
|
-
data_path = os.path.join(dir_path, file_name)
|
|
218
|
-
if load_pt_file:
|
|
219
|
-
import torch
|
|
220
|
-
from msprobe.pytorch.common.utils import load_pt
|
|
221
|
-
data_value = load_pt(data_path, True).detach()
|
|
222
|
-
if data_value.dtype == torch.bfloat16:
|
|
223
|
-
data_value = data_value.to(torch.float32)
|
|
224
|
-
data_value = data_value.numpy()
|
|
225
|
-
else:
|
|
226
|
-
data_value = load_npy(data_path)
|
|
227
|
-
return data_value
|
|
228
|
-
|
|
229
|
-
def process_internal_api_mapping(self, npu_op_name):
|
|
230
|
-
# get api name & class name from op_name
|
|
231
|
-
# Functional.addcmul.0.forward.input.0
|
|
232
|
-
ms_api_name = self.get_api_name(npu_op_name.split(Const.SEP))
|
|
233
|
-
class_name = ms_api_name.split(Const.SEP)[0]
|
|
234
|
-
if class_name == "Mint":
|
|
235
|
-
return npu_op_name.replace("Mint", "Torch")
|
|
236
|
-
elif class_name == "MintFunctional":
|
|
237
|
-
return npu_op_name.replace("MintFunctional", "Functional")
|
|
238
|
-
elif self.ms_to_pt_mapping.get(ms_api_name):
|
|
239
|
-
return npu_op_name.replace(ms_api_name, self.ms_to_pt_mapping.get(ms_api_name))
|
|
240
|
-
else:
|
|
241
|
-
return npu_op_name
|
|
242
|
-
|
|
243
|
-
def get_api_name(self, api_list):
|
|
244
|
-
try:
|
|
245
|
-
api_name = api_list[0] + Const.SEP + api_list[1]
|
|
246
|
-
except IndexError as error:
|
|
247
|
-
logger.error(f'Failed to retrieve API name, please check if the dump data is reasonable')
|
|
248
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
249
|
-
return api_name
|
|
250
|
-
|
|
251
|
-
def compare_process(self, file_lists):
|
|
252
|
-
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
253
|
-
npu_json_data = load_json(npu_json_path)
|
|
254
|
-
bench_json_data = load_json(bench_json_path)
|
|
255
|
-
stack_json_data = load_json(stack_json_path) if self.stack_mode else None
|
|
256
|
-
|
|
257
|
-
npu_df = self.gen_data_df(npu_json_data, stack_json_data)
|
|
258
|
-
bench_df = self.gen_data_df(bench_json_data, stack_json_data)
|
|
259
|
-
if self.cell_mapping:
|
|
260
|
-
npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping)
|
|
261
|
-
elif self.api_mapping:
|
|
262
|
-
npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_internal_api_mapping)
|
|
263
|
-
if isinstance(self.api_mapping, str):
|
|
264
|
-
self.modify_compare_data_with_user_mapping(npu_df, bench_df)
|
|
265
|
-
else:
|
|
266
|
-
npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME]
|
|
267
|
-
npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str)
|
|
268
|
-
bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str)
|
|
269
|
-
npu_df[CompareConst.COMPARE_SHAPE] = npu_df[Const.SHAPE]
|
|
270
|
-
bench_df[CompareConst.COMPARE_KEY] = bench_df[CompareConst.OP_NAME]
|
|
271
|
-
bench_df[CompareConst.COMPARE_SHAPE] = bench_df[Const.SHAPE]
|
|
272
|
-
match_result = pd.merge(npu_df, bench_df, on=[CompareConst.COMPARE_KEY, CompareConst.COMPARE_SHAPE],
|
|
273
|
-
how='outer')
|
|
274
|
-
match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A)
|
|
275
|
-
|
|
276
|
-
def gen_dtype_condition():
|
|
277
|
-
npu_dtype = match_result['dtype_x']
|
|
278
|
-
bench_dtype = match_result['dtype_y']
|
|
279
|
-
if self.cross_frame:
|
|
280
|
-
npu_dtype = npu_dtype.map(dtype_mapping).fillna(npu_dtype)
|
|
281
|
-
|
|
282
|
-
equal_condition = npu_dtype == bench_dtype
|
|
283
|
-
match_condition = (
|
|
284
|
-
(npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[0]) & bench_dtype.isin(
|
|
285
|
-
CompareConst.DTYPE_MATCH_GROUPS[0])) |
|
|
286
|
-
(npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[1]) & bench_dtype.isin(
|
|
287
|
-
CompareConst.DTYPE_MATCH_GROUPS[1]))
|
|
288
|
-
)
|
|
289
|
-
return equal_condition | match_condition
|
|
290
|
-
|
|
291
|
-
match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A
|
|
292
|
-
return self.make_result_df(match_result)
|
|
293
|
-
|
|
294
|
-
def modify_compare_data_with_user_mapping(self, npu_df, bench_df):
|
|
295
|
-
def get_api_indices_dict(op_name_df):
|
|
296
|
-
api_indices_dict = defaultdict(list)
|
|
297
|
-
for op_index, name in enumerate(op_name_df[CompareConst.OP_NAME]):
|
|
298
|
-
api = self.get_api_name(name.split(Const.SEP))
|
|
299
|
-
api_indices_dict[api].append(op_index)
|
|
300
|
-
return api_indices_dict
|
|
301
|
-
|
|
302
|
-
ms_api_indices_dict = get_api_indices_dict(npu_df)
|
|
303
|
-
pt_api_indices_dict = get_api_indices_dict(bench_df)
|
|
304
|
-
|
|
305
|
-
def gen_input_compare_key(pattern, term):
|
|
306
|
-
flag = True
|
|
307
|
-
for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')):
|
|
308
|
-
if op_name.split(pattern)[1].startswith(str(prefix)):
|
|
309
|
-
npu_df.loc[index, CompareConst.COMPARE_KEY] = (
|
|
310
|
-
op_name.replace(pattern + str(prefix),
|
|
311
|
-
pattern + str(mapping_dict.get(f'pt_{term}')[i])))
|
|
312
|
-
flag = False
|
|
313
|
-
return flag
|
|
314
|
-
|
|
315
|
-
for mapping_dict in self.api_mapping_dict:
|
|
316
|
-
keys_to_compare = [
|
|
317
|
-
('ms_args', 'pt_args'),
|
|
318
|
-
('ms_output', 'pt_output'),
|
|
319
|
-
('ms_parameters', 'pt_parameters'),
|
|
320
|
-
('ms_parameters_grad', 'pt_parameters_grad'),
|
|
321
|
-
]
|
|
322
|
-
if not all(len(mapping_dict.get(k1, [])) == len(mapping_dict.get(k2, [])) for k1, k2 in keys_to_compare):
|
|
323
|
-
logger.warning('The user-defined mapping table is incorrect,\
|
|
324
|
-
make sure that the number of parameters is equal')
|
|
325
|
-
continue
|
|
326
|
-
|
|
327
|
-
ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api')
|
|
328
|
-
if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict:
|
|
329
|
-
continue
|
|
330
|
-
for index in ms_api_indices_dict.get(ms_api):
|
|
331
|
-
op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1)
|
|
332
|
-
if CompareConst.INPUT_PATTERN in op_name:
|
|
333
|
-
is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args')
|
|
334
|
-
elif CompareConst.KWARGS_PATTERN in op_name:
|
|
335
|
-
is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args')
|
|
336
|
-
elif CompareConst.OUTPUT_PATTERN in op_name:
|
|
337
|
-
is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output')
|
|
338
|
-
elif CompareConst.PARAMS_PATTERN in op_name:
|
|
339
|
-
is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters')
|
|
340
|
-
elif CompareConst.PARAMS_GRAD_PATTERN in op_name:
|
|
341
|
-
is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad')
|
|
342
|
-
else:
|
|
343
|
-
logger.error(f'Excepted op_name: {op_name}')
|
|
344
|
-
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
345
|
-
if is_abandoned:
|
|
346
|
-
npu_df.loc[index, CompareConst.COMPARE_KEY] = op_name + 'abandoned'
|
|
347
|
-
|
|
348
|
-
def gen_data_df(self, data_json, stack_json_data):
|
|
349
|
-
result = {
|
|
350
|
-
CompareConst.OP_NAME: [],
|
|
351
|
-
Const.DTYPE: [],
|
|
352
|
-
Const.SHAPE: [],
|
|
353
|
-
Const.SUMMARY: [],
|
|
354
|
-
'stack_info': []
|
|
355
|
-
}
|
|
356
|
-
if self.dump_mode == Const.ALL:
|
|
357
|
-
result['data_name'] = []
|
|
358
|
-
elif self.dump_mode == Const.MD5:
|
|
359
|
-
result[Const.MD5] = []
|
|
360
|
-
for data_name in data_json['data']:
|
|
361
|
-
check_op_str_pattern_valid(data_name)
|
|
362
|
-
merge_list = self.gen_merge_list(data_json, data_name, stack_json_data)
|
|
363
|
-
if not merge_list:
|
|
364
|
-
continue
|
|
365
|
-
|
|
366
|
-
op_name_list = merge_list.get(CompareConst.OP_NAME)
|
|
367
|
-
summary_list = merge_list.get(Const.SUMMARY)
|
|
368
|
-
data_name_list = merge_list.get('data_name')
|
|
369
|
-
op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
|
|
370
|
-
summary_list,
|
|
371
|
-
data_name_list)
|
|
372
|
-
for op_name in op_name_reorder:
|
|
373
|
-
result[CompareConst.OP_NAME].append(op_name)
|
|
374
|
-
if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name):
|
|
375
|
-
struct = merge_list[CompareConst.INPUT_STRUCT].pop(0)
|
|
376
|
-
elif CompareConst.OUTPUT_PATTERN in op_name:
|
|
377
|
-
struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0)
|
|
378
|
-
elif CompareConst.PARAMS_PATTERN in op_name:
|
|
379
|
-
struct = merge_list[CompareConst.PARAMS_STRUCT].pop(0)
|
|
380
|
-
else:
|
|
381
|
-
struct = merge_list[CompareConst.PARAMS_GRAD_STRUCT].pop(0)
|
|
382
|
-
result[Const.DTYPE].append(struct[0])
|
|
383
|
-
result[Const.SHAPE].append(struct[1])
|
|
384
|
-
if self.dump_mode == Const.MD5:
|
|
385
|
-
result[Const.MD5].append(struct[2])
|
|
386
|
-
result[Const.SUMMARY].append(summary_reorder.pop(0))
|
|
387
|
-
result['stack_info'].append(merge_list['stack_info'][0] if self.stack_mode else None)
|
|
388
|
-
if self.dump_mode == Const.ALL:
|
|
389
|
-
result['data_name'].append(data_name_reorder.pop(0))
|
|
390
|
-
return pd.DataFrame(result)
|
|
18
|
+
from msprobe.mindspore.compare.utils import read_npy_data, check_cross_framework
|
|
391
19
|
|
|
392
20
|
|
|
393
|
-
def
|
|
394
|
-
|
|
395
|
-
if
|
|
396
|
-
|
|
21
|
+
def read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name, cross_frame) -> tuple:
|
|
22
|
+
n_value = read_npy_data(npu_dir, npu_data_name)
|
|
23
|
+
if cross_frame:
|
|
24
|
+
from msprobe.pytorch.compare.utils import read_pt_data
|
|
25
|
+
b_value = read_pt_data(bench_dir, bench_data_name)
|
|
397
26
|
else:
|
|
398
|
-
|
|
27
|
+
b_value = read_npy_data(bench_dir, bench_data_name)
|
|
28
|
+
return n_value, b_value
|
|
399
29
|
|
|
400
30
|
|
|
401
31
|
def ms_compare(input_param, output_path, **kwargs):
|
|
402
|
-
|
|
403
|
-
auto_analyze = kwargs.get('auto_analyze', True)
|
|
404
|
-
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
405
|
-
cell_mapping = kwargs.get('cell_mapping', None)
|
|
406
|
-
api_mapping = kwargs.get('api_mapping', None)
|
|
407
|
-
data_mapping = kwargs.get('data_mapping', None)
|
|
408
|
-
layer_mapping = kwargs.get('layer_mapping', None)
|
|
409
|
-
suffix = kwargs.get('suffix', '')
|
|
32
|
+
config = setup_comparison(input_param, output_path, **kwargs)
|
|
410
33
|
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
if 'stack_json_path' in input_param:
|
|
414
|
-
stack_mode = kwargs.get('stack_mode', False)
|
|
415
|
-
else:
|
|
416
|
-
stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
|
|
417
|
-
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
|
|
418
|
-
create_directory(output_path)
|
|
419
|
-
check_compare_param(input_param, output_path, dump_mode, stack_mode)
|
|
420
|
-
except (CompareException, FileCheckException) as error:
|
|
421
|
-
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
422
|
-
raise CompareException(error.code) from error
|
|
423
|
-
if layer_mapping:
|
|
424
|
-
data_mapping = generate_data_mapping_by_layer_mapping(input_param, layer_mapping, output_path)
|
|
34
|
+
if config.layer_mapping:
|
|
35
|
+
config.data_mapping = generate_data_mapping_by_layer_mapping(input_param, config.layer_mapping, output_path)
|
|
425
36
|
|
|
426
|
-
mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode)
|
|
427
|
-
mapping_config = MappingConfig(cell_mapping, api_mapping, data_mapping)
|
|
428
37
|
is_cross_framework = check_cross_framework(input_param.get('bench_json_path'))
|
|
429
|
-
|
|
430
|
-
|
|
38
|
+
mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match,
|
|
39
|
+
config.dump_mode, config.compared_file_type)
|
|
40
|
+
mapping_config = MappingConfig(config.cell_mapping, config.api_mapping, config.data_mapping)
|
|
41
|
+
ms_comparator = Comparator(read_real_data, mode_config, mapping_config, is_cross_framework)
|
|
42
|
+
ms_comparator.compare_core(input_param, output_path, suffix=config.suffix)
|
|
@@ -85,11 +85,13 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
|
|
|
85
85
|
}
|
|
86
86
|
for statistic_file in statistic_file_list:
|
|
87
87
|
content = read_csv(statistic_file, as_pd=False)
|
|
88
|
+
if not content:
|
|
89
|
+
logger.error(f'Empty dump file: {statistic_file}')
|
|
90
|
+
raise CompareException(f'Empty dump file: {statistic_file}')
|
|
88
91
|
header = content[0]
|
|
89
|
-
for
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
header_index[key] = index
|
|
92
|
+
for index, value in enumerate(header):
|
|
93
|
+
if value in header_index:
|
|
94
|
+
header_index[value] = index
|
|
93
95
|
statistic_data_list.extend(content[1:])
|
|
94
96
|
|
|
95
97
|
for key in header_index.keys():
|
|
@@ -97,7 +99,14 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
|
|
|
97
99
|
logger.warning(f"Data_path {statistic_file_path} has no key {key}.")
|
|
98
100
|
|
|
99
101
|
for data in statistic_data_list:
|
|
100
|
-
|
|
102
|
+
'''
|
|
103
|
+
13列分别是OpType, OpName, TaskId, StreamId, TimeStamp, IO, Slot, DataSize,
|
|
104
|
+
DataType, Shape, MaxValue, MinValue, L2NormValue
|
|
105
|
+
'''
|
|
106
|
+
if len(data) < 13:
|
|
107
|
+
logger.error(f'Dump file {statistic_file_path} has been modified into incorrect format!')
|
|
108
|
+
raise CompareException(f'Dump file {statistic_file_path} has been modified into incorrect format!')
|
|
109
|
+
compare_key = f"{data[1]}.{data[2]}.{data[5]}.{data[6]}" # OpName, TaskId, IO, Slot
|
|
101
110
|
op_name = f"{compare_key} {statistic_file_path}"
|
|
102
111
|
timestamp = int(data[4])
|
|
103
112
|
result_data = [op_name, compare_key, timestamp]
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.file_utils import load_npy, FileChecker, FileCheckConst
|
|
20
|
+
from msprobe.core.common.utils import detect_framework_by_dump_json
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def read_npy_data(dir_path, file_name):
|
|
24
|
+
if not file_name:
|
|
25
|
+
return None
|
|
26
|
+
|
|
27
|
+
data_path = os.path.join(dir_path, file_name)
|
|
28
|
+
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
29
|
+
FileCheckConst.NUMPY_SUFFIX, False)
|
|
30
|
+
data_path = path_checker.common_check()
|
|
31
|
+
data_value = load_npy(data_path)
|
|
32
|
+
return data_value
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def check_cross_framework(bench_json_path):
|
|
36
|
+
framework = detect_framework_by_dump_json(bench_json_path)
|
|
37
|
+
return framework == Const.PT_FRAMEWORK
|
|
@@ -15,12 +15,18 @@
|
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
17
|
|
|
18
|
+
from mindspore import nn
|
|
19
|
+
|
|
18
20
|
from msprobe.core.common.const import Const
|
|
19
21
|
from msprobe.core.common.exceptions import MsprobeException
|
|
20
22
|
from msprobe.core.common.file_utils import create_directory
|
|
23
|
+
from msprobe.core.common.log import logger
|
|
21
24
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
22
25
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
23
|
-
from msprobe.
|
|
26
|
+
from msprobe.mindspore.common.utils import is_mindtorch
|
|
27
|
+
|
|
28
|
+
if is_mindtorch():
|
|
29
|
+
import torch
|
|
24
30
|
|
|
25
31
|
|
|
26
32
|
class DebuggerConfig:
|
|
@@ -41,8 +47,12 @@ class DebuggerConfig:
|
|
|
41
47
|
self.check_mode = task_config.check_mode
|
|
42
48
|
self.framework = Const.MS_FRAMEWORK
|
|
43
49
|
self.summary_mode = task_config.summary_mode
|
|
50
|
+
self.stat_cal_mode = task_config.stat_cal_mode if hasattr(task_config, 'stat_cal_mode') else None
|
|
51
|
+
self.device_stat_precision_mode = task_config.device_stat_precision_mode \
|
|
52
|
+
if hasattr(task_config, 'device_stat_precision_mode') else None
|
|
44
53
|
self.async_dump = common_config.async_dump if common_config.async_dump else False
|
|
45
54
|
self.check()
|
|
55
|
+
self._check_statistics_config(task_config)
|
|
46
56
|
create_directory(self.dump_path)
|
|
47
57
|
|
|
48
58
|
if self.task == Const.FREE_BENCHMARK:
|
|
@@ -62,6 +72,31 @@ class DebuggerConfig:
|
|
|
62
72
|
raise ValueError
|
|
63
73
|
self.dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
|
|
64
74
|
|
|
75
|
+
@staticmethod
|
|
76
|
+
def check_model(models, token_range=None):
|
|
77
|
+
if token_range and not models:
|
|
78
|
+
error_info = "The 'model' parameter must be provided when token_range is not None"
|
|
79
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
80
|
+
|
|
81
|
+
target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell")
|
|
82
|
+
if models is None or isinstance(models, target_module_type[0]):
|
|
83
|
+
return models
|
|
84
|
+
error_model = None
|
|
85
|
+
if isinstance(models, (list, tuple)):
|
|
86
|
+
for model in models:
|
|
87
|
+
if not isinstance(model, target_module_type[0]):
|
|
88
|
+
error_model = model
|
|
89
|
+
break
|
|
90
|
+
else:
|
|
91
|
+
error_model = models
|
|
92
|
+
|
|
93
|
+
if error_model is not None:
|
|
94
|
+
error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] "
|
|
95
|
+
f"type, currently there is a {type(error_model)} type.")
|
|
96
|
+
raise MsprobeException(
|
|
97
|
+
MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
98
|
+
return models
|
|
99
|
+
|
|
65
100
|
def check(self):
|
|
66
101
|
if not self.dump_path:
|
|
67
102
|
raise Exception("Dump path is empty.")
|
|
@@ -76,8 +111,12 @@ class DebuggerConfig:
|
|
|
76
111
|
self.check_mode = "all"
|
|
77
112
|
if not isinstance(self.async_dump, bool):
|
|
78
113
|
raise Exception("The parameters async_dump should be bool.")
|
|
79
|
-
if self.async_dump and self.task == Const.TENSOR
|
|
80
|
-
|
|
114
|
+
if self.async_dump and self.task == Const.TENSOR:
|
|
115
|
+
if self.level_ori == Const.LEVEL_DEBUG:
|
|
116
|
+
self.list = [] # async_dump + debug level case ignore list
|
|
117
|
+
if not self.list and self.level_ori != Const.LEVEL_DEBUG:
|
|
118
|
+
raise Exception("The parameters async_dump is true in tensor task,"
|
|
119
|
+
" the parameters list cannot be empty.")
|
|
81
120
|
if self.task == Const.STRUCTURE and self.level_ori not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
82
121
|
logger.warning_on_rank_0(
|
|
83
122
|
f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
|
|
@@ -98,3 +137,14 @@ class DebuggerConfig:
|
|
|
98
137
|
if not self.list or len(self.list) != 1:
|
|
99
138
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
100
139
|
f"When level is set to L2, the list must be configured as a list with one api name.")
|
|
140
|
+
|
|
141
|
+
def _check_statistics_config(self, task_config):
|
|
142
|
+
if self.task != Const.STATISTICS:
|
|
143
|
+
return
|
|
144
|
+
self.tensor_list = []
|
|
145
|
+
if not hasattr(task_config, "tensor_list"):
|
|
146
|
+
return
|
|
147
|
+
if self.level_ori == Const.LEVEL_DEBUG and task_config.tensor_list:
|
|
148
|
+
logger.warning_on_rank_0("When level is set to debug, the tensor_list will be invalid.")
|
|
149
|
+
return
|
|
150
|
+
self.tensor_list = task_config.tensor_list
|