mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
msprobe/pytorch/__init__.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
|
1
|
-
|
|
2
|
-
# -*- coding: utf-8 -*-
|
|
3
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
4
2
|
# All rights reserved.
|
|
5
3
|
#
|
|
6
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -16,8 +14,12 @@
|
|
|
16
14
|
# limitations under the License.
|
|
17
15
|
|
|
18
16
|
|
|
19
|
-
|
|
20
|
-
from .common.utils import seed_all
|
|
17
|
+
import torch
|
|
21
18
|
from .compare.distributed_compare import compare_distributed
|
|
22
19
|
from .compare.pt_compare import compare
|
|
23
|
-
from .
|
|
20
|
+
from .common.utils import seed_all
|
|
21
|
+
from .debugger.precision_debugger import PrecisionDebugger, module_dump, module_dump_end
|
|
22
|
+
|
|
23
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
24
|
+
if torch_version_above_or_equal_2:
|
|
25
|
+
from msprobe.pytorch.monitor.module_hook import TrainerMon
|
|
@@ -16,10 +16,18 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
+
from collections import namedtuple
|
|
19
20
|
from msprobe.core.common.file_utils import load_yaml, check_file_or_directory_path
|
|
21
|
+
from msprobe.core.common.utils import is_int
|
|
20
22
|
from msprobe.pytorch.pt_config import RunUTConfig
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
RunUtConfig = namedtuple('RunUtConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
|
|
26
|
+
'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
|
|
27
|
+
'black_list', 'error_data_path', 'online_config'])
|
|
28
|
+
OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
|
|
29
|
+
|
|
30
|
+
|
|
23
31
|
class Config:
|
|
24
32
|
def __init__(self, yaml_file):
|
|
25
33
|
check_file_or_directory_path(yaml_file, False)
|
|
@@ -50,6 +58,8 @@ class Config:
|
|
|
50
58
|
raise ValueError(f"{key} must be one of {validators.keys()}")
|
|
51
59
|
if not isinstance(value, validators.get(key)):
|
|
52
60
|
raise ValueError(f"{key} must be {validators[key].__name__} type")
|
|
61
|
+
if key == 'precision' and not is_int(value):
|
|
62
|
+
raise ValueError("precision must be an integer")
|
|
53
63
|
if key == 'precision' and (value < 0 or value > 20):
|
|
54
64
|
raise ValueError("precision must be greater than or equal to 0 and less than 21")
|
|
55
65
|
if key == 'white_list':
|
|
@@ -68,3 +78,55 @@ class Config:
|
|
|
68
78
|
cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
69
79
|
yaml_path = os.path.join(cur_path, "config.yaml")
|
|
70
80
|
msCheckerConfig = Config(yaml_path)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class CheckerConfig:
|
|
84
|
+
def __init__(self, task_config=None):
|
|
85
|
+
self.white_list = msCheckerConfig.white_list
|
|
86
|
+
self.black_list = msCheckerConfig.black_list
|
|
87
|
+
self.error_data_path = msCheckerConfig.error_data_path
|
|
88
|
+
self.is_online = msCheckerConfig.is_online
|
|
89
|
+
self.nfs_path = msCheckerConfig.nfs_path
|
|
90
|
+
self.host = msCheckerConfig.host
|
|
91
|
+
self.port = msCheckerConfig.port
|
|
92
|
+
self.rank_list = msCheckerConfig.rank_list
|
|
93
|
+
self.tls_path = msCheckerConfig.tls_path
|
|
94
|
+
|
|
95
|
+
if task_config:
|
|
96
|
+
self.load_config(task_config)
|
|
97
|
+
|
|
98
|
+
def load_config(self, task_config):
|
|
99
|
+
self.white_list = task_config.white_list
|
|
100
|
+
self.black_list = task_config.black_list
|
|
101
|
+
self.error_data_path = task_config.error_data_path
|
|
102
|
+
self.is_online = task_config.is_online
|
|
103
|
+
self.nfs_path = task_config.nfs_path
|
|
104
|
+
self.host = task_config.host
|
|
105
|
+
self.port = task_config.port
|
|
106
|
+
self.rank_list = task_config.rank_list
|
|
107
|
+
self.tls_path = task_config.tls_path
|
|
108
|
+
|
|
109
|
+
def get_online_config(self):
|
|
110
|
+
return OnlineConfig(
|
|
111
|
+
is_online=self.is_online,
|
|
112
|
+
nfs_path=self.nfs_path,
|
|
113
|
+
host=self.host,
|
|
114
|
+
port=self.port,
|
|
115
|
+
rank_list=self.rank_list,
|
|
116
|
+
tls_path=self.tls_path
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def get_run_ut_config(self, **config_params):
|
|
120
|
+
return RunUtConfig(
|
|
121
|
+
forward_content=config_params.get('forward_content'),
|
|
122
|
+
backward_content=config_params.get('backward_content'),
|
|
123
|
+
result_csv_path=config_params.get('result_csv_path'),
|
|
124
|
+
details_csv_path=config_params.get('details_csv_path'),
|
|
125
|
+
save_error_data=config_params.get('save_error_data'),
|
|
126
|
+
is_continue_run_ut=config_params.get('is_continue_run_ut'),
|
|
127
|
+
real_data_path=config_params.get('real_data_path'),
|
|
128
|
+
white_list=self.white_list,
|
|
129
|
+
black_list=self.black_list,
|
|
130
|
+
error_data_path=config_params.get('error_data_path'),
|
|
131
|
+
online_config=self.get_online_config()
|
|
132
|
+
)
|
|
@@ -72,38 +72,53 @@ def check_need_convert(api_name):
|
|
|
72
72
|
return convert_type
|
|
73
73
|
|
|
74
74
|
|
|
75
|
-
def
|
|
75
|
+
def cross_entropy_process(api_info_dict):
|
|
76
76
|
"""
|
|
77
77
|
Function Description:
|
|
78
|
-
Preprocesses the API information.
|
|
78
|
+
Preprocesses the cross_entropy API information.
|
|
79
79
|
Parameter:
|
|
80
|
-
api_name: Name of the API.
|
|
81
80
|
api_info_dict: argument of the API.
|
|
82
81
|
Return api_info_dict:
|
|
83
|
-
convert_type: Type of conversion.
|
|
84
82
|
api_info_dict: Processed argument of the API.
|
|
85
83
|
"""
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
api_info_dict
|
|
89
|
-
|
|
84
|
+
if 'input_args' in api_info_dict and len(api_info_dict['input_args']) > 1 \
|
|
85
|
+
and 'Min' in api_info_dict['input_args'][1]:
|
|
86
|
+
if api_info_dict['input_args'][1]['Min'] <= 0:
|
|
87
|
+
# The second argument in cross_entropy should be -100 or not less than 0
|
|
88
|
+
api_info_dict['input_args'][1]['Min'] = 0
|
|
89
|
+
return api_info_dict
|
|
90
90
|
|
|
91
91
|
|
|
92
|
-
def
|
|
92
|
+
def histc_process(api_info_dict):
|
|
93
|
+
input_args = api_info_dict['input_args']
|
|
94
|
+
if input_args and input_args[0].get('dtype'):
|
|
95
|
+
dtype = input_args[0]['dtype']
|
|
96
|
+
if dtype in Const.TORCH_INT_DTYPE:
|
|
97
|
+
api_info_dict['input_args'][0]['dtype'] = Const.TORCH_FLOAT32
|
|
98
|
+
return api_info_dict
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
API_PROCESS_MAP = {
|
|
102
|
+
'cross_entropy': cross_entropy_process,
|
|
103
|
+
'histc': histc_process
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def api_info_preprocess(api_name, api_info_dict):
|
|
93
108
|
"""
|
|
94
109
|
Function Description:
|
|
95
|
-
Preprocesses the
|
|
110
|
+
Preprocesses the API information.
|
|
96
111
|
Parameter:
|
|
112
|
+
api_name: Name of the API.
|
|
97
113
|
api_info_dict: argument of the API.
|
|
98
114
|
Return api_info_dict:
|
|
115
|
+
convert_type: Type of conversion.
|
|
99
116
|
api_info_dict: Processed argument of the API.
|
|
100
117
|
"""
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
api_info_dict['input_args'][1]['Min'] = 0
|
|
106
|
-
return api_info_dict
|
|
118
|
+
convert_type = check_need_convert(api_name)
|
|
119
|
+
if api_name in API_PROCESS_MAP:
|
|
120
|
+
api_info_dict = API_PROCESS_MAP[api_name](api_info_dict)
|
|
121
|
+
return convert_type, api_info_dict
|
|
107
122
|
|
|
108
123
|
|
|
109
124
|
def initialize_save_path(save_path, dir_name):
|
|
@@ -16,10 +16,12 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
# 定义比对算法及比对标准
|
|
19
|
+
import math
|
|
19
20
|
import torch
|
|
20
21
|
import numpy as np
|
|
21
22
|
|
|
22
23
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAMETERS
|
|
24
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
|
|
23
25
|
from msprobe.core.common.const import CompareConst
|
|
24
26
|
|
|
25
27
|
|
|
@@ -179,13 +181,13 @@ def check_inf_nan_value(inf_nan_mask, bench_output, device_output, dtype, rtol):
|
|
|
179
181
|
|
|
180
182
|
def check_small_value(abs_err, small_value_mask, small_value_atol):
|
|
181
183
|
'''
|
|
182
|
-
|
|
184
|
+
新精度标准的绝对阈值法中,检查npu和golden正常值输出的绝对误差是否满足阈值
|
|
183
185
|
输入:
|
|
184
|
-
|
|
186
|
+
abs_err:npu输出和golden输出的绝对误差
|
|
185
187
|
normal_value_mask:npu输出和golden输出的正常值mask
|
|
186
|
-
|
|
188
|
+
atol:绝对误差的阈值
|
|
187
189
|
输出:
|
|
188
|
-
|
|
190
|
+
abs_err_ratio:npu输出和golden输出的绝对误差不满足阈值的比例
|
|
189
191
|
'''
|
|
190
192
|
greater_mask = np.greater(abs_err, small_value_atol)
|
|
191
193
|
err_mask = np.logical_and(greater_mask, small_value_mask)
|
|
@@ -195,13 +197,13 @@ def check_small_value(abs_err, small_value_mask, small_value_atol):
|
|
|
195
197
|
|
|
196
198
|
def check_norm_value(normal_value_mask, rel_err, rtol):
|
|
197
199
|
'''
|
|
198
|
-
|
|
200
|
+
新精度标准的相对阈值法中,检查npu和golden小值域输出的相对误差是否满足阈值
|
|
199
201
|
输入:
|
|
200
|
-
|
|
202
|
+
rel_err:npu输出和golden输出的相对误差
|
|
201
203
|
normal_value_mask:npu输出和golden输出的正常值mask
|
|
202
|
-
|
|
204
|
+
rtol:相对误差的阈值
|
|
203
205
|
输出:
|
|
204
|
-
|
|
206
|
+
rel_err_ratio:npu输出和golden输出的相对误差不满足阈值的比例
|
|
205
207
|
'''
|
|
206
208
|
err_mask = np.greater(rel_err, rtol)
|
|
207
209
|
err_mask = np.logical_and(err_mask, normal_value_mask)
|
|
@@ -228,3 +230,34 @@ def get_ulp_err(bench_output, device_output, dtype):
|
|
|
228
230
|
def calc_ulp_err(bench_output, device_output, eb, exponent_num, data_type):
|
|
229
231
|
return (device_output.astype(data_type) - bench_output).astype(data_type) * \
|
|
230
232
|
np.exp2(-eb + exponent_num).astype(data_type)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def calc_ratio(x, y, dtype):
|
|
236
|
+
"""
|
|
237
|
+
Calculate the ratio between NPU and GPU statistical values.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
x (float): Statistical value from the NPU side
|
|
241
|
+
y (float): Statistical value from the GPU side
|
|
242
|
+
dtype: Data type used to determine the minimum error value
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
float: The ratio of NPU to GPU statistical values
|
|
246
|
+
|
|
247
|
+
Notes:
|
|
248
|
+
- Takes absolute values of both x and y for calculation
|
|
249
|
+
- Uses StandardConfig.get_minmum_err(dtype) to get minimum error for the specified dtype
|
|
250
|
+
- Prevents division by zero by ensuring denominator is not less than minimum error
|
|
251
|
+
- Returns |x| / max(|y|, minimum_error)
|
|
252
|
+
"""
|
|
253
|
+
x, y = abs(x), abs(y)
|
|
254
|
+
minmum_err = StandardConfig.get_minmum_err(dtype)
|
|
255
|
+
err_y = max(y, minmum_err)
|
|
256
|
+
return x / err_y
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def compare_bool_tensor(bench_output, device_output):
|
|
260
|
+
error_nums = (bench_output != device_output).sum()
|
|
261
|
+
error_rate = float(error_nums / bench_output.size)
|
|
262
|
+
result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
|
|
263
|
+
return error_rate, result, ""
|