mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,92 +13,21 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import
|
|
16
|
+
from msprobe.core.compare.acc_compare import Comparator, ModeConfig, MappingConfig, setup_comparison
|
|
17
|
+
from msprobe.pytorch.compare.utils import read_pt_data
|
|
17
18
|
|
|
18
|
-
import torch
|
|
19
19
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
set_dump_path
|
|
25
|
-
from msprobe.core.compare.acc_compare import Comparator, ModeConfig
|
|
26
|
-
from msprobe.core.compare.utils import set_stack_json_path
|
|
27
|
-
from msprobe.pytorch.common.log import logger
|
|
28
|
-
from msprobe.pytorch.common.utils import load_pt
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
class PTComparator(Comparator):
|
|
32
|
-
def __init__(self, mode_config, data_mapping=None):
|
|
33
|
-
super().__init__(mode_config)
|
|
34
|
-
|
|
35
|
-
self.stack_mode = mode_config.stack_mode
|
|
36
|
-
self.auto_analyze = mode_config.auto_analyze
|
|
37
|
-
self.fuzzy_match = mode_config.fuzzy_match
|
|
38
|
-
self.dump_mode = mode_config.dump_mode
|
|
39
|
-
|
|
40
|
-
self.frame_name = PTComparator.__name__
|
|
41
|
-
self.data_mapping = data_mapping
|
|
42
|
-
if isinstance(self.data_mapping, str) or self.data_mapping is None:
|
|
43
|
-
self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
|
|
44
|
-
elif isinstance(self.data_mapping, dict):
|
|
45
|
-
self.data_mapping_dict = self.data_mapping
|
|
46
|
-
else:
|
|
47
|
-
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
48
|
-
f"{type(self.data_mapping)}")
|
|
49
|
-
|
|
50
|
-
@staticmethod
|
|
51
|
-
def load_mapping_file(mapping_file):
|
|
52
|
-
if isinstance(mapping_file, str):
|
|
53
|
-
mapping_dict = load_yaml(mapping_file)
|
|
54
|
-
else:
|
|
55
|
-
mapping_dict = {}
|
|
56
|
-
return mapping_dict
|
|
57
|
-
|
|
58
|
-
def read_npy_data(self, dir_path, file_name):
|
|
59
|
-
if not file_name:
|
|
60
|
-
return None
|
|
61
|
-
data_path = os.path.join(dir_path, file_name)
|
|
62
|
-
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
63
|
-
FileCheckConst.PT_SUFFIX, False)
|
|
64
|
-
data_path = path_checker.common_check()
|
|
65
|
-
try:
|
|
66
|
-
# detach because numpy can not process gradient information
|
|
67
|
-
data_value = load_pt(data_path, to_cpu=True).detach()
|
|
68
|
-
except RuntimeError as e:
|
|
69
|
-
# 这里捕获 load_pt 中抛出的异常
|
|
70
|
-
logger.error(f"Failed to load the .pt file at {data_path}.")
|
|
71
|
-
raise CompareException(CompareException.INVALID_FILE_ERROR) from e
|
|
72
|
-
except AttributeError as e:
|
|
73
|
-
# 这里捕获 detach 方法抛出的异常
|
|
74
|
-
logger.error(f"Failed to detach the loaded tensor.")
|
|
75
|
-
raise CompareException(CompareException.DETACH_ERROR) from e
|
|
76
|
-
if data_value.dtype == torch.bfloat16:
|
|
77
|
-
data_value = data_value.to(torch.float32)
|
|
78
|
-
data_value = data_value.numpy()
|
|
79
|
-
return data_value
|
|
20
|
+
def read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name, _) -> tuple:
|
|
21
|
+
n_value = read_pt_data(npu_dir, npu_data_name)
|
|
22
|
+
b_value = read_pt_data(bench_dir, bench_data_name)
|
|
23
|
+
return n_value, b_value
|
|
80
24
|
|
|
81
25
|
|
|
82
26
|
def compare(input_param, output_path, **kwargs):
|
|
83
|
-
|
|
84
|
-
auto_analyze = kwargs.get('auto_analyze', True)
|
|
85
|
-
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
86
|
-
data_mapping = kwargs.get('data_mapping', None)
|
|
87
|
-
suffix = kwargs.get('suffix', '')
|
|
88
|
-
|
|
89
|
-
set_dump_path(input_param)
|
|
90
|
-
dump_mode = get_dump_mode(input_param)
|
|
91
|
-
if "stack_json_path" in input_param:
|
|
92
|
-
stack_mode = kwargs.get('stack_mode', False)
|
|
93
|
-
else:
|
|
94
|
-
stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
|
|
95
|
-
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
|
|
96
|
-
create_directory(output_path)
|
|
97
|
-
check_compare_param(input_param, output_path, dump_mode, stack_mode)
|
|
98
|
-
except (CompareException, FileCheckException) as error:
|
|
99
|
-
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
100
|
-
raise CompareException(error.code) from error
|
|
27
|
+
config = setup_comparison(input_param, output_path, **kwargs)
|
|
101
28
|
|
|
102
|
-
mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match,
|
|
103
|
-
|
|
104
|
-
|
|
29
|
+
mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match,
|
|
30
|
+
config.dump_mode, config.compared_file_type)
|
|
31
|
+
mapping_config = MappingConfig(data_mapping=config.data_mapping)
|
|
32
|
+
pt_comparator = Comparator(read_real_data, mode_config, mapping_config)
|
|
33
|
+
pt_comparator.compare_core(input_param, output_path, suffix=config.suffix)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.utils import logger, CompareException
|
|
21
|
+
from msprobe.core.common.file_utils import FileChecker, FileCheckConst
|
|
22
|
+
from msprobe.pytorch.common.utils import load_pt
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def read_pt_data(dir_path, file_name):
|
|
26
|
+
if not file_name:
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
data_path = os.path.join(dir_path, file_name)
|
|
30
|
+
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
31
|
+
FileCheckConst.PT_SUFFIX, False)
|
|
32
|
+
data_path = path_checker.common_check()
|
|
33
|
+
try:
|
|
34
|
+
# detach because numpy can not process gradient information
|
|
35
|
+
data_value = load_pt(data_path, to_cpu=True).detach()
|
|
36
|
+
except RuntimeError as e:
|
|
37
|
+
# 这里捕获 load_pt 中抛出的异常
|
|
38
|
+
logger.error(f"Failed to load the .pt file at {data_path}.")
|
|
39
|
+
raise CompareException(CompareException.INVALID_FILE_ERROR) from e
|
|
40
|
+
except AttributeError as e:
|
|
41
|
+
# 这里捕获 detach 方法抛出的异常
|
|
42
|
+
logger.error(f"Failed to detach the loaded tensor.")
|
|
43
|
+
raise CompareException(CompareException.DETACH_ERROR) from e
|
|
44
|
+
if data_value.dtype == torch.bfloat16:
|
|
45
|
+
data_value = data_value.to(torch.float32)
|
|
46
|
+
data_value = data_value.numpy()
|
|
47
|
+
return data_value
|
|
@@ -13,11 +13,10 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import torch
|
|
17
|
-
|
|
18
16
|
from msprobe.core.common.const import Const
|
|
19
17
|
from msprobe.core.common.exceptions import MsprobeException
|
|
20
18
|
from msprobe.pytorch.common.log import logger
|
|
19
|
+
from msprobe.pytorch.common.utils import is_torch_nn_module
|
|
21
20
|
|
|
22
21
|
|
|
23
22
|
class DebuggerConfig:
|
|
@@ -60,6 +59,7 @@ class DebuggerConfig:
|
|
|
60
59
|
if isinstance(task_config.online_run_ut_recompute, bool) else False
|
|
61
60
|
|
|
62
61
|
self.check()
|
|
62
|
+
self._check_statistics_config(task_config)
|
|
63
63
|
|
|
64
64
|
if self.level == Const.LEVEL_L2:
|
|
65
65
|
self.is_backward_kernel_dump = False
|
|
@@ -78,10 +78,13 @@ class DebuggerConfig:
|
|
|
78
78
|
if not isinstance(self.async_dump, bool):
|
|
79
79
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
80
80
|
f"The parameters async_dump should be bool.")
|
|
81
|
-
if self.async_dump and self.task == Const.TENSOR
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
81
|
+
if self.async_dump and self.task == Const.TENSOR:
|
|
82
|
+
if self.level == Const.LEVEL_DEBUG:
|
|
83
|
+
self.list = [] # async_dump + debug level case ignore list
|
|
84
|
+
if not self.list and self.level != Const.LEVEL_DEBUG:
|
|
85
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
86
|
+
f"The parameters async_dump is true in tensor task, the parameters list cannot be "
|
|
87
|
+
f"empty.")
|
|
85
88
|
if self.task == Const.STRUCTURE and self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
86
89
|
logger.warning_on_rank_0(
|
|
87
90
|
f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
|
|
@@ -93,25 +96,24 @@ class DebuggerConfig:
|
|
|
93
96
|
self.check_kwargs()
|
|
94
97
|
return True
|
|
95
98
|
|
|
96
|
-
def check_model(self, instance, start_model):
|
|
97
|
-
if
|
|
98
|
-
|
|
99
|
-
logger.info_on_rank_0(
|
|
100
|
-
f"The current level is not L0 or mix level, so the model parameters will not be used.")
|
|
99
|
+
def check_model(self, instance, start_model, token_range=None):
|
|
100
|
+
instance.model = start_model if start_model is not None else instance.model
|
|
101
|
+
if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX] and token_range is None:
|
|
101
102
|
return
|
|
102
|
-
|
|
103
|
+
|
|
104
|
+
if instance.model is None:
|
|
103
105
|
logger.error_on_rank_0(
|
|
104
|
-
f"For level {self.level}
|
|
106
|
+
f"For level {self.level} or non-empty token_range, "
|
|
107
|
+
f"PrecisionDebugger or start interface must receive a 'model' parameter.")
|
|
105
108
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
|
|
106
109
|
|
|
107
|
-
|
|
108
|
-
if isinstance(instance.model, torch.nn.Module):
|
|
110
|
+
if is_torch_nn_module(instance.model):
|
|
109
111
|
return
|
|
110
112
|
|
|
111
113
|
error_model = None
|
|
112
114
|
if isinstance(instance.model, (list, tuple)):
|
|
113
115
|
for model in instance.model:
|
|
114
|
-
if not
|
|
116
|
+
if not is_torch_nn_module(model):
|
|
115
117
|
error_model = model
|
|
116
118
|
break
|
|
117
119
|
else:
|
|
@@ -119,7 +121,7 @@ class DebuggerConfig:
|
|
|
119
121
|
|
|
120
122
|
if error_model is not None:
|
|
121
123
|
error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] "
|
|
122
|
-
f"type, currently there is
|
|
124
|
+
f"type, currently there is an unsupported {type(error_model)} type.")
|
|
123
125
|
raise MsprobeException(
|
|
124
126
|
MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
125
127
|
|
|
@@ -130,8 +132,23 @@ class DebuggerConfig:
|
|
|
130
132
|
if not self.list or len(self.list) != 1:
|
|
131
133
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
132
134
|
f"When level is set to L2, the list must be configured as a list with one api name.")
|
|
135
|
+
if self.task != Const.TENSOR:
|
|
136
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
137
|
+
f"When level is set to L2, the task must be set to tensor.")
|
|
138
|
+
|
|
133
139
|
api_name = self.list[0]
|
|
134
140
|
if api_name.endswith(Const.BACKWARD):
|
|
135
141
|
self.is_backward_kernel_dump = True
|
|
136
142
|
api_forward_name = api_name[:-len(Const.BACKWARD)] + Const.FORWARD
|
|
137
143
|
self.list.append(api_forward_name)
|
|
144
|
+
|
|
145
|
+
def _check_statistics_config(self, task_config):
|
|
146
|
+
if self.task != Const.STATISTICS:
|
|
147
|
+
return
|
|
148
|
+
self.tensor_list = []
|
|
149
|
+
if not hasattr(task_config, "tensor_list"):
|
|
150
|
+
return
|
|
151
|
+
if self.level == Const.LEVEL_DEBUG and task_config.tensor_list:
|
|
152
|
+
logger.warning_on_rank_0("When level is set to debug, the tensor_list will be invalid.")
|
|
153
|
+
return
|
|
154
|
+
self.tensor_list = task_config.tensor_list
|
|
@@ -13,36 +13,22 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
from
|
|
16
|
+
from torch.utils.data import dataloader
|
|
17
17
|
|
|
18
|
-
import
|
|
19
|
-
from msprobe.core.common.const import Const, FileCheckConst, MsgConst
|
|
18
|
+
from msprobe.core.common.const import Const, MsgConst
|
|
20
19
|
from msprobe.core.common.exceptions import MsprobeException
|
|
21
|
-
from msprobe.core.common.
|
|
22
|
-
from msprobe.core.
|
|
20
|
+
from msprobe.core.common.utils import check_token_range
|
|
21
|
+
from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger
|
|
23
22
|
from msprobe.pytorch.common.log import logger
|
|
24
|
-
from msprobe.pytorch.common.utils import check_save_param
|
|
23
|
+
from msprobe.pytorch.common.utils import check_save_param, is_torch_nn_module
|
|
25
24
|
from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
26
25
|
from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper
|
|
27
26
|
from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
|
|
28
|
-
from msprobe.pytorch.
|
|
29
|
-
from msprobe.pytorch.
|
|
30
|
-
from torch.utils.data import dataloader
|
|
31
|
-
|
|
32
|
-
ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task",
|
|
33
|
-
"dump_path", "level", "model"])
|
|
27
|
+
from msprobe.pytorch.pytorch_service import PytorchService
|
|
28
|
+
from msprobe.pytorch.pt_config import parse_task_config
|
|
34
29
|
|
|
35
30
|
|
|
36
|
-
class PrecisionDebugger:
|
|
37
|
-
_instance = None
|
|
38
|
-
tasks_not_need_debugger = [Const.GRAD_PROBE]
|
|
39
|
-
|
|
40
|
-
def __new__(cls, *args, **kwargs):
|
|
41
|
-
if cls._instance is None:
|
|
42
|
-
cls._instance = super(PrecisionDebugger, cls).__new__(cls)
|
|
43
|
-
cls._instance.config = None
|
|
44
|
-
cls._instance.enable_dataloader = False
|
|
45
|
-
return cls._instance
|
|
31
|
+
class PrecisionDebugger(BasePrecisionDebugger):
|
|
46
32
|
|
|
47
33
|
def __init__(
|
|
48
34
|
self,
|
|
@@ -53,90 +39,65 @@ class PrecisionDebugger:
|
|
|
53
39
|
model=None,
|
|
54
40
|
step=None
|
|
55
41
|
):
|
|
56
|
-
if
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
self.
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
if step is not None:
|
|
72
|
-
common_config.step = get_real_step_or_rank(step, Const.STEP)
|
|
73
|
-
self.config = DebuggerConfig(
|
|
74
|
-
common_config, task_config, task, dump_path, level
|
|
75
|
-
)
|
|
76
|
-
self.service = Service(self.config)
|
|
77
|
-
self.module_dumper = ModuleDumper(self.service)
|
|
78
|
-
self.enable_dataloader = self.config.enable_dataloader
|
|
79
|
-
if self.enable_dataloader:
|
|
80
|
-
logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
|
|
81
|
-
dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
|
|
82
|
-
|
|
83
|
-
@property
|
|
84
|
-
def instance(self):
|
|
85
|
-
return self._instance
|
|
42
|
+
if self.initialized:
|
|
43
|
+
return
|
|
44
|
+
super().__init__(config_path, task, dump_path, level, step)
|
|
45
|
+
self.model = model
|
|
46
|
+
if self.task == Const.GRAD_PROBE:
|
|
47
|
+
self.gm = GradientMonitor(self.common_config, self.task_config)
|
|
48
|
+
return
|
|
49
|
+
self.config = DebuggerConfig(
|
|
50
|
+
self.common_config, self.task_config, task, dump_path, level
|
|
51
|
+
)
|
|
52
|
+
self.service = PytorchService(self.config)
|
|
53
|
+
self.module_dumper = ModuleDumper(self.service)
|
|
54
|
+
self.ori_customer_func = {}
|
|
55
|
+
self.enable_dataloader = self.config.enable_dataloader
|
|
56
|
+
self._param_warning()
|
|
86
57
|
|
|
87
58
|
@staticmethod
|
|
88
|
-
def
|
|
89
|
-
|
|
90
|
-
if not isinstance(args.config_path, str):
|
|
91
|
-
raise MsprobeException(
|
|
92
|
-
MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
|
|
93
|
-
file_checker = FileChecker(
|
|
94
|
-
file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
95
|
-
file_checker.common_check()
|
|
59
|
+
def _get_task_config(task, json_config):
|
|
60
|
+
return parse_task_config(task, json_config)
|
|
96
61
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
if not isinstance(args.dump_path, str):
|
|
62
|
+
@staticmethod
|
|
63
|
+
def _iter_tracer(func):
|
|
64
|
+
def func_wrapper(*args, **kwargs):
|
|
65
|
+
debugger_instance = PrecisionDebugger._instance
|
|
66
|
+
if not debugger_instance:
|
|
103
67
|
raise MsprobeException(
|
|
104
|
-
MsprobeException.
|
|
68
|
+
MsprobeException.INTERFACE_USAGE_ERROR,
|
|
69
|
+
f"PrecisionDebugger must be instantiated before executing the dataloader iteration"
|
|
70
|
+
)
|
|
105
71
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
72
|
+
debugger_instance.enable_dataloader = False
|
|
73
|
+
if not debugger_instance.service.first_start:
|
|
74
|
+
debugger_instance.stop()
|
|
75
|
+
debugger_instance.step()
|
|
76
|
+
result = func(*args, **kwargs)
|
|
77
|
+
debugger_instance.start()
|
|
78
|
+
debugger_instance.enable_dataloader = True
|
|
79
|
+
return result
|
|
109
80
|
|
|
110
|
-
|
|
111
|
-
logger.warning_on_rank_0(
|
|
112
|
-
"The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
|
|
113
|
-
"It is recommended to pass the 'model' parameter in the start interface instead."
|
|
114
|
-
)
|
|
81
|
+
return func_wrapper
|
|
115
82
|
|
|
116
83
|
@classmethod
|
|
117
|
-
def start(cls, model=None):
|
|
118
|
-
instance = cls.
|
|
119
|
-
if
|
|
120
|
-
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
121
|
-
if instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
84
|
+
def start(cls, model=None, token_range=None):
|
|
85
|
+
instance = cls._get_instance()
|
|
86
|
+
if instance is None:
|
|
122
87
|
return
|
|
123
|
-
|
|
88
|
+
|
|
89
|
+
check_token_range(token_range)
|
|
90
|
+
instance.config.check_model(instance, model, token_range)
|
|
91
|
+
|
|
124
92
|
if instance.enable_dataloader:
|
|
125
93
|
logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
|
|
126
94
|
else:
|
|
127
|
-
instance.service.start(instance.model)
|
|
128
|
-
|
|
129
|
-
@classmethod
|
|
130
|
-
def forward_backward_dump_end(cls):
|
|
131
|
-
instance = cls._instance
|
|
132
|
-
instance.stop()
|
|
95
|
+
instance.service.start(instance.model, token_range)
|
|
133
96
|
|
|
134
97
|
@classmethod
|
|
135
98
|
def stop(cls):
|
|
136
|
-
instance = cls.
|
|
137
|
-
if
|
|
138
|
-
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
139
|
-
if instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
99
|
+
instance = cls._get_instance()
|
|
100
|
+
if instance is None:
|
|
140
101
|
return
|
|
141
102
|
if instance.enable_dataloader:
|
|
142
103
|
logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
|
|
@@ -145,9 +106,8 @@ class PrecisionDebugger:
|
|
|
145
106
|
|
|
146
107
|
@classmethod
|
|
147
108
|
def step(cls):
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
109
|
+
instance = cls._get_instance()
|
|
110
|
+
if instance is None:
|
|
151
111
|
return
|
|
152
112
|
cls._instance.service.step()
|
|
153
113
|
|
|
@@ -172,21 +132,23 @@ class PrecisionDebugger:
|
|
|
172
132
|
return
|
|
173
133
|
instance.service.save(variable, name, save_backward)
|
|
174
134
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
135
|
+
def _param_warning(self):
|
|
136
|
+
if self.model is not None:
|
|
137
|
+
logger.warning_on_rank_0(
|
|
138
|
+
"The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
|
|
139
|
+
"It is recommended to pass the 'model' parameter in the start interface instead."
|
|
140
|
+
)
|
|
141
|
+
if self.enable_dataloader:
|
|
142
|
+
logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
|
|
143
|
+
dataloader._BaseDataLoaderIter.__next__ = self._iter_tracer(dataloader._BaseDataLoaderIter.__next__)
|
|
183
144
|
|
|
184
145
|
|
|
185
146
|
def module_dump(module, dump_name):
|
|
186
|
-
if not
|
|
147
|
+
if not is_torch_nn_module(module):
|
|
187
148
|
raise MsprobeException(
|
|
188
149
|
MsprobeException.INVALID_PARAM_ERROR,
|
|
189
|
-
f"the module argument in module_dump must be a torch.nn.Module
|
|
150
|
+
f"the module argument in module_dump must be a torch.nn.Module type, "
|
|
151
|
+
f"but currently there is an unsupported {type(module)} type."
|
|
190
152
|
)
|
|
191
153
|
if not isinstance(dump_name, str):
|
|
192
154
|
raise MsprobeException(
|
|
@@ -210,17 +172,3 @@ def module_dump_end():
|
|
|
210
172
|
f"PrecisionDebugger must be instantiated before using module_dump_end interface"
|
|
211
173
|
)
|
|
212
174
|
instance.module_dumper.stop_module_dump()
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
def iter_tracer(func):
|
|
216
|
-
def func_wrapper(*args, **kwargs):
|
|
217
|
-
debugger_instance = PrecisionDebugger.instance
|
|
218
|
-
debugger_instance.enable_dataloader = False
|
|
219
|
-
if not debugger_instance.service.first_start:
|
|
220
|
-
debugger_instance.stop()
|
|
221
|
-
debugger_instance.step()
|
|
222
|
-
result = func(*args, **kwargs)
|
|
223
|
-
debugger_instance.start()
|
|
224
|
-
debugger_instance.enable_dataloader = True
|
|
225
|
-
return result
|
|
226
|
-
return func_wrapper
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from functools import wraps
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch.utils.hooks import BackwardHook
|
|
20
|
+
|
|
21
|
+
from msprobe.core.common.const import Const
|
|
22
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
23
|
+
from msprobe.pytorch.common.log import logger
|
|
24
|
+
from msprobe.pytorch.common.utils import is_float8_tensor
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def wrap_setup_backward_hook(func):
|
|
28
|
+
def requires_clone(tensor):
|
|
29
|
+
return isinstance(tensor, torch.Tensor) and not is_float8_tensor(tensor) and \
|
|
30
|
+
tensor.requires_grad and torch.is_grad_enabled()
|
|
31
|
+
|
|
32
|
+
@recursion_depth_decorator("Dump: wrap_setup_backward_hook.parse_tensor", max_depth=Const.DUMP_MAX_DEPTH)
|
|
33
|
+
def parse_tensor(item, tensor_list):
|
|
34
|
+
if requires_clone(item):
|
|
35
|
+
tensor_list.append(item)
|
|
36
|
+
elif isinstance(item, (list, tuple)):
|
|
37
|
+
for value in item:
|
|
38
|
+
parse_tensor(value, tensor_list)
|
|
39
|
+
elif isinstance(item, dict):
|
|
40
|
+
for value in item.values():
|
|
41
|
+
parse_tensor(value, tensor_list)
|
|
42
|
+
|
|
43
|
+
@recursion_depth_decorator("Dump: wrap_setup_backward_hook.rebuild_args", max_depth=Const.DUMP_MAX_DEPTH)
|
|
44
|
+
def rebuild_args(item, tensor_iter):
|
|
45
|
+
if requires_clone(item):
|
|
46
|
+
result = next(tensor_iter)
|
|
47
|
+
if hasattr(result, "_base") and result._base is not None:
|
|
48
|
+
if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0):
|
|
49
|
+
torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0))
|
|
50
|
+
return result
|
|
51
|
+
if isinstance(item, list):
|
|
52
|
+
for index, value in enumerate(item):
|
|
53
|
+
item[index] = rebuild_args(value, tensor_iter)
|
|
54
|
+
return item
|
|
55
|
+
if isinstance(item, dict):
|
|
56
|
+
for key, value in item.items():
|
|
57
|
+
item[key] = rebuild_args(value, tensor_iter)
|
|
58
|
+
return item
|
|
59
|
+
if isinstance(item, tuple):
|
|
60
|
+
if hasattr(item, '_fields'):
|
|
61
|
+
return type(item)(*[rebuild_args(i, tensor_iter) for i in item])
|
|
62
|
+
return type(item)([rebuild_args(i, tensor_iter) for i in item])
|
|
63
|
+
return item
|
|
64
|
+
|
|
65
|
+
@wraps(func)
|
|
66
|
+
def wrap_setup_hook_func(*args, **kwargs):
|
|
67
|
+
if len(args) < 2:
|
|
68
|
+
return func(*args, **kwargs)
|
|
69
|
+
|
|
70
|
+
actual_args = args[1]
|
|
71
|
+
|
|
72
|
+
tensor_list = []
|
|
73
|
+
|
|
74
|
+
parse_tensor(actual_args, tensor_list)
|
|
75
|
+
|
|
76
|
+
new_args = args[0], tuple(tensor_list)
|
|
77
|
+
hooked_tensors = func(*new_args, **kwargs)
|
|
78
|
+
|
|
79
|
+
tensor_iter = iter(hooked_tensors)
|
|
80
|
+
try:
|
|
81
|
+
new_data = rebuild_args(actual_args, tensor_iter)
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.debug(f"Unsupported data in setup input/output hook. The detail info: {e}")
|
|
84
|
+
new_data = actual_args
|
|
85
|
+
|
|
86
|
+
return new_data
|
|
87
|
+
|
|
88
|
+
return wrap_setup_hook_func
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def wrap_setup_input_output_hook():
|
|
92
|
+
BackwardHook.setup_input_hook = wrap_setup_backward_hook(BackwardHook.setup_input_hook)
|
|
93
|
+
BackwardHook.setup_output_hook = wrap_setup_backward_hook(BackwardHook.setup_output_hook)
|