mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import csv
|
|
18
|
+
|
|
19
|
+
from msprobe.core.common.const import Const, CompareConst, MsCompareConst
|
|
20
|
+
from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv
|
|
21
|
+
from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException
|
|
22
|
+
from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
|
|
23
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
24
|
+
from msprobe.mindspore.common.log import logger
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ResultCsvEntry:
|
|
28
|
+
def __init__(self) -> None:
|
|
29
|
+
self.forward_pass_status = None
|
|
30
|
+
self.backward_pass_status = None
|
|
31
|
+
self.forward_err_msg = ""
|
|
32
|
+
self.backward_err_msg = ""
|
|
33
|
+
self.overall_err_msg = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def write_csv_header(csv_path, header_func):
|
|
37
|
+
"""如果是第一次写入,则写入 CSV 表头"""
|
|
38
|
+
header = header_func() # 获取表头
|
|
39
|
+
logger.debug(f"Writing CSV header: {header}")
|
|
40
|
+
write_csv([header], csv_path, mode="a+")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_result_csv_header():
|
|
44
|
+
"""获取结果 CSV 文件的表头"""
|
|
45
|
+
return [
|
|
46
|
+
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
47
|
+
MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
|
|
48
|
+
MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
|
|
49
|
+
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_detail_csv_header():
|
|
54
|
+
"""获取详细 CSV 文件的表头"""
|
|
55
|
+
detail_csv_header_basic_info = [
|
|
56
|
+
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
57
|
+
MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
|
|
58
|
+
MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
|
|
59
|
+
MsCompareConst.DETAIL_CSV_SHAPE,
|
|
60
|
+
]
|
|
61
|
+
detail_csv_header_compare_result = list(compare_algorithms.keys())
|
|
62
|
+
detail_csv_header_status = [
|
|
63
|
+
MsCompareConst.DETAIL_CSV_PASS_STATUS,
|
|
64
|
+
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
65
|
+
]
|
|
66
|
+
return detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def check_csv_header(headers, required_constants, csv_path):
|
|
70
|
+
"""校验 CSV 文件表头是否包含所有必需的常量"""
|
|
71
|
+
missing_constants = [const for const in required_constants if not any(const in header for header in headers)]
|
|
72
|
+
|
|
73
|
+
if missing_constants:
|
|
74
|
+
raise MsprobeBaseException(
|
|
75
|
+
MsprobeBaseException.MISSING_HEADER_ERROR,
|
|
76
|
+
f"{csv_path} 缺少以下必需的表头字段: {missing_constants}"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class DataManager:
|
|
81
|
+
def __init__(self, csv_dir, result_csv_path):
|
|
82
|
+
self.results = {}
|
|
83
|
+
self.is_first_write = True # 标记用于添加表头
|
|
84
|
+
self.csv_dir = csv_dir
|
|
85
|
+
self.api_names_set = set() # 存储已经出现的 API 名称的集合
|
|
86
|
+
# 如果传入了 result_csv_path,则启用断点续检
|
|
87
|
+
if result_csv_path:
|
|
88
|
+
self.resume_from_last_csv(result_csv_path)
|
|
89
|
+
self.initialize_api_names_set(result_csv_path)
|
|
90
|
+
else:
|
|
91
|
+
# 默认情况下,设置输出路径为空,等待首次写入时初始化
|
|
92
|
+
self.result_out_path = os.path.join(self.csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
|
|
93
|
+
self.detail_out_path = os.path.join(
|
|
94
|
+
self.csv_dir,
|
|
95
|
+
os.path.basename(self.result_out_path).replace("result", "details")
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
if self.detail_out_path and os.path.exists(self.detail_out_path):
|
|
99
|
+
check_file_or_directory_path(self.detail_out_path)
|
|
100
|
+
|
|
101
|
+
if self.result_out_path and os.path.exists(self.result_out_path):
|
|
102
|
+
check_file_or_directory_path(self.result_out_path)
|
|
103
|
+
|
|
104
|
+
def initialize_api_names_set(self, result_csv_path):
|
|
105
|
+
"""读取现有的 CSV 文件并存储已经出现的 API 名称到集合中"""
|
|
106
|
+
# 使用新的 read_csv 函数读取数据
|
|
107
|
+
csv_data = read_csv(result_csv_path, as_pd=False)
|
|
108
|
+
|
|
109
|
+
# 读取标题行
|
|
110
|
+
headers = csv_data[0] if csv_data else [] # 如果文件为空,则 headers 会为空
|
|
111
|
+
|
|
112
|
+
# 使用提取的表头校验函数
|
|
113
|
+
if check_csv_header(headers, get_result_csv_header(), result_csv_path):
|
|
114
|
+
|
|
115
|
+
# 获取 "API Name" 列的索引
|
|
116
|
+
api_name_index = None
|
|
117
|
+
for i, header in enumerate(headers):
|
|
118
|
+
if MsCompareConst.DETAIL_CSV_API_NAME in header: # CSV 文件的标题行包含了字节顺序标记,所以使用通过包含方式来查找
|
|
119
|
+
api_name_index = i
|
|
120
|
+
break
|
|
121
|
+
|
|
122
|
+
if api_name_index is None:
|
|
123
|
+
logger.warning(f"{result_csv_path} No column contains 'API Name'.")
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
# 读取每一行的 API 名称
|
|
127
|
+
for row in csv_data[1:]: # 跳过标题行,从第二行开始
|
|
128
|
+
if row and len(row) > api_name_index:
|
|
129
|
+
api_name = row[api_name_index]
|
|
130
|
+
if api_name:
|
|
131
|
+
self.api_names_set.add(api_name)
|
|
132
|
+
|
|
133
|
+
logger.debug(f"Initialized API names set from existing CSV: {self.api_names_set}")
|
|
134
|
+
|
|
135
|
+
def is_unique_api(self, api_name):
|
|
136
|
+
"""检查 API 名称是否唯一,如果已经存在则返回 False,否则加入集合并返回 True"""
|
|
137
|
+
if api_name in self.api_names_set:
|
|
138
|
+
return False
|
|
139
|
+
self.api_names_set.add(api_name)
|
|
140
|
+
return True
|
|
141
|
+
|
|
142
|
+
def resume_from_last_csv(self, result_csv_path):
|
|
143
|
+
"""从上次运行的 result_csv_path 恢复断点"""
|
|
144
|
+
# 获取上次的目录路径
|
|
145
|
+
last_dir = os.path.dirname(result_csv_path)
|
|
146
|
+
|
|
147
|
+
# 设置当前目录和输出路径,确保在首次写入时使用
|
|
148
|
+
self.csv_dir = last_dir
|
|
149
|
+
self.detail_out_path = os.path.join(last_dir, os.path.basename(result_csv_path).replace("result", "details"))
|
|
150
|
+
if self.detail_out_path and os.path.exists(self.detail_out_path):
|
|
151
|
+
check_file_or_directory_path(self.detail_out_path)
|
|
152
|
+
self.result_out_path = result_csv_path
|
|
153
|
+
self.is_first_write = False
|
|
154
|
+
|
|
155
|
+
def save_results(self, api_name_str):
|
|
156
|
+
if self.is_first_write:
|
|
157
|
+
# 直接写入表头
|
|
158
|
+
logger.info("Writing CSV headers for the first time.")
|
|
159
|
+
write_csv_header(self.detail_out_path, get_detail_csv_header)
|
|
160
|
+
write_csv_header(self.result_out_path, get_result_csv_header)
|
|
161
|
+
self.is_first_write = False # 写入后标记为 False,避免重复写入表头
|
|
162
|
+
|
|
163
|
+
"""写入详细输出和结果摘要并清理结果"""
|
|
164
|
+
logger.debug("Starting to write detailed output to CSV.")
|
|
165
|
+
self.to_detail_csv(self.detail_out_path)
|
|
166
|
+
logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.")
|
|
167
|
+
|
|
168
|
+
logger.debug("Starting to write result summary to CSV.")
|
|
169
|
+
self.to_result_csv(self.result_out_path)
|
|
170
|
+
logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.")
|
|
171
|
+
|
|
172
|
+
# 清理记录,准备下一次调用
|
|
173
|
+
self.clear_results()
|
|
174
|
+
|
|
175
|
+
def record(self, output_list):
|
|
176
|
+
if output_list is None:
|
|
177
|
+
return
|
|
178
|
+
for output in output_list:
|
|
179
|
+
api_real_name, forward_or_backward, basic_info, compare_result_dict = output
|
|
180
|
+
key = (api_real_name, forward_or_backward)
|
|
181
|
+
if key not in self.results:
|
|
182
|
+
self.results[key] = []
|
|
183
|
+
self.results[key].append((basic_info, compare_result_dict))
|
|
184
|
+
logger.debug(f"Updated self.results for key {key}: {self.results[key]}")
|
|
185
|
+
logger.debug(f"Complete self.results after recording: {self.results}")
|
|
186
|
+
|
|
187
|
+
def clear_results(self):
|
|
188
|
+
"""清空 self.results 数据"""
|
|
189
|
+
logger.debug("Clearing self.results data.")
|
|
190
|
+
self.results.clear()
|
|
191
|
+
|
|
192
|
+
def to_detail_csv(self, csv_path):
|
|
193
|
+
logger.debug("Preparing detail CSV headers and rows.")
|
|
194
|
+
detail_csv = []
|
|
195
|
+
|
|
196
|
+
detail_csv_header_compare_result = list(compare_algorithms.keys())
|
|
197
|
+
|
|
198
|
+
for _, results in self.results.items():
|
|
199
|
+
for res in results:
|
|
200
|
+
basic_info, compare_result_dict = res
|
|
201
|
+
csv_row_basic_info = [
|
|
202
|
+
basic_info.api_name,
|
|
203
|
+
basic_info.bench_dtype,
|
|
204
|
+
basic_info.tested_dtype,
|
|
205
|
+
basic_info.shape
|
|
206
|
+
]
|
|
207
|
+
csv_row_compare_result = [
|
|
208
|
+
compare_result_dict.get(algorithm_name).compare_value
|
|
209
|
+
for algorithm_name in detail_csv_header_compare_result
|
|
210
|
+
]
|
|
211
|
+
csv_row_status = [basic_info.status, basic_info.err_msg]
|
|
212
|
+
csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
|
|
213
|
+
detail_csv.append(csv_row)
|
|
214
|
+
logger.debug(f"Detail CSV row added: {csv_row}")
|
|
215
|
+
|
|
216
|
+
logger.debug(f"Writing detail CSV to {csv_path}.")
|
|
217
|
+
write_csv(detail_csv, csv_path, mode="a+")
|
|
218
|
+
logger.debug(f"Detail CSV written successfully to {csv_path}.")
|
|
219
|
+
|
|
220
|
+
def to_result_csv(self, csv_path):
|
|
221
|
+
logger.debug("Preparing result CSV data.")
|
|
222
|
+
result_csv = []
|
|
223
|
+
|
|
224
|
+
result_csv_dict = {}
|
|
225
|
+
for key, results in self.results.items():
|
|
226
|
+
api_real_name, forward_or_backward = key
|
|
227
|
+
pass_status = CompareConst.PASS
|
|
228
|
+
overall_err_msg = ""
|
|
229
|
+
|
|
230
|
+
for res in results:
|
|
231
|
+
basic_info, _ = res
|
|
232
|
+
if basic_info.status != CompareConst.PASS:
|
|
233
|
+
pass_status = CompareConst.ERROR
|
|
234
|
+
overall_err_msg += basic_info.err_msg
|
|
235
|
+
|
|
236
|
+
overall_err_msg = "" if pass_status == CompareConst.PASS else overall_err_msg
|
|
237
|
+
|
|
238
|
+
if api_real_name not in result_csv_dict:
|
|
239
|
+
result_csv_dict[api_real_name] = ResultCsvEntry()
|
|
240
|
+
if forward_or_backward == Const.FORWARD:
|
|
241
|
+
result_csv_dict[api_real_name].forward_pass_status = pass_status
|
|
242
|
+
result_csv_dict[api_real_name].forward_err_msg = overall_err_msg
|
|
243
|
+
else:
|
|
244
|
+
result_csv_dict[api_real_name].backward_pass_status = pass_status
|
|
245
|
+
result_csv_dict[api_real_name].backward_err_msg = overall_err_msg
|
|
246
|
+
|
|
247
|
+
for api_name, entry in result_csv_dict.items():
|
|
248
|
+
overall_err_msg = "" if (entry.forward_pass_status == CompareConst.PASS and
|
|
249
|
+
entry.backward_pass_status == CompareConst.PASS) else \
|
|
250
|
+
entry.forward_err_msg + entry.backward_err_msg
|
|
251
|
+
row = [
|
|
252
|
+
api_name,
|
|
253
|
+
entry.forward_pass_status,
|
|
254
|
+
entry.backward_pass_status,
|
|
255
|
+
overall_err_msg
|
|
256
|
+
]
|
|
257
|
+
result_csv.append(row)
|
|
258
|
+
logger.debug(f"Result CSV row added: {row}")
|
|
259
|
+
|
|
260
|
+
write_csv(result_csv, csv_path, mode="a+")
|
|
261
|
+
logger.debug(f"Result CSV written successfully to {csv_path}.")
|
|
262
|
+
|
|
263
|
+
# 设置标记为 False,防止后续重复添加表头
|
|
264
|
+
self.is_first_write = False
|
|
@@ -1,9 +1,33 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker
|
|
2
17
|
|
|
18
|
+
from msprobe.mindspore.api_accuracy_checker.multi_api_accuracy_checker import MultiApiAccuracyChecker
|
|
19
|
+
|
|
20
|
+
from msprobe.mindspore.api_accuracy_checker.cmd_parser import check_args
|
|
21
|
+
|
|
3
22
|
|
|
4
23
|
def api_checker_main(args):
|
|
5
|
-
|
|
24
|
+
check_args(args)
|
|
25
|
+
api_accuracy_checker = ApiAccuracyChecker(args)
|
|
26
|
+
api_accuracy_checker.parse(args.api_info_file)
|
|
27
|
+
api_accuracy_checker.run_and_compare()
|
|
28
|
+
|
|
29
|
+
def mul_api_checker_main(args):
|
|
30
|
+
check_args(args)
|
|
31
|
+
api_accuracy_checker = MultiApiAccuracyChecker(args)
|
|
6
32
|
api_accuracy_checker.parse(args.api_info_file)
|
|
7
33
|
api_accuracy_checker.run_and_compare()
|
|
8
|
-
api_accuracy_checker.to_detail_csv(args.out_path)
|
|
9
|
-
api_accuracy_checker.to_result_csv(args.out_path)
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# 标准库导入
|
|
17
|
+
import multiprocessing
|
|
18
|
+
from multiprocessing import Manager
|
|
19
|
+
import os
|
|
20
|
+
import signal
|
|
21
|
+
import sys
|
|
22
|
+
import time
|
|
23
|
+
|
|
24
|
+
# 第三方库导入
|
|
25
|
+
from mindspore import context
|
|
26
|
+
import numpy as np
|
|
27
|
+
from tqdm import tqdm
|
|
28
|
+
|
|
29
|
+
# 本地应用/库特定导入
|
|
30
|
+
from msprobe.core.common.const import Const, CompareConst, MsCompareConst
|
|
31
|
+
from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker, BasicInfoAndStatus
|
|
32
|
+
from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataManager
|
|
33
|
+
from msprobe.mindspore.common.log import logger
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MultiApiAccuracyChecker(ApiAccuracyChecker):
|
|
37
|
+
def __init__(self, args):
|
|
38
|
+
# 可以添加 MultiApiAccuracyChecker 特有的属性或方法
|
|
39
|
+
self.api_infos = dict()
|
|
40
|
+
|
|
41
|
+
# 使用 Manager 创建共享变量,确保进程间的同步
|
|
42
|
+
self.manager = Manager()
|
|
43
|
+
self.is_first_write = self.manager.Value('b', True) # 创建共享变量
|
|
44
|
+
|
|
45
|
+
# 初始化 DataManager 时传入共享的 is_first_write
|
|
46
|
+
self.multi_data_manager = MultiDataManager(args.out_path, args.result_csv_path, self.is_first_write)
|
|
47
|
+
|
|
48
|
+
self.args = args # 将 args 保存为类的属性
|
|
49
|
+
|
|
50
|
+
# 初始化一个属性来存储当前的设备ID(用于日志中显示)
|
|
51
|
+
self.current_device_id = None
|
|
52
|
+
|
|
53
|
+
def process_on_device(self, device_id, api_infos, progress_queue):
|
|
54
|
+
"""
|
|
55
|
+
在特定设备上处理一部分API。
|
|
56
|
+
|
|
57
|
+
参数:
|
|
58
|
+
device_id (int): 要使用的设备ID。
|
|
59
|
+
api_infos (list): 包含API名称和对应信息的元组列表。
|
|
60
|
+
progress_queue (multiprocessing.Queue): 用于通信进度更新的队列。
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
# 设置当前设备ID
|
|
64
|
+
self.current_device_id = device_id
|
|
65
|
+
|
|
66
|
+
# 设置 MindSpore context 的 device_id
|
|
67
|
+
context.set_context(device_id=device_id)
|
|
68
|
+
|
|
69
|
+
# 遍历当前进程分配的任务
|
|
70
|
+
for _, (api_name_str, api_info) in enumerate(api_infos):
|
|
71
|
+
logger.debug(f"Processing API: {api_name_str}, Device: {device_id}")
|
|
72
|
+
|
|
73
|
+
if not self.multi_data_manager.is_unique_api(api_name_str):
|
|
74
|
+
logger.debug(f"API {api_name_str} is not unique, skipping.")
|
|
75
|
+
progress_queue.put(1)
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
# 处理前向
|
|
79
|
+
forward_output_list = self.process_forward(api_name_str, api_info)
|
|
80
|
+
if forward_output_list is not Const.EXCEPTION_NONE:
|
|
81
|
+
self.multi_data_manager.record(forward_output_list)
|
|
82
|
+
|
|
83
|
+
# 处理反向
|
|
84
|
+
backward_output_list = self.process_backward(api_name_str, api_info)
|
|
85
|
+
if backward_output_list is not Const.EXCEPTION_NONE:
|
|
86
|
+
self.multi_data_manager.record(backward_output_list)
|
|
87
|
+
|
|
88
|
+
# 保存结果
|
|
89
|
+
self.multi_data_manager.save_results(api_name_str)
|
|
90
|
+
progress_queue.put(1) # 更新进度
|
|
91
|
+
|
|
92
|
+
def run_and_compare(self):
|
|
93
|
+
# 获取要使用的设备ID列表
|
|
94
|
+
device_ids = self.args.device_id
|
|
95
|
+
|
|
96
|
+
# 按设备数划分要处理的 API 项
|
|
97
|
+
partitioned_api_infos = list(self.api_infos.items())
|
|
98
|
+
|
|
99
|
+
# 在主进程中进行交叉任务切分(基于取模的方式)
|
|
100
|
+
partitioned_api_infos_split = [[] for _ in range(len(device_ids))]
|
|
101
|
+
for idx, api_info in enumerate(partitioned_api_infos):
|
|
102
|
+
device_index = idx % len(device_ids) # 使用取模方法分配任务
|
|
103
|
+
partitioned_api_infos_split[device_index].append(api_info)
|
|
104
|
+
|
|
105
|
+
# 创建一个共享进度队列
|
|
106
|
+
progress_queue = multiprocessing.Queue()
|
|
107
|
+
|
|
108
|
+
# 进度条
|
|
109
|
+
total_tasks = len(partitioned_api_infos) # 计算总任务数
|
|
110
|
+
with tqdm(total=total_tasks, desc="Total Progress", ncols=100) as pbar:
|
|
111
|
+
# 创建多进程
|
|
112
|
+
processes = []
|
|
113
|
+
for index, device_id in enumerate(device_ids):
|
|
114
|
+
process = multiprocessing.Process(target=self.process_on_device,
|
|
115
|
+
args=(device_id, partitioned_api_infos_split[index], progress_queue))
|
|
116
|
+
processes.append(process)
|
|
117
|
+
process.start()
|
|
118
|
+
|
|
119
|
+
# 主进程更新进度条
|
|
120
|
+
completed_tasks = 0
|
|
121
|
+
while completed_tasks < total_tasks:
|
|
122
|
+
try:
|
|
123
|
+
completed_tasks += progress_queue.get(timeout=Const.PROGRESS_TIMEOUT) # 设置超时时间(秒)
|
|
124
|
+
pbar.update(1)
|
|
125
|
+
except multiprocessing.queues.Empty:
|
|
126
|
+
logger.error("Timeout while waiting for progress updates. Skipping remaining tasks.")
|
|
127
|
+
break
|
|
128
|
+
|
|
129
|
+
# 检查子进程状态
|
|
130
|
+
for process in processes:
|
|
131
|
+
if not process.is_alive():
|
|
132
|
+
if process.exitcode != 0:
|
|
133
|
+
logger.error(f"Process {process.pid} exited with code {process.exitcode}.")
|
|
134
|
+
total_tasks -= len(partitioned_api_infos_split[processes.index(process)])
|
|
135
|
+
processes.remove(process)
|
|
136
|
+
|
|
137
|
+
# 确保所有子进程完成或终止
|
|
138
|
+
for process in processes:
|
|
139
|
+
process.join(timeout=Const.PROGRESS_TIMEOUT)
|
|
140
|
+
if process.is_alive():
|
|
141
|
+
logger.error(f"Process {process.pid} did not terminate. Forcing termination.")
|
|
142
|
+
process.terminate()
|
|
143
|
+
|
|
144
|
+
def process_forward(self, api_name_str, api_info):
|
|
145
|
+
"""
|
|
146
|
+
Overrides the parent class's process_forward method to log the device ID when exceptions occur.
|
|
147
|
+
|
|
148
|
+
Parameters:
|
|
149
|
+
api_name_str (str): The name of the API.
|
|
150
|
+
api_info (object): The API information object.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
list or None: The forward output list or None if an error occurs.
|
|
154
|
+
"""
|
|
155
|
+
if not api_info.check_forward_info():
|
|
156
|
+
logger.debug(
|
|
157
|
+
f"[Device {self.current_device_id}] API: {api_name_str} lacks forward information, skipping forward check.")
|
|
158
|
+
return Const.EXCEPTION_NONE
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
|
|
162
|
+
except Exception as e:
|
|
163
|
+
logger.warning(
|
|
164
|
+
f"[Device {self.current_device_id}] Exception occurred while getting forward API inputs for {api_name_str}. Skipping forward check. Detailed exception information: {e}.")
|
|
165
|
+
return Const.EXCEPTION_NONE
|
|
166
|
+
|
|
167
|
+
forward_output_list = None
|
|
168
|
+
try:
|
|
169
|
+
forward_output_list = self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation,
|
|
170
|
+
Const.FORWARD)
|
|
171
|
+
except Exception as e:
|
|
172
|
+
logger.warning(
|
|
173
|
+
f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str} forward API. Detailed exception information: {e}.")
|
|
174
|
+
return forward_output_list
|
|
175
|
+
|
|
176
|
+
def process_backward(self, api_name_str, api_info):
|
|
177
|
+
"""
|
|
178
|
+
Overrides the parent class's process_backward method to log the device ID when exceptions occur.
|
|
179
|
+
|
|
180
|
+
Parameters:
|
|
181
|
+
api_name_str (str): The name of the API.
|
|
182
|
+
api_info (object): The API information object.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
list or None: The backward output list or None if an error occurs.
|
|
186
|
+
"""
|
|
187
|
+
if not api_info.check_backward_info():
|
|
188
|
+
logger.debug(
|
|
189
|
+
f"[Device {self.current_device_id}] API: {api_name_str} lacks backward information, skipping backward check.")
|
|
190
|
+
return Const.EXCEPTION_NONE
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
|
|
194
|
+
except Exception as e:
|
|
195
|
+
logger.warning(
|
|
196
|
+
f"[Device {self.current_device_id}] Exception occurred while getting backward API inputs for {api_name_str}. Skipping backward check. Detailed exception information: {e}.")
|
|
197
|
+
return Const.EXCEPTION_NONE
|
|
198
|
+
|
|
199
|
+
backward_output_list = None
|
|
200
|
+
try:
|
|
201
|
+
backward_output_list = self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation,
|
|
202
|
+
Const.BACKWARD)
|
|
203
|
+
except Exception as e:
|
|
204
|
+
logger.warning(
|
|
205
|
+
f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str} backward API. Detailed exception information: {e}.")
|
|
206
|
+
return backward_output_list
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
import multiprocessing
|
|
18
|
+
import os
|
|
19
|
+
|
|
20
|
+
from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager, ResultCsvEntry, write_csv_header, get_result_csv_header, get_detail_csv_header, check_csv_header
|
|
21
|
+
from msprobe.mindspore.common.log import logger
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MultiDataManager(DataManager):
|
|
25
|
+
def __init__(self, csv_dir, result_csv_path, shared_is_first_write):
|
|
26
|
+
super().__init__(csv_dir, result_csv_path)
|
|
27
|
+
|
|
28
|
+
# 使用共享的 is_first_write 变量来控制表头写入
|
|
29
|
+
self.shared_is_first_write = shared_is_first_write
|
|
30
|
+
# 创建锁对象,确保线程安全
|
|
31
|
+
self.lock = multiprocessing.Lock()
|
|
32
|
+
|
|
33
|
+
def save_results(self, api_name_str):
|
|
34
|
+
"""保存结果,线程安全操作"""
|
|
35
|
+
|
|
36
|
+
with self.lock: # 确保保存操作不会被多个进程同时进行
|
|
37
|
+
if self.is_first_write and self.shared_is_first_write.value:
|
|
38
|
+
self.shared_is_first_write.value = False
|
|
39
|
+
self.is_first_write = False # 写入后标记为 False,避免重复写入表头
|
|
40
|
+
# 直接写入表头
|
|
41
|
+
logger.info("Writing CSV headers for the first time.")
|
|
42
|
+
write_csv_header(self.detail_out_path, get_detail_csv_header)
|
|
43
|
+
write_csv_header(self.result_out_path, get_result_csv_header)
|
|
44
|
+
|
|
45
|
+
"""写入详细输出和结果摘要并清理结果"""
|
|
46
|
+
self.to_detail_csv(self.detail_out_path)
|
|
47
|
+
logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.")
|
|
48
|
+
|
|
49
|
+
self.to_result_csv(self.result_out_path)
|
|
50
|
+
logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.")
|
|
51
|
+
|
|
52
|
+
# 清理记录,准备下一次调用
|
|
53
|
+
self.clear_results()
|
|
54
|
+
|
|
55
|
+
def clear_results(self):
|
|
56
|
+
"""清空 self.results 数据,线程安全操作"""
|
|
57
|
+
logger.debug("Clearing results data.")
|
|
58
|
+
self.results.clear()
|
|
@@ -1,7 +1,23 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
3
16
|
import mindspore
|
|
17
|
+
import numpy as np
|
|
4
18
|
import torch
|
|
19
|
+
from mindspore._c_expression import typing
|
|
20
|
+
from mindspore.common import dtype as mstype
|
|
5
21
|
|
|
6
22
|
INT8 = "Int8"
|
|
7
23
|
UINT8 = "UInt8"
|
|
@@ -18,7 +34,6 @@ BOOL = "Bool"
|
|
|
18
34
|
BFLOAT16 = "BFloat16"
|
|
19
35
|
INT4 = "Int4"
|
|
20
36
|
|
|
21
|
-
|
|
22
37
|
dtype_str_to_ms_dtype = {
|
|
23
38
|
INT8: mstype.int8,
|
|
24
39
|
UINT8: mstype.uint8,
|
|
@@ -37,7 +52,6 @@ dtype_str_to_ms_dtype = {
|
|
|
37
52
|
}
|
|
38
53
|
ms_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_ms_dtype.items()}
|
|
39
54
|
|
|
40
|
-
|
|
41
55
|
dtype_str_to_np_dtype = {
|
|
42
56
|
INT8: np.int8,
|
|
43
57
|
UINT8: np.uint8,
|
|
@@ -75,6 +89,8 @@ FLOAT_TYPE_STR = "float"
|
|
|
75
89
|
SLICE_TYPE_STR = "slice"
|
|
76
90
|
TUPLE_TYPE_STR = "tuple"
|
|
77
91
|
STR_TYPE_STR = "str"
|
|
92
|
+
MINDSPORE_DTYPE_TYPE_STR = "mindspore.dtype"
|
|
93
|
+
TORCH_DTYPE_TYPE_STR = "torch.dtype"
|
|
78
94
|
|
|
79
95
|
api_info_type_str_to_type = {
|
|
80
96
|
MINDSPORE_TENSOR_TYPE_STR: mindspore.Tensor,
|
|
@@ -83,6 +99,7 @@ api_info_type_str_to_type = {
|
|
|
83
99
|
FLOAT_TYPE_STR: float,
|
|
84
100
|
SLICE_TYPE_STR: slice,
|
|
85
101
|
STR_TYPE_STR: str,
|
|
102
|
+
MINDSPORE_DTYPE_TYPE_STR: typing.Type,
|
|
86
103
|
}
|
|
87
104
|
type_to_api_info_type_str = {value: key for key, value in api_info_type_str_to_type.items()}
|
|
88
105
|
|
|
@@ -111,4 +128,4 @@ uint_dtype_str_list = [
|
|
|
111
128
|
UINT16,
|
|
112
129
|
UINT32,
|
|
113
130
|
UINT64,
|
|
114
|
-
]
|
|
131
|
+
]
|