mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -13,75 +13,28 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import torch
|
|
17
|
-
from msprobe.core.common.const import Const
|
|
18
|
-
from msprobe.core.data_dump.scope import BaseScope
|
|
19
16
|
from msprobe.pytorch.common.log import logger
|
|
17
|
+
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
20
18
|
from msprobe.pytorch.hook_module.api_register import get_api_register
|
|
21
19
|
|
|
22
|
-
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
23
|
-
|
|
24
20
|
|
|
25
21
|
class ModuleDumper:
|
|
26
22
|
def __init__(self, service):
|
|
27
23
|
self.service = service
|
|
28
|
-
self.hook_handle_list = []
|
|
29
24
|
self.api_register = get_api_register()
|
|
30
25
|
|
|
31
26
|
def start_module_dump(self, module, dump_name):
|
|
27
|
+
if hasattr(module, 'msprobe_hook') and not hasattr(module, 'msprobe_module_dump'):
|
|
28
|
+
logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.")
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
ModuleProcesser.enable_module_dump = True
|
|
32
32
|
self.api_register.restore_all_api()
|
|
33
|
-
|
|
33
|
+
if not hasattr(module, 'msprobe_module_dump'):
|
|
34
|
+
self.service.module_processor.register_module_hook(module, self.service.build_hook,
|
|
35
|
+
recursive=False, module_names=[dump_name])
|
|
36
|
+
setattr(module, 'msprobe_module_dump', True)
|
|
34
37
|
|
|
35
38
|
def stop_module_dump(self):
|
|
39
|
+
ModuleProcesser.enable_module_dump = False
|
|
36
40
|
self.api_register.register_all_api()
|
|
37
|
-
for hook_handle in self.hook_handle_list:
|
|
38
|
-
if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
|
|
39
|
-
hook_handle.remove()
|
|
40
|
-
self.hook_handle_list.clear()
|
|
41
|
-
|
|
42
|
-
def register_hook(self, module, dump_name):
|
|
43
|
-
prefix_name = (
|
|
44
|
-
BaseScope.Module_Type_Module + Const.SEP +
|
|
45
|
-
dump_name + Const.SEP +
|
|
46
|
-
module.__class__.__name__ + Const.SEP
|
|
47
|
-
)
|
|
48
|
-
module_processor = self.service.module_processor
|
|
49
|
-
_, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.service.build_hook(
|
|
50
|
-
BaseScope.Module_Type_Module,
|
|
51
|
-
prefix_name
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
if module_processor.has_register_backward_hook(module):
|
|
55
|
-
logger.warning(
|
|
56
|
-
f"The {dump_name} module has registered deprecated register_backward_hook,"
|
|
57
|
-
f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
|
|
58
|
-
)
|
|
59
|
-
if torch_version_above_or_equal_2:
|
|
60
|
-
forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
61
|
-
else:
|
|
62
|
-
if not module_processor.has_register_backward_hook(module):
|
|
63
|
-
backward_hook_handle = module.register_full_backward_hook(
|
|
64
|
-
module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
|
|
65
|
-
)
|
|
66
|
-
self.hook_handle_list.append(backward_hook_handle)
|
|
67
|
-
forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2)
|
|
68
|
-
self.hook_handle_list.append(forward_hook_handle)
|
|
69
|
-
if not module_processor.has_register_backward_hook(module):
|
|
70
|
-
backward_hook_handle = module.register_full_backward_hook(backward_hook)
|
|
71
|
-
self.hook_handle_list.append(backward_hook_handle)
|
|
72
|
-
|
|
73
|
-
forward_pre_hook_handle = module.register_forward_pre_hook(
|
|
74
|
-
module_processor.node_hook(prefix_name + Const.FORWARD, Const.START)
|
|
75
|
-
)
|
|
76
|
-
forward_hook_handle = module.register_forward_hook(
|
|
77
|
-
module_processor.node_hook(prefix_name + Const.FORWARD, Const.STOP)
|
|
78
|
-
)
|
|
79
|
-
self.hook_handle_list.extend([forward_pre_hook_handle, forward_hook_handle])
|
|
80
|
-
if torch_version_above_or_equal_2 and not module_processor.has_register_backward_hook(module):
|
|
81
|
-
backward_pre_hook_handle = module.register_full_backward_pre_hook(
|
|
82
|
-
module_processor.node_hook(prefix_name + Const.BACKWARD, Const.START)
|
|
83
|
-
)
|
|
84
|
-
backward_hook_handle = module.register_full_backward_hook(
|
|
85
|
-
module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
|
|
86
|
-
)
|
|
87
|
-
self.hook_handle_list.extend([backward_pre_hook_handle, backward_hook_handle])
|
|
@@ -13,16 +13,16 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
from
|
|
16
|
+
from collections import OrderedDict
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
|
-
from torch.utils.hooks import BackwardHook
|
|
19
|
+
from torch.utils.hooks import BackwardHook, RemovableHandle
|
|
20
20
|
|
|
21
21
|
from msprobe.core.common.const import Const
|
|
22
|
-
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
23
22
|
from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
|
|
24
23
|
from msprobe.pytorch.common.log import logger
|
|
25
|
-
from msprobe.pytorch.common.utils import
|
|
24
|
+
from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
|
|
25
|
+
from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_output_hook
|
|
26
26
|
|
|
27
27
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
28
28
|
if torch_version_above_or_equal_2:
|
|
@@ -39,43 +39,40 @@ def replace_checkpoint():
|
|
|
39
39
|
torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
|
|
40
40
|
|
|
41
41
|
|
|
42
|
+
def wrap_megatron_deallocate(func):
|
|
43
|
+
def wrapper_func(out, deallocate_pipeline_outputs=False):
|
|
44
|
+
if deallocate_pipeline_outputs and isinstance(out, torch.Tensor) and getattr(out, "_base") is not None:
|
|
45
|
+
out_clone = out.clone()
|
|
46
|
+
out.data = torch.empty((1,), device=out.device, dtype=out.dtype, )
|
|
47
|
+
return func(out_clone, deallocate_pipeline_outputs)
|
|
48
|
+
return func(out, deallocate_pipeline_outputs)
|
|
49
|
+
return wrapper_func
|
|
50
|
+
|
|
51
|
+
|
|
42
52
|
class ModuleProcesser:
|
|
43
53
|
module_count = {}
|
|
44
54
|
module_stack = []
|
|
45
55
|
api_parent_node = ""
|
|
46
56
|
module_node = {}
|
|
57
|
+
module_bw_hook_kernels = {}
|
|
58
|
+
module_with_backward_hook = {}
|
|
59
|
+
enable_module_dump = False
|
|
47
60
|
|
|
48
61
|
def __init__(self, scope):
|
|
49
62
|
self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
|
|
50
|
-
|
|
51
|
-
BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
|
|
63
|
+
wrap_setup_input_output_hook()
|
|
52
64
|
replace_checkpoint()
|
|
65
|
+
try:
|
|
66
|
+
from megatron.core.pipeline_parallel import schedules
|
|
67
|
+
schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor)
|
|
68
|
+
logger.info_on_rank_0("Patch megatron method success.")
|
|
69
|
+
except ImportError:
|
|
70
|
+
logger.info_on_rank_0("No megatron find.")
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.info_on_rank_0(f"Patch megatron method failed, detail:{str(e)}")
|
|
53
73
|
|
|
54
74
|
@staticmethod
|
|
55
|
-
def
|
|
56
|
-
@wraps(func)
|
|
57
|
-
def clone_return_value_func(*args, **kwargs):
|
|
58
|
-
result = func(*args, **kwargs)
|
|
59
|
-
return ModuleProcesser.clone_if_tensor(result)
|
|
60
|
-
|
|
61
|
-
return clone_return_value_func
|
|
62
|
-
|
|
63
|
-
@staticmethod
|
|
64
|
-
@recursion_depth_decorator("ModuleDump: ModuleProcesser.clone_if_tensor", max_depth=Const.DUMP_MAX_DEPTH)
|
|
65
|
-
def clone_if_tensor(result):
|
|
66
|
-
if isinstance(result, torch.Tensor) and not is_float8_tensor(result):
|
|
67
|
-
return result.clone()
|
|
68
|
-
elif type(result) is tuple:
|
|
69
|
-
return tuple(ModuleProcesser.clone_if_tensor(x) for x in result)
|
|
70
|
-
elif type(result) is list:
|
|
71
|
-
return list(ModuleProcesser.clone_if_tensor(x) for x in result)
|
|
72
|
-
elif type(result) is dict:
|
|
73
|
-
return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()}
|
|
74
|
-
else:
|
|
75
|
-
return result
|
|
76
|
-
|
|
77
|
-
@staticmethod
|
|
78
|
-
def module_count_func(module_name):
|
|
75
|
+
def set_and_get_calls_number(module_name):
|
|
79
76
|
if module_name not in ModuleProcesser.module_count:
|
|
80
77
|
ModuleProcesser.module_count[module_name] = 0
|
|
81
78
|
else:
|
|
@@ -89,13 +86,19 @@ class ModuleProcesser:
|
|
|
89
86
|
module._is_full_backward_hook is False
|
|
90
87
|
|
|
91
88
|
@staticmethod
|
|
92
|
-
def get_modules_and_names(models):
|
|
89
|
+
def get_modules_and_names(models, recursive, module_names):
|
|
93
90
|
modules_and_names_with_index = {}
|
|
94
91
|
if isinstance(models, (list, tuple)):
|
|
92
|
+
if not recursive and len(module_names) != len(models):
|
|
93
|
+
return modules_and_names_with_index
|
|
95
94
|
for index, model in enumerate(models):
|
|
96
|
-
modules_and_names_with_index[str(index)] = model.named_modules()
|
|
95
|
+
modules_and_names_with_index[str(index)] = model.named_modules() if recursive else \
|
|
96
|
+
[(module_names[index], model)]
|
|
97
97
|
else:
|
|
98
|
-
|
|
98
|
+
if not recursive and len(module_names) != 1:
|
|
99
|
+
return modules_and_names_with_index
|
|
100
|
+
modules_and_names_with_index["-1"] = models.named_modules() if recursive else \
|
|
101
|
+
[(module_names[0], models)]
|
|
99
102
|
return modules_and_names_with_index
|
|
100
103
|
|
|
101
104
|
@classmethod
|
|
@@ -104,107 +107,134 @@ class ModuleProcesser:
|
|
|
104
107
|
cls.module_stack = []
|
|
105
108
|
cls.api_parent_node = ""
|
|
106
109
|
cls.module_node = {}
|
|
110
|
+
cls.module_bw_hook_kernels = {}
|
|
111
|
+
cls.enable_module_dump = False
|
|
107
112
|
|
|
108
|
-
def register_module_hook(self, models, build_hook):
|
|
109
|
-
|
|
110
|
-
|
|
113
|
+
def register_module_hook(self, models, build_hook, recursive=True, module_names=None):
|
|
114
|
+
if module_names is None:
|
|
115
|
+
module_names = []
|
|
116
|
+
|
|
117
|
+
modules_and_names_with_index = self.get_modules_and_names(models, recursive, module_names)
|
|
111
118
|
for index, modules_and_names in modules_and_names_with_index.items():
|
|
112
119
|
model = models if index == "-1" else models[int(index)]
|
|
113
120
|
for name, module in modules_and_names:
|
|
114
|
-
if module == model:
|
|
121
|
+
if recursive and module == model:
|
|
122
|
+
continue
|
|
123
|
+
if not is_torch_nn_module(module):
|
|
124
|
+
logger.warning(
|
|
125
|
+
f"The module dump does not support {type(module)} type. "
|
|
126
|
+
f"The data dump for this module will be skipped."
|
|
127
|
+
)
|
|
115
128
|
continue
|
|
116
129
|
if module.__class__.__name__ == "FullyShardedDataParallel":
|
|
117
130
|
continue
|
|
131
|
+
setattr(module, 'msprobe_hook', True)
|
|
118
132
|
module_index = (index + Const.SEP) if index != "-1" else ""
|
|
119
|
-
prefix_name =
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
prefix_name
|
|
124
|
-
)
|
|
133
|
+
prefix_name = f'{BaseScope.Module_Type_Module}{Const.SEP}{module_index}{name}{Const.SEP}' + \
|
|
134
|
+
f'{module.__class__.__name__}{Const.SEP}'
|
|
135
|
+
|
|
136
|
+
forward_pre_hook = self.build_module_hook(prefix_name, build_hook)
|
|
125
137
|
|
|
126
138
|
if self.has_register_backward_hook(module):
|
|
127
139
|
logger.warning(
|
|
128
140
|
f"The {prefix_name[:-1]} has registered deprecated register_backward_hook,"
|
|
129
141
|
f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
|
|
130
142
|
)
|
|
143
|
+
ModuleProcesser.module_with_backward_hook[prefix_name] = True
|
|
144
|
+
register_forward_pre_hook(module, forward_pre_hook)
|
|
145
|
+
|
|
146
|
+
def build_module_hook(self, module_name, build_data_hook):
|
|
147
|
+
def forward_pre_hook(module, args, kwargs=None):
|
|
148
|
+
if kwargs is None:
|
|
149
|
+
kwargs = {}
|
|
150
|
+
|
|
151
|
+
if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump:
|
|
152
|
+
return (args, kwargs) if torch_version_above_or_equal_2 else args
|
|
153
|
+
|
|
154
|
+
index = ModuleProcesser.set_and_get_calls_number(module_name)
|
|
155
|
+
full_forward_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}'
|
|
156
|
+
full_backward_name = f'{module_name}{Const.BACKWARD}{Const.SEP}{index}'
|
|
157
|
+
|
|
158
|
+
self.set_construct_info_in_pre_hook(full_forward_name)
|
|
159
|
+
|
|
160
|
+
if not hasattr(module, 'msprobe_forward_hook'):
|
|
161
|
+
forward_hooks_dict = getattr(module, '_forward_hooks', OrderedDict())
|
|
162
|
+
handle = RemovableHandle(forward_hooks_dict)
|
|
163
|
+
forward_hooks_dict[handle.id] = forward_hook
|
|
164
|
+
forward_hooks_dict.move_to_end(handle.id, last=False)
|
|
165
|
+
if torch_version_above_or_equal_2:
|
|
166
|
+
forward_hooks_with_kwargs_dict = getattr(module, '_forward_hooks_with_kwargs', OrderedDict())
|
|
167
|
+
forward_hooks_with_kwargs_dict[handle.id] = True
|
|
168
|
+
|
|
169
|
+
setattr(module, 'msprobe_forward_hook', True)
|
|
170
|
+
|
|
171
|
+
hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
|
|
172
|
+
|
|
173
|
+
def get_backward_pre_hook(full_backward_name):
|
|
174
|
+
def backward_pre_hook_fn(module, grad_output):
|
|
175
|
+
self.set_construct_info_in_pre_hook(full_backward_name)
|
|
176
|
+
return backward_pre_hook_fn
|
|
177
|
+
|
|
178
|
+
def get_backward_hook(backward_data_hook, full_backward_name):
|
|
179
|
+
def backward_hook_fn(module, grad_input, grad_output):
|
|
180
|
+
new_output = backward_data_hook(module, grad_input, grad_output)
|
|
181
|
+
self.set_construct_info_in_hook(full_backward_name, is_forward=False)
|
|
182
|
+
return new_output
|
|
183
|
+
return backward_hook_fn
|
|
184
|
+
|
|
185
|
+
if not ModuleProcesser.module_with_backward_hook.get(module_name):
|
|
186
|
+
backward_pre_hook = get_backward_pre_hook(full_backward_name)
|
|
187
|
+
backward_hook = get_backward_hook(hook_set.backward_hook, full_backward_name)
|
|
131
188
|
if torch_version_above_or_equal_2:
|
|
132
|
-
module
|
|
189
|
+
bw_hook = BackwardHook(module, [backward_hook], [backward_pre_hook])
|
|
133
190
|
else:
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
index = None
|
|
153
|
-
pass
|
|
154
|
-
full_name = name_prefix + Const.SEP + str(index)
|
|
155
|
-
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
156
|
-
module.mindstudio_reserved_name = []
|
|
157
|
-
module.mindstudio_reserved_name.append(full_name)
|
|
158
|
-
if self.module_stack:
|
|
159
|
-
ModuleProcesser.module_node[full_name] = self.module_stack[-1]
|
|
191
|
+
bw_hook = BackwardHook(module, [backward_hook])
|
|
192
|
+
ModuleProcesser.module_bw_hook_kernels[full_forward_name] = bw_hook
|
|
193
|
+
args = bw_hook.setup_input_hook(args)
|
|
194
|
+
return (args, kwargs) if torch_version_above_or_equal_2 else args
|
|
195
|
+
|
|
196
|
+
def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None):
|
|
197
|
+
if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump:
|
|
198
|
+
return output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
|
|
199
|
+
|
|
200
|
+
index = ModuleProcesser.module_count.get(module_name)
|
|
201
|
+
full_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}'
|
|
202
|
+
|
|
203
|
+
hook_set = build_data_hook(BaseScope.Module_Type_Module, full_name)
|
|
204
|
+
hook_result = hook_set.forward_hook(module, args, kwargs_or_output, output_or_kwargs)
|
|
205
|
+
self.set_construct_info_in_hook(full_name)
|
|
206
|
+
|
|
207
|
+
if hook_result is not None:
|
|
208
|
+
result = hook_result
|
|
160
209
|
else:
|
|
161
|
-
|
|
210
|
+
result = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
|
|
162
211
|
|
|
163
|
-
ModuleProcesser.
|
|
164
|
-
if
|
|
165
|
-
|
|
166
|
-
if self.scope:
|
|
167
|
-
self.scope.begin_module(full_name)
|
|
212
|
+
bw_hook = ModuleProcesser.module_bw_hook_kernels.get(full_name)
|
|
213
|
+
if bw_hook:
|
|
214
|
+
result = bw_hook.setup_output_hook(result)
|
|
168
215
|
|
|
169
|
-
|
|
216
|
+
return result
|
|
217
|
+
|
|
218
|
+
return forward_pre_hook
|
|
219
|
+
|
|
220
|
+
def set_construct_info_in_pre_hook(self, full_name):
|
|
221
|
+
if self.module_stack:
|
|
222
|
+
ModuleProcesser.module_node[full_name] = self.module_stack[-1]
|
|
223
|
+
else:
|
|
224
|
+
ModuleProcesser.module_node[full_name] = None
|
|
225
|
+
ModuleProcesser.module_stack.append(full_name)
|
|
226
|
+
ModuleProcesser.api_parent_node = full_name
|
|
227
|
+
if self.scope:
|
|
228
|
+
self.scope.begin_module(full_name)
|
|
229
|
+
|
|
230
|
+
def set_construct_info_in_hook(self, full_name, is_forward=True):
|
|
231
|
+
if torch_version_above_or_equal_2 or is_forward:
|
|
170
232
|
if self.module_stack:
|
|
171
233
|
ModuleProcesser.module_stack.pop()
|
|
172
|
-
if self.module_stack
|
|
173
|
-
ModuleProcesser.api_parent_node = self.module_stack[-1]
|
|
174
|
-
else:
|
|
175
|
-
ModuleProcesser.api_parent_node = None
|
|
176
|
-
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
177
|
-
raise RuntimeError(f"module reserve name is None when pop")
|
|
178
|
-
current_name = module.mindstudio_reserved_name.pop()
|
|
234
|
+
ModuleProcesser.api_parent_node = ModuleProcesser.module_stack[-1] if self.module_stack else None
|
|
179
235
|
if self.scope:
|
|
180
|
-
self.scope.end_module(
|
|
181
|
-
|
|
182
|
-
def backward_hook(module, input, output=None):
|
|
183
|
-
try:
|
|
184
|
-
index = ModuleProcesser.module_count_func(name_prefix)
|
|
185
|
-
except IndexError as e:
|
|
186
|
-
index = None
|
|
187
|
-
pass
|
|
188
|
-
full_name = name_prefix + Const.SEP + str(index)
|
|
189
|
-
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
190
|
-
module.mindstudio_reserved_name = []
|
|
191
|
-
module.mindstudio_reserved_name.append(full_name)
|
|
192
|
-
forward_full_name = replace_last_occurrence(full_name, Const.BACKWARD, Const.FORWARD)
|
|
193
|
-
ModuleProcesser.module_node[full_name] = replace_last_occurrence(
|
|
194
|
-
ModuleProcesser.module_node.get(forward_full_name), Const.FORWARD, Const.BACKWARD)
|
|
195
|
-
ModuleProcesser.api_parent_node = None
|
|
236
|
+
self.scope.end_module(full_name)
|
|
237
|
+
else:
|
|
196
238
|
if self.scope:
|
|
197
239
|
self.scope.begin_module(full_name)
|
|
198
|
-
|
|
199
|
-
if torch_version_above_or_equal_2:
|
|
200
|
-
if Const.START in start_or_stop:
|
|
201
|
-
return pre_hook
|
|
202
|
-
else:
|
|
203
|
-
return end_hook
|
|
204
|
-
else:
|
|
205
|
-
if Const.FORWARD in name_prefix and Const.START in start_or_stop:
|
|
206
|
-
return pre_hook
|
|
207
|
-
elif Const.BACKWARD in name_prefix:
|
|
208
|
-
return backward_hook
|
|
209
|
-
else:
|
|
210
|
-
return end_hook
|
|
240
|
+
ModuleProcesser.api_parent_node = full_name
|
|
@@ -17,6 +17,7 @@ from abc import ABC, abstractmethod
|
|
|
17
17
|
from collections import namedtuple
|
|
18
18
|
import hashlib
|
|
19
19
|
from functools import wraps
|
|
20
|
+
import zlib
|
|
20
21
|
import torch
|
|
21
22
|
from msprobe.core.grad_probe.constant import GradConst
|
|
22
23
|
|
|
@@ -74,8 +75,8 @@ class CsvMd5(CsvItem):
|
|
|
74
75
|
def generate_csv_content(csv_content_input):
|
|
75
76
|
grad = csv_content_input.grad
|
|
76
77
|
tensor_bytes = grad.cpu().detach().float().numpy().tobytes()
|
|
77
|
-
md5_hash =
|
|
78
|
-
return [md5_hash
|
|
78
|
+
md5_hash = f"{zlib.crc32(tensor_bytes):08x}"
|
|
79
|
+
return [md5_hash]
|
|
79
80
|
|
|
80
81
|
|
|
81
82
|
@register_csv_item(GradConst.DISTRIBUTION)
|
|
@@ -15,21 +15,36 @@
|
|
|
15
15
|
|
|
16
16
|
import functools
|
|
17
17
|
import os
|
|
18
|
+
import inspect
|
|
18
19
|
|
|
19
20
|
import torch
|
|
20
21
|
import torch.distributed as dist
|
|
21
22
|
|
|
22
23
|
from msprobe.core.common.const import Const
|
|
23
24
|
from msprobe.core.data_dump.api_registry import ApiRegistry
|
|
25
|
+
from msprobe.pytorch.common.log import logger
|
|
24
26
|
from msprobe.pytorch.common.utils import (
|
|
25
27
|
torch_without_guard_version, is_gpu, torch_device_guard, parameter_adapter
|
|
26
28
|
)
|
|
27
29
|
from msprobe.pytorch.function_factory import npu_custom_functions
|
|
28
30
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
31
|
+
from msprobe.pytorch.hook_module.utils import dynamic_import_op
|
|
32
|
+
from msprobe.core.common.file_utils import load_yaml
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
import mindspeed.ops
|
|
36
|
+
except ImportError:
|
|
37
|
+
mindspeed_enable = False
|
|
38
|
+
else:
|
|
39
|
+
mindspeed_enable = True
|
|
29
40
|
|
|
30
41
|
|
|
31
42
|
torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
|
|
32
43
|
|
|
44
|
+
_inner_used_api = {}
|
|
45
|
+
_supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
|
|
46
|
+
_cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
|
|
47
|
+
|
|
33
48
|
_api_types = {
|
|
34
49
|
Const.PT_FRAMEWORK: {
|
|
35
50
|
Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
|
|
@@ -57,10 +72,11 @@ if not is_gpu:
|
|
|
57
72
|
torch_npu.distributed.distributed_c10d))
|
|
58
73
|
}
|
|
59
74
|
)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
75
|
+
if mindspeed_enable:
|
|
76
|
+
_api_types.get(Const.PT_FRAMEWORK).update({Const.PT_API_TYPE_MINDSPEED: (mindspeed.ops, (mindspeed.ops,))})
|
|
77
|
+
mindspeed_op_list = load_yaml(_supported_api_list_path[0]).get(Const.PT_API_TYPE_MINDSPEED)
|
|
78
|
+
mindspeed_op_file_list = [op.split(Const.SEP)[0] + Const.PY_SUFFIX for op in mindspeed_op_list]
|
|
79
|
+
dynamic_import_op(mindspeed.ops, mindspeed_op_file_list)
|
|
64
80
|
|
|
65
81
|
|
|
66
82
|
@parameter_adapter
|
|
@@ -70,7 +86,15 @@ def tensor_module_forward(module, *args, **kwargs):
|
|
|
70
86
|
|
|
71
87
|
def dist_module_forward(module, *args, **kwargs):
|
|
72
88
|
handle = module.api_func(*args, **kwargs)
|
|
73
|
-
|
|
89
|
+
try:
|
|
90
|
+
bound = inspect.signature(module.api_func).bind(*args, **kwargs)
|
|
91
|
+
bound.apply_defaults()
|
|
92
|
+
use_async_op_flag = bound.arguments.get("async_op", False)
|
|
93
|
+
except Exception as e:
|
|
94
|
+
use_async_op_flag = False
|
|
95
|
+
logger.warning(f"fail to get dist api's func signature because {e}, no wait")
|
|
96
|
+
|
|
97
|
+
if use_async_op_flag or module.api_name in ["isend", "irecv"]:
|
|
74
98
|
if handle and hasattr(handle, 'wait'):
|
|
75
99
|
handle.wait()
|
|
76
100
|
if module.api_name == "batch_isend_irecv":
|
|
@@ -21,9 +21,8 @@ import torch
|
|
|
21
21
|
import torch.nn as nn
|
|
22
22
|
import torch.utils.hooks as full_hooks
|
|
23
23
|
|
|
24
|
-
from msprobe.
|
|
25
|
-
|
|
26
|
-
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
24
|
+
from msprobe.core.common.runtime import Runtime
|
|
25
|
+
from msprobe.pytorch.common.utils import is_float8_tensor, register_forward_pre_hook, register_forward_hook
|
|
27
26
|
|
|
28
27
|
|
|
29
28
|
class HOOKModule(nn.Module):
|
|
@@ -41,16 +40,14 @@ class HOOKModule(nn.Module):
|
|
|
41
40
|
if not self.stop_hook:
|
|
42
41
|
self.forward_data_collected = False
|
|
43
42
|
|
|
43
|
+
if not Runtime.is_running:
|
|
44
|
+
return
|
|
44
45
|
prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
|
|
45
46
|
if callable(hook_build_func):
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
else:
|
|
51
|
-
self.register_forward_pre_hook(forward_pre_hook)
|
|
52
|
-
self.register_forward_hook(forward_hook)
|
|
53
|
-
self.register_backward_hook(backward_hook)
|
|
47
|
+
hook_set = hook_build_func(prefix)
|
|
48
|
+
register_forward_pre_hook(self, hook_set.forward_pre_hook)
|
|
49
|
+
register_forward_hook(self, hook_set.forward_hook)
|
|
50
|
+
self.register_backward_hook(hook_set.backward_hook)
|
|
54
51
|
|
|
55
52
|
def __call__(self, *args, **kwargs):
|
|
56
53
|
changed = False
|
|
@@ -79,13 +76,7 @@ class HOOKModule(nn.Module):
|
|
|
79
76
|
if len(self._backward_hooks) > 0:
|
|
80
77
|
full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
|
|
81
78
|
for hook in self._forward_pre_hooks.values():
|
|
82
|
-
|
|
83
|
-
if result_args is not None:
|
|
84
|
-
if not isinstance(result_args, tuple):
|
|
85
|
-
result_args = (result_args,)
|
|
86
|
-
args = result_args
|
|
87
|
-
if result_kwargs is not None:
|
|
88
|
-
kwargs = result_kwargs
|
|
79
|
+
hook(self, args, kwargs)
|
|
89
80
|
bw_hook = None
|
|
90
81
|
if len(full_backward_hooks) > 0:
|
|
91
82
|
bw_hook = full_hooks.BackwardHook(self, full_backward_hooks)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
from msprobe.pytorch.hook_module.api_register import get_api_register
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def wrap_jit_script_func():
|
|
22
|
+
def patched_script(*args, **kwargs):
|
|
23
|
+
all_api_registered = api_register.all_api_registered
|
|
24
|
+
if all_api_registered:
|
|
25
|
+
api_register.restore_all_api()
|
|
26
|
+
result = original_script(*args, **kwargs)
|
|
27
|
+
if all_api_registered:
|
|
28
|
+
api_register.register_all_api()
|
|
29
|
+
return result
|
|
30
|
+
|
|
31
|
+
original_script = torch.jit.script
|
|
32
|
+
api_register = get_api_register()
|
|
33
|
+
torch.jit.script = patched_script
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
from contextlib import nullcontext
|
|
18
|
+
|
|
19
|
+
from msprobe.core.common.const import Const
|
|
20
|
+
from msprobe.core.common.utils import replace_last_occurrence
|
|
21
|
+
from msprobe.core.hook_manager import BaseHookManager, HookSet
|
|
22
|
+
from msprobe.pytorch.common.utils import is_recomputation, torch_version_above_or_equal_2
|
|
23
|
+
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PytorchHookManager(BaseHookManager):
|
|
27
|
+
@property
|
|
28
|
+
def _is_recompute(self):
|
|
29
|
+
return is_recomputation()
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def _no_grad_context():
|
|
33
|
+
return nullcontext()
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def _add_count(name):
|
|
37
|
+
HOOKModule.add_module_count(name)
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs):
|
|
41
|
+
kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {}
|
|
42
|
+
output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
|
|
43
|
+
return kwargs, output
|
|
44
|
+
|
|
45
|
+
def build_hook(self, hook_type, name):
|
|
46
|
+
if hook_type == Const.API:
|
|
47
|
+
full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD
|
|
48
|
+
else:
|
|
49
|
+
full_forward_name = name
|
|
50
|
+
full_backward_name = replace_last_occurrence(full_forward_name, Const.FORWARD, Const.BACKWARD)
|
|
51
|
+
hookset = HookSet(
|
|
52
|
+
forward_hook=self._build_forward_hook(hook_type, full_forward_name),
|
|
53
|
+
forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name),
|
|
54
|
+
backward_hook=self._build_backward_hook(hook_type, full_backward_name)
|
|
55
|
+
)
|
|
56
|
+
return hookset
|
|
57
|
+
|
|
58
|
+
def _need_exchange(self, module):
|
|
59
|
+
return True
|
|
60
|
+
|
|
61
|
+
def _get_params_dict(self, module):
|
|
62
|
+
params_dict = {}
|
|
63
|
+
if self.config.task != Const.STRUCTURE:
|
|
64
|
+
params_dict = {
|
|
65
|
+
key.split(Const.SEP)[-1]: value
|
|
66
|
+
for key, value in module.named_parameters(recurse=False)
|
|
67
|
+
}
|
|
68
|
+
return params_dict
|