mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
|
@@ -1,6 +1,68 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, create_directory
|
|
21
|
+
from msprobe.core.common.utils import Const, MsprobeBaseException
|
|
22
|
+
|
|
23
|
+
class UniqueDeviceAction(argparse.Action):
|
|
24
|
+
def __call__(self, parser, namespace, values, option_string=None):
|
|
25
|
+
unique_values = set(values)
|
|
26
|
+
if len(values) != len(unique_values):
|
|
27
|
+
parser.error("device id must be unique")
|
|
28
|
+
for device_id in values:
|
|
29
|
+
if not 0 <= device_id <= 4095:
|
|
30
|
+
parser.error(f"the argument 'device_id' must be in range [0, 4095], but got {device_id}")
|
|
31
|
+
setattr(namespace, self.dest, values)
|
|
32
|
+
|
|
33
|
+
|
|
1
34
|
def add_api_accuracy_checker_argument(parser):
|
|
2
35
|
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
|
|
3
36
|
help="<Required> The api param tool result file: generate from api param tool, "
|
|
4
37
|
"a json file.")
|
|
5
38
|
parser.add_argument("-o", "--out_path", dest="out_path", default="./", type=str, required=False,
|
|
6
|
-
help="<optional> The ut task result out path.")
|
|
39
|
+
help="<optional> The ut task result out path.")
|
|
40
|
+
parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
|
|
41
|
+
help="<optional> the exit csv for continue")
|
|
42
|
+
|
|
43
|
+
def multi_add_api_accuracy_checker_argument(parser):
|
|
44
|
+
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
|
|
45
|
+
help="<Required> The api param tool result file: generate from api param tool, "
|
|
46
|
+
"a json file.")
|
|
47
|
+
parser.add_argument("-o", "--out_path", dest="out_path", default="./", type=str, required=False,
|
|
48
|
+
help="<optional> The ut task result out path.")
|
|
49
|
+
parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
|
|
50
|
+
help="<optional> the exit csv for continue")
|
|
51
|
+
#以下属于多线程参数
|
|
52
|
+
parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int,
|
|
53
|
+
help="<optional> set device id to run ut, must be unique and in range 0-7",
|
|
54
|
+
default=[0], required=False, action=UniqueDeviceAction)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def check_args(args):
|
|
58
|
+
args.api_info_file = os.path.abspath(args.api_info_file)
|
|
59
|
+
check_file_or_directory_path(args.api_info_file)
|
|
60
|
+
|
|
61
|
+
if args.out_path == "":
|
|
62
|
+
args.out_path = "./"
|
|
63
|
+
args.out_path = os.path.abspath(args.out_path)
|
|
64
|
+
create_directory(args.out_path)
|
|
65
|
+
|
|
66
|
+
if args.result_csv_path:
|
|
67
|
+
args.result_csv_path = os.path.abspath(args.result_csv_path)
|
|
68
|
+
check_file_or_directory_path(args.result_csv_path)
|
|
@@ -1,21 +1,37 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
|
|
3
18
|
import mindspore
|
|
4
|
-
import torch
|
|
5
19
|
import numpy as np
|
|
6
|
-
|
|
7
|
-
from
|
|
20
|
+
import torch
|
|
21
|
+
from mindspore._c_expression import typing
|
|
22
|
+
from msprobe.core.common.const import Const
|
|
8
23
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
9
24
|
from msprobe.core.common.file_utils import load_npy
|
|
10
|
-
from msprobe.mindspore.api_accuracy_checker.type_mapping import (
|
|
25
|
+
from msprobe.mindspore.api_accuracy_checker.type_mapping import (api_info_type_str_to_type,
|
|
11
26
|
ms_dtype_to_dtype_str, torch_dtype_to_dtype_str,
|
|
12
27
|
dtype_str_to_ms_dtype, dtype_str_to_np_dtype,
|
|
13
28
|
dtype_str_to_torch_dtype, type_to_api_info_type_str,
|
|
14
29
|
DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE, TUPLE_TYPE_STR,
|
|
15
|
-
MINDSPORE_TENSOR_TYPE_STR,
|
|
16
|
-
|
|
17
|
-
|
|
30
|
+
MINDSPORE_TENSOR_TYPE_STR, MINDSPORE_DTYPE_TYPE_STR,
|
|
31
|
+
SLICE_TYPE_STR, TORCH_DTYPE_TYPE_STR,
|
|
32
|
+
float_dtype_str_list, int_dtype_str_list)
|
|
18
33
|
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
|
|
34
|
+
from msprobe.mindspore.common.log import logger
|
|
19
35
|
|
|
20
36
|
|
|
21
37
|
class MstensorMetaData:
|
|
@@ -26,6 +42,12 @@ class MstensorMetaData:
|
|
|
26
42
|
self.minimum = minimum
|
|
27
43
|
self.shape = shape
|
|
28
44
|
|
|
45
|
+
|
|
46
|
+
class DtypeMetaData:
|
|
47
|
+
def __init__(self, dtype_str) -> None:
|
|
48
|
+
self.dtype_str = dtype_str
|
|
49
|
+
|
|
50
|
+
|
|
29
51
|
class ComputeElement:
|
|
30
52
|
def __init__(self, compute_element_info=None, parameter=None):
|
|
31
53
|
self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple])
|
|
@@ -118,6 +140,11 @@ class ComputeElement:
|
|
|
118
140
|
for compute_element in self.parameter])
|
|
119
141
|
elif isinstance(self.parameter, self.supported_parameter_type):
|
|
120
142
|
parameter_tmp = self.parameter
|
|
143
|
+
elif isinstance(self.parameter, DtypeMetaData):
|
|
144
|
+
if tensor_platform == Const.MS_FRAMEWORK:
|
|
145
|
+
parameter_tmp = dtype_str_to_ms_dtype.get(self.parameter.dtype_str)
|
|
146
|
+
else:
|
|
147
|
+
parameter_tmp = dtype_str_to_torch_dtype.get(self.parameter.dtype_str)
|
|
121
148
|
elif isinstance(self.parameter, MstensorMetaData):
|
|
122
149
|
mstensor_meta_data = self.parameter
|
|
123
150
|
ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
|
|
@@ -130,13 +157,13 @@ class ComputeElement:
|
|
|
130
157
|
parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype)
|
|
131
158
|
else:
|
|
132
159
|
err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \
|
|
133
|
-
|
|
160
|
+
"(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)"
|
|
134
161
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
135
162
|
|
|
136
163
|
# if necessary, do transfer
|
|
137
164
|
if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
|
|
138
165
|
parameter = self.transfer_to_torch_tensor(parameter_tmp)
|
|
139
|
-
elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform ==Const.MS_FRAMEWORK:
|
|
166
|
+
elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform == Const.MS_FRAMEWORK:
|
|
140
167
|
parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
|
|
141
168
|
else:
|
|
142
169
|
parameter = parameter_tmp
|
|
@@ -183,34 +210,38 @@ class ComputeElement:
|
|
|
183
210
|
else:
|
|
184
211
|
type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json",
|
|
185
212
|
accepted_type=str, accepted_value=api_info_type_str_to_type.keys())
|
|
186
|
-
|
|
213
|
+
self.shape = tuple()
|
|
214
|
+
self.dtype_str = type_str
|
|
187
215
|
if type_str == MINDSPORE_TENSOR_TYPE_STR:
|
|
188
216
|
self._init_from_mstensor_compute_element_info(compute_element_info)
|
|
189
|
-
else:
|
|
217
|
+
else:
|
|
190
218
|
value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json")
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
219
|
+
if type_str == MINDSPORE_DTYPE_TYPE_STR:
|
|
220
|
+
self.parameter = DtypeMetaData(value)
|
|
221
|
+
elif type_str == SLICE_TYPE_STR:
|
|
222
|
+
self.parameter = slice(*tuple(value))
|
|
223
|
+
else: # type_str in ("str", "int", "float", "bool")
|
|
224
|
+
self.parameter = value
|
|
194
225
|
|
|
195
226
|
def _init_from_mstensor_compute_element_info(self, compute_element_info):
|
|
196
227
|
'''
|
|
197
228
|
do not load real tensor, only record meta data
|
|
198
229
|
'''
|
|
199
230
|
dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json",
|
|
200
|
-
|
|
231
|
+
accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys())
|
|
201
232
|
shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json",
|
|
202
|
-
|
|
233
|
+
accepted_type=(list,))
|
|
203
234
|
if global_context.get_is_constructed():
|
|
204
235
|
maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json",
|
|
205
|
-
|
|
236
|
+
accepted_type=(int, float))
|
|
206
237
|
minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json",
|
|
207
|
-
|
|
238
|
+
accepted_type=(int, float))
|
|
208
239
|
|
|
209
240
|
npy_path = None
|
|
210
241
|
else:
|
|
211
242
|
maximum, minimum = None, None
|
|
212
243
|
data_name = check_and_get_from_json_dict(compute_element_info, "data_name",
|
|
213
|
-
|
|
244
|
+
"data_name field in api_info.json", accepted_type=(str,))
|
|
214
245
|
npy_path = os.path.join(global_context.get_dump_data_dir(), data_name)
|
|
215
246
|
mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape)
|
|
216
247
|
self.parameter = mstensor_meta_data
|
|
@@ -219,9 +250,10 @@ class ComputeElement:
|
|
|
219
250
|
|
|
220
251
|
def _init_with_parameter(self, parameter):
|
|
221
252
|
self.parameter = parameter
|
|
253
|
+
self.shape = tuple()
|
|
222
254
|
if not isinstance(parameter, self.supported_parameter_type):
|
|
223
255
|
err_msg = "ComputeElement._init_with_parameter failed: " \
|
|
224
|
-
|
|
256
|
+
"parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)"
|
|
225
257
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
226
258
|
if isinstance(parameter, mindspore.Tensor):
|
|
227
259
|
self.shape = tuple(parameter.shape)
|
|
@@ -229,11 +261,14 @@ class ComputeElement:
|
|
|
229
261
|
elif isinstance(parameter, torch.Tensor):
|
|
230
262
|
self.shape = tuple(parameter.shape)
|
|
231
263
|
self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype)
|
|
264
|
+
elif isinstance(parameter, typing.Type):
|
|
265
|
+
self.dtype_str = MINDSPORE_DTYPE_TYPE_STR
|
|
266
|
+
self.parameter = DtypeMetaData(ms_dtype_to_dtype_str.get(parameter))
|
|
267
|
+
elif isinstance(parameter, torch.dtype):
|
|
268
|
+
self.dtype_str = TORCH_DTYPE_TYPE_STR
|
|
269
|
+
self.parameter = DtypeMetaData(torch_dtype_to_dtype_str.get(parameter))
|
|
232
270
|
elif isinstance(parameter, tuple):
|
|
233
|
-
self.shape = tuple()
|
|
234
271
|
self.dtype_str = TUPLE_TYPE_STR
|
|
235
272
|
self.parameter = tuple([ComputeElement(parameter=param) for param in parameter])
|
|
236
273
|
else:
|
|
237
|
-
self.
|
|
238
|
-
self.dtype_str = \
|
|
239
|
-
TUPLE_TYPE_STR if isinstance(parameter, tuple) else type_to_api_info_type_str.get(type(parameter))
|
|
274
|
+
self.dtype_str = type_to_api_info_type_str.get(type(parameter))
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import csv
|
|
18
|
+
|
|
19
|
+
from msprobe.core.common.const import Const, CompareConst, MsCompareConst
|
|
20
|
+
from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv
|
|
21
|
+
from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException
|
|
22
|
+
from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
|
|
23
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
24
|
+
from msprobe.mindspore.common.log import logger
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ResultCsvEntry:
|
|
28
|
+
def __init__(self) -> None:
|
|
29
|
+
self.forward_pass_status = None
|
|
30
|
+
self.backward_pass_status = None
|
|
31
|
+
self.forward_err_msg = ""
|
|
32
|
+
self.backward_err_msg = ""
|
|
33
|
+
self.overall_err_msg = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def write_csv_header(csv_path, header_func):
|
|
37
|
+
"""如果是第一次写入,则写入 CSV 表头"""
|
|
38
|
+
header = header_func() # 获取表头
|
|
39
|
+
logger.debug(f"Writing CSV header: {header}")
|
|
40
|
+
write_csv([header], csv_path, mode="a+")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_result_csv_header():
|
|
44
|
+
"""获取结果 CSV 文件的表头"""
|
|
45
|
+
return [
|
|
46
|
+
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
47
|
+
MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
|
|
48
|
+
MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
|
|
49
|
+
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_detail_csv_header():
|
|
54
|
+
"""获取详细 CSV 文件的表头"""
|
|
55
|
+
detail_csv_header_basic_info = [
|
|
56
|
+
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
57
|
+
MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
|
|
58
|
+
MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
|
|
59
|
+
MsCompareConst.DETAIL_CSV_SHAPE,
|
|
60
|
+
]
|
|
61
|
+
detail_csv_header_compare_result = list(compare_algorithms.keys())
|
|
62
|
+
detail_csv_header_status = [
|
|
63
|
+
MsCompareConst.DETAIL_CSV_PASS_STATUS,
|
|
64
|
+
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
65
|
+
]
|
|
66
|
+
return detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def check_csv_header(headers, required_constants, csv_path):
|
|
70
|
+
"""校验 CSV 文件表头是否包含所有必需的常量"""
|
|
71
|
+
missing_constants = [const for const in required_constants if not any(const in header for header in headers)]
|
|
72
|
+
|
|
73
|
+
if missing_constants:
|
|
74
|
+
raise MsprobeBaseException(
|
|
75
|
+
MsprobeBaseException.MISSING_HEADER_ERROR,
|
|
76
|
+
f"{csv_path} 缺少以下必需的表头字段: {missing_constants}"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class DataManager:
|
|
81
|
+
def __init__(self, csv_dir, result_csv_path):
|
|
82
|
+
self.results = {}
|
|
83
|
+
self.is_first_write = True # 标记用于添加表头
|
|
84
|
+
self.csv_dir = csv_dir
|
|
85
|
+
self.api_names_set = set() # 存储已经出现的 API 名称的集合
|
|
86
|
+
# 如果传入了 result_csv_path,则启用断点续检
|
|
87
|
+
if result_csv_path:
|
|
88
|
+
self.resume_from_last_csv(result_csv_path)
|
|
89
|
+
self.initialize_api_names_set(result_csv_path)
|
|
90
|
+
else:
|
|
91
|
+
# 默认情况下,设置输出路径为空,等待首次写入时初始化
|
|
92
|
+
self.result_out_path = os.path.join(self.csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
|
|
93
|
+
self.detail_out_path = os.path.join(
|
|
94
|
+
self.csv_dir,
|
|
95
|
+
os.path.basename(self.result_out_path).replace("result", "details")
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
if self.detail_out_path and os.path.exists(self.detail_out_path):
|
|
99
|
+
check_file_or_directory_path(self.detail_out_path)
|
|
100
|
+
|
|
101
|
+
if self.result_out_path and os.path.exists(self.result_out_path):
|
|
102
|
+
check_file_or_directory_path(self.result_out_path)
|
|
103
|
+
|
|
104
|
+
def initialize_api_names_set(self, result_csv_path):
|
|
105
|
+
"""读取现有的 CSV 文件并存储已经出现的 API 名称到集合中"""
|
|
106
|
+
# 使用新的 read_csv 函数读取数据
|
|
107
|
+
csv_data = read_csv(result_csv_path, as_pd=False)
|
|
108
|
+
|
|
109
|
+
# 读取标题行
|
|
110
|
+
headers = csv_data[0] if csv_data else [] # 如果文件为空,则 headers 会为空
|
|
111
|
+
|
|
112
|
+
# 使用提取的表头校验函数
|
|
113
|
+
if check_csv_header(headers, get_result_csv_header(), result_csv_path):
|
|
114
|
+
|
|
115
|
+
# 获取 "API Name" 列的索引
|
|
116
|
+
api_name_index = None
|
|
117
|
+
for i, header in enumerate(headers):
|
|
118
|
+
if MsCompareConst.DETAIL_CSV_API_NAME in header: # CSV 文件的标题行包含了字节顺序标记,所以使用通过包含方式来查找
|
|
119
|
+
api_name_index = i
|
|
120
|
+
break
|
|
121
|
+
|
|
122
|
+
if api_name_index is None:
|
|
123
|
+
logger.warning(f"{result_csv_path} No column contains 'API Name'.")
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
# 读取每一行的 API 名称
|
|
127
|
+
for row in csv_data[1:]: # 跳过标题行,从第二行开始
|
|
128
|
+
if row and len(row) > api_name_index:
|
|
129
|
+
api_name = row[api_name_index]
|
|
130
|
+
if api_name:
|
|
131
|
+
self.api_names_set.add(api_name)
|
|
132
|
+
|
|
133
|
+
logger.debug(f"Initialized API names set from existing CSV: {self.api_names_set}")
|
|
134
|
+
|
|
135
|
+
def is_unique_api(self, api_name):
|
|
136
|
+
"""检查 API 名称是否唯一,如果已经存在则返回 False,否则加入集合并返回 True"""
|
|
137
|
+
if api_name in self.api_names_set:
|
|
138
|
+
return False
|
|
139
|
+
self.api_names_set.add(api_name)
|
|
140
|
+
return True
|
|
141
|
+
|
|
142
|
+
def resume_from_last_csv(self, result_csv_path):
|
|
143
|
+
"""从上次运行的 result_csv_path 恢复断点"""
|
|
144
|
+
# 获取上次的目录路径
|
|
145
|
+
last_dir = os.path.dirname(result_csv_path)
|
|
146
|
+
|
|
147
|
+
# 设置当前目录和输出路径,确保在首次写入时使用
|
|
148
|
+
self.csv_dir = last_dir
|
|
149
|
+
self.detail_out_path = os.path.join(last_dir, os.path.basename(result_csv_path).replace("result", "details"))
|
|
150
|
+
if self.detail_out_path and os.path.exists(self.detail_out_path):
|
|
151
|
+
check_file_or_directory_path(self.detail_out_path)
|
|
152
|
+
self.result_out_path = result_csv_path
|
|
153
|
+
self.is_first_write = False
|
|
154
|
+
|
|
155
|
+
def save_results(self, api_name_str):
|
|
156
|
+
if self.is_first_write:
|
|
157
|
+
# 直接写入表头
|
|
158
|
+
logger.info("Writing CSV headers for the first time.")
|
|
159
|
+
write_csv_header(self.detail_out_path, get_detail_csv_header)
|
|
160
|
+
write_csv_header(self.result_out_path, get_result_csv_header)
|
|
161
|
+
self.is_first_write = False # 写入后标记为 False,避免重复写入表头
|
|
162
|
+
|
|
163
|
+
"""写入详细输出和结果摘要并清理结果"""
|
|
164
|
+
logger.debug("Starting to write detailed output to CSV.")
|
|
165
|
+
self.to_detail_csv(self.detail_out_path)
|
|
166
|
+
logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.")
|
|
167
|
+
|
|
168
|
+
logger.debug("Starting to write result summary to CSV.")
|
|
169
|
+
self.to_result_csv(self.result_out_path)
|
|
170
|
+
logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.")
|
|
171
|
+
|
|
172
|
+
# 清理记录,准备下一次调用
|
|
173
|
+
self.clear_results()
|
|
174
|
+
|
|
175
|
+
def record(self, output_list):
|
|
176
|
+
if output_list is None:
|
|
177
|
+
return
|
|
178
|
+
for output in output_list:
|
|
179
|
+
api_real_name, forward_or_backward, basic_info, compare_result_dict = output
|
|
180
|
+
key = (api_real_name, forward_or_backward)
|
|
181
|
+
if key not in self.results:
|
|
182
|
+
self.results[key] = []
|
|
183
|
+
self.results[key].append((basic_info, compare_result_dict))
|
|
184
|
+
logger.debug(f"Updated self.results for key {key}: {self.results[key]}")
|
|
185
|
+
logger.debug(f"Complete self.results after recording: {self.results}")
|
|
186
|
+
|
|
187
|
+
def clear_results(self):
|
|
188
|
+
"""清空 self.results 数据"""
|
|
189
|
+
logger.debug("Clearing self.results data.")
|
|
190
|
+
self.results.clear()
|
|
191
|
+
|
|
192
|
+
def to_detail_csv(self, csv_path):
|
|
193
|
+
logger.debug("Preparing detail CSV headers and rows.")
|
|
194
|
+
detail_csv = []
|
|
195
|
+
|
|
196
|
+
detail_csv_header_compare_result = list(compare_algorithms.keys())
|
|
197
|
+
|
|
198
|
+
for _, results in self.results.items():
|
|
199
|
+
for res in results:
|
|
200
|
+
basic_info, compare_result_dict = res
|
|
201
|
+
csv_row_basic_info = [
|
|
202
|
+
basic_info.api_name,
|
|
203
|
+
basic_info.bench_dtype,
|
|
204
|
+
basic_info.tested_dtype,
|
|
205
|
+
basic_info.shape
|
|
206
|
+
]
|
|
207
|
+
csv_row_compare_result = [
|
|
208
|
+
compare_result_dict.get(algorithm_name).compare_value
|
|
209
|
+
for algorithm_name in detail_csv_header_compare_result
|
|
210
|
+
]
|
|
211
|
+
csv_row_status = [basic_info.status, basic_info.err_msg]
|
|
212
|
+
csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
|
|
213
|
+
detail_csv.append(csv_row)
|
|
214
|
+
logger.debug(f"Detail CSV row added: {csv_row}")
|
|
215
|
+
|
|
216
|
+
logger.debug(f"Writing detail CSV to {csv_path}.")
|
|
217
|
+
write_csv(detail_csv, csv_path, mode="a+")
|
|
218
|
+
logger.debug(f"Detail CSV written successfully to {csv_path}.")
|
|
219
|
+
|
|
220
|
+
def to_result_csv(self, csv_path):
|
|
221
|
+
logger.debug("Preparing result CSV data.")
|
|
222
|
+
result_csv = []
|
|
223
|
+
|
|
224
|
+
result_csv_dict = {}
|
|
225
|
+
for key, results in self.results.items():
|
|
226
|
+
api_real_name, forward_or_backward = key
|
|
227
|
+
pass_status = CompareConst.PASS
|
|
228
|
+
overall_err_msg = ""
|
|
229
|
+
|
|
230
|
+
for res in results:
|
|
231
|
+
basic_info, _ = res
|
|
232
|
+
if basic_info.status != CompareConst.PASS:
|
|
233
|
+
pass_status = CompareConst.ERROR
|
|
234
|
+
overall_err_msg += basic_info.err_msg
|
|
235
|
+
|
|
236
|
+
overall_err_msg = "" if pass_status == CompareConst.PASS else overall_err_msg
|
|
237
|
+
|
|
238
|
+
if api_real_name not in result_csv_dict:
|
|
239
|
+
result_csv_dict[api_real_name] = ResultCsvEntry()
|
|
240
|
+
if forward_or_backward == Const.FORWARD:
|
|
241
|
+
result_csv_dict[api_real_name].forward_pass_status = pass_status
|
|
242
|
+
result_csv_dict[api_real_name].forward_err_msg = overall_err_msg
|
|
243
|
+
else:
|
|
244
|
+
result_csv_dict[api_real_name].backward_pass_status = pass_status
|
|
245
|
+
result_csv_dict[api_real_name].backward_err_msg = overall_err_msg
|
|
246
|
+
|
|
247
|
+
for api_name, entry in result_csv_dict.items():
|
|
248
|
+
overall_err_msg = "" if (entry.forward_pass_status == CompareConst.PASS and
|
|
249
|
+
entry.backward_pass_status == CompareConst.PASS) else \
|
|
250
|
+
entry.forward_err_msg + entry.backward_err_msg
|
|
251
|
+
row = [
|
|
252
|
+
api_name,
|
|
253
|
+
entry.forward_pass_status,
|
|
254
|
+
entry.backward_pass_status,
|
|
255
|
+
overall_err_msg
|
|
256
|
+
]
|
|
257
|
+
result_csv.append(row)
|
|
258
|
+
logger.debug(f"Result CSV row added: {row}")
|
|
259
|
+
|
|
260
|
+
write_csv(result_csv, csv_path, mode="a+")
|
|
261
|
+
logger.debug(f"Result CSV written successfully to {csv_path}.")
|
|
262
|
+
|
|
263
|
+
# 设置标记为 False,防止后续重复添加表头
|
|
264
|
+
self.is_first_write = False
|
|
@@ -1,9 +1,33 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker
|
|
2
17
|
|
|
18
|
+
from msprobe.mindspore.api_accuracy_checker.multi_api_accuracy_checker import MultiApiAccuracyChecker
|
|
19
|
+
|
|
20
|
+
from msprobe.mindspore.api_accuracy_checker.cmd_parser import check_args
|
|
21
|
+
|
|
3
22
|
|
|
4
23
|
def api_checker_main(args):
|
|
5
|
-
|
|
24
|
+
check_args(args)
|
|
25
|
+
api_accuracy_checker = ApiAccuracyChecker(args)
|
|
26
|
+
api_accuracy_checker.parse(args.api_info_file)
|
|
27
|
+
api_accuracy_checker.run_and_compare()
|
|
28
|
+
|
|
29
|
+
def mul_api_checker_main(args):
|
|
30
|
+
check_args(args)
|
|
31
|
+
api_accuracy_checker = MultiApiAccuracyChecker(args)
|
|
6
32
|
api_accuracy_checker.parse(args.api_info_file)
|
|
7
33
|
api_accuracy_checker.run_and_compare()
|
|
8
|
-
api_accuracy_checker.to_detail_csv(args.out_path)
|
|
9
|
-
api_accuracy_checker.to_result_csv(args.out_path)
|