mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- msprobe/README.md +46 -16
- msprobe/__init__.py +16 -1
- msprobe/config.json +0 -2
- msprobe/core/advisor/advisor.py +8 -8
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +64 -3
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +54 -9
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +21 -11
- msprobe/core/common/utils.py +153 -167
- msprobe/core/common_config.py +18 -25
- msprobe/core/compare/acc_compare.py +209 -36
- msprobe/core/compare/check.py +102 -17
- msprobe/core/compare/compare_cli.py +21 -1
- msprobe/core/compare/highlight.py +41 -5
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +21 -6
- msprobe/core/compare/utils.py +82 -48
- msprobe/core/data_dump/data_collector.py +31 -32
- msprobe/core/data_dump/data_processor/base.py +45 -22
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
- msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +32 -16
- msprobe/core/grad_probe/constant.py +4 -0
- msprobe/core/grad_probe/grad_compare.py +2 -3
- msprobe/core/grad_probe/utils.py +16 -3
- msprobe/docs/01.installation.md +19 -9
- msprobe/docs/02.config_introduction.md +52 -80
- msprobe/docs/03.config_examples.md +3 -13
- msprobe/docs/04.acl_config_examples.md +11 -9
- msprobe/docs/05.data_dump_PyTorch.md +140 -12
- msprobe/docs/06.data_dump_MindSpore.md +47 -5
- msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/docs/17.grad_probe.md +14 -16
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
- msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
- msprobe/mindspore/cell_processor.py +27 -3
- msprobe/mindspore/common/const.py +2 -0
- msprobe/mindspore/common/utils.py +18 -2
- msprobe/mindspore/compare/distributed_compare.py +9 -22
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +173 -35
- msprobe/mindspore/compare/ms_graph_compare.py +27 -11
- msprobe/mindspore/debugger/debugger_config.py +16 -13
- msprobe/mindspore/debugger/precision_debugger.py +37 -13
- msprobe/mindspore/dump/dump_tool_factory.py +16 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +41 -17
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
- msprobe/mindspore/free_benchmark/common/utils.py +19 -5
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
- msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
- msprobe/mindspore/grad_probe/global_context.py +18 -8
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/service.py +42 -123
- msprobe/pytorch/__init__.py +20 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +20 -5
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +26 -11
- msprobe/pytorch/common/utils.py +40 -35
- msprobe/pytorch/compare/distributed_compare.py +11 -11
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +38 -6
- msprobe/pytorch/debugger/debugger_config.py +52 -39
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/enums.py +28 -0
- msprobe/pytorch/free_benchmark/common/params.py +15 -0
- msprobe/pytorch/free_benchmark/common/utils.py +17 -1
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +10 -11
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +17 -2
- msprobe/pytorch/online_dispatch/compare.py +11 -12
- msprobe/pytorch/online_dispatch/single_compare.py +7 -7
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
- msprobe/pytorch/online_dispatch/utils.py +1 -4
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +9 -10
- msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
- msprobe/pytorch/parse_tool/lib/utils.py +28 -24
- msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
- msprobe/pytorch/pt_config.py +167 -38
- msprobe/pytorch/service.py +97 -32
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
msprobe/pytorch/pt_config.py
CHANGED
|
@@ -1,12 +1,33 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
2
16
|
import os
|
|
3
17
|
|
|
4
|
-
from msprobe.core.common_config import CommonConfig, BaseConfig
|
|
5
|
-
from msprobe.core.common.file_utils import FileOpen
|
|
6
18
|
from msprobe.core.common.const import Const
|
|
7
|
-
from msprobe.
|
|
19
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
20
|
+
from msprobe.core.common.file_utils import FileOpen, load_json
|
|
21
|
+
from msprobe.core.common.log import logger
|
|
22
|
+
from msprobe.core.common_config import BaseConfig, CommonConfig
|
|
8
23
|
from msprobe.core.grad_probe.constant import level_adp
|
|
9
|
-
from msprobe.core.grad_probe.utils import
|
|
24
|
+
from msprobe.core.grad_probe.utils import check_bounds
|
|
25
|
+
from msprobe.pytorch.free_benchmark.common.enums import (
|
|
26
|
+
DeviceType,
|
|
27
|
+
HandlerType,
|
|
28
|
+
PytorchFreeBenchmarkConst,
|
|
29
|
+
)
|
|
30
|
+
from msprobe.pytorch.hook_module.utils import get_ops
|
|
10
31
|
|
|
11
32
|
|
|
12
33
|
class TensorConfig(BaseConfig):
|
|
@@ -16,7 +37,7 @@ class TensorConfig(BaseConfig):
|
|
|
16
37
|
self.nfs_path = json_config.get("nfs_path", "")
|
|
17
38
|
self.host = json_config.get("host", "")
|
|
18
39
|
self.port = json_config.get("port", -1)
|
|
19
|
-
self.tls_path = json_config.get("tls_path", "")
|
|
40
|
+
self.tls_path = json_config.get("tls_path", "./")
|
|
20
41
|
self.check_config()
|
|
21
42
|
self._check_file_format()
|
|
22
43
|
self._check_tls_path_config()
|
|
@@ -26,13 +47,8 @@ class TensorConfig(BaseConfig):
|
|
|
26
47
|
raise Exception("file_format is invalid")
|
|
27
48
|
|
|
28
49
|
def _check_tls_path_config(self):
|
|
29
|
-
if self.tls_path:
|
|
30
|
-
|
|
31
|
-
raise Exception("tls_path: %s does not exist" % self.tls_path)
|
|
32
|
-
if not os.path.exists(os.path.join(self.tls_path, "client.key")):
|
|
33
|
-
raise Exception("tls_path does not contain client.key")
|
|
34
|
-
if not os.path.exists(os.path.join(self.tls_path, "client.crt")):
|
|
35
|
-
raise Exception("tls_path does not contain client.crt")
|
|
50
|
+
if self.tls_path and not os.path.exists(self.tls_path):
|
|
51
|
+
raise Exception("tls_path: %s does not exist" % self.tls_path)
|
|
36
52
|
|
|
37
53
|
|
|
38
54
|
class StatisticsConfig(BaseConfig):
|
|
@@ -61,23 +77,142 @@ class OverflowCheckConfig(BaseConfig):
|
|
|
61
77
|
|
|
62
78
|
|
|
63
79
|
class FreeBenchmarkCheckConfig(BaseConfig):
|
|
80
|
+
|
|
64
81
|
def __init__(self, json_config):
|
|
65
82
|
super().__init__(json_config)
|
|
66
|
-
self.fuzz_device = json_config.get("fuzz_device")
|
|
67
|
-
self.pert_mode = json_config.get("pert_mode")
|
|
68
|
-
self.handler_type = json_config.get("handler_type")
|
|
69
|
-
self.fuzz_level = json_config.get("fuzz_level")
|
|
70
|
-
self.fuzz_stage = json_config.get("fuzz_stage")
|
|
71
|
-
self.if_preheat = json_config.get("if_preheat")
|
|
72
|
-
self.preheat_step = json_config.get("preheat_step")
|
|
73
|
-
self.max_sample = json_config.get("max_sample")
|
|
83
|
+
self.fuzz_device = json_config.get("fuzz_device", PytorchFreeBenchmarkConst.DEFAULT_DEVICE)
|
|
84
|
+
self.pert_mode = json_config.get("pert_mode", PytorchFreeBenchmarkConst.DEFAULT_MODE)
|
|
85
|
+
self.handler_type = json_config.get("handler_type", PytorchFreeBenchmarkConst.DEFAULT_HANDLER)
|
|
86
|
+
self.fuzz_level = json_config.get("fuzz_level", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_LEVEL)
|
|
87
|
+
self.fuzz_stage = json_config.get("fuzz_stage", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_STAGE)
|
|
88
|
+
self.if_preheat = json_config.get("if_preheat", False)
|
|
89
|
+
self.preheat_step = json_config.get("preheat_step", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
|
|
90
|
+
self.max_sample = json_config.get("max_sample", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
|
|
74
91
|
self.check_freebenchmark_config()
|
|
75
92
|
|
|
76
93
|
def check_freebenchmark_config(self):
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
94
|
+
self._check_pert_mode()
|
|
95
|
+
self._check_fuzz_device()
|
|
96
|
+
self._check_handler_type()
|
|
97
|
+
self._check_fuzz_stage()
|
|
98
|
+
self._check_fuzz_level()
|
|
99
|
+
self._check_if_preheat()
|
|
100
|
+
if self.handler_type == HandlerType.FIX:
|
|
101
|
+
self._check_fix_config()
|
|
102
|
+
if self.if_preheat:
|
|
103
|
+
self._check_preheat_config()
|
|
104
|
+
|
|
105
|
+
def _check_pert_mode(self):
|
|
106
|
+
if self.pert_mode not in PytorchFreeBenchmarkConst.PERTURBATION_MODE_LIST:
|
|
107
|
+
msg = (
|
|
108
|
+
f"pert_mode is invalid, it should be one of"
|
|
109
|
+
f" {PytorchFreeBenchmarkConst.PERTURBATION_MODE_LIST}"
|
|
110
|
+
)
|
|
111
|
+
logger.error_log_with_exp(
|
|
112
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def _check_fuzz_device(self):
|
|
116
|
+
if self.fuzz_device not in PytorchFreeBenchmarkConst.DEVICE_LIST:
|
|
117
|
+
msg = (
|
|
118
|
+
f"fuzz_device is invalid, it should be one of"
|
|
119
|
+
f" {PytorchFreeBenchmarkConst.DEVICE_LIST}"
|
|
120
|
+
)
|
|
121
|
+
logger.error_log_with_exp(
|
|
122
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
123
|
+
)
|
|
124
|
+
if (self.fuzz_device == DeviceType.CPU) ^ (
|
|
125
|
+
self.pert_mode in PytorchFreeBenchmarkConst.CPU_MODE_LIST
|
|
126
|
+
):
|
|
127
|
+
msg = (
|
|
128
|
+
f"You neet to and can only set fuzz_device as {DeviceType.CPU} "
|
|
129
|
+
f"when pert_mode in {PytorchFreeBenchmarkConst.CPU_MODE_LIST}"
|
|
130
|
+
)
|
|
131
|
+
logger.error_log_with_exp(
|
|
132
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def _check_handler_type(self):
|
|
136
|
+
if self.handler_type not in PytorchFreeBenchmarkConst.HANDLER_LIST:
|
|
137
|
+
msg = (
|
|
138
|
+
f"handler_type is invalid, it should be one of"
|
|
139
|
+
f" {PytorchFreeBenchmarkConst.HANDLER_LIST}"
|
|
140
|
+
)
|
|
141
|
+
logger.error_log_with_exp(
|
|
142
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def _check_fuzz_stage(self):
|
|
146
|
+
if self.fuzz_stage not in PytorchFreeBenchmarkConst.FUZZ_STAGE_LIST:
|
|
147
|
+
msg = (
|
|
148
|
+
f"fuzz_stage is invalid, it should be one of"
|
|
149
|
+
f" {PytorchFreeBenchmarkConst.FUZZ_STAGE_LIST}"
|
|
150
|
+
)
|
|
151
|
+
logger.error_log_with_exp(
|
|
152
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
def _check_fuzz_level(self):
|
|
156
|
+
if self.fuzz_level not in PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST:
|
|
157
|
+
msg = (
|
|
158
|
+
f"fuzz_level is invalid, it should be one of"
|
|
159
|
+
f" {PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST}"
|
|
160
|
+
)
|
|
161
|
+
logger.error_log_with_exp(
|
|
162
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def _check_if_preheat(self):
|
|
166
|
+
if not isinstance(self.if_preheat, bool):
|
|
167
|
+
msg = "if_preheat is invalid, it should be a boolean"
|
|
168
|
+
logger.error_log_with_exp(
|
|
169
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
def _check_preheat_config(self):
|
|
173
|
+
if not isinstance(self.preheat_step, int):
|
|
174
|
+
msg = "preheat_step is invalid, it should be an integer"
|
|
175
|
+
logger.error_log_with_exp(
|
|
176
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
177
|
+
)
|
|
178
|
+
if self.preheat_step <= 0:
|
|
179
|
+
msg = "preheat_step must be greater than 0"
|
|
180
|
+
logger.error_log_with_exp(
|
|
181
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
182
|
+
)
|
|
183
|
+
if not isinstance(self.max_sample, int):
|
|
184
|
+
msg = "max_sample is invalid, it should be an integer"
|
|
185
|
+
logger.error_log_with_exp(
|
|
186
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
187
|
+
)
|
|
188
|
+
if self.max_sample <= 0:
|
|
189
|
+
msg = "max_sample must be greater than 0"
|
|
190
|
+
logger.error_log_with_exp(
|
|
191
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def _check_fix_config(self):
|
|
195
|
+
if self.if_preheat:
|
|
196
|
+
msg = f"Preheating is not supported for {HandlerType.FIX} handler type"
|
|
197
|
+
logger.error_log_with_exp(
|
|
198
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
199
|
+
)
|
|
200
|
+
if self.fuzz_stage not in PytorchFreeBenchmarkConst.FIX_STAGE_LIST:
|
|
201
|
+
msg = (
|
|
202
|
+
f"The fuzz_stage when opening {HandlerType.FIX} handler must be one of "
|
|
203
|
+
f"{PytorchFreeBenchmarkConst.FIX_STAGE_LIST}"
|
|
204
|
+
)
|
|
205
|
+
logger.error_log_with_exp(
|
|
206
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
207
|
+
)
|
|
208
|
+
if self.pert_mode not in PytorchFreeBenchmarkConst.FIX_MODE_LIST:
|
|
209
|
+
msg = (
|
|
210
|
+
f"The pert_mode when opening {HandlerType.FIX} handler must be one of "
|
|
211
|
+
f"{PytorchFreeBenchmarkConst.FIX_MODE_LIST}"
|
|
212
|
+
)
|
|
213
|
+
logger.error_log_with_exp(
|
|
214
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
215
|
+
)
|
|
81
216
|
|
|
82
217
|
|
|
83
218
|
class RunUTConfig(BaseConfig):
|
|
@@ -93,7 +228,7 @@ class RunUTConfig(BaseConfig):
|
|
|
93
228
|
self.host = json_config.get("host", "")
|
|
94
229
|
self.port = json_config.get("port", -1)
|
|
95
230
|
self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
|
|
96
|
-
self.tls_path = json_config.get("tls_path", "")
|
|
231
|
+
self.tls_path = json_config.get("tls_path", "./")
|
|
97
232
|
self.check_run_ut_config()
|
|
98
233
|
|
|
99
234
|
@classmethod
|
|
@@ -118,13 +253,8 @@ class RunUTConfig(BaseConfig):
|
|
|
118
253
|
|
|
119
254
|
@classmethod
|
|
120
255
|
def check_tls_path_config(cls, tls_path):
|
|
121
|
-
if tls_path:
|
|
122
|
-
|
|
123
|
-
raise Exception("tls_path: %s does not exist" % tls_path)
|
|
124
|
-
if not os.path.exists(os.path.join(tls_path, "server.key")):
|
|
125
|
-
raise Exception("tls_path does not contain server.key")
|
|
126
|
-
if not os.path.exists(os.path.join(tls_path, "server.crt")):
|
|
127
|
-
raise Exception("tls_path does not contain server.crt")
|
|
256
|
+
if tls_path and not os.path.exists(tls_path):
|
|
257
|
+
raise Exception("tls_path: %s does not exist" % tls_path)
|
|
128
258
|
|
|
129
259
|
def check_run_ut_config(self):
|
|
130
260
|
RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
|
|
@@ -141,13 +271,13 @@ class GradToolConfig(BaseConfig):
|
|
|
141
271
|
self.param_list = json_config.get("param_list", [])
|
|
142
272
|
self.bounds = json_config.get("bounds", [-1, 0, 1])
|
|
143
273
|
self._check_config()
|
|
144
|
-
|
|
274
|
+
|
|
145
275
|
def _check_config(self):
|
|
146
276
|
if self.grad_level not in level_adp.keys():
|
|
147
277
|
raise Exception(f"grad_level must be one of {level_adp.keys()}")
|
|
148
278
|
if not isinstance(self.param_list, list):
|
|
149
279
|
raise Exception(f"param_list must be a list")
|
|
150
|
-
|
|
280
|
+
check_bounds(self.bounds)
|
|
151
281
|
|
|
152
282
|
|
|
153
283
|
def parse_task_config(task, json_config):
|
|
@@ -178,10 +308,9 @@ def parse_json_config(json_file_path, task):
|
|
|
178
308
|
if not json_file_path:
|
|
179
309
|
config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
180
310
|
json_file_path = os.path.join(config_dir, "config.json")
|
|
181
|
-
|
|
182
|
-
json_config = json.load(file)
|
|
311
|
+
json_config = load_json(json_file_path)
|
|
183
312
|
common_config = CommonConfig(json_config)
|
|
184
|
-
if task
|
|
313
|
+
if task:
|
|
185
314
|
task_config = parse_task_config(task, json_config)
|
|
186
315
|
else:
|
|
187
316
|
task_config = parse_task_config(common_config.task, json_config)
|
msprobe/pytorch/service.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import functools
|
|
2
17
|
import os
|
|
3
18
|
|
|
@@ -6,6 +21,7 @@ import torch
|
|
|
6
21
|
from msprobe.core.common.const import Const
|
|
7
22
|
from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
|
|
8
23
|
from msprobe.core.common.file_utils import create_directory
|
|
24
|
+
from msprobe.core.common.utils import print_tools_ends_info
|
|
9
25
|
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
10
26
|
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
11
27
|
from msprobe.core.data_dump.scope import BaseScope
|
|
@@ -16,7 +32,10 @@ from msprobe.pytorch.hook_module.api_registry import api_register
|
|
|
16
32
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
17
33
|
from msprobe.pytorch.module_processer import ModuleProcesser
|
|
18
34
|
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
35
|
+
|
|
19
36
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
37
|
+
if torch_version_above_or_equal_2:
|
|
38
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
|
|
20
39
|
|
|
21
40
|
HookFn = namedtuple('hookFn', ['pre_hook', 'forward_hook', 'backward_hook', 'forward_hook_torch_version_below_2'])
|
|
22
41
|
|
|
@@ -32,6 +51,7 @@ class Service:
|
|
|
32
51
|
self.first_start = True
|
|
33
52
|
self.current_rank = None
|
|
34
53
|
self.dump_iter_dir = None
|
|
54
|
+
self.should_stop_service = False
|
|
35
55
|
self.attl = None
|
|
36
56
|
|
|
37
57
|
@staticmethod
|
|
@@ -39,14 +59,29 @@ class Service:
|
|
|
39
59
|
logger.info_on_rank_0("Data needed ends here.")
|
|
40
60
|
api_register.api_originality()
|
|
41
61
|
|
|
62
|
+
@staticmethod
|
|
63
|
+
def is_registered_backward_hook(module):
|
|
64
|
+
if hasattr(module, '_backward_hooks') and \
|
|
65
|
+
len(module._backward_hooks) > 0 and \
|
|
66
|
+
module._is_full_backward_hook is False:
|
|
67
|
+
return True
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
def check_register_full_backward_hook(self, module):
|
|
71
|
+
if self.is_registered_backward_hook(module):
|
|
72
|
+
module._backward_hooks.clear()
|
|
73
|
+
module._is_full_backward_hook = None
|
|
74
|
+
logger.warning("Found deprecated backward hooks. Removing them and switching to full backward hooks.")
|
|
75
|
+
|
|
42
76
|
def build_hook(self, module_type, name):
|
|
43
77
|
def pre_hook(api_or_module_name, module, args, kwargs):
|
|
78
|
+
if not self.should_execute_hook():
|
|
79
|
+
return args, kwargs
|
|
80
|
+
|
|
44
81
|
if module_type == BaseScope.Module_Type_Module:
|
|
45
82
|
api_or_module_name = module.mindstudio_reserved_name
|
|
46
83
|
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
47
84
|
|
|
48
|
-
if not self.switch:
|
|
49
|
-
return args, kwargs
|
|
50
85
|
if self.config.online_run_ut:
|
|
51
86
|
return None, None
|
|
52
87
|
if self.data_collector:
|
|
@@ -55,13 +90,13 @@ class Service:
|
|
|
55
90
|
return args, kwargs
|
|
56
91
|
|
|
57
92
|
def forward_hook(api_or_module_name, module, args, kwargs, output):
|
|
93
|
+
if not self.should_execute_hook():
|
|
94
|
+
return None
|
|
95
|
+
|
|
58
96
|
if module_type == BaseScope.Module_Type_Module:
|
|
59
97
|
api_or_module_name = module.mindstudio_reserved_name
|
|
60
98
|
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
61
99
|
|
|
62
|
-
if not self.switch:
|
|
63
|
-
return None
|
|
64
|
-
|
|
65
100
|
if self.config.online_run_ut:
|
|
66
101
|
if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
|
|
67
102
|
return None
|
|
@@ -80,18 +115,14 @@ class Service:
|
|
|
80
115
|
return forward_hook(api_or_module_name, module, args, {}, output)
|
|
81
116
|
|
|
82
117
|
def backward_hook(api_or_module_name, module, grad_input, grad_output):
|
|
118
|
+
if not self.should_execute_hook():
|
|
119
|
+
return
|
|
120
|
+
|
|
83
121
|
if module_type == BaseScope.Module_Type_Module:
|
|
84
122
|
api_or_module_name = module.mindstudio_reserved_name
|
|
85
123
|
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
86
124
|
|
|
87
|
-
if not self.switch:
|
|
88
|
-
return
|
|
89
|
-
|
|
90
125
|
if self.config.online_run_ut:
|
|
91
|
-
if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
|
|
92
|
-
return
|
|
93
|
-
api_data = ApiData(name[:-1], grad_input, {}, grad_output, self.current_iter, self.current_rank)
|
|
94
|
-
self.attl_send(api_data)
|
|
95
126
|
return
|
|
96
127
|
|
|
97
128
|
if self.data_collector:
|
|
@@ -105,26 +136,15 @@ class Service:
|
|
|
105
136
|
pre_forward_hook_fn = functools.partial(pre_hook, forward_name_template)
|
|
106
137
|
forward_hook_fn = functools.partial(forward_hook, forward_name_template)
|
|
107
138
|
backward_hook_fn = functools.partial(backward_hook, backward_name_template)
|
|
108
|
-
forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2,
|
|
139
|
+
forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2,
|
|
140
|
+
forward_name_template)
|
|
109
141
|
return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
|
|
110
142
|
|
|
111
|
-
def step(self):
|
|
112
|
-
self.current_iter += 1
|
|
113
|
-
self.data_collector.update_iter(self.current_iter)
|
|
114
|
-
|
|
115
|
-
ModuleProcesser.reset_module_stats()
|
|
116
|
-
HOOKModule.reset_module_stats()
|
|
117
|
-
|
|
118
143
|
def start(self, model, api_origin=False):
|
|
119
|
-
self.
|
|
120
|
-
if self.config.step and self.current_iter > max(self.config.step):
|
|
121
|
-
if self.config.online_run_ut:
|
|
122
|
-
# send stop signal if online_run_ut
|
|
123
|
-
self.attl_stop()
|
|
124
|
-
self.stop()
|
|
125
|
-
raise Exception("msprobe: exit after iteration {}".format(max(self.config.step)))
|
|
126
|
-
if self.config.step and self.current_iter not in self.config.step:
|
|
144
|
+
if self.need_stop_service():
|
|
127
145
|
return
|
|
146
|
+
|
|
147
|
+
self.model = model
|
|
128
148
|
if self.first_start:
|
|
129
149
|
try:
|
|
130
150
|
self.current_rank = get_rank_if_initialized()
|
|
@@ -138,6 +158,8 @@ class Service:
|
|
|
138
158
|
self.first_start = False
|
|
139
159
|
if api_origin:
|
|
140
160
|
api_register.api_modularity()
|
|
161
|
+
if self.config.online_run_ut and torch_version_above_or_equal_2:
|
|
162
|
+
run_ut_dispatch(self.attl, True)
|
|
141
163
|
self.switch = True
|
|
142
164
|
logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ")
|
|
143
165
|
if self.config.level != "L2" and not self.config.online_run_ut:
|
|
@@ -145,6 +167,8 @@ class Service:
|
|
|
145
167
|
logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
|
|
146
168
|
|
|
147
169
|
def stop(self):
|
|
170
|
+
if self.should_stop_service:
|
|
171
|
+
return
|
|
148
172
|
if self.config.level == "L2":
|
|
149
173
|
return
|
|
150
174
|
if self.config.step and self.current_iter not in self.config.step:
|
|
@@ -152,10 +176,47 @@ class Service:
|
|
|
152
176
|
if self.config.rank and self.current_rank not in self.config.rank:
|
|
153
177
|
return
|
|
154
178
|
self.switch = False
|
|
155
|
-
if self.config.online_run_ut:
|
|
179
|
+
if self.config.online_run_ut and torch_version_above_or_equal_2:
|
|
180
|
+
run_ut_dispatch(self.attl, False)
|
|
156
181
|
return
|
|
157
182
|
self.data_collector.write_json()
|
|
158
183
|
|
|
184
|
+
def step(self):
|
|
185
|
+
if self.should_stop_service:
|
|
186
|
+
return
|
|
187
|
+
self.current_iter += 1
|
|
188
|
+
self.data_collector.update_iter(self.current_iter)
|
|
189
|
+
|
|
190
|
+
ModuleProcesser.reset_module_stats()
|
|
191
|
+
HOOKModule.reset_module_stats()
|
|
192
|
+
self.data_collector.data_writer.reset_cache()
|
|
193
|
+
|
|
194
|
+
def need_stop_service(self):
|
|
195
|
+
if self.should_stop_service:
|
|
196
|
+
return True
|
|
197
|
+
end_service = self.config.step and self.current_iter > max(self.config.step) or \
|
|
198
|
+
self.data_collector and self.data_collector.data_processor.is_terminated
|
|
199
|
+
if end_service:
|
|
200
|
+
if self.config.online_run_ut:
|
|
201
|
+
# send stop signal if online_run_ut
|
|
202
|
+
self.attl_stop()
|
|
203
|
+
if self.config.level in [Const.LEVEL_L1, Const.LEVEL_L2, Const.LEVEL_MIX]:
|
|
204
|
+
api_register.api_originality()
|
|
205
|
+
self.switch = False
|
|
206
|
+
self.should_stop_service = True
|
|
207
|
+
print_tools_ends_info()
|
|
208
|
+
return True
|
|
209
|
+
if self.config.step and self.current_iter not in self.config.step:
|
|
210
|
+
return True
|
|
211
|
+
return False
|
|
212
|
+
|
|
213
|
+
def should_execute_hook(self):
|
|
214
|
+
if not self.switch:
|
|
215
|
+
return False
|
|
216
|
+
if self.data_collector and self.data_collector.data_processor.is_terminated:
|
|
217
|
+
return False
|
|
218
|
+
return True
|
|
219
|
+
|
|
159
220
|
def create_dirs(self):
|
|
160
221
|
create_directory(self.config.dump_path)
|
|
161
222
|
self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
|
|
@@ -187,14 +248,16 @@ class Service:
|
|
|
187
248
|
prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \
|
|
188
249
|
module.__class__.__name__ + Const.SEP
|
|
189
250
|
|
|
190
|
-
pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2
|
|
191
|
-
|
|
251
|
+
pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.build_hook(
|
|
252
|
+
BaseScope.Module_Type_Module, prefix)
|
|
192
253
|
if torch_version_above_or_equal_2:
|
|
193
254
|
module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
194
255
|
else:
|
|
256
|
+
self.check_register_full_backward_hook(module)
|
|
195
257
|
module.register_full_backward_hook(
|
|
196
258
|
self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
197
259
|
module.register_forward_hook(forward_hook_torch_version_below_2)
|
|
260
|
+
self.check_register_full_backward_hook(module)
|
|
198
261
|
module.register_full_backward_hook(backward_hook)
|
|
199
262
|
|
|
200
263
|
module.register_forward_pre_hook(
|
|
@@ -204,11 +267,13 @@ class Service:
|
|
|
204
267
|
if torch_version_above_or_equal_2:
|
|
205
268
|
module.register_full_backward_pre_hook(
|
|
206
269
|
self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
|
|
270
|
+
self.check_register_full_backward_hook(module)
|
|
207
271
|
module.register_full_backward_hook(
|
|
208
272
|
self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
209
273
|
|
|
210
274
|
if self.config.level in ["mix", "L1", "L2"]:
|
|
211
|
-
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)
|
|
275
|
+
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API),
|
|
276
|
+
self.config.online_run_ut)
|
|
212
277
|
api_register.api_modularity()
|
|
213
278
|
|
|
214
279
|
if Const.STATISTICS == self.config.task or Const.TENSOR == self.config.task:
|