mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
|
@@ -1,16 +1,34 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
2
16
|
import os
|
|
17
|
+
from tqdm import tqdm
|
|
3
18
|
|
|
4
|
-
from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv
|
|
5
|
-
from msprobe.core.common.utils import add_time_as_suffix
|
|
6
19
|
from msprobe.core.common.const import Const, CompareConst, MsCompareConst
|
|
7
|
-
from msprobe.
|
|
20
|
+
from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, load_json, load_yaml
|
|
21
|
+
from msprobe.core.common.utils import add_time_as_suffix
|
|
8
22
|
from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
|
|
9
23
|
from msprobe.mindspore.api_accuracy_checker.api_runner import api_runner, ApiInputAggregation
|
|
10
24
|
from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
|
|
25
|
+
from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager
|
|
11
26
|
from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context,
|
|
12
27
|
trim_output_compute_element_list)
|
|
28
|
+
from msprobe.mindspore.common.log import logger
|
|
13
29
|
|
|
30
|
+
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
31
|
+
yaml_path = os.path.join(cur_path, MsCompareConst.SUPPORTED_API_LIST_FILE)
|
|
14
32
|
|
|
15
33
|
class BasicInfoAndStatus:
|
|
16
34
|
def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
|
|
@@ -21,6 +39,7 @@ class BasicInfoAndStatus:
|
|
|
21
39
|
self.status = status
|
|
22
40
|
self.err_msg = err_msg
|
|
23
41
|
|
|
42
|
+
|
|
24
43
|
class ResultCsvEntry:
|
|
25
44
|
def __init__(self) -> None:
|
|
26
45
|
self.forward_pass_status = None
|
|
@@ -31,9 +50,9 @@ class ResultCsvEntry:
|
|
|
31
50
|
|
|
32
51
|
|
|
33
52
|
class ApiAccuracyChecker:
|
|
34
|
-
def __init__(self):
|
|
53
|
+
def __init__(self, args):
|
|
35
54
|
self.api_infos = dict()
|
|
36
|
-
self.
|
|
55
|
+
self.data_manager = DataManager(args.out_path, args.result_csv_path) # 在初始化时实例化 DataManager
|
|
37
56
|
|
|
38
57
|
@staticmethod
|
|
39
58
|
def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
|
|
@@ -80,13 +99,13 @@ class ApiAccuracyChecker:
|
|
|
80
99
|
compare_result_dict[compare_algorithm_name] = compare_result
|
|
81
100
|
|
|
82
101
|
if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \
|
|
83
|
-
|
|
102
|
+
compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
|
|
84
103
|
status = CompareConst.PASS
|
|
85
104
|
err_msg = ""
|
|
86
105
|
else:
|
|
87
106
|
status = CompareConst.ERROR
|
|
88
107
|
err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg + \
|
|
89
|
-
|
|
108
|
+
compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
|
|
90
109
|
basic_info_status = \
|
|
91
110
|
BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
|
|
92
111
|
output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
|
|
@@ -109,13 +128,35 @@ class ApiAccuracyChecker:
|
|
|
109
128
|
gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
|
|
110
129
|
return ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
|
|
111
130
|
|
|
131
|
+
@staticmethod
|
|
132
|
+
def is_api_checkable(api_name_str):
|
|
133
|
+
'''
|
|
134
|
+
Args:
|
|
135
|
+
api_name_str: str, e.g. "MintFunctional.relu.0.forward", key in data field of api_info.json
|
|
136
|
+
Returns:
|
|
137
|
+
is_checkable: bool
|
|
138
|
+
Description:
|
|
139
|
+
tell whether this api is checkable based on the key in "data" dict in api_info.json
|
|
140
|
+
'''
|
|
141
|
+
api_name_str_list = api_name_str.split(Const.SEP)
|
|
142
|
+
if len(api_name_str_list) < MsCompareConst.API_NAME_STR_LENGTH:
|
|
143
|
+
return False
|
|
144
|
+
api_type_str = api_name_str_list[0]
|
|
145
|
+
real_api_str = Const.SEP.join(api_name_str_list[1:-2])
|
|
146
|
+
api_list = load_yaml(yaml_path)
|
|
147
|
+
supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
|
|
148
|
+
if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL):
|
|
149
|
+
return True
|
|
150
|
+
if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list:
|
|
151
|
+
return True
|
|
152
|
+
return False
|
|
153
|
+
|
|
112
154
|
def parse(self, api_info_path):
|
|
113
|
-
|
|
114
|
-
api_info_dict = json.load(f)
|
|
155
|
+
api_info_dict = load_json(api_info_path)
|
|
115
156
|
|
|
116
157
|
# init global context
|
|
117
158
|
task = check_and_get_from_json_dict(api_info_dict, MsCompareConst.TASK_FIELD,
|
|
118
|
-
"task field in api_info.json",accepted_type=str,
|
|
159
|
+
"task field in api_info.json", accepted_type=str,
|
|
119
160
|
accepted_value=(MsCompareConst.STATISTICS_TASK,
|
|
120
161
|
MsCompareConst.TENSOR_TASK))
|
|
121
162
|
is_constructed = task == MsCompareConst.STATISTICS_TASK
|
|
@@ -129,14 +170,12 @@ class ApiAccuracyChecker:
|
|
|
129
170
|
api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
|
|
130
171
|
"data field in api_info.json", accepted_type=dict)
|
|
131
172
|
for api_name, api_info in api_info_data.items():
|
|
132
|
-
|
|
133
|
-
(MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL)
|
|
134
|
-
if not is_mint:
|
|
173
|
+
if not self.is_api_checkable(api_name):
|
|
135
174
|
continue
|
|
136
175
|
forbackward_str = api_name.split(Const.SEP)[-1]
|
|
137
176
|
if forbackward_str not in (Const.FORWARD, Const.BACKWARD):
|
|
138
177
|
logger.warning(f"api: {api_name} is not recognized as forward api or backward api, skip this.")
|
|
139
|
-
api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1])
|
|
178
|
+
api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
|
|
140
179
|
if api_name not in self.api_infos:
|
|
141
180
|
self.api_infos[api_name] = ApiInfo(api_name)
|
|
142
181
|
|
|
@@ -145,135 +184,64 @@ class ApiAccuracyChecker:
|
|
|
145
184
|
else:
|
|
146
185
|
self.api_infos[api_name].load_backward_info(api_info)
|
|
147
186
|
|
|
187
|
+
def process_forward(self, api_name_str, api_info):
|
|
188
|
+
"""处理前向检查"""
|
|
189
|
+
if not api_info.check_forward_info():
|
|
190
|
+
logger.debug(f"api: {api_name_str} is lack of forward information, skip forward check.")
|
|
191
|
+
return Const.EXCEPTION_NONE
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
|
|
195
|
+
except Exception as e:
|
|
196
|
+
logger.warning(f"Exception occurs when getting inputs for {api_name_str} forward api. "
|
|
197
|
+
f"Skipping forward check. Detailed exception information: {e}.")
|
|
198
|
+
return Const.EXCEPTION_NONE
|
|
199
|
+
|
|
200
|
+
forward_output_list = None
|
|
201
|
+
try:
|
|
202
|
+
forward_output_list = self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
|
|
203
|
+
except Exception as e:
|
|
204
|
+
logger.warning(f"Exception occurs when running and comparing {api_name_str} forward api. "
|
|
205
|
+
f"Detailed exception information: {e}.")
|
|
206
|
+
return forward_output_list
|
|
207
|
+
|
|
208
|
+
def process_backward(self, api_name_str, api_info):
|
|
209
|
+
"""处理反向检查"""
|
|
210
|
+
if not api_info.check_backward_info():
|
|
211
|
+
logger.debug(f"api: {api_name_str} is lack of backward information, skipping backward check.")
|
|
212
|
+
return Const.EXCEPTION_NONE
|
|
213
|
+
|
|
214
|
+
try:
|
|
215
|
+
backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
|
|
216
|
+
except Exception as e:
|
|
217
|
+
logger.warning(f"Exception occurs when getting inputs for {api_name_str} backward api. "
|
|
218
|
+
f"Skipping backward check. Detailed exception information: {e}.")
|
|
219
|
+
return Const.EXCEPTION_NONE
|
|
220
|
+
|
|
221
|
+
backward_output_list = None
|
|
222
|
+
try:
|
|
223
|
+
backward_output_list = self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
|
|
224
|
+
except Exception as e:
|
|
225
|
+
logger.warning(f"Exception occurs when running and comparing {api_name_str} backward api. "
|
|
226
|
+
f"Detailed exception information: {e}.")
|
|
227
|
+
return backward_output_list
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
|
|
148
231
|
def run_and_compare(self):
|
|
149
|
-
for api_name_str, api_info in self.api_infos.items():
|
|
150
|
-
if not
|
|
151
|
-
logger.warning(f"api: {api_name_str} is lack of forward infomation, skip forward and backward check.")
|
|
152
|
-
continue
|
|
153
|
-
try:
|
|
154
|
-
forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
|
|
155
|
-
except Exception as e:
|
|
156
|
-
logger.warning(f"exception occurs when getting inputs for {api_name_str} forward api. "
|
|
157
|
-
f"skip forward and backward check. detailed exception information: {e}.")
|
|
158
|
-
continue
|
|
159
|
-
forward_output_list = None
|
|
160
|
-
try:
|
|
161
|
-
forward_output_list = \
|
|
162
|
-
self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
|
|
163
|
-
except Exception as e:
|
|
164
|
-
logger.warning(f"exception occurs when running and comparing {api_name_str} forward api. "
|
|
165
|
-
f"detailed exception information: {e}.")
|
|
166
|
-
self.record(forward_output_list)
|
|
167
|
-
|
|
168
|
-
if not api_info.check_backward_info():
|
|
169
|
-
logger.warning(f"api: {api_name_str} is lack of backward infomation, skip backward check.")
|
|
170
|
-
continue
|
|
171
|
-
try:
|
|
172
|
-
backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
|
|
173
|
-
except Exception as e:
|
|
174
|
-
logger.warning(f"exception occurs when getting inputs for {api_name_str} backward api. "
|
|
175
|
-
f"skip backward check. detailed exception information: {e}.")
|
|
232
|
+
for api_name_str, api_info in tqdm(self.api_infos.items()):
|
|
233
|
+
if not self.data_manager.is_unique_api(api_name_str):
|
|
176
234
|
continue
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
self.
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
api_real_name, forward_or_backward, basic_info, compare_result_dict = output
|
|
191
|
-
key = tuple([api_real_name, forward_or_backward])
|
|
192
|
-
if key not in self.results:
|
|
193
|
-
self.results[key] = []
|
|
194
|
-
self.results[key].append(tuple([basic_info, compare_result_dict]))
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
def to_detail_csv(self, csv_dir):
|
|
198
|
-
# detail_csv
|
|
199
|
-
detail_csv = []
|
|
200
|
-
detail_csv_header_basic_info = [
|
|
201
|
-
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
202
|
-
MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
|
|
203
|
-
MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
|
|
204
|
-
MsCompareConst.DETAIL_CSV_SHAPE,
|
|
205
|
-
]
|
|
206
|
-
detail_csv_header_compare_result = list(compare_algorithms.keys())
|
|
207
|
-
detail_csv_header_status = [
|
|
208
|
-
MsCompareConst.DETAIL_CSV_PASS_STATUS,
|
|
209
|
-
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
210
|
-
]
|
|
211
|
-
|
|
212
|
-
detail_csv_header = detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
|
|
213
|
-
detail_csv.append(detail_csv_header)
|
|
214
|
-
|
|
215
|
-
for _, results in self.results.items():
|
|
216
|
-
# detail csv
|
|
217
|
-
for res in results:
|
|
218
|
-
basic_info, compare_result_dict = res
|
|
219
|
-
csv_row_basic_info = \
|
|
220
|
-
[basic_info.api_name, basic_info.bench_dtype, basic_info.tested_dtype, basic_info.shape]
|
|
221
|
-
csv_row_compare_result = list(compare_result_dict.get(algorithm_name).compare_value \
|
|
222
|
-
for algorithm_name in detail_csv_header_compare_result)
|
|
223
|
-
csv_row_status = [basic_info.status, basic_info.err_msg]
|
|
224
|
-
csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
|
|
225
|
-
detail_csv.append(csv_row)
|
|
226
|
-
|
|
227
|
-
file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.DETAIL_CSV_FILE_NAME))
|
|
228
|
-
create_directory(csv_dir)
|
|
229
|
-
write_csv(detail_csv, file_name, mode="w")
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
def to_result_csv(self, csv_dir):
|
|
233
|
-
result_csv_dict = dict()
|
|
234
|
-
for key, results in self.results.items():
|
|
235
|
-
api_real_name, forward_or_backward = key
|
|
236
|
-
forward_or_backward_pass_status = CompareConst.PASS
|
|
237
|
-
forward_or_backward_overall_err_msg = ""
|
|
238
|
-
# detail csv
|
|
239
|
-
for res in results:
|
|
240
|
-
basic_info, _ = res
|
|
241
|
-
if basic_info.status != CompareConst.PASS:
|
|
242
|
-
forward_or_backward_pass_status = CompareConst.ERROR
|
|
243
|
-
forward_or_backward_overall_err_msg += basic_info.err_msg
|
|
244
|
-
forward_or_backward_overall_err_msg = \
|
|
245
|
-
"" if forward_or_backward_pass_status == CompareConst.PASS else forward_or_backward_overall_err_msg
|
|
246
|
-
|
|
247
|
-
#result_csv_dict
|
|
248
|
-
if api_real_name not in result_csv_dict:
|
|
249
|
-
result_csv_dict[api_real_name] = ResultCsvEntry()
|
|
250
|
-
if forward_or_backward == Const.FORWARD:
|
|
251
|
-
result_csv_dict[api_real_name].forward_pass_status = forward_or_backward_pass_status
|
|
252
|
-
result_csv_dict[api_real_name].forward_err_msg = forward_or_backward_overall_err_msg
|
|
253
|
-
else:
|
|
254
|
-
result_csv_dict[api_real_name].backward_pass_status = forward_or_backward_pass_status
|
|
255
|
-
result_csv_dict[api_real_name].backward_err_msg = forward_or_backward_overall_err_msg
|
|
256
|
-
|
|
257
|
-
#result_csv
|
|
258
|
-
result_csv = []
|
|
259
|
-
result_csv_header = [
|
|
260
|
-
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
261
|
-
MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
|
|
262
|
-
MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
|
|
263
|
-
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
264
|
-
]
|
|
265
|
-
result_csv.append(result_csv_header)
|
|
266
|
-
|
|
267
|
-
for api_name, result_csv_entry in result_csv_dict.items():
|
|
268
|
-
if result_csv_entry.forward_pass_status == CompareConst.PASS and \
|
|
269
|
-
result_csv_entry.backward_pass_status == CompareConst.PASS:
|
|
270
|
-
overall_err_msg = ""
|
|
271
|
-
else:
|
|
272
|
-
overall_err_msg = result_csv_entry.forward_err_msg + result_csv_entry.backward_err_msg
|
|
273
|
-
row = [api_name, result_csv_entry.forward_pass_status,
|
|
274
|
-
result_csv_entry.backward_pass_status, overall_err_msg]
|
|
275
|
-
result_csv.append(row)
|
|
276
|
-
|
|
277
|
-
file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
|
|
278
|
-
create_directory(csv_dir)
|
|
279
|
-
write_csv(result_csv, file_name, mode="w")
|
|
235
|
+
|
|
236
|
+
# 处理前向
|
|
237
|
+
forward_output_list = self.process_forward(api_name_str, api_info)
|
|
238
|
+
if forward_output_list is not Const.EXCEPTION_NONE:
|
|
239
|
+
self.data_manager.record(forward_output_list)
|
|
240
|
+
|
|
241
|
+
# 处理反向
|
|
242
|
+
backward_output_list = self.process_backward(api_name_str, api_info)
|
|
243
|
+
if backward_output_list is not Const.EXCEPTION_NONE:
|
|
244
|
+
self.data_manager.record(backward_output_list)
|
|
245
|
+
|
|
246
|
+
self.data_manager.save_results(api_name_str)
|
|
247
|
+
|
|
@@ -1,9 +1,25 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
2
16
|
from msprobe.core.common.const import Const
|
|
3
|
-
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
|
|
4
17
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
5
|
-
from msprobe.mindspore.common.log import logger
|
|
6
18
|
from msprobe.core.common.utils import is_invalid_pattern
|
|
19
|
+
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
20
|
+
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
|
|
21
|
+
from msprobe.mindspore.common.log import logger
|
|
22
|
+
|
|
7
23
|
|
|
8
24
|
class ApiInfo:
|
|
9
25
|
def __init__(self, api_name):
|
|
@@ -66,11 +82,10 @@ class ApiInfo:
|
|
|
66
82
|
err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string"
|
|
67
83
|
logger.error_log_with_exp(err_msg,
|
|
68
84
|
ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
|
|
69
|
-
if not isinstance(compute_element_info, (list, dict)):
|
|
70
|
-
err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or
|
|
85
|
+
if not (isinstance(compute_element_info, (list, dict)) or compute_element_info is None):
|
|
86
|
+
err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list, dict or null"
|
|
71
87
|
logger.error_log_with_exp(err_msg,
|
|
72
88
|
ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
|
|
73
89
|
kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info)
|
|
74
90
|
for key_str, compute_element_info in kwargs_dict.items()}
|
|
75
91
|
return kwargs_compute_element_dict
|
|
76
|
-
|
|
@@ -1,15 +1,27 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
2
15
|
|
|
3
16
|
import mindspore
|
|
4
17
|
import torch
|
|
5
18
|
from mindspore import ops
|
|
6
|
-
|
|
7
|
-
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
8
19
|
from msprobe.core.common.const import Const, MsCompareConst
|
|
9
20
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
10
|
-
from msprobe.mindspore.
|
|
11
|
-
from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
|
|
21
|
+
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
12
22
|
from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str
|
|
23
|
+
from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
|
|
24
|
+
from msprobe.mindspore.common.log import logger
|
|
13
25
|
|
|
14
26
|
|
|
15
27
|
class ApiInputAggregation:
|
|
@@ -24,11 +36,23 @@ class ApiInputAggregation:
|
|
|
24
36
|
self.kwargs = kwargs
|
|
25
37
|
self.gradient_inputs = gradient_inputs
|
|
26
38
|
|
|
39
|
+
|
|
27
40
|
api_parent_module_mapping = {
|
|
28
41
|
(MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint,
|
|
29
42
|
(MsCompareConst.MINT, Const.PT_FRAMEWORK): torch,
|
|
30
43
|
(MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
|
|
31
|
-
(MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional
|
|
44
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional,
|
|
45
|
+
(MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): mindspore.Tensor,
|
|
46
|
+
(MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
api_parent_module_str_mapping = {
|
|
50
|
+
(MsCompareConst.MINT, Const.MS_FRAMEWORK): "mindspore.mint",
|
|
51
|
+
(MsCompareConst.MINT, Const.PT_FRAMEWORK): "torch",
|
|
52
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): "mindspore.mint.nn.functional",
|
|
53
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): "torch.nn.functional",
|
|
54
|
+
(MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): "mindspore.Tensor",
|
|
55
|
+
(MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor"
|
|
32
56
|
}
|
|
33
57
|
|
|
34
58
|
|
|
@@ -60,7 +84,7 @@ class ApiRunner:
|
|
|
60
84
|
api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
|
|
61
85
|
|
|
62
86
|
Return:
|
|
63
|
-
api_type_str: str, Union["MintFunctional", "Mint"]
|
|
87
|
+
api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
|
|
64
88
|
api_sub_name: str, e.g. "relu"
|
|
65
89
|
'''
|
|
66
90
|
api_name_list = api_name_str.split(Const.SEP)
|
|
@@ -68,8 +92,8 @@ class ApiRunner:
|
|
|
68
92
|
err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
|
|
69
93
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
70
94
|
api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
|
|
71
|
-
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL]:
|
|
72
|
-
err_msg = f"ApiRunner.get_info_from_name failed: not mint
|
|
95
|
+
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API]:
|
|
96
|
+
err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api"
|
|
73
97
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
74
98
|
|
|
75
99
|
return api_type_str, api_sub_name
|
|
@@ -78,7 +102,7 @@ class ApiRunner:
|
|
|
78
102
|
def get_api_instance(api_type_str, api_sub_name, api_platform):
|
|
79
103
|
'''
|
|
80
104
|
Args:
|
|
81
|
-
api_type_str: str, Union["MintFunctional", "Mint"]
|
|
105
|
+
api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
|
|
82
106
|
api_sub_name: str, e.g. "relu"
|
|
83
107
|
api_platform: str: Union["mindpore", "torch"]
|
|
84
108
|
|
|
@@ -92,9 +116,8 @@ class ApiRunner:
|
|
|
92
116
|
'''
|
|
93
117
|
|
|
94
118
|
api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
full_api_name = module_str + submodule_str + api_sub_name
|
|
119
|
+
api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform))
|
|
120
|
+
full_api_name = api_parent_module_str + Const.SEP + api_sub_name
|
|
98
121
|
if not hasattr(api_parent_module, api_sub_name):
|
|
99
122
|
err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
|
|
100
123
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
|
|
@@ -115,7 +138,7 @@ class ApiRunner:
|
|
|
115
138
|
gradient_inputs = api_input_aggregation.gradient_inputs
|
|
116
139
|
|
|
117
140
|
if forward_or_backward == Const.FORWARD:
|
|
118
|
-
forward_result = api_instance(*inputs, **kwargs)
|
|
141
|
+
forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
|
|
119
142
|
forward_result_tuple = convert_to_tuple(forward_result)
|
|
120
143
|
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
|
|
121
144
|
else:
|
|
@@ -127,18 +150,20 @@ class ApiRunner:
|
|
|
127
150
|
if api_platform == Const.MS_FRAMEWORK:
|
|
128
151
|
if len(gradient_inputs) == 1:
|
|
129
152
|
gradient_inputs = gradient_inputs[0]
|
|
153
|
+
|
|
130
154
|
def api_with_kwargs(*forward_inputs):
|
|
131
155
|
return api_instance(*forward_inputs, **kwargs)
|
|
156
|
+
|
|
132
157
|
grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs)
|
|
133
|
-
backward_result = grad_func(*inputs, gradient_inputs)
|
|
158
|
+
backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
|
|
134
159
|
backward_result_tuple = convert_to_tuple(backward_result)
|
|
135
160
|
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
|
|
136
161
|
else:
|
|
137
|
-
#set requires_grad
|
|
162
|
+
# set requires_grad
|
|
138
163
|
requires_grad_index = []
|
|
139
164
|
for index, tensor in enumerate(inputs):
|
|
140
165
|
if isinstance(tensor, torch.Tensor) and \
|
|
141
|
-
|
|
166
|
+
torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list:
|
|
142
167
|
setattr(tensor, "requires_grad", True)
|
|
143
168
|
requires_grad_index.append(index)
|
|
144
169
|
forward_results = api_instance(*inputs, **kwargs)
|
|
@@ -153,4 +178,4 @@ class ApiRunner:
|
|
|
153
178
|
return res_compute_element_list
|
|
154
179
|
|
|
155
180
|
|
|
156
|
-
api_runner = ApiRunner()
|
|
181
|
+
api_runner = ApiRunner()
|
|
@@ -1,12 +1,27 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from abc import ABC, abstractmethod
|
|
2
17
|
|
|
3
18
|
import mindspore
|
|
4
|
-
import torch
|
|
5
19
|
import numpy as np
|
|
6
|
-
|
|
20
|
+
import torch
|
|
21
|
+
from msprobe.core.common.const import CompareConst, MsCompareConst
|
|
7
22
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
8
23
|
from msprobe.mindspore.common.log import logger
|
|
9
|
-
|
|
24
|
+
|
|
10
25
|
|
|
11
26
|
class CompareResult:
|
|
12
27
|
def __init__(self, compare_value, pass_status, err_msg):
|
|
@@ -28,7 +43,7 @@ class BaseCompareAlgorithm(ABC):
|
|
|
28
43
|
CompareConst.MAX_ABS_ERR: {
|
|
29
44
|
CompareConst.PASS: "",
|
|
30
45
|
CompareConst.ERROR: "max absolute difference is greater than " \
|
|
31
|
-
|
|
46
|
+
f"threshold: {CompareConst.MAX_ABS_ERR_THRESHOLD} ",
|
|
32
47
|
CompareConst.SKIP: "two inputs are not valid for computing max absolute difference, skip comparing ",
|
|
33
48
|
},
|
|
34
49
|
CompareConst.MAX_RELATIVE_ERR: {
|
|
@@ -68,7 +83,7 @@ class BaseCompareAlgorithm(ABC):
|
|
|
68
83
|
ndarray = tensor.to(torch.float64, copy=True).numpy()
|
|
69
84
|
else:
|
|
70
85
|
err_msg = "BaseCompareAlgorithm.convert_to_np_float64_ndarray failed: " \
|
|
71
|
-
|
|
86
|
+
"input is not mindspore.Tensor or torch.Tensor"
|
|
72
87
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
73
88
|
return ndarray
|
|
74
89
|
|
|
@@ -189,9 +204,8 @@ class MaxRelativeDiffCompareAlgorithm(BaseCompareAlgorithm):
|
|
|
189
204
|
return CompareConst.ERROR
|
|
190
205
|
|
|
191
206
|
|
|
192
|
-
|
|
193
207
|
compare_algorithms = {
|
|
194
208
|
CompareConst.COSINE: CosineSimilarityCompareAlgorithm(),
|
|
195
209
|
CompareConst.MAX_ABS_ERR: MaxAbsoluteDiffCompareAlgorithm(),
|
|
196
210
|
CompareConst.MAX_RELATIVE_ERR: MaxRelativeDiffCompareAlgorithm(),
|
|
197
|
-
}
|
|
211
|
+
}
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
# list of api that can be checked
|
|
17
|
+
|
|
18
|
+
tensor:
|
|
19
|
+
- add_
|
|
20
|
+
- add
|
|
21
|
+
- addmm_
|
|
22
|
+
- all
|
|
23
|
+
- allclose
|
|
24
|
+
- any
|
|
25
|
+
- bool
|
|
26
|
+
- byte
|
|
27
|
+
- ceil
|
|
28
|
+
- clamp
|
|
29
|
+
- contiguous
|
|
30
|
+
- copy_
|
|
31
|
+
- cos
|
|
32
|
+
- clone
|
|
33
|
+
- cumprod
|
|
34
|
+
- expand_as
|
|
35
|
+
- flatten
|
|
36
|
+
- float
|
|
37
|
+
- half
|
|
38
|
+
- int
|
|
39
|
+
- is_contiguous
|
|
40
|
+
- isnan
|
|
41
|
+
- item
|
|
42
|
+
- log
|
|
43
|
+
- log2
|
|
44
|
+
- long
|
|
45
|
+
- masked_fill
|
|
46
|
+
- max
|
|
47
|
+
- mean
|
|
48
|
+
- min
|
|
49
|
+
- numel
|
|
50
|
+
- numpy
|
|
51
|
+
- repeat
|
|
52
|
+
- repeat_interleave
|
|
53
|
+
- reshape
|
|
54
|
+
- round
|
|
55
|
+
- select
|
|
56
|
+
- sin
|
|
57
|
+
- size
|
|
58
|
+
- split
|
|
59
|
+
- sqrt
|
|
60
|
+
- square
|
|
61
|
+
- sub
|
|
62
|
+
- swapaxes
|
|
63
|
+
- to
|
|
64
|
+
- t
|
|
65
|
+
- tolist
|
|
66
|
+
- topk
|
|
67
|
+
- transpose
|
|
68
|
+
- trunc
|
|
69
|
+
- type
|
|
70
|
+
- unsqueeze
|
|
71
|
+
- view
|
|
72
|
+
- view_as
|
|
73
|
+
- fill_
|
|
74
|
+
- floor_
|
|
75
|
+
- clamp_
|
|
76
|
+
- type_as
|
|
77
|
+
- zero_
|