mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- msprobe/README.md +46 -16
- msprobe/__init__.py +16 -1
- msprobe/config.json +0 -2
- msprobe/core/advisor/advisor.py +8 -8
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +64 -3
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +54 -9
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +21 -11
- msprobe/core/common/utils.py +153 -167
- msprobe/core/common_config.py +18 -25
- msprobe/core/compare/acc_compare.py +209 -36
- msprobe/core/compare/check.py +102 -17
- msprobe/core/compare/compare_cli.py +21 -1
- msprobe/core/compare/highlight.py +41 -5
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +21 -6
- msprobe/core/compare/utils.py +82 -48
- msprobe/core/data_dump/data_collector.py +31 -32
- msprobe/core/data_dump/data_processor/base.py +45 -22
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
- msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +32 -16
- msprobe/core/grad_probe/constant.py +4 -0
- msprobe/core/grad_probe/grad_compare.py +2 -3
- msprobe/core/grad_probe/utils.py +16 -3
- msprobe/docs/01.installation.md +19 -9
- msprobe/docs/02.config_introduction.md +52 -80
- msprobe/docs/03.config_examples.md +3 -13
- msprobe/docs/04.acl_config_examples.md +11 -9
- msprobe/docs/05.data_dump_PyTorch.md +140 -12
- msprobe/docs/06.data_dump_MindSpore.md +47 -5
- msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/docs/17.grad_probe.md +14 -16
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
- msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
- msprobe/mindspore/cell_processor.py +27 -3
- msprobe/mindspore/common/const.py +2 -0
- msprobe/mindspore/common/utils.py +18 -2
- msprobe/mindspore/compare/distributed_compare.py +9 -22
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +173 -35
- msprobe/mindspore/compare/ms_graph_compare.py +27 -11
- msprobe/mindspore/debugger/debugger_config.py +16 -13
- msprobe/mindspore/debugger/precision_debugger.py +37 -13
- msprobe/mindspore/dump/dump_tool_factory.py +16 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +41 -17
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
- msprobe/mindspore/free_benchmark/common/utils.py +19 -5
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
- msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
- msprobe/mindspore/grad_probe/global_context.py +18 -8
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/service.py +42 -123
- msprobe/pytorch/__init__.py +20 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +20 -5
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +26 -11
- msprobe/pytorch/common/utils.py +40 -35
- msprobe/pytorch/compare/distributed_compare.py +11 -11
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +38 -6
- msprobe/pytorch/debugger/debugger_config.py +52 -39
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/enums.py +28 -0
- msprobe/pytorch/free_benchmark/common/params.py +15 -0
- msprobe/pytorch/free_benchmark/common/utils.py +17 -1
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +10 -11
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +17 -2
- msprobe/pytorch/online_dispatch/compare.py +11 -12
- msprobe/pytorch/online_dispatch/single_compare.py +7 -7
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
- msprobe/pytorch/online_dispatch/utils.py +1 -4
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +9 -10
- msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
- msprobe/pytorch/parse_tool/lib/utils.py +28 -24
- msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
- msprobe/pytorch/pt_config.py +167 -38
- msprobe/pytorch/service.py +97 -32
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,35 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
1
18
|
import os
|
|
2
19
|
import re
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import torch_npu
|
|
24
|
+
except ImportError:
|
|
25
|
+
current_device = "cuda"
|
|
26
|
+
else:
|
|
27
|
+
current_device = "npu"
|
|
3
28
|
|
|
4
|
-
from msprobe.core.common.const import FileCheckConst
|
|
29
|
+
from msprobe.core.common.const import FileCheckConst, Const, CompareConst
|
|
5
30
|
from msprobe.core.common.file_utils import FileChecker
|
|
31
|
+
from msprobe.core.common.log import logger
|
|
32
|
+
from msprobe.core.common.utils import CompareException
|
|
6
33
|
from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
|
|
7
34
|
from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
|
|
8
35
|
from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
|
|
@@ -10,11 +37,20 @@ from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
|
|
|
10
37
|
from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
|
|
11
38
|
|
|
12
39
|
hf_32_standard_api = ["conv1d", "conv2d"]
|
|
40
|
+
not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
|
|
41
|
+
not_raise_dtype_set = {'type_as'}
|
|
42
|
+
|
|
43
|
+
PRECISION_MAPPING = {
|
|
44
|
+
torch.float16: torch.float32,
|
|
45
|
+
torch.bfloat16: torch.float32,
|
|
46
|
+
torch.float32: torch.float64
|
|
47
|
+
}
|
|
13
48
|
|
|
14
49
|
|
|
15
|
-
class
|
|
50
|
+
class BackwardMessage:
|
|
16
51
|
MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
|
|
17
|
-
UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation,
|
|
52
|
+
UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, " \
|
|
53
|
+
"skip backward."
|
|
18
54
|
NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
|
|
19
55
|
|
|
20
56
|
|
|
@@ -68,3 +104,110 @@ def exec_api(api_type, api_name, device, args, kwargs):
|
|
|
68
104
|
torch_api = NpuOPTemplate(api_name, None, False, device)
|
|
69
105
|
out = torch_api.forward(*args, **kwargs)
|
|
70
106
|
return out
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def deal_detach(arg, to_detach=True):
|
|
110
|
+
return arg.detach() if to_detach else arg
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def raise_bench_data_dtype(api_name, arg, raise_dtype=None):
|
|
114
|
+
'''
|
|
115
|
+
将标杆数据的dtype转换为raise_dtype
|
|
116
|
+
输入:
|
|
117
|
+
api_name:api名称
|
|
118
|
+
arg:标杆输入
|
|
119
|
+
raise_dtype:需要转换的dtype
|
|
120
|
+
输出:
|
|
121
|
+
arg: 转换dtype的标杆输入
|
|
122
|
+
'''
|
|
123
|
+
if api_name in hf_32_standard_api and arg.dtype == torch.float32:
|
|
124
|
+
return arg
|
|
125
|
+
if raise_dtype is None or arg.dtype not in PRECISION_MAPPING or raise_dtype == arg.dtype:
|
|
126
|
+
return arg
|
|
127
|
+
return arg.type(raise_dtype)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def generate_device_params(input_args, input_kwargs, need_backward, api_name):
|
|
131
|
+
def recursive_arg_to_device(arg_in, to_detach, depth=0):
|
|
132
|
+
if depth > Const.MAX_DEPTH:
|
|
133
|
+
logger.error("The depth of arg_in is too large, please check the arg_in.")
|
|
134
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
135
|
+
if isinstance(arg_in, (list, tuple)):
|
|
136
|
+
return type(arg_in)(recursive_arg_to_device(arg, to_detach, depth=depth+1) for arg in arg_in)
|
|
137
|
+
elif isinstance(arg_in, torch.Tensor):
|
|
138
|
+
if need_backward and arg_in.requires_grad:
|
|
139
|
+
arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_()
|
|
140
|
+
temp_arg_in = arg_in * 1
|
|
141
|
+
arg_in = temp_arg_in.type_as(arg_in)
|
|
142
|
+
arg_in.retain_grad()
|
|
143
|
+
return arg_in
|
|
144
|
+
else:
|
|
145
|
+
return deal_detach(arg_in.clone(), to_detach).to(current_device)
|
|
146
|
+
else:
|
|
147
|
+
return arg_in
|
|
148
|
+
|
|
149
|
+
is_detach = api_name not in not_detach_set
|
|
150
|
+
device_args = recursive_arg_to_device(input_args, is_detach)
|
|
151
|
+
device_kwargs = \
|
|
152
|
+
{key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()}
|
|
153
|
+
return device_args, device_kwargs
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
|
|
157
|
+
def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None, depth=0):
|
|
158
|
+
if depth > Const.MAX_DEPTH:
|
|
159
|
+
logger.error("The depth of arg_in is too large, please check the arg_in.")
|
|
160
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
161
|
+
if isinstance(arg_in, (list, tuple)):
|
|
162
|
+
return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype, depth=depth+1)
|
|
163
|
+
for arg in arg_in)
|
|
164
|
+
elif isinstance(arg_in, torch.Tensor):
|
|
165
|
+
if need_backward and arg_in.requires_grad:
|
|
166
|
+
arg_in = deal_detach(raise_bench_data_dtype(
|
|
167
|
+
api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
|
|
168
|
+
temp_arg_in = arg_in * 1
|
|
169
|
+
arg_in = temp_arg_in.type_as(arg_in)
|
|
170
|
+
arg_in.retain_grad()
|
|
171
|
+
return arg_in
|
|
172
|
+
else:
|
|
173
|
+
return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach)
|
|
174
|
+
else:
|
|
175
|
+
return arg_in
|
|
176
|
+
|
|
177
|
+
def is_tensor_with_raise_precision(arg_in, check_kwargs=False):
|
|
178
|
+
if arg_in.dtype in PRECISION_MAPPING:
|
|
179
|
+
return True
|
|
180
|
+
if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]:
|
|
181
|
+
return True
|
|
182
|
+
return False
|
|
183
|
+
|
|
184
|
+
def recursive_find_dtypes(arg_in, kwargs=None, check_kwargs=False, depth=0):
|
|
185
|
+
if depth > Const.MAX_DEPTH:
|
|
186
|
+
logger.error("The depth of arg_in is too large, please check the arg_in.")
|
|
187
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
188
|
+
if isinstance(arg_in, (list, tuple)):
|
|
189
|
+
return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs, depth=depth+1) for arg in arg_in))
|
|
190
|
+
elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
|
|
191
|
+
return set([arg_in.dtype])
|
|
192
|
+
elif isinstance(arg_in, dict) and check_kwargs:
|
|
193
|
+
return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True, depth=depth+1) for v in arg_in.values()))
|
|
194
|
+
return set()
|
|
195
|
+
|
|
196
|
+
raise_dtype = None
|
|
197
|
+
need_raise_dtypes = recursive_find_dtypes(input_args)
|
|
198
|
+
need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
|
|
199
|
+
if len(need_raise_dtypes) == 1:
|
|
200
|
+
raise_dtype = PRECISION_MAPPING.get(need_raise_dtypes.pop(), torch.float32)
|
|
201
|
+
elif len(need_raise_dtypes) >= 2:
|
|
202
|
+
raise_dtype = torch.float32
|
|
203
|
+
|
|
204
|
+
raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
|
|
205
|
+
is_detach = api_name not in not_detach_set
|
|
206
|
+
cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
|
|
207
|
+
cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()}
|
|
208
|
+
return cpu_args, cpu_kwargs
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def record_skip_info(api_full_name, compare, compare_alg_results):
|
|
212
|
+
result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [compare_alg_results], None, 0)
|
|
213
|
+
compare.record_results(result_info)
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import glob
|
|
2
17
|
import os.path
|
|
3
18
|
import time
|
|
@@ -41,6 +56,7 @@ class ATTL:
|
|
|
41
56
|
self.message_end = False
|
|
42
57
|
self.kill_progress = False
|
|
43
58
|
self.check_attl_config()
|
|
59
|
+
self.nfs_path = None
|
|
44
60
|
if self.session_config.nfs_path:
|
|
45
61
|
self.nfs_path = self.session_config.nfs_path
|
|
46
62
|
elif self.session_config.is_benchmark_device:
|
|
@@ -77,6 +93,11 @@ class ATTL:
|
|
|
77
93
|
"""
|
|
78
94
|
npu major in 'send' (client)
|
|
79
95
|
"""
|
|
96
|
+
|
|
97
|
+
# if tcp connection lost,
|
|
98
|
+
if self.socket_manager.signal_exit:
|
|
99
|
+
raise ConnectionError(f"Failed to connect to {self.session_config.connect_ip}.")
|
|
100
|
+
|
|
80
101
|
# know receiver receive and go next
|
|
81
102
|
if isinstance(buffer, ApiData):
|
|
82
103
|
buffer = move2target_device(buffer, torch.device('cpu'))
|
|
@@ -1,10 +1,24 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import hashlib
|
|
2
17
|
import io
|
|
3
18
|
import struct
|
|
4
19
|
import time
|
|
5
20
|
import os
|
|
6
21
|
import signal
|
|
7
|
-
import sys
|
|
8
22
|
from queue import Queue
|
|
9
23
|
from threading import Thread
|
|
10
24
|
from typing import Union
|
|
@@ -13,7 +27,10 @@ from twisted.internet import reactor, protocol, endpoints
|
|
|
13
27
|
from twisted.protocols.basic import FileSender
|
|
14
28
|
|
|
15
29
|
from msprobe.pytorch.common.utils import logger
|
|
16
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.
|
|
30
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import struct_unpack_mode as unpack_mode, \
|
|
31
|
+
str_to_bytes_order as bytes_order
|
|
32
|
+
|
|
33
|
+
MAX_SENDING_QUEUE_SIZE = 20
|
|
17
34
|
|
|
18
35
|
|
|
19
36
|
class TCPDataItem:
|
|
@@ -31,7 +48,6 @@ class TCPDataItem:
|
|
|
31
48
|
|
|
32
49
|
|
|
33
50
|
class TCPClient:
|
|
34
|
-
MAX_SENDING_QUEUE_SIZE = 20
|
|
35
51
|
ACK_SUCCESS = b"OK___"
|
|
36
52
|
ACK_ERROR = b"ERROR"
|
|
37
53
|
ACK_BUSY = b"BUSY_"
|
|
@@ -39,13 +55,13 @@ class TCPClient:
|
|
|
39
55
|
ACK_STOP_CONFIRM = b"OVER_"
|
|
40
56
|
ACK_KILL_PROCESS = b"KILL_"
|
|
41
57
|
|
|
42
|
-
QUEUE_PENDING_TIME =
|
|
58
|
+
QUEUE_PENDING_TIME = 60
|
|
43
59
|
RESEND_RETRY_TIMES = 2 # 最大重传数
|
|
44
60
|
RESEND_TIMER_TIME = 5 # 接收ACK超时定时器
|
|
45
61
|
RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据
|
|
46
62
|
|
|
47
63
|
def __init__(self, host="localhost", port=8000, check_sum=False, tls_path=None):
|
|
48
|
-
self.send_queue = Queue(
|
|
64
|
+
self.send_queue = Queue(MAX_SENDING_QUEUE_SIZE)
|
|
49
65
|
self.resend_dict = dict()
|
|
50
66
|
self.host = host
|
|
51
67
|
self.port = port
|
|
@@ -55,7 +71,8 @@ class TCPClient:
|
|
|
55
71
|
self.signal_exit = False
|
|
56
72
|
self.tcp_manager = ClientProtocol(ack_queue_size=100,
|
|
57
73
|
chunk_size=655360,
|
|
58
|
-
check_sum=check_sum
|
|
74
|
+
check_sum=check_sum,
|
|
75
|
+
tls=self.tls_path)
|
|
59
76
|
self.send_thread = Thread(target=self._sending_queue_data)
|
|
60
77
|
self.send_thread.setDaemon(True)
|
|
61
78
|
self.send_thread.start()
|
|
@@ -67,6 +84,15 @@ class TCPClient:
|
|
|
67
84
|
def run_reactor():
|
|
68
85
|
reactor.run(installSignalHandlers=False)
|
|
69
86
|
|
|
87
|
+
def check_tls_path(self):
|
|
88
|
+
client_key = os.path.join(self.tls_path, "client.key")
|
|
89
|
+
client_crt = os.path.join(self.tls_path, "client.crt")
|
|
90
|
+
if not os.path.exists(client_key):
|
|
91
|
+
raise Exception(f"client_key: {client_key} is not exists.")
|
|
92
|
+
if not os.path.exists(client_crt):
|
|
93
|
+
raise Exception(f"client_crt: {client_crt} is not exists.")
|
|
94
|
+
return client_key, client_crt
|
|
95
|
+
|
|
70
96
|
def start(self):
|
|
71
97
|
def conn_callback(cur_protocol):
|
|
72
98
|
if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host:
|
|
@@ -80,8 +106,6 @@ class TCPClient:
|
|
|
80
106
|
time.sleep(1)
|
|
81
107
|
reactor.stop()
|
|
82
108
|
logger.error(f"Failed to connected {self.host} {self.port}. Reason is {failure.getErrorMessage()}")
|
|
83
|
-
os.kill(os.getpid(), signal.SIGKILL)
|
|
84
|
-
os.kill(os.getppid(), signal.SIGKILL)
|
|
85
109
|
|
|
86
110
|
def cur_protocol():
|
|
87
111
|
return self.tcp_manager
|
|
@@ -89,14 +113,9 @@ class TCPClient:
|
|
|
89
113
|
self.factory = MessageClientFactory()
|
|
90
114
|
self.factory.protocol = cur_protocol
|
|
91
115
|
if self.tls_path:
|
|
92
|
-
from OpenSSL import SSL
|
|
93
116
|
from twisted.internet import ssl
|
|
94
|
-
client_key =
|
|
95
|
-
|
|
96
|
-
client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt, SSL.TLSv1_2_METHOD)
|
|
97
|
-
client_context_ = client_context_factory.getContext()
|
|
98
|
-
client_context_.set_cipher_list(cipher_list)
|
|
99
|
-
client_context_.set_options(SSL.OP_NO_RENEGOTIATION)
|
|
117
|
+
client_key, client_crt = self.check_tls_path()
|
|
118
|
+
client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt)
|
|
100
119
|
endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory)
|
|
101
120
|
else:
|
|
102
121
|
endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port)
|
|
@@ -109,7 +128,11 @@ class TCPClient:
|
|
|
109
128
|
|
|
110
129
|
def send_after_queue_empty(self, data):
|
|
111
130
|
while not self._ready_to_exit():
|
|
112
|
-
self.
|
|
131
|
+
if not self.tls_path:
|
|
132
|
+
self.add_to_sending_queue(data)
|
|
133
|
+
else:
|
|
134
|
+
for _ in range(MAX_SENDING_QUEUE_SIZE):
|
|
135
|
+
self.add_to_sending_queue(data)
|
|
113
136
|
time.sleep(2)
|
|
114
137
|
|
|
115
138
|
def check_client_alive(self):
|
|
@@ -124,8 +147,6 @@ class TCPClient:
|
|
|
124
147
|
if not self.check_client_alive():
|
|
125
148
|
break
|
|
126
149
|
time.sleep(1)
|
|
127
|
-
while not self.tcp_manager.kill_process:
|
|
128
|
-
time.sleep(1)
|
|
129
150
|
|
|
130
151
|
def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], rank: int = 0, step: int = 0):
|
|
131
152
|
if self._ready_to_exit():
|
|
@@ -142,7 +163,8 @@ class TCPClient:
|
|
|
142
163
|
self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME)
|
|
143
164
|
except Exception as e:
|
|
144
165
|
logger.error(f"send_queue put send_data timeout, rank: {send_data.rank}, step: {send_data.step},"
|
|
145
|
-
f"sequence_number: {send_data.sequence_number}, {
|
|
166
|
+
f"sequence_number: {send_data.sequence_number}, send_queue size: {self.send_queue.qsize()},"
|
|
167
|
+
f"{str(e)}")
|
|
146
168
|
|
|
147
169
|
def _send_data(self, data: TCPDataItem):
|
|
148
170
|
self.tcp_manager.send_wrapped_data(data.raw_data,
|
|
@@ -159,10 +181,11 @@ class TCPClient:
|
|
|
159
181
|
while self.send_queue.qsize() > 0:
|
|
160
182
|
if self._ready_to_exit():
|
|
161
183
|
break
|
|
162
|
-
if len(self.resend_dict) <
|
|
184
|
+
if len(self.resend_dict) < MAX_SENDING_QUEUE_SIZE:
|
|
163
185
|
data_obj = self.send_queue.get()
|
|
164
|
-
self._send_data(data_obj)
|
|
165
186
|
resend_key = str(data_obj.sequence_number) + "_" + str(data_obj.rank) + "_" + str(data_obj.step)
|
|
187
|
+
logger.debug(f"get {resend_key} from send_queue, and send to server.")
|
|
188
|
+
self._send_data(data_obj)
|
|
166
189
|
if resend_key not in self.resend_dict.keys():
|
|
167
190
|
# Send data for the first time
|
|
168
191
|
self.resend_dict[resend_key] = data_obj
|
|
@@ -233,7 +256,7 @@ class TCPClient:
|
|
|
233
256
|
class ClientProtocol(protocol.Protocol):
|
|
234
257
|
TIMEOUT = 60 * 10
|
|
235
258
|
|
|
236
|
-
def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False):
|
|
259
|
+
def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False, tls=None):
|
|
237
260
|
self.buffer = io.BytesIO()
|
|
238
261
|
self.is_connected = False
|
|
239
262
|
self.check_sum = check_sum
|
|
@@ -244,6 +267,13 @@ class ClientProtocol(protocol.Protocol):
|
|
|
244
267
|
self.signal_exit = False
|
|
245
268
|
self.defer = None
|
|
246
269
|
self.kill_process = False
|
|
270
|
+
self.ack = None
|
|
271
|
+
|
|
272
|
+
self.timeout_call = None
|
|
273
|
+
|
|
274
|
+
self.tls = tls
|
|
275
|
+
self.send_buffer = b""
|
|
276
|
+
self.buffer_cnt = 0
|
|
247
277
|
|
|
248
278
|
def dataReceived(self, data):
|
|
249
279
|
if self.timeout_call.active():
|
|
@@ -255,9 +285,11 @@ class ClientProtocol(protocol.Protocol):
|
|
|
255
285
|
while True:
|
|
256
286
|
if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3
|
|
257
287
|
ack = self.buffer.read(5)
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
288
|
+
self.ack = ack
|
|
289
|
+
seq_number = struct.unpack(unpack_mode, self.buffer.read(8))[0]
|
|
290
|
+
rank = struct.unpack(unpack_mode, self.buffer.read(8))[0]
|
|
291
|
+
step = struct.unpack(unpack_mode, self.buffer.read(8))[0]
|
|
292
|
+
logger.debug(f"receive 流水号: {seq_number}; RANK: {rank}; STEP: {step}; ACK: {ack}")
|
|
261
293
|
if ack == b"KILL_":
|
|
262
294
|
self.kill_process = True
|
|
263
295
|
logger.debug(f"接收到KILL信号, PID {os.getpid()}")
|
|
@@ -276,20 +308,33 @@ class ClientProtocol(protocol.Protocol):
|
|
|
276
308
|
def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0):
|
|
277
309
|
length = len(data)
|
|
278
310
|
md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else ""
|
|
311
|
+
data_meaasge = length.to_bytes(8, byteorder=bytes_order) + \
|
|
312
|
+
sequence_number.to_bytes(8, byteorder=bytes_order) + \
|
|
313
|
+
rank.to_bytes(8, byteorder=bytes_order) + \
|
|
314
|
+
step.to_bytes(8, byteorder=bytes_order) + \
|
|
315
|
+
md5_hash.encode() + \
|
|
316
|
+
data
|
|
317
|
+
logger.debug(f"send 流水号: {sequence_number}; RANK: {rank}; STEP: {step}; LENGTH: {length}")
|
|
318
|
+
|
|
279
319
|
while True:
|
|
280
320
|
if self.defer is None or self.defer.called:
|
|
281
|
-
self.defer = self.send_large_data(
|
|
282
|
-
length.to_bytes(8, byteorder='big') +
|
|
283
|
-
sequence_number.to_bytes(8, byteorder='big') +
|
|
284
|
-
rank.to_bytes(8, byteorder='big') +
|
|
285
|
-
step.to_bytes(8, byteorder='big') +
|
|
286
|
-
md5_hash.encode() +
|
|
287
|
-
data)
|
|
321
|
+
self.defer = self.send_large_data(data_meaasge)
|
|
288
322
|
break
|
|
289
323
|
time.sleep(0.01)
|
|
290
324
|
|
|
291
325
|
def send_large_data(self, data):
|
|
292
|
-
|
|
326
|
+
|
|
327
|
+
if self.tls:
|
|
328
|
+
self.send_buffer += data
|
|
329
|
+
self.buffer_cnt += 1
|
|
330
|
+
if self.buffer_cnt >= MAX_SENDING_QUEUE_SIZE:
|
|
331
|
+
d = self.file_sender.beginFileTransfer(io.BytesIO(self.send_buffer), self.transport)
|
|
332
|
+
self.send_buffer = b""
|
|
333
|
+
self.buffer_cnt = 0
|
|
334
|
+
else:
|
|
335
|
+
d = None
|
|
336
|
+
else:
|
|
337
|
+
d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport)
|
|
293
338
|
return d
|
|
294
339
|
|
|
295
340
|
def connection_timeout(self):
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import time
|
|
2
17
|
from collections import namedtuple
|
|
3
18
|
|
|
@@ -12,6 +27,8 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import DETAIL_TE
|
|
|
12
27
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api
|
|
13
28
|
from msprobe.pytorch.common.log import logger
|
|
14
29
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device
|
|
30
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params
|
|
31
|
+
|
|
15
32
|
|
|
16
33
|
# NPU vs GPU api list
|
|
17
34
|
CompareApi = set(absolute_standard_api) | set(binary_standard_api) | set(thousandth_standard_api)
|
|
@@ -75,7 +92,8 @@ def online_precision_compare(api_data, device, common_config, api_precision_csv_
|
|
|
75
92
|
|
|
76
93
|
try:
|
|
77
94
|
# NPU vs CPU
|
|
78
|
-
|
|
95
|
+
cpu_args, cpu_kwargs = generate_cpu_params(npu_args, npu_kwargs, False, api_name)
|
|
96
|
+
cpu_out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs)
|
|
79
97
|
npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank)
|
|
80
98
|
npu_detail = compare.compare_output(api_full_name, npu_data_info, True)
|
|
81
99
|
npu_data = pd.DataFrame(npu_detail, columns=DETAIL_TEST_ROWS[-1])
|
|
@@ -156,7 +174,10 @@ class ConsumerDispatcher:
|
|
|
156
174
|
|
|
157
175
|
def start(self, handle_func, config):
|
|
158
176
|
self.queues = [mp.Queue(maxsize=self.capacity) for _ in range(self.num_workers)]
|
|
159
|
-
api_precision_csv_file = [
|
|
177
|
+
api_precision_csv_file = [
|
|
178
|
+
ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME,
|
|
179
|
+
ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME
|
|
180
|
+
]
|
|
160
181
|
common_config = CommonCompareConfig(self.compare, handle_func, config)
|
|
161
182
|
for xpu_id, q in enumerate(self.queues):
|
|
162
183
|
p = mp.Process(name="run_ut_process", target=run_ut_process,
|
|
@@ -164,8 +185,10 @@ class ConsumerDispatcher:
|
|
|
164
185
|
|
|
165
186
|
p.start()
|
|
166
187
|
self.processes.append(p)
|
|
167
|
-
logger.info(
|
|
168
|
-
|
|
188
|
+
logger.info(
|
|
189
|
+
f'Api_precision_compare task result will be saved in {ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME}')
|
|
190
|
+
logger.info(
|
|
191
|
+
f"Api_precision_compare task details will be saved in {ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME}")
|
|
169
192
|
logger.info("Successfully start unittest process.")
|
|
170
193
|
|
|
171
194
|
def stop(self):
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from functools import wraps
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from torch.utils._python_dispatch import TorchDispatchMode
|
|
21
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
22
|
+
from msprobe.pytorch.common.utils import get_tensor_rank
|
|
23
|
+
from msprobe.core.common.const import Const
|
|
24
|
+
from msprobe.pytorch.common.log import logger
|
|
25
|
+
from msprobe.core.common.file_utils import load_yaml
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def singleton(cls):
|
|
29
|
+
_instance = {}
|
|
30
|
+
|
|
31
|
+
@wraps(cls)
|
|
32
|
+
def inner():
|
|
33
|
+
if cls not in _instance:
|
|
34
|
+
_instance[cls] = cls()
|
|
35
|
+
return _instance[cls]
|
|
36
|
+
return inner
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@singleton
|
|
40
|
+
class Counter:
|
|
41
|
+
def __init__(self) -> None:
|
|
42
|
+
self.index_dict = {}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
counter = Counter()
|
|
46
|
+
yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
|
|
47
|
+
yaml_file = load_yaml(yaml_path)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class AccuracyCheckerDispatch(TorchDispatchMode):
|
|
51
|
+
def __init__(self, attl):
|
|
52
|
+
super(AccuracyCheckerDispatch, self).__init__()
|
|
53
|
+
self.attl = attl
|
|
54
|
+
self.counter = counter
|
|
55
|
+
self.aten_ops_blacklist = []
|
|
56
|
+
self.npu_adjust_autogard = []
|
|
57
|
+
self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist', [])
|
|
58
|
+
self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard', [])
|
|
59
|
+
|
|
60
|
+
def __torch_dispatch__(self, func, types, args=None, kwargs=None):
|
|
61
|
+
func_name_split_list = func.__name__.split(Const.SEP)
|
|
62
|
+
aten_api = func_name_split_list[0]
|
|
63
|
+
self.enable_autogard(aten_api)
|
|
64
|
+
if aten_api in self.aten_ops_blacklist:
|
|
65
|
+
npu_out = func(*args, **kwargs)
|
|
66
|
+
return npu_out
|
|
67
|
+
|
|
68
|
+
res = func(*args, **kwargs)
|
|
69
|
+
cur_rank = get_tensor_rank(args, res)
|
|
70
|
+
cur_api_number = self.counter.index_dict.setdefault(aten_api, 0)
|
|
71
|
+
api_name = f'{Const.ATEN}{Const.SEP}{aten_api}{Const.SEP}{cur_api_number}'
|
|
72
|
+
logger.info(f"tools is dumping api: {api_name}")
|
|
73
|
+
api_data = ApiData(api_name, args, kwargs, res, 0, cur_rank)
|
|
74
|
+
if "device" in api_data.kwargs:
|
|
75
|
+
api_data.kwargs.pop("device")
|
|
76
|
+
if self.attl.nfs_path:
|
|
77
|
+
self.attl.upload(api_data)
|
|
78
|
+
else:
|
|
79
|
+
self.attl.send(api_data)
|
|
80
|
+
self.counter.index_dict[aten_api] += 1
|
|
81
|
+
|
|
82
|
+
return res
|
|
83
|
+
|
|
84
|
+
def enable_autogard(self, aten_api):
|
|
85
|
+
if aten_api in self.npu_adjust_autogard:
|
|
86
|
+
torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def dispatch4data(func, attl, status):
|
|
90
|
+
@wraps(func)
|
|
91
|
+
def wrapper(*args, **kwargs):
|
|
92
|
+
if not status:
|
|
93
|
+
return func(*args, **kwargs)
|
|
94
|
+
with AccuracyCheckerDispatch(attl):
|
|
95
|
+
res = func(*args, **kwargs)
|
|
96
|
+
return res
|
|
97
|
+
|
|
98
|
+
return wrapper
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def run_ut_dispatch(attl, status):
|
|
102
|
+
"""
|
|
103
|
+
This function called by online_run_ut.
|
|
104
|
+
It is used to enable or disable dispatch for torch.autograd.backward function.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
attl (ATTL): online_run_ut class ATTL, which is used to upload or send api data to server.
|
|
108
|
+
status (bool): True means enable dispatch, False means disable dispatch.
|
|
109
|
+
"""
|
|
110
|
+
torch.autograd.backward = dispatch4data(torch.autograd.backward, attl, status)
|