mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
- msprobe/README.md +27 -22
- msprobe/core/common/const.py +129 -60
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +43 -33
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +16 -9
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +30 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +94 -10
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
- msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
- msprobe/core/data_dump/json_writer.py +61 -40
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +27 -1
- msprobe/docs/02.config_introduction.md +27 -23
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +103 -16
- msprobe/docs/06.data_dump_MindSpore.md +76 -32
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +332 -273
- msprobe/docs/21.visualization_PyTorch.md +42 -13
- msprobe/docs/22.visualization_MindSpore.md +43 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +301 -27
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
- msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +48 -18
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +31 -6
- msprobe/mindspore/debugger/precision_debugger.py +45 -14
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +21 -15
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +873 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +309 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +114 -34
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +97 -4
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +24 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +8 -2
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +18 -14
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +238 -193
- msprobe/pytorch/monitor/module_metric.py +9 -6
- msprobe/pytorch/monitor/optimizer_collect.py +100 -67
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +76 -44
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +30 -29
- msprobe/pytorch/service.py +114 -32
- msprobe/visualization/builder/graph_builder.py +75 -10
- msprobe/visualization/builder/msprobe_adapter.py +7 -6
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +11 -3
- msprobe/visualization/graph/distributed_analyzer.py +71 -3
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +4 -3
- msprobe/visualization/graph_service.py +4 -5
- msprobe/visualization/utils.py +12 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -25,6 +25,7 @@ from msprobe.core.common.file_utils import load_npy
|
|
|
25
25
|
from msprobe.mindspore.api_accuracy_checker.type_mapping import (api_info_type_str_to_type,
|
|
26
26
|
ms_dtype_to_dtype_str, torch_dtype_to_dtype_str,
|
|
27
27
|
dtype_str_to_ms_dtype, dtype_str_to_np_dtype,
|
|
28
|
+
dtype_str_to_mindtorch_dtype,
|
|
28
29
|
dtype_str_to_torch_dtype, type_to_api_info_type_str,
|
|
29
30
|
DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE, TUPLE_TYPE_STR,
|
|
30
31
|
MINDSPORE_TENSOR_TYPE_STR, MINDSPORE_DTYPE_TYPE_STR,
|
|
@@ -33,6 +34,15 @@ from msprobe.mindspore.api_accuracy_checker.type_mapping import (api_info_type_s
|
|
|
33
34
|
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
|
|
34
35
|
from msprobe.mindspore.common.log import logger
|
|
35
36
|
|
|
37
|
+
import msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer as env_module
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
if env_module.is_valid_pt_mt_env:
|
|
41
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch
|
|
42
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import torch
|
|
43
|
+
else:
|
|
44
|
+
import torch
|
|
45
|
+
|
|
36
46
|
|
|
37
47
|
class MstensorMetaData:
|
|
38
48
|
def __init__(self, dtype_str, npy_path, maximum, minimum, shape) -> None:
|
|
@@ -86,6 +96,37 @@ class ComputeElement:
|
|
|
86
96
|
torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
|
|
87
97
|
return torch_tensor
|
|
88
98
|
|
|
99
|
+
@staticmethod
|
|
100
|
+
def transfer_to_mindtorch_tensor(ms_tensor):
|
|
101
|
+
"""
|
|
102
|
+
Args:
|
|
103
|
+
ms_tensor: mindspore.Tensor
|
|
104
|
+
Return:
|
|
105
|
+
mindtorch_tensor: mindtorch.Tensor
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
ms_dtype = ms_tensor.dtype
|
|
109
|
+
|
|
110
|
+
dtype_str = ms_dtype_to_dtype_str.get(ms_dtype)
|
|
111
|
+
|
|
112
|
+
if dtype_str not in dtype_str_to_mindtorch_dtype:
|
|
113
|
+
err_msg = f"ComputeElement.transfer_to_mindtorch_tensor failed: no matching mindtorch dtype for {dtype_str}"
|
|
114
|
+
logger.error_log_with_exp(err_msg,
|
|
115
|
+
ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
116
|
+
else:
|
|
117
|
+
mindtorch_dtype = dtype_str_to_mindtorch_dtype.get(dtype_str)
|
|
118
|
+
|
|
119
|
+
if dtype_str in int_dtype_str_list:
|
|
120
|
+
middle_dtype = mindspore.int64
|
|
121
|
+
else:
|
|
122
|
+
middle_dtype = mindspore.float64
|
|
123
|
+
|
|
124
|
+
np_ndarray = ms_tensor.astype(middle_dtype).numpy()
|
|
125
|
+
|
|
126
|
+
mindtorch_tensor = mindtorch.from_numpy(np_ndarray).to(ms_dtype)
|
|
127
|
+
|
|
128
|
+
return mindtorch_tensor
|
|
129
|
+
|
|
89
130
|
@staticmethod
|
|
90
131
|
def transfer_to_mindspore_tensor(torch_tensor):
|
|
91
132
|
'''
|
|
@@ -141,8 +182,11 @@ class ComputeElement:
|
|
|
141
182
|
elif isinstance(self.parameter, DtypeMetaData):
|
|
142
183
|
if tensor_platform == Const.MS_FRAMEWORK:
|
|
143
184
|
parameter_tmp = dtype_str_to_ms_dtype.get(self.parameter.dtype_str)
|
|
144
|
-
|
|
185
|
+
elif tensor_platform == Const.PT_FRAMEWORK:
|
|
145
186
|
parameter_tmp = dtype_str_to_torch_dtype.get(self.parameter.dtype_str)
|
|
187
|
+
elif tensor_platform == Const.MT_FRAMEWORK:
|
|
188
|
+
parameter_tmp = dtype_str_to_mindtorch_dtype.get(self.parameter.dtype_str)
|
|
189
|
+
|
|
146
190
|
elif isinstance(self.parameter, MstensorMetaData):
|
|
147
191
|
mstensor_meta_data = self.parameter
|
|
148
192
|
ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
|
|
@@ -161,6 +205,8 @@ class ComputeElement:
|
|
|
161
205
|
# if necessary, do transfer
|
|
162
206
|
if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
|
|
163
207
|
parameter = self.transfer_to_torch_tensor(parameter_tmp)
|
|
208
|
+
elif not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.MT_FRAMEWORK:
|
|
209
|
+
parameter = self.transfer_to_mindtorch_tensor(parameter_tmp)
|
|
164
210
|
elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform == Const.MS_FRAMEWORK:
|
|
165
211
|
parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
|
|
166
212
|
else:
|
|
@@ -16,12 +16,13 @@
|
|
|
16
16
|
import os
|
|
17
17
|
import csv
|
|
18
18
|
|
|
19
|
-
from msprobe.core.common.const import Const, CompareConst
|
|
19
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
20
20
|
from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv
|
|
21
21
|
from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException
|
|
22
22
|
from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
|
|
23
23
|
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
24
24
|
from msprobe.mindspore.common.log import logger
|
|
25
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
class ResultCsvEntry:
|
|
@@ -27,10 +27,11 @@ import numpy as np
|
|
|
27
27
|
from tqdm import tqdm
|
|
28
28
|
|
|
29
29
|
# 本地应用/库特定导入
|
|
30
|
-
from msprobe.core.common.const import Const, CompareConst
|
|
30
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
31
31
|
from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker, BasicInfoAndStatus
|
|
32
32
|
from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataManager
|
|
33
33
|
from msprobe.mindspore.common.log import logger
|
|
34
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
34
35
|
|
|
35
36
|
|
|
36
37
|
class MultiApiAccuracyChecker(ApiAccuracyChecker):
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import gc
|
|
18
|
+
import sys
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
import mindspore
|
|
21
|
+
from msprobe.mindspore.common.log import logger
|
|
22
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
23
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
24
|
+
import torch as mindtorch
|
|
25
|
+
from torch import Tensor as mindtorch_tensor
|
|
26
|
+
import torch.nn.functional as mindtorch_func
|
|
27
|
+
import torch.distributed as mindtorch_dist
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
is_valid_pt_mt_env = True
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def is_mindtorch():
|
|
34
|
+
mindtorch_check_result = False
|
|
35
|
+
try:
|
|
36
|
+
import torch as test_torch
|
|
37
|
+
from mindspore import Tensor as MindsporeTensor
|
|
38
|
+
except ImportError:
|
|
39
|
+
return mindtorch_check_result
|
|
40
|
+
tensor = test_torch.tensor(0.0)
|
|
41
|
+
if isinstance(tensor, MindsporeTensor):
|
|
42
|
+
mindtorch_check_result = True
|
|
43
|
+
|
|
44
|
+
return mindtorch_check_result
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def remove_torch_related_paths():
|
|
48
|
+
removed_paths = []
|
|
49
|
+
if not is_mindtorch():
|
|
50
|
+
return
|
|
51
|
+
try:
|
|
52
|
+
import torch as remove_torch
|
|
53
|
+
torch_file = remove_torch.__file__
|
|
54
|
+
except ImportError:
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
torch_dir = os.path.dirname(torch_file)
|
|
58
|
+
|
|
59
|
+
torch_dir_path = Path(torch_dir).resolve()
|
|
60
|
+
parent_dir = torch_dir_path.parent
|
|
61
|
+
|
|
62
|
+
paths_to_remove = [str(parent_dir)]
|
|
63
|
+
|
|
64
|
+
for path in paths_to_remove:
|
|
65
|
+
try:
|
|
66
|
+
path_resolved = str(Path(path).resolve())
|
|
67
|
+
except Exception as error:
|
|
68
|
+
logger.debug(f"Failed to resolve path {path}: {error}")
|
|
69
|
+
continue
|
|
70
|
+
|
|
71
|
+
if path_resolved in sys.path:
|
|
72
|
+
index = sys.path.index(path_resolved)
|
|
73
|
+
removed_paths.append((path_resolved, index))
|
|
74
|
+
sys.path.pop(index)
|
|
75
|
+
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def clear_torch_from_sys_modules():
|
|
80
|
+
modules_to_remove = []
|
|
81
|
+
for module in sys.modules:
|
|
82
|
+
if module == "torch" or module.startswith("torch."):
|
|
83
|
+
modules_to_remove.append(module)
|
|
84
|
+
|
|
85
|
+
for module in modules_to_remove:
|
|
86
|
+
del sys.modules[module]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def set_pt_mt_env_invalid():
|
|
90
|
+
global is_valid_pt_mt_env
|
|
91
|
+
is_valid_pt_mt_env = False
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def delete_torch_paths():
|
|
95
|
+
|
|
96
|
+
if not is_mindtorch():
|
|
97
|
+
set_pt_mt_env_invalid()
|
|
98
|
+
|
|
99
|
+
clear_torch_from_sys_modules()
|
|
100
|
+
|
|
101
|
+
for count_delete_env_path in range(MsCompareConst.MAX_RECURSION_DEPTH):
|
|
102
|
+
if not is_mindtorch():
|
|
103
|
+
break
|
|
104
|
+
|
|
105
|
+
remove_torch_related_paths()
|
|
106
|
+
|
|
107
|
+
clear_torch_from_sys_modules()
|
|
108
|
+
|
|
109
|
+
if count_delete_env_path >= MsCompareConst.MAX_RECURSION_DEPTH - 1:
|
|
110
|
+
raise Exception(f"Please check if you have a valid PyTorch and MindTorch environment, and ensure "
|
|
111
|
+
f"the PYTHONPATH environment variable depth does not exceed {Const.MAX_RECURSION_DEPTH}.")
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
if not is_mindtorch():
|
|
115
|
+
set_pt_mt_env_invalid()
|
|
116
|
+
|
|
117
|
+
else:
|
|
118
|
+
initial_sys_path = sys.path.copy()
|
|
119
|
+
delete_torch_paths()
|
|
120
|
+
|
|
121
|
+
gc.collect()
|
|
122
|
+
|
|
123
|
+
import torch
|
|
124
|
+
|
|
125
|
+
if is_mindtorch():
|
|
126
|
+
set_pt_mt_env_invalid()
|
|
127
|
+
|
|
128
|
+
sys.path = initial_sys_path
|
|
129
|
+
|
|
130
|
+
|
|
@@ -15,10 +15,18 @@
|
|
|
15
15
|
|
|
16
16
|
import mindspore
|
|
17
17
|
import numpy as np
|
|
18
|
-
import torch
|
|
19
18
|
from mindspore._c_expression import typing
|
|
20
19
|
from mindspore.common import dtype as mstype
|
|
21
20
|
|
|
21
|
+
from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer
|
|
22
|
+
|
|
23
|
+
if torch_mindtorch_importer.is_valid_pt_mt_env:
|
|
24
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch
|
|
25
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import torch
|
|
26
|
+
else:
|
|
27
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch
|
|
28
|
+
import torch
|
|
29
|
+
|
|
22
30
|
INT8 = "Int8"
|
|
23
31
|
UINT8 = "UInt8"
|
|
24
32
|
INT16 = "Int16"
|
|
@@ -82,6 +90,21 @@ dtype_str_to_torch_dtype = {
|
|
|
82
90
|
}
|
|
83
91
|
torch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_torch_dtype.items()}
|
|
84
92
|
|
|
93
|
+
|
|
94
|
+
dtype_str_to_mindtorch_dtype = {
|
|
95
|
+
INT8: mindtorch.int8,
|
|
96
|
+
UINT8: mindtorch.uint8,
|
|
97
|
+
INT16: mindtorch.int16,
|
|
98
|
+
INT32: mindtorch.int32,
|
|
99
|
+
INT64: mindtorch.int64,
|
|
100
|
+
FLOAT16: mindtorch.float16,
|
|
101
|
+
FLOAT32: mindtorch.float32,
|
|
102
|
+
FLOAT64: mindtorch.float64,
|
|
103
|
+
BOOL: mindtorch.bool,
|
|
104
|
+
BFLOAT16: mindtorch.bfloat16,
|
|
105
|
+
}
|
|
106
|
+
mindtorch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_mindtorch_dtype.items()}
|
|
107
|
+
|
|
85
108
|
MINDSPORE_TENSOR_TYPE_STR = "mindspore.Tensor"
|
|
86
109
|
BOOL_TYPE_STR = "bool"
|
|
87
110
|
INT_TYPE_STR = "int"
|
|
@@ -82,10 +82,12 @@ class GlobalContext:
|
|
|
82
82
|
def __init__(self):
|
|
83
83
|
self.is_constructed = True
|
|
84
84
|
self.dump_data_dir = ""
|
|
85
|
+
self.framework = Const.MS_FRAMEWORK
|
|
85
86
|
|
|
86
|
-
def init(self, is_constructed, dump_data_dir):
|
|
87
|
+
def init(self, is_constructed, dump_data_dir, framework):
|
|
87
88
|
self.is_constructed = is_constructed
|
|
88
89
|
self.dump_data_dir = dump_data_dir
|
|
90
|
+
self.framework = framework
|
|
89
91
|
|
|
90
92
|
def get_dump_data_dir(self):
|
|
91
93
|
return self.dump_data_dir
|
|
@@ -93,5 +95,8 @@ class GlobalContext:
|
|
|
93
95
|
def get_is_constructed(self):
|
|
94
96
|
return self.is_constructed
|
|
95
97
|
|
|
98
|
+
def get_framework(self):
|
|
99
|
+
return self.framework
|
|
100
|
+
|
|
96
101
|
|
|
97
102
|
global_context = GlobalContext()
|
|
@@ -70,6 +70,67 @@ class Const:
|
|
|
70
70
|
}
|
|
71
71
|
|
|
72
72
|
|
|
73
|
+
class MsCompareConst:
|
|
74
|
+
# api_info field
|
|
75
|
+
MINT = "Mint"
|
|
76
|
+
MINT_FUNCTIONAL = "MintFunctional"
|
|
77
|
+
TENSOR_API = "Tensor"
|
|
78
|
+
FUNCTIONAL_API = "Functional"
|
|
79
|
+
FUSION_API = "FUSION"
|
|
80
|
+
|
|
81
|
+
API_NAME_STR_LENGTH = 4
|
|
82
|
+
MAX_RECURSION_DEPTH = 20
|
|
83
|
+
|
|
84
|
+
# Mindtorch api_info field
|
|
85
|
+
MINDTORCH_TENSOR = "Tensor"
|
|
86
|
+
MINDTORCH = "Torch"
|
|
87
|
+
MINDTORCH_FUNC = "Functional"
|
|
88
|
+
MINDTORCH_NPU = "NPU"
|
|
89
|
+
MINDTORCH_DIST = "Distributed"
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
MT_VALID_API_TYPES = [
|
|
94
|
+
MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR
|
|
95
|
+
]
|
|
96
|
+
SUPPORTED_FUSION_LIST = ["flash_attention_score"]
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
TASK_FIELD = "task"
|
|
100
|
+
STATISTICS_TASK = "statistics"
|
|
101
|
+
FRAMEWORK = "framework"
|
|
102
|
+
TENSOR_TASK = "tensor"
|
|
103
|
+
DUMP_DATA_DIR_FIELD = "dump_data_dir"
|
|
104
|
+
DATA_FIELD = "data"
|
|
105
|
+
|
|
106
|
+
# supported api yaml
|
|
107
|
+
SUPPORTED_API_LIST_FILE = "checker_support_api.yaml"
|
|
108
|
+
SUPPORTED_TENSOR_LIST_KEY = "tensor"
|
|
109
|
+
|
|
110
|
+
# detail_csv
|
|
111
|
+
DETAIL_CSV_API_NAME = "API Name"
|
|
112
|
+
DETAIL_CSV_BENCH_DTYPE = "Bench Dtype"
|
|
113
|
+
DETAIL_CSV_TESTED_DTYPE = "Tested Dtype"
|
|
114
|
+
DETAIL_CSV_SHAPE = "Shape"
|
|
115
|
+
DETAIL_CSV_PASS_STATUS = "Status"
|
|
116
|
+
DETAIL_CSV_MESSAGE = "Message"
|
|
117
|
+
DETAIL_CSV_FILE_NAME = "accuracy_checking_details"
|
|
118
|
+
|
|
119
|
+
# result_csv
|
|
120
|
+
RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success"
|
|
121
|
+
RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success"
|
|
122
|
+
RESULT_CSV_FILE_NAME = "accuracy_checking_result"
|
|
123
|
+
|
|
124
|
+
EPSILON = 1e-8
|
|
125
|
+
|
|
126
|
+
class ProcessStatus:
|
|
127
|
+
SUCCESS = "success"
|
|
128
|
+
API_NOT_FOUND = "api_not_found"
|
|
129
|
+
EXCEPTION_SKIP = "exception_skip"
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
|
|
73
134
|
class FreeBenchmarkConst:
|
|
74
135
|
ADD_NOISE = "add_noise"
|
|
75
136
|
BIT_NOISE = "bit_noise"
|
|
@@ -25,7 +25,31 @@ from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
|
25
25
|
from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy
|
|
26
26
|
from msprobe.core.common.log import logger
|
|
27
27
|
from msprobe.core.common.const import Const
|
|
28
|
-
from msprobe.core.common.utils import CompareException, check_seed_all
|
|
28
|
+
from msprobe.core.common.utils import CompareException, check_seed_all, is_save_variable_valid
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MsprobeStep(ms.train.Callback):
|
|
32
|
+
def __init__(self, debugger):
|
|
33
|
+
super(MsprobeStep, self).__init__()
|
|
34
|
+
self.debugger = debugger
|
|
35
|
+
|
|
36
|
+
def on_train_step_begin(self, run_context):
|
|
37
|
+
self.debugger.start()
|
|
38
|
+
|
|
39
|
+
def on_train_step_end(self, run_context):
|
|
40
|
+
self.debugger.stop()
|
|
41
|
+
self.debugger.step()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class MsprobeInitStep(ms.train.Callback):
|
|
45
|
+
def on_train_begin(self, run_context):
|
|
46
|
+
try:
|
|
47
|
+
from ms._c_expression import _set_init_iter
|
|
48
|
+
except ImportError:
|
|
49
|
+
logger.warning('MsprobeInitStep does not work on this version of MindSpore.')
|
|
50
|
+
return
|
|
51
|
+
cb_params = run_context.original_args()
|
|
52
|
+
_set_init_iter(cb_params.cur_step_num)
|
|
29
53
|
|
|
30
54
|
|
|
31
55
|
def get_rank_if_initialized():
|
|
@@ -93,20 +117,6 @@ def seed_all(seed=1234, mode=False, rm_dropout=True):
|
|
|
93
117
|
remove_dropout()
|
|
94
118
|
|
|
95
119
|
|
|
96
|
-
class MsprobeStep(ms.train.Callback):
|
|
97
|
-
|
|
98
|
-
def __init__(self, debugger):
|
|
99
|
-
super(MsprobeStep, self).__init__()
|
|
100
|
-
self.debugger = debugger
|
|
101
|
-
|
|
102
|
-
def on_train_step_begin(self, run_context):
|
|
103
|
-
self.debugger.start()
|
|
104
|
-
|
|
105
|
-
def on_train_step_end(self, run_context):
|
|
106
|
-
self.debugger.stop()
|
|
107
|
-
self.debugger.step()
|
|
108
|
-
|
|
109
|
-
|
|
110
120
|
class Dropout(ops.Dropout):
|
|
111
121
|
def __init__(self, keep_prob=0.5, seed0=0, seed1=1):
|
|
112
122
|
super().__init__(1., seed0, seed1)
|
|
@@ -151,11 +161,10 @@ def is_mindtorch():
|
|
|
151
161
|
mindtorch_check_result = False
|
|
152
162
|
try:
|
|
153
163
|
import torch
|
|
154
|
-
from mindspore._c_expression import Tensor
|
|
155
164
|
except ImportError:
|
|
156
165
|
return mindtorch_check_result
|
|
157
166
|
tensor = torch.tensor(0.0)
|
|
158
|
-
if isinstance(tensor, Tensor):
|
|
167
|
+
if isinstance(tensor, ms.Tensor):
|
|
159
168
|
mindtorch_check_result = True
|
|
160
169
|
return mindtorch_check_result
|
|
161
170
|
|
|
@@ -170,7 +179,7 @@ def set_register_backward_hook_functions():
|
|
|
170
179
|
from msprobe.mindspore.mindtorch import (_call_impl,
|
|
171
180
|
register_full_backward_pre_hook,
|
|
172
181
|
register_full_backward_hook)
|
|
173
|
-
if not hasattr(torch, "register_full_backward_hook"):
|
|
182
|
+
if not hasattr(torch.nn.Module, "register_full_backward_hook"):
|
|
174
183
|
setattr(torch.nn.Module, "_call_impl", _call_impl)
|
|
175
184
|
setattr(torch.nn.Module, "register_full_backward_pre_hook", register_full_backward_pre_hook)
|
|
176
185
|
setattr(torch.nn.Module, "register_full_backward_hook", register_full_backward_hook)
|
|
@@ -179,3 +188,24 @@ def set_register_backward_hook_functions():
|
|
|
179
188
|
else:
|
|
180
189
|
register_backward_hook_functions["pre"] = ms.nn.Cell.register_backward_pre_hook
|
|
181
190
|
register_backward_hook_functions["full"] = ms.nn.Cell.register_backward_hook
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def check_save_param(variable, name, save_backward):
|
|
194
|
+
# try catch this api to skip invalid call
|
|
195
|
+
valid_data_types = tuple([ms.Tensor, int, float, str])
|
|
196
|
+
if not is_save_variable_valid(variable, valid_data_types):
|
|
197
|
+
valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
|
|
198
|
+
logger.warning("PrecisionDebugger.save variable type not valid, "
|
|
199
|
+
f"should be one of {valid_data_types_with_nested_types}"
|
|
200
|
+
"Skip current save process.")
|
|
201
|
+
raise ValueError
|
|
202
|
+
if not isinstance(name, str):
|
|
203
|
+
logger.warning("PrecisionDebugger.save name not valid, "
|
|
204
|
+
"should be string. "
|
|
205
|
+
"skip current save process.")
|
|
206
|
+
raise ValueError
|
|
207
|
+
if not isinstance(save_backward, bool):
|
|
208
|
+
logger.warning("PrecisionDebugger.save_backward name not valid, "
|
|
209
|
+
"should be bool. "
|
|
210
|
+
"Skip current save process.")
|
|
211
|
+
raise ValueError
|
|
@@ -22,10 +22,10 @@ import pandas as pd
|
|
|
22
22
|
|
|
23
23
|
from msprobe.core.common.const import CompareConst, Const
|
|
24
24
|
from msprobe.core.common.exceptions import FileCheckException
|
|
25
|
-
from msprobe.core.common.file_utils import
|
|
25
|
+
from msprobe.core.common.file_utils import create_directory, load_json, load_npy, load_yaml
|
|
26
26
|
from msprobe.core.common.log import logger
|
|
27
27
|
from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, \
|
|
28
|
-
check_op_str_pattern_valid, get_dump_mode, set_dump_path
|
|
28
|
+
check_op_str_pattern_valid, get_dump_mode, set_dump_path, detect_framework_by_dump_json
|
|
29
29
|
from msprobe.core.compare.acc_compare import Comparator, ModeConfig
|
|
30
30
|
from msprobe.core.compare.check import dtype_mapping
|
|
31
31
|
from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping
|
|
@@ -78,6 +78,11 @@ class MSComparator(Comparator):
|
|
|
78
78
|
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
79
79
|
f"{type(self.data_mapping)}")
|
|
80
80
|
|
|
81
|
+
@staticmethod
|
|
82
|
+
def process_data_name(result):
|
|
83
|
+
result['data_name_x'] = result.apply(lambda row: [row['data_name_x'], row['data_name_y']], axis=1)
|
|
84
|
+
return result
|
|
85
|
+
|
|
81
86
|
def calc_accuracy(self, result_df, header):
|
|
82
87
|
condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
|
|
83
88
|
result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
|
|
@@ -120,12 +125,13 @@ class MSComparator(Comparator):
|
|
|
120
125
|
result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
|
|
121
126
|
elif self.dump_mode == Const.SUMMARY:
|
|
122
127
|
warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']]
|
|
123
|
-
warning_flag = pd.DataFrame(warning_list).
|
|
128
|
+
warning_flag = pd.DataFrame(warning_list).any()
|
|
124
129
|
result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
|
|
125
130
|
result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
|
|
126
131
|
result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
|
|
127
132
|
else:
|
|
128
|
-
fill_cols = [CompareConst.COSINE, CompareConst.
|
|
133
|
+
fill_cols = [CompareConst.COSINE, CompareConst.EUC_DIST,
|
|
134
|
+
CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
|
|
129
135
|
CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
|
|
130
136
|
CompareConst.ERROR_MESSAGE]
|
|
131
137
|
result_df.loc[~condition_no_bench, fill_cols] = ''
|
|
@@ -139,6 +145,8 @@ class MSComparator(Comparator):
|
|
|
139
145
|
header.append(CompareConst.STACK)
|
|
140
146
|
if self.dump_mode == Const.ALL:
|
|
141
147
|
header.append(CompareConst.DATA_NAME)
|
|
148
|
+
result = self.process_data_name(result)
|
|
149
|
+
|
|
142
150
|
result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
|
|
143
151
|
'op_name_y': CompareConst.BENCH_NAME,
|
|
144
152
|
'dtype_x': CompareConst.NPU_DTYPE,
|
|
@@ -169,6 +177,7 @@ class MSComparator(Comparator):
|
|
|
169
177
|
|
|
170
178
|
result[npu_summary] = result['summary_x'].apply(set_summary).tolist()
|
|
171
179
|
result[bench_summary] = result['summary_y'].apply(set_summary).tolist()
|
|
180
|
+
|
|
172
181
|
result_df = pd.DataFrame(columns=header)
|
|
173
182
|
for h in header:
|
|
174
183
|
if h in result.columns:
|
|
@@ -269,15 +278,15 @@ class MSComparator(Comparator):
|
|
|
269
278
|
bench_dtype = match_result['dtype_y']
|
|
270
279
|
if self.cross_frame:
|
|
271
280
|
npu_dtype = npu_dtype.map(dtype_mapping).fillna(npu_dtype)
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
(
|
|
276
|
-
|
|
277
|
-
(
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
+
|
|
282
|
+
equal_condition = npu_dtype == bench_dtype
|
|
283
|
+
match_condition = (
|
|
284
|
+
(npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[0]) & bench_dtype.isin(
|
|
285
|
+
CompareConst.DTYPE_MATCH_GROUPS[0])) |
|
|
286
|
+
(npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[1]) & bench_dtype.isin(
|
|
287
|
+
CompareConst.DTYPE_MATCH_GROUPS[1]))
|
|
288
|
+
)
|
|
289
|
+
return equal_condition | match_condition
|
|
281
290
|
|
|
282
291
|
match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A
|
|
283
292
|
return self.make_result_df(match_result)
|
|
@@ -382,12 +391,11 @@ class MSComparator(Comparator):
|
|
|
382
391
|
|
|
383
392
|
|
|
384
393
|
def check_cross_framework(bench_json_path):
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
return False
|
|
394
|
+
framework = detect_framework_by_dump_json(bench_json_path)
|
|
395
|
+
if framework == Const.PT_FRAMEWORK:
|
|
396
|
+
return True
|
|
397
|
+
else:
|
|
398
|
+
return False
|
|
391
399
|
|
|
392
400
|
|
|
393
401
|
def ms_compare(input_param, output_path, **kwargs):
|
|
@@ -195,11 +195,12 @@ class GraphMSComparator:
|
|
|
195
195
|
if not error_flag:
|
|
196
196
|
result_list, err_msg = compare_ops_apply(n_value, b_value, False, "")
|
|
197
197
|
result_dict[CompareConst.COSINE] = result_list[0]
|
|
198
|
-
result_dict[CompareConst.
|
|
199
|
-
result_dict[CompareConst.
|
|
200
|
-
result_dict[CompareConst.
|
|
201
|
-
result_dict[CompareConst.
|
|
202
|
-
result_dict[CompareConst.
|
|
198
|
+
result_dict[CompareConst.EUC_DIST] = result_list[1]
|
|
199
|
+
result_dict[CompareConst.MAX_ABS_ERR] = result_list[2]
|
|
200
|
+
result_dict[CompareConst.MAX_RELATIVE_ERR] = result_list[3]
|
|
201
|
+
result_dict[CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result_list[4]
|
|
202
|
+
result_dict[CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result_list[5]
|
|
203
|
+
result_dict[CompareConst.ACCURACY] = check_accuracy(result_list[0], result_list[2])
|
|
203
204
|
result_dict[CompareConst.ERROR_MESSAGE] = err_msg
|
|
204
205
|
|
|
205
206
|
return pd.Series(result_dict)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -16,9 +16,11 @@
|
|
|
16
16
|
import os
|
|
17
17
|
|
|
18
18
|
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
19
20
|
from msprobe.core.common.file_utils import create_directory
|
|
20
21
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
21
22
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
23
|
+
from msprobe.core.common.log import logger
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
class DebuggerConfig:
|
|
@@ -50,12 +52,14 @@ class DebuggerConfig:
|
|
|
50
52
|
if not task_config.handler_type else task_config.handler_type)
|
|
51
53
|
self.stage = FreeBenchmarkConst.DEFAULT_STAGE if not task_config.fuzz_stage else task_config.fuzz_stage
|
|
52
54
|
if self.handler_type == FreeBenchmarkConst.FIX and \
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
55
|
+
self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE:
|
|
56
|
+
logger.error("pert_mode must be improve_precision or empty when handler_type is fix, "
|
|
57
|
+
f"but got {self.pert_type}.")
|
|
58
|
+
raise ValueError
|
|
56
59
|
if self.stage == Const.BACKWARD and self.handler_type == FreeBenchmarkConst.FIX:
|
|
57
|
-
|
|
58
|
-
|
|
60
|
+
logger.error("handler_type must be check or empty when fuzz_stage is backward, "
|
|
61
|
+
f"but got {self.handler_type}.")
|
|
62
|
+
raise ValueError
|
|
59
63
|
self.dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
|
|
60
64
|
|
|
61
65
|
def check(self):
|
|
@@ -72,4 +76,25 @@ class DebuggerConfig:
|
|
|
72
76
|
self.check_mode = "all"
|
|
73
77
|
if not isinstance(self.async_dump, bool):
|
|
74
78
|
raise Exception("The parameters async_dump should be bool.")
|
|
79
|
+
if self.async_dump and self.task == Const.TENSOR and not self.list:
|
|
80
|
+
raise Exception("The parameters async_dump is true in tensor task, the parameters list cannot be empty.")
|
|
81
|
+
if self.task == Const.STRUCTURE and self.level_ori not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
82
|
+
logger.warning_on_rank_0(
|
|
83
|
+
f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
|
|
84
|
+
f"If not, the default level is {Const.LEVEL_MIX}."
|
|
85
|
+
)
|
|
86
|
+
self.level_ori = Const.LEVEL_MIX
|
|
75
87
|
return True
|
|
88
|
+
|
|
89
|
+
def check_config_with_l2(self):
|
|
90
|
+
if self.level_ori != Const.LEVEL_L2:
|
|
91
|
+
return
|
|
92
|
+
if self.task != Const.TENSOR:
|
|
93
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
94
|
+
f"When level is set to L2, the task must be set to tensor.")
|
|
95
|
+
if self.scope:
|
|
96
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
97
|
+
f"When level is set to L2, the scope cannot be configured.")
|
|
98
|
+
if not self.list or len(self.list) != 1:
|
|
99
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
100
|
+
f"When level is set to L2, the list must be configured as a list with one api name.")
|