mindstudio-probe 8.2.0__py3-none-any.whl → 8.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/RECORD +90 -79
- msprobe/README.md +7 -5
- msprobe/core/common/const.py +6 -0
- msprobe/core/common/db_manager.py +35 -4
- msprobe/core/common/file_utils.py +105 -27
- msprobe/core/common/framework_adapter.py +7 -6
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/utils.py +14 -3
- msprobe/core/compare/find_first/analyzer.py +8 -7
- msprobe/core/compare/find_first/graph.py +11 -3
- msprobe/core/compare/find_first/utils.py +2 -1
- msprobe/core/compare/highlight.py +13 -6
- msprobe/core/compare/multiprocessing_compute.py +17 -10
- msprobe/core/compare/utils.py +14 -5
- msprobe/core/data_dump/data_collector.py +18 -21
- msprobe/core/data_dump/data_processor/pytorch_processor.py +43 -20
- msprobe/core/data_dump/json_writer.py +18 -8
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +37 -3
- msprobe/core/service.py +18 -5
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +7 -5
- msprobe/docs/02.config_introduction.md +14 -1
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/06.data_dump_MindSpore.md +1 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +295 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +46 -5
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +2 -0
- msprobe/docs/21.visualization_PyTorch.md +15 -80
- msprobe/docs/22.visualization_MindSpore.md +20 -104
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/mindspore/cell_processor.py +33 -5
- msprobe/mindspore/compare/common_dir_compare.py +22 -26
- msprobe/mindspore/compare/utils.py +1 -2
- msprobe/mindspore/debugger/precision_debugger.py +1 -1
- msprobe/mindspore/dump/cell_dump_process.py +73 -62
- msprobe/mindspore/dump/graph_mode_cell_dump.py +21 -10
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +2 -0
- msprobe/msprobe.py +6 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +36 -3
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +24 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +12 -2
- msprobe/pytorch/api_accuracy_checker/config.yaml +6 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +132 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +205 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +378 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +239 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +250 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +198 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/common/utils.py +22 -2
- msprobe/pytorch/compare/utils.py +3 -3
- msprobe/pytorch/debugger/debugger_config.py +10 -0
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +34 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +23 -10
- msprobe/pytorch/hook_module/api_register.py +6 -1
- msprobe/pytorch/monitor/module_hook.py +28 -9
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/pt_config.py +57 -2
- msprobe/pytorch/pytorch_service.py +11 -2
- msprobe/visualization/builder/graph_builder.py +170 -64
- msprobe/visualization/builder/graph_merger.py +0 -1
- msprobe/visualization/builder/msprobe_adapter.py +1 -1
- msprobe/visualization/db_utils.py +25 -2
- msprobe/visualization/graph/base_node.py +0 -24
- msprobe/visualization/graph/graph.py +5 -14
- msprobe/visualization/graph_service.py +29 -53
- msprobe/visualization/utils.py +11 -1
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/top_level.txt +0 -0
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -39,7 +39,6 @@ except ImportError:
|
|
|
39
39
|
else:
|
|
40
40
|
is_gpu = False
|
|
41
41
|
|
|
42
|
-
|
|
43
42
|
torch_without_guard_version = torch.__version__ >= '2.1'
|
|
44
43
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
45
44
|
|
|
@@ -416,7 +415,8 @@ def is_recomputation():
|
|
|
416
415
|
|
|
417
416
|
# Identify indices in the call stack where the specific function is being executed
|
|
418
417
|
for idx, frame_info in enumerate(call_stack):
|
|
419
|
-
if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward'
|
|
418
|
+
if (frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward' and
|
|
419
|
+
"megatron" in frame_info.filename):
|
|
420
420
|
backward_function_indices.append(idx)
|
|
421
421
|
|
|
422
422
|
# Check if the execution is within 'torch/autograd/function.py' file
|
|
@@ -471,3 +471,23 @@ def register_forward_hook(module, forward_hook):
|
|
|
471
471
|
module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
472
472
|
else:
|
|
473
473
|
module.register_forward_hook(forward_hook)
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def save_api_data(api_data):
|
|
477
|
+
"""Save data to io stream"""
|
|
478
|
+
try:
|
|
479
|
+
io_buff = io.BytesIO()
|
|
480
|
+
torch.save(api_data, io_buff)
|
|
481
|
+
except Exception as e:
|
|
482
|
+
raise RuntimeError(f"save api_data to io_buff failed") from e
|
|
483
|
+
return io_buff
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def load_api_data(api_data_bytes):
|
|
487
|
+
"""Load data from bytes stream"""
|
|
488
|
+
try:
|
|
489
|
+
buffer = io.BytesIO(api_data_bytes)
|
|
490
|
+
buffer = torch.load(buffer, map_location="cpu")
|
|
491
|
+
except Exception as e:
|
|
492
|
+
raise RuntimeError(f"load api_data from bytes failed") from e
|
|
493
|
+
return buffer
|
msprobe/pytorch/compare/utils.py
CHANGED
|
@@ -27,15 +27,15 @@ def read_pt_data(dir_path, file_name):
|
|
|
27
27
|
return None
|
|
28
28
|
|
|
29
29
|
data_path = os.path.join(dir_path, file_name)
|
|
30
|
-
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
31
|
-
FileCheckConst.PT_SUFFIX, False)
|
|
30
|
+
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.PT_SUFFIX)
|
|
32
31
|
data_path = path_checker.common_check()
|
|
33
32
|
try:
|
|
34
33
|
# detach because numpy can not process gradient information
|
|
35
34
|
data_value = load_pt(data_path, to_cpu=True).detach()
|
|
36
35
|
except RuntimeError as e:
|
|
37
36
|
# 这里捕获 load_pt 中抛出的异常
|
|
38
|
-
|
|
37
|
+
data_path_file_name = os.path.basename(data_path)
|
|
38
|
+
logger.error(f"Failed to load the .pt file at {data_path_file_name}.")
|
|
39
39
|
raise CompareException(CompareException.INVALID_FILE_ERROR) from e
|
|
40
40
|
except AttributeError as e:
|
|
41
41
|
# 这里捕获 detach 方法抛出的异常
|
|
@@ -48,6 +48,16 @@ class DebuggerConfig:
|
|
|
48
48
|
"max_sample": task_config.max_sample
|
|
49
49
|
}
|
|
50
50
|
|
|
51
|
+
self.online_run_ut = False
|
|
52
|
+
if self.task == Const.TENSOR:
|
|
53
|
+
# dump api tensor and collaborate with online run_ut
|
|
54
|
+
self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
|
|
55
|
+
self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
|
|
56
|
+
self.tls_path = task_config.tls_path if task_config.tls_path else ""
|
|
57
|
+
self.host = task_config.host if task_config.host else ""
|
|
58
|
+
self.port = task_config.port if task_config.port else -1
|
|
59
|
+
self.online_run_ut_recompute = task_config.online_run_ut_recompute \
|
|
60
|
+
if isinstance(task_config.online_run_ut_recompute, bool) else False
|
|
51
61
|
|
|
52
62
|
self.check()
|
|
53
63
|
self._check_statistics_config(task_config)
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
from functools import wraps
|
|
17
|
+
from typing import Any, Callable
|
|
17
18
|
|
|
18
19
|
import torch
|
|
19
20
|
from torch.utils.hooks import BackwardHook
|
|
@@ -21,11 +22,17 @@ from torch.utils.hooks import BackwardHook
|
|
|
21
22
|
from msprobe.core.common.const import Const
|
|
22
23
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
23
24
|
from msprobe.pytorch.common.log import logger
|
|
25
|
+
from msprobe.pytorch.hook_module.api_register import get_api_register
|
|
26
|
+
|
|
27
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
24
28
|
|
|
25
29
|
|
|
26
30
|
def wrap_setup_backward_hook(func):
|
|
27
|
-
def requires_clone(tensor):
|
|
28
|
-
|
|
31
|
+
def requires_clone(tensor, need_check_leaf=False):
|
|
32
|
+
need_clone = isinstance(tensor, torch.Tensor) and tensor.requires_grad and torch.is_grad_enabled()
|
|
33
|
+
if need_check_leaf:
|
|
34
|
+
need_clone &= tensor.grad_fn is not None
|
|
35
|
+
return need_clone
|
|
29
36
|
|
|
30
37
|
@recursion_depth_decorator("Dump: wrap_setup_backward_hook.parse_tensor", max_depth=Const.DUMP_MAX_DEPTH)
|
|
31
38
|
def parse_tensor(item, tensor_list):
|
|
@@ -39,20 +46,20 @@ def wrap_setup_backward_hook(func):
|
|
|
39
46
|
parse_tensor(value, tensor_list)
|
|
40
47
|
|
|
41
48
|
@recursion_depth_decorator("Dump: wrap_setup_backward_hook.rebuild_args", max_depth=Const.DUMP_MAX_DEPTH)
|
|
42
|
-
def rebuild_args(item, tensor_iter):
|
|
43
|
-
if requires_clone(item):
|
|
49
|
+
def rebuild_args(item, tensor_iter, need_check_leaf=False):
|
|
50
|
+
if requires_clone(item, need_check_leaf):
|
|
44
51
|
result = next(tensor_iter)
|
|
45
52
|
if hasattr(result, "_base") and result._base is not None:
|
|
46
53
|
if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0):
|
|
47
54
|
torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0))
|
|
48
|
-
return result
|
|
55
|
+
return result
|
|
49
56
|
if isinstance(item, list):
|
|
50
57
|
for index, value in enumerate(item):
|
|
51
|
-
item[index] = rebuild_args(value, tensor_iter)
|
|
58
|
+
item[index] = rebuild_args(value, tensor_iter, need_check_leaf=True)
|
|
52
59
|
return item
|
|
53
60
|
if isinstance(item, dict):
|
|
54
61
|
for key, value in item.items():
|
|
55
|
-
item[key] = rebuild_args(value, tensor_iter)
|
|
62
|
+
item[key] = rebuild_args(value, tensor_iter, need_check_leaf=True)
|
|
56
63
|
return item
|
|
57
64
|
if isinstance(item, tuple):
|
|
58
65
|
if hasattr(item, '_fields'):
|
|
@@ -89,3 +96,23 @@ def wrap_setup_backward_hook(func):
|
|
|
89
96
|
def wrap_setup_input_output_hook():
|
|
90
97
|
BackwardHook.setup_input_hook = wrap_setup_backward_hook(BackwardHook.setup_input_hook)
|
|
91
98
|
BackwardHook.setup_output_hook = wrap_setup_backward_hook(BackwardHook.setup_output_hook)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_apply_func_wrapper(original_func: Callable) -> Callable:
|
|
102
|
+
@wraps(original_func)
|
|
103
|
+
def wrapped_apply(*args, **kwargs) -> Any:
|
|
104
|
+
api_register = get_api_register()
|
|
105
|
+
if api_register:
|
|
106
|
+
api_register.restore_inner_used_api()
|
|
107
|
+
result = original_func(*args, **kwargs)
|
|
108
|
+
if api_register:
|
|
109
|
+
api_register.register_inner_used_api()
|
|
110
|
+
return result
|
|
111
|
+
|
|
112
|
+
return wrapped_apply
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def wrap_backward_hook_function_apply():
|
|
116
|
+
if torch_version_above_or_equal_2:
|
|
117
|
+
original_apply = torch.nn.modules._functions.BackwardHookFunction.apply
|
|
118
|
+
torch.nn.modules._functions.BackwardHookFunction.apply = get_apply_func_wrapper(original_apply)
|
|
@@ -13,23 +13,29 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import threading
|
|
17
16
|
import sys
|
|
17
|
+
import threading
|
|
18
18
|
from collections import OrderedDict
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
21
|
from torch.utils.hooks import BackwardHook, RemovableHandle
|
|
22
22
|
|
|
23
23
|
from msprobe.core.common.const import Const
|
|
24
|
+
from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
|
|
24
25
|
from msprobe.core.common.runtime import Runtime
|
|
25
26
|
from msprobe.core.common.utils import ModuleQueue, ThreadSafe
|
|
26
27
|
from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
|
|
27
28
|
from msprobe.pytorch.common.log import logger
|
|
28
29
|
from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
|
|
29
|
-
from msprobe.pytorch.dump.module_dump.hook_wrapper import
|
|
30
|
+
from msprobe.pytorch.dump.module_dump.hook_wrapper import (
|
|
31
|
+
wrap_setup_input_output_hook,
|
|
32
|
+
wrap_backward_hook_function_apply
|
|
33
|
+
)
|
|
34
|
+
|
|
30
35
|
|
|
31
36
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
32
|
-
|
|
37
|
+
torch_version_above_or_equal_21 = torch.__version__.split('+')[0] >= '2.1'
|
|
38
|
+
if torch_version_above_or_equal_21:
|
|
33
39
|
from torch.utils.checkpoint import _StopRecomputationError
|
|
34
40
|
|
|
35
41
|
|
|
@@ -61,7 +67,8 @@ def wrap_forward_with_hook_safety(module):
|
|
|
61
67
|
hook_fn = list(module._forward_hooks.values())[0]
|
|
62
68
|
hook_fn(module, args, kwargs, exception_output)
|
|
63
69
|
raise e
|
|
64
|
-
|
|
70
|
+
|
|
71
|
+
if torch_version_above_or_equal_21:
|
|
65
72
|
module.forward = wrapped_forward
|
|
66
73
|
|
|
67
74
|
|
|
@@ -78,10 +85,13 @@ class ModuleProcesser:
|
|
|
78
85
|
def __init__(self, scope):
|
|
79
86
|
self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
|
|
80
87
|
wrap_setup_input_output_hook()
|
|
88
|
+
wrap_backward_hook_function_apply()
|
|
81
89
|
try:
|
|
82
90
|
from megatron.core.pipeline_parallel import schedules
|
|
83
91
|
origin_func_id = id(schedules.deallocate_output_tensor)
|
|
84
92
|
schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor)
|
|
93
|
+
schedules.forward_step = wrap_megatron_step(schedules.forward_step)
|
|
94
|
+
schedules.backward_step = wrap_megatron_step(schedules.backward_step, is_forward=False)
|
|
85
95
|
for module in list(sys.modules.values()):
|
|
86
96
|
if module.__name__ == 'schedules':
|
|
87
97
|
continue
|
|
@@ -258,14 +268,16 @@ class ModuleProcesser:
|
|
|
258
268
|
ModuleProcesser.module_stack[tid] = []
|
|
259
269
|
|
|
260
270
|
if self.module_stack[tid]:
|
|
261
|
-
ModuleProcesser.module_node[full_name] = self.module_stack[tid][-1]
|
|
271
|
+
ModuleProcesser.module_node[full_name] = self.module_stack[tid][-1] if not is_megatron() \
|
|
272
|
+
else [self.module_stack[tid][-1], get_micro_step()]
|
|
262
273
|
else:
|
|
263
274
|
parent_name = ModuleProcesser.module_queue.find_last(full_name)
|
|
264
|
-
ModuleProcesser.module_node[full_name] = parent_name
|
|
275
|
+
ModuleProcesser.module_node[full_name] = parent_name if not is_megatron() \
|
|
276
|
+
else [parent_name, get_micro_step()]
|
|
265
277
|
|
|
266
278
|
ModuleProcesser.module_queue.add_name(full_name)
|
|
267
279
|
ModuleProcesser.module_stack[tid].append(full_name)
|
|
268
|
-
ModuleProcesser.api_parent_node[tid] = full_name
|
|
280
|
+
ModuleProcesser.api_parent_node[tid] = full_name if not is_megatron() else [full_name, get_micro_step()]
|
|
269
281
|
if self.scope:
|
|
270
282
|
self.scope.begin_module(full_name)
|
|
271
283
|
|
|
@@ -273,14 +285,15 @@ class ModuleProcesser:
|
|
|
273
285
|
tid = threading.get_ident()
|
|
274
286
|
if torch_version_above_or_equal_2 or is_forward:
|
|
275
287
|
ModuleProcesser.module_queue.remove_name(full_name)
|
|
276
|
-
ModuleProcesser.api_parent_node[tid] = None
|
|
288
|
+
ModuleProcesser.api_parent_node[tid] = None if not is_megatron() else [None, get_micro_step()]
|
|
277
289
|
if self.module_stack.get(tid):
|
|
278
290
|
ModuleProcesser.module_stack[tid].pop()
|
|
279
291
|
if self.module_stack.get(tid):
|
|
280
|
-
ModuleProcesser.api_parent_node[tid] = ModuleProcesser.module_stack[tid][-1]
|
|
292
|
+
ModuleProcesser.api_parent_node[tid] = ModuleProcesser.module_stack[tid][-1] if not is_megatron() \
|
|
293
|
+
else [ModuleProcesser.module_stack[tid][-1], get_micro_step()]
|
|
281
294
|
if self.scope:
|
|
282
295
|
self.scope.end_module(full_name)
|
|
283
296
|
else:
|
|
284
297
|
if self.scope:
|
|
285
298
|
self.scope.begin_module(full_name)
|
|
286
|
-
ModuleProcesser.api_parent_node[tid] = full_name
|
|
299
|
+
ModuleProcesser.api_parent_node[tid] = full_name if not is_megatron() else [full_name, get_micro_step()]
|
|
@@ -43,7 +43,6 @@ else:
|
|
|
43
43
|
|
|
44
44
|
torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
|
|
45
45
|
|
|
46
|
-
_inner_used_api = {}
|
|
47
46
|
_supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
|
|
48
47
|
_cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
|
|
49
48
|
dist_data_collect_func = {}
|
|
@@ -85,6 +84,12 @@ if not is_gpu:
|
|
|
85
84
|
mindspeed_op_file_list = [op.split(Const.SEP)[0] + Const.PY_SUFFIX for op in mindspeed_op_list]
|
|
86
85
|
dynamic_import_op(mindspeed.ops, mindspeed_op_file_list)
|
|
87
86
|
|
|
87
|
+
_inner_used_api = {
|
|
88
|
+
Const.PT_FRAMEWORK + Const.SEP + Const.PT_API_TYPE_TENSOR: (
|
|
89
|
+
torch.Tensor, "view_as"
|
|
90
|
+
)
|
|
91
|
+
}
|
|
92
|
+
|
|
88
93
|
|
|
89
94
|
@parameter_adapter
|
|
90
95
|
def tensor_module_forward(module, *args, **kwargs):
|
|
@@ -19,12 +19,14 @@ import importlib
|
|
|
19
19
|
from collections import defaultdict
|
|
20
20
|
from datetime import datetime
|
|
21
21
|
from functools import partial
|
|
22
|
+
from itertools import cycle
|
|
22
23
|
|
|
23
24
|
import pytz
|
|
24
25
|
import torch
|
|
25
26
|
import torch.distributed as dist
|
|
26
27
|
import pandas as pd
|
|
27
28
|
from torch.utils.hooks import BackwardHook
|
|
29
|
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
28
30
|
|
|
29
31
|
from msprobe.core.common.const import MonitorConst, Const
|
|
30
32
|
from msprobe.core.common.file_utils import load_json, save_json, make_dir
|
|
@@ -229,6 +231,8 @@ class TrainerMon:
|
|
|
229
231
|
self.duplicate_param = {}
|
|
230
232
|
self.name2tag = {}
|
|
231
233
|
self.param_name_call_id = {}
|
|
234
|
+
self.flat_prefix_names = []
|
|
235
|
+
self.flat_prefix_reverse_iter = None
|
|
232
236
|
self.call_id = 0
|
|
233
237
|
self.module_struct = defaultdict(dict)
|
|
234
238
|
self.grad_accs = []
|
|
@@ -945,13 +949,20 @@ class TrainerMon:
|
|
|
945
949
|
return False
|
|
946
950
|
|
|
947
951
|
def _register_chunk(self, model_chunk, prefix):
|
|
952
|
+
if isinstance(model_chunk, FSDP):
|
|
953
|
+
if not model_chunk._use_orig_params:
|
|
954
|
+
raise ValueError("Only Support fsdp1 with use_orig_params=True")
|
|
955
|
+
self.fsdp_wrapped_module = True
|
|
948
956
|
for (param_name, param) in model_chunk.named_parameters():
|
|
949
957
|
if not param.requires_grad:
|
|
950
958
|
continue
|
|
951
|
-
if not self.fsdp_wrapped_module and param_name.startswith("_fsdp_wrapped_module"):
|
|
952
|
-
self.fsdp_wrapped_module = True
|
|
953
959
|
if not self.fsdp2_wrapped_module and param.__class__.__name__ == "DTensor":
|
|
954
960
|
self.fsdp2_wrapped_module = True
|
|
961
|
+
if self.fsdp_wrapped_module: # FSDP1需要记录完整的不被target限制的flat权重前缀名,以供后续对flat解包
|
|
962
|
+
flat_prefix_name, _ = param_name.rsplit(MonitorConst.FSDP_FLAT_SEP, 1)
|
|
963
|
+
if flat_prefix_name not in self.flat_prefix_names:
|
|
964
|
+
self.flat_prefix_names.append(flat_prefix_name)
|
|
965
|
+
|
|
955
966
|
if self._is_target_param(param_name, param, prefix):
|
|
956
967
|
name = prefix + squash_param_name(param_name, self.squash_name)
|
|
957
968
|
if name in self.param2name.values():
|
|
@@ -975,6 +986,8 @@ class TrainerMon:
|
|
|
975
986
|
k: get_summary_writer_tag_name(name, k, self.rank)
|
|
976
987
|
for k in keywords
|
|
977
988
|
}
|
|
989
|
+
if self.fsdp_wrapped_module:
|
|
990
|
+
self.flat_prefix_reverse_iter = cycle(reversed(self.flat_prefix_names)) # post_backward_hook调用顺序是反向的
|
|
978
991
|
|
|
979
992
|
def _register_param_name(self):
|
|
980
993
|
for vpp_stage, model_chunk in enumerate(self.model):
|
|
@@ -1224,17 +1237,22 @@ class TrainerMon:
|
|
|
1224
1237
|
每个forward阶段,fsdp对AccumulateGrad重复注册hook方法,monitor工具内注册hook无法生效,
|
|
1225
1238
|
因此对_post_backward_hook进行patch,在backward后,reduce_scatter前采集梯度。
|
|
1226
1239
|
"""
|
|
1240
|
+
|
|
1227
1241
|
def patch_post_backward_hook(_post_backward_hook):
|
|
1228
1242
|
def wrapper(state, handle, *unused):
|
|
1229
1243
|
grad_dict = {}
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1244
|
+
local_names = handle.flat_param._fqns
|
|
1245
|
+
offsets = handle._get_flat_param_offsets()
|
|
1246
|
+
shapes = handle.flat_param._shapes
|
|
1247
|
+
flat_prefix = next(self.flat_prefix_reverse_iter)
|
|
1248
|
+
for local_name, (start, end), local_shape in zip(local_names, offsets, shapes):
|
|
1249
|
+
grad_clip = handle.flat_param.grad[start:end + 1]
|
|
1250
|
+
grad = grad_clip.reshape(local_shape)
|
|
1251
|
+
total_name = f"{flat_prefix}{MonitorConst.FSDP_FLAT_SEP}{local_name}"
|
|
1252
|
+
if total_name not in self.origin2squash:
|
|
1253
|
+
logger.warning(f"{total_name} not in model.named_parameters(), skip.")
|
|
1234
1254
|
continue
|
|
1235
|
-
|
|
1236
|
-
offset += limit
|
|
1237
|
-
tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
|
|
1255
|
+
tag = self.name2tag.get(self.origin2squash[total_name], {}).get(MonitorConst.PRE_GRAD)
|
|
1238
1256
|
if tag is None:
|
|
1239
1257
|
continue
|
|
1240
1258
|
grad_dict[tag] = grad
|
|
@@ -1242,6 +1260,7 @@ class TrainerMon:
|
|
|
1242
1260
|
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
1243
1261
|
out = _post_backward_hook(state, handle, *unused)
|
|
1244
1262
|
return out
|
|
1263
|
+
|
|
1245
1264
|
return wrapper
|
|
1246
1265
|
|
|
1247
1266
|
logger.info("Patch fsdp _post_backward_hook, collect pre_grad metrics.")
|
|
@@ -17,7 +17,7 @@ import json
|
|
|
17
17
|
import os
|
|
18
18
|
import time
|
|
19
19
|
import multiprocessing
|
|
20
|
-
from multiprocessing import Pool
|
|
20
|
+
from multiprocessing import Pool, Lock
|
|
21
21
|
|
|
22
22
|
import torch
|
|
23
23
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
@@ -39,6 +39,7 @@ from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu, ge
|
|
|
39
39
|
from msprobe.pytorch.online_dispatch.compare import Comparator
|
|
40
40
|
from msprobe.core.common.utils import check_str_param, safe_get_value
|
|
41
41
|
|
|
42
|
+
child_global_lock = None
|
|
42
43
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
43
44
|
RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
|
|
44
45
|
DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
@@ -86,14 +87,14 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
86
87
|
yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
|
|
87
88
|
self.get_ops(yaml_path)
|
|
88
89
|
|
|
89
|
-
self.lock = None
|
|
90
|
+
self.lock = Lock() if process_num > 0 else None
|
|
90
91
|
max_process_num = max(int((multiprocessing.cpu_count() + 1) // Const.CPU_QUARTER), 1)
|
|
91
92
|
if process_num > max_process_num:
|
|
92
93
|
logger.error(f"process_num should be less than or equal to {max_process_num}, but got {process_num}!")
|
|
93
94
|
raise DispatchException(f'process_num should be less than or equal to {max_process_num}, '
|
|
94
95
|
f'but got {process_num}!')
|
|
95
96
|
if process_num > 0:
|
|
96
|
-
self.pool = Pool(process_num)
|
|
97
|
+
self.pool = Pool(process_num, initializer=self._init_child_process, initargs=(self.lock,))
|
|
97
98
|
if debug:
|
|
98
99
|
logger.info(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} '
|
|
99
100
|
f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], '
|
|
@@ -114,18 +115,17 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
114
115
|
logger.error("Please check train log, An exception may have occurred!")
|
|
115
116
|
return
|
|
116
117
|
check_file_or_directory_path(summary_path, False)
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
fp_handle.close()
|
|
118
|
+
with FileOpen(summary_path, "r") as fp_handle:
|
|
119
|
+
while True:
|
|
120
|
+
json_line_data = fp_handle.readline()
|
|
121
|
+
if json_line_data == '\n':
|
|
122
|
+
continue
|
|
123
|
+
if len(json_line_data) == 0:
|
|
124
|
+
break
|
|
125
|
+
msg = json.loads(json_line_data)
|
|
126
|
+
if len(msg) < 2:
|
|
127
|
+
raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.")
|
|
128
|
+
self.all_summary[msg[0]] = msg[1]
|
|
129
129
|
|
|
130
130
|
if self.debug_flag:
|
|
131
131
|
input_num = 0
|
|
@@ -163,11 +163,16 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
163
163
|
|
|
164
164
|
call_stack = get_callstack()
|
|
165
165
|
self.call_stack_list.append(call_stack)
|
|
166
|
-
|
|
167
|
-
if
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
self.single_api_index_dict
|
|
166
|
+
|
|
167
|
+
self.lock.acquire() if self.process_num > 0 else None
|
|
168
|
+
try:
|
|
169
|
+
self.api_index += 1
|
|
170
|
+
if aten_api not in self.single_api_index_dict:
|
|
171
|
+
self.single_api_index_dict[aten_api] = 1
|
|
172
|
+
else:
|
|
173
|
+
self.single_api_index_dict[aten_api] += 1
|
|
174
|
+
finally:
|
|
175
|
+
self.lock.release() if self.process_num > 0 else None
|
|
171
176
|
|
|
172
177
|
run_param = self.get_run_param(aten_api, func.__name__, aten_api_overload_name)
|
|
173
178
|
|
|
@@ -180,7 +185,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
180
185
|
cpu_kwargs = []
|
|
181
186
|
data_to_cpu(args, 0, cpu_args)
|
|
182
187
|
data_to_cpu(kwargs, 0, cpu_kwargs)
|
|
183
|
-
|
|
188
|
+
|
|
184
189
|
cpu_args = safe_get_value(cpu_args, 0, "cpu_args")
|
|
185
190
|
cpu_kwargs = safe_get_value(cpu_kwargs, 0, "cpu_kwargs")
|
|
186
191
|
|
|
@@ -194,7 +199,12 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
194
199
|
try:
|
|
195
200
|
cpu_out = func(*cpu_args, **cpu_kwargs)
|
|
196
201
|
except RuntimeError as e:
|
|
197
|
-
self.
|
|
202
|
+
self.lock.acquire() if self.process_num > 0 else None
|
|
203
|
+
try:
|
|
204
|
+
self.api_index -= 1
|
|
205
|
+
self.single_api_index_dict[aten_api] -= 1
|
|
206
|
+
finally:
|
|
207
|
+
self.lock.release() if self.process_num > 0 else None
|
|
198
208
|
logger.warning(f"RuntimeError: {e}")
|
|
199
209
|
logger.warning(f"This aten_api {aten_api} does not support running on cpu, so skip it.")
|
|
200
210
|
return npu_out
|
|
@@ -215,7 +225,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
215
225
|
run_param.process_flag = True
|
|
216
226
|
if self.check_fun(func, run_param):
|
|
217
227
|
data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out,
|
|
218
|
-
|
|
228
|
+
child_global_lock)
|
|
219
229
|
self.pool.apply_async(func=dispatch_multiprocess, args=(run_param, data_info),
|
|
220
230
|
error_callback=error_call)
|
|
221
231
|
else:
|
|
@@ -233,12 +243,20 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
233
243
|
return True
|
|
234
244
|
return False
|
|
235
245
|
|
|
246
|
+
@staticmethod
|
|
247
|
+
def _init_child_process(lock):
|
|
248
|
+
global child_global_lock
|
|
249
|
+
child_global_lock = lock
|
|
250
|
+
|
|
236
251
|
def get_dir_name(self, tag):
|
|
237
252
|
# guarantee file uniqueness
|
|
238
253
|
time.sleep(1)
|
|
239
|
-
|
|
254
|
+
# 时间格式:年-月-日-时-分-秒-毫秒(精确到千分之一秒)
|
|
255
|
+
time_now = time.strftime("%Y%m%d%H%M%S%f", time.localtime(time.time()))[:-3] # 取前3位毫秒
|
|
256
|
+
|
|
240
257
|
if tag is None or not isinstance(tag, str):
|
|
241
258
|
logger.warning('There is not tag or the type of tag is not string.')
|
|
259
|
+
# 目录名格式:msprobe_rank{设备ID}_{毫秒时间戳}
|
|
242
260
|
dir_name = f'msprobe_rank{self.device_id}_{time_now}'
|
|
243
261
|
else:
|
|
244
262
|
dir_name = f'msprobe_{tag}_rank{self.device_id}_{time_now}'
|
msprobe/pytorch/pt_config.py
CHANGED
|
@@ -35,15 +35,48 @@ from msprobe.pytorch.hook_module.utils import get_ops
|
|
|
35
35
|
class TensorConfig(BaseConfig):
|
|
36
36
|
def __init__(self, json_config):
|
|
37
37
|
super().__init__(json_config)
|
|
38
|
+
self.online_run_ut = json_config.get("online_run_ut", False)
|
|
39
|
+
self.nfs_path = json_config.get("nfs_path", "")
|
|
40
|
+
self.host = json_config.get("host", "")
|
|
41
|
+
self.port = json_config.get("port", -1)
|
|
42
|
+
self.tls_path = json_config.get("tls_path", "./")
|
|
43
|
+
self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
|
|
38
44
|
self.check_config()
|
|
39
45
|
self._check_summary_mode()
|
|
40
46
|
self._check_file_format()
|
|
41
|
-
|
|
47
|
+
if self.online_run_ut:
|
|
48
|
+
self._check_online_run_ut()
|
|
42
49
|
|
|
43
50
|
def _check_file_format(self):
|
|
44
51
|
if self.file_format is not None and self.file_format not in ["npy", "bin"]:
|
|
45
52
|
raise Exception("file_format is invalid")
|
|
46
53
|
|
|
54
|
+
def _check_online_run_ut(self):
|
|
55
|
+
if not isinstance(self.online_run_ut, bool):
|
|
56
|
+
raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.")
|
|
57
|
+
|
|
58
|
+
if not isinstance(self.online_run_ut_recompute, bool):
|
|
59
|
+
raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.")
|
|
60
|
+
|
|
61
|
+
if self.nfs_path:
|
|
62
|
+
check_file_or_directory_path(self.nfs_path, isdir=True)
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
if self.tls_path:
|
|
66
|
+
check_file_or_directory_path(self.tls_path, isdir=True)
|
|
67
|
+
check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
|
|
68
|
+
check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
|
|
69
|
+
check_file_or_directory_path(os.path.join(self.tls_path, "ca.crt"))
|
|
70
|
+
crl_path = os.path.join(self.tls_path, "crl.pem")
|
|
71
|
+
if os.path.exists(crl_path):
|
|
72
|
+
check_file_or_directory_path(crl_path)
|
|
73
|
+
|
|
74
|
+
if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
|
|
75
|
+
raise Exception(f"host: {self.host} is invalid.")
|
|
76
|
+
|
|
77
|
+
if not isinstance(self.port, int) or not (0 < self.port <= 65535):
|
|
78
|
+
raise Exception(f"port: {self.port} is invalid, port range 0-65535.")
|
|
79
|
+
|
|
47
80
|
|
|
48
81
|
class StatisticsConfig(BaseConfig):
|
|
49
82
|
def __init__(self, json_config):
|
|
@@ -80,6 +113,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
|
|
|
80
113
|
self.handler_type = json_config.get("handler_type", PytorchFreeBenchmarkConst.DEFAULT_HANDLER)
|
|
81
114
|
self.fuzz_level = json_config.get("fuzz_level", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_LEVEL)
|
|
82
115
|
self.fuzz_stage = json_config.get("fuzz_stage", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_STAGE)
|
|
116
|
+
self.list = json_config.get("list")
|
|
83
117
|
self.if_preheat = json_config.get("if_preheat", False)
|
|
84
118
|
self.preheat_step = json_config.get("preheat_step", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
|
|
85
119
|
self.max_sample = json_config.get("max_sample", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
|
|
@@ -146,6 +180,11 @@ class FreeBenchmarkCheckConfig(BaseConfig):
|
|
|
146
180
|
logger.error_log_with_exp(
|
|
147
181
|
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
148
182
|
)
|
|
183
|
+
if self.fuzz_stage == Const.BACKWARD and not self.list:
|
|
184
|
+
raise MsprobeException(
|
|
185
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
186
|
+
f"When fuzz_stage is set to {Const.BACKWARD}, the parameters list must not be empty."
|
|
187
|
+
)
|
|
149
188
|
|
|
150
189
|
def _check_fuzz_level(self):
|
|
151
190
|
if self.fuzz_level not in PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST:
|
|
@@ -218,7 +257,12 @@ class RunUTConfig(BaseConfig):
|
|
|
218
257
|
self.white_list = json_config.get("white_list", Const.DEFAULT_LIST)
|
|
219
258
|
self.black_list = json_config.get("black_list", Const.DEFAULT_LIST)
|
|
220
259
|
self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH)
|
|
221
|
-
|
|
260
|
+
self.is_online = json_config.get("is_online", False)
|
|
261
|
+
self.nfs_path = json_config.get("nfs_path", "")
|
|
262
|
+
self.host = json_config.get("host", "")
|
|
263
|
+
self.port = json_config.get("port", -1)
|
|
264
|
+
self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
|
|
265
|
+
self.tls_path = json_config.get("tls_path", "./")
|
|
222
266
|
self.check_run_ut_config()
|
|
223
267
|
|
|
224
268
|
@classmethod
|
|
@@ -236,11 +280,22 @@ class RunUTConfig(BaseConfig):
|
|
|
236
280
|
if not os.path.exists(error_data_path):
|
|
237
281
|
raise Exception("error_data_path: %s does not exist" % error_data_path)
|
|
238
282
|
|
|
283
|
+
@classmethod
|
|
284
|
+
def check_nfs_path_config(cls, nfs_path):
|
|
285
|
+
if nfs_path:
|
|
286
|
+
FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
287
|
+
|
|
288
|
+
@classmethod
|
|
289
|
+
def check_tls_path_config(cls, tls_path):
|
|
290
|
+
if tls_path:
|
|
291
|
+
FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
239
292
|
|
|
240
293
|
def check_run_ut_config(self):
|
|
241
294
|
RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
|
|
242
295
|
RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list)
|
|
243
296
|
RunUTConfig.check_error_data_path_config(self.error_data_path)
|
|
297
|
+
RunUTConfig.check_nfs_path_config(self.nfs_path)
|
|
298
|
+
RunUTConfig.check_tls_path_config(self.tls_path)
|
|
244
299
|
|
|
245
300
|
|
|
246
301
|
class GradToolConfig(BaseConfig):
|
|
@@ -15,8 +15,9 @@
|
|
|
15
15
|
|
|
16
16
|
from msprobe.core.common.utils import Const
|
|
17
17
|
from msprobe.core.service import BaseService
|
|
18
|
+
from msprobe.pytorch.attl_manager import ATTLManager
|
|
18
19
|
from msprobe.pytorch.common.log import logger
|
|
19
|
-
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
20
|
+
from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_above_or_equal_2
|
|
20
21
|
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
21
22
|
from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate, redirect_wait
|
|
22
23
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
@@ -24,6 +25,9 @@ from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager
|
|
|
24
25
|
from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
|
|
25
26
|
from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, preprocess_func
|
|
26
27
|
|
|
28
|
+
if torch_version_above_or_equal_2:
|
|
29
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
|
|
30
|
+
|
|
27
31
|
|
|
28
32
|
class PytorchService(BaseService):
|
|
29
33
|
@property
|
|
@@ -41,10 +45,12 @@ class PytorchService(BaseService):
|
|
|
41
45
|
self.logger = logger
|
|
42
46
|
self.api_register = get_api_register()
|
|
43
47
|
self.module_processor = ModuleProcesser(self.data_collector.scope)
|
|
44
|
-
self.
|
|
48
|
+
self.attl_manager = ATTLManager(self.config)
|
|
49
|
+
self.hook_manager = PytorchHookManager(self.data_collector, self.config, self.attl_manager)
|
|
45
50
|
self.api_template = ApiTemplate
|
|
46
51
|
|
|
47
52
|
def _register_hook(self):
|
|
53
|
+
self.attl_manager.attl_init()
|
|
48
54
|
if self._is_mix_level:
|
|
49
55
|
register_optimizer_hook(self.data_collector)
|
|
50
56
|
|
|
@@ -59,6 +65,9 @@ class PytorchService(BaseService):
|
|
|
59
65
|
self.module_processor.register_module_hook(self.model, self.build_hook)
|
|
60
66
|
self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.")
|
|
61
67
|
|
|
68
|
+
def _run_ut_dispatch(self, status):
|
|
69
|
+
if torch_version_above_or_equal_2:
|
|
70
|
+
run_ut_dispatch(self.attl_manager.attl, status, self.config.online_run_ut_recompute)
|
|
62
71
|
|
|
63
72
|
def _reset_status(self):
|
|
64
73
|
super()._reset_status()
|