mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -1,8 +1,35 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
1
18
|
import os
|
|
2
19
|
import re
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import torch_npu
|
|
24
|
+
except ImportError:
|
|
25
|
+
current_device = "cuda"
|
|
26
|
+
else:
|
|
27
|
+
current_device = "npu"
|
|
3
28
|
|
|
4
|
-
from msprobe.core.common.const import FileCheckConst
|
|
29
|
+
from msprobe.core.common.const import FileCheckConst, Const, CompareConst
|
|
5
30
|
from msprobe.core.common.file_utils import FileChecker
|
|
31
|
+
from msprobe.core.common.log import logger
|
|
32
|
+
from msprobe.core.common.utils import CompareException
|
|
6
33
|
from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
|
|
7
34
|
from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
|
|
8
35
|
from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
|
|
@@ -10,12 +37,21 @@ from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
|
|
|
10
37
|
from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
|
|
11
38
|
|
|
12
39
|
hf_32_standard_api = ["conv1d", "conv2d"]
|
|
40
|
+
not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
|
|
41
|
+
not_raise_dtype_set = {'type_as'}
|
|
42
|
+
|
|
43
|
+
PRECISION_MAPPING = {
|
|
44
|
+
torch.float16: torch.float32,
|
|
45
|
+
torch.bfloat16: torch.float32,
|
|
46
|
+
torch.float32: torch.float64
|
|
47
|
+
}
|
|
13
48
|
|
|
14
49
|
|
|
15
|
-
class
|
|
50
|
+
class BackwardMessage:
|
|
16
51
|
MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
|
|
17
|
-
UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation,
|
|
18
|
-
|
|
52
|
+
UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, " \
|
|
53
|
+
"skip backward."
|
|
54
|
+
NO_BACKWARD_RESULT_MESSAGE = "This API does not have backward input data, skip backward."
|
|
19
55
|
|
|
20
56
|
|
|
21
57
|
class UtDataInfo:
|
|
@@ -68,3 +104,121 @@ def exec_api(api_type, api_name, device, args, kwargs):
|
|
|
68
104
|
torch_api = NpuOPTemplate(api_name, None, False, device)
|
|
69
105
|
out = torch_api.forward(*args, **kwargs)
|
|
70
106
|
return out
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def deal_detach(arg, to_detach=True):
|
|
110
|
+
return arg.detach() if to_detach else arg
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def raise_bench_data_dtype(api_name, arg, raise_dtype=None):
|
|
114
|
+
'''
|
|
115
|
+
将标杆数据的dtype转换为raise_dtype
|
|
116
|
+
输入:
|
|
117
|
+
api_name:api名称
|
|
118
|
+
arg:标杆输入
|
|
119
|
+
raise_dtype:需要转换的dtype
|
|
120
|
+
输出:
|
|
121
|
+
arg: 转换dtype的标杆输入
|
|
122
|
+
'''
|
|
123
|
+
if api_name in hf_32_standard_api and arg.dtype == torch.float32:
|
|
124
|
+
return arg
|
|
125
|
+
if raise_dtype is None or arg.dtype not in PRECISION_MAPPING or raise_dtype == arg.dtype:
|
|
126
|
+
return arg
|
|
127
|
+
return arg.type(raise_dtype)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def generate_device_params(input_args, input_kwargs, need_backward, api_name):
|
|
131
|
+
def recursive_arg_to_device(arg_in, to_detach, depth=0):
|
|
132
|
+
if depth > Const.MAX_DEPTH:
|
|
133
|
+
logger.error("The depth of arg_in is too large, please check the arg_in.")
|
|
134
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
135
|
+
if isinstance(arg_in, (list, tuple)):
|
|
136
|
+
return type(arg_in)(recursive_arg_to_device(arg, to_detach, depth=depth+1) for arg in arg_in)
|
|
137
|
+
elif isinstance(arg_in, torch.Tensor):
|
|
138
|
+
if need_backward and arg_in.requires_grad:
|
|
139
|
+
arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_()
|
|
140
|
+
temp_arg_in = arg_in * 1
|
|
141
|
+
arg_in = temp_arg_in.type_as(arg_in)
|
|
142
|
+
arg_in.retain_grad()
|
|
143
|
+
return arg_in
|
|
144
|
+
else:
|
|
145
|
+
return deal_detach(arg_in.clone(), to_detach).to(current_device)
|
|
146
|
+
else:
|
|
147
|
+
return arg_in
|
|
148
|
+
|
|
149
|
+
is_detach = api_name not in not_detach_set
|
|
150
|
+
device_args = recursive_arg_to_device(input_args, is_detach)
|
|
151
|
+
device_kwargs = \
|
|
152
|
+
{key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()}
|
|
153
|
+
return device_args, device_kwargs
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
|
|
157
|
+
def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None, depth=0):
|
|
158
|
+
if depth > Const.MAX_DEPTH:
|
|
159
|
+
logger.error("The depth of arg_in is too large, please check the arg_in.")
|
|
160
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
161
|
+
if isinstance(arg_in, (list, tuple)):
|
|
162
|
+
return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype, depth=depth+1)
|
|
163
|
+
for arg in arg_in)
|
|
164
|
+
elif isinstance(arg_in, torch.Tensor):
|
|
165
|
+
if need_backward and arg_in.requires_grad:
|
|
166
|
+
arg_in = deal_detach(raise_bench_data_dtype(
|
|
167
|
+
api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
|
|
168
|
+
temp_arg_in = arg_in * 1
|
|
169
|
+
arg_in = temp_arg_in.type_as(arg_in)
|
|
170
|
+
arg_in.retain_grad()
|
|
171
|
+
return arg_in
|
|
172
|
+
else:
|
|
173
|
+
return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach)
|
|
174
|
+
else:
|
|
175
|
+
return arg_in
|
|
176
|
+
|
|
177
|
+
def is_tensor_with_raise_precision(arg_in, check_kwargs=False):
|
|
178
|
+
if arg_in.dtype in PRECISION_MAPPING:
|
|
179
|
+
return True
|
|
180
|
+
if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]:
|
|
181
|
+
return True
|
|
182
|
+
return False
|
|
183
|
+
|
|
184
|
+
def recursive_find_dtypes(arg_in, kwargs=None, check_kwargs=False, depth=0):
|
|
185
|
+
if depth > Const.MAX_DEPTH:
|
|
186
|
+
logger.error("The depth of arg_in is too large, please check the arg_in.")
|
|
187
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
188
|
+
if isinstance(arg_in, (list, tuple)):
|
|
189
|
+
return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs, depth=depth+1) for
|
|
190
|
+
arg in arg_in))
|
|
191
|
+
elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
|
|
192
|
+
return set([arg_in.dtype])
|
|
193
|
+
elif isinstance(arg_in, dict) and check_kwargs:
|
|
194
|
+
return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True, depth=depth+1) for
|
|
195
|
+
v in arg_in.values()))
|
|
196
|
+
return set()
|
|
197
|
+
|
|
198
|
+
raise_dtype = None
|
|
199
|
+
need_raise_dtypes = recursive_find_dtypes(input_args)
|
|
200
|
+
need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
|
|
201
|
+
if len(need_raise_dtypes) == 1:
|
|
202
|
+
raise_dtype = PRECISION_MAPPING.get(need_raise_dtypes.pop(), torch.float32)
|
|
203
|
+
elif len(need_raise_dtypes) >= 2:
|
|
204
|
+
raise_dtype = torch.float32
|
|
205
|
+
|
|
206
|
+
raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
|
|
207
|
+
is_detach = api_name not in not_detach_set
|
|
208
|
+
cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
|
|
209
|
+
cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for
|
|
210
|
+
key, value in input_kwargs.items()}
|
|
211
|
+
return cpu_args, cpu_kwargs
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def record_skip_info(api_full_name, compare, compare_alg_results):
|
|
215
|
+
result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [compare_alg_results], None, 0)
|
|
216
|
+
compare.record_results(result_info)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def is_unsupported_api(api_name, is_overflow_check=False):
|
|
220
|
+
split_name = api_name.split(Const.SEP)[0]
|
|
221
|
+
flag = (split_name == Const.DISTRIBUTED) or (is_overflow_check and split_name == Const.NPU)
|
|
222
|
+
if flag:
|
|
223
|
+
logger.info(f"{split_name} api is not supported for run ut. SKIP.")
|
|
224
|
+
return flag
|
|
@@ -1,7 +1,21 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import glob
|
|
2
17
|
import os.path
|
|
3
18
|
import time
|
|
4
|
-
import re
|
|
5
19
|
from multiprocessing import Queue
|
|
6
20
|
from typing import Optional, Union, Dict, Any
|
|
7
21
|
from dataclasses import dataclass
|
|
@@ -11,9 +25,8 @@ import torch
|
|
|
11
25
|
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
12
26
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
|
|
13
27
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
|
|
14
|
-
from msprobe.pytorch.common.utils import logger
|
|
15
28
|
from msprobe.core.common.file_utils import remove_path
|
|
16
|
-
from msprobe.pytorch.common.utils import save_api_data, load_api_data,
|
|
29
|
+
from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
|
|
17
30
|
|
|
18
31
|
BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
|
|
19
32
|
|
|
@@ -40,7 +53,7 @@ class ATTL:
|
|
|
40
53
|
self.dequeue_list = []
|
|
41
54
|
self.message_end = False
|
|
42
55
|
self.kill_progress = False
|
|
43
|
-
self.
|
|
56
|
+
self.nfs_path = None
|
|
44
57
|
if self.session_config.nfs_path:
|
|
45
58
|
self.nfs_path = self.session_config.nfs_path
|
|
46
59
|
elif self.session_config.is_benchmark_device:
|
|
@@ -57,18 +70,6 @@ class ATTL:
|
|
|
57
70
|
self.session_config.tls_path)
|
|
58
71
|
self.socket_manager.start()
|
|
59
72
|
|
|
60
|
-
def check_attl_config(self):
|
|
61
|
-
if self.session_config.nfs_path:
|
|
62
|
-
if os.path.exists(self.session_config.nfs_path):
|
|
63
|
-
return
|
|
64
|
-
else:
|
|
65
|
-
raise Exception(f"nfs path {self.session_config.nfs_path} doesn't exists.")
|
|
66
|
-
ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
|
|
67
|
-
if not re.match(ipv4_pattern, self.session_config.connect_ip):
|
|
68
|
-
raise Exception(f"host {self.session_config.connect_ip} is invalid.")
|
|
69
|
-
if not (0 < self.session_config.connect_port <= 65535):
|
|
70
|
-
raise Exception(f"port {self.session_config.connect_port} is invalid.")
|
|
71
|
-
|
|
72
73
|
def stop_serve(self):
|
|
73
74
|
if isinstance(self.socket_manager, TCPServer):
|
|
74
75
|
self.socket_manager.stop()
|
|
@@ -77,6 +78,11 @@ class ATTL:
|
|
|
77
78
|
"""
|
|
78
79
|
npu major in 'send' (client)
|
|
79
80
|
"""
|
|
81
|
+
|
|
82
|
+
# if tcp connection lost,
|
|
83
|
+
if self.socket_manager.signal_exit:
|
|
84
|
+
raise ConnectionError(f"Failed to connect to {self.session_config.connect_ip}.")
|
|
85
|
+
|
|
80
86
|
# know receiver receive and go next
|
|
81
87
|
if isinstance(buffer, ApiData):
|
|
82
88
|
buffer = move2target_device(buffer, torch.device('cpu'))
|
|
@@ -94,21 +100,21 @@ class ATTL:
|
|
|
94
100
|
self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
|
|
95
101
|
|
|
96
102
|
def recv(self, timeout_ms=0) -> Optional[BufferType]:
|
|
97
|
-
buffer =
|
|
98
|
-
while buffer
|
|
103
|
+
buffer = ''
|
|
104
|
+
while not buffer:
|
|
99
105
|
if timeout_ms > 0:
|
|
100
106
|
time.sleep(timeout_ms / 1000.0)
|
|
101
|
-
if buffer
|
|
107
|
+
if not buffer and not self.data_queue.empty():
|
|
102
108
|
buffer = self.data_queue.get()
|
|
103
109
|
break
|
|
104
|
-
if buffer
|
|
110
|
+
if not buffer and timeout_ms > 0: # timeout is the only case we give up and return None
|
|
105
111
|
break
|
|
106
112
|
if self.message_end and self.data_queue.empty():
|
|
107
113
|
buffer = b"KILL_CONFIRM"
|
|
108
114
|
self.kill_progress = True
|
|
109
115
|
break
|
|
110
116
|
time.sleep(0.1) # waiting outside the lock before next attempt
|
|
111
|
-
if buffer
|
|
117
|
+
if not buffer:
|
|
112
118
|
# this is a result of a timeout
|
|
113
119
|
self.logger.info(f"RECEIVE API DATA TIMED OUT")
|
|
114
120
|
else:
|
|
@@ -125,7 +131,7 @@ class ATTL:
|
|
|
125
131
|
except Exception as e:
|
|
126
132
|
self.logger.warning("there is something error. please check it. %s", e)
|
|
127
133
|
if isinstance(buffer, bytes):
|
|
128
|
-
return
|
|
134
|
+
return ''
|
|
129
135
|
if isinstance(buffer, str):
|
|
130
136
|
return buffer
|
|
131
137
|
|
|
@@ -139,7 +145,7 @@ class ATTL:
|
|
|
139
145
|
file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
|
|
140
146
|
|
|
141
147
|
try:
|
|
142
|
-
|
|
148
|
+
save_pkl(buffer, file_path)
|
|
143
149
|
except Exception as e:
|
|
144
150
|
self.logger.warning("there is something error in save_pt. please check it. %s", e)
|
|
145
151
|
|
|
@@ -155,7 +161,7 @@ class ATTL:
|
|
|
155
161
|
|
|
156
162
|
if cur_file is not None:
|
|
157
163
|
try:
|
|
158
|
-
buffer =
|
|
164
|
+
buffer = load_pkl(cur_file)
|
|
159
165
|
except Exception as e:
|
|
160
166
|
self.logger.warning("there is something error. please check it. %s", e)
|
|
161
167
|
remove_path(cur_file)
|
|
@@ -1,10 +1,24 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import hashlib
|
|
2
17
|
import io
|
|
3
18
|
import struct
|
|
4
19
|
import time
|
|
5
20
|
import os
|
|
6
21
|
import signal
|
|
7
|
-
import sys
|
|
8
22
|
from queue import Queue
|
|
9
23
|
from threading import Thread
|
|
10
24
|
from typing import Union
|
|
@@ -13,7 +27,10 @@ from twisted.internet import reactor, protocol, endpoints
|
|
|
13
27
|
from twisted.protocols.basic import FileSender
|
|
14
28
|
|
|
15
29
|
from msprobe.pytorch.common.utils import logger
|
|
16
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.
|
|
30
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import STRUCT_UNPACK_MODE as unpack_mode, \
|
|
31
|
+
STR_TO_BYTES_ORDER as bytes_order
|
|
32
|
+
|
|
33
|
+
MAX_SENDING_QUEUE_SIZE = 20
|
|
17
34
|
|
|
18
35
|
|
|
19
36
|
class TCPDataItem:
|
|
@@ -31,7 +48,6 @@ class TCPDataItem:
|
|
|
31
48
|
|
|
32
49
|
|
|
33
50
|
class TCPClient:
|
|
34
|
-
MAX_SENDING_QUEUE_SIZE = 20
|
|
35
51
|
ACK_SUCCESS = b"OK___"
|
|
36
52
|
ACK_ERROR = b"ERROR"
|
|
37
53
|
ACK_BUSY = b"BUSY_"
|
|
@@ -39,13 +55,13 @@ class TCPClient:
|
|
|
39
55
|
ACK_STOP_CONFIRM = b"OVER_"
|
|
40
56
|
ACK_KILL_PROCESS = b"KILL_"
|
|
41
57
|
|
|
42
|
-
QUEUE_PENDING_TIME =
|
|
58
|
+
QUEUE_PENDING_TIME = 60
|
|
43
59
|
RESEND_RETRY_TIMES = 2 # 最大重传数
|
|
44
60
|
RESEND_TIMER_TIME = 5 # 接收ACK超时定时器
|
|
45
61
|
RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据
|
|
46
62
|
|
|
47
63
|
def __init__(self, host="localhost", port=8000, check_sum=False, tls_path=None):
|
|
48
|
-
self.send_queue = Queue(
|
|
64
|
+
self.send_queue = Queue(MAX_SENDING_QUEUE_SIZE)
|
|
49
65
|
self.resend_dict = dict()
|
|
50
66
|
self.host = host
|
|
51
67
|
self.port = port
|
|
@@ -55,7 +71,8 @@ class TCPClient:
|
|
|
55
71
|
self.signal_exit = False
|
|
56
72
|
self.tcp_manager = ClientProtocol(ack_queue_size=100,
|
|
57
73
|
chunk_size=655360,
|
|
58
|
-
check_sum=check_sum
|
|
74
|
+
check_sum=check_sum,
|
|
75
|
+
tls=self.tls_path)
|
|
59
76
|
self.send_thread = Thread(target=self._sending_queue_data)
|
|
60
77
|
self.send_thread.setDaemon(True)
|
|
61
78
|
self.send_thread.start()
|
|
@@ -80,8 +97,6 @@ class TCPClient:
|
|
|
80
97
|
time.sleep(1)
|
|
81
98
|
reactor.stop()
|
|
82
99
|
logger.error(f"Failed to connected {self.host} {self.port}. Reason is {failure.getErrorMessage()}")
|
|
83
|
-
os.kill(os.getpid(), signal.SIGKILL)
|
|
84
|
-
os.kill(os.getppid(), signal.SIGKILL)
|
|
85
100
|
|
|
86
101
|
def cur_protocol():
|
|
87
102
|
return self.tcp_manager
|
|
@@ -89,14 +104,10 @@ class TCPClient:
|
|
|
89
104
|
self.factory = MessageClientFactory()
|
|
90
105
|
self.factory.protocol = cur_protocol
|
|
91
106
|
if self.tls_path:
|
|
92
|
-
from OpenSSL import SSL
|
|
93
107
|
from twisted.internet import ssl
|
|
94
108
|
client_key = os.path.join(self.tls_path, "client.key")
|
|
95
109
|
client_crt = os.path.join(self.tls_path, "client.crt")
|
|
96
|
-
client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt
|
|
97
|
-
client_context_ = client_context_factory.getContext()
|
|
98
|
-
client_context_.set_cipher_list(cipher_list)
|
|
99
|
-
client_context_.set_options(SSL.OP_NO_RENEGOTIATION)
|
|
110
|
+
client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt)
|
|
100
111
|
endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory)
|
|
101
112
|
else:
|
|
102
113
|
endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port)
|
|
@@ -109,7 +120,11 @@ class TCPClient:
|
|
|
109
120
|
|
|
110
121
|
def send_after_queue_empty(self, data):
|
|
111
122
|
while not self._ready_to_exit():
|
|
112
|
-
self.
|
|
123
|
+
if not self.tls_path:
|
|
124
|
+
self.add_to_sending_queue(data)
|
|
125
|
+
else:
|
|
126
|
+
for _ in range(MAX_SENDING_QUEUE_SIZE):
|
|
127
|
+
self.add_to_sending_queue(data)
|
|
113
128
|
time.sleep(2)
|
|
114
129
|
|
|
115
130
|
def check_client_alive(self):
|
|
@@ -124,8 +139,6 @@ class TCPClient:
|
|
|
124
139
|
if not self.check_client_alive():
|
|
125
140
|
break
|
|
126
141
|
time.sleep(1)
|
|
127
|
-
while not self.tcp_manager.kill_process:
|
|
128
|
-
time.sleep(1)
|
|
129
142
|
|
|
130
143
|
def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], rank: int = 0, step: int = 0):
|
|
131
144
|
if self._ready_to_exit():
|
|
@@ -142,7 +155,8 @@ class TCPClient:
|
|
|
142
155
|
self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME)
|
|
143
156
|
except Exception as e:
|
|
144
157
|
logger.error(f"send_queue put send_data timeout, rank: {send_data.rank}, step: {send_data.step},"
|
|
145
|
-
f"sequence_number: {send_data.sequence_number}, {
|
|
158
|
+
f"sequence_number: {send_data.sequence_number}, send_queue size: {self.send_queue.qsize()},"
|
|
159
|
+
f"{str(e)}")
|
|
146
160
|
|
|
147
161
|
def _send_data(self, data: TCPDataItem):
|
|
148
162
|
self.tcp_manager.send_wrapped_data(data.raw_data,
|
|
@@ -159,10 +173,11 @@ class TCPClient:
|
|
|
159
173
|
while self.send_queue.qsize() > 0:
|
|
160
174
|
if self._ready_to_exit():
|
|
161
175
|
break
|
|
162
|
-
if len(self.resend_dict) <
|
|
176
|
+
if len(self.resend_dict) < MAX_SENDING_QUEUE_SIZE:
|
|
163
177
|
data_obj = self.send_queue.get()
|
|
164
|
-
self._send_data(data_obj)
|
|
165
178
|
resend_key = str(data_obj.sequence_number) + "_" + str(data_obj.rank) + "_" + str(data_obj.step)
|
|
179
|
+
logger.debug(f"get {resend_key} from send_queue, and send to server.")
|
|
180
|
+
self._send_data(data_obj)
|
|
166
181
|
if resend_key not in self.resend_dict.keys():
|
|
167
182
|
# Send data for the first time
|
|
168
183
|
self.resend_dict[resend_key] = data_obj
|
|
@@ -233,7 +248,7 @@ class TCPClient:
|
|
|
233
248
|
class ClientProtocol(protocol.Protocol):
|
|
234
249
|
TIMEOUT = 60 * 10
|
|
235
250
|
|
|
236
|
-
def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False):
|
|
251
|
+
def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False, tls=None):
|
|
237
252
|
self.buffer = io.BytesIO()
|
|
238
253
|
self.is_connected = False
|
|
239
254
|
self.check_sum = check_sum
|
|
@@ -244,6 +259,13 @@ class ClientProtocol(protocol.Protocol):
|
|
|
244
259
|
self.signal_exit = False
|
|
245
260
|
self.defer = None
|
|
246
261
|
self.kill_process = False
|
|
262
|
+
self.ack = None
|
|
263
|
+
|
|
264
|
+
self.timeout_call = None
|
|
265
|
+
|
|
266
|
+
self.tls = tls
|
|
267
|
+
self.send_buffer = b""
|
|
268
|
+
self.buffer_cnt = 0
|
|
247
269
|
|
|
248
270
|
def dataReceived(self, data):
|
|
249
271
|
if self.timeout_call.active():
|
|
@@ -255,9 +277,11 @@ class ClientProtocol(protocol.Protocol):
|
|
|
255
277
|
while True:
|
|
256
278
|
if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3
|
|
257
279
|
ack = self.buffer.read(5)
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
280
|
+
self.ack = ack
|
|
281
|
+
seq_number = struct.unpack(unpack_mode, self.buffer.read(8))[0]
|
|
282
|
+
rank = struct.unpack(unpack_mode, self.buffer.read(8))[0]
|
|
283
|
+
step = struct.unpack(unpack_mode, self.buffer.read(8))[0]
|
|
284
|
+
logger.debug(f"receive 流水号: {seq_number}; RANK: {rank}; STEP: {step}; ACK: {ack}")
|
|
261
285
|
if ack == b"KILL_":
|
|
262
286
|
self.kill_process = True
|
|
263
287
|
logger.debug(f"接收到KILL信号, PID {os.getpid()}")
|
|
@@ -276,20 +300,33 @@ class ClientProtocol(protocol.Protocol):
|
|
|
276
300
|
def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0):
|
|
277
301
|
length = len(data)
|
|
278
302
|
md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else ""
|
|
303
|
+
data_meaasge = length.to_bytes(8, byteorder=bytes_order) + \
|
|
304
|
+
sequence_number.to_bytes(8, byteorder=bytes_order) + \
|
|
305
|
+
rank.to_bytes(8, byteorder=bytes_order) + \
|
|
306
|
+
step.to_bytes(8, byteorder=bytes_order) + \
|
|
307
|
+
md5_hash.encode() + \
|
|
308
|
+
data
|
|
309
|
+
logger.debug(f"send 流水号: {sequence_number}; RANK: {rank}; STEP: {step}; LENGTH: {length}")
|
|
310
|
+
|
|
279
311
|
while True:
|
|
280
312
|
if self.defer is None or self.defer.called:
|
|
281
|
-
self.defer = self.send_large_data(
|
|
282
|
-
length.to_bytes(8, byteorder='big') +
|
|
283
|
-
sequence_number.to_bytes(8, byteorder='big') +
|
|
284
|
-
rank.to_bytes(8, byteorder='big') +
|
|
285
|
-
step.to_bytes(8, byteorder='big') +
|
|
286
|
-
md5_hash.encode() +
|
|
287
|
-
data)
|
|
313
|
+
self.defer = self.send_large_data(data_meaasge)
|
|
288
314
|
break
|
|
289
315
|
time.sleep(0.01)
|
|
290
316
|
|
|
291
317
|
def send_large_data(self, data):
|
|
292
|
-
|
|
318
|
+
|
|
319
|
+
if self.tls:
|
|
320
|
+
self.send_buffer += data
|
|
321
|
+
self.buffer_cnt += 1
|
|
322
|
+
if self.buffer_cnt >= MAX_SENDING_QUEUE_SIZE:
|
|
323
|
+
d = self.file_sender.beginFileTransfer(io.BytesIO(self.send_buffer), self.transport)
|
|
324
|
+
self.send_buffer = b""
|
|
325
|
+
self.buffer_cnt = 0
|
|
326
|
+
else:
|
|
327
|
+
d = None
|
|
328
|
+
else:
|
|
329
|
+
d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport)
|
|
293
330
|
return d
|
|
294
331
|
|
|
295
332
|
def connection_timeout(self):
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import time
|
|
2
17
|
from collections import namedtuple
|
|
3
18
|
|
|
@@ -12,6 +27,8 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import DETAIL_TE
|
|
|
12
27
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api
|
|
13
28
|
from msprobe.pytorch.common.log import logger
|
|
14
29
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device
|
|
30
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params
|
|
31
|
+
|
|
15
32
|
|
|
16
33
|
# NPU vs GPU api list
|
|
17
34
|
CompareApi = set(absolute_standard_api) | set(binary_standard_api) | set(thousandth_standard_api)
|
|
@@ -75,7 +92,8 @@ def online_precision_compare(api_data, device, common_config, api_precision_csv_
|
|
|
75
92
|
|
|
76
93
|
try:
|
|
77
94
|
# NPU vs CPU
|
|
78
|
-
|
|
95
|
+
cpu_args, cpu_kwargs = generate_cpu_params(npu_args, npu_kwargs, False, api_name)
|
|
96
|
+
cpu_out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs)
|
|
79
97
|
npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank)
|
|
80
98
|
npu_detail = compare.compare_output(api_full_name, npu_data_info, True)
|
|
81
99
|
npu_data = pd.DataFrame(npu_detail, columns=DETAIL_TEST_ROWS[-1])
|
|
@@ -156,7 +174,10 @@ class ConsumerDispatcher:
|
|
|
156
174
|
|
|
157
175
|
def start(self, handle_func, config):
|
|
158
176
|
self.queues = [mp.Queue(maxsize=self.capacity) for _ in range(self.num_workers)]
|
|
159
|
-
api_precision_csv_file = [
|
|
177
|
+
api_precision_csv_file = [
|
|
178
|
+
ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME,
|
|
179
|
+
ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME
|
|
180
|
+
]
|
|
160
181
|
common_config = CommonCompareConfig(self.compare, handle_func, config)
|
|
161
182
|
for xpu_id, q in enumerate(self.queues):
|
|
162
183
|
p = mp.Process(name="run_ut_process", target=run_ut_process,
|
|
@@ -164,8 +185,10 @@ class ConsumerDispatcher:
|
|
|
164
185
|
|
|
165
186
|
p.start()
|
|
166
187
|
self.processes.append(p)
|
|
167
|
-
logger.info(
|
|
168
|
-
|
|
188
|
+
logger.info(
|
|
189
|
+
f'Api_precision_compare task result will be saved in {ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME}')
|
|
190
|
+
logger.info(
|
|
191
|
+
f"Api_precision_compare task details will be saved in {ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME}")
|
|
169
192
|
logger.info("Successfully start unittest process.")
|
|
170
193
|
|
|
171
194
|
def stop(self):
|