mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -17,9 +17,11 @@ import csv
|
|
|
17
17
|
import os
|
|
18
18
|
import copy
|
|
19
19
|
import threading
|
|
20
|
+
import traceback
|
|
21
|
+
from datetime import datetime, timezone, timedelta
|
|
20
22
|
|
|
21
23
|
from msprobe.core.common.const import Const, FileCheckConst
|
|
22
|
-
from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json
|
|
24
|
+
from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json, check_path_before_create
|
|
23
25
|
from msprobe.core.common.log import logger
|
|
24
26
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
25
27
|
|
|
@@ -35,11 +37,15 @@ class DataWriter:
|
|
|
35
37
|
self.free_benchmark_file_path = None
|
|
36
38
|
self.dump_tensor_data_dir = None
|
|
37
39
|
self.debug_file_path = None
|
|
40
|
+
self.dump_error_info_path = None
|
|
38
41
|
self.flush_size = 1000
|
|
42
|
+
self.larger_flush_size = 20000
|
|
39
43
|
self.cache_data = {}
|
|
40
44
|
self.cache_stack = {}
|
|
41
45
|
self.cache_construct = {}
|
|
42
46
|
self.cache_debug = {}
|
|
47
|
+
self.stat_stack_list = []
|
|
48
|
+
self._error_log_initialized = False
|
|
43
49
|
|
|
44
50
|
@staticmethod
|
|
45
51
|
def write_data_to_csv(result: list, result_header: tuple, file_path: str):
|
|
@@ -56,13 +62,54 @@ class DataWriter:
|
|
|
56
62
|
if is_new_file:
|
|
57
63
|
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
58
64
|
|
|
65
|
+
@recursion_depth_decorator("JsonWriter: DataWriter._replace_stat_placeholders")
|
|
66
|
+
def _replace_stat_placeholders(self, data, stat_result):
|
|
67
|
+
if isinstance(data, dict):
|
|
68
|
+
keys = list(data.keys()) # 获取当前所有键
|
|
69
|
+
for key in keys: # 递归所有变量
|
|
70
|
+
value = data[key]
|
|
71
|
+
if key == Const.TENSOR_STAT_INDEX and isinstance(value, int):
|
|
72
|
+
if value >= 0:
|
|
73
|
+
idx = value
|
|
74
|
+
else:
|
|
75
|
+
return
|
|
76
|
+
stat_values = stat_result[idx] if idx < len(stat_result) else [None] * 4
|
|
77
|
+
|
|
78
|
+
new_entries = {
|
|
79
|
+
Const.TYPE: data["type"],
|
|
80
|
+
Const.DTYPE: data["dtype"],
|
|
81
|
+
Const.SHAPE: data["shape"],
|
|
82
|
+
Const.MAX: stat_values[0],
|
|
83
|
+
Const.MIN: stat_values[1],
|
|
84
|
+
Const.MEAN: stat_values[2],
|
|
85
|
+
Const.NORM: stat_values[3],
|
|
86
|
+
}
|
|
87
|
+
del data[key]
|
|
88
|
+
|
|
89
|
+
# 重构字典顺序
|
|
90
|
+
updated_dict = {}
|
|
91
|
+
# 通过插入排序后字段保证字段写入json的有序
|
|
92
|
+
updated_dict.update(new_entries)
|
|
93
|
+
# 遍历原字典其他字段(排除已删除的tensor_stat_index)
|
|
94
|
+
for k in data:
|
|
95
|
+
if k not in new_entries:
|
|
96
|
+
updated_dict[k] = data[k]
|
|
97
|
+
data.clear()
|
|
98
|
+
data.update(updated_dict)
|
|
99
|
+
else:
|
|
100
|
+
self._replace_stat_placeholders(value, stat_result)
|
|
101
|
+
elif isinstance(data, (list, tuple)):
|
|
102
|
+
for item in data:
|
|
103
|
+
self._replace_stat_placeholders(item, stat_result)
|
|
104
|
+
|
|
59
105
|
def reset_cache(self):
|
|
60
106
|
self.cache_data = {}
|
|
61
107
|
self.cache_stack = {}
|
|
62
108
|
self.cache_construct = {}
|
|
109
|
+
self.cache_debug = {}
|
|
63
110
|
|
|
64
111
|
def initialize_json_file(self, **kwargs):
|
|
65
|
-
if
|
|
112
|
+
if kwargs["level"] == Const.LEVEL_DEBUG and not self.cache_debug:
|
|
66
113
|
# debug level case only create debug.json
|
|
67
114
|
debug_dict = copy.deepcopy(kwargs)
|
|
68
115
|
debug_dict.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
|
|
@@ -85,12 +132,46 @@ class DataWriter:
|
|
|
85
132
|
self.dump_tensor_data_dir = dump_path_aggregation.dump_tensor_data_dir
|
|
86
133
|
self.free_benchmark_file_path = dump_path_aggregation.free_benchmark_file_path
|
|
87
134
|
self.debug_file_path = dump_path_aggregation.debug_file_path
|
|
135
|
+
self.dump_error_info_path = dump_path_aggregation.dump_error_info_path
|
|
88
136
|
|
|
89
137
|
def flush_data_periodically(self):
|
|
90
138
|
dump_data = self.cache_data.get(Const.DATA)
|
|
91
|
-
|
|
139
|
+
|
|
140
|
+
if not dump_data or not isinstance(dump_data, dict):
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
length = len(dump_data)
|
|
144
|
+
|
|
145
|
+
threshold = self.flush_size if length < self.larger_flush_size else self.larger_flush_size
|
|
146
|
+
|
|
147
|
+
if length % threshold == 0:
|
|
92
148
|
self.write_json()
|
|
93
149
|
|
|
150
|
+
def write_error_log(self, message: str):
|
|
151
|
+
"""
|
|
152
|
+
写错误日志:
|
|
153
|
+
- 第一次调用时以 'w' 模式清空文件,之后都用 'a' 模式追加
|
|
154
|
+
- 添加时间戳
|
|
155
|
+
- 在 message 后写入当前的调用栈(方便追踪日志来源)
|
|
156
|
+
"""
|
|
157
|
+
try:
|
|
158
|
+
mode = "w" if not self._error_log_initialized else "a"
|
|
159
|
+
self._error_log_initialized = True
|
|
160
|
+
|
|
161
|
+
check_path_before_create(self.dump_error_info_path)
|
|
162
|
+
|
|
163
|
+
with FileOpen(self.dump_error_info_path, mode) as f:
|
|
164
|
+
cst_timezone = timezone(timedelta(hours=8), name="CST")
|
|
165
|
+
timestamp = datetime.now(cst_timezone).strftime("%Y-%m-%d %H:%M:%S %z")
|
|
166
|
+
f.write(f"[{timestamp}] {message}\n")
|
|
167
|
+
f.write("Call stack (most recent call last):\n")
|
|
168
|
+
|
|
169
|
+
f.write("".join(traceback.format_stack()[:-1])) # 去掉自己这一层
|
|
170
|
+
f.write("\n")
|
|
171
|
+
except Exception as e:
|
|
172
|
+
# 如果连写日志都失败了,就打印到 stderr
|
|
173
|
+
logger.warning(f"[FallbackError] Failed to write error log: {e}")
|
|
174
|
+
|
|
94
175
|
def update_data(self, new_data):
|
|
95
176
|
with lock:
|
|
96
177
|
if not isinstance(new_data, dict) or len(new_data.keys()) != 1:
|
|
@@ -107,9 +188,13 @@ class DataWriter:
|
|
|
107
188
|
else:
|
|
108
189
|
dump_data.update(new_data)
|
|
109
190
|
|
|
110
|
-
def update_stack(self,
|
|
191
|
+
def update_stack(self, name, stack_data):
|
|
111
192
|
with lock:
|
|
112
|
-
self.cache_stack.
|
|
193
|
+
api_list = self.cache_stack.get(stack_data)
|
|
194
|
+
if api_list is None:
|
|
195
|
+
self.cache_stack.update({stack_data: [name]})
|
|
196
|
+
else:
|
|
197
|
+
api_list.append(name)
|
|
113
198
|
|
|
114
199
|
def update_construct(self, new_data):
|
|
115
200
|
with lock:
|
|
@@ -124,7 +209,11 @@ class DataWriter:
|
|
|
124
209
|
save_json(file_path, self.cache_data, indent=1)
|
|
125
210
|
|
|
126
211
|
def write_stack_info_json(self, file_path):
|
|
127
|
-
|
|
212
|
+
num, new_cache_stack = 0, {}
|
|
213
|
+
for key, value in self.cache_stack.items():
|
|
214
|
+
new_cache_stack[num] = [value, key]
|
|
215
|
+
num += 1
|
|
216
|
+
save_json(file_path, new_cache_stack, indent=1)
|
|
128
217
|
|
|
129
218
|
def write_construct_info_json(self, file_path):
|
|
130
219
|
save_json(file_path, self.cache_construct, indent=1)
|
|
@@ -132,8 +221,56 @@ class DataWriter:
|
|
|
132
221
|
def write_debug_info_json(self, file_path):
|
|
133
222
|
save_json(file_path, self.cache_debug, indent=1)
|
|
134
223
|
|
|
224
|
+
def append_stat_to_buffer(self, stat_vector):
|
|
225
|
+
"""
|
|
226
|
+
直接使用 Python list 存储 stat_vector,
|
|
227
|
+
将 stat_vector 存入 self.stat_stack_list 的方式
|
|
228
|
+
"""
|
|
229
|
+
self.stat_stack_list.append(stat_vector)
|
|
230
|
+
return len(self.stat_stack_list) - 1
|
|
231
|
+
|
|
232
|
+
def get_buffer_values_max(self, index):
|
|
233
|
+
if 0 <= index < len(self.stat_stack_list) and len(self.stat_stack_list[index]) >= 1:
|
|
234
|
+
return self.stat_stack_list[index][0]
|
|
235
|
+
else:
|
|
236
|
+
logger.warning(f"stat_stack_list[{index}] The internal data is incomplete,"
|
|
237
|
+
f" and the maximum value cannot be obtained.")
|
|
238
|
+
return None
|
|
239
|
+
|
|
240
|
+
def get_buffer_values_min(self, index):
|
|
241
|
+
if 0 <= index < len(self.stat_stack_list) and len(self.stat_stack_list[index]) >= 1:
|
|
242
|
+
return self.stat_stack_list[index][1]
|
|
243
|
+
else:
|
|
244
|
+
logger.warning(f"stat_stack_list[{index}] Internal data is incomplete"
|
|
245
|
+
f" and minimum values cannot be obtained.")
|
|
246
|
+
return None
|
|
247
|
+
|
|
248
|
+
def flush_stat_stack(self):
|
|
249
|
+
"""
|
|
250
|
+
在 flush 阶段,将所有存储的统计值从设备搬到 CPU,
|
|
251
|
+
这里返回一个列表,每个元素是 [Max, Min, Mean, Norm] 的数值列表
|
|
252
|
+
"""
|
|
253
|
+
if not self.stat_stack_list:
|
|
254
|
+
return []
|
|
255
|
+
result = [
|
|
256
|
+
[
|
|
257
|
+
x.item() if hasattr(x, "item") else x
|
|
258
|
+
for x in stat_values
|
|
259
|
+
]
|
|
260
|
+
for stat_values in self.stat_stack_list
|
|
261
|
+
]
|
|
262
|
+
self.stat_stack_list = []
|
|
263
|
+
return result
|
|
264
|
+
|
|
135
265
|
def write_json(self):
|
|
136
266
|
with lock:
|
|
267
|
+
# 在写 JSON 前,统一获取统计值
|
|
268
|
+
stat_result = self.flush_stat_stack()
|
|
269
|
+
# 遍历 cache_data,将占位符替换为最终统计值
|
|
270
|
+
if stat_result:
|
|
271
|
+
self._replace_stat_placeholders(self.cache_data, stat_result)
|
|
272
|
+
if self.cache_debug:
|
|
273
|
+
self._replace_stat_placeholders(self.cache_debug, stat_result)
|
|
137
274
|
if self.cache_data:
|
|
138
275
|
self.write_data_json(self.dump_file_path)
|
|
139
276
|
if self.cache_stack:
|
|
@@ -143,24 +280,3 @@ class DataWriter:
|
|
|
143
280
|
if self.cache_debug:
|
|
144
281
|
self.write_debug_info_json(self.debug_file_path)
|
|
145
282
|
|
|
146
|
-
def fill_stack_tensor_data(self):
|
|
147
|
-
self.process_stat_data_recursive(self.cache_data)
|
|
148
|
-
|
|
149
|
-
@recursion_depth_decorator("AsyncDump: DataWriter.process_stat_data_recursive", max_depth=Const.DUMP_MAX_DEPTH)
|
|
150
|
-
def process_stat_data_recursive(self, data):
|
|
151
|
-
if isinstance(data, dict):
|
|
152
|
-
if "tensor_stat" in data.keys():
|
|
153
|
-
tensor_stat = data["tensor_stat"]
|
|
154
|
-
if len(tensor_stat) != Const.TENSOR_STAT_LEN or len(tensor_stat[0]) != len(tensor_stat[1]):
|
|
155
|
-
logger.warning("Some bad data in async dump")
|
|
156
|
-
else:
|
|
157
|
-
tensor_stat_index, tensor_stat_data = tensor_stat[0], tensor_stat[1]
|
|
158
|
-
for index, stat in zip(tensor_stat_index, tensor_stat_data):
|
|
159
|
-
data.update({index: stat.item()})
|
|
160
|
-
del data["tensor_stat"]
|
|
161
|
-
else:
|
|
162
|
-
for key in data.keys():
|
|
163
|
-
self.process_stat_data_recursive(data[key])
|
|
164
|
-
elif isinstance(data, (list, tuple)):
|
|
165
|
-
for i in data:
|
|
166
|
-
self.process_stat_data_recursive(i)
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.const import Const, FileCheckConst, MsgConst
|
|
19
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
20
|
+
from msprobe.core.common.file_utils import FileChecker, load_json
|
|
21
|
+
from msprobe.core.common.utils import get_real_step_or_rank, check_init_step
|
|
22
|
+
from msprobe.core.common_config import CommonConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class BasePrecisionDebugger:
|
|
26
|
+
_instance = None
|
|
27
|
+
tasks_not_need_debugger = [Const.GRAD_PROBE]
|
|
28
|
+
|
|
29
|
+
def __new__(cls, *args, **kwargs):
|
|
30
|
+
if cls._instance is None:
|
|
31
|
+
cls._instance = super(BasePrecisionDebugger, cls).__new__(cls)
|
|
32
|
+
cls._instance.config = None
|
|
33
|
+
cls._instance.enable_dataloader = False
|
|
34
|
+
cls._instance.initialized = False
|
|
35
|
+
cls.service = None
|
|
36
|
+
cls.first_start = False
|
|
37
|
+
return cls._instance
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
config_path=None,
|
|
42
|
+
task=None,
|
|
43
|
+
dump_path=None,
|
|
44
|
+
level=None,
|
|
45
|
+
step=None
|
|
46
|
+
):
|
|
47
|
+
if self.initialized:
|
|
48
|
+
return
|
|
49
|
+
self.initialized = True
|
|
50
|
+
self._check_input_params(config_path, task, dump_path, level)
|
|
51
|
+
self.common_config, self.task_config = self._parse_config_path(config_path, task)
|
|
52
|
+
self.task = self.common_config.task
|
|
53
|
+
if step is not None:
|
|
54
|
+
self.common_config.step = get_real_step_or_rank(step, Const.STEP)
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def _check_input_params(config_path, task, dump_path, level):
|
|
58
|
+
if not config_path:
|
|
59
|
+
config_path = os.path.join(os.path.dirname(__file__), "../../config.json")
|
|
60
|
+
if config_path is not None:
|
|
61
|
+
if not isinstance(config_path, str):
|
|
62
|
+
raise MsprobeException(
|
|
63
|
+
MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
|
|
64
|
+
file_checker = FileChecker(
|
|
65
|
+
file_path=config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
66
|
+
file_checker.common_check()
|
|
67
|
+
|
|
68
|
+
if task is not None and task not in Const.TASK_LIST:
|
|
69
|
+
raise MsprobeException(
|
|
70
|
+
MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
|
|
71
|
+
|
|
72
|
+
if dump_path is not None:
|
|
73
|
+
if not isinstance(dump_path, str):
|
|
74
|
+
raise MsprobeException(
|
|
75
|
+
MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
|
|
76
|
+
|
|
77
|
+
if level is not None and level not in Const.LEVEL_LIST:
|
|
78
|
+
raise MsprobeException(
|
|
79
|
+
MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def _get_task_config(task, json_config):
|
|
83
|
+
raise NotImplementedError("Subclass must implement _get_task_config")
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
def forward_backward_dump_end(cls):
|
|
87
|
+
instance = cls._instance
|
|
88
|
+
instance.stop()
|
|
89
|
+
|
|
90
|
+
@classmethod
|
|
91
|
+
def set_init_step(cls, step):
|
|
92
|
+
instance = cls._instance
|
|
93
|
+
if not instance:
|
|
94
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
95
|
+
check_init_step(step)
|
|
96
|
+
instance.service.init_step = step
|
|
97
|
+
instance.service.loop = 0
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def register_custom_api(cls, module, api, api_prefix=None):
|
|
101
|
+
if not api_prefix:
|
|
102
|
+
api_prefix = getattr(module, "__name__", "Custom")
|
|
103
|
+
if not isinstance(api_prefix, str):
|
|
104
|
+
raise MsprobeException(
|
|
105
|
+
MsprobeException.INVALID_PARAM_ERROR, "api_prefix must be string")
|
|
106
|
+
if not hasattr(module, api):
|
|
107
|
+
raise MsprobeException(
|
|
108
|
+
MsprobeException.INVALID_PARAM_ERROR, f"module {str(module)} does not have {api}")
|
|
109
|
+
instance = cls._instance
|
|
110
|
+
if not instance:
|
|
111
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
112
|
+
instance.service.register_custom_api(module, api, api_prefix)
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def restore_custom_api(cls, module, api):
|
|
116
|
+
if not hasattr(module, api):
|
|
117
|
+
raise MsprobeException(
|
|
118
|
+
MsprobeException.INVALID_PARAM_ERROR, f"module {str(module)} does not have {api}")
|
|
119
|
+
instance = cls._instance
|
|
120
|
+
if not instance:
|
|
121
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
122
|
+
instance.service.restore_custom_api(module, api)
|
|
123
|
+
|
|
124
|
+
@classmethod
|
|
125
|
+
def _get_instance(cls):
|
|
126
|
+
instance = cls._instance
|
|
127
|
+
if not instance:
|
|
128
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
129
|
+
if instance.task in BasePrecisionDebugger.tasks_not_need_debugger:
|
|
130
|
+
instance = None
|
|
131
|
+
return instance
|
|
132
|
+
|
|
133
|
+
def _parse_config_path(self, json_file_path, task):
|
|
134
|
+
if not json_file_path:
|
|
135
|
+
json_file_path = os.path.join(os.path.dirname(__file__), "../../config.json")
|
|
136
|
+
json_config = load_json(json_file_path)
|
|
137
|
+
common_config = CommonConfig(json_config)
|
|
138
|
+
if task:
|
|
139
|
+
task_config = self._get_task_config(task, json_config)
|
|
140
|
+
else:
|
|
141
|
+
if not common_config.task:
|
|
142
|
+
common_config.task = Const.STATISTICS
|
|
143
|
+
task_config = self._get_task_config(common_config.task, json_config)
|
|
144
|
+
return common_config, task_config
|
|
@@ -121,7 +121,7 @@ class GradComparator:
|
|
|
121
121
|
similarities = {}
|
|
122
122
|
logger.info(f"{len(steps)} steps will be compared")
|
|
123
123
|
grad_weight_order = cls._get_grad_weight_order(path1, path2)
|
|
124
|
-
for step in tqdm(steps, desc="
|
|
124
|
+
for step in tqdm(steps, desc="calculate similarities (by step)"):
|
|
125
125
|
grad_files = cls._get_matched_grad_files(path1, path2, step)
|
|
126
126
|
same_count_summary = 0
|
|
127
127
|
total_count_summary = 0
|
msprobe/core/grad_probe/utils.py
CHANGED
|
@@ -82,7 +82,7 @@ class ListCache(list):
|
|
|
82
82
|
if len(self) == 0:
|
|
83
83
|
return
|
|
84
84
|
if not self._output_file:
|
|
85
|
-
logger.warning("dumpfile path is not
|
|
85
|
+
logger.warning("dumpfile path is not set.")
|
|
86
86
|
write_csv(self, self._output_file)
|
|
87
87
|
logger.info(f"write {len(self)} items to {self._output_file}.")
|
|
88
88
|
self.clear()
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
import os
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.runtime import Runtime
|
|
21
|
+
from msprobe.core.common.utils import Const
|
|
22
|
+
from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class HookSet:
|
|
26
|
+
def __init__(self, forward_hook=None, forward_pre_hook=None, backward_hook=None, backward_pre_hook=None):
|
|
27
|
+
self.forward_hook = forward_hook
|
|
28
|
+
self.forward_pre_hook = forward_pre_hook
|
|
29
|
+
self.backward_hook = backward_hook
|
|
30
|
+
self.backward_pre_hook = backward_pre_hook
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BaseHookManager(ABC):
|
|
34
|
+
inner_switch = False
|
|
35
|
+
hook_handle_dict = {}
|
|
36
|
+
params_grad_info = {}
|
|
37
|
+
|
|
38
|
+
def __init__(self, data_collector, config, attl_manager=None):
|
|
39
|
+
self.data_collector = data_collector
|
|
40
|
+
self.config = config
|
|
41
|
+
self.attl_manager = attl_manager
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def _pid(self):
|
|
45
|
+
return os.getpid()
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def _is_recompute(self):
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def _no_grad_context():
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
@abstractmethod
|
|
59
|
+
def _add_count(name):
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
@abstractmethod
|
|
64
|
+
def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs):
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def _clear_input_kwargs(module):
|
|
69
|
+
if hasattr(module, 'msprobe_input_kwargs'):
|
|
70
|
+
del module.msprobe_input_kwargs
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def build_hook(self):
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
@abstractmethod
|
|
77
|
+
def _get_params_dict(self, module):
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
@abstractmethod
|
|
81
|
+
def _need_exchange(self, module):
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
def _register_param_hook(self, name, module, params_dict):
|
|
85
|
+
ori_name = name.rsplit(Const.SEP, 2)[0]
|
|
86
|
+
grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
|
|
87
|
+
# 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
|
|
88
|
+
setattr(module, 'params_grad_name', grad_name)
|
|
89
|
+
# data_mode为forward时,不注册参数hook
|
|
90
|
+
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
91
|
+
for param_name, param in params_dict.items():
|
|
92
|
+
if param.requires_grad:
|
|
93
|
+
name = ori_name + Const.SEP + param_name
|
|
94
|
+
old_handle = BaseHookManager.hook_handle_dict.get(name)
|
|
95
|
+
if old_handle and hasattr(old_handle, "remove"):
|
|
96
|
+
old_handle.remove()
|
|
97
|
+
handle = param.register_hook(self._build_grad_hook(module, ori_name, param_name))
|
|
98
|
+
BaseHookManager.hook_handle_dict[name] = handle
|
|
99
|
+
|
|
100
|
+
def _init_params_grad_info(self, module, params_dict):
|
|
101
|
+
'''
|
|
102
|
+
初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
|
|
103
|
+
'''
|
|
104
|
+
if not params_dict:
|
|
105
|
+
return
|
|
106
|
+
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
107
|
+
grad_name = module.params_grad_name if hasattr(module, 'params_grad_name') else None
|
|
108
|
+
# 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
|
|
109
|
+
if not BaseHookManager.params_grad_info.get(grad_name):
|
|
110
|
+
data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
|
|
111
|
+
# 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
|
|
112
|
+
if data_info.get(grad_name):
|
|
113
|
+
# 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
|
|
114
|
+
self.data_collector.handle_data(grad_name, data_info,
|
|
115
|
+
flush=self.data_collector.data_processor.is_terminated)
|
|
116
|
+
# 记录当前模块的参数梯度信息已占位
|
|
117
|
+
BaseHookManager.params_grad_info[grad_name] = True
|
|
118
|
+
|
|
119
|
+
def _should_execute_hook(self, hook_type, module, is_forward):
|
|
120
|
+
is_module_hook = hook_type == Const.MODULE
|
|
121
|
+
if is_module_hook and not Runtime.is_running:
|
|
122
|
+
return False
|
|
123
|
+
elif not is_module_hook and is_forward and not Runtime.is_running:
|
|
124
|
+
return False
|
|
125
|
+
elif not is_module_hook and not is_forward and not module.forward_data_collected:
|
|
126
|
+
return False
|
|
127
|
+
if BaseHookManager.inner_switch:
|
|
128
|
+
return False
|
|
129
|
+
if not self.data_collector or self.data_collector.data_processor.is_terminated:
|
|
130
|
+
return False
|
|
131
|
+
return True
|
|
132
|
+
|
|
133
|
+
def _build_grad_hook(self, module, ori_name, param_name):
|
|
134
|
+
def hook_fn(grad):
|
|
135
|
+
if not self._should_execute_hook(Const.MODULE, module, False):
|
|
136
|
+
return
|
|
137
|
+
BaseHookManager.inner_switch = True
|
|
138
|
+
self.data_collector.params_data_collect(ori_name, param_name, self._pid, grad)
|
|
139
|
+
BaseHookManager.inner_switch = False
|
|
140
|
+
return
|
|
141
|
+
return hook_fn
|
|
142
|
+
|
|
143
|
+
def _build_forward_pre_hook(self, hook_type, full_name, api_name):
|
|
144
|
+
def forward_pre_hook(module, args, kwargs=None):
|
|
145
|
+
if hook_type == Const.MODULE:
|
|
146
|
+
return
|
|
147
|
+
if not self._should_execute_hook(hook_type, module, True):
|
|
148
|
+
return
|
|
149
|
+
if kwargs is None:
|
|
150
|
+
kwargs = module.msprobe_input_kwargs if hasattr(module, 'msprobe_input_kwargs') else {}
|
|
151
|
+
with self._no_grad_context():
|
|
152
|
+
BaseHookManager.inner_switch = False
|
|
153
|
+
module.forward_data_collected = True
|
|
154
|
+
self._add_count(api_name)
|
|
155
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
|
|
156
|
+
self.data_collector.update_api_or_module_name(full_name)
|
|
157
|
+
if getattr(self.config, "online_run_ut", False):
|
|
158
|
+
BaseHookManager.inner_switch = False
|
|
159
|
+
return
|
|
160
|
+
self.data_collector.forward_input_data_collect(
|
|
161
|
+
full_name,
|
|
162
|
+
module,
|
|
163
|
+
self._pid,
|
|
164
|
+
module_input_output,
|
|
165
|
+
self._is_recompute
|
|
166
|
+
)
|
|
167
|
+
BaseHookManager.inner_switch = False
|
|
168
|
+
return forward_pre_hook
|
|
169
|
+
|
|
170
|
+
def _build_forward_hook(self, hook_type, full_name):
|
|
171
|
+
def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None):
|
|
172
|
+
if not self._should_execute_hook(hook_type, module, True):
|
|
173
|
+
self._clear_input_kwargs(module)
|
|
174
|
+
return None
|
|
175
|
+
kwargs, output = self._process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs)
|
|
176
|
+
BaseHookManager.inner_switch = True
|
|
177
|
+
self.data_collector.update_api_or_module_name(full_name)
|
|
178
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
179
|
+
with self._no_grad_context():
|
|
180
|
+
if getattr(self.config, "online_run_ut", False):
|
|
181
|
+
if self.data_collector.scope and not self.data_collector.scope.check(full_name):
|
|
182
|
+
return None
|
|
183
|
+
if self.attl_manager:
|
|
184
|
+
self.attl_manager.attl_send(full_name, args, kwargs, output)
|
|
185
|
+
BaseHookManager.inner_switch = False
|
|
186
|
+
return None
|
|
187
|
+
if hook_type == Const.MODULE:
|
|
188
|
+
params_dict = self._get_params_dict(module)
|
|
189
|
+
setattr(module_input_output, Const.PARAMS, params_dict)
|
|
190
|
+
if params_dict:
|
|
191
|
+
self._register_param_hook(full_name, module, params_dict)
|
|
192
|
+
self.data_collector.update_api_or_module_name(full_name)
|
|
193
|
+
self.data_collector.forward_data_collect(
|
|
194
|
+
full_name,
|
|
195
|
+
module,
|
|
196
|
+
self._pid,
|
|
197
|
+
module_input_output,
|
|
198
|
+
self._is_recompute
|
|
199
|
+
)
|
|
200
|
+
self._init_params_grad_info(module, params_dict)
|
|
201
|
+
else:
|
|
202
|
+
self.data_collector.forward_output_data_collect(
|
|
203
|
+
full_name,
|
|
204
|
+
module,
|
|
205
|
+
self._pid,
|
|
206
|
+
module_input_output,
|
|
207
|
+
self._is_recompute
|
|
208
|
+
)
|
|
209
|
+
self._clear_input_kwargs(module)
|
|
210
|
+
|
|
211
|
+
if self.data_collector.if_return_forward_new_output():
|
|
212
|
+
forward_new_output = self.data_collector.get_forward_new_output()
|
|
213
|
+
BaseHookManager.inner_switch = False
|
|
214
|
+
return forward_new_output
|
|
215
|
+
|
|
216
|
+
BaseHookManager.inner_switch = False
|
|
217
|
+
return output
|
|
218
|
+
return forward_hook
|
|
219
|
+
|
|
220
|
+
def _build_backward_hook(self, hook_type, full_name):
|
|
221
|
+
def backward_hook(module, grad_input, grad_output):
|
|
222
|
+
if not self._should_execute_hook(hook_type, module, False):
|
|
223
|
+
return
|
|
224
|
+
BaseHookManager.inner_switch = True
|
|
225
|
+
self.data_collector.update_api_or_module_name(full_name)
|
|
226
|
+
if getattr(self.config, "online_run_ut", False):
|
|
227
|
+
BaseHookManager.inner_switch = False
|
|
228
|
+
return
|
|
229
|
+
need_exchange = self._need_exchange(module) if hook_type == Const.MODULE else True
|
|
230
|
+
if need_exchange:
|
|
231
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
|
|
232
|
+
else:
|
|
233
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
|
|
234
|
+
self.data_collector.backward_data_collect(
|
|
235
|
+
full_name,
|
|
236
|
+
module,
|
|
237
|
+
self._pid,
|
|
238
|
+
module_input_output,
|
|
239
|
+
self._is_recompute
|
|
240
|
+
)
|
|
241
|
+
BaseHookManager.inner_switch = False
|
|
242
|
+
return backward_hook
|