mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -16,11 +16,10 @@
|
|
|
16
16
|
import abc
|
|
17
17
|
|
|
18
18
|
import numpy as np
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
from msprobe.core.common.const import Const, CompareConst
|
|
21
21
|
from msprobe.core.common.log import logger
|
|
22
|
-
|
|
23
|
-
from msprobe.core.common.utils import CompareException
|
|
22
|
+
from msprobe.core.common.utils import CompareException, format_value
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
def handle_inf_nan(n_value, b_value):
|
|
@@ -53,66 +52,66 @@ def handle_inf_nan(n_value, b_value):
|
|
|
53
52
|
return n_value, b_value
|
|
54
53
|
|
|
55
54
|
|
|
56
|
-
def
|
|
57
|
-
"""判断数据是否有异常并返回异常的n_value, b_value,同时返回error_flag"""
|
|
55
|
+
def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None):
|
|
56
|
+
"""判断数据是否有异常并返回异常的n_value, b_value,同时返回error_flag和error_msg"""
|
|
57
|
+
err_msg = ""
|
|
58
58
|
if error_flag:
|
|
59
|
-
|
|
59
|
+
if error_file == "no_bench_data":
|
|
60
|
+
err_msg = "Bench does not have data file."
|
|
61
|
+
elif error_file:
|
|
62
|
+
err_msg = f"Dump file: {error_file} not found."
|
|
63
|
+
else:
|
|
64
|
+
err_msg = CompareConst.NO_BENCH
|
|
65
|
+
error_flag = True
|
|
66
|
+
return CompareConst.READ_NONE, CompareConst.READ_NONE, error_flag, err_msg
|
|
67
|
+
|
|
60
68
|
if n_value.size == 0: # 判断读取到的数据是否为空
|
|
61
|
-
|
|
69
|
+
err_msg = "This is empty data, can not compare."
|
|
70
|
+
error_flag = True
|
|
71
|
+
return CompareConst.NONE, CompareConst.NONE, error_flag, err_msg
|
|
72
|
+
if not n_value.shape: # 判断数据是否为0维张量
|
|
73
|
+
err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', "
|
|
74
|
+
f"'{CompareConst.ONE_THOUSANDTH_ERR_RATIO}' and '{CompareConst.FIVE_THOUSANDTHS_ERR_RATIO}'. ")
|
|
75
|
+
error_flag = False # 0-d tensor 最大绝对误差、最大相对误差仍然支持计算,因此error_flag设置为False,不做统一处理
|
|
76
|
+
return n_value, b_value, error_flag, err_msg
|
|
62
77
|
if n_value.shape != b_value.shape: # 判断NPU和bench的数据结构是否一致
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
return
|
|
78
|
+
err_msg = "Shape of NPU and bench tensor do not match. Skipped."
|
|
79
|
+
error_flag = True
|
|
80
|
+
return CompareConst.SHAPE_UNMATCH, CompareConst.SHAPE_UNMATCH, error_flag, err_msg
|
|
66
81
|
|
|
67
82
|
try:
|
|
68
83
|
n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
|
|
69
84
|
except CompareException:
|
|
70
85
|
logger.error('Numpy data is unreadable, please check!')
|
|
71
|
-
|
|
86
|
+
err_msg = "Data is unreadable."
|
|
87
|
+
error_flag = True
|
|
88
|
+
return CompareConst.UNREADABLE, CompareConst.UNREADABLE, error_flag, err_msg
|
|
72
89
|
if n_value is CompareConst.NAN or b_value is CompareConst.NAN:
|
|
73
|
-
|
|
74
|
-
|
|
90
|
+
err_msg = "The position of inf or nan in NPU and bench Tensor do not match."
|
|
91
|
+
error_flag = True
|
|
92
|
+
return CompareConst.NAN, CompareConst.NAN, error_flag, err_msg
|
|
93
|
+
|
|
94
|
+
if n_value.dtype != b_value.dtype: # 判断数据的dtype是否一致
|
|
95
|
+
err_msg = "Dtype of NPU and bench tensor do not match."
|
|
96
|
+
error_flag = False
|
|
97
|
+
return n_value, b_value, error_flag, err_msg
|
|
98
|
+
|
|
99
|
+
return n_value, b_value, error_flag, err_msg
|
|
75
100
|
|
|
76
101
|
|
|
77
102
|
def reshape_value(n_value, b_value):
|
|
78
103
|
"""返回reshape后的数据"""
|
|
79
|
-
if not n_value.shape: #
|
|
104
|
+
if not n_value.shape: # 判断数据是否为0维tensor, 如果0维tensor,不会转成1维tensor,直接返回
|
|
80
105
|
if n_value.dtype == bool:
|
|
81
106
|
n_value = n_value.astype(float)
|
|
82
107
|
b_value = b_value.astype(float)
|
|
83
108
|
return n_value, b_value
|
|
84
109
|
|
|
85
|
-
n_value = n_value.reshape(-1).astype(float)
|
|
110
|
+
n_value = n_value.reshape(-1).astype(float) # 32转64为了防止某些数转dataframe时出现误差
|
|
86
111
|
b_value = b_value.reshape(-1).astype(float)
|
|
87
112
|
return n_value, b_value
|
|
88
113
|
|
|
89
114
|
|
|
90
|
-
def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None):
|
|
91
|
-
"""获取异常情况的错误信息"""
|
|
92
|
-
if error_flag:
|
|
93
|
-
if n_value == CompareConst.READ_NONE:
|
|
94
|
-
if error_file == 'no_bench_data':
|
|
95
|
-
return 'Bench does not have data file.'
|
|
96
|
-
elif error_file is not None:
|
|
97
|
-
return "Dump file: {} not found.".format(error_file)
|
|
98
|
-
return CompareConst.NO_BENCH
|
|
99
|
-
if n_value == CompareConst.NONE:
|
|
100
|
-
return "This is empty data, can not compare."
|
|
101
|
-
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
102
|
-
return "Shape of NPU and bench Tensor do not match. Skipped."
|
|
103
|
-
if n_value == CompareConst.NAN:
|
|
104
|
-
return "The position of inf or nan in NPU and bench Tensor do not match."
|
|
105
|
-
if n_value == CompareConst.UNREADABLE:
|
|
106
|
-
return "The npy data is unable to be read or compared, please check dump data files."
|
|
107
|
-
else:
|
|
108
|
-
if not n_value.shape:
|
|
109
|
-
return "This is type of scalar data, can not compare."
|
|
110
|
-
if n_value.dtype != b_value.dtype:
|
|
111
|
-
logger.warning("Dtype of NPU and bench Tensor do not match: {}".format(npu_op_name))
|
|
112
|
-
return "Dtype of NPU and bench Tensor do not match."
|
|
113
|
-
return ""
|
|
114
|
-
|
|
115
|
-
|
|
116
115
|
def npy_data_check(n_value, b_value):
|
|
117
116
|
error_message = ""
|
|
118
117
|
if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
|
|
@@ -170,10 +169,25 @@ def statistics_data_check(result_dict):
|
|
|
170
169
|
class TensorComparisonBasic(abc.ABC):
|
|
171
170
|
"""NPU和bench中npy数据的比较模板"""
|
|
172
171
|
@abc.abstractmethod
|
|
173
|
-
def apply(self, n_value, b_value,
|
|
172
|
+
def apply(self, n_value, b_value, relative_err):
|
|
174
173
|
raise NotImplementedError
|
|
175
174
|
|
|
176
175
|
|
|
176
|
+
def get_relative_err(n_value, b_value):
|
|
177
|
+
"""计算相对误差"""
|
|
178
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
179
|
+
if b_value.dtype not in CompareConst.FLOAT_TYPE:
|
|
180
|
+
n_value, b_value = n_value.astype(float), b_value.astype(float)
|
|
181
|
+
|
|
182
|
+
n_value_copy = n_value.copy()
|
|
183
|
+
b_value_copy = b_value.copy()
|
|
184
|
+
zero_mask = (b_value_copy == 0)
|
|
185
|
+
b_value_copy[zero_mask] += Const.FLOAT_EPSILON
|
|
186
|
+
n_value_copy[zero_mask] += Const.FLOAT_EPSILON
|
|
187
|
+
relative_err = np.divide((n_value_copy - b_value_copy), b_value_copy)
|
|
188
|
+
return np.abs(relative_err)
|
|
189
|
+
|
|
190
|
+
|
|
177
191
|
class GetCosineSimilarity(TensorComparisonBasic):
|
|
178
192
|
"""计算cosine相似度"""
|
|
179
193
|
@staticmethod
|
|
@@ -184,140 +198,67 @@ class GetCosineSimilarity(TensorComparisonBasic):
|
|
|
184
198
|
return round(float(result), 6)
|
|
185
199
|
return result
|
|
186
200
|
|
|
187
|
-
def apply(self, n_value, b_value,
|
|
188
|
-
if error_flag:
|
|
189
|
-
if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
|
|
190
|
-
return CompareConst.UNSUPPORTED, ''
|
|
191
|
-
if n_value == CompareConst.NONE:
|
|
192
|
-
return CompareConst.UNSUPPORTED, ''
|
|
193
|
-
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
194
|
-
return CompareConst.SHAPE_UNMATCH, ''
|
|
195
|
-
if n_value == CompareConst.NAN:
|
|
196
|
-
return CompareConst.N_A, ''
|
|
197
|
-
|
|
201
|
+
def apply(self, n_value, b_value, relative_err):
|
|
198
202
|
if not n_value.shape:
|
|
199
|
-
return CompareConst.UNSUPPORTED,
|
|
203
|
+
return CompareConst.UNSUPPORTED, ""
|
|
200
204
|
|
|
201
|
-
with np.errstate(divide=
|
|
205
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
202
206
|
if len(n_value) == 1:
|
|
203
|
-
return CompareConst.UNSUPPORTED, "This tensor
|
|
207
|
+
return CompareConst.UNSUPPORTED, "This is a 1-d tensor of length 1."
|
|
204
208
|
num = n_value.dot(b_value)
|
|
205
209
|
a_norm = np.linalg.norm(n_value)
|
|
206
210
|
b_norm = np.linalg.norm(b_value)
|
|
207
211
|
|
|
208
212
|
if a_norm <= Const.FLOAT_EPSILON and b_norm <= Const.FLOAT_EPSILON:
|
|
209
|
-
return 1.0,
|
|
213
|
+
return 1.0, ""
|
|
210
214
|
if a_norm <= Const.FLOAT_EPSILON:
|
|
211
|
-
return CompareConst.NAN,
|
|
215
|
+
return CompareConst.NAN, "Cannot compare by Cosine Similarity, All the data is Zero in npu dump data."
|
|
212
216
|
if b_norm <= Const.FLOAT_EPSILON:
|
|
213
|
-
return CompareConst.NAN,
|
|
217
|
+
return CompareConst.NAN, "Cannot compare by Cosine Similarity, All the data is Zero in Bench dump data."
|
|
214
218
|
|
|
215
219
|
cos = num / (a_norm * b_norm)
|
|
216
220
|
if np.isnan(cos):
|
|
217
|
-
return CompareConst.NAN,
|
|
221
|
+
return CompareConst.NAN, "Cannot compare by Cosine Similarity, the dump data has NaN."
|
|
218
222
|
result = format_value(cos)
|
|
219
223
|
result = self.correct_data(result)
|
|
220
|
-
return
|
|
224
|
+
return result, ""
|
|
221
225
|
|
|
222
226
|
|
|
223
227
|
class GetMaxAbsErr(TensorComparisonBasic):
|
|
224
228
|
"""计算最大绝对误差"""
|
|
225
|
-
def apply(self, n_value, b_value,
|
|
226
|
-
if error_flag:
|
|
227
|
-
if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
|
|
228
|
-
return CompareConst.UNSUPPORTED, ""
|
|
229
|
-
if n_value == CompareConst.NONE:
|
|
230
|
-
return 0, ""
|
|
231
|
-
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
232
|
-
return CompareConst.SHAPE_UNMATCH, ""
|
|
233
|
-
if n_value == CompareConst.NAN:
|
|
234
|
-
return CompareConst.N_A, ""
|
|
235
|
-
|
|
229
|
+
def apply(self, n_value, b_value, relative_err):
|
|
236
230
|
temp_res = n_value - b_value
|
|
237
231
|
max_value = np.max(np.abs(temp_res))
|
|
238
232
|
if np.isnan(max_value):
|
|
239
|
-
|
|
240
|
-
return CompareConst.NAN,
|
|
233
|
+
msg = "Cannot compare by MaxAbsError, the data contains nan/inf/-inf in dump data."
|
|
234
|
+
return CompareConst.NAN, msg
|
|
241
235
|
return format_value(max_value), ""
|
|
242
236
|
|
|
243
237
|
|
|
244
|
-
def get_relative_err(n_value, b_value):
|
|
245
|
-
"""计算相对误差"""
|
|
246
|
-
with np.errstate(divide='ignore', invalid='ignore'):
|
|
247
|
-
if b_value.dtype not in CompareConst.FLOAT_TYPE:
|
|
248
|
-
n_value, b_value = n_value.astype(float), b_value.astype(float)
|
|
249
|
-
zero_mask = (b_value == 0)
|
|
250
|
-
b_value[zero_mask] += np.finfo(b_value.dtype).eps
|
|
251
|
-
n_value[zero_mask] += np.finfo(b_value.dtype).eps
|
|
252
|
-
relative_err = np.divide((n_value - b_value), b_value)
|
|
253
|
-
return np.abs(relative_err)
|
|
254
|
-
|
|
255
|
-
|
|
256
238
|
class GetMaxRelativeErr(TensorComparisonBasic):
|
|
257
239
|
"""计算最大相对误差"""
|
|
258
|
-
def apply(self, n_value, b_value,
|
|
259
|
-
if error_flag:
|
|
260
|
-
if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
|
|
261
|
-
return CompareConst.UNSUPPORTED, ''
|
|
262
|
-
if n_value == CompareConst.NONE:
|
|
263
|
-
return 0, ''
|
|
264
|
-
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
265
|
-
return CompareConst.SHAPE_UNMATCH, ''
|
|
266
|
-
if n_value == CompareConst.NAN:
|
|
267
|
-
return CompareConst.N_A, ''
|
|
268
|
-
|
|
269
|
-
if relative_err is None:
|
|
270
|
-
relative_err = get_relative_err(n_value, b_value)
|
|
240
|
+
def apply(self, n_value, b_value, relative_err):
|
|
271
241
|
max_relative_err = np.max(np.abs(relative_err))
|
|
272
242
|
if np.isnan(max_relative_err):
|
|
273
|
-
|
|
274
|
-
return CompareConst.NAN,
|
|
275
|
-
return format_value(max_relative_err),
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
class GetThousandErrRatio(TensorComparisonBasic):
|
|
279
|
-
"""计算相对误差小于千分之一的比例"""
|
|
280
|
-
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
281
|
-
if error_flag:
|
|
282
|
-
if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
|
|
283
|
-
return CompareConst.UNSUPPORTED, ""
|
|
284
|
-
if n_value == CompareConst.NONE:
|
|
285
|
-
return 0, ""
|
|
286
|
-
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
287
|
-
return CompareConst.SHAPE_UNMATCH, ""
|
|
288
|
-
if n_value == CompareConst.NAN:
|
|
289
|
-
return CompareConst.N_A, ""
|
|
243
|
+
msg = "Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data."
|
|
244
|
+
return CompareConst.NAN, msg
|
|
245
|
+
return format_value(max_relative_err), ""
|
|
290
246
|
|
|
291
|
-
if not n_value.shape:
|
|
292
|
-
return CompareConst.NAN, ""
|
|
293
|
-
if relative_err is None:
|
|
294
|
-
relative_err = get_relative_err(n_value, b_value)
|
|
295
|
-
if not np.size(relative_err):
|
|
296
|
-
return CompareConst.NAN, ""
|
|
297
|
-
return format_value(np.sum(relative_err < CompareConst.THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
class GetFiveThousandErrRatio(TensorComparisonBasic):
|
|
301
|
-
"""计算相对误差小于千分之五的比例"""
|
|
302
|
-
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
303
|
-
if error_flag:
|
|
304
|
-
if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
|
|
305
|
-
return CompareConst.UNSUPPORTED, ""
|
|
306
|
-
if n_value == CompareConst.NONE:
|
|
307
|
-
return 0, ""
|
|
308
|
-
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
309
|
-
return CompareConst.SHAPE_UNMATCH, ""
|
|
310
|
-
if n_value == CompareConst.NAN:
|
|
311
|
-
return CompareConst.N_A, ""
|
|
312
247
|
|
|
248
|
+
class GetErrRatio(TensorComparisonBasic):
|
|
249
|
+
"""计算相对误差小于指定阈值(千分之一、千分之五)的比例"""
|
|
250
|
+
def __init__(self, threshold):
|
|
251
|
+
self.threshold = threshold
|
|
252
|
+
|
|
253
|
+
def apply(self, n_value, b_value, relative_err):
|
|
313
254
|
if not n_value.shape:
|
|
314
|
-
return CompareConst.
|
|
315
|
-
|
|
316
|
-
relative_err = get_relative_err(n_value, b_value)
|
|
255
|
+
return CompareConst.UNSUPPORTED, ""
|
|
256
|
+
|
|
317
257
|
if not np.size(relative_err):
|
|
318
258
|
return CompareConst.NAN, ""
|
|
319
|
-
|
|
320
|
-
|
|
259
|
+
|
|
260
|
+
ratio = np.sum(relative_err < self.threshold) / np.size(relative_err)
|
|
261
|
+
return format_value(ratio), ""
|
|
321
262
|
|
|
322
263
|
|
|
323
264
|
class CompareOps:
|
|
@@ -325,15 +266,36 @@ class CompareOps:
|
|
|
325
266
|
"cosine_similarity": GetCosineSimilarity(),
|
|
326
267
|
"max_abs_error": GetMaxAbsErr(),
|
|
327
268
|
"max_relative_error": GetMaxRelativeErr(),
|
|
328
|
-
"one_thousand_err_ratio":
|
|
329
|
-
"five_thousand_err_ratio":
|
|
269
|
+
"one_thousand_err_ratio": GetErrRatio(CompareConst.THOUSAND_RATIO_THRESHOLD),
|
|
270
|
+
"five_thousand_err_ratio": GetErrRatio(CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD)
|
|
330
271
|
}
|
|
331
272
|
|
|
332
273
|
|
|
333
|
-
def
|
|
274
|
+
def error_value_process(n_value):
|
|
275
|
+
if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
|
|
276
|
+
return CompareConst.UNSUPPORTED, ""
|
|
277
|
+
if n_value == CompareConst.NONE:
|
|
278
|
+
return 0, ""
|
|
279
|
+
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
280
|
+
return CompareConst.SHAPE_UNMATCH, ""
|
|
281
|
+
if n_value == CompareConst.NAN:
|
|
282
|
+
return CompareConst.N_A, ""
|
|
283
|
+
return CompareConst.N_A, ""
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def compare_ops_apply(n_value, b_value, error_flag, err_msg):
|
|
334
287
|
result_list = []
|
|
288
|
+
if error_flag:
|
|
289
|
+
result, msg = error_value_process(n_value)
|
|
290
|
+
result_list = [result] * len(CompareOps.compare_ops)
|
|
291
|
+
err_msg += msg * len(CompareOps.compare_ops)
|
|
292
|
+
return result_list, err_msg
|
|
293
|
+
|
|
294
|
+
relative_err = get_relative_err(n_value, b_value)
|
|
295
|
+
n_value, b_value = reshape_value(n_value, b_value)
|
|
296
|
+
|
|
335
297
|
for op in CompareOps.compare_ops.values():
|
|
336
|
-
result, msg = op.apply(n_value, b_value,
|
|
337
|
-
err_msg += msg
|
|
298
|
+
result, msg = op.apply(n_value, b_value, relative_err)
|
|
338
299
|
result_list.append(result)
|
|
300
|
+
err_msg += msg
|
|
339
301
|
return result_list, err_msg
|