mindstudio-probe 8.2.1__py3-none-any.whl → 8.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/RECORD +46 -37
- msprobe/README.md +3 -1
- msprobe/core/common/file_utils.py +80 -25
- msprobe/core/common/framework_adapter.py +7 -6
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +4 -16
- msprobe/core/compare/find_first/utils.py +1 -1
- msprobe/core/hook_manager.py +16 -3
- msprobe/core/service.py +16 -5
- msprobe/docs/02.config_introduction.md +14 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +295 -0
- msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +3 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/mindspore/compare/utils.py +1 -2
- msprobe/msprobe.py +6 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +36 -3
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +24 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +12 -2
- msprobe/pytorch/api_accuracy_checker/config.yaml +6 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +132 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +205 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +378 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +239 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +250 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +198 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/common/utils.py +22 -2
- msprobe/pytorch/compare/utils.py +1 -2
- msprobe/pytorch/debugger/debugger_config.py +10 -0
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +24 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +9 -3
- msprobe/pytorch/hook_module/api_register.py +6 -1
- msprobe/pytorch/pt_config.py +57 -2
- msprobe/pytorch/pytorch_service.py +11 -2
- msprobe/visualization/builder/graph_builder.py +1 -0
- msprobe/visualization/utils.py +11 -1
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +0 -3
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/top_level.txt +0 -0
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -39,7 +39,6 @@ except ImportError:
|
|
|
39
39
|
else:
|
|
40
40
|
is_gpu = False
|
|
41
41
|
|
|
42
|
-
|
|
43
42
|
torch_without_guard_version = torch.__version__ >= '2.1'
|
|
44
43
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
45
44
|
|
|
@@ -416,7 +415,8 @@ def is_recomputation():
|
|
|
416
415
|
|
|
417
416
|
# Identify indices in the call stack where the specific function is being executed
|
|
418
417
|
for idx, frame_info in enumerate(call_stack):
|
|
419
|
-
if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward'
|
|
418
|
+
if (frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward' and
|
|
419
|
+
"megatron" in frame_info.filename):
|
|
420
420
|
backward_function_indices.append(idx)
|
|
421
421
|
|
|
422
422
|
# Check if the execution is within 'torch/autograd/function.py' file
|
|
@@ -471,3 +471,23 @@ def register_forward_hook(module, forward_hook):
|
|
|
471
471
|
module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
472
472
|
else:
|
|
473
473
|
module.register_forward_hook(forward_hook)
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def save_api_data(api_data):
|
|
477
|
+
"""Save data to io stream"""
|
|
478
|
+
try:
|
|
479
|
+
io_buff = io.BytesIO()
|
|
480
|
+
torch.save(api_data, io_buff)
|
|
481
|
+
except Exception as e:
|
|
482
|
+
raise RuntimeError(f"save api_data to io_buff failed") from e
|
|
483
|
+
return io_buff
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def load_api_data(api_data_bytes):
|
|
487
|
+
"""Load data from bytes stream"""
|
|
488
|
+
try:
|
|
489
|
+
buffer = io.BytesIO(api_data_bytes)
|
|
490
|
+
buffer = torch.load(buffer, map_location="cpu")
|
|
491
|
+
except Exception as e:
|
|
492
|
+
raise RuntimeError(f"load api_data from bytes failed") from e
|
|
493
|
+
return buffer
|
msprobe/pytorch/compare/utils.py
CHANGED
|
@@ -27,8 +27,7 @@ def read_pt_data(dir_path, file_name):
|
|
|
27
27
|
return None
|
|
28
28
|
|
|
29
29
|
data_path = os.path.join(dir_path, file_name)
|
|
30
|
-
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
31
|
-
FileCheckConst.PT_SUFFIX, False)
|
|
30
|
+
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.PT_SUFFIX)
|
|
32
31
|
data_path = path_checker.common_check()
|
|
33
32
|
try:
|
|
34
33
|
# detach because numpy can not process gradient information
|
|
@@ -48,6 +48,16 @@ class DebuggerConfig:
|
|
|
48
48
|
"max_sample": task_config.max_sample
|
|
49
49
|
}
|
|
50
50
|
|
|
51
|
+
self.online_run_ut = False
|
|
52
|
+
if self.task == Const.TENSOR:
|
|
53
|
+
# dump api tensor and collaborate with online run_ut
|
|
54
|
+
self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
|
|
55
|
+
self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
|
|
56
|
+
self.tls_path = task_config.tls_path if task_config.tls_path else ""
|
|
57
|
+
self.host = task_config.host if task_config.host else ""
|
|
58
|
+
self.port = task_config.port if task_config.port else -1
|
|
59
|
+
self.online_run_ut_recompute = task_config.online_run_ut_recompute \
|
|
60
|
+
if isinstance(task_config.online_run_ut_recompute, bool) else False
|
|
51
61
|
|
|
52
62
|
self.check()
|
|
53
63
|
self._check_statistics_config(task_config)
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
from functools import wraps
|
|
17
|
+
from typing import Any, Callable
|
|
17
18
|
|
|
18
19
|
import torch
|
|
19
20
|
from torch.utils.hooks import BackwardHook
|
|
@@ -21,6 +22,9 @@ from torch.utils.hooks import BackwardHook
|
|
|
21
22
|
from msprobe.core.common.const import Const
|
|
22
23
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
23
24
|
from msprobe.pytorch.common.log import logger
|
|
25
|
+
from msprobe.pytorch.hook_module.api_register import get_api_register
|
|
26
|
+
|
|
27
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
24
28
|
|
|
25
29
|
|
|
26
30
|
def wrap_setup_backward_hook(func):
|
|
@@ -92,3 +96,23 @@ def wrap_setup_backward_hook(func):
|
|
|
92
96
|
def wrap_setup_input_output_hook():
|
|
93
97
|
BackwardHook.setup_input_hook = wrap_setup_backward_hook(BackwardHook.setup_input_hook)
|
|
94
98
|
BackwardHook.setup_output_hook = wrap_setup_backward_hook(BackwardHook.setup_output_hook)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_apply_func_wrapper(original_func: Callable) -> Callable:
|
|
102
|
+
@wraps(original_func)
|
|
103
|
+
def wrapped_apply(*args, **kwargs) -> Any:
|
|
104
|
+
api_register = get_api_register()
|
|
105
|
+
if api_register:
|
|
106
|
+
api_register.restore_inner_used_api()
|
|
107
|
+
result = original_func(*args, **kwargs)
|
|
108
|
+
if api_register:
|
|
109
|
+
api_register.register_inner_used_api()
|
|
110
|
+
return result
|
|
111
|
+
|
|
112
|
+
return wrapped_apply
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def wrap_backward_hook_function_apply():
|
|
116
|
+
if torch_version_above_or_equal_2:
|
|
117
|
+
original_apply = torch.nn.modules._functions.BackwardHookFunction.apply
|
|
118
|
+
torch.nn.modules._functions.BackwardHookFunction.apply = get_apply_func_wrapper(original_apply)
|
|
@@ -13,21 +13,25 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import threading
|
|
17
16
|
import sys
|
|
17
|
+
import threading
|
|
18
18
|
from collections import OrderedDict
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
21
|
from torch.utils.hooks import BackwardHook, RemovableHandle
|
|
22
22
|
|
|
23
23
|
from msprobe.core.common.const import Const
|
|
24
|
+
from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
|
|
24
25
|
from msprobe.core.common.runtime import Runtime
|
|
25
26
|
from msprobe.core.common.utils import ModuleQueue, ThreadSafe
|
|
26
|
-
from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
|
|
27
27
|
from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
|
|
28
28
|
from msprobe.pytorch.common.log import logger
|
|
29
29
|
from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
|
|
30
|
-
from msprobe.pytorch.dump.module_dump.hook_wrapper import
|
|
30
|
+
from msprobe.pytorch.dump.module_dump.hook_wrapper import (
|
|
31
|
+
wrap_setup_input_output_hook,
|
|
32
|
+
wrap_backward_hook_function_apply
|
|
33
|
+
)
|
|
34
|
+
|
|
31
35
|
|
|
32
36
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
33
37
|
torch_version_above_or_equal_21 = torch.__version__.split('+')[0] >= '2.1'
|
|
@@ -63,6 +67,7 @@ def wrap_forward_with_hook_safety(module):
|
|
|
63
67
|
hook_fn = list(module._forward_hooks.values())[0]
|
|
64
68
|
hook_fn(module, args, kwargs, exception_output)
|
|
65
69
|
raise e
|
|
70
|
+
|
|
66
71
|
if torch_version_above_or_equal_21:
|
|
67
72
|
module.forward = wrapped_forward
|
|
68
73
|
|
|
@@ -80,6 +85,7 @@ class ModuleProcesser:
|
|
|
80
85
|
def __init__(self, scope):
|
|
81
86
|
self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
|
|
82
87
|
wrap_setup_input_output_hook()
|
|
88
|
+
wrap_backward_hook_function_apply()
|
|
83
89
|
try:
|
|
84
90
|
from megatron.core.pipeline_parallel import schedules
|
|
85
91
|
origin_func_id = id(schedules.deallocate_output_tensor)
|
|
@@ -43,7 +43,6 @@ else:
|
|
|
43
43
|
|
|
44
44
|
torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
|
|
45
45
|
|
|
46
|
-
_inner_used_api = {}
|
|
47
46
|
_supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
|
|
48
47
|
_cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
|
|
49
48
|
dist_data_collect_func = {}
|
|
@@ -85,6 +84,12 @@ if not is_gpu:
|
|
|
85
84
|
mindspeed_op_file_list = [op.split(Const.SEP)[0] + Const.PY_SUFFIX for op in mindspeed_op_list]
|
|
86
85
|
dynamic_import_op(mindspeed.ops, mindspeed_op_file_list)
|
|
87
86
|
|
|
87
|
+
_inner_used_api = {
|
|
88
|
+
Const.PT_FRAMEWORK + Const.SEP + Const.PT_API_TYPE_TENSOR: (
|
|
89
|
+
torch.Tensor, "view_as"
|
|
90
|
+
)
|
|
91
|
+
}
|
|
92
|
+
|
|
88
93
|
|
|
89
94
|
@parameter_adapter
|
|
90
95
|
def tensor_module_forward(module, *args, **kwargs):
|
msprobe/pytorch/pt_config.py
CHANGED
|
@@ -35,15 +35,48 @@ from msprobe.pytorch.hook_module.utils import get_ops
|
|
|
35
35
|
class TensorConfig(BaseConfig):
|
|
36
36
|
def __init__(self, json_config):
|
|
37
37
|
super().__init__(json_config)
|
|
38
|
+
self.online_run_ut = json_config.get("online_run_ut", False)
|
|
39
|
+
self.nfs_path = json_config.get("nfs_path", "")
|
|
40
|
+
self.host = json_config.get("host", "")
|
|
41
|
+
self.port = json_config.get("port", -1)
|
|
42
|
+
self.tls_path = json_config.get("tls_path", "./")
|
|
43
|
+
self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
|
|
38
44
|
self.check_config()
|
|
39
45
|
self._check_summary_mode()
|
|
40
46
|
self._check_file_format()
|
|
41
|
-
|
|
47
|
+
if self.online_run_ut:
|
|
48
|
+
self._check_online_run_ut()
|
|
42
49
|
|
|
43
50
|
def _check_file_format(self):
|
|
44
51
|
if self.file_format is not None and self.file_format not in ["npy", "bin"]:
|
|
45
52
|
raise Exception("file_format is invalid")
|
|
46
53
|
|
|
54
|
+
def _check_online_run_ut(self):
|
|
55
|
+
if not isinstance(self.online_run_ut, bool):
|
|
56
|
+
raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.")
|
|
57
|
+
|
|
58
|
+
if not isinstance(self.online_run_ut_recompute, bool):
|
|
59
|
+
raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.")
|
|
60
|
+
|
|
61
|
+
if self.nfs_path:
|
|
62
|
+
check_file_or_directory_path(self.nfs_path, isdir=True)
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
if self.tls_path:
|
|
66
|
+
check_file_or_directory_path(self.tls_path, isdir=True)
|
|
67
|
+
check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
|
|
68
|
+
check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
|
|
69
|
+
check_file_or_directory_path(os.path.join(self.tls_path, "ca.crt"))
|
|
70
|
+
crl_path = os.path.join(self.tls_path, "crl.pem")
|
|
71
|
+
if os.path.exists(crl_path):
|
|
72
|
+
check_file_or_directory_path(crl_path)
|
|
73
|
+
|
|
74
|
+
if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
|
|
75
|
+
raise Exception(f"host: {self.host} is invalid.")
|
|
76
|
+
|
|
77
|
+
if not isinstance(self.port, int) or not (0 < self.port <= 65535):
|
|
78
|
+
raise Exception(f"port: {self.port} is invalid, port range 0-65535.")
|
|
79
|
+
|
|
47
80
|
|
|
48
81
|
class StatisticsConfig(BaseConfig):
|
|
49
82
|
def __init__(self, json_config):
|
|
@@ -80,6 +113,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
|
|
|
80
113
|
self.handler_type = json_config.get("handler_type", PytorchFreeBenchmarkConst.DEFAULT_HANDLER)
|
|
81
114
|
self.fuzz_level = json_config.get("fuzz_level", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_LEVEL)
|
|
82
115
|
self.fuzz_stage = json_config.get("fuzz_stage", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_STAGE)
|
|
116
|
+
self.list = json_config.get("list")
|
|
83
117
|
self.if_preheat = json_config.get("if_preheat", False)
|
|
84
118
|
self.preheat_step = json_config.get("preheat_step", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
|
|
85
119
|
self.max_sample = json_config.get("max_sample", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
|
|
@@ -146,6 +180,11 @@ class FreeBenchmarkCheckConfig(BaseConfig):
|
|
|
146
180
|
logger.error_log_with_exp(
|
|
147
181
|
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
148
182
|
)
|
|
183
|
+
if self.fuzz_stage == Const.BACKWARD and not self.list:
|
|
184
|
+
raise MsprobeException(
|
|
185
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
186
|
+
f"When fuzz_stage is set to {Const.BACKWARD}, the parameters list must not be empty."
|
|
187
|
+
)
|
|
149
188
|
|
|
150
189
|
def _check_fuzz_level(self):
|
|
151
190
|
if self.fuzz_level not in PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST:
|
|
@@ -218,7 +257,12 @@ class RunUTConfig(BaseConfig):
|
|
|
218
257
|
self.white_list = json_config.get("white_list", Const.DEFAULT_LIST)
|
|
219
258
|
self.black_list = json_config.get("black_list", Const.DEFAULT_LIST)
|
|
220
259
|
self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH)
|
|
221
|
-
|
|
260
|
+
self.is_online = json_config.get("is_online", False)
|
|
261
|
+
self.nfs_path = json_config.get("nfs_path", "")
|
|
262
|
+
self.host = json_config.get("host", "")
|
|
263
|
+
self.port = json_config.get("port", -1)
|
|
264
|
+
self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
|
|
265
|
+
self.tls_path = json_config.get("tls_path", "./")
|
|
222
266
|
self.check_run_ut_config()
|
|
223
267
|
|
|
224
268
|
@classmethod
|
|
@@ -236,11 +280,22 @@ class RunUTConfig(BaseConfig):
|
|
|
236
280
|
if not os.path.exists(error_data_path):
|
|
237
281
|
raise Exception("error_data_path: %s does not exist" % error_data_path)
|
|
238
282
|
|
|
283
|
+
@classmethod
|
|
284
|
+
def check_nfs_path_config(cls, nfs_path):
|
|
285
|
+
if nfs_path:
|
|
286
|
+
FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
287
|
+
|
|
288
|
+
@classmethod
|
|
289
|
+
def check_tls_path_config(cls, tls_path):
|
|
290
|
+
if tls_path:
|
|
291
|
+
FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
239
292
|
|
|
240
293
|
def check_run_ut_config(self):
|
|
241
294
|
RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
|
|
242
295
|
RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list)
|
|
243
296
|
RunUTConfig.check_error_data_path_config(self.error_data_path)
|
|
297
|
+
RunUTConfig.check_nfs_path_config(self.nfs_path)
|
|
298
|
+
RunUTConfig.check_tls_path_config(self.tls_path)
|
|
244
299
|
|
|
245
300
|
|
|
246
301
|
class GradToolConfig(BaseConfig):
|
|
@@ -15,8 +15,9 @@
|
|
|
15
15
|
|
|
16
16
|
from msprobe.core.common.utils import Const
|
|
17
17
|
from msprobe.core.service import BaseService
|
|
18
|
+
from msprobe.pytorch.attl_manager import ATTLManager
|
|
18
19
|
from msprobe.pytorch.common.log import logger
|
|
19
|
-
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
20
|
+
from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_above_or_equal_2
|
|
20
21
|
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
21
22
|
from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate, redirect_wait
|
|
22
23
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
@@ -24,6 +25,9 @@ from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager
|
|
|
24
25
|
from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
|
|
25
26
|
from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, preprocess_func
|
|
26
27
|
|
|
28
|
+
if torch_version_above_or_equal_2:
|
|
29
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
|
|
30
|
+
|
|
27
31
|
|
|
28
32
|
class PytorchService(BaseService):
|
|
29
33
|
@property
|
|
@@ -41,10 +45,12 @@ class PytorchService(BaseService):
|
|
|
41
45
|
self.logger = logger
|
|
42
46
|
self.api_register = get_api_register()
|
|
43
47
|
self.module_processor = ModuleProcesser(self.data_collector.scope)
|
|
44
|
-
self.
|
|
48
|
+
self.attl_manager = ATTLManager(self.config)
|
|
49
|
+
self.hook_manager = PytorchHookManager(self.data_collector, self.config, self.attl_manager)
|
|
45
50
|
self.api_template = ApiTemplate
|
|
46
51
|
|
|
47
52
|
def _register_hook(self):
|
|
53
|
+
self.attl_manager.attl_init()
|
|
48
54
|
if self._is_mix_level:
|
|
49
55
|
register_optimizer_hook(self.data_collector)
|
|
50
56
|
|
|
@@ -59,6 +65,9 @@ class PytorchService(BaseService):
|
|
|
59
65
|
self.module_processor.register_module_hook(self.model, self.build_hook)
|
|
60
66
|
self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.")
|
|
61
67
|
|
|
68
|
+
def _run_ut_dispatch(self, status):
|
|
69
|
+
if torch_version_above_or_equal_2:
|
|
70
|
+
run_ut_dispatch(self.attl_manager.attl, status, self.config.online_run_ut_recompute)
|
|
62
71
|
|
|
63
72
|
def _reset_status(self):
|
|
64
73
|
super()._reset_status()
|
|
@@ -74,6 +74,7 @@ class GraphBuilder:
|
|
|
74
74
|
config.graph_b.data_source = GraphConst.JSON_BENCH_KEY
|
|
75
75
|
config.graph_b.step = config.step
|
|
76
76
|
config.graph_b.rank = config.rank
|
|
77
|
+
config.graph_b.compare_mode = config.compare_mode
|
|
77
78
|
node_to_db(config.graph_b, filename)
|
|
78
79
|
config_to_db(config, filename)
|
|
79
80
|
|
msprobe/visualization/utils.py
CHANGED
|
@@ -152,7 +152,8 @@ def load_parallel_param(input_param):
|
|
|
152
152
|
|
|
153
153
|
|
|
154
154
|
def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'):
|
|
155
|
-
|
|
155
|
+
pattern = re.compile(r'^[a-z\-]+$')
|
|
156
|
+
params = [parallel_param.tp, parallel_param.pp, parallel_param.rank_size, parallel_param.vpp]
|
|
156
157
|
ranks = check_and_return_dir_contents(dump_path, Const.RANK)
|
|
157
158
|
if len(ranks) != parallel_param.rank_size:
|
|
158
159
|
logger.error(f'{log_prefix} The parallel param "rank_size" error, '
|
|
@@ -161,6 +162,12 @@ def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'):
|
|
|
161
162
|
if any(x is None for x in params):
|
|
162
163
|
logger.error(f'{log_prefix} The parallel params "tp/pp/rank_size" must not be null!')
|
|
163
164
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
165
|
+
if any(isinstance(x, bool) for x in params):
|
|
166
|
+
logger.error(f'{log_prefix} The parallel params "tp/pp/vpp/rank_size" must not be bool!')
|
|
167
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
168
|
+
if any(not isinstance(x, int) for x in params):
|
|
169
|
+
logger.error(f'{log_prefix} The parallel params "tp/pp/vpp/rank_size" must be int!')
|
|
170
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
164
171
|
if any(x <= 0 for x in params):
|
|
165
172
|
logger.error(f'{log_prefix} The parallel params "tp/pp/vpp/rank_size" must be greater than 0!')
|
|
166
173
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
@@ -185,6 +192,9 @@ def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'):
|
|
|
185
192
|
if not isinstance(parallel_param.order, str):
|
|
186
193
|
logger.error(f'{log_prefix} The parallel params "order" must be of string type!')
|
|
187
194
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
195
|
+
if not pattern.match(parallel_param.order):
|
|
196
|
+
logger.error(f'{log_prefix} The parallel params "order" must consist only of lowercase letters and "-"!')
|
|
197
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
188
198
|
|
|
189
199
|
|
|
190
200
|
class ParallelParam:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|