mindstudio-probe 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
- mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
- mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
- mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
- mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
- mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
- msprobe/README.md +182 -0
- msprobe/__init__.py +0 -0
- msprobe/config/README.md +397 -0
- msprobe/config/config.json +28 -0
- msprobe/config/img/free_benchmark.png +0 -0
- msprobe/core/common/const.py +241 -0
- msprobe/core/common/exceptions.py +88 -0
- msprobe/core/common/file_check.py +265 -0
- msprobe/core/common/log.py +55 -0
- msprobe/core/common/utils.py +516 -0
- msprobe/core/common_config.py +58 -0
- msprobe/core/data_dump/data_collector.py +140 -0
- msprobe/core/data_dump/data_processor/base.py +245 -0
- msprobe/core/data_dump/data_processor/factory.py +61 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
- msprobe/core/data_dump/json_writer.py +116 -0
- msprobe/core/data_dump/scope.py +178 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/debugger/__init__.py +0 -0
- msprobe/mindspore/debugger/debugger_config.py +51 -0
- msprobe/mindspore/debugger/precision_debugger.py +32 -0
- msprobe/mindspore/doc/dump.md +65 -0
- msprobe/mindspore/dump/__init__.py +0 -0
- msprobe/mindspore/dump/api_kbk_dump.py +55 -0
- msprobe/mindspore/dump/dump_tool_factory.py +38 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
- msprobe/mindspore/ms_config.py +78 -0
- msprobe/mindspore/overflow_check/__init__.py +0 -0
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
- msprobe/mindspore/task_handler_factory.py +21 -0
- msprobe/msprobe.py +67 -0
- msprobe/pytorch/__init__.py +4 -0
- msprobe/pytorch/advisor/advisor.py +124 -0
- msprobe/pytorch/advisor/advisor_const.py +59 -0
- msprobe/pytorch/advisor/advisor_result.py +58 -0
- msprobe/pytorch/api_accuracy_checker/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
- msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
- msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
- msprobe/pytorch/common/__init__.py +2 -0
- msprobe/pytorch/common/compare_script.template +14 -0
- msprobe/pytorch/common/log.py +32 -0
- msprobe/pytorch/common/parse_json.py +37 -0
- msprobe/pytorch/common/utils.py +224 -0
- msprobe/pytorch/compare/acc_compare.py +1024 -0
- msprobe/pytorch/compare/distributed_compare.py +111 -0
- msprobe/pytorch/compare/highlight.py +100 -0
- msprobe/pytorch/compare/mapping.yaml +607 -0
- msprobe/pytorch/compare/match.py +36 -0
- msprobe/pytorch/compare/npy_compare.py +244 -0
- msprobe/pytorch/debugger/__init__.py +0 -0
- msprobe/pytorch/debugger/debugger_config.py +86 -0
- msprobe/pytorch/debugger/precision_debugger.py +95 -0
- msprobe/pytorch/doc/FAQ.md +193 -0
- msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
- msprobe/pytorch/doc/dump.md +207 -0
- msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
- msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
- msprobe/pytorch/doc/img/cpu_info.png +0 -0
- msprobe/pytorch/doc/img/module_compare.png +0 -0
- msprobe/pytorch/doc/parse_tool.md +286 -0
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
- msprobe/pytorch/doc/run_overflow_check.md +25 -0
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
- msprobe/pytorch/free_benchmark/__init__.py +8 -0
- msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/common/constant.py +67 -0
- msprobe/pytorch/free_benchmark/common/counter.py +72 -0
- msprobe/pytorch/free_benchmark/common/enums.py +37 -0
- msprobe/pytorch/free_benchmark/common/params.py +129 -0
- msprobe/pytorch/free_benchmark/common/utils.py +98 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
- msprobe/pytorch/free_benchmark/main.py +102 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
- msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
- msprobe/pytorch/functional/__init__.py +0 -0
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +39 -0
- msprobe/pytorch/hook_module/__init__.py +1 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -0
- msprobe/pytorch/hook_module/hook_module.py +109 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
- msprobe/pytorch/hook_module/utils.py +29 -0
- msprobe/pytorch/hook_module/wrap_aten.py +100 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
- msprobe/pytorch/hook_module/wrap_functional.py +108 -0
- msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
- msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
- msprobe/pytorch/hook_module/wrap_torch.py +88 -0
- msprobe/pytorch/hook_module/wrap_vf.py +64 -0
- msprobe/pytorch/module_processer.py +98 -0
- msprobe/pytorch/online_dispatch/__init__.py +20 -0
- msprobe/pytorch/online_dispatch/compare.py +236 -0
- msprobe/pytorch/online_dispatch/dispatch.py +274 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
- msprobe/pytorch/online_dispatch/single_compare.py +391 -0
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
- msprobe/pytorch/online_dispatch/utils.py +187 -0
- msprobe/pytorch/parse.py +4 -0
- msprobe/pytorch/parse_tool/__init__.py +0 -0
- msprobe/pytorch/parse_tool/cli.py +32 -0
- msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
- msprobe/pytorch/parse_tool/lib/compare.py +259 -0
- msprobe/pytorch/parse_tool/lib/config.py +51 -0
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
- msprobe/pytorch/parse_tool/lib/utils.py +367 -0
- msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
- msprobe/pytorch/pt_config.py +93 -0
- msprobe/pytorch/service.py +167 -0
- msprobe/test/core_ut/common/test_utils.py +345 -0
- msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
- msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
- msprobe/test/core_ut/data_dump/test_scope.py +151 -0
- msprobe/test/core_ut/test_common_config.py +152 -0
- msprobe/test/core_ut/test_file_check.py +218 -0
- msprobe/test/core_ut/test_log.py +109 -0
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
- msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
- msprobe/test/mindspore_ut/test_ms_config.py +69 -0
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
- msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
- msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
- msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
- msprobe/test/pytorch_ut/test_pt_config.py +69 -0
- msprobe/test/pytorch_ut/test_service.py +59 -0
- msprobe/test/resources/advisor.txt +3 -0
- msprobe/test/resources/compare_result_20230703104808.csv +9 -0
- msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
- msprobe/test/resources/config.yaml +3 -0
- msprobe/test/resources/npu_test.pkl +8 -0
- msprobe/test/run_test.sh +30 -0
- msprobe/test/run_ut.py +58 -0
- msprobe/test/test_module_processer.py +64 -0
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AdvisorConst:
|
|
20
|
+
"""
|
|
21
|
+
Class for advisor const
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
# text symbol
|
|
25
|
+
NEW_LINE = "\n"
|
|
26
|
+
COLON = ": "
|
|
27
|
+
|
|
28
|
+
# advisor summary key
|
|
29
|
+
SUSPECT_NODES = "Suspect Nodes"
|
|
30
|
+
LINE = "Line"
|
|
31
|
+
ADVISOR_SUGGEST = "Expert Advice"
|
|
32
|
+
|
|
33
|
+
NO_ERROR_API = "NA"
|
|
34
|
+
|
|
35
|
+
# advisor message
|
|
36
|
+
NO_ERR_SUGGEST = "All data in comparison result meets the accuracy requirements."
|
|
37
|
+
FORWARD_INPUT_SUGGEST = "1. Analyze the model to view the input source.\n" \
|
|
38
|
+
"2. Check whether an inplace API causes the output result to overwrite the input result. That is, the fault is actually caused by a computation error.\n" \
|
|
39
|
+
"3. The fault may be caused by memory corruption and further analysis is required."
|
|
40
|
+
FORWARD_OUTPUT_SUGGEST = "This is a forward API computation error. Check the computation implementation."
|
|
41
|
+
BACKWARD_INPUT_SUGGEST = "Check whether the forward computation result is affected."
|
|
42
|
+
BACKWARD_OUTPUT_SUGGEST = "This is a backward API computation error. Check the computation implementation."
|
|
43
|
+
BATCH_NORM_SUGGEST = "Torch API batch_norm input not fixed, the following suggestions may fix it:\n" \
|
|
44
|
+
"1. If use torch.nn.functional.batch_norm, you can set parameter training=False.\n" \
|
|
45
|
+
"2. If use torch.nn.BatchNormXXX, you can set parameter affine=False.\n" \
|
|
46
|
+
"3. Use seed_all(mode=True) to enable deterministic computing."
|
|
47
|
+
DETERMINISTIC_SUGGEST = "This torch api may be uncertainty in the calculation, " \
|
|
48
|
+
"can seed_all(mode=True) to enable deterministic computing."
|
|
49
|
+
|
|
50
|
+
FUNC_BATCH_NORM = "Functional_batch_norm"
|
|
51
|
+
FORWARD_INPUT_1 = "forward_input.1"
|
|
52
|
+
NEED_DETERMINISTIC_API = ["conv2d", "conv3d", "matmul", "nll_loss", "layer_norm", "lstm"]
|
|
53
|
+
BATCH_NORM = "batch_norm"
|
|
54
|
+
|
|
55
|
+
# name keyword
|
|
56
|
+
INPUT = "input"
|
|
57
|
+
OUTPUT = "output"
|
|
58
|
+
FORWARD = "forward"
|
|
59
|
+
BACKWARD = "backward"
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
import os
|
|
18
|
+
import time
|
|
19
|
+
|
|
20
|
+
from msprobe.pytorch.advisor.advisor_const import AdvisorConst
|
|
21
|
+
from msprobe.pytorch.common.log import logger
|
|
22
|
+
from msprobe.core.common.const import Const, FileCheckConst
|
|
23
|
+
from msprobe.core.common.file_check import change_mode
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AdvisorResult:
|
|
27
|
+
"""
|
|
28
|
+
Class for generate advisor result
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, node, line, message):
|
|
32
|
+
self.suspect_node = node
|
|
33
|
+
self.line = line
|
|
34
|
+
self.advisor_message = message
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def gen_summary_file(out_path, message_list):
|
|
38
|
+
file_name = 'advisor_{}.txt'.format(time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
39
|
+
result_file = os.path.join(out_path, file_name)
|
|
40
|
+
try:
|
|
41
|
+
with os.fdopen(os.open(result_file, Const.WRITE_FLAGS, Const.WRITE_MODES), 'w+') as output_file:
|
|
42
|
+
output_file.truncate(0)
|
|
43
|
+
message_list = [message + AdvisorConst.NEW_LINE for message in message_list]
|
|
44
|
+
output_file.writelines(message_list)
|
|
45
|
+
change_mode(result_file, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
46
|
+
except IOError as io_error:
|
|
47
|
+
logger.error("Failed to save %s, the reason is %s." % (result_file, io_error))
|
|
48
|
+
else:
|
|
49
|
+
logger.info("The advisor summary is saved in: %s" % result_file)
|
|
50
|
+
|
|
51
|
+
def print_advisor_log(self):
|
|
52
|
+
logger.info("The summary of the expert advice is as follows: ")
|
|
53
|
+
message_list = [AdvisorConst.LINE + AdvisorConst.COLON + str(self.line),
|
|
54
|
+
AdvisorConst.SUSPECT_NODES + AdvisorConst.COLON + self.suspect_node,
|
|
55
|
+
AdvisorConst.ADVISOR_SUGGEST + AdvisorConst.COLON + self.advisor_message]
|
|
56
|
+
for message in message_list:
|
|
57
|
+
logger.info(message)
|
|
58
|
+
return message_list
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import yaml
|
|
3
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path
|
|
4
|
+
from msprobe.pytorch.hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps
|
|
5
|
+
from msprobe.core.common.file_check import FileOpen
|
|
6
|
+
|
|
7
|
+
WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Config:
|
|
11
|
+
def __init__(self, yaml_file):
|
|
12
|
+
check_file_or_directory_path(yaml_file, False)
|
|
13
|
+
with FileOpen(yaml_file, 'r') as file:
|
|
14
|
+
config = yaml.safe_load(file)
|
|
15
|
+
self.config = {key: self.validate(key, value) for key, value in config.items()}
|
|
16
|
+
|
|
17
|
+
def __getattr__(self, item):
|
|
18
|
+
return self.config[item]
|
|
19
|
+
|
|
20
|
+
def __str__(self):
|
|
21
|
+
return '\n'.join(f"{key}={value}" for key, value in self.config.items())
|
|
22
|
+
|
|
23
|
+
@staticmethod
|
|
24
|
+
def validate(key, value):
|
|
25
|
+
validators = {
|
|
26
|
+
'white_list': list,
|
|
27
|
+
'error_data_path': str,
|
|
28
|
+
'precision': int
|
|
29
|
+
}
|
|
30
|
+
if key not in validators:
|
|
31
|
+
raise ValueError(f"{key} must be one of {validators.keys()}")
|
|
32
|
+
if not isinstance(value, validators.get(key)):
|
|
33
|
+
raise ValueError(f"{key} must be {validators[key].__name__} type")
|
|
34
|
+
if key == 'precision' and value < 0:
|
|
35
|
+
raise ValueError("precision must be greater than 0")
|
|
36
|
+
if key == 'white_list':
|
|
37
|
+
if not isinstance(value, list):
|
|
38
|
+
raise ValueError("white_list must be a list type")
|
|
39
|
+
if not all(isinstance(i, str) for i in value):
|
|
40
|
+
raise ValueError("All elements in white_list must be of str type")
|
|
41
|
+
invalid_api = [i for i in value if i not in WrapApi]
|
|
42
|
+
if invalid_api:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
f"{', '.join(invalid_api)} is not in support_wrap_ops.yaml, please check the white_list")
|
|
45
|
+
return value
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
49
|
+
yaml_path = os.path.join(cur_path, "config.yaml")
|
|
50
|
+
msCheckerConfig = Config(yaml_path)
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
import json
|
|
18
|
+
import os
|
|
19
|
+
import re
|
|
20
|
+
import csv
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
import torch_npu
|
|
26
|
+
except ImportError:
|
|
27
|
+
IS_GPU = True
|
|
28
|
+
else:
|
|
29
|
+
IS_GPU = False
|
|
30
|
+
|
|
31
|
+
from msprobe.pytorch.common.log import logger
|
|
32
|
+
from msprobe.core.common.file_check import FileChecker, FileOpen, change_mode, create_directory
|
|
33
|
+
from msprobe.core.common.const import Const, FileCheckConst
|
|
34
|
+
from msprobe.core.common.utils import CompareException
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class DumpException(CompareException):
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def write_csv(data, filepath):
|
|
42
|
+
with FileOpen(filepath, 'a', encoding='utf-8-sig') as f:
|
|
43
|
+
writer = csv.writer(f)
|
|
44
|
+
writer.writerows(data)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def check_object_type(check_object, allow_type):
|
|
48
|
+
"""
|
|
49
|
+
Function Description:
|
|
50
|
+
Check if the object belongs to a certain data type
|
|
51
|
+
Parameter:
|
|
52
|
+
check_object: the object to be checked
|
|
53
|
+
allow_type: legal data type
|
|
54
|
+
Exception Description:
|
|
55
|
+
when invalid data throw exception
|
|
56
|
+
"""
|
|
57
|
+
if not isinstance(check_object, allow_type):
|
|
58
|
+
logger.error(f"{check_object} not of {allow_type} type")
|
|
59
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def check_file_or_directory_path(path, isdir=False):
|
|
63
|
+
"""
|
|
64
|
+
Function Description:
|
|
65
|
+
check whether the path is valid
|
|
66
|
+
Parameter:
|
|
67
|
+
path: the path to check
|
|
68
|
+
isdir: the path is dir or file
|
|
69
|
+
Exception Description:
|
|
70
|
+
when invalid data throw exception
|
|
71
|
+
"""
|
|
72
|
+
if isdir:
|
|
73
|
+
if not os.path.exists(path):
|
|
74
|
+
logger.error('The path {} is not exist.'.format(path))
|
|
75
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
76
|
+
|
|
77
|
+
if not os.path.isdir(path):
|
|
78
|
+
logger.error('The path {} is not a directory.'.format(path))
|
|
79
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
80
|
+
|
|
81
|
+
if not os.access(path, os.W_OK):
|
|
82
|
+
logger.error(
|
|
83
|
+
'The path {} does not have permission to write. Please check the path permission'.format(path))
|
|
84
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
85
|
+
else:
|
|
86
|
+
if not os.path.isfile(path):
|
|
87
|
+
logger.error('{} is an invalid file or non-exist.'.format(path))
|
|
88
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
89
|
+
|
|
90
|
+
if not os.access(path, os.R_OK):
|
|
91
|
+
logger.error(
|
|
92
|
+
'The path {} does not have permission to read. Please check the path permission'.format(path))
|
|
93
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def get_json_contents(file_path):
|
|
97
|
+
ops = get_file_content_bytes(file_path)
|
|
98
|
+
try:
|
|
99
|
+
json_obj = json.loads(ops)
|
|
100
|
+
except ValueError as error:
|
|
101
|
+
logger.error('Failed to load "%s". %s' % (file_path, str(error)))
|
|
102
|
+
raise CompareException(CompareException.INVALID_FILE_ERROR) from error
|
|
103
|
+
if not isinstance(json_obj, dict):
|
|
104
|
+
logger.error('Json file %s, content is not a dictionary!' % file_path)
|
|
105
|
+
raise CompareException(CompareException.INVALID_FILE_ERROR)
|
|
106
|
+
return json_obj
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def get_file_content_bytes(file):
|
|
110
|
+
with FileOpen(file, 'rb') as file_handle:
|
|
111
|
+
return file_handle.read()
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class SoftlinkCheckException(Exception):
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def check_need_convert(api_name):
|
|
119
|
+
convert_type = None
|
|
120
|
+
for key, value in Const.CONVERT_API.items():
|
|
121
|
+
if api_name not in value:
|
|
122
|
+
continue
|
|
123
|
+
else:
|
|
124
|
+
convert_type = key
|
|
125
|
+
return convert_type
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def api_info_preprocess(api_name, api_info_dict):
|
|
129
|
+
"""
|
|
130
|
+
Function Description:
|
|
131
|
+
Preprocesses the API information.
|
|
132
|
+
Parameter:
|
|
133
|
+
api_name: Name of the API.
|
|
134
|
+
api_info_dict: argument of the API.
|
|
135
|
+
Return api_info_dict:
|
|
136
|
+
convert_type: Type of conversion.
|
|
137
|
+
api_info_dict: Processed argument of the API.
|
|
138
|
+
"""
|
|
139
|
+
convert_type = check_need_convert(api_name)
|
|
140
|
+
if api_name == 'cross_entropy':
|
|
141
|
+
api_info_dict = cross_entropy_process(api_info_dict)
|
|
142
|
+
return convert_type, api_info_dict
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def cross_entropy_process(api_info_dict):
|
|
146
|
+
"""
|
|
147
|
+
Function Description:
|
|
148
|
+
Preprocesses the cross_entropy API information.
|
|
149
|
+
Parameter:
|
|
150
|
+
api_info_dict: argument of the API.
|
|
151
|
+
Return api_info_dict:
|
|
152
|
+
api_info_dict: Processed argument of the API.
|
|
153
|
+
"""
|
|
154
|
+
if 'args' in api_info_dict and len(api_info_dict['args']) > 1 and 'Min' in api_info_dict['args'][1]:
|
|
155
|
+
if api_info_dict['args'][1]['Min'] <= 0:
|
|
156
|
+
# The second argument in cross_entropy should be -100 or not less than 0
|
|
157
|
+
api_info_dict['args'][1]['Min'] = 0
|
|
158
|
+
return api_info_dict
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def initialize_save_path(save_path, dir_name):
|
|
162
|
+
data_path = os.path.join(save_path, dir_name)
|
|
163
|
+
if os.path.exists(data_path):
|
|
164
|
+
logger.warning(f"{data_path} already exists, it will be overwritten")
|
|
165
|
+
else:
|
|
166
|
+
os.mkdir(data_path, mode=FileCheckConst.DATA_DIR_AUTHORITY)
|
|
167
|
+
data_path_checker = FileChecker(data_path, FileCheckConst.DIR)
|
|
168
|
+
data_path_checker.common_check()
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def write_pt(file_path, tensor):
|
|
172
|
+
if os.path.exists(file_path):
|
|
173
|
+
raise ValueError(f"File {file_path} already exists")
|
|
174
|
+
torch.save(tensor, file_path)
|
|
175
|
+
full_path = os.path.realpath(file_path)
|
|
176
|
+
change_mode(full_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
177
|
+
return full_path
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def get_real_data_path(file_path):
|
|
181
|
+
targets = ['forward_real_data', 'backward_real_data', 'ut_error_data\d+']
|
|
182
|
+
pattern = re.compile(r'({})'.format('|'.join(targets)))
|
|
183
|
+
match = pattern.search(file_path)
|
|
184
|
+
if match:
|
|
185
|
+
target_index = match.start()
|
|
186
|
+
target_path = file_path[target_index:]
|
|
187
|
+
return target_path
|
|
188
|
+
else:
|
|
189
|
+
raise DumpException(DumpException.INVALID_PATH_ERROR)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def get_full_data_path(data_path, real_data_path):
|
|
193
|
+
if not data_path:
|
|
194
|
+
return data_path
|
|
195
|
+
full_data_path = os.path.join(real_data_path, data_path)
|
|
196
|
+
return os.path.realpath(full_data_path)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class UtDataProcessor:
|
|
200
|
+
def __init__(self, save_path):
|
|
201
|
+
self.save_path = save_path
|
|
202
|
+
self.index = 0
|
|
203
|
+
|
|
204
|
+
def save_tensors_in_element(self, api_name, element):
|
|
205
|
+
self.index = 0
|
|
206
|
+
self._save_recursive(api_name, element)
|
|
207
|
+
|
|
208
|
+
def _save_recursive(self, api_name, element):
|
|
209
|
+
if isinstance(element, torch.Tensor):
|
|
210
|
+
api_args = api_name + Const.SEP + str(self.index)
|
|
211
|
+
create_directory(self.save_path)
|
|
212
|
+
file_path = os.path.join(self.save_path, f'{api_args}.pt')
|
|
213
|
+
write_pt(file_path, element.contiguous().cpu().detach())
|
|
214
|
+
self.index += 1
|
|
215
|
+
elif element is None or isinstance(element, (bool, int, float, str, slice)):
|
|
216
|
+
self.index += 1
|
|
217
|
+
elif isinstance(element, (list, tuple)):
|
|
218
|
+
for item in element:
|
|
219
|
+
self._save_recursive(api_name, item)
|
|
220
|
+
elif isinstance(element, dict):
|
|
221
|
+
for value in element.values():
|
|
222
|
+
self._save_recursive(api_name, value)
|
|
223
|
+
else:
|
|
224
|
+
self.index += 1
|
|
File without changes
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
# 定义比对算法及比对标准
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAMETERS
|
|
6
|
+
from msprobe.core.common.const import CompareConst
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
DEFAULT_THRESHOLD = 1
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
#cos
|
|
13
|
+
def cosine_sim(bench_output, device_output):
|
|
14
|
+
msg = ""
|
|
15
|
+
n_value = device_output.reshape(-1)
|
|
16
|
+
b_value = bench_output.reshape(-1)
|
|
17
|
+
cos = CompareConst.SPACE
|
|
18
|
+
np.seterr(divide="ignore", invalid="ignore")
|
|
19
|
+
if n_value.shape != b_value.shape:
|
|
20
|
+
msg = f"Shape of device and bench outputs don't match. device: {n_value.shape}, bench: {b_value.shape}."
|
|
21
|
+
return -1, False, msg
|
|
22
|
+
if len(n_value) == 1:
|
|
23
|
+
msg = "All the data in device dump data is scalar. Please refer to other compare algorithms."
|
|
24
|
+
return cos, True, msg
|
|
25
|
+
n_value_max = np.max(np.abs(n_value))
|
|
26
|
+
b_value_max = np.max(np.abs(b_value))
|
|
27
|
+
if n_value_max <= np.finfo(float).eps and b_value_max <= np.finfo(float).eps:
|
|
28
|
+
msg = "All the data in device and bench outputs are zero."
|
|
29
|
+
return cos, True, msg
|
|
30
|
+
elif n_value_max <= np.finfo(float).eps:
|
|
31
|
+
msg = "All the data is zero in device dump data."
|
|
32
|
+
return CompareConst.SPACE, False, msg
|
|
33
|
+
elif b_value_max <= np.finfo(float).eps:
|
|
34
|
+
msg = "All the data is zero in bench dump data."
|
|
35
|
+
return CompareConst.SPACE, False, msg
|
|
36
|
+
else:
|
|
37
|
+
n_value = n_value.astype(float) / n_value_max
|
|
38
|
+
b_value = b_value.astype(float) / b_value_max
|
|
39
|
+
cos = np.dot(n_value, b_value) / (np.linalg.norm(n_value) * np.linalg.norm(b_value))
|
|
40
|
+
if np.isnan(cos):
|
|
41
|
+
msg = "Dump data has NaN when comparing with Cosine Similarity."
|
|
42
|
+
cos = np.clip(cos, -1, 1)
|
|
43
|
+
return cos, cos > 0.99, msg
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
#rmse
|
|
47
|
+
def get_rmse(abs_err, inf_nan_mask):
|
|
48
|
+
masked_ae = np.where(inf_nan_mask, 0, abs_err)
|
|
49
|
+
mse = np.mean(np.square(masked_ae))
|
|
50
|
+
inf_nan_cnt = np.sum(inf_nan_mask)
|
|
51
|
+
mse = mse * (abs_err.size / (abs_err.size - inf_nan_cnt + 0.0001) + 0.0001)
|
|
52
|
+
rmse = np.sqrt(mse)
|
|
53
|
+
return rmse
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
#误差均衡性
|
|
57
|
+
def get_error_balance(bench_data, device_data):
|
|
58
|
+
larger_count = np.sum(np.greater(device_data - bench_data.astype(device_data.dtype), 0))
|
|
59
|
+
smaller_count = np.sum(np.less(device_data - bench_data.astype(device_data.dtype), 0))
|
|
60
|
+
total_count = bench_data.size
|
|
61
|
+
error_balance = abs(larger_count - smaller_count) / total_count if total_count > 0 else 0
|
|
62
|
+
return error_balance
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
#小值域错误占比
|
|
66
|
+
def get_small_value_err_ratio(small_value_mask, abs_err_greater_mask):
|
|
67
|
+
err_mask = np.logical_and(small_value_mask, abs_err_greater_mask)
|
|
68
|
+
small_value_err_num = np.sum(err_mask)
|
|
69
|
+
small_value_num = np.sum(small_value_mask)
|
|
70
|
+
return 0 if small_value_num == 0 else small_value_err_num / small_value_num
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask):
|
|
74
|
+
rel_err_tmp = abs_err / abs_bench_with_eps
|
|
75
|
+
rel_err_mask = np.logical_or(small_value_mask, inf_nan_mask)
|
|
76
|
+
rel_err = np.where(rel_err_mask, -1, rel_err_tmp)
|
|
77
|
+
return rel_err
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_abs_err(bench_data, device_data):
|
|
81
|
+
abs_err = np.abs(device_data - bench_data)
|
|
82
|
+
return abs_err
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_rel_err_origin(abs_err, b_value):
|
|
86
|
+
rel_err_origin = np.abs(abs_err / b_value)
|
|
87
|
+
return rel_err_origin
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_max_abs_err(abs_err):
|
|
91
|
+
max_abs_err = abs_err.max()
|
|
92
|
+
bool_result = max_abs_err < 0.001
|
|
93
|
+
return max_abs_err, bool_result
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
#相对误差最大值
|
|
97
|
+
def get_max_rel_err(rel_err):
|
|
98
|
+
return np.max(rel_err) if np.max(rel_err) >= 0 else 0
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
#相对误差均值
|
|
102
|
+
def get_mean_rel_err(rel_err):
|
|
103
|
+
non_negative_rel_err = rel_err[rel_err >= 0]
|
|
104
|
+
return np.mean(non_negative_rel_err) if non_negative_rel_err.size > 0 else 0
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def get_rel_err_ratio(rel_err, thresholding):
|
|
108
|
+
if np.size(rel_err) == 0:
|
|
109
|
+
ratio = 1
|
|
110
|
+
else:
|
|
111
|
+
ratio = np.divide(np.sum(rel_err < thresholding), np.size(rel_err))
|
|
112
|
+
bool_result = ratio > (1 - thresholding)
|
|
113
|
+
return ratio, bool_result
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def get_finite_and_infinite_mask(bench_output, device_output):
|
|
117
|
+
device_finite_mask = np.isfinite(device_output)
|
|
118
|
+
bench_finite_mask = np.isfinite(bench_output.astype(device_output.dtype))
|
|
119
|
+
both_finite_mask = np.logical_and(device_finite_mask, bench_finite_mask)
|
|
120
|
+
inf_nan_mask = np.logical_not(both_finite_mask)
|
|
121
|
+
return both_finite_mask, inf_nan_mask
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold):
|
|
125
|
+
small_value_mask = np.less_equal(abs_bench, small_value_threshold)
|
|
126
|
+
small_value_mask = np.logical_and(small_value_mask, both_finite_mask)
|
|
127
|
+
return small_value_mask
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def get_abs_bench_with_eps(bench, dtype):
|
|
131
|
+
abs_bench = np.abs(bench)
|
|
132
|
+
eps = np.finfo(bench.dtype).eps if dtype != torch.bfloat16 else CompareConst.BFLOAT16_EPS
|
|
133
|
+
abs_bench_with_eps = abs_bench + eps
|
|
134
|
+
return abs_bench, abs_bench_with_eps
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def check_inf_nan_value(inf_nan_mask, bench_output, device_output, dtype, rtol):
|
|
138
|
+
'''
|
|
139
|
+
新精度标准的绝对阈值法中,检查npu和golden输出的inf、nan是否一致
|
|
140
|
+
输入:
|
|
141
|
+
inf_nan_mask:npu输出和golden输出的inf、nan的mask
|
|
142
|
+
bench_output:golden输出
|
|
143
|
+
device_output:npu输出
|
|
144
|
+
dtype:npu输出的dtype
|
|
145
|
+
输出:
|
|
146
|
+
inf_nan_err_ratio:npu输出和golden输出的inf、nan不一致的比例
|
|
147
|
+
'''
|
|
148
|
+
abs_gpu, abs_gpu_with_eps = get_abs_bench_with_eps(bench_output, dtype)
|
|
149
|
+
golden_same_dtype = bench_output.astype(device_output.dtype)
|
|
150
|
+
a_min = np.finfo(device_output.dtype).min if dtype != torch.bfloat16 else CompareConst.BFLOAT16_MIN
|
|
151
|
+
a_max = np.finfo(device_output.dtype).max if dtype != torch.bfloat16 else CompareConst.BFLOAT16_MAX
|
|
152
|
+
golden_clip = np.clip(golden_same_dtype, a_min, a_max)
|
|
153
|
+
npu_clip = np.clip(device_output, a_min, a_max)
|
|
154
|
+
clipped_abs_ae = np.abs(npu_clip - golden_clip)
|
|
155
|
+
clipped_re = clipped_abs_ae / abs_gpu_with_eps
|
|
156
|
+
pass_mask = np.less_equal(clipped_re, rtol)
|
|
157
|
+
both_nan_mask = np.logical_and(np.isnan(device_output), np.isnan(golden_clip))
|
|
158
|
+
pass_mask = np.logical_or(pass_mask, both_nan_mask)
|
|
159
|
+
not_pass_mask = np.logical_not(pass_mask)
|
|
160
|
+
not_pass_mask = np.logical_and(not_pass_mask, inf_nan_mask)
|
|
161
|
+
|
|
162
|
+
inf_nan_err_cnt = np.sum(not_pass_mask)
|
|
163
|
+
return 0 if np.sum(inf_nan_mask) == 0 else inf_nan_err_cnt / np.sum(inf_nan_mask)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def check_small_value(abs_err, small_value_mask, small_value_atol):
|
|
167
|
+
'''
|
|
168
|
+
新精度标准的相对阈值法中,检查npu和golden小值域输出的相对误差是否满足阈值
|
|
169
|
+
输入:
|
|
170
|
+
rel_err:npu输出和golden输出的相对误差
|
|
171
|
+
normal_value_mask:npu输出和golden输出的正常值mask
|
|
172
|
+
rtol:相对误差的阈值
|
|
173
|
+
输出:
|
|
174
|
+
rel_err_ratio:npu输出和golden输出的相对误差不满足阈值的比例
|
|
175
|
+
'''
|
|
176
|
+
greater_mask = np.greater(abs_err, small_value_atol)
|
|
177
|
+
err_mask = np.logical_and(greater_mask, small_value_mask)
|
|
178
|
+
err_cnt = np.sum(err_mask)
|
|
179
|
+
return 0 if np.sum(small_value_mask) == 0 else err_cnt / np.sum(small_value_mask)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def check_norm_value(normal_value_mask, rel_err, rtol):
|
|
183
|
+
'''
|
|
184
|
+
新精度标准的绝对阈值法中,检查npu和golden正常值输出的绝对误差是否满足阈值
|
|
185
|
+
输入:
|
|
186
|
+
abs_err:npu输出和golden输出的绝对误差
|
|
187
|
+
normal_value_mask:npu输出和golden输出的正常值mask
|
|
188
|
+
atol:绝对误差的阈值
|
|
189
|
+
输出:
|
|
190
|
+
abs_err_ratio:npu输出和golden输出的绝对误差不满足阈值的比例
|
|
191
|
+
'''
|
|
192
|
+
err_mask = np.greater(rel_err, rtol)
|
|
193
|
+
err_mask = np.logical_and(err_mask, normal_value_mask)
|
|
194
|
+
err_cnt = np.sum(err_mask)
|
|
195
|
+
return 0 if np.sum(normal_value_mask) == 0 else err_cnt / np.sum(normal_value_mask)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def get_ulp_err(bench_output, device_output, dtype):
|
|
199
|
+
parameters = ULP_PARAMETERS.get(dtype)
|
|
200
|
+
min_eb = parameters.get('min_eb', DEFAULT_THRESHOLD)[0]
|
|
201
|
+
exponent_num = parameters.get('exponent_num', DEFAULT_THRESHOLD)[0]
|
|
202
|
+
abs_bench = np.abs(bench_output)
|
|
203
|
+
eb = np.where(abs_bench == 0, 0, np.floor(np.log2(abs_bench)))
|
|
204
|
+
eb = np.maximum(eb, min_eb)
|
|
205
|
+
|
|
206
|
+
if dtype == torch.float32:
|
|
207
|
+
ulp_err = calc_ulp_err(bench_output, device_output, eb, exponent_num, np.float64)
|
|
208
|
+
else:
|
|
209
|
+
ulp_err = calc_ulp_err(bench_output, device_output, eb, exponent_num, np.float32)
|
|
210
|
+
ulp_err = np.abs(ulp_err)
|
|
211
|
+
return ulp_err
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def calc_ulp_err(bench_output, device_output, eb, exponent_num, data_type):
|
|
215
|
+
return (device_output.astype(data_type) - bench_output).astype(data_type) * \
|
|
216
|
+
np.exp2(-eb + exponent_num).astype(data_type)
|