mindstudio-probe 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
- mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
- mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
- mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
- mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
- mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
- msprobe/README.md +182 -0
- msprobe/__init__.py +0 -0
- msprobe/config/README.md +397 -0
- msprobe/config/config.json +28 -0
- msprobe/config/img/free_benchmark.png +0 -0
- msprobe/core/common/const.py +241 -0
- msprobe/core/common/exceptions.py +88 -0
- msprobe/core/common/file_check.py +265 -0
- msprobe/core/common/log.py +55 -0
- msprobe/core/common/utils.py +516 -0
- msprobe/core/common_config.py +58 -0
- msprobe/core/data_dump/data_collector.py +140 -0
- msprobe/core/data_dump/data_processor/base.py +245 -0
- msprobe/core/data_dump/data_processor/factory.py +61 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
- msprobe/core/data_dump/json_writer.py +116 -0
- msprobe/core/data_dump/scope.py +178 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/debugger/__init__.py +0 -0
- msprobe/mindspore/debugger/debugger_config.py +51 -0
- msprobe/mindspore/debugger/precision_debugger.py +32 -0
- msprobe/mindspore/doc/dump.md +65 -0
- msprobe/mindspore/dump/__init__.py +0 -0
- msprobe/mindspore/dump/api_kbk_dump.py +55 -0
- msprobe/mindspore/dump/dump_tool_factory.py +38 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
- msprobe/mindspore/ms_config.py +78 -0
- msprobe/mindspore/overflow_check/__init__.py +0 -0
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
- msprobe/mindspore/task_handler_factory.py +21 -0
- msprobe/msprobe.py +67 -0
- msprobe/pytorch/__init__.py +4 -0
- msprobe/pytorch/advisor/advisor.py +124 -0
- msprobe/pytorch/advisor/advisor_const.py +59 -0
- msprobe/pytorch/advisor/advisor_result.py +58 -0
- msprobe/pytorch/api_accuracy_checker/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
- msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
- msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
- msprobe/pytorch/common/__init__.py +2 -0
- msprobe/pytorch/common/compare_script.template +14 -0
- msprobe/pytorch/common/log.py +32 -0
- msprobe/pytorch/common/parse_json.py +37 -0
- msprobe/pytorch/common/utils.py +224 -0
- msprobe/pytorch/compare/acc_compare.py +1024 -0
- msprobe/pytorch/compare/distributed_compare.py +111 -0
- msprobe/pytorch/compare/highlight.py +100 -0
- msprobe/pytorch/compare/mapping.yaml +607 -0
- msprobe/pytorch/compare/match.py +36 -0
- msprobe/pytorch/compare/npy_compare.py +244 -0
- msprobe/pytorch/debugger/__init__.py +0 -0
- msprobe/pytorch/debugger/debugger_config.py +86 -0
- msprobe/pytorch/debugger/precision_debugger.py +95 -0
- msprobe/pytorch/doc/FAQ.md +193 -0
- msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
- msprobe/pytorch/doc/dump.md +207 -0
- msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
- msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
- msprobe/pytorch/doc/img/cpu_info.png +0 -0
- msprobe/pytorch/doc/img/module_compare.png +0 -0
- msprobe/pytorch/doc/parse_tool.md +286 -0
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
- msprobe/pytorch/doc/run_overflow_check.md +25 -0
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
- msprobe/pytorch/free_benchmark/__init__.py +8 -0
- msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/common/constant.py +67 -0
- msprobe/pytorch/free_benchmark/common/counter.py +72 -0
- msprobe/pytorch/free_benchmark/common/enums.py +37 -0
- msprobe/pytorch/free_benchmark/common/params.py +129 -0
- msprobe/pytorch/free_benchmark/common/utils.py +98 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
- msprobe/pytorch/free_benchmark/main.py +102 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
- msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
- msprobe/pytorch/functional/__init__.py +0 -0
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +39 -0
- msprobe/pytorch/hook_module/__init__.py +1 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -0
- msprobe/pytorch/hook_module/hook_module.py +109 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
- msprobe/pytorch/hook_module/utils.py +29 -0
- msprobe/pytorch/hook_module/wrap_aten.py +100 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
- msprobe/pytorch/hook_module/wrap_functional.py +108 -0
- msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
- msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
- msprobe/pytorch/hook_module/wrap_torch.py +88 -0
- msprobe/pytorch/hook_module/wrap_vf.py +64 -0
- msprobe/pytorch/module_processer.py +98 -0
- msprobe/pytorch/online_dispatch/__init__.py +20 -0
- msprobe/pytorch/online_dispatch/compare.py +236 -0
- msprobe/pytorch/online_dispatch/dispatch.py +274 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
- msprobe/pytorch/online_dispatch/single_compare.py +391 -0
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
- msprobe/pytorch/online_dispatch/utils.py +187 -0
- msprobe/pytorch/parse.py +4 -0
- msprobe/pytorch/parse_tool/__init__.py +0 -0
- msprobe/pytorch/parse_tool/cli.py +32 -0
- msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
- msprobe/pytorch/parse_tool/lib/compare.py +259 -0
- msprobe/pytorch/parse_tool/lib/config.py +51 -0
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
- msprobe/pytorch/parse_tool/lib/utils.py +367 -0
- msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
- msprobe/pytorch/pt_config.py +93 -0
- msprobe/pytorch/service.py +167 -0
- msprobe/test/core_ut/common/test_utils.py +345 -0
- msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
- msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
- msprobe/test/core_ut/data_dump/test_scope.py +151 -0
- msprobe/test/core_ut/test_common_config.py +152 -0
- msprobe/test/core_ut/test_file_check.py +218 -0
- msprobe/test/core_ut/test_log.py +109 -0
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
- msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
- msprobe/test/mindspore_ut/test_ms_config.py +69 -0
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
- msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
- msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
- msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
- msprobe/test/pytorch_ut/test_pt_config.py +69 -0
- msprobe/test/pytorch_ut/test_service.py +59 -0
- msprobe/test/resources/advisor.txt +3 -0
- msprobe/test/resources/compare_result_20230703104808.csv +9 -0
- msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
- msprobe/test/resources/config.yaml +3 -0
- msprobe/test/resources/npu_test.pkl +8 -0
- msprobe/test/run_test.sh +30 -0
- msprobe/test/run_ut.py +58 -0
- msprobe/test/test_module_processer.py +64 -0
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import os
|
|
19
|
+
import math
|
|
20
|
+
import torch
|
|
21
|
+
import numpy
|
|
22
|
+
|
|
23
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
|
|
24
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path, check_object_type, \
|
|
25
|
+
get_full_data_path, CompareException
|
|
26
|
+
from msprobe.pytorch.common.log import logger
|
|
27
|
+
from msprobe.core.common.const import Const
|
|
28
|
+
|
|
29
|
+
TORCH_TYPE = ["torch.device", "torch.dtype"]
|
|
30
|
+
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
|
|
31
|
+
FLOAT_TYPE = ['torch.float32', 'torch.float', 'torch.float64', 'torch.double', 'torch.float16',
|
|
32
|
+
'torch.half', 'torch.bfloat16']
|
|
33
|
+
NUMPY_TYPE = ["numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32",
|
|
34
|
+
"numpy.uint64", "numpy.float16", "numpy.float32", "numpy.float64", "numpy.float128", "numpy.complex64",
|
|
35
|
+
"numpy.complex128", "numpy.complex256", "numpy.bool_", "numpy.string_", "numpy.bytes_", "numpy.unicode_"]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
|
|
39
|
+
"""
|
|
40
|
+
Function Description:
|
|
41
|
+
Based on arg basic information, generate arg data
|
|
42
|
+
Parameter:
|
|
43
|
+
info: arg basic information. Dict
|
|
44
|
+
api_name: API name
|
|
45
|
+
need_grad: set Tensor grad for backward
|
|
46
|
+
convert_type: convert ori_type to dist_type flag.
|
|
47
|
+
"""
|
|
48
|
+
check_object_type(info, dict)
|
|
49
|
+
data_type = info.get('type')
|
|
50
|
+
data_path = info.get('datapath', info.get('data_name'))
|
|
51
|
+
data_path = get_full_data_path(data_path, real_data_path)
|
|
52
|
+
if data_type in TENSOR_DATA_LIST:
|
|
53
|
+
if data_path:
|
|
54
|
+
data = gen_real_tensor(data_path, convert_type)
|
|
55
|
+
else:
|
|
56
|
+
data = gen_random_tensor(info, convert_type)
|
|
57
|
+
if api_name in hf_32_standard_api and data.dtype == torch.float32:
|
|
58
|
+
data = fp32_to_hf32_to_fp32(data)
|
|
59
|
+
if info.get('requires_grad') and need_grad:
|
|
60
|
+
data.requires_grad_(True)
|
|
61
|
+
temp_data = data * 1
|
|
62
|
+
data = temp_data.type_as(data)
|
|
63
|
+
data.retain_grad()
|
|
64
|
+
elif data_type.startswith("numpy"):
|
|
65
|
+
if data_type not in NUMPY_TYPE:
|
|
66
|
+
raise Exception("{} is not supported now".format(data_type))
|
|
67
|
+
data = info.get("value")
|
|
68
|
+
try:
|
|
69
|
+
data = eval(data_type)(data)
|
|
70
|
+
except Exception as err:
|
|
71
|
+
logger.error("Failed to convert the type to numpy: %s" % str(err))
|
|
72
|
+
elif data_type == "torch.Size":
|
|
73
|
+
data = torch.Size(info.get("value"))
|
|
74
|
+
else:
|
|
75
|
+
data = info.get('value')
|
|
76
|
+
if info.get("type") == "slice":
|
|
77
|
+
data = slice(*data)
|
|
78
|
+
return data
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def gen_real_tensor(data_path, convert_type):
|
|
82
|
+
"""
|
|
83
|
+
Function Description:
|
|
84
|
+
Based on API data path, generate input parameters real data
|
|
85
|
+
Parameter:
|
|
86
|
+
data_path: API data path
|
|
87
|
+
convert_type: convert ori_type to dist_type flag.
|
|
88
|
+
"""
|
|
89
|
+
data_path = os.path.realpath(data_path)
|
|
90
|
+
check_file_or_directory_path(data_path)
|
|
91
|
+
if not data_path.endswith('.pt') and not data_path.endswith('.npy'):
|
|
92
|
+
error_info = f"The file: {data_path} is not a pt or numpy file."
|
|
93
|
+
raise CompareException(CompareException.INVALID_FILE_ERROR, error_info)
|
|
94
|
+
if data_path.endswith('.pt'):
|
|
95
|
+
data = torch.load(data_path).cpu()
|
|
96
|
+
else:
|
|
97
|
+
data_np = numpy.load(data_path)
|
|
98
|
+
data = torch.from_numpy(data_np)
|
|
99
|
+
if convert_type:
|
|
100
|
+
ori_dtype = Const.CONVERT.get(convert_type)[0]
|
|
101
|
+
dist_dtype = Const.CONVERT.get(convert_type)[1]
|
|
102
|
+
if str(data.dtype) == ori_dtype:
|
|
103
|
+
data = data.type(eval(dist_dtype))
|
|
104
|
+
return data
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def gen_random_tensor(info, convert_type):
|
|
108
|
+
"""
|
|
109
|
+
Function Description:
|
|
110
|
+
Based on API MAX and MIN, generate input parameters random data
|
|
111
|
+
Parameter:
|
|
112
|
+
info: API data info
|
|
113
|
+
convert_type: convert ori_type to dist_type flag.
|
|
114
|
+
"""
|
|
115
|
+
check_object_type(info, dict)
|
|
116
|
+
low, high = info.get('Min'), info.get('Max')
|
|
117
|
+
low_origin, high_origin = info.get('Min_origin'), info.get('Max_origin')
|
|
118
|
+
low_info = [low, low_origin]
|
|
119
|
+
high_info = [high, high_origin]
|
|
120
|
+
data_dtype = info.get('dtype')
|
|
121
|
+
shape = tuple(info.get('shape'))
|
|
122
|
+
if not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
|
|
123
|
+
error_info = f'Data info Min: {low} , Max: {high}, info type must be int or float.'
|
|
124
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
|
|
125
|
+
if data_dtype == "torch.bool":
|
|
126
|
+
data = gen_bool_tensor(low, high, shape)
|
|
127
|
+
else:
|
|
128
|
+
data = gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type)
|
|
129
|
+
return data
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def fp32_to_hf32_to_fp32(input_tensor):
|
|
133
|
+
# 将输入的float32 tensor转为hf32 tensor,再转为float32 tensor
|
|
134
|
+
input_np = input_tensor.detach().numpy()
|
|
135
|
+
input_int = input_np.view(numpy.int32)
|
|
136
|
+
input_int = numpy.right_shift(numpy.right_shift(input_int, 11) + 1, 1)
|
|
137
|
+
input_int = numpy.left_shift(input_int, 12)
|
|
138
|
+
input_fp32 = input_int.view(numpy.float32)
|
|
139
|
+
input_hf32 = torch.from_numpy(input_fp32)
|
|
140
|
+
return input_hf32
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
|
|
144
|
+
"""
|
|
145
|
+
Function Description:
|
|
146
|
+
Based on API basic information, generate int or float tensor
|
|
147
|
+
Parameter:
|
|
148
|
+
low_info: [low, low_origin], low is the minimum value in the tensor removed inf and nan,
|
|
149
|
+
low_origin is the original minimum value in the tensor
|
|
150
|
+
high_info: [high, high_origin], high is the maximum value in the tensor removed inf and nan,
|
|
151
|
+
high_origin is the original maximum value in the tensor
|
|
152
|
+
shape:The shape of Tensor
|
|
153
|
+
data_dtype: The data type of Tensor
|
|
154
|
+
convert_type: convert ori_type to dist_type flag.
|
|
155
|
+
"""
|
|
156
|
+
if convert_type:
|
|
157
|
+
ori_dtype = Const.CONVERT.get(convert_type)[0]
|
|
158
|
+
if ori_dtype == data_dtype:
|
|
159
|
+
data_dtype = Const.CONVERT.get(convert_type)[1]
|
|
160
|
+
low, low_origin = low_info[0], low_info[1]
|
|
161
|
+
high, high_origin = high_info[0], high_info[1]
|
|
162
|
+
if data_dtype in FLOAT_TYPE:
|
|
163
|
+
if math.isnan(high):
|
|
164
|
+
tensor = torch._C._VariableFunctionsClass.full(shape, float('nan'), dtype=eval(data_dtype))
|
|
165
|
+
return tensor
|
|
166
|
+
#high_origin为新版json中的属性,只有当high_origin不为None,且high为inf或-inf时,原tensor全为inf或-inf
|
|
167
|
+
if high_origin and high in [float('inf'), float('-inf')]:
|
|
168
|
+
tensor = torch._C._VariableFunctionsClass.full(shape, high, dtype=eval(data_dtype))
|
|
169
|
+
tensor[-1] = low
|
|
170
|
+
return tensor
|
|
171
|
+
low_scale, high_scale = low, high
|
|
172
|
+
dtype_finfo = torch.finfo(eval(data_dtype))
|
|
173
|
+
#适配老版json high和low为inf或-inf的情况,取dtype的最大值或最小值进行放缩
|
|
174
|
+
if high == float('inf'):
|
|
175
|
+
high_scale = dtype_finfo.max
|
|
176
|
+
elif high == float('-inf'):
|
|
177
|
+
high_scale = dtype_finfo.min
|
|
178
|
+
if low == float('inf'):
|
|
179
|
+
low_scale = dtype_finfo.max
|
|
180
|
+
elif low == float('-inf'):
|
|
181
|
+
low_scale = dtype_finfo.min
|
|
182
|
+
|
|
183
|
+
scale = high_scale - low_scale
|
|
184
|
+
rand01 = torch.rand(shape, dtype=eval(data_dtype))
|
|
185
|
+
tensor = rand01 * scale + low_scale
|
|
186
|
+
elif 'int' in data_dtype or 'long' in data_dtype:
|
|
187
|
+
low, high = int(low), int(high)
|
|
188
|
+
tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype))
|
|
189
|
+
else:
|
|
190
|
+
logger.error('Dtype is not supported: ' + data_dtype)
|
|
191
|
+
raise NotImplementedError()
|
|
192
|
+
if tensor.nelement() == 0:
|
|
193
|
+
return tensor
|
|
194
|
+
tmp_tensor = tensor.reshape(-1)
|
|
195
|
+
if high_origin and math.isnan(high_origin):
|
|
196
|
+
if tmp_tensor.numel() <= 2:
|
|
197
|
+
tmp_tensor[0] = float('nan')
|
|
198
|
+
tmp_tensor[-1] = high
|
|
199
|
+
else:
|
|
200
|
+
tmp_tensor[0] = low
|
|
201
|
+
tmp_tensor[1] = float('nan')
|
|
202
|
+
tmp_tensor[-1] = high
|
|
203
|
+
else:
|
|
204
|
+
tmp_tensor[0] = low
|
|
205
|
+
tmp_tensor[-1] = high
|
|
206
|
+
if high_origin in [float('inf'), float('-inf')]:
|
|
207
|
+
tmp_tensor[-1] = high_origin
|
|
208
|
+
if low_origin in [float('inf'), float('-inf')]:
|
|
209
|
+
tmp_tensor[0] = low_origin
|
|
210
|
+
data = tmp_tensor.reshape(shape)
|
|
211
|
+
return data
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def gen_bool_tensor(low, high, shape):
|
|
215
|
+
"""
|
|
216
|
+
Function Description:
|
|
217
|
+
Based on API basic information, generate bool tensor
|
|
218
|
+
Parameter:
|
|
219
|
+
low: The minimum value in Tensor
|
|
220
|
+
high: The max value in Tensor
|
|
221
|
+
shape:The shape of Tensor
|
|
222
|
+
"""
|
|
223
|
+
low, high = int(low), int(high)
|
|
224
|
+
if low > high:
|
|
225
|
+
low, high = high, low
|
|
226
|
+
tensor = torch.randint(low, high + 1, shape)
|
|
227
|
+
data = torch.gt(tensor, 0)
|
|
228
|
+
return data
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
|
|
232
|
+
"""
|
|
233
|
+
Function Description:
|
|
234
|
+
Based on API basic information, generate input parameters: args, for API forward running
|
|
235
|
+
Parameter:
|
|
236
|
+
api_info: API basic information. List
|
|
237
|
+
api_name: API name
|
|
238
|
+
need_grad: set Tensor grad for backward
|
|
239
|
+
convert_type: convert ori_type to dist_type flag.
|
|
240
|
+
real_data_path: the root directory for storing real data.
|
|
241
|
+
"""
|
|
242
|
+
check_object_type(args_info, list)
|
|
243
|
+
args_result = []
|
|
244
|
+
for arg in args_info:
|
|
245
|
+
if isinstance(arg, (list, tuple)):
|
|
246
|
+
data = gen_args(arg, api_name, need_grad, convert_type, real_data_path)
|
|
247
|
+
elif isinstance(arg, dict):
|
|
248
|
+
data = gen_data(arg, api_name, need_grad, convert_type, real_data_path)
|
|
249
|
+
elif arg is None:
|
|
250
|
+
data = None
|
|
251
|
+
else:
|
|
252
|
+
logger.warning(f'Warning: {arg} is not supported')
|
|
253
|
+
raise NotImplementedError()
|
|
254
|
+
args_result.append(data)
|
|
255
|
+
return args_result
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def gen_kwargs(api_info, convert_type=None, real_data_path=None):
|
|
259
|
+
"""
|
|
260
|
+
Function Description:
|
|
261
|
+
Based on API basic information, generate input parameters: kwargs, for API forward running
|
|
262
|
+
Parameter:
|
|
263
|
+
api_info: API basic information. Dict
|
|
264
|
+
convert_type: convert ori_type to dist_type flag.
|
|
265
|
+
real_data_path: the root directory for storing real data.
|
|
266
|
+
"""
|
|
267
|
+
check_object_type(api_info, dict)
|
|
268
|
+
kwargs_params = api_info.get("input_kwargs")
|
|
269
|
+
for key, value in kwargs_params.items():
|
|
270
|
+
if isinstance(value, (list, tuple)):
|
|
271
|
+
kwargs_params[key] = gen_list_kwargs(value, convert_type, real_data_path)
|
|
272
|
+
elif value is None:
|
|
273
|
+
kwargs_params[key] = None
|
|
274
|
+
elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"):
|
|
275
|
+
kwargs_params[key] = gen_data(value, True, convert_type, real_data_path)
|
|
276
|
+
elif value.get('type') in TORCH_TYPE:
|
|
277
|
+
gen_torch_kwargs(kwargs_params, key, value)
|
|
278
|
+
else:
|
|
279
|
+
kwargs_params[key] = value.get('value')
|
|
280
|
+
return kwargs_params
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def gen_torch_kwargs(kwargs_params, key, value):
|
|
284
|
+
if value.get('type') != "torch.device":
|
|
285
|
+
kwargs_params[key] = eval(value.get('value'))
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def gen_list_kwargs(kwargs_item_value, convert_type, real_data_path=None):
|
|
289
|
+
"""
|
|
290
|
+
Function Description:
|
|
291
|
+
When kwargs value is list, generate the list of kwargs result
|
|
292
|
+
Parameter:
|
|
293
|
+
kwargs_item_value: kwargs value before to generate. List
|
|
294
|
+
convert_type: convert ori_type to dist_type flag.
|
|
295
|
+
"""
|
|
296
|
+
kwargs_item_result = []
|
|
297
|
+
for item in kwargs_item_value:
|
|
298
|
+
if item.get('type') in TENSOR_DATA_LIST:
|
|
299
|
+
item_value = gen_data(item, False, convert_type, real_data_path)
|
|
300
|
+
elif item.get('type') == "torch.Size":
|
|
301
|
+
item_value = torch.Size(item.get('value'))
|
|
302
|
+
else:
|
|
303
|
+
item_value = item.get('value')
|
|
304
|
+
kwargs_item_result.append(item_value)
|
|
305
|
+
return kwargs_item_result
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
|
|
309
|
+
"""
|
|
310
|
+
Function Description:
|
|
311
|
+
Based on API basic information, generate input parameters: args, kwargs, for API forward running
|
|
312
|
+
Parameter:
|
|
313
|
+
api_info: API basic information. Dict
|
|
314
|
+
api_name: API name
|
|
315
|
+
need_grad: set grad for backward
|
|
316
|
+
convert_type: convert ori_type to dist_type flag.
|
|
317
|
+
"""
|
|
318
|
+
check_object_type(api_info, dict)
|
|
319
|
+
if convert_type and convert_type not in Const.CONVERT:
|
|
320
|
+
error_info = f"convert_type params not support {convert_type}."
|
|
321
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
|
|
322
|
+
kwargs_params = gen_kwargs(api_info, convert_type, real_data_path)
|
|
323
|
+
if api_info.get("input_args"):
|
|
324
|
+
args_params = gen_args(api_info.get("input_args"), api_name, need_grad, convert_type, real_data_path)
|
|
325
|
+
else:
|
|
326
|
+
logger.warning(f'Warning: No args in {api_info} ')
|
|
327
|
+
args_params = []
|
|
328
|
+
return args_params, kwargs_params
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
import argparse
|
|
6
|
+
import time
|
|
7
|
+
import signal
|
|
8
|
+
import threading
|
|
9
|
+
from collections import namedtuple
|
|
10
|
+
from itertools import cycle
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_result_csv_path, \
|
|
13
|
+
get_validated_details_csv_path, preprocess_forward_content
|
|
14
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
|
|
15
|
+
from msprobe.pytorch.common import parse_json_info_forward_backward
|
|
16
|
+
from msprobe.core.common.file_check import FileChecker, check_file_suffix, check_link, FileOpen, \
|
|
17
|
+
check_path_before_create, create_directory
|
|
18
|
+
from msprobe.pytorch.common.log import logger
|
|
19
|
+
from msprobe.core.common.const import FileCheckConst
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def split_json_file(input_file, num_splits, filter_api):
|
|
23
|
+
forward_data, backward_data, real_data_path = parse_json_info_forward_backward(input_file)
|
|
24
|
+
if filter_api:
|
|
25
|
+
forward_data = preprocess_forward_content(forward_data)
|
|
26
|
+
for data_name in list(forward_data.keys()):
|
|
27
|
+
forward_data[f"{data_name}.forward"] = forward_data.pop(data_name)
|
|
28
|
+
for data_name in list(backward_data.keys()):
|
|
29
|
+
backward_data[f"{data_name}.backward"] = backward_data.pop(data_name)
|
|
30
|
+
|
|
31
|
+
with FileOpen(input_file, 'r') as file:
|
|
32
|
+
input_data = json.load(file)
|
|
33
|
+
input_data.pop("data")
|
|
34
|
+
|
|
35
|
+
items = list(forward_data.items())
|
|
36
|
+
total_items = len(items)
|
|
37
|
+
chunk_size = total_items // num_splits
|
|
38
|
+
split_files = []
|
|
39
|
+
|
|
40
|
+
for i in range(num_splits):
|
|
41
|
+
start = i * chunk_size
|
|
42
|
+
end = (i + 1) * chunk_size if i < num_splits - 1 else total_items
|
|
43
|
+
|
|
44
|
+
split_forward_data = dict(items[start:end])
|
|
45
|
+
temp_data = {
|
|
46
|
+
**input_data,
|
|
47
|
+
"data":{
|
|
48
|
+
**split_forward_data,
|
|
49
|
+
**backward_data
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
split_filename = f"temp_part{i}.json"
|
|
53
|
+
with FileOpen(split_filename, 'w') as split_file:
|
|
54
|
+
json.dump(temp_data, split_file)
|
|
55
|
+
split_files.append(split_filename)
|
|
56
|
+
|
|
57
|
+
return split_files, total_items
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def signal_handler(signum, frame):
|
|
61
|
+
logger.warning(f'Signal handler called with signal {signum}')
|
|
62
|
+
raise KeyboardInterrupt()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
signal.signal(signal.SIGINT, signal_handler)
|
|
66
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
ParallelUTConfig = namedtuple('ParallelUTConfig', ['api_files', 'out_path', 'num_splits',
|
|
70
|
+
'save_error_data_flag', 'jit_compile_flag', 'device_id',
|
|
71
|
+
'result_csv_path', 'total_items', 'real_data_path'])
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def run_parallel_ut(config):
|
|
75
|
+
processes = []
|
|
76
|
+
device_id_cycle = cycle(config.device_id)
|
|
77
|
+
if config.save_error_data_flag:
|
|
78
|
+
logger.info("UT task error datas will be saved")
|
|
79
|
+
logger.info(f"Starting parallel UT with {config.num_splits} processes")
|
|
80
|
+
progress_bar = tqdm(total=config.total_items, desc="Total items", unit="items")
|
|
81
|
+
|
|
82
|
+
def create_cmd(api_info, dev_id):
|
|
83
|
+
dirname, filename = os.path.split(os.path.abspath(__file__))
|
|
84
|
+
run_ut_path = os.path.join(dirname, "run_ut.py")
|
|
85
|
+
cmd = [
|
|
86
|
+
sys.executable, run_ut_path,
|
|
87
|
+
'-api_info', api_info,
|
|
88
|
+
*(['-o', config.out_path] if config.out_path else []),
|
|
89
|
+
'-d', str(dev_id),
|
|
90
|
+
*(['-j'] if config.jit_compile_flag else []),
|
|
91
|
+
*(['-save_error_data'] if config.save_error_data_flag else []),
|
|
92
|
+
'-csv_path', config.result_csv_path,
|
|
93
|
+
*(['-real_data_path', config.real_data_path] if config.real_data_path else [])
|
|
94
|
+
]
|
|
95
|
+
return cmd
|
|
96
|
+
|
|
97
|
+
def read_process_output(process):
|
|
98
|
+
try:
|
|
99
|
+
while True:
|
|
100
|
+
if process.poll() is not None:
|
|
101
|
+
break
|
|
102
|
+
output = process.stdout.readline()
|
|
103
|
+
if output == '':
|
|
104
|
+
break
|
|
105
|
+
if '[ERROR]' in output:
|
|
106
|
+
print(output, end='')
|
|
107
|
+
sys.stdout.flush()
|
|
108
|
+
except ValueError as e:
|
|
109
|
+
logger.warning(f"An error occurred while reading subprocess output: {e}")
|
|
110
|
+
|
|
111
|
+
def update_progress_bar(progress_bar, result_csv_path):
|
|
112
|
+
while any(process.poll() is None for process in processes):
|
|
113
|
+
try:
|
|
114
|
+
with open(result_csv_path, 'r') as result_file:
|
|
115
|
+
completed_items = len(result_file.readlines()) - 1
|
|
116
|
+
progress_bar.update(completed_items - progress_bar.n)
|
|
117
|
+
except FileNotFoundError:
|
|
118
|
+
logger.warning(f"Result CSV file not found: {result_csv_path}.")
|
|
119
|
+
except Exception as e:
|
|
120
|
+
logger.error(f"An unexpected error occurred while reading result CSV: {e}")
|
|
121
|
+
time.sleep(1)
|
|
122
|
+
|
|
123
|
+
for api_info in config.api_files:
|
|
124
|
+
cmd = create_cmd(api_info, next(device_id_cycle))
|
|
125
|
+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, bufsize=1)
|
|
126
|
+
processes.append(process)
|
|
127
|
+
threading.Thread(target=read_process_output, args=(process,), daemon=True).start()
|
|
128
|
+
|
|
129
|
+
progress_bar_thread = threading.Thread(target=update_progress_bar, args=(progress_bar, config.result_csv_path))
|
|
130
|
+
progress_bar_thread.start()
|
|
131
|
+
|
|
132
|
+
def clean_up():
|
|
133
|
+
progress_bar.close()
|
|
134
|
+
for process in processes:
|
|
135
|
+
try:
|
|
136
|
+
process.terminate()
|
|
137
|
+
process.wait(timeout=1)
|
|
138
|
+
except subprocess.TimeoutExpired:
|
|
139
|
+
process.kill()
|
|
140
|
+
for file in config.api_files:
|
|
141
|
+
check_link(file)
|
|
142
|
+
try:
|
|
143
|
+
os.remove(file)
|
|
144
|
+
except FileNotFoundError:
|
|
145
|
+
logger.warning(f"File not found and could not be deleted: {file}")
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
for process in processes:
|
|
149
|
+
process.communicate(timeout=None)
|
|
150
|
+
except KeyboardInterrupt:
|
|
151
|
+
logger.warning("Interrupted by user, terminating processes and cleaning up...")
|
|
152
|
+
except Exception as e:
|
|
153
|
+
logger.error(f"An unexpected error occurred: {e}")
|
|
154
|
+
finally:
|
|
155
|
+
if progress_bar.n < config.total_items:
|
|
156
|
+
logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to the result CSV file will be utilized to resume the UT task.")
|
|
157
|
+
clean_up()
|
|
158
|
+
progress_bar_thread.join()
|
|
159
|
+
try:
|
|
160
|
+
comparator = Comparator(config.result_csv_path, config.result_csv_path, False)
|
|
161
|
+
comparator.print_pretest_result()
|
|
162
|
+
except FileNotFoundError as e:
|
|
163
|
+
logger.error(f"Error: {e}")
|
|
164
|
+
except Exception as e:
|
|
165
|
+
logger.error(f"An unexpected error occurred: {e}")
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def prepare_config(args):
|
|
169
|
+
check_link(args.api_info_file)
|
|
170
|
+
api_info = os.path.realpath(args.api_info_file)
|
|
171
|
+
check_file_suffix(api_info, FileCheckConst.JSON_SUFFIX)
|
|
172
|
+
out_path = os.path.realpath(args.out_path) if args.out_path else "./"
|
|
173
|
+
check_path_before_create(out_path)
|
|
174
|
+
create_directory(out_path)
|
|
175
|
+
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
|
|
176
|
+
out_path = out_path_checker.common_check()
|
|
177
|
+
split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
|
|
178
|
+
|
|
179
|
+
result_csv_path = args.result_csv_path or os.path.join(out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
|
|
180
|
+
if not args.result_csv_path:
|
|
181
|
+
details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv")
|
|
182
|
+
comparator = Comparator(result_csv_path, details_csv_path, False)
|
|
183
|
+
else:
|
|
184
|
+
result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
|
|
185
|
+
details_csv_path = get_validated_details_csv_path(result_csv_path)
|
|
186
|
+
logger.info(f"UT task result will be saved in {result_csv_path}")
|
|
187
|
+
logger.info(f"UT task details will be saved in {details_csv_path}")
|
|
188
|
+
return ParallelUTConfig(split_files, out_path, args.num_splits, args.save_error_data,
|
|
189
|
+
args.jit_compile, args.device_id, result_csv_path,
|
|
190
|
+
total_items, args.real_data_path)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def main():
|
|
194
|
+
parser = argparse.ArgumentParser(description='Run UT in parallel')
|
|
195
|
+
_run_ut_parser(parser)
|
|
196
|
+
parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, help='Number of splits for parallel processing. Range: 1-64')
|
|
197
|
+
args = parser.parse_args()
|
|
198
|
+
config = prepare_config(args)
|
|
199
|
+
run_parallel_ut(config)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
if __name__ == '__main__':
|
|
203
|
+
main()
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import torch_npu
|
|
7
|
+
except ImportError:
|
|
8
|
+
is_gpu = True
|
|
9
|
+
else:
|
|
10
|
+
is_gpu = False
|
|
11
|
+
import torch
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, get_api_info
|
|
14
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents
|
|
15
|
+
from msprobe.core.common.file_check import check_link
|
|
16
|
+
from msprobe.pytorch.common.log import logger
|
|
17
|
+
|
|
18
|
+
def check_tensor_overflow(x):
|
|
19
|
+
if isinstance(x, torch.Tensor) and x.numel() != 0 and x.dtype != torch.bool:
|
|
20
|
+
if len(x.shape) == 0:
|
|
21
|
+
tensor_max = x.cpu().detach().float().numpy().tolist()
|
|
22
|
+
tensor_min = tensor_max
|
|
23
|
+
else:
|
|
24
|
+
tensor_max = torch._C._VariableFunctionsClass.max(x).cpu().detach().float().numpy().tolist()
|
|
25
|
+
tensor_min = torch._C._VariableFunctionsClass.min(x).cpu().detach().float().numpy().tolist()
|
|
26
|
+
# inf
|
|
27
|
+
if tensor_max == float('inf') or tensor_min == float('-inf'):
|
|
28
|
+
return True
|
|
29
|
+
# nan
|
|
30
|
+
elif tensor_max != tensor_max or tensor_min != tensor_min:
|
|
31
|
+
return True
|
|
32
|
+
else:
|
|
33
|
+
return False
|
|
34
|
+
elif isinstance(x, bool) or isinstance(x, int) or isinstance(x, float):
|
|
35
|
+
if x == float('inf') or x == float('-inf') or x != x:
|
|
36
|
+
return True
|
|
37
|
+
else:
|
|
38
|
+
return False
|
|
39
|
+
else:
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def check_data_overflow(x):
|
|
44
|
+
if isinstance(x, (tuple, list)) and x:
|
|
45
|
+
for _, item in enumerate(x):
|
|
46
|
+
if check_data_overflow(item):
|
|
47
|
+
return True
|
|
48
|
+
return False
|
|
49
|
+
else:
|
|
50
|
+
return check_tensor_overflow(x)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def run_overflow_check(forward_file):
|
|
54
|
+
logger.info("start UT test")
|
|
55
|
+
forward_content = get_json_contents(forward_file)
|
|
56
|
+
for api_full_name, api_info_dict in tqdm(forward_content.items()):
|
|
57
|
+
try:
|
|
58
|
+
run_torch_api(api_full_name, api_info_dict)
|
|
59
|
+
except Exception as err:
|
|
60
|
+
api_name = api_full_name.split("_", 1)[1].rsplit("_", 2)[0]
|
|
61
|
+
if "not implemented for 'Half'" in str(err):
|
|
62
|
+
logger.warning(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API "
|
|
63
|
+
f"'fp16_to_fp32' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
|
|
64
|
+
elif "expected scalar type Long" in str(err):
|
|
65
|
+
logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
|
|
66
|
+
f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
|
|
67
|
+
else:
|
|
68
|
+
logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def run_torch_api(api_full_name, api_info_dict):
|
|
72
|
+
torch.npu.clear_npu_overflow_flag()
|
|
73
|
+
api_type = api_full_name.split(".")[0]
|
|
74
|
+
api_name = api_full_name.split(".", 1)[1].rsplit(".", 2)[0]
|
|
75
|
+
args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path='')
|
|
76
|
+
if not need_grad:
|
|
77
|
+
logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
|
|
78
|
+
% api_full_name)
|
|
79
|
+
npu_args, npu_kwargs = generate_device_params(args, kwargs, False, api_name)
|
|
80
|
+
if kwargs.get("device"):
|
|
81
|
+
del kwargs["device"]
|
|
82
|
+
out = exec_api(api_type, api_name, args, kwargs)
|
|
83
|
+
npu_out = exec_api(api_type, api_name, npu_args, npu_kwargs)
|
|
84
|
+
cpu_overflow = check_data_overflow(out)
|
|
85
|
+
npu_overflow = torch_npu.npu.utils.npu_check_overflow(npu_out)
|
|
86
|
+
if cpu_overflow == npu_overflow:
|
|
87
|
+
logger.warning("The %s overflow is a normal overflow." % api_full_name)
|
|
88
|
+
else:
|
|
89
|
+
logger.warning("The %s overflow is an abnormal overflow." % api_full_name)
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _run_overflow_check_parser(parser):
|
|
94
|
+
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="",
|
|
95
|
+
help="<Required> The api param tool result file: generate from api param tool, "
|
|
96
|
+
"a json file.",
|
|
97
|
+
required=True)
|
|
98
|
+
parser.add_argument("-j", "--jit_compile", dest="jit_compile", help="<optional> whether to turn on jit compile",
|
|
99
|
+
default=False, required=False)
|
|
100
|
+
parser.add_argument("-d", "--device", dest="device_id", type=int, help="<optional> set NPU device id to run ut",
|
|
101
|
+
default=0, required=False)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _run_overflow_check(parser=None):
|
|
105
|
+
if not parser:
|
|
106
|
+
parser = argparse.ArgumentParser()
|
|
107
|
+
_run_overflow_check_parser(parser)
|
|
108
|
+
args = parser.parse_args(sys.argv[1:])
|
|
109
|
+
_run_overflow_check_command(args)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _run_overflow_check_command(args):
|
|
113
|
+
torch.npu.set_compile_mode(jit_compile=args.jit_compile)
|
|
114
|
+
npu_device = "npu:" + str(args.device_id)
|
|
115
|
+
check_link(args.api_info_file)
|
|
116
|
+
api_info = os.path.realpath(args.api_info_file)
|
|
117
|
+
try:
|
|
118
|
+
torch.npu.set_device(npu_device)
|
|
119
|
+
except Exception as error:
|
|
120
|
+
logger.error(f"Set NPU device id failed. device id is: {args.device_id}")
|
|
121
|
+
raise NotImplementedError from error
|
|
122
|
+
run_overflow_check(api_info)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
if __name__ == '__main__':
|
|
126
|
+
_run_overflow_check()
|
|
127
|
+
logger.info("UT task completed.")
|