mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -1,16 +1,45 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import abc
|
|
17
|
+
|
|
2
18
|
import numpy as np
|
|
3
19
|
from msprobe.core.common.utils import format_value
|
|
4
20
|
from msprobe.core.common.const import Const, CompareConst
|
|
5
21
|
from msprobe.core.common.log import logger
|
|
6
22
|
|
|
23
|
+
from msprobe.core.common.utils import CompareException
|
|
24
|
+
|
|
7
25
|
|
|
8
26
|
def handle_inf_nan(n_value, b_value):
|
|
27
|
+
def convert_to_float(value):
|
|
28
|
+
try:
|
|
29
|
+
if isinstance(value, np.ndarray):
|
|
30
|
+
return value.astype(float)
|
|
31
|
+
else:
|
|
32
|
+
return float(value)
|
|
33
|
+
except ValueError as e:
|
|
34
|
+
logger.error('\n'.join(e.args))
|
|
35
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
36
|
+
|
|
37
|
+
n_value_convert, b_value_convert = convert_to_float(n_value), convert_to_float(b_value)
|
|
9
38
|
"""处理inf和nan的数据"""
|
|
10
|
-
n_inf = np.isinf(
|
|
11
|
-
b_inf = np.isinf(
|
|
12
|
-
n_nan = np.isnan(
|
|
13
|
-
b_nan = np.isnan(
|
|
39
|
+
n_inf = np.isinf(n_value_convert)
|
|
40
|
+
b_inf = np.isinf(b_value_convert)
|
|
41
|
+
n_nan = np.isnan(n_value_convert)
|
|
42
|
+
b_nan = np.isnan(b_value_convert)
|
|
14
43
|
n_invalid = np.any(n_inf) or np.any(n_nan)
|
|
15
44
|
b_invalid = np.any(b_inf) or np.any(b_nan)
|
|
16
45
|
if n_invalid or b_invalid:
|
|
@@ -35,7 +64,11 @@ def get_error_type(n_value, b_value, error_flag):
|
|
|
35
64
|
if not n_value.shape: # 判断数据是否为标量
|
|
36
65
|
return n_value, b_value, False
|
|
37
66
|
|
|
38
|
-
|
|
67
|
+
try:
|
|
68
|
+
n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
|
|
69
|
+
except CompareException:
|
|
70
|
+
logger.error('Numpy data is unreadable, please check!')
|
|
71
|
+
return CompareConst.UNREADABLE, CompareConst.UNREADABLE, True
|
|
39
72
|
if n_value is CompareConst.NAN or b_value is CompareConst.NAN:
|
|
40
73
|
return CompareConst.NAN, CompareConst.NAN, True
|
|
41
74
|
return n_value, b_value, False
|
|
@@ -58,7 +91,9 @@ def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None
|
|
|
58
91
|
"""获取异常情况的错误信息"""
|
|
59
92
|
if error_flag:
|
|
60
93
|
if n_value == CompareConst.READ_NONE:
|
|
61
|
-
if error_file:
|
|
94
|
+
if error_file == 'no_bench_data':
|
|
95
|
+
return 'Bench does not have data file.'
|
|
96
|
+
elif error_file is not None:
|
|
62
97
|
return "Dump file: {} not found.".format(error_file)
|
|
63
98
|
return CompareConst.NO_BENCH
|
|
64
99
|
if n_value == CompareConst.NONE:
|
|
@@ -67,6 +102,8 @@ def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None
|
|
|
67
102
|
return "Shape of NPU and bench Tensor do not match. Skipped."
|
|
68
103
|
if n_value == CompareConst.NAN:
|
|
69
104
|
return "The position of inf or nan in NPU and bench Tensor do not match."
|
|
105
|
+
if n_value == CompareConst.UNREADABLE:
|
|
106
|
+
return "The npy data is unable to be read or compared, please check dump data files."
|
|
70
107
|
else:
|
|
71
108
|
if not n_value.shape:
|
|
72
109
|
return "This is type of scalar data, can not compare."
|
|
@@ -78,10 +115,8 @@ def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None
|
|
|
78
115
|
|
|
79
116
|
def npy_data_check(n_value, b_value):
|
|
80
117
|
error_message = ""
|
|
81
|
-
if n_value
|
|
82
|
-
error_message += "Dump file not
|
|
83
|
-
if n_value == "" or b_value == "":
|
|
84
|
-
error_message += "Dump file not found.\n"
|
|
118
|
+
if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
|
|
119
|
+
error_message += "Dump file is not ndarray.\n"
|
|
85
120
|
|
|
86
121
|
# 检查 n_value 和 b_value 是否为空
|
|
87
122
|
if not error_message and (n_value.size == 0 or b_value.size == 0):
|
|
@@ -96,8 +131,13 @@ def npy_data_check(n_value, b_value):
|
|
|
96
131
|
error_message += "Dtype of NPU and bench Tensor do not match. Skipped.\n"
|
|
97
132
|
|
|
98
133
|
if not error_message:
|
|
99
|
-
|
|
100
|
-
|
|
134
|
+
try:
|
|
135
|
+
n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
|
|
136
|
+
except CompareException:
|
|
137
|
+
logger.error('Numpy data is unreadable, please check!')
|
|
138
|
+
return True, 'Numpy data is unreadable, please check!'
|
|
139
|
+
# handle_inf_nan 会返回'Nan'或ndarray类型,使用类型判断是否存在无法处理的nan/inf数据
|
|
140
|
+
if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
|
|
101
141
|
error_message += "The position of inf or nan in NPU and bench Tensor do not match.\n"
|
|
102
142
|
if error_message == "":
|
|
103
143
|
error_flag = False
|
|
@@ -146,14 +186,14 @@ class GetCosineSimilarity(TensorComparisonBasic):
|
|
|
146
186
|
|
|
147
187
|
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
148
188
|
if error_flag:
|
|
149
|
-
if n_value == CompareConst.READ_NONE:
|
|
150
|
-
return CompareConst.
|
|
189
|
+
if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
|
|
190
|
+
return CompareConst.UNSUPPORTED, ''
|
|
151
191
|
if n_value == CompareConst.NONE:
|
|
152
192
|
return CompareConst.UNSUPPORTED, ''
|
|
153
193
|
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
154
194
|
return CompareConst.SHAPE_UNMATCH, ''
|
|
155
195
|
if n_value == CompareConst.NAN:
|
|
156
|
-
return
|
|
196
|
+
return CompareConst.N_A, ''
|
|
157
197
|
|
|
158
198
|
if not n_value.shape:
|
|
159
199
|
return CompareConst.UNSUPPORTED, ''
|
|
@@ -184,17 +224,20 @@ class GetMaxAbsErr(TensorComparisonBasic):
|
|
|
184
224
|
"""计算最大绝对误差"""
|
|
185
225
|
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
186
226
|
if error_flag:
|
|
187
|
-
if n_value == CompareConst.READ_NONE:
|
|
188
|
-
return CompareConst.
|
|
227
|
+
if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
|
|
228
|
+
return CompareConst.UNSUPPORTED, ""
|
|
189
229
|
if n_value == CompareConst.NONE:
|
|
190
230
|
return 0, ""
|
|
191
231
|
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
192
232
|
return CompareConst.SHAPE_UNMATCH, ""
|
|
193
233
|
if n_value == CompareConst.NAN:
|
|
194
|
-
return
|
|
234
|
+
return CompareConst.N_A, ""
|
|
195
235
|
|
|
196
236
|
temp_res = n_value - b_value
|
|
197
237
|
max_value = np.max(np.abs(temp_res))
|
|
238
|
+
if np.isnan(max_value):
|
|
239
|
+
message = 'Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data.'
|
|
240
|
+
return CompareConst.NAN, message
|
|
198
241
|
return format_value(max_value), ""
|
|
199
242
|
|
|
200
243
|
|
|
@@ -214,20 +257,20 @@ class GetMaxRelativeErr(TensorComparisonBasic):
|
|
|
214
257
|
"""计算最大相对误差"""
|
|
215
258
|
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
216
259
|
if error_flag:
|
|
217
|
-
if n_value == CompareConst.READ_NONE:
|
|
218
|
-
return CompareConst.
|
|
260
|
+
if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
|
|
261
|
+
return CompareConst.UNSUPPORTED, ''
|
|
219
262
|
if n_value == CompareConst.NONE:
|
|
220
263
|
return 0, ''
|
|
221
264
|
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
222
265
|
return CompareConst.SHAPE_UNMATCH, ''
|
|
223
266
|
if n_value == CompareConst.NAN:
|
|
224
|
-
return
|
|
267
|
+
return CompareConst.N_A, ''
|
|
225
268
|
|
|
226
269
|
if relative_err is None:
|
|
227
270
|
relative_err = get_relative_err(n_value, b_value)
|
|
228
271
|
max_relative_err = np.max(np.abs(relative_err))
|
|
229
272
|
if np.isnan(max_relative_err):
|
|
230
|
-
message = 'Cannot compare by MaxRelativeError, the data contains nan in dump data.'
|
|
273
|
+
message = 'Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data.'
|
|
231
274
|
return CompareConst.NAN, message
|
|
232
275
|
return format_value(max_relative_err), ''
|
|
233
276
|
|
|
@@ -236,14 +279,14 @@ class GetThousandErrRatio(TensorComparisonBasic):
|
|
|
236
279
|
"""计算相对误差小于千分之一的比例"""
|
|
237
280
|
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
238
281
|
if error_flag:
|
|
239
|
-
if n_value == CompareConst.READ_NONE:
|
|
240
|
-
return CompareConst.
|
|
282
|
+
if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
|
|
283
|
+
return CompareConst.UNSUPPORTED, ""
|
|
241
284
|
if n_value == CompareConst.NONE:
|
|
242
285
|
return 0, ""
|
|
243
286
|
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
244
287
|
return CompareConst.SHAPE_UNMATCH, ""
|
|
245
288
|
if n_value == CompareConst.NAN:
|
|
246
|
-
return
|
|
289
|
+
return CompareConst.N_A, ""
|
|
247
290
|
|
|
248
291
|
if not n_value.shape:
|
|
249
292
|
return CompareConst.NAN, ""
|
|
@@ -258,14 +301,14 @@ class GetFiveThousandErrRatio(TensorComparisonBasic):
|
|
|
258
301
|
"""计算相对误差小于千分之五的比例"""
|
|
259
302
|
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
260
303
|
if error_flag:
|
|
261
|
-
if n_value == CompareConst.READ_NONE:
|
|
262
|
-
return CompareConst.
|
|
304
|
+
if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
|
|
305
|
+
return CompareConst.UNSUPPORTED, ""
|
|
263
306
|
if n_value == CompareConst.NONE:
|
|
264
307
|
return 0, ""
|
|
265
308
|
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
266
309
|
return CompareConst.SHAPE_UNMATCH, ""
|
|
267
310
|
if n_value == CompareConst.NAN:
|
|
268
|
-
return
|
|
311
|
+
return CompareConst.N_A, ""
|
|
269
312
|
|
|
270
313
|
if not n_value.shape:
|
|
271
314
|
return CompareConst.NAN, ""
|
|
@@ -273,7 +316,8 @@ class GetFiveThousandErrRatio(TensorComparisonBasic):
|
|
|
273
316
|
relative_err = get_relative_err(n_value, b_value)
|
|
274
317
|
if not np.size(relative_err):
|
|
275
318
|
return CompareConst.NAN, ""
|
|
276
|
-
return format_value(
|
|
319
|
+
return format_value(
|
|
320
|
+
np.sum(relative_err < CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
|
|
277
321
|
|
|
278
322
|
|
|
279
323
|
class CompareOps:
|