mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
import numpy as np
|
|
20
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import convert_str_to_float
|
|
21
|
+
from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_bench_with_eps, get_abs_err, \
|
|
22
|
+
get_finite_and_infinite_mask, get_small_value_mask
|
|
23
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class BaseCompare(ABC):
|
|
27
|
+
"""
|
|
28
|
+
Base comparison class for benchmarking and device output.
|
|
29
|
+
|
|
30
|
+
This class provides a foundation for comparing benchmark outputs with device outputs.
|
|
31
|
+
It encapsulates the common logic for calculating accuracy metrics and
|
|
32
|
+
provides a framework for subclasses to implement specific comparison logic.
|
|
33
|
+
|
|
34
|
+
Attributes:
|
|
35
|
+
bench_output (np.ndarray): The output from the benchmark.
|
|
36
|
+
device_output (np.ndarray): The output from the device.
|
|
37
|
+
compare_column (object): The column object to store comparison results.
|
|
38
|
+
dtype (torch.dtype): The data type of the outputs.
|
|
39
|
+
|
|
40
|
+
Methods:
|
|
41
|
+
get_small_value_threshold(): Retrieves the small value threshold for the given data type.
|
|
42
|
+
stat_abs_bench_with_eps(): Calculates the absolute benchmark output with epsilon.
|
|
43
|
+
stat_abs_error(): Calculates the absolute error between the benchmark and device outputs.
|
|
44
|
+
stat_finite_and_infinite_mask(): Generates masks for finite and infinite/NaN values.
|
|
45
|
+
stat_small_value_mask(abs_bench, both_finite_mask, small_value): Creates a mask for small values.
|
|
46
|
+
compare(): Performs the comparison and computes metrics.
|
|
47
|
+
_pre_compare(): Pre-comparison hook for subclass-specific initialization.
|
|
48
|
+
_compute_metrics(): Computes the comparison metrics.
|
|
49
|
+
_post_compare(metrics): Post-comparison hook to update comparison results.
|
|
50
|
+
|
|
51
|
+
Note:
|
|
52
|
+
This class assumes that the input data is an instance of InputData containing the benchmark output,
|
|
53
|
+
device output, comparison column, and data type. Subclasses should implement the _pre_compare,
|
|
54
|
+
_compute_metrics, and _post_compare methods to provide specific comparison logic.
|
|
55
|
+
|
|
56
|
+
See Also:
|
|
57
|
+
InputData: The class containing input data for comparison.
|
|
58
|
+
StandardConfig: The class containing standard configuration values.
|
|
59
|
+
"""
|
|
60
|
+
def __init__(self, input_data):
|
|
61
|
+
self.bench_output = input_data.bench_output
|
|
62
|
+
self.device_output = input_data.device_output
|
|
63
|
+
self.compare_column = input_data.compare_column
|
|
64
|
+
self.dtype = input_data.dtype
|
|
65
|
+
self.compare_algorithm = None
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def stat_small_value_mask(abs_bench, both_finite_mask, small_value):
|
|
69
|
+
small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value)
|
|
70
|
+
return small_value_mask
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def _get_rel_err(abs_err, abs_bench_with_eps):
|
|
74
|
+
rel_err = abs_err / abs_bench_with_eps
|
|
75
|
+
return rel_err
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def _get_normal_value_mask(both_finite_mask, small_value_mask):
|
|
79
|
+
return np.logical_and(both_finite_mask, np.logical_not(small_value_mask))
|
|
80
|
+
|
|
81
|
+
@abstractmethod
|
|
82
|
+
def _pre_compare(self):
|
|
83
|
+
raise NotImplementedError
|
|
84
|
+
|
|
85
|
+
def get_small_value_threshold(self):
|
|
86
|
+
small_value = StandardConfig.get_small_value(self.dtype, self.compare_algorithm)
|
|
87
|
+
small_value_atol = StandardConfig.get_small_value_atol(self.dtype, self.compare_algorithm)
|
|
88
|
+
return small_value, small_value_atol
|
|
89
|
+
|
|
90
|
+
def stat_abs_bench_with_eps(self):
|
|
91
|
+
abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(self.bench_output, self.dtype)
|
|
92
|
+
return abs_bench, abs_bench_with_eps
|
|
93
|
+
|
|
94
|
+
def stat_abs_error(self):
|
|
95
|
+
abs_err = get_abs_err(self.bench_output, self.device_output)
|
|
96
|
+
return abs_err
|
|
97
|
+
|
|
98
|
+
def stat_finite_and_infinite_mask(self):
|
|
99
|
+
both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(self.bench_output, self.device_output)
|
|
100
|
+
return both_finite_mask, inf_nan_mask
|
|
101
|
+
|
|
102
|
+
def compare(self):
|
|
103
|
+
self._pre_compare()
|
|
104
|
+
metrics = self._compute_metrics()
|
|
105
|
+
self._post_compare(metrics)
|
|
106
|
+
|
|
107
|
+
def _compute_metrics(self):
|
|
108
|
+
return {}
|
|
109
|
+
|
|
110
|
+
def _post_compare(self, metrics):
|
|
111
|
+
self.compare_column.update(metrics)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class BasePrecisionCompare:
|
|
115
|
+
def __init__(self, input_data):
|
|
116
|
+
self.row_npu = input_data.row_npu
|
|
117
|
+
self.row_gpu = input_data.row_gpu
|
|
118
|
+
self.dtype = input_data.dtype
|
|
119
|
+
self.compare_column = input_data.compare_column
|
|
120
|
+
self.compare_algorithm = None
|
|
121
|
+
|
|
122
|
+
@abstractmethod
|
|
123
|
+
def _get_status(self, metrics, inf_nan_consistency):
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
@abstractmethod
|
|
127
|
+
def _compute_ratio(self):
|
|
128
|
+
pass
|
|
129
|
+
|
|
130
|
+
def compare(self):
|
|
131
|
+
metrics, inf_nan_consistency = self._compute_ratio()
|
|
132
|
+
compare_result = self._post_compare(metrics, inf_nan_consistency)
|
|
133
|
+
return compare_result
|
|
134
|
+
|
|
135
|
+
def _get_and_convert_values(self, column_name):
|
|
136
|
+
npu_value = self.row_npu.get(column_name)
|
|
137
|
+
gpu_value = self.row_gpu.get(column_name)
|
|
138
|
+
if npu_value is None:
|
|
139
|
+
raise ValueError(f"NPU value for column '{column_name}' is None.")
|
|
140
|
+
if gpu_value is None:
|
|
141
|
+
raise ValueError(f"GPU value for column '{column_name}' is None.")
|
|
142
|
+
npu_value = convert_str_to_float(npu_value)
|
|
143
|
+
gpu_value = convert_str_to_float(gpu_value)
|
|
144
|
+
return npu_value, gpu_value
|
|
145
|
+
|
|
146
|
+
def _post_compare(self, metrics, inf_nan_consistency):
|
|
147
|
+
metrics = self._get_status(metrics, inf_nan_consistency)
|
|
148
|
+
metrics.update({'compare_algorithm': self.compare_algorithm})
|
|
149
|
+
self.compare_column.update(metrics)
|
|
150
|
+
compare_result = metrics.get('compare_result')
|
|
151
|
+
return compare_result
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
import math
|
|
19
|
+
from collections import namedtuple
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
|
|
23
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare, BasePrecisionCompare
|
|
24
|
+
from msprobe.pytorch.api_accuracy_checker.compare.algorithm import calc_ratio, get_small_value_err_ratio, get_rel_err, \
|
|
25
|
+
get_rmse, get_error_balance, get_max_rel_err, get_mean_rel_err
|
|
26
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ApiPrecisionCompareColumn, check_inf_or_nan, \
|
|
27
|
+
is_inf_or_nan
|
|
28
|
+
from msprobe.core.common.const import CompareConst
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
BenchmarkInfNanConsistency = namedtuple('BenchmarkInfNanConsistency', ['small_value_inf_nan_consistency',
|
|
32
|
+
'rmse_inf_nan_consistency',
|
|
33
|
+
'max_rel_inf_nan_consistency',
|
|
34
|
+
'mean_rel_inf_nan_consistency',
|
|
35
|
+
'eb_inf_nan_consistency'])
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class BenchmarkCompare(BaseCompare):
|
|
39
|
+
"""
|
|
40
|
+
Benchmark comparison class for calculating accuracy metrics.
|
|
41
|
+
|
|
42
|
+
This class is designed to compare the output of a benchmark test with the output of a device.
|
|
43
|
+
It calculates various metrics such as small value error ratio, RMSE, error balance, max relative error,
|
|
44
|
+
and mean relative error to assess the accuracy of the device output against the benchmark output.
|
|
45
|
+
|
|
46
|
+
Attributes:
|
|
47
|
+
bench_output (np.ndarray): The output from the benchmark.
|
|
48
|
+
device_output (np.ndarray): The output from the device.
|
|
49
|
+
dtype (torch.dtype): The data type of the outputs.
|
|
50
|
+
abs_bench (np.ndarray): The absolute value of the benchmark output.
|
|
51
|
+
abs_bench_with_eps (np.ndarray): The absolute value of the benchmark output with epsilon.
|
|
52
|
+
both_finite_mask (np.ndarray): A mask indicating where both outputs are finite.
|
|
53
|
+
inf_nan_mask (np.ndarray): A mask indicating where either output is infinite or NaN.
|
|
54
|
+
abs_err (np.ndarray): The absolute error between the benchmark and device outputs.
|
|
55
|
+
small_value (float): The small value threshold for comparison.
|
|
56
|
+
small_value_atol (float): The absolute tolerance for small values.
|
|
57
|
+
small_value_mask (np.ndarray): A mask indicating where values are small.
|
|
58
|
+
rel_err (np.ndarray): The relative error between the benchmark and device outputs.
|
|
59
|
+
abs_err_greater_mask (np.ndarray): A mask indicating where absolute error is greater than the small value
|
|
60
|
+
tolerance.
|
|
61
|
+
|
|
62
|
+
Methods:
|
|
63
|
+
_get_abs_err_greater_mask(small_value_atol): Calculates a mask where absolute error is greater than the small
|
|
64
|
+
value tolerance.
|
|
65
|
+
_compute_rel_err(): Computes the relative error between the benchmark and device outputs.
|
|
66
|
+
_pre_compare(): Prepares the comparison by calculating various metrics.
|
|
67
|
+
_compute_metrics(): Computes the accuracy metrics.
|
|
68
|
+
|
|
69
|
+
Note:
|
|
70
|
+
This class assumes that the input data is a dictionary containing 'bench_output', 'device_output',
|
|
71
|
+
'compare_column' and 'dtype'.
|
|
72
|
+
The data type should be a PyTorch data type.
|
|
73
|
+
|
|
74
|
+
See Also:
|
|
75
|
+
BaseCompare: The base class for comparison classes.
|
|
76
|
+
InputData: The class containing input data for comparison.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(self, input_data):
|
|
80
|
+
super(BenchmarkCompare, self).__init__(input_data)
|
|
81
|
+
self.compare_algorithm = CompareConst.BENCHMARK
|
|
82
|
+
|
|
83
|
+
def _get_abs_err_greater_mask(self, small_value_atol):
|
|
84
|
+
abs_err_greater_mask = np.greater(self.abs_err, small_value_atol)
|
|
85
|
+
return abs_err_greater_mask
|
|
86
|
+
|
|
87
|
+
def _compute_rel_err(self):
|
|
88
|
+
rel_err = get_rel_err(self.abs_err, self.abs_bench_with_eps, self.small_value_mask, self.inf_nan_mask)
|
|
89
|
+
return rel_err
|
|
90
|
+
|
|
91
|
+
def _pre_compare(self):
|
|
92
|
+
self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps()
|
|
93
|
+
self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask()
|
|
94
|
+
self.abs_err = self.stat_abs_error()
|
|
95
|
+
self.small_value, self.small_value_atol = self.get_small_value_threshold()
|
|
96
|
+
self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value)
|
|
97
|
+
self.rel_err = self._compute_rel_err()
|
|
98
|
+
self.abs_err_greater_mask = self._get_abs_err_greater_mask(self.small_value_atol)
|
|
99
|
+
|
|
100
|
+
def _compute_metrics(self):
|
|
101
|
+
"""
|
|
102
|
+
Computes a comprehensive set of error metrics for the comparison between benchmark and device outputs.
|
|
103
|
+
|
|
104
|
+
This method calculates five key metrics:
|
|
105
|
+
1. Small Value Error Ratio: The proportion of errors associated with small values.
|
|
106
|
+
2. Root Mean Square Error (RMSE): The square root of the mean of the squared errors.
|
|
107
|
+
3. Error Balance (EB): A measure of the balance between the errors in the benchmark and device outputs.
|
|
108
|
+
4. Maximum Relative Error: The maximum relative error between the benchmark and device outputs.
|
|
109
|
+
5. Mean Relative Error: The mean relative error between the benchmark and device outputs.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
dict: A dictionary containing the computed error metrics.
|
|
113
|
+
The dictionary has the following keys:
|
|
114
|
+
- "small_value_err_ratio": The proportion of errors associated with small values.
|
|
115
|
+
- "max_rel_error": The maximum relative error.
|
|
116
|
+
- "mean_rel_error": The mean relative error.
|
|
117
|
+
- "rmse": The root mean square error.
|
|
118
|
+
- "eb": The error balance.
|
|
119
|
+
"""
|
|
120
|
+
small_value_err_ratio = get_small_value_err_ratio(self.small_value_mask, self.abs_err_greater_mask)
|
|
121
|
+
rmse = get_rmse(self.abs_err, np.logical_or(self.inf_nan_mask, self.small_value_mask))
|
|
122
|
+
eb = get_error_balance(self.bench_output, self.device_output)
|
|
123
|
+
max_rel_error = get_max_rel_err(self.rel_err)
|
|
124
|
+
mean_rel_error = get_mean_rel_err(self.rel_err)
|
|
125
|
+
|
|
126
|
+
return {
|
|
127
|
+
"small_value_err_ratio": small_value_err_ratio,
|
|
128
|
+
"max_rel_error": max_rel_error,
|
|
129
|
+
"mean_rel_error": mean_rel_error,
|
|
130
|
+
"rmse": rmse,
|
|
131
|
+
"eb": eb
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class BenchmarkPrecisionCompare(BasePrecisionCompare):
|
|
136
|
+
def __init__(self, input_data):
|
|
137
|
+
super().__init__(input_data)
|
|
138
|
+
self.compare_algorithm = CompareConst.BENCHMARK_COMPARE_ALGORITHM_NAME
|
|
139
|
+
|
|
140
|
+
@staticmethod
|
|
141
|
+
def get_final_status(status_list):
|
|
142
|
+
compare_result = CompareConst.PASS
|
|
143
|
+
if CompareConst.ERROR in status_list:
|
|
144
|
+
compare_result = CompareConst.ERROR
|
|
145
|
+
elif CompareConst.WARNING in status_list:
|
|
146
|
+
compare_result = CompareConst.WARNING
|
|
147
|
+
return compare_result
|
|
148
|
+
|
|
149
|
+
def _calc_ratio(self, column_name):
|
|
150
|
+
npu_value, gpu_value = self._get_and_convert_values(column_name)
|
|
151
|
+
if is_inf_or_nan(npu_value) or is_inf_or_nan(gpu_value):
|
|
152
|
+
return check_inf_or_nan(npu_value, gpu_value, column_name)
|
|
153
|
+
else:
|
|
154
|
+
return calc_ratio(npu_value, gpu_value, str(self.dtype)), True, ""
|
|
155
|
+
|
|
156
|
+
def _compute_ratio(self):
|
|
157
|
+
compare_message = ""
|
|
158
|
+
small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = \
|
|
159
|
+
self._calc_ratio(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE)
|
|
160
|
+
compare_message += small_value_message
|
|
161
|
+
rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE)
|
|
162
|
+
compare_message += rmse_message
|
|
163
|
+
max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = \
|
|
164
|
+
self._calc_ratio(ApiPrecisionCompareColumn.MAX_REL_ERR)
|
|
165
|
+
compare_message += max_rel_message
|
|
166
|
+
mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = \
|
|
167
|
+
self._calc_ratio(ApiPrecisionCompareColumn.MEAN_REL_ERR)
|
|
168
|
+
compare_message += mean_rel_message
|
|
169
|
+
eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB)
|
|
170
|
+
compare_message += eb_message
|
|
171
|
+
|
|
172
|
+
metrics = {
|
|
173
|
+
CompareConst.SMALL_VALUE_ERR_RATIO: small_value_err_ratio,
|
|
174
|
+
CompareConst.RMSE_RATIO: rmse_ratio,
|
|
175
|
+
CompareConst.MAX_REL_ERR_RATIO: max_rel_err_ratio,
|
|
176
|
+
CompareConst.MEAN_REL_ERR_RATIO: mean_rel_err_ratio,
|
|
177
|
+
CompareConst.EB_RATIO: eb_ratio,
|
|
178
|
+
CompareConst.COMPARE_MESSAGE: compare_message
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
return metrics, \
|
|
182
|
+
BenchmarkInfNanConsistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
|
|
183
|
+
max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency,
|
|
184
|
+
eb_inf_nan_consistency)
|
|
185
|
+
|
|
186
|
+
def _get_threshold(self, metric):
|
|
187
|
+
error_threshold = StandardConfig.get_benchmark_threshold(metric)
|
|
188
|
+
return error_threshold
|
|
189
|
+
|
|
190
|
+
def _get_single_metric_status(self, ratio, metric):
|
|
191
|
+
if is_inf_or_nan(ratio):
|
|
192
|
+
return CompareConst.PASS
|
|
193
|
+
error_threshold = self._get_threshold(metric)
|
|
194
|
+
if ratio > error_threshold:
|
|
195
|
+
return CompareConst.ERROR
|
|
196
|
+
return CompareConst.PASS
|
|
197
|
+
|
|
198
|
+
def _get_status(self, metrics, inf_nan_consistency):
|
|
199
|
+
small_value_err_ratio = metrics.get(CompareConst.SMALL_VALUE_ERR_RATIO)
|
|
200
|
+
rmse_ratio = metrics.get(CompareConst.RMSE_RATIO)
|
|
201
|
+
max_rel_err_ratio = metrics.get(CompareConst.MAX_REL_ERR_RATIO)
|
|
202
|
+
mean_rel_err_ratio = metrics.get(CompareConst.MEAN_REL_ERR_RATIO)
|
|
203
|
+
eb_ratio = metrics.get(CompareConst.EB_RATIO)
|
|
204
|
+
|
|
205
|
+
small_value_err_status = self._get_single_metric_status(small_value_err_ratio, CompareConst.SMALL_VALUE) \
|
|
206
|
+
if inf_nan_consistency.small_value_inf_nan_consistency else CompareConst.ERROR
|
|
207
|
+
rmse_status = self._get_single_metric_status(rmse_ratio, CompareConst.RMSE) \
|
|
208
|
+
if inf_nan_consistency.rmse_inf_nan_consistency else CompareConst.ERROR
|
|
209
|
+
max_rel_err_status = self._get_single_metric_status(max_rel_err_ratio, CompareConst.MAX_REL_ERR) \
|
|
210
|
+
if inf_nan_consistency.max_rel_inf_nan_consistency else CompareConst.ERROR
|
|
211
|
+
mean_rel_err_status = self._get_single_metric_status(mean_rel_err_ratio, CompareConst.MEAN_REL_ERR) \
|
|
212
|
+
if inf_nan_consistency.mean_rel_inf_nan_consistency else CompareConst.ERROR
|
|
213
|
+
eb_status = self._get_single_metric_status(eb_ratio, CompareConst.EB) \
|
|
214
|
+
if inf_nan_consistency.eb_inf_nan_consistency else CompareConst.ERROR
|
|
215
|
+
status_list = [small_value_err_status, rmse_status, max_rel_err_status, mean_rel_err_status]
|
|
216
|
+
compare_result = self.get_final_status(status_list)
|
|
217
|
+
status_dict = {
|
|
218
|
+
CompareConst.SMALL_VALUE_ERR_STATUS: small_value_err_status,
|
|
219
|
+
CompareConst.RMSE_STATUS: rmse_status,
|
|
220
|
+
CompareConst.MAX_REL_ERR_STATUS: max_rel_err_status,
|
|
221
|
+
CompareConst.MEAN_REL_ERR_STATUS: mean_rel_err_status,
|
|
222
|
+
CompareConst.EB_STATUS: eb_status
|
|
223
|
+
}
|
|
224
|
+
metrics.update(status_dict)
|
|
225
|
+
metrics.update({CompareConst.COMPARE_RESULT: compare_result})
|
|
226
|
+
return metrics
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
from msprobe.pytorch.api_accuracy_checker.compare.algorithm import compare_bool_tensor
|
|
19
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BinaryCompare(BaseCompare):
|
|
23
|
+
"""
|
|
24
|
+
Binary comparison class for comparing boolean tensors.
|
|
25
|
+
|
|
26
|
+
This class is designed to compare the output of a binary operation between a benchmark and a device.
|
|
27
|
+
It calculates the error rate of the comparison and provides a simple metric for assessing the accuracy.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
bench_output (np.ndarray): The output from the benchmark.
|
|
31
|
+
device_output (np.ndarray): The output from the device.
|
|
32
|
+
compare_column (object): The column object to store comparison results.
|
|
33
|
+
dtype (torch.dtype): The data type of the outputs.
|
|
34
|
+
|
|
35
|
+
Methods:
|
|
36
|
+
_compute_metrics(): Computes the comparison metrics, specifically the error rate.
|
|
37
|
+
|
|
38
|
+
Note:
|
|
39
|
+
This class assumes that the input data is an instance of InputData containing the benchmark output,
|
|
40
|
+
device output, comparison column, and data type. The outputs are expected to be boolean tensors.
|
|
41
|
+
|
|
42
|
+
See Also:
|
|
43
|
+
BaseCompare: The base class for comparison classes.
|
|
44
|
+
compare_bool_tensor: The function used to compare boolean tensors.
|
|
45
|
+
"""
|
|
46
|
+
def __init__(self, input_data):
|
|
47
|
+
super(BinaryCompare, self).__init__(input_data)
|
|
48
|
+
|
|
49
|
+
def _pre_compare(self):
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
def _compute_metrics(self):
|
|
53
|
+
"""
|
|
54
|
+
Computes the error rate metric for the comparison between benchmark and device outputs.
|
|
55
|
+
|
|
56
|
+
This method calculates the proportion of mismatches between the benchmark output and the device output.
|
|
57
|
+
It uses the `compare_bool_tensor` function to compare the two tensors and extract the error rate.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
dict: A dictionary containing the computed error rate metric.
|
|
61
|
+
The dictionary has the following key:
|
|
62
|
+
- "error_rate": The proportion of mismatches between the benchmark and device outputs.
|
|
63
|
+
"""
|
|
64
|
+
error_rate, _, _ = compare_bool_tensor(self.bench_output, self.device_output)
|
|
65
|
+
|
|
66
|
+
return {
|
|
67
|
+
"error_rate": error_rate
|
|
68
|
+
}
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.const import CompareConst
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class StandardConfig:
|
|
24
|
+
"""
|
|
25
|
+
Standard configuration class for managing precision and comparison thresholds.
|
|
26
|
+
|
|
27
|
+
This class provides a centralized way to manage the small value thresholds, absolute tolerances,
|
|
28
|
+
and relative tolerances (rtol) used in precision comparisons. It allows for different thresholds
|
|
29
|
+
based on the data type, with default values provided for common data types.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
_small_value (dict): A dictionary mapping data types to their corresponding small value thresholds.
|
|
33
|
+
_small_value_atol (dict): A dictionary mapping data types to their corresponding absolute tolerances.
|
|
34
|
+
_rtol (dict): A dictionary mapping data types to their corresponding relative tolerances.
|
|
35
|
+
|
|
36
|
+
Methods:
|
|
37
|
+
get_small_value(dtype): Retrieves the small value threshold for the given data type.
|
|
38
|
+
get_small_value_atol(dtype): Retrieves the absolute tolerance for the given data type.
|
|
39
|
+
get_rtol(dtype): Retrieves the relative tolerance for the given data type.
|
|
40
|
+
|
|
41
|
+
Example:
|
|
42
|
+
>>> small_value = StandardConfig.get_small_value(torch.float32)
|
|
43
|
+
>>> atol = StandardConfig.get_small_value_atol(torch.float32)
|
|
44
|
+
>>> rtol = StandardConfig.get_rtol(torch.float32)
|
|
45
|
+
>>> print(small_value, atol, rtol)
|
|
46
|
+
1e-6 1e-9 1e-6
|
|
47
|
+
|
|
48
|
+
Note:
|
|
49
|
+
The data type is expected to be a PyTorch data type. If the data type is not found in the dictionary,
|
|
50
|
+
the default value is returned.
|
|
51
|
+
|
|
52
|
+
See Also:
|
|
53
|
+
torch.dtype: PyTorch data types.
|
|
54
|
+
"""
|
|
55
|
+
_small_value = {
|
|
56
|
+
torch.float16: 2**-10,
|
|
57
|
+
torch.bfloat16: 2**-10,
|
|
58
|
+
torch.float32: 2**-20,
|
|
59
|
+
"default": 2**-20
|
|
60
|
+
}
|
|
61
|
+
_threshold_small_value_atol = {
|
|
62
|
+
torch.float16: 2**-16,
|
|
63
|
+
torch.bfloat16: 1e-16,
|
|
64
|
+
torch.float32: 2**-30,
|
|
65
|
+
"default": 2**-30
|
|
66
|
+
}
|
|
67
|
+
_benchmark_small_value_atol = {
|
|
68
|
+
torch.float16: 1e-16,
|
|
69
|
+
torch.bfloat16: 1e-16,
|
|
70
|
+
torch.float32: 2**-30,
|
|
71
|
+
"default": 2**-30
|
|
72
|
+
}
|
|
73
|
+
_rtol = {
|
|
74
|
+
torch.float16: 2**-10,
|
|
75
|
+
torch.bfloat16: 2**-8,
|
|
76
|
+
torch.float32: 2**-20,
|
|
77
|
+
"default": 2**-20
|
|
78
|
+
}
|
|
79
|
+
_accumulative_error_bound = {
|
|
80
|
+
torch.float16: 2**-8,
|
|
81
|
+
torch.bfloat16: 2**-7,
|
|
82
|
+
torch.float32: 2**-11,
|
|
83
|
+
"default": 2**-11
|
|
84
|
+
}
|
|
85
|
+
_small_value_threshold = {
|
|
86
|
+
'error_threshold': 2,
|
|
87
|
+
'warning_threshold': 1,
|
|
88
|
+
"default": 1
|
|
89
|
+
}
|
|
90
|
+
_rmse_threshold = {
|
|
91
|
+
'error_threshold': 2,
|
|
92
|
+
'warning_threshold': 1,
|
|
93
|
+
"default": 1
|
|
94
|
+
}
|
|
95
|
+
_max_rel_err_threshold = {
|
|
96
|
+
'error_threshold': 10,
|
|
97
|
+
'warning_threshold': 1,
|
|
98
|
+
"default": 1
|
|
99
|
+
}
|
|
100
|
+
_mean_rel_err_threshold = {
|
|
101
|
+
'error_threshold': 2,
|
|
102
|
+
'warning_threshold': 1,
|
|
103
|
+
"default": 1
|
|
104
|
+
}
|
|
105
|
+
_eb_threshold = {
|
|
106
|
+
'error_threshold': 2,
|
|
107
|
+
'warning_threshold': 1,
|
|
108
|
+
"default": 1
|
|
109
|
+
}
|
|
110
|
+
_minmum_err = {
|
|
111
|
+
'torch.float16': 2**-11,
|
|
112
|
+
'torch.bfloat16': 2**-8,
|
|
113
|
+
'torch.float32': 2**-14,
|
|
114
|
+
'default': 2**-14
|
|
115
|
+
}
|
|
116
|
+
_accumulative_error_eb_threshold = {
|
|
117
|
+
'torch.float16': 2**-20,
|
|
118
|
+
'torch.bfloat16': 2**-7,
|
|
119
|
+
'torch.float32': 2**-14,
|
|
120
|
+
'default': 2**-14
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
_fp32_mean_ulp_err_threshold = 64
|
|
124
|
+
ulp_err_proportion_ratio = 1
|
|
125
|
+
_fp32_ulp_err_proportion = 0.05
|
|
126
|
+
_fp16_ulp_err_proportion = 0.001
|
|
127
|
+
_special_samll_value = 1
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def get_small_value(cls, dtype, standard):
|
|
131
|
+
if standard == CompareConst.ACCUMULATIVE_ERROR_COMPARE:
|
|
132
|
+
return cls._special_samll_value
|
|
133
|
+
return cls._small_value.get(dtype, cls._small_value["default"])
|
|
134
|
+
|
|
135
|
+
@classmethod
|
|
136
|
+
def get_small_value_atol(cls, dtype, standard):
|
|
137
|
+
standard_dict = {
|
|
138
|
+
CompareConst.ABSOLUTE_THRESHOLD: cls._threshold_small_value_atol,
|
|
139
|
+
CompareConst.BENCHMARK: cls._benchmark_small_value_atol
|
|
140
|
+
}
|
|
141
|
+
small_value_atol_standard = standard_dict.get(standard, cls._benchmark_small_value_atol)
|
|
142
|
+
return small_value_atol_standard.get(dtype, small_value_atol_standard["default"])
|
|
143
|
+
|
|
144
|
+
@classmethod
|
|
145
|
+
def get_rtol(cls, dtype):
|
|
146
|
+
return cls._rtol.get(dtype, cls._rtol["default"])
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def get_small_value_threshold(cls, threshold_type):
|
|
150
|
+
return cls._small_value_threshold.get(threshold_type, cls._small_value_threshold["default"])
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def get_rmse_threshold(cls, threshold_type):
|
|
154
|
+
return cls._rmse_threshold.get(threshold_type, cls._rmse_threshold["default"])
|
|
155
|
+
|
|
156
|
+
@classmethod
|
|
157
|
+
def get_max_rel_err_threshold(cls, threshold_type):
|
|
158
|
+
return cls._max_rel_err_threshold.get(threshold_type, cls._max_rel_err_threshold["default"])
|
|
159
|
+
|
|
160
|
+
@classmethod
|
|
161
|
+
def get_mean_rel_err_threshold(cls, threshold_type):
|
|
162
|
+
return cls._mean_rel_err_threshold.get(threshold_type, cls._mean_rel_err_threshold["default"])
|
|
163
|
+
|
|
164
|
+
@classmethod
|
|
165
|
+
def get_eb_threshold(cls, threshold_type):
|
|
166
|
+
return cls._eb_threshold.get(threshold_type, cls._eb_threshold["default"])
|
|
167
|
+
|
|
168
|
+
@classmethod
|
|
169
|
+
def get_benchmark_threshold(cls, metric):
|
|
170
|
+
metric_threshold_functions = {
|
|
171
|
+
'small_value': StandardConfig.get_small_value_threshold,
|
|
172
|
+
'rmse': StandardConfig.get_rmse_threshold,
|
|
173
|
+
'max_rel_err': StandardConfig.get_max_rel_err_threshold,
|
|
174
|
+
'mean_rel_err': StandardConfig.get_mean_rel_err_threshold,
|
|
175
|
+
'eb': StandardConfig.get_eb_threshold
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
threshold_func = metric_threshold_functions.get(metric)
|
|
179
|
+
return threshold_func('error_threshold')
|
|
180
|
+
|
|
181
|
+
@classmethod
|
|
182
|
+
def get_fp32_mean_ulp_err_threshold(cls):
|
|
183
|
+
return cls._fp32_mean_ulp_err_threshold
|
|
184
|
+
|
|
185
|
+
@classmethod
|
|
186
|
+
def get_ulp_err_proportion_ratio_threshold(cls):
|
|
187
|
+
return cls.ulp_err_proportion_ratio
|
|
188
|
+
|
|
189
|
+
@classmethod
|
|
190
|
+
def get_fp32_ulp_err_proportion_threshold(cls):
|
|
191
|
+
return cls._fp32_ulp_err_proportion
|
|
192
|
+
|
|
193
|
+
@classmethod
|
|
194
|
+
def get_fp16_ulp_err_proportion_threshold(cls):
|
|
195
|
+
return cls._fp16_ulp_err_proportion
|
|
196
|
+
|
|
197
|
+
@classmethod
|
|
198
|
+
def get_ulp_threshold(cls, dtype):
|
|
199
|
+
ulp_err_proportion_ratio_threshold = StandardConfig.get_ulp_err_proportion_ratio_threshold()
|
|
200
|
+
if dtype == torch.float32:
|
|
201
|
+
mean_ulp_err_threshold = StandardConfig.get_fp32_mean_ulp_err_threshold()
|
|
202
|
+
ulp_err_proportion_threshold = StandardConfig.get_fp32_ulp_err_proportion_threshold()
|
|
203
|
+
return mean_ulp_err_threshold, ulp_err_proportion_threshold, ulp_err_proportion_ratio_threshold
|
|
204
|
+
else:
|
|
205
|
+
ulp_err_proportion_threshold = StandardConfig.get_fp16_ulp_err_proportion_threshold()
|
|
206
|
+
return None, ulp_err_proportion_threshold, ulp_err_proportion_ratio_threshold
|
|
207
|
+
|
|
208
|
+
@classmethod
|
|
209
|
+
def get_minmum_err(cls, dtype):
|
|
210
|
+
return cls._minmum_err.get(dtype, cls._minmum_err["default"])
|
|
211
|
+
|
|
212
|
+
@classmethod
|
|
213
|
+
def get_accumulative_error_bound(cls, dtype):
|
|
214
|
+
return cls._accumulative_error_bound.get(dtype, cls._accumulative_error_bound["default"])
|
|
215
|
+
|
|
216
|
+
@classmethod
|
|
217
|
+
def get_accumulative_error_eb_threshold(cls, dtype):
|
|
218
|
+
return cls._accumulative_error_eb_threshold.get(dtype, cls._accumulative_error_eb_threshold["default"])
|