mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- msprobe/README.md +46 -16
- msprobe/__init__.py +16 -1
- msprobe/config.json +0 -2
- msprobe/core/advisor/advisor.py +8 -8
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +64 -3
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +54 -9
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +21 -11
- msprobe/core/common/utils.py +153 -167
- msprobe/core/common_config.py +18 -25
- msprobe/core/compare/acc_compare.py +209 -36
- msprobe/core/compare/check.py +102 -17
- msprobe/core/compare/compare_cli.py +21 -1
- msprobe/core/compare/highlight.py +41 -5
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +21 -6
- msprobe/core/compare/utils.py +82 -48
- msprobe/core/data_dump/data_collector.py +31 -32
- msprobe/core/data_dump/data_processor/base.py +45 -22
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
- msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +32 -16
- msprobe/core/grad_probe/constant.py +4 -0
- msprobe/core/grad_probe/grad_compare.py +2 -3
- msprobe/core/grad_probe/utils.py +16 -3
- msprobe/docs/01.installation.md +19 -9
- msprobe/docs/02.config_introduction.md +52 -80
- msprobe/docs/03.config_examples.md +3 -13
- msprobe/docs/04.acl_config_examples.md +11 -9
- msprobe/docs/05.data_dump_PyTorch.md +140 -12
- msprobe/docs/06.data_dump_MindSpore.md +47 -5
- msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/docs/17.grad_probe.md +14 -16
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
- msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
- msprobe/mindspore/cell_processor.py +27 -3
- msprobe/mindspore/common/const.py +2 -0
- msprobe/mindspore/common/utils.py +18 -2
- msprobe/mindspore/compare/distributed_compare.py +9 -22
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +173 -35
- msprobe/mindspore/compare/ms_graph_compare.py +27 -11
- msprobe/mindspore/debugger/debugger_config.py +16 -13
- msprobe/mindspore/debugger/precision_debugger.py +37 -13
- msprobe/mindspore/dump/dump_tool_factory.py +16 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +41 -17
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
- msprobe/mindspore/free_benchmark/common/utils.py +19 -5
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
- msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
- msprobe/mindspore/grad_probe/global_context.py +18 -8
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/service.py +42 -123
- msprobe/pytorch/__init__.py +20 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +20 -5
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +26 -11
- msprobe/pytorch/common/utils.py +40 -35
- msprobe/pytorch/compare/distributed_compare.py +11 -11
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +38 -6
- msprobe/pytorch/debugger/debugger_config.py +52 -39
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/enums.py +28 -0
- msprobe/pytorch/free_benchmark/common/params.py +15 -0
- msprobe/pytorch/free_benchmark/common/utils.py +17 -1
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +10 -11
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +17 -2
- msprobe/pytorch/online_dispatch/compare.py +11 -12
- msprobe/pytorch/online_dispatch/single_compare.py +7 -7
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
- msprobe/pytorch/online_dispatch/utils.py +1 -4
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +9 -10
- msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
- msprobe/pytorch/parse_tool/lib/utils.py +28 -24
- msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
- msprobe/pytorch/pt_config.py +167 -38
- msprobe/pytorch/service.py +97 -32
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
2
|
# -*- coding: utf-8 -*-
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
#
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
7
|
# you may not use this file except in compliance with the License.
|
|
7
8
|
# You may obtain a copy of the License at
|
|
8
9
|
#
|
|
@@ -13,7 +14,6 @@
|
|
|
13
14
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
15
|
# See the License for the specific language governing permissions and
|
|
15
16
|
# limitations under the License.
|
|
16
|
-
"""
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
19
|
import math
|
|
@@ -22,19 +22,28 @@ import numpy
|
|
|
22
22
|
|
|
23
23
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
|
|
24
24
|
from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \
|
|
25
|
-
CompareException
|
|
25
|
+
CompareException, get_module_and_atttribute_name, get_attribute
|
|
26
26
|
from msprobe.core.common.file_utils import FileChecker, load_npy
|
|
27
27
|
from msprobe.pytorch.common.log import logger
|
|
28
28
|
from msprobe.pytorch.common.utils import load_pt
|
|
29
|
-
from msprobe.core.common.const import Const, FileCheckConst
|
|
29
|
+
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
30
30
|
|
|
31
31
|
TORCH_TYPE = ["torch.device", "torch.dtype"]
|
|
32
32
|
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
|
|
33
|
-
FLOAT_TYPE = [
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
33
|
+
FLOAT_TYPE = [
|
|
34
|
+
'torch.float32',
|
|
35
|
+
'torch.float',
|
|
36
|
+
'torch.float64',
|
|
37
|
+
'torch.double',
|
|
38
|
+
'torch.float16',
|
|
39
|
+
'torch.half',
|
|
40
|
+
'torch.bfloat16'
|
|
41
|
+
]
|
|
42
|
+
NUMPY_TYPE = [
|
|
43
|
+
"numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32",
|
|
44
|
+
"numpy.uint64", "numpy.float16", "numpy.float32", "numpy.float64", "numpy.float128", "numpy.complex64",
|
|
45
|
+
"numpy.complex128", "numpy.complex256", "numpy.bool_", "numpy.string_", "numpy.bytes_", "numpy.unicode_"
|
|
46
|
+
]
|
|
38
47
|
|
|
39
48
|
|
|
40
49
|
def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
|
|
@@ -68,7 +77,8 @@ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
|
|
|
68
77
|
raise Exception("{} is not supported now".format(data_type))
|
|
69
78
|
data = info.get("value")
|
|
70
79
|
try:
|
|
71
|
-
|
|
80
|
+
module_name, attribute_name = get_module_and_atttribute_name(data_type)
|
|
81
|
+
data = get_attribute(module_name, attribute_name)(data)
|
|
72
82
|
except Exception as err:
|
|
73
83
|
logger.error("Failed to convert the type to numpy: %s" % str(err))
|
|
74
84
|
elif data_type == "torch.Size":
|
|
@@ -104,8 +114,9 @@ def gen_real_tensor(data_path, convert_type):
|
|
|
104
114
|
if convert_type:
|
|
105
115
|
ori_dtype = Const.CONVERT.get(convert_type)[0]
|
|
106
116
|
dist_dtype = Const.CONVERT.get(convert_type)[1]
|
|
117
|
+
module_name, attribute_name = get_module_and_atttribute_name(dist_dtype)
|
|
107
118
|
if str(data.dtype) == ori_dtype:
|
|
108
|
-
data = data.type(
|
|
119
|
+
data = data.type(get_attribute(module_name, attribute_name))
|
|
109
120
|
return data
|
|
110
121
|
|
|
111
122
|
|
|
@@ -118,8 +129,12 @@ def gen_random_tensor(info, convert_type):
|
|
|
118
129
|
convert_type: convert ori_type to dist_type flag.
|
|
119
130
|
"""
|
|
120
131
|
check_object_type(info, dict)
|
|
121
|
-
|
|
122
|
-
low_origin
|
|
132
|
+
|
|
133
|
+
low_origin = info.get('Min')
|
|
134
|
+
low = info.get('Min_except_inf_nan', low_origin)
|
|
135
|
+
high_origin = info.get('Max')
|
|
136
|
+
high = info.get('Max_except_inf_nan', high_origin)
|
|
137
|
+
|
|
123
138
|
low_info = [low, low_origin]
|
|
124
139
|
high_info = [high, high_origin]
|
|
125
140
|
data_dtype = info.get('dtype')
|
|
@@ -164,33 +179,35 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
|
|
|
164
179
|
data_dtype = Const.CONVERT.get(convert_type)[1]
|
|
165
180
|
low, low_origin = low_info[0], low_info[1]
|
|
166
181
|
high, high_origin = high_info[0], high_info[1]
|
|
167
|
-
|
|
182
|
+
module_name, attribute_name = get_module_and_atttribute_name(data_dtype)
|
|
183
|
+
dtype = get_attribute(module_name, attribute_name)
|
|
184
|
+
if data_dtype in FLOAT_TYPE:
|
|
168
185
|
if math.isnan(high):
|
|
169
|
-
tensor = torch.
|
|
186
|
+
tensor = torch.full(shape, float('nan'), dtype=dtype)
|
|
170
187
|
return tensor
|
|
171
188
|
#high_origin为新版json中的属性,只有当high_origin不为None,且high为inf或-inf时,原tensor全为inf或-inf
|
|
172
|
-
if high_origin and high in [float(
|
|
173
|
-
tensor = torch.
|
|
189
|
+
if high_origin and high in [float(CompareConst.INF), float(CompareConst.NEG_INF)]:
|
|
190
|
+
tensor = torch.full(shape, high, dtype=dtype)
|
|
174
191
|
tensor[-1] = low
|
|
175
192
|
return tensor
|
|
176
193
|
low_scale, high_scale = low, high
|
|
177
|
-
dtype_finfo = torch.finfo(
|
|
194
|
+
dtype_finfo = torch.finfo(dtype)
|
|
178
195
|
#适配老版json high和low为inf或-inf的情况,取dtype的最大值或最小值进行放缩
|
|
179
|
-
if high == float(
|
|
196
|
+
if high == float(CompareConst.INF):
|
|
180
197
|
high_scale = dtype_finfo.max
|
|
181
|
-
elif high == float(
|
|
198
|
+
elif high == float(CompareConst.NEG_INF):
|
|
182
199
|
high_scale = dtype_finfo.min
|
|
183
|
-
if low == float(
|
|
200
|
+
if low == float(CompareConst.INF):
|
|
184
201
|
low_scale = dtype_finfo.max
|
|
185
|
-
elif low == float(
|
|
202
|
+
elif low == float(CompareConst.NEG_INF):
|
|
186
203
|
low_scale = dtype_finfo.min
|
|
187
204
|
|
|
188
205
|
scale = high_scale - low_scale
|
|
189
|
-
rand01 = torch.rand(shape, dtype=
|
|
206
|
+
rand01 = torch.rand(shape, dtype=dtype)
|
|
190
207
|
tensor = rand01 * scale + low_scale
|
|
191
208
|
elif 'int' in data_dtype or 'long' in data_dtype:
|
|
192
209
|
low, high = int(low), int(high)
|
|
193
|
-
tensor = torch.randint(low, high + 1, shape, dtype=
|
|
210
|
+
tensor = torch.randint(low, high + 1, shape, dtype=dtype)
|
|
194
211
|
else:
|
|
195
212
|
logger.error('Dtype is not supported: ' + data_dtype)
|
|
196
213
|
raise NotImplementedError()
|
|
@@ -208,9 +225,9 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
|
|
|
208
225
|
else:
|
|
209
226
|
tmp_tensor[0] = low
|
|
210
227
|
tmp_tensor[-1] = high
|
|
211
|
-
if high_origin in [float(
|
|
228
|
+
if high_origin in [float(CompareConst.INF), float(CompareConst.NEG_INF)]:
|
|
212
229
|
tmp_tensor[-1] = high_origin
|
|
213
|
-
if low_origin in [float(
|
|
230
|
+
if low_origin in [float(CompareConst.INF), float(CompareConst.NEG_INF)]:
|
|
214
231
|
tmp_tensor[0] = low_origin
|
|
215
232
|
data = tmp_tensor.reshape(shape)
|
|
216
233
|
return data
|
|
@@ -233,7 +250,7 @@ def gen_bool_tensor(low, high, shape):
|
|
|
233
250
|
return data
|
|
234
251
|
|
|
235
252
|
|
|
236
|
-
def gen_args(args_info, api_name,
|
|
253
|
+
def gen_args(args_info, api_name, func_options):
|
|
237
254
|
"""
|
|
238
255
|
Function Description:
|
|
239
256
|
Based on API basic information, generate input parameters: args, for API forward running
|
|
@@ -246,9 +263,20 @@ def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_p
|
|
|
246
263
|
"""
|
|
247
264
|
check_object_type(args_info, list)
|
|
248
265
|
args_result = []
|
|
266
|
+
|
|
267
|
+
need_grad = func_options.get('need_grad', True)
|
|
268
|
+
convert_type = func_options.get('convert_type', None)
|
|
269
|
+
real_data_path = func_options.get('real_data_path', None)
|
|
270
|
+
depth = func_options.get('depth', 0)
|
|
271
|
+
|
|
272
|
+
if depth > Const.MAX_DEPTH:
|
|
273
|
+
logger.error("The depth of args is too large, please check the input args.")
|
|
274
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
275
|
+
|
|
249
276
|
for arg in args_info:
|
|
250
277
|
if isinstance(arg, (list, tuple)):
|
|
251
|
-
|
|
278
|
+
func_options['depth'] = depth + 1
|
|
279
|
+
data = gen_args(arg, api_name, func_options)
|
|
252
280
|
elif isinstance(arg, dict):
|
|
253
281
|
data = gen_data(arg, api_name, need_grad, convert_type, real_data_path)
|
|
254
282
|
elif arg is None:
|
|
@@ -288,7 +316,8 @@ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
|
|
|
288
316
|
|
|
289
317
|
def gen_torch_kwargs(kwargs_params, key, value):
|
|
290
318
|
if value.get('type') != "torch.device":
|
|
291
|
-
|
|
319
|
+
module_name, attribute_name = get_module_and_atttribute_name(value.get('value'))
|
|
320
|
+
kwargs_params[key] = get_attribute(module_name, attribute_name)
|
|
292
321
|
|
|
293
322
|
|
|
294
323
|
def gen_list_kwargs(kwargs_item_value, api_name, convert_type, real_data_path=None):
|
|
@@ -327,8 +356,14 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d
|
|
|
327
356
|
error_info = f"convert_type params not support {convert_type}."
|
|
328
357
|
raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
|
|
329
358
|
kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path)
|
|
359
|
+
func_options = {
|
|
360
|
+
'need_grad': need_grad,
|
|
361
|
+
'convert_type': convert_type,
|
|
362
|
+
'real_data_path': real_data_path,
|
|
363
|
+
'depth': 0
|
|
364
|
+
}
|
|
330
365
|
if api_info.get("input_args"):
|
|
331
|
-
args_params = gen_args(api_info.get("input_args"), api_name,
|
|
366
|
+
args_params = gen_args(api_info.get("input_args"), api_name, func_options)
|
|
332
367
|
else:
|
|
333
368
|
logger.warning(f'Warning: No args in {api_info} ')
|
|
334
369
|
args_params = []
|
|
@@ -1,3 +1,20 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
1
18
|
import subprocess
|
|
2
19
|
import json
|
|
3
20
|
import os
|
|
@@ -105,7 +122,7 @@ def run_parallel_ut(config):
|
|
|
105
122
|
if output == '':
|
|
106
123
|
break
|
|
107
124
|
if '[ERROR]' in output:
|
|
108
|
-
|
|
125
|
+
logger.warning(output, end='')
|
|
109
126
|
sys.stdout.flush()
|
|
110
127
|
except ValueError as e:
|
|
111
128
|
logger.warning(f"An error occurred while reading subprocess output: {e}")
|
|
@@ -119,7 +136,8 @@ def run_parallel_ut(config):
|
|
|
119
136
|
|
|
120
137
|
for api_info in config.api_files:
|
|
121
138
|
cmd = create_cmd(api_info, next(device_id_cycle))
|
|
122
|
-
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
|
|
139
|
+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
|
|
140
|
+
text=True, bufsize=1, shell=False)
|
|
123
141
|
processes.append(process)
|
|
124
142
|
threading.Thread(target=read_process_output, args=(process,), daemon=True).start()
|
|
125
143
|
|
|
@@ -150,7 +168,8 @@ def run_parallel_ut(config):
|
|
|
150
168
|
logger.error(f"An unexpected error occurred: {e}")
|
|
151
169
|
finally:
|
|
152
170
|
if progress_bar.n < config.total_items:
|
|
153
|
-
logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to
|
|
171
|
+
logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to " \
|
|
172
|
+
"the result CSV file will be utilized to resume the UT task.")
|
|
154
173
|
clean_up()
|
|
155
174
|
progress_bar_thread.join()
|
|
156
175
|
try:
|
|
@@ -173,7 +192,8 @@ def prepare_config(args):
|
|
|
173
192
|
out_path = out_path_checker.common_check()
|
|
174
193
|
split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
|
|
175
194
|
config_path = os.path.realpath(args.config_path) if args.config_path else None
|
|
176
|
-
result_csv_path = args.result_csv_path or os.path.join(
|
|
195
|
+
result_csv_path = args.result_csv_path or os.path.join(
|
|
196
|
+
out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
|
|
177
197
|
if not args.result_csv_path:
|
|
178
198
|
details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv")
|
|
179
199
|
comparator = Comparator(result_csv_path, details_csv_path, False)
|
|
@@ -190,7 +210,8 @@ def prepare_config(args):
|
|
|
190
210
|
def main():
|
|
191
211
|
parser = argparse.ArgumentParser(description='Run UT in parallel')
|
|
192
212
|
_run_ut_parser(parser)
|
|
193
|
-
parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
|
|
213
|
+
parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
|
|
214
|
+
help='Number of splits for parallel processing. Range: 1-64')
|
|
194
215
|
args = parser.parse_args()
|
|
195
216
|
config = prepare_config(args)
|
|
196
217
|
run_parallel_ut(config)
|
|
@@ -1,3 +1,20 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
1
18
|
import argparse
|
|
2
19
|
import os
|
|
3
20
|
import sys
|
|
@@ -24,8 +41,8 @@ def check_tensor_overflow(x):
|
|
|
24
41
|
tensor_max = x.cpu().detach().float().numpy().tolist()
|
|
25
42
|
tensor_min = tensor_max
|
|
26
43
|
else:
|
|
27
|
-
tensor_max = torch.
|
|
28
|
-
tensor_min = torch.
|
|
44
|
+
tensor_max = torch.max(x).cpu().detach().float().numpy().tolist()
|
|
45
|
+
tensor_min = torch.min(x).cpu().detach().float().numpy().tolist()
|
|
29
46
|
# inf
|
|
30
47
|
if tensor_max == float('inf') or tensor_min == float('-inf'):
|
|
31
48
|
return True
|
|
@@ -1,3 +1,20 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
1
18
|
import argparse
|
|
2
19
|
import os
|
|
3
20
|
import csv
|
|
@@ -17,8 +34,8 @@ else:
|
|
|
17
34
|
import torch
|
|
18
35
|
from tqdm import tqdm
|
|
19
36
|
|
|
20
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import
|
|
21
|
-
get_validated_result_csv_path, get_validated_details_csv_path, exec_api
|
|
37
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import BackwardMessage, UtDataInfo, \
|
|
38
|
+
get_validated_result_csv_path, get_validated_details_csv_path, exec_api, record_skip_info
|
|
22
39
|
from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
|
|
23
40
|
from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
|
|
24
41
|
initialize_save_path, UtDataProcessor, extract_basic_api_segments, ApiData
|
|
@@ -26,13 +43,14 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
|
|
|
26
43
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
|
|
27
44
|
from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
|
|
28
45
|
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
29
|
-
from msprobe.core.common.file_utils import
|
|
30
|
-
|
|
46
|
+
from msprobe.core.common.file_utils import FileChecker, change_mode, check_path_before_create, \
|
|
47
|
+
create_directory, get_json_contents, read_csv
|
|
31
48
|
from msprobe.pytorch.common.log import logger
|
|
32
49
|
from msprobe.pytorch.pt_config import parse_json_config
|
|
33
50
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
34
51
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
|
|
35
52
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
|
|
53
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params
|
|
36
54
|
|
|
37
55
|
|
|
38
56
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
@@ -46,14 +64,7 @@ RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content',
|
|
|
46
64
|
OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
|
|
47
65
|
|
|
48
66
|
not_backward_list = ['repeat_interleave']
|
|
49
|
-
not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
|
|
50
|
-
not_raise_dtype_set = {'type_as'}
|
|
51
67
|
|
|
52
|
-
RAISE_PRECISION = {
|
|
53
|
-
torch.float16: torch.float32,
|
|
54
|
-
torch.bfloat16: torch.float32,
|
|
55
|
-
torch.float32: torch.float64
|
|
56
|
-
}
|
|
57
68
|
|
|
58
69
|
tqdm_params = {
|
|
59
70
|
'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1
|
|
@@ -71,98 +82,6 @@ tqdm_params = {
|
|
|
71
82
|
}
|
|
72
83
|
|
|
73
84
|
|
|
74
|
-
def deal_detach(arg, to_detach=True):
|
|
75
|
-
return arg.detach() if to_detach else arg
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
def raise_bench_data_dtype(api_name, arg, raise_dtype=None):
|
|
79
|
-
'''
|
|
80
|
-
将标杆数据的dtype转换为raise_dtype
|
|
81
|
-
输入:
|
|
82
|
-
api_name:api名称
|
|
83
|
-
arg:标杆输入
|
|
84
|
-
raise_dtype:需要转换的dtype
|
|
85
|
-
输出:
|
|
86
|
-
arg: 转换dtype的标杆输入
|
|
87
|
-
'''
|
|
88
|
-
if api_name in hf_32_standard_api and arg.dtype == torch.float32:
|
|
89
|
-
return arg
|
|
90
|
-
if raise_dtype is None or arg.dtype not in RAISE_PRECISION or raise_dtype == arg.dtype:
|
|
91
|
-
return arg
|
|
92
|
-
return arg.type(raise_dtype)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def generate_device_params(input_args, input_kwargs, need_backward, api_name):
|
|
96
|
-
def recursive_arg_to_device(arg_in, to_detach):
|
|
97
|
-
if isinstance(arg_in, (list, tuple)):
|
|
98
|
-
return type(arg_in)(recursive_arg_to_device(arg, to_detach) for arg in arg_in)
|
|
99
|
-
elif isinstance(arg_in, torch.Tensor):
|
|
100
|
-
if need_backward and arg_in.requires_grad:
|
|
101
|
-
arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_()
|
|
102
|
-
temp_arg_in = arg_in * 1
|
|
103
|
-
arg_in = temp_arg_in.type_as(arg_in)
|
|
104
|
-
arg_in.retain_grad()
|
|
105
|
-
return arg_in
|
|
106
|
-
else:
|
|
107
|
-
return deal_detach(arg_in.clone(), to_detach).to(current_device)
|
|
108
|
-
else:
|
|
109
|
-
return arg_in
|
|
110
|
-
|
|
111
|
-
is_detach = api_name not in not_detach_set
|
|
112
|
-
device_args = recursive_arg_to_device(input_args, is_detach)
|
|
113
|
-
device_kwargs = \
|
|
114
|
-
{key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()}
|
|
115
|
-
return device_args, device_kwargs
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
|
|
119
|
-
def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None):
|
|
120
|
-
if isinstance(arg_in, (list, tuple)):
|
|
121
|
-
return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype) for arg in arg_in)
|
|
122
|
-
elif isinstance(arg_in, torch.Tensor):
|
|
123
|
-
if need_backward and arg_in.requires_grad:
|
|
124
|
-
arg_in = deal_detach(raise_bench_data_dtype(
|
|
125
|
-
api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
|
|
126
|
-
temp_arg_in = arg_in * 1
|
|
127
|
-
arg_in = temp_arg_in.type_as(arg_in)
|
|
128
|
-
arg_in.retain_grad()
|
|
129
|
-
return arg_in
|
|
130
|
-
else:
|
|
131
|
-
return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach)
|
|
132
|
-
else:
|
|
133
|
-
return arg_in
|
|
134
|
-
|
|
135
|
-
def is_tensor_with_raise_precision(arg_in, check_kwargs=False):
|
|
136
|
-
if arg_in.dtype in RAISE_PRECISION:
|
|
137
|
-
return True
|
|
138
|
-
if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]:
|
|
139
|
-
return True
|
|
140
|
-
return False
|
|
141
|
-
|
|
142
|
-
def recursive_find_dtypes(arg_in, kwargs=None, check_kwargs=False):
|
|
143
|
-
if isinstance(arg_in, (list, tuple)):
|
|
144
|
-
return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs) for arg in arg_in))
|
|
145
|
-
elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
|
|
146
|
-
return set([arg_in.dtype])
|
|
147
|
-
elif isinstance(arg_in, dict) and check_kwargs:
|
|
148
|
-
return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True) for v in arg_in.values()))
|
|
149
|
-
return set()
|
|
150
|
-
|
|
151
|
-
raise_dtype = None
|
|
152
|
-
need_raise_dtypes = recursive_find_dtypes(input_args)
|
|
153
|
-
need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
|
|
154
|
-
if len(need_raise_dtypes) == 1:
|
|
155
|
-
raise_dtype = RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32)
|
|
156
|
-
elif len(need_raise_dtypes) >= 2:
|
|
157
|
-
raise_dtype = torch.float32
|
|
158
|
-
|
|
159
|
-
raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
|
|
160
|
-
is_detach = api_name not in not_detach_set
|
|
161
|
-
cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
|
|
162
|
-
cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()}
|
|
163
|
-
return cpu_args, cpu_kwargs
|
|
164
|
-
|
|
165
|
-
|
|
166
85
|
def run_ut(config):
|
|
167
86
|
logger.info("start UT test")
|
|
168
87
|
if config.online_config.is_online:
|
|
@@ -179,10 +98,8 @@ def run_ut(config):
|
|
|
179
98
|
if config.online_config.is_online:
|
|
180
99
|
run_api_online(config, compare)
|
|
181
100
|
else:
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
next(csv_reader)
|
|
185
|
-
api_name_set = {row[0] for row in csv_reader}
|
|
101
|
+
csv_df = read_csv(config.result_csv_path)
|
|
102
|
+
api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
|
|
186
103
|
run_api_offline(config, compare, api_name_set)
|
|
187
104
|
for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
|
|
188
105
|
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
@@ -198,17 +115,23 @@ def run_api_offline(config, compare, api_name_set):
|
|
|
198
115
|
if api_full_name in api_name_set:
|
|
199
116
|
continue
|
|
200
117
|
if is_unsupported_api(api_full_name):
|
|
118
|
+
skip_message = f"API {api_full_name} not support for run ut. SKIP."
|
|
119
|
+
compare_alg_results = err_column.to_column_value(CompareConst.SKIP, skip_message)
|
|
120
|
+
record_skip_info(api_full_name, compare, compare_alg_results)
|
|
201
121
|
continue
|
|
202
122
|
_, api_name = extract_basic_api_segments(api_full_name)
|
|
203
123
|
if not api_name:
|
|
204
124
|
err_message = f"API {api_full_name} not support for run ut. SKIP."
|
|
205
125
|
logger.error(err_message)
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
compare.record_results(result_info)
|
|
126
|
+
compare_alg_results = err_column.to_column_value(CompareConst.SKIP, err_message)
|
|
127
|
+
record_skip_info(api_full_name, compare, compare_alg_results)
|
|
209
128
|
continue
|
|
210
129
|
try:
|
|
211
130
|
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
131
|
+
skip_message = f"API {api_name} in black list or not in white list. SKIP."
|
|
132
|
+
logger.info(skip_message)
|
|
133
|
+
compare_alg_results = err_column.to_column_value(CompareConst.SKIP, skip_message)
|
|
134
|
+
record_skip_info(api_full_name, compare, compare_alg_results)
|
|
212
135
|
continue
|
|
213
136
|
data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict)
|
|
214
137
|
is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info)
|
|
@@ -220,9 +143,8 @@ def run_api_offline(config, compare, api_name_set):
|
|
|
220
143
|
f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
|
|
221
144
|
else:
|
|
222
145
|
logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
compare.record_results(result_info)
|
|
146
|
+
compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
|
|
147
|
+
record_skip_info(api_full_name, compare, compare_alg_results)
|
|
226
148
|
finally:
|
|
227
149
|
if is_gpu:
|
|
228
150
|
torch.cuda.empty_cache()
|
|
@@ -327,12 +249,12 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
327
249
|
in_fwd_data_list.append(kwargs)
|
|
328
250
|
need_backward = api_full_name in backward_content
|
|
329
251
|
if not need_grad:
|
|
330
|
-
logger.warning("%s %s" % (api_full_name,
|
|
331
|
-
backward_message +=
|
|
252
|
+
logger.warning("%s %s" % (api_full_name, BackwardMessage.UNSUPPORT_BACKWARD_MESSAGE))
|
|
253
|
+
backward_message += BackwardMessage.UNSUPPORT_BACKWARD_MESSAGE
|
|
332
254
|
if api_name in not_backward_list:
|
|
333
255
|
need_grad = False
|
|
334
|
-
logger.warning("%s %s" % (api_full_name,
|
|
335
|
-
backward_message +=
|
|
256
|
+
logger.warning("%s %s" % (api_full_name, BackwardMessage.NO_BACKWARD_RESULT_MESSAGE))
|
|
257
|
+
backward_message += BackwardMessage.NO_BACKWARD_RESULT_MESSAGE
|
|
336
258
|
need_backward = need_backward and need_grad
|
|
337
259
|
if kwargs.get("device"):
|
|
338
260
|
del kwargs["device"]
|
|
@@ -353,13 +275,16 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
353
275
|
if need_backward:
|
|
354
276
|
if need_to_backward(grad_index, out):
|
|
355
277
|
backward_args = backward_content[api_full_name].get("input")
|
|
356
|
-
|
|
278
|
+
func_options = {
|
|
279
|
+
'real_data_path': real_data_path
|
|
280
|
+
}
|
|
281
|
+
grad = gen_args(backward_args, api_name, func_options)[0]
|
|
357
282
|
bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
|
|
358
283
|
bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
|
|
359
284
|
device_grad = grad.clone().detach().to(current_device)
|
|
360
285
|
device_grad_out = run_backward(device_args, device_grad, grad_index, device_out)
|
|
361
286
|
else:
|
|
362
|
-
backward_message +=
|
|
287
|
+
backward_message += BackwardMessage.MULTIPLE_BACKWARD_MESSAGE
|
|
363
288
|
if api_name == "npu_fusion_attention":
|
|
364
289
|
out = out[0]
|
|
365
290
|
device_out = device_out[0]
|
|
@@ -416,7 +341,7 @@ def initialize_save_error_data(error_data_path):
|
|
|
416
341
|
error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
|
|
417
342
|
ability=FileCheckConst.WRITE_ABLE)
|
|
418
343
|
error_data_path = error_data_path_checker.common_check()
|
|
419
|
-
error_data_path =initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
|
|
344
|
+
error_data_path = initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
|
|
420
345
|
return error_data_path
|
|
421
346
|
|
|
422
347
|
|
|
@@ -477,7 +402,8 @@ def preprocess_forward_content(forward_content):
|
|
|
477
402
|
if key not in arg_cache:
|
|
478
403
|
filtered_new_args = [
|
|
479
404
|
{k: v for k, v in arg.items() if k not in ['Max', 'Min']}
|
|
480
|
-
for arg in value['input_args']
|
|
405
|
+
for arg in value['input_args']
|
|
406
|
+
if isinstance(arg, dict)
|
|
481
407
|
]
|
|
482
408
|
arg_cache[key] = (filtered_new_args, value['input_kwargs'])
|
|
483
409
|
|
|
@@ -529,14 +455,14 @@ def run_ut_command(args):
|
|
|
529
455
|
# 离线场景下,forward_content, backward_content, real_data_path从api_info_file中解析
|
|
530
456
|
forward_content, backward_content, real_data_path = None, None, None
|
|
531
457
|
if args.api_info_file:
|
|
532
|
-
api_info_file_checker = FileChecker(file_path
|
|
533
|
-
ability
|
|
458
|
+
api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
|
|
459
|
+
ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
534
460
|
checked_api_info = api_info_file_checker.common_check()
|
|
535
461
|
forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info)
|
|
536
462
|
if args.filter_api:
|
|
537
|
-
logger.info("Start filtering the api in the
|
|
463
|
+
logger.info("Start filtering the api in the api_info_file.")
|
|
538
464
|
forward_content = preprocess_forward_content(forward_content)
|
|
539
|
-
logger.info("Finish filtering the api in the
|
|
465
|
+
logger.info("Finish filtering the api in the api_info_file.")
|
|
540
466
|
|
|
541
467
|
out_path = os.path.realpath(args.out_path) if args.out_path else "./"
|
|
542
468
|
check_path_before_create(out_path)
|