mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- msprobe/README.md +46 -16
- msprobe/__init__.py +16 -1
- msprobe/config.json +0 -2
- msprobe/core/advisor/advisor.py +8 -8
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +64 -3
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +54 -9
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +21 -11
- msprobe/core/common/utils.py +153 -167
- msprobe/core/common_config.py +18 -25
- msprobe/core/compare/acc_compare.py +209 -36
- msprobe/core/compare/check.py +102 -17
- msprobe/core/compare/compare_cli.py +21 -1
- msprobe/core/compare/highlight.py +41 -5
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +21 -6
- msprobe/core/compare/utils.py +82 -48
- msprobe/core/data_dump/data_collector.py +31 -32
- msprobe/core/data_dump/data_processor/base.py +45 -22
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
- msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +32 -16
- msprobe/core/grad_probe/constant.py +4 -0
- msprobe/core/grad_probe/grad_compare.py +2 -3
- msprobe/core/grad_probe/utils.py +16 -3
- msprobe/docs/01.installation.md +19 -9
- msprobe/docs/02.config_introduction.md +52 -80
- msprobe/docs/03.config_examples.md +3 -13
- msprobe/docs/04.acl_config_examples.md +11 -9
- msprobe/docs/05.data_dump_PyTorch.md +140 -12
- msprobe/docs/06.data_dump_MindSpore.md +47 -5
- msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/docs/17.grad_probe.md +14 -16
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
- msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
- msprobe/mindspore/cell_processor.py +27 -3
- msprobe/mindspore/common/const.py +2 -0
- msprobe/mindspore/common/utils.py +18 -2
- msprobe/mindspore/compare/distributed_compare.py +9 -22
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +173 -35
- msprobe/mindspore/compare/ms_graph_compare.py +27 -11
- msprobe/mindspore/debugger/debugger_config.py +16 -13
- msprobe/mindspore/debugger/precision_debugger.py +37 -13
- msprobe/mindspore/dump/dump_tool_factory.py +16 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +41 -17
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
- msprobe/mindspore/free_benchmark/common/utils.py +19 -5
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
- msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
- msprobe/mindspore/grad_probe/global_context.py +18 -8
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/service.py +42 -123
- msprobe/pytorch/__init__.py +20 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +20 -5
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +26 -11
- msprobe/pytorch/common/utils.py +40 -35
- msprobe/pytorch/compare/distributed_compare.py +11 -11
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +38 -6
- msprobe/pytorch/debugger/debugger_config.py +52 -39
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/enums.py +28 -0
- msprobe/pytorch/free_benchmark/common/params.py +15 -0
- msprobe/pytorch/free_benchmark/common/utils.py +17 -1
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +10 -11
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +17 -2
- msprobe/pytorch/online_dispatch/compare.py +11 -12
- msprobe/pytorch/online_dispatch/single_compare.py +7 -7
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
- msprobe/pytorch/online_dispatch/utils.py +1 -4
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +9 -10
- msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
- msprobe/pytorch/parse_tool/lib/utils.py +28 -24
- msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
- msprobe/pytorch/pt_config.py +167 -38
- msprobe/pytorch/service.py +97 -32
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
2
|
# -*- coding: utf-8 -*-
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
#
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
7
|
# you may not use this file except in compliance with the License.
|
|
7
8
|
# You may obtain a copy of the License at
|
|
8
9
|
#
|
|
@@ -13,10 +14,11 @@
|
|
|
13
14
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
15
|
# See the License for the specific language governing permissions and
|
|
15
16
|
# limitations under the License.
|
|
16
|
-
|
|
17
|
+
|
|
17
18
|
import os
|
|
18
19
|
import re
|
|
19
20
|
from collections import namedtuple
|
|
21
|
+
import importlib
|
|
20
22
|
|
|
21
23
|
import torch
|
|
22
24
|
|
|
@@ -96,7 +98,8 @@ def cross_entropy_process(api_info_dict):
|
|
|
96
98
|
Return api_info_dict:
|
|
97
99
|
api_info_dict: Processed argument of the API.
|
|
98
100
|
"""
|
|
99
|
-
if 'input_args' in api_info_dict and len(api_info_dict['input_args']) > 1
|
|
101
|
+
if 'input_args' in api_info_dict and len(api_info_dict['input_args']) > 1 \
|
|
102
|
+
and 'Min' in api_info_dict['input_args'][1]:
|
|
100
103
|
if api_info_dict['input_args'][1]['Min'] <= 0:
|
|
101
104
|
# The second argument in cross_entropy should be -100 or not less than 0
|
|
102
105
|
api_info_dict['input_args'][1]['Min'] = 0
|
|
@@ -109,18 +112,6 @@ def initialize_save_path(save_path, dir_name):
|
|
|
109
112
|
return data_path
|
|
110
113
|
|
|
111
114
|
|
|
112
|
-
def get_real_data_path(file_path):
|
|
113
|
-
targets = ['forward_real_data', 'backward_real_data', 'ut_error_data\d+']
|
|
114
|
-
pattern = re.compile(r'({})'.format('|'.join(targets)))
|
|
115
|
-
match = pattern.search(file_path)
|
|
116
|
-
if match:
|
|
117
|
-
target_index = match.start()
|
|
118
|
-
target_path = file_path[target_index:]
|
|
119
|
-
return target_path
|
|
120
|
-
else:
|
|
121
|
-
raise DumpException(DumpException.INVALID_PATH_ERROR)
|
|
122
|
-
|
|
123
|
-
|
|
124
115
|
def get_full_data_path(data_path, real_data_path):
|
|
125
116
|
if not data_path:
|
|
126
117
|
return data_path
|
|
@@ -137,7 +128,10 @@ class UtDataProcessor:
|
|
|
137
128
|
self.index = 0
|
|
138
129
|
self._save_recursive(api_name, element)
|
|
139
130
|
|
|
140
|
-
def _save_recursive(self, api_name, element):
|
|
131
|
+
def _save_recursive(self, api_name, element, depth=0):
|
|
132
|
+
if depth > Const.MAX_DEPTH:
|
|
133
|
+
logger.error(f"Maximum depth of {Const.MAX_DEPTH} exceeded for {api_name}")
|
|
134
|
+
raise DumpException(DumpException.RECURSION_LIMIT_ERROR)
|
|
141
135
|
if isinstance(element, torch.Tensor):
|
|
142
136
|
api_args = api_name + Const.SEP + str(self.index)
|
|
143
137
|
create_directory(self.save_path)
|
|
@@ -153,10 +147,10 @@ class UtDataProcessor:
|
|
|
153
147
|
self.index += 1
|
|
154
148
|
elif isinstance(element, (list, tuple)):
|
|
155
149
|
for item in element:
|
|
156
|
-
self._save_recursive(api_name, item)
|
|
150
|
+
self._save_recursive(api_name, item, depth=depth+1)
|
|
157
151
|
elif isinstance(element, dict):
|
|
158
152
|
for value in element.values():
|
|
159
|
-
self._save_recursive(api_name, value)
|
|
153
|
+
self._save_recursive(api_name, value, depth=depth+1)
|
|
160
154
|
else:
|
|
161
155
|
self.index += 1
|
|
162
156
|
|
|
@@ -211,4 +205,42 @@ def extract_detailed_api_segments(full_api_name_with_direction_status):
|
|
|
211
205
|
else:
|
|
212
206
|
full_api_name = None
|
|
213
207
|
return api_name, full_api_name, direction_status
|
|
214
|
-
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def get_module_and_atttribute_name(attribute):
|
|
211
|
+
'''
|
|
212
|
+
Function Description:
|
|
213
|
+
Get the module and attribute name.
|
|
214
|
+
Parameter:
|
|
215
|
+
name: Attribute of a module. Example: torch.float16
|
|
216
|
+
Return:
|
|
217
|
+
module_name: Name of the module. Example: torch.
|
|
218
|
+
attribute_name: Name of the attribute. Example: float16.
|
|
219
|
+
'''
|
|
220
|
+
try:
|
|
221
|
+
module_name, attribute_name = attribute.split(Const.SEP)
|
|
222
|
+
except ValueError as e:
|
|
223
|
+
logger.error(f"Failed to get module and attribute name from {attribute}")
|
|
224
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
225
|
+
return module_name, attribute_name
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def get_attribute(module_name, attribute_name):
|
|
229
|
+
'''
|
|
230
|
+
Function Description:
|
|
231
|
+
Get the attribute of the module.
|
|
232
|
+
Parameter:
|
|
233
|
+
module_name: Name of the module.
|
|
234
|
+
attribute_name: Name of the attribute.
|
|
235
|
+
'''
|
|
236
|
+
attribute = None
|
|
237
|
+
if module_name not in Const.MODULE_WHITE_LIST:
|
|
238
|
+
logger.error(f"Module {module_name} is not in white list")
|
|
239
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
240
|
+
try:
|
|
241
|
+
module = importlib.import_module(module_name)
|
|
242
|
+
attribute = getattr(module, attribute_name)
|
|
243
|
+
except (ImportError, AttributeError) as e:
|
|
244
|
+
logger.error(f"Failed to get attribute {attribute_name} from module {module_name}: {e}")
|
|
245
|
+
raise CompareException(CompareException.INVALID_ATTRIBUTE_ERROR) from e
|
|
246
|
+
return attribute
|
|
@@ -1,3 +1,20 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
1
18
|
# 定义比对算法及比对标准
|
|
2
19
|
import torch
|
|
3
20
|
import numpy as np
|
|
@@ -142,7 +159,7 @@ def check_inf_nan_value(inf_nan_mask, bench_output, device_output, dtype, rtol):
|
|
|
142
159
|
输出:
|
|
143
160
|
inf_nan_err_ratio:npu输出和golden输出的inf、nan不一致的比例
|
|
144
161
|
'''
|
|
145
|
-
|
|
162
|
+
_, abs_gpu_with_eps = get_abs_bench_with_eps(bench_output, dtype)
|
|
146
163
|
golden_same_dtype = bench_output.astype(device_output.dtype)
|
|
147
164
|
a_min = np.finfo(device_output.dtype).min if dtype != torch.bfloat16 else CompareConst.BFLOAT16_MIN
|
|
148
165
|
a_max = np.finfo(device_output.dtype).max if dtype != torch.bfloat16 else CompareConst.BFLOAT16_MAX
|
|
@@ -209,5 +226,5 @@ def get_ulp_err(bench_output, device_output, dtype):
|
|
|
209
226
|
|
|
210
227
|
|
|
211
228
|
def calc_ulp_err(bench_output, device_output, eb, exponent_num, data_type):
|
|
212
|
-
return
|
|
229
|
+
return (device_output.astype(data_type) - bench_output).astype(data_type) * \
|
|
213
230
|
np.exp2(-eb + exponent_num).astype(data_type)
|
|
@@ -1,3 +1,20 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
1
18
|
import argparse
|
|
2
19
|
import math
|
|
3
20
|
import os
|
|
@@ -7,7 +24,7 @@ from collections import namedtuple
|
|
|
7
24
|
import torch
|
|
8
25
|
import pandas as pd
|
|
9
26
|
|
|
10
|
-
from msprobe.core.common.file_utils import write_csv
|
|
27
|
+
from msprobe.core.common.file_utils import write_csv, read_csv
|
|
11
28
|
from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
|
|
12
29
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
|
|
13
30
|
API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
|
|
@@ -23,12 +40,12 @@ from msprobe.core.common.utils import CompareException
|
|
|
23
40
|
from msprobe.core.common.const import Const, CompareConst, FileCheckConst
|
|
24
41
|
|
|
25
42
|
CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
|
|
26
|
-
|
|
43
|
+
BenchmarkInfNanConsistency = namedtuple('BenchmarkInfNanConsistency', ['small_value_inf_nan_consistency',
|
|
27
44
|
'rmse_inf_nan_consistency',
|
|
28
45
|
'max_rel_inf_nan_consistency',
|
|
29
46
|
'mean_rel_inf_nan_consistency',
|
|
30
47
|
'eb_inf_nan_consistency'])
|
|
31
|
-
|
|
48
|
+
UNSUPPORTED_MESSAGE = 'This data type does not support benchmark compare.'
|
|
32
49
|
|
|
33
50
|
DEFAULT_THRESHOLD = 1
|
|
34
51
|
|
|
@@ -154,11 +171,11 @@ class BenchmarkStandard(Standard):
|
|
|
154
171
|
self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') if rmse_inf_nan_consistency \
|
|
155
172
|
else CompareConst.ERROR
|
|
156
173
|
self.check_result_list.append(self.rmse_status)
|
|
157
|
-
self.max_rel_err_status = self._get_status(
|
|
158
|
-
|
|
174
|
+
self.max_rel_err_status = self._get_status(
|
|
175
|
+
self.max_rel_err_ratio, 'max_rel_err') if max_rel_inf_nan_consistency else CompareConst.ERROR
|
|
159
176
|
self.check_result_list.append(self.max_rel_err_status)
|
|
160
|
-
self.mean_rel_err_status = self._get_status(
|
|
161
|
-
else CompareConst.ERROR
|
|
177
|
+
self.mean_rel_err_status = self._get_status(
|
|
178
|
+
self.mean_rel_err_ratio, 'mean_rel_err') if mean_rel_inf_nan_consistency else CompareConst.ERROR
|
|
162
179
|
self.check_result_list.append(self.mean_rel_err_status)
|
|
163
180
|
self.eb_status = self._get_status(self.eb_ratio, 'eb')
|
|
164
181
|
if CompareConst.ERROR in self.check_result_list:
|
|
@@ -187,7 +204,8 @@ class BenchmarkStandard(Standard):
|
|
|
187
204
|
self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR),
|
|
188
205
|
self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0)
|
|
189
206
|
self.compare_message += max_rel_message
|
|
190
|
-
self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(
|
|
207
|
+
self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(
|
|
208
|
+
ApiPrecisionCompareColumn.MEAN_REL_ERR,
|
|
191
209
|
self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR),
|
|
192
210
|
self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0)
|
|
193
211
|
self.compare_message += mean_rel_message
|
|
@@ -196,8 +214,9 @@ class BenchmarkStandard(Standard):
|
|
|
196
214
|
self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0)
|
|
197
215
|
self.compare_message += eb_message
|
|
198
216
|
|
|
199
|
-
return
|
|
200
|
-
|
|
217
|
+
return BenchmarkInfNanConsistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
|
|
218
|
+
max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency,
|
|
219
|
+
eb_inf_nan_consistency)
|
|
201
220
|
|
|
202
221
|
|
|
203
222
|
class ULPStandard(Standard):
|
|
@@ -269,12 +288,12 @@ def api_precision_compare(config):
|
|
|
269
288
|
logger.info(f"Compare task result will be saved in {config.result_csv_path}")
|
|
270
289
|
logger.info(f"Compare task detail will be saved in {config.details_csv_path}")
|
|
271
290
|
try:
|
|
272
|
-
npu_data =
|
|
291
|
+
npu_data = read_csv(config.npu_csv_path)
|
|
273
292
|
except Exception as err:
|
|
274
293
|
logger.error(f"Open npu csv Error: %s" % str(err))
|
|
275
294
|
check_csv_columns(npu_data.columns, "npu_csv")
|
|
276
295
|
try:
|
|
277
|
-
gpu_data =
|
|
296
|
+
gpu_data = read_csv(config.gpu_csv_path)
|
|
278
297
|
except Exception as err:
|
|
279
298
|
logger.error(f"Open gpu csv Error: %s" % str(err))
|
|
280
299
|
check_csv_columns(gpu_data.columns, "gpu_csv")
|
|
@@ -292,8 +311,10 @@ def api_precision_compare(config):
|
|
|
292
311
|
|
|
293
312
|
def online_api_precision_compare(online_config):
|
|
294
313
|
rank = online_config.rank
|
|
295
|
-
result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace(
|
|
296
|
-
|
|
314
|
+
result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace(
|
|
315
|
+
"_rank*.csv", f"_rank{rank}.csv")
|
|
316
|
+
details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace(
|
|
317
|
+
"_rank*.csv", f"_rank{rank}.csv")
|
|
297
318
|
detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
|
|
298
319
|
result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
|
|
299
320
|
if not os.path.exists(result_csv_path):
|
|
@@ -315,6 +336,7 @@ def online_api_precision_compare(online_config):
|
|
|
315
336
|
def analyse_csv(npu_data, gpu_data, config):
|
|
316
337
|
forward_status, backward_status = [], []
|
|
317
338
|
last_api_name, last_api_dtype, last_api_full_name = None, None, None
|
|
339
|
+
last_api_skip_message = ''
|
|
318
340
|
for _, row_npu in npu_data.iterrows():
|
|
319
341
|
message = ''
|
|
320
342
|
compare_column = ApiPrecisionOutputColumn()
|
|
@@ -328,7 +350,7 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
328
350
|
compare_column.compare_result = CompareConst.SKIP
|
|
329
351
|
compare_column.compare_message = err_message
|
|
330
352
|
write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
|
|
331
|
-
write_csv([[full_api_name_with_direction_status, CompareConst.SKIP, CompareConst.SKIP,
|
|
353
|
+
write_csv([[full_api_name_with_direction_status, CompareConst.SKIP, CompareConst.SKIP, err_message]],
|
|
332
354
|
config.result_csv_path)
|
|
333
355
|
continue
|
|
334
356
|
if row_gpu.empty:
|
|
@@ -355,19 +377,19 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
355
377
|
|
|
356
378
|
if last_api_name is not None and api_full_name != last_api_name:
|
|
357
379
|
if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
|
|
358
|
-
message =
|
|
380
|
+
message = UNSUPPORTED_MESSAGE
|
|
359
381
|
write_csv([[last_api_name, CompareConst.SKIP, CompareConst.SKIP, message]], config.result_csv_path)
|
|
360
382
|
print_test_success(last_api_name, CompareConst.SKIP, CompareConst.SKIP)
|
|
361
|
-
forward_status, backward_status = [], []
|
|
362
|
-
message = ''
|
|
363
383
|
else:
|
|
364
384
|
forward_result = get_api_checker_result(forward_status)
|
|
365
385
|
backward_result = get_api_checker_result(backward_status)
|
|
366
386
|
message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
|
|
387
|
+
message += last_api_skip_message if forward_result == CompareConst.SKIP else ""
|
|
367
388
|
write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
|
|
368
389
|
print_test_success(last_api_name, forward_result, backward_result)
|
|
369
|
-
|
|
370
|
-
|
|
390
|
+
last_api_skip_message = ''
|
|
391
|
+
forward_status, backward_status = [], []
|
|
392
|
+
message = ''
|
|
371
393
|
|
|
372
394
|
is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST
|
|
373
395
|
last_api_name = api_full_name
|
|
@@ -378,6 +400,8 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
378
400
|
|
|
379
401
|
if direction_status == 'forward':
|
|
380
402
|
forward_status.append(new_status)
|
|
403
|
+
last_api_skip_message = str(row_npu[ApiPrecisionCompareColumn.MESSAGE]) if new_status == CompareConst.SKIP \
|
|
404
|
+
else ''
|
|
381
405
|
elif direction_status == 'backward':
|
|
382
406
|
backward_status.append(new_status)
|
|
383
407
|
else:
|
|
@@ -385,15 +409,17 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
385
409
|
|
|
386
410
|
if last_api_name is not None:
|
|
387
411
|
if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
|
|
388
|
-
message =
|
|
412
|
+
message = UNSUPPORTED_MESSAGE
|
|
389
413
|
write_csv([[last_api_name, CompareConst.SKIP, CompareConst.SKIP, message]], config.result_csv_path)
|
|
390
414
|
print_test_success(last_api_name, CompareConst.SKIP, CompareConst.SKIP)
|
|
391
415
|
else:
|
|
392
416
|
forward_result = get_api_checker_result(forward_status)
|
|
393
417
|
backward_result = get_api_checker_result(backward_status)
|
|
394
418
|
message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
|
|
419
|
+
message += last_api_skip_message if forward_result == CompareConst.SKIP else ""
|
|
395
420
|
write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
|
|
396
421
|
print_test_success(last_api_name, forward_result, backward_result)
|
|
422
|
+
last_api_skip_message = ''
|
|
397
423
|
|
|
398
424
|
|
|
399
425
|
def get_api_status(row_npu, row_gpu, api_name, compare_column):
|
|
@@ -1,3 +1,20 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
1
18
|
# 进行比对及结果展示
|
|
2
19
|
import os
|
|
3
20
|
from collections import namedtuple
|
|
@@ -127,8 +144,12 @@ class Comparator:
|
|
|
127
144
|
return test_rows
|
|
128
145
|
|
|
129
146
|
def write_csv_title(self):
|
|
130
|
-
summary_test_rows = [
|
|
131
|
-
|
|
147
|
+
summary_test_rows = [
|
|
148
|
+
[self.COLUMN_API_NAME,
|
|
149
|
+
self.COLUMN_FORWARD_SUCCESS,
|
|
150
|
+
self.COLUMN_BACKWARD_SUCCESS,
|
|
151
|
+
"Message"]
|
|
152
|
+
]
|
|
132
153
|
for save_path, detail_save_path in zip(self.save_path_list, self.detail_save_path_list):
|
|
133
154
|
if not os.path.exists(save_path):
|
|
134
155
|
write_csv(summary_test_rows, save_path)
|
|
@@ -240,13 +261,15 @@ class Comparator:
|
|
|
240
261
|
def _compare_core(self, api_name, bench_output, device_output):
|
|
241
262
|
compare_column = CompareColumn()
|
|
242
263
|
if not isinstance(bench_output, type(device_output)):
|
|
243
|
-
|
|
264
|
+
status = CompareConst.ERROR
|
|
265
|
+
message = "bench and npu output type is different."
|
|
244
266
|
elif isinstance(bench_output, dict):
|
|
245
267
|
b_keys, n_keys = set(bench_output.keys()), set(device_output.keys())
|
|
246
268
|
if b_keys != n_keys:
|
|
247
|
-
|
|
269
|
+
status = CompareConst.ERROR
|
|
270
|
+
message = "bench and npu output dict keys are different."
|
|
248
271
|
else:
|
|
249
|
-
status,
|
|
272
|
+
status, compare_column, message = self._compare_core(api_name, list(bench_output.values()),
|
|
250
273
|
list(device_output.values()))
|
|
251
274
|
elif isinstance(bench_output, torch.Tensor):
|
|
252
275
|
copy_bench_out = bench_output.detach().clone()
|
|
@@ -254,19 +277,20 @@ class Comparator:
|
|
|
254
277
|
compare_column.bench_type = str(copy_bench_out.dtype)
|
|
255
278
|
compare_column.npu_type = str(copy_device_output.dtype)
|
|
256
279
|
compare_column.shape = tuple(device_output.shape)
|
|
257
|
-
status,
|
|
280
|
+
status, compare_column, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output,
|
|
258
281
|
compare_column)
|
|
259
282
|
elif isinstance(bench_output, (bool, int, float, str)):
|
|
260
283
|
compare_column.bench_type = str(type(bench_output))
|
|
261
284
|
compare_column.npu_type = str(type(device_output))
|
|
262
|
-
status,
|
|
285
|
+
status, compare_column, message = self._compare_builtin_type(bench_output, device_output, compare_column)
|
|
263
286
|
elif bench_output is None:
|
|
264
|
-
|
|
287
|
+
status = CompareConst.SKIP
|
|
288
|
+
message = "Bench output is None, skip this test."
|
|
265
289
|
else:
|
|
266
|
-
|
|
267
|
-
|
|
290
|
+
status = CompareConst.ERROR
|
|
291
|
+
message = "Unexpected output type in compare_core: {}".format(type(bench_output))
|
|
268
292
|
|
|
269
|
-
return status,
|
|
293
|
+
return status, compare_column, message
|
|
270
294
|
|
|
271
295
|
def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column):
|
|
272
296
|
cpu_shape = bench_output.shape
|
|
@@ -330,21 +354,23 @@ class Comparator:
|
|
|
330
354
|
compare_column.max_ulp_error = np.max(ulp_err)
|
|
331
355
|
compare_column.mean_ulp_error = np.mean(ulp_err)
|
|
332
356
|
if dtype == torch.float32:
|
|
333
|
-
compare_column.ulp_error_proportion =
|
|
357
|
+
compare_column.ulp_error_proportion = \
|
|
358
|
+
np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size
|
|
334
359
|
else:
|
|
335
|
-
compare_column.ulp_error_proportion =
|
|
360
|
+
compare_column.ulp_error_proportion = \
|
|
361
|
+
np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size
|
|
336
362
|
else:
|
|
337
363
|
dtype_config = precision_configs.get(dtype)
|
|
338
364
|
small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0])
|
|
339
365
|
abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0])
|
|
340
366
|
compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask)
|
|
341
367
|
rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask)
|
|
342
|
-
compare_column.
|
|
343
|
-
compare_column.
|
|
368
|
+
compare_column.rmse = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask))
|
|
369
|
+
compare_column.eb = get_error_balance(bench_output, device_output)
|
|
344
370
|
if rel_err.size == 0:
|
|
345
371
|
return CompareConst.ERROR, compare_column, "Relative error result list is empty."
|
|
346
|
-
compare_column.
|
|
347
|
-
compare_column.
|
|
372
|
+
compare_column.max_rel_error = get_max_rel_err(rel_err)
|
|
373
|
+
compare_column.mean_rel_error = get_mean_rel_err(rel_err)
|
|
348
374
|
|
|
349
375
|
cos_res, cos_status, msg = cosine_sim(bench_output, device_output)
|
|
350
376
|
compare_column.cosine_sim = cos_res
|
|
@@ -363,7 +389,8 @@ class Comparator:
|
|
|
363
389
|
hundred_res, hundred_status = get_rel_err_ratio(rel_err_orign, CompareConst.HUNDRED_RATIO_THRESHOLD)
|
|
364
390
|
compare_column.rel_err_hundredth = hundred_res
|
|
365
391
|
if not hundred_status:
|
|
366
|
-
message += "Relative error is greater than 0.01, consider as error,
|
|
392
|
+
message += "Relative error is greater than 0.01, consider as error, " \
|
|
393
|
+
"skip other check and set to SPACE.\n"
|
|
367
394
|
return CompareConst.ERROR, compare_column, message
|
|
368
395
|
thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
|
|
369
396
|
compare_column.rel_err_thousandth = thousand_res
|
|
@@ -373,14 +400,17 @@ class Comparator:
|
|
|
373
400
|
return CompareConst.PASS, compare_column, message
|
|
374
401
|
message += "Relative error is greater than 0.001, consider as warning, skip other check and set to SPACE.\n"
|
|
375
402
|
return CompareConst.WARNING, compare_column, message
|
|
376
|
-
ten_thousand_res, ten_thousand_status = get_rel_err_ratio(
|
|
403
|
+
ten_thousand_res, ten_thousand_status = get_rel_err_ratio(
|
|
404
|
+
rel_err_orign, CompareConst.TEN_THOUSAND_RATIO_THRESHOLD)
|
|
377
405
|
compare_column.rel_err_ten_thousandth = ten_thousand_res
|
|
378
406
|
if dtype in [torch.float32, torch.float64]:
|
|
379
407
|
if not thousand_status:
|
|
380
|
-
message += "Relative error is greater than 0.001, consider as error,
|
|
408
|
+
message += "Relative error is greater than 0.001, consider as error, " \
|
|
409
|
+
"skip other check and set to SPACE.\n"
|
|
381
410
|
return CompareConst.ERROR, compare_column, message
|
|
382
411
|
if not ten_thousand_status:
|
|
383
|
-
message += "Relative error is greater than 0.0001, consider as warning,
|
|
412
|
+
message += "Relative error is greater than 0.0001, consider as warning, " \
|
|
413
|
+
"skip other check and set to SPACE.\n"
|
|
384
414
|
return CompareConst.WARNING, compare_column, message
|
|
385
415
|
message += "Relative error is less than 0.0001, consider as pass.\n"
|
|
386
416
|
return CompareConst.PASS, compare_column, message
|
|
@@ -1,3 +1,20 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
1
18
|
from msprobe.core.common.const import CompareConst
|
|
2
19
|
|
|
3
20
|
|
|
@@ -12,11 +29,11 @@ class CompareColumn:
|
|
|
12
29
|
self.rel_err_thousandth = CompareConst.SPACE
|
|
13
30
|
self.rel_err_ten_thousandth = CompareConst.SPACE
|
|
14
31
|
self.error_rate = CompareConst.SPACE
|
|
15
|
-
self.
|
|
16
|
-
self.
|
|
32
|
+
self.eb = CompareConst.SPACE
|
|
33
|
+
self.rmse = CompareConst.SPACE
|
|
17
34
|
self.small_value_err_ratio = CompareConst.SPACE
|
|
18
|
-
self.
|
|
19
|
-
self.
|
|
35
|
+
self.max_rel_error = CompareConst.SPACE
|
|
36
|
+
self.mean_rel_error = CompareConst.SPACE
|
|
20
37
|
self.inf_nan_error_ratio = CompareConst.SPACE
|
|
21
38
|
self.rel_err_ratio = CompareConst.SPACE
|
|
22
39
|
self.abs_err_ratio = CompareConst.SPACE
|
|
@@ -26,8 +43,8 @@ class CompareColumn:
|
|
|
26
43
|
|
|
27
44
|
def to_column_value(self, is_pass, message):
|
|
28
45
|
return [self.bench_type, self.npu_type, self.shape, self.cosine_sim, self.max_abs_err, self.rel_err_hundredth,
|
|
29
|
-
self.rel_err_thousandth, self.rel_err_ten_thousandth, self.error_rate, self.
|
|
30
|
-
self.small_value_err_ratio, self.
|
|
46
|
+
self.rel_err_thousandth, self.rel_err_ten_thousandth, self.error_rate, self.eb, self.rmse,
|
|
47
|
+
self.small_value_err_ratio, self.max_rel_error, self.mean_rel_error, self.inf_nan_error_ratio,
|
|
31
48
|
self.rel_err_ratio, self.abs_err_ratio, self.max_ulp_error, self.mean_ulp_error,
|
|
32
49
|
self.ulp_error_proportion, is_pass, message]
|
|
33
50
|
|
|
@@ -1,3 +1,20 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
1
18
|
import time
|
|
2
19
|
import os
|
|
3
20
|
import math
|
|
@@ -32,7 +49,8 @@ threshold_yaml_path = os.path.join(cur_path, "api_precision_threshold.yaml")
|
|
|
32
49
|
apis_threshold = load_yaml(threshold_yaml_path)
|
|
33
50
|
|
|
34
51
|
|
|
35
|
-
DETAIL_TEST_ROWS = [
|
|
52
|
+
DETAIL_TEST_ROWS = [
|
|
53
|
+
[
|
|
36
54
|
"API Name", "Bench Dtype", "DEVICE Dtype", "Shape",
|
|
37
55
|
"余弦相似度",
|
|
38
56
|
"最大绝对误差",
|
|
@@ -53,7 +71,8 @@ DETAIL_TEST_ROWS = [[
|
|
|
53
71
|
"ULP误差大于阈值占比",
|
|
54
72
|
"Status",
|
|
55
73
|
"Message"
|
|
56
|
-
|
|
74
|
+
]
|
|
75
|
+
]
|
|
57
76
|
|
|
58
77
|
|
|
59
78
|
precision_configs = {
|
|
@@ -154,11 +173,11 @@ class ApiPrecisionCompareColumn:
|
|
|
154
173
|
def to_required_columns():
|
|
155
174
|
return [ApiPrecisionCompareColumn.API_NAME, ApiPrecisionCompareColumn.DEVICE_DTYPE,
|
|
156
175
|
ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE, ApiPrecisionCompareColumn.RMSE,
|
|
157
|
-
ApiPrecisionCompareColumn.MAX_REL_ERR, ApiPrecisionCompareColumn.MEAN_REL_ERR,
|
|
158
|
-
ApiPrecisionCompareColumn.
|
|
159
|
-
ApiPrecisionCompareColumn.
|
|
160
|
-
ApiPrecisionCompareColumn.
|
|
161
|
-
ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
|
|
176
|
+
ApiPrecisionCompareColumn.MAX_REL_ERR, ApiPrecisionCompareColumn.MEAN_REL_ERR,
|
|
177
|
+
ApiPrecisionCompareColumn.EB, ApiPrecisionCompareColumn.ERROR_RATE,
|
|
178
|
+
ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO, ApiPrecisionCompareColumn.REL_ERR_RATIO,
|
|
179
|
+
ApiPrecisionCompareColumn.ABS_ERR_RATIO, ApiPrecisionCompareColumn.MEAN_ULP_ERR,
|
|
180
|
+
ApiPrecisionCompareColumn.ULP_ERR_PROPORTION, ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
|
|
162
181
|
|
|
163
182
|
@staticmethod
|
|
164
183
|
def get_detail_csv_title():
|
|
@@ -175,7 +194,8 @@ class ApiPrecisionCompareColumn:
|
|
|
175
194
|
ApiPrecisionCompareColumn.MEAN_ULP_ERR, ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
|
|
176
195
|
ApiPrecisionCompareColumn.ULP_ERR_PROPORTION_RATIO, ApiPrecisionCompareColumn.ULP_ERR_STATUS,
|
|
177
196
|
ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH, ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH_STATUS,
|
|
178
|
-
ApiPrecisionCompareColumn.FINAL_RESULT, ApiPrecisionCompareColumn.ALGORITHM,
|
|
197
|
+
ApiPrecisionCompareColumn.FINAL_RESULT, ApiPrecisionCompareColumn.ALGORITHM,
|
|
198
|
+
ApiPrecisionCompareColumn.MESSAGE]
|
|
179
199
|
|
|
180
200
|
@staticmethod
|
|
181
201
|
def get_result_csv_title():
|