mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -1,10 +1,24 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
1
17
|
import os
|
|
2
18
|
import time
|
|
3
|
-
import json
|
|
4
19
|
from multiprocessing import Pool
|
|
5
20
|
|
|
6
21
|
import torch
|
|
7
|
-
|
|
8
22
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
9
23
|
|
|
10
24
|
try:
|
|
@@ -14,15 +28,15 @@ except ImportError:
|
|
|
14
28
|
else:
|
|
15
29
|
is_npu = True
|
|
16
30
|
|
|
17
|
-
from msprobe.core.common.file_utils import
|
|
31
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, load_yaml, FileOpen, create_directory
|
|
18
32
|
from msprobe.core.common.const import Const, CompareConst
|
|
19
33
|
from msprobe.pytorch.common.log import logger
|
|
20
|
-
from msprobe.pytorch.online_dispatch.dump_compare import dispatch_workflow, dispatch_multiprocess, error_call,
|
|
21
|
-
DispatchRunParam, DisPatchDataInfo
|
|
22
|
-
from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu,
|
|
34
|
+
from msprobe.pytorch.online_dispatch.dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, \
|
|
35
|
+
TimeStatistics, DispatchRunParam, DisPatchDataInfo
|
|
36
|
+
from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu, get_sys_info, DispatchException, \
|
|
37
|
+
COMPARE_LOGO
|
|
23
38
|
from msprobe.pytorch.online_dispatch.compare import Comparator
|
|
24
|
-
from msprobe.core.common.
|
|
25
|
-
|
|
39
|
+
from msprobe.core.common.utils import check_str_param, safe_get_value
|
|
26
40
|
|
|
27
41
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
28
42
|
RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
|
|
@@ -42,7 +56,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
42
56
|
|
|
43
57
|
self.device_id = torch_npu._C._npu_getDevice()
|
|
44
58
|
self.dump_mode = dump_mode
|
|
45
|
-
self.dump_api_list = api_list
|
|
59
|
+
self.dump_api_list = api_list or []
|
|
46
60
|
self.debug_flag = debug
|
|
47
61
|
self.api_index = 0
|
|
48
62
|
self.single_api_index_dict = {}
|
|
@@ -51,14 +65,13 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
51
65
|
self.all_summary = []
|
|
52
66
|
self.call_stack_list = []
|
|
53
67
|
self.process_num = process_num
|
|
54
|
-
self.
|
|
68
|
+
self.tag = tag
|
|
55
69
|
self.check_param()
|
|
70
|
+
self.filter_dump_api()
|
|
56
71
|
dir_name = self.get_dir_name(tag)
|
|
57
72
|
self.root_path = os.path.join(os.path.realpath(dump_path), dir_name)
|
|
58
73
|
self.root_cpu_path = os.path.join(self.root_path, f'cpu')
|
|
59
74
|
self.root_npu_path = os.path.join(self.root_path, f'npu')
|
|
60
|
-
check_path_before_create(self.root_cpu_path)
|
|
61
|
-
check_path_before_create(self.root_npu_path)
|
|
62
75
|
create_directory(self.root_cpu_path)
|
|
63
76
|
create_directory(self.root_npu_path)
|
|
64
77
|
|
|
@@ -67,7 +80,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
67
80
|
self.comparator = Comparator(self.result_csv_path, self.detail_csv_path, False)
|
|
68
81
|
|
|
69
82
|
self.aten_ops_blacklist = []
|
|
70
|
-
self.
|
|
83
|
+
self.npu_adjust_autograd = []
|
|
71
84
|
yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
|
|
72
85
|
self.get_ops(yaml_path)
|
|
73
86
|
|
|
@@ -76,8 +89,8 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
76
89
|
self.pool = Pool(process_num)
|
|
77
90
|
if debug:
|
|
78
91
|
logger.info(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} '
|
|
79
|
-
|
|
80
|
-
|
|
92
|
+
f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], '
|
|
93
|
+
f'process[{process_num}]')
|
|
81
94
|
|
|
82
95
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
83
96
|
super().__exit__(exc_type, exc_val, exc_tb)
|
|
@@ -119,7 +132,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
119
132
|
output_num = output_num + 1
|
|
120
133
|
total_num = total_num + 1
|
|
121
134
|
logger.info(f'Dispatch exit: Device[{self.device_id}], Pid[{os.getpid()} Input[{input_num}] '
|
|
122
|
-
|
|
135
|
+
f'Output[{output_num}] Total[{total_num}] API_Total[{self.api_index}]]')
|
|
123
136
|
|
|
124
137
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
125
138
|
if not is_npu:
|
|
@@ -134,7 +147,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
134
147
|
logger.error(f"Please check the func name {func.__name__}!")
|
|
135
148
|
return func(*args, **kwargs)
|
|
136
149
|
|
|
137
|
-
self.
|
|
150
|
+
self.enable_autograd(aten_api)
|
|
138
151
|
if aten_api in self.aten_ops_blacklist:
|
|
139
152
|
npu_out = func(*args, **kwargs)
|
|
140
153
|
return npu_out
|
|
@@ -151,24 +164,31 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
151
164
|
|
|
152
165
|
if self.debug_flag:
|
|
153
166
|
logger.info(f'Dispatch Info: Rank[{self.device_id}], Pid[{os.getpid()}], Func[{func.__name__}], '
|
|
154
|
-
|
|
155
|
-
|
|
167
|
+
f'Name[{run_param.aten_api}_{run_param.single_api_index}], '
|
|
168
|
+
f'Count[{self.api_index}], Sys[{get_sys_info()}]')
|
|
156
169
|
|
|
157
170
|
cpu_args = []
|
|
158
171
|
cpu_kwargs = []
|
|
159
172
|
data_to_cpu(args, 0, cpu_args)
|
|
160
173
|
data_to_cpu(kwargs, 0, cpu_kwargs)
|
|
161
|
-
|
|
162
|
-
|
|
174
|
+
|
|
175
|
+
cpu_args = safe_get_value(cpu_args, 0, "cpu_args")
|
|
176
|
+
cpu_kwargs = safe_get_value(cpu_kwargs, 0, "cpu_kwargs")
|
|
163
177
|
|
|
164
178
|
with TimeStatistics("NPU RUN", run_param):
|
|
165
179
|
npu_out = func(*args, **kwargs)
|
|
166
180
|
npu_out_cpu = []
|
|
167
181
|
data_to_cpu(npu_out, 0, npu_out_cpu)
|
|
168
|
-
npu_out_cpu = npu_out_cpu
|
|
182
|
+
npu_out_cpu = safe_get_value(npu_out_cpu, 0, "npu_out_cpu")
|
|
169
183
|
|
|
170
184
|
with TimeStatistics("CPU RUN", run_param):
|
|
171
|
-
|
|
185
|
+
try:
|
|
186
|
+
cpu_out = func(*cpu_args, **cpu_kwargs)
|
|
187
|
+
except RuntimeError as e:
|
|
188
|
+
self.api_index -= 1
|
|
189
|
+
logger.warning(f"RuntimeError: {e}")
|
|
190
|
+
logger.warning(f"This aten_api {aten_api} does not support running on cpu, so skip it.")
|
|
191
|
+
return npu_out
|
|
172
192
|
|
|
173
193
|
if isinstance(cpu_out, torch.Tensor) and cpu_out.dtype in [torch.bfloat16, torch.float16, torch.half]:
|
|
174
194
|
cpu_out = cpu_out.float()
|
|
@@ -216,7 +236,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
216
236
|
def get_ops(self, file_path):
|
|
217
237
|
yaml_file = load_yaml(file_path)
|
|
218
238
|
self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist')
|
|
219
|
-
self.
|
|
239
|
+
self.npu_adjust_autograd = yaml_file.get('npu_adjust_autograd')
|
|
220
240
|
|
|
221
241
|
def filter_dump_api(self):
|
|
222
242
|
if self.dump_mode != Const.LIST or not self.dump_api_list:
|
|
@@ -260,6 +280,17 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
260
280
|
if not isinstance(self.dump_api_list, list):
|
|
261
281
|
logger.error('The type of parameter "api_list" can only be list.')
|
|
262
282
|
raise DispatchException(DispatchException.INVALID_PARAMETER)
|
|
283
|
+
if not all(isinstance(item, str) for item in self.dump_api_list):
|
|
284
|
+
logger.error('The type of parameter in "api_list" can only be str.')
|
|
285
|
+
raise DispatchException(DispatchException.INVALID_PARAMETER)
|
|
286
|
+
if len(self.dump_api_list) > Const.STEP_RANK_MAXIMUM_VALUE:
|
|
287
|
+
logger.error('The length of parameter "api_list" should not be greater '
|
|
288
|
+
f'than {Const.STEP_RANK_MAXIMUM_VALUE}.')
|
|
289
|
+
raise DispatchException(DispatchException.INVALID_PARAMETER)
|
|
290
|
+
for item in self.dump_api_list:
|
|
291
|
+
check_str_param(item)
|
|
292
|
+
if self.tag is not None:
|
|
293
|
+
check_str_param(self.tag)
|
|
263
294
|
if not isinstance(self.debug_flag, bool):
|
|
264
295
|
logger.error('The type of parameter "debug" can only be bool.')
|
|
265
296
|
raise DispatchException(DispatchException.INVALID_PARAMETER)
|
|
@@ -267,6 +298,6 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
267
298
|
logger.error('The type of parameter "process_num" can only be int and it should not be less than 0.')
|
|
268
299
|
raise DispatchException(DispatchException.INVALID_PARAMETER)
|
|
269
300
|
|
|
270
|
-
def
|
|
271
|
-
if aten_api in self.
|
|
272
|
-
torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False)
|
|
301
|
+
def enable_autograd(self, aten_api):
|
|
302
|
+
if aten_api in self.npu_adjust_autograd:
|
|
303
|
+
torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False)
|
|
@@ -1,11 +1,26 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
3
16
|
import copy
|
|
17
|
+
import json
|
|
18
|
+
import os
|
|
4
19
|
from datetime import datetime, timezone
|
|
5
20
|
|
|
6
21
|
import torch
|
|
22
|
+
from msprobe.core.common.file_utils import FileOpen, save_npy, save_json
|
|
7
23
|
from msprobe.pytorch.common.log import logger
|
|
8
|
-
from msprobe.core.common.file_utils import FileOpen, save_npy
|
|
9
24
|
|
|
10
25
|
|
|
11
26
|
class DispatchRunParam:
|
|
@@ -55,7 +70,7 @@ class TimeStatistics:
|
|
|
55
70
|
if self.debug:
|
|
56
71
|
self.time = datetime.now(tz=timezone.utc)
|
|
57
72
|
logger.info(f'Time[{self.tag}]-ENTER: Dev[{self.device}], Pid[{os.getpid()}], Fun[{self.fun}], ' \
|
|
58
|
-
|
|
73
|
+
f'Id[{self.index}]')
|
|
59
74
|
|
|
60
75
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
61
76
|
if self.debug:
|
|
@@ -92,10 +107,8 @@ def dump_data(data, prefix, dump_path):
|
|
|
92
107
|
def save_temp_summary(api_index, single_api_summary, path, lock):
|
|
93
108
|
summary_path = os.path.join(path, f'summary.json')
|
|
94
109
|
lock.acquire()
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
f.write('\n')
|
|
98
|
-
lock.release()
|
|
110
|
+
data = [api_index, single_api_summary]
|
|
111
|
+
save_json(summary_path, data, mode='a')
|
|
99
112
|
|
|
100
113
|
|
|
101
114
|
def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo):
|
|
@@ -152,4 +165,3 @@ def dispatch_multiprocess(run_param, dispatch_data_info):
|
|
|
152
165
|
|
|
153
166
|
def error_call(err):
|
|
154
167
|
logger.error(f'multiprocess {err}')
|
|
155
|
-
|
|
@@ -1,9 +1,27 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import logging
|
|
17
|
+
from collections import namedtuple
|
|
2
18
|
from functools import wraps
|
|
19
|
+
|
|
3
20
|
import torch
|
|
4
|
-
from prettytable import PrettyTable
|
|
5
|
-
from collections import namedtuple
|
|
6
21
|
from msprobe.pytorch.common.log import logger
|
|
22
|
+
from msprobe.pytorch.online_dispatch.utils import check_idx_valid
|
|
23
|
+
from prettytable import PrettyTable
|
|
24
|
+
|
|
7
25
|
|
|
8
26
|
def func_log_wrapper():
|
|
9
27
|
def _out_wrapper(func):
|
|
@@ -13,9 +31,9 @@ def func_log_wrapper():
|
|
|
13
31
|
x = func(*kargs, **kwargs)
|
|
14
32
|
logger.info(f"end to run: {func.__name__}")
|
|
15
33
|
return x
|
|
16
|
-
|
|
34
|
+
|
|
17
35
|
return _in_wrapper
|
|
18
|
-
|
|
36
|
+
|
|
19
37
|
return _out_wrapper
|
|
20
38
|
|
|
21
39
|
|
|
@@ -31,7 +49,7 @@ class SingleBenchmarkCompareStandard:
|
|
|
31
49
|
torch.bfloat16: 2 ** -7,
|
|
32
50
|
torch.float32: 2 ** -14,
|
|
33
51
|
torch.float64: 2 ** -14}
|
|
34
|
-
|
|
52
|
+
|
|
35
53
|
def get_error_thd(self, dtype):
|
|
36
54
|
if dtype in self.error_thd.keys():
|
|
37
55
|
if dtype == torch.float64:
|
|
@@ -42,12 +60,12 @@ class SingleBenchmarkCompareStandard:
|
|
|
42
60
|
"in fp16, bf16, fp32. "
|
|
43
61
|
)
|
|
44
62
|
return None
|
|
45
|
-
|
|
63
|
+
|
|
46
64
|
def get_eb_thd(self, dtype):
|
|
47
65
|
if dtype in self.eb_thd.keys():
|
|
48
66
|
return self.eb_thd.get(dtype)
|
|
49
67
|
return None
|
|
50
|
-
|
|
68
|
+
|
|
51
69
|
|
|
52
70
|
class SingleBenchmarkAccuracyResult:
|
|
53
71
|
def __init__(
|
|
@@ -82,7 +100,7 @@ class SingleBenchmarkAccuracyCompare:
|
|
|
82
100
|
@func_log_wrapper()
|
|
83
101
|
def check_output_size(cls, npu_out, bench_out):
|
|
84
102
|
acc_result = None
|
|
85
|
-
if npu_out.numel() == 0 and bench_out.
|
|
103
|
+
if npu_out.numel() == 0 and bench_out.numel() == 0:
|
|
86
104
|
info = (
|
|
87
105
|
"The npu_output is [], and it is same as benchmark_output, "
|
|
88
106
|
"the result of data_compare is Pass"
|
|
@@ -99,14 +117,14 @@ class SingleBenchmarkAccuracyCompare:
|
|
|
99
117
|
logging.error(error_info)
|
|
100
118
|
acc_result = SingleBenchmarkAccuracyResult(result=False)
|
|
101
119
|
return acc_result
|
|
102
|
-
|
|
120
|
+
|
|
103
121
|
@classmethod
|
|
104
122
|
@func_log_wrapper()
|
|
105
123
|
def check_output_invalid_value(cls, output):
|
|
106
124
|
has_nan = torch.isnan(output).any()
|
|
107
125
|
has_inf = torch.isinf(output).any()
|
|
108
126
|
return has_nan or has_inf
|
|
109
|
-
|
|
127
|
+
|
|
110
128
|
@classmethod
|
|
111
129
|
@func_log_wrapper()
|
|
112
130
|
def precision_compare_for_case(cls, npu_out, bench_out, benchmark_standard: SingleBenchmarkCompareStandard):
|
|
@@ -119,19 +137,19 @@ class SingleBenchmarkAccuracyCompare:
|
|
|
119
137
|
if acc_result:
|
|
120
138
|
failed_info = "比对数据的shape不一致"
|
|
121
139
|
return CompareResultInfo(acc_result, error_thd, eb_thd, failed_info)
|
|
122
|
-
|
|
140
|
+
|
|
123
141
|
if cls.check_output_invalid_value(bench_out):
|
|
124
142
|
logging.info("The benchmark result contains nan/inf value. ")
|
|
125
143
|
failed_info = "标杆结果存在nan值或inf值, 依照单标杆标准该用例通过"
|
|
126
144
|
acc_result = SingleBenchmarkAccuracyResult(result=True)
|
|
127
145
|
return CompareResultInfo(acc_result, error_thd, eb_thd, failed_info)
|
|
128
|
-
|
|
146
|
+
|
|
129
147
|
if cls.check_output_invalid_value(npu_out):
|
|
130
148
|
logging.info("The NPU result contains nan/inf value. ")
|
|
131
149
|
failed_info = "NPU结果存在nan值或inf值, 依照单标杆标准该用例不通过"
|
|
132
150
|
acc_result = SingleBenchmarkAccuracyResult(result=False)
|
|
133
151
|
return CompareResultInfo(acc_result, error_thd, eb_thd, failed_info)
|
|
134
|
-
|
|
152
|
+
|
|
135
153
|
data_type = npu_out.dtype
|
|
136
154
|
if data_type not in [torch.float16, torch.float32, torch.float64, torch.bfloat16]:
|
|
137
155
|
acc_result = cls.compute_binary_diff(npu_out, bench_out)
|
|
@@ -159,7 +177,6 @@ class SingleBenchmarkAccuracyCompare:
|
|
|
159
177
|
acc_result.get_result(eb_thd, error_thd)
|
|
160
178
|
return CompareResultInfo(acc_result, error_thd, eb_thd, None)
|
|
161
179
|
|
|
162
|
-
|
|
163
180
|
@classmethod
|
|
164
181
|
@func_log_wrapper()
|
|
165
182
|
def compute_binary_diff(cls, npu_out, bench_out):
|
|
@@ -167,7 +184,7 @@ class SingleBenchmarkAccuracyCompare:
|
|
|
167
184
|
if result:
|
|
168
185
|
logger.info("二进制精度比对通过, 无需单标杆比对法验证")
|
|
169
186
|
return SingleBenchmarkAccuracyResult(result=result, max_abs_diff=0, max_rel_diff=0, error_balance=0)
|
|
170
|
-
|
|
187
|
+
|
|
171
188
|
@classmethod
|
|
172
189
|
@func_log_wrapper()
|
|
173
190
|
def compute_error_balance(cls, npu_out, bench_out, benchmark_standard: SingleBenchmarkCompareStandard):
|
|
@@ -176,11 +193,11 @@ class SingleBenchmarkAccuracyCompare:
|
|
|
176
193
|
abs_mask_idx = torch.where(torch.abs(bench_out) < benchmark_standard.small_value, ones, zeros)
|
|
177
194
|
abs_mask_idx = abs_mask_idx.type(torch.bool)
|
|
178
195
|
diff_value = torch.subtract(npu_out, bench_out)
|
|
179
|
-
diff_value_rel = diff_value / (torch.abs(bench_out) + torch.finfo(torch.float).eps
|
|
196
|
+
diff_value_rel = diff_value / (torch.abs(bench_out) + torch.finfo(torch.float).eps)
|
|
180
197
|
rel_and_abs = torch.where(abs_mask_idx, diff_value, diff_value_rel)
|
|
181
198
|
eb_float = float(torch.mean(rel_and_abs))
|
|
182
199
|
return eb_float
|
|
183
|
-
|
|
200
|
+
|
|
184
201
|
@classmethod
|
|
185
202
|
@func_log_wrapper()
|
|
186
203
|
def compute_abs_diff(cls, npu_out, bench_out, error_thd, benchmark_standard: SingleBenchmarkCompareStandard):
|
|
@@ -200,15 +217,16 @@ class SingleBenchmarkAccuracyCompare:
|
|
|
200
217
|
err_for_max = torch.where(abs_err_idx == 1, diff_abs, zeros)
|
|
201
218
|
logging.debug("err_for_max for abs %s", err_for_max)
|
|
202
219
|
max_abs_idx = torch.argmax(err_for_max)
|
|
203
|
-
|
|
220
|
+
if check_idx_valid(diff_abs, max_abs_idx):
|
|
221
|
+
max_abs_diff = diff_abs[max_abs_idx]
|
|
204
222
|
elif torch.sum(abs_mask_idx) > 0:
|
|
205
223
|
err_for_max = torch.where(abs_mask_idx == 1, diff_abs, zeros)
|
|
206
224
|
logging.debug("error_for_max for abs %s", err_for_max)
|
|
207
225
|
max_abs_idx = torch.argmax(err_for_max)
|
|
208
|
-
if err_for_max.max() != 0:
|
|
226
|
+
if err_for_max.max() != 0 and check_idx_valid(diff_abs, max_abs_idx):
|
|
209
227
|
max_abs_diff = diff_abs[max_abs_idx]
|
|
210
228
|
return (float(max_abs_diff), int(max_abs_idx) if torch.is_tensor(max_abs_idx) else max_abs_idx)
|
|
211
|
-
|
|
229
|
+
|
|
212
230
|
@classmethod
|
|
213
231
|
@func_log_wrapper()
|
|
214
232
|
def compute_rel_diff(cls, npu_out, bench_out, error_thd, benchmark_standard: SingleBenchmarkCompareStandard):
|
|
@@ -221,7 +239,7 @@ class SingleBenchmarkAccuracyCompare:
|
|
|
221
239
|
diff_abs = torch.abs(diff_value)
|
|
222
240
|
|
|
223
241
|
rel_mask_idx = torch.where(torch.abs(bench_out) >= benchmark_standard.small_value, ones, zeros)
|
|
224
|
-
rel_err = diff_abs / (torch.abs(bench_out) + torch.finfo(torch.float).eps
|
|
242
|
+
rel_err = diff_abs / (torch.abs(bench_out) + torch.finfo(torch.float).eps)
|
|
225
243
|
diff_rel = rel_err
|
|
226
244
|
rel_err_idx = torch.where(rel_err > error_thd, ones, zeros)
|
|
227
245
|
rel_err_idx = rel_err_idx * rel_mask_idx
|
|
@@ -230,19 +248,20 @@ class SingleBenchmarkAccuracyCompare:
|
|
|
230
248
|
err_for_max = torch.where(rel_err_idx == 1, diff_rel, zeros)
|
|
231
249
|
logging.debug("error_for_max for rel %s", err_for_max)
|
|
232
250
|
max_rel_idx = torch.argmax(err_for_max)
|
|
233
|
-
|
|
251
|
+
if check_idx_valid(diff_rel, max_rel_idx):
|
|
252
|
+
max_rel_diff = diff_rel[max_rel_idx]
|
|
234
253
|
elif torch.sum(rel_mask_idx > 0):
|
|
235
254
|
err_for_max = torch.where(rel_mask_idx == 1, diff_rel, zeros)
|
|
236
255
|
logging.debug("err_for_max for rel %s", err_for_max)
|
|
237
256
|
max_rel_idx = torch.argmax(err_for_max)
|
|
238
|
-
if torch.sum(err_for_max) != 0:
|
|
257
|
+
if torch.sum(err_for_max) != 0 and check_idx_valid(diff_rel, max_rel_idx):
|
|
239
258
|
max_rel_diff = diff_rel[max_rel_idx]
|
|
240
259
|
return (float(max_rel_diff), int(max_rel_idx) if torch.is_tensor(max_rel_idx) else max_rel_idx)
|
|
241
260
|
|
|
242
261
|
|
|
243
262
|
class SingleBenchSummary:
|
|
244
263
|
def __init__(self, precision_result: SingleBenchmarkAccuracyResult, npu_dtype=None,
|
|
245
|
-
|
|
264
|
+
bench_dtype=None, shape=None, error_thd=None, eb_thd=None, failed_info=None):
|
|
246
265
|
self.npu_dtype = npu_dtype
|
|
247
266
|
self.bench_dtype = bench_dtype
|
|
248
267
|
self.shape = shape
|
|
@@ -261,12 +280,13 @@ class SingleBenchSummary:
|
|
|
261
280
|
return "PASS"
|
|
262
281
|
else:
|
|
263
282
|
return "FAILED"
|
|
264
|
-
|
|
283
|
+
|
|
265
284
|
def get_result_msg(self):
|
|
266
285
|
result_str = ""
|
|
267
286
|
if self.failed_info:
|
|
268
|
-
|
|
269
|
-
|
|
287
|
+
result_str = self.failed_info
|
|
288
|
+
return result_str
|
|
289
|
+
|
|
270
290
|
if self.result:
|
|
271
291
|
result_str += "误差均衡性EB: %s <= 阈值%s\n" % (self.error_balance, self.eb_thd)
|
|
272
292
|
result_str += "最大绝对误差: %s <= 阈值%s\n" % (self.max_abs_diff, self.error_thd)
|
|
@@ -290,7 +310,7 @@ class SingleBenchSummary:
|
|
|
290
310
|
self.max_rel_diff,
|
|
291
311
|
)
|
|
292
312
|
return result_str
|
|
293
|
-
|
|
313
|
+
|
|
294
314
|
def print_detail_table(self):
|
|
295
315
|
table = PrettyTable()
|
|
296
316
|
table.title = "Single Benchmark Metrics Info"
|
|
@@ -307,7 +327,7 @@ class SingleBenchSummary:
|
|
|
307
327
|
return [self.bench_dtype, self.npu_dtype, self.shape, self.error_balance,
|
|
308
328
|
self.max_abs_diff, self.max_abs_idx, self.max_rel_diff, self.max_rel_idx,
|
|
309
329
|
self.eb_thd, self.error_thd, self.result, self.failed_info]
|
|
310
|
-
|
|
330
|
+
|
|
311
331
|
|
|
312
332
|
def single_benchmark_compare(npu_out: torch.Tensor, bench_out: torch.Tensor, high_precision: bool = True):
|
|
313
333
|
benchmark_standard = SingleBenchmarkCompareStandard(high_precision)
|
|
@@ -322,8 +342,9 @@ def single_benchmark_compare(npu_out: torch.Tensor, bench_out: torch.Tensor, hig
|
|
|
322
342
|
failed_info
|
|
323
343
|
) = (compare_results.accuracy_result, compare_results.error_threshold,
|
|
324
344
|
compare_results.eb_threshold, compare_results.failed_information)
|
|
325
|
-
|
|
326
|
-
summary = SingleBenchSummary(precision_result, str(npu_out.dtype), str(bench_out.dtype), tuple(npu_out.shape),
|
|
345
|
+
|
|
346
|
+
summary = SingleBenchSummary(precision_result, str(npu_out.dtype), str(bench_out.dtype), tuple(npu_out.shape),
|
|
347
|
+
error_thd, eb_thd, failed_info)
|
|
327
348
|
result = summary.result
|
|
328
349
|
details = summary.to_column_value()
|
|
329
350
|
return result, details
|
|
@@ -349,7 +370,7 @@ def calc_status_details_dict(npu_out, bench_out, summary):
|
|
|
349
370
|
summary.failed_info = "bench and npu_output dict keys are different."
|
|
350
371
|
return False, summary.to_column_value()
|
|
351
372
|
else:
|
|
352
|
-
status, details = single_benchmark_compare_wrap(list(bench_out.values(), list(npu_out.values()))
|
|
373
|
+
status, details = single_benchmark_compare_wrap(list(bench_out.values()), list(npu_out.values()))
|
|
353
374
|
return status, details
|
|
354
375
|
|
|
355
376
|
|
|
@@ -1,7 +1,23 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import inspect
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
2
19
|
import psutil
|
|
3
20
|
import torch
|
|
4
|
-
import numpy as np
|
|
5
21
|
|
|
6
22
|
try:
|
|
7
23
|
import torch_npu
|
|
@@ -11,6 +27,7 @@ else:
|
|
|
11
27
|
pta_cpu_device = torch.device("cpu")
|
|
12
28
|
|
|
13
29
|
from msprobe.core.common.const import CompareConst
|
|
30
|
+
from msprobe.pytorch.common.log import logger
|
|
14
31
|
|
|
15
32
|
cpu_device = torch._C.device("cpu")
|
|
16
33
|
COLOR_RED = '\033[31m'
|
|
@@ -31,24 +48,26 @@ COMPARE_LOGO = '''
|
|
|
31
48
|
|_|
|
|
32
49
|
'''
|
|
33
50
|
|
|
34
|
-
CSV_COLUMN_NAME = [
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
51
|
+
CSV_COLUMN_NAME = [
|
|
52
|
+
CompareConst.NPU_NAME,
|
|
53
|
+
CompareConst.BENCH_NAME,
|
|
54
|
+
CompareConst.NPU_DTYPE,
|
|
55
|
+
CompareConst.BENCH_DTYPE,
|
|
56
|
+
CompareConst.NPU_SHAPE,
|
|
57
|
+
CompareConst.BENCH_SHAPE,
|
|
58
|
+
CompareConst.NPU_MAX,
|
|
59
|
+
CompareConst.NPU_MIN,
|
|
60
|
+
CompareConst.NPU_MEAN,
|
|
61
|
+
CompareConst.BENCH_MAX,
|
|
62
|
+
CompareConst.BENCH_MIN,
|
|
63
|
+
CompareConst.BENCH_MEAN,
|
|
64
|
+
CompareConst.COSINE,
|
|
65
|
+
CompareConst.MAX_ABS_ERR,
|
|
66
|
+
CompareConst.MAX_RELATIVE_ERR,
|
|
67
|
+
CompareConst.ACCURACY,
|
|
68
|
+
CompareConst.STACK,
|
|
69
|
+
CompareConst.ERROR_MESSAGE
|
|
70
|
+
]
|
|
52
71
|
|
|
53
72
|
FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16]
|
|
54
73
|
BOOL_TYPE = [bool, np.uint8]
|
|
@@ -58,8 +77,11 @@ INT_TYPE = [np.int32, np.int64]
|
|
|
58
77
|
def get_callstack():
|
|
59
78
|
callstack = []
|
|
60
79
|
for (_, path, line, func, code, _) in inspect.stack()[2:]:
|
|
61
|
-
|
|
62
|
-
|
|
80
|
+
try:
|
|
81
|
+
stack_line = [path, str(line), func, code[0].strip() if code else code]
|
|
82
|
+
callstack.append(stack_line)
|
|
83
|
+
except IndexError:
|
|
84
|
+
logger.error("Failed to get callstack for code:{} index out of range".format(code))
|
|
63
85
|
return callstack
|
|
64
86
|
|
|
65
87
|
|
|
@@ -125,3 +147,9 @@ class DispatchException(Exception):
|
|
|
125
147
|
|
|
126
148
|
def __str__(self):
|
|
127
149
|
return self.err_msg
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def check_idx_valid(data, idx):
|
|
153
|
+
if data is not None and data.numel() > 0 and 0 <= idx < data.numel():
|
|
154
|
+
return True
|
|
155
|
+
return False
|