mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -13,17 +13,18 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import concurrent
|
|
17
|
+
import copy
|
|
16
18
|
import csv
|
|
17
19
|
import os
|
|
18
|
-
import copy
|
|
19
20
|
import threading
|
|
20
21
|
import traceback
|
|
21
22
|
from datetime import datetime, timezone, timedelta
|
|
22
23
|
|
|
23
24
|
from msprobe.core.common.const import Const, FileCheckConst
|
|
24
|
-
from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json, check_path_before_create
|
|
25
|
-
from msprobe.core.common.log import logger
|
|
26
25
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
26
|
+
from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, check_path_before_create
|
|
27
|
+
from msprobe.core.common.log import logger
|
|
27
28
|
|
|
28
29
|
lock = threading.Lock()
|
|
29
30
|
|
|
@@ -39,6 +40,7 @@ class DataWriter:
|
|
|
39
40
|
self.debug_file_path = None
|
|
40
41
|
self.dump_error_info_path = None
|
|
41
42
|
self.flush_size = 1000
|
|
43
|
+
self.md5_flush_size = 5000
|
|
42
44
|
self.larger_flush_size = 20000
|
|
43
45
|
self.cache_data = {}
|
|
44
46
|
self.cache_stack = {}
|
|
@@ -46,6 +48,9 @@ class DataWriter:
|
|
|
46
48
|
self.cache_debug = {}
|
|
47
49
|
self.stat_stack_list = []
|
|
48
50
|
self._error_log_initialized = False
|
|
51
|
+
self._cache_logged_error_types = set()
|
|
52
|
+
self.crc32_stack_list = []
|
|
53
|
+
self.data_updated = False
|
|
49
54
|
|
|
50
55
|
@staticmethod
|
|
51
56
|
def write_data_to_csv(result: list, result_header: tuple, file_path: str):
|
|
@@ -57,11 +62,31 @@ class DataWriter:
|
|
|
57
62
|
spawn_writer = csv.writer(csv_file)
|
|
58
63
|
if not is_exists:
|
|
59
64
|
spawn_writer.writerow(result_header)
|
|
60
|
-
spawn_writer.writerows([result,])
|
|
65
|
+
spawn_writer.writerows([result, ])
|
|
61
66
|
is_new_file = not is_exists
|
|
62
67
|
if is_new_file:
|
|
63
68
|
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
64
69
|
|
|
70
|
+
@recursion_depth_decorator("JsonWriter: DataWriter._replace_crc32_placeholders")
|
|
71
|
+
def _replace_crc32_placeholders(self, data, crc32_results):
|
|
72
|
+
"""
|
|
73
|
+
遍历 JSON 结构,将所有 md5_index 占位符替换成真实的 CRC32
|
|
74
|
+
"""
|
|
75
|
+
if isinstance(data, dict):
|
|
76
|
+
for k, v in list(data.items()):
|
|
77
|
+
if k == Const.MD5_INDEX and isinstance(v, int):
|
|
78
|
+
idx = v
|
|
79
|
+
# 防越界
|
|
80
|
+
crc = crc32_results[idx] if idx < len(crc32_results) else None
|
|
81
|
+
# 删除占位符,改成真实字段
|
|
82
|
+
del data[k]
|
|
83
|
+
data[Const.MD5] = crc
|
|
84
|
+
else:
|
|
85
|
+
self._replace_crc32_placeholders(v, crc32_results)
|
|
86
|
+
elif isinstance(data, (list, tuple)):
|
|
87
|
+
for item in data:
|
|
88
|
+
self._replace_crc32_placeholders(item, crc32_results)
|
|
89
|
+
|
|
65
90
|
@recursion_depth_decorator("JsonWriter: DataWriter._replace_stat_placeholders")
|
|
66
91
|
def _replace_stat_placeholders(self, data, stat_result):
|
|
67
92
|
if isinstance(data, dict):
|
|
@@ -107,6 +132,25 @@ class DataWriter:
|
|
|
107
132
|
self.cache_stack = {}
|
|
108
133
|
self.cache_construct = {}
|
|
109
134
|
self.cache_debug = {}
|
|
135
|
+
self._cache_logged_error_types = set()
|
|
136
|
+
|
|
137
|
+
def append_crc32_to_buffer(self, future: concurrent.futures.Future) -> int:
|
|
138
|
+
"""
|
|
139
|
+
把一个计算 CRC32 的 Future 放入队列,返回占位符索引
|
|
140
|
+
"""
|
|
141
|
+
idx = len(self.crc32_stack_list)
|
|
142
|
+
self.crc32_stack_list.append(future)
|
|
143
|
+
return idx
|
|
144
|
+
|
|
145
|
+
def flush_crc32_stack(self):
|
|
146
|
+
"""
|
|
147
|
+
等待所有 CRC32 计算完成,返回结果列表
|
|
148
|
+
"""
|
|
149
|
+
if not self.crc32_stack_list:
|
|
150
|
+
return []
|
|
151
|
+
results = [f.result() for f in self.crc32_stack_list]
|
|
152
|
+
self.crc32_stack_list = []
|
|
153
|
+
return results
|
|
110
154
|
|
|
111
155
|
def initialize_json_file(self, **kwargs):
|
|
112
156
|
if kwargs["level"] == Const.LEVEL_DEBUG and not self.cache_debug:
|
|
@@ -142,18 +186,32 @@ class DataWriter:
|
|
|
142
186
|
|
|
143
187
|
length = len(dump_data)
|
|
144
188
|
|
|
145
|
-
|
|
189
|
+
# 1) 先取到 config(如果没有,就拿 None)
|
|
190
|
+
cfg = getattr(self, "config", None)
|
|
191
|
+
# 2) 再取 summary_mode(如果 cfg 是 None 或者没 summary_mode,就拿 None)
|
|
192
|
+
summary_mode = getattr(cfg, "summary_mode", None)
|
|
193
|
+
|
|
194
|
+
if summary_mode == Const.MD5:
|
|
195
|
+
threshold = self.md5_flush_size
|
|
196
|
+
else:
|
|
197
|
+
threshold = self.flush_size if length < self.larger_flush_size else self.larger_flush_size
|
|
146
198
|
|
|
147
199
|
if length % threshold == 0:
|
|
148
200
|
self.write_json()
|
|
149
201
|
|
|
150
|
-
def write_error_log(self, message: str):
|
|
202
|
+
def write_error_log(self, message: str, error_type: str):
|
|
151
203
|
"""
|
|
152
204
|
写错误日志:
|
|
153
205
|
- 第一次调用时以 'w' 模式清空文件,之后都用 'a' 模式追加
|
|
154
206
|
- 添加时间戳
|
|
155
207
|
- 在 message 后写入当前的调用栈(方便追踪日志来源)
|
|
156
208
|
"""
|
|
209
|
+
# 如果同类型错误已经记录过,跳过
|
|
210
|
+
if error_type in self._cache_logged_error_types:
|
|
211
|
+
return
|
|
212
|
+
# 否则添加到已记录集合,并继续写日志
|
|
213
|
+
self._cache_logged_error_types.add(error_type)
|
|
214
|
+
|
|
157
215
|
try:
|
|
158
216
|
mode = "w" if not self._error_log_initialized else "a"
|
|
159
217
|
self._error_log_initialized = True
|
|
@@ -182,6 +240,7 @@ class DataWriter:
|
|
|
182
240
|
logger.warning(f"The dump data({dump_data}) should be a dict.")
|
|
183
241
|
return
|
|
184
242
|
|
|
243
|
+
self.data_updated = True
|
|
185
244
|
key = next(iter(new_data.keys()))
|
|
186
245
|
if key in dump_data:
|
|
187
246
|
dump_data.get(key).update(new_data.get(key))
|
|
@@ -190,6 +249,7 @@ class DataWriter:
|
|
|
190
249
|
|
|
191
250
|
def update_stack(self, name, stack_data):
|
|
192
251
|
with lock:
|
|
252
|
+
self.data_updated = True
|
|
193
253
|
api_list = self.cache_stack.get(stack_data)
|
|
194
254
|
if api_list is None:
|
|
195
255
|
self.cache_stack.update({stack_data: [name]})
|
|
@@ -198,10 +258,12 @@ class DataWriter:
|
|
|
198
258
|
|
|
199
259
|
def update_construct(self, new_data):
|
|
200
260
|
with lock:
|
|
261
|
+
self.data_updated = True
|
|
201
262
|
self.cache_construct.update(new_data)
|
|
202
263
|
|
|
203
264
|
def update_debug(self, new_data):
|
|
204
265
|
with lock:
|
|
266
|
+
self.data_updated = True
|
|
205
267
|
self.cache_debug['data'].update(new_data)
|
|
206
268
|
|
|
207
269
|
def write_data_json(self, file_path):
|
|
@@ -268,9 +330,21 @@ class DataWriter:
|
|
|
268
330
|
stat_result = self.flush_stat_stack()
|
|
269
331
|
# 遍历 cache_data,将占位符替换为最终统计值
|
|
270
332
|
if stat_result:
|
|
333
|
+
self.data_updated = True
|
|
271
334
|
self._replace_stat_placeholders(self.cache_data, stat_result)
|
|
272
335
|
if self.cache_debug:
|
|
273
336
|
self._replace_stat_placeholders(self.cache_debug, stat_result)
|
|
337
|
+
|
|
338
|
+
crc32_result = self.flush_crc32_stack()
|
|
339
|
+
if crc32_result:
|
|
340
|
+
self.data_updated = True
|
|
341
|
+
self._replace_crc32_placeholders(self.cache_data, crc32_result)
|
|
342
|
+
if self.cache_debug:
|
|
343
|
+
self._replace_crc32_placeholders(self.cache_debug, crc32_result)
|
|
344
|
+
|
|
345
|
+
if not self.data_updated:
|
|
346
|
+
return
|
|
347
|
+
|
|
274
348
|
if self.cache_data:
|
|
275
349
|
self.write_data_json(self.dump_file_path)
|
|
276
350
|
if self.cache_stack:
|
|
@@ -279,4 +353,4 @@ class DataWriter:
|
|
|
279
353
|
self.write_construct_info_json(self.construct_file_path)
|
|
280
354
|
if self.cache_debug:
|
|
281
355
|
self.write_debug_info_json(self.debug_file_path)
|
|
282
|
-
|
|
356
|
+
self.data_updated = False
|
msprobe/core/data_dump/scope.py
CHANGED
|
@@ -69,8 +69,7 @@ class BaseScope(ABC):
|
|
|
69
69
|
self.scope = scope
|
|
70
70
|
self.api_list = api_list
|
|
71
71
|
|
|
72
|
-
|
|
73
|
-
def rectify_args(scope, api_list):
|
|
72
|
+
def rectify_args(self, scope, api_list):
|
|
74
73
|
if not isinstance(api_list, list):
|
|
75
74
|
raise ScopeException(ScopeException.InvalidApiStr,
|
|
76
75
|
f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
|
|
@@ -104,12 +103,11 @@ class BaseScope(ABC):
|
|
|
104
103
|
|
|
105
104
|
|
|
106
105
|
class ListScope(BaseScope):
|
|
107
|
-
|
|
108
|
-
def rectify_args(scope, api_list):
|
|
106
|
+
def rectify_args(self, scope, api_list):
|
|
109
107
|
if scope and api_list:
|
|
110
108
|
raise ScopeException(ScopeException.ArgConflict,
|
|
111
109
|
f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
|
|
112
|
-
return super(
|
|
110
|
+
return super().rectify_args(scope, api_list)
|
|
113
111
|
|
|
114
112
|
def check(self, name):
|
|
115
113
|
if not self.scope or name in self.scope:
|
|
@@ -147,7 +145,7 @@ class RangeScope(BaseScope, ABC):
|
|
|
147
145
|
f"scope参数格式错误,要求格式为api或模块完整命名,实际为{name}.")
|
|
148
146
|
|
|
149
147
|
def rectify_args(self, scope, api_list):
|
|
150
|
-
scope, api_list = super(
|
|
148
|
+
scope, api_list = super().rectify_args(scope, api_list)
|
|
151
149
|
if scope and len(scope) != 2:
|
|
152
150
|
raise ScopeException(ScopeException.InvalidScope,
|
|
153
151
|
f"scope参数指定区间断点,须传入长度为2的列表,实际长度为{len(scope)}.")
|
msprobe/core/hook_manager.py
CHANGED
|
@@ -13,34 +13,42 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import gc
|
|
16
17
|
import os
|
|
17
18
|
import threading
|
|
18
19
|
from abc import ABC, abstractmethod
|
|
19
20
|
from collections import defaultdict
|
|
20
21
|
|
|
21
|
-
from msprobe.core.common.log import logger
|
|
22
22
|
from msprobe.core.common.runtime import Runtime
|
|
23
23
|
from msprobe.core.common.utils import Const, ThreadSafe
|
|
24
24
|
from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs)
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class HookSet:
|
|
28
|
-
def __init__(
|
|
29
|
-
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
forward_pre_hook=None,
|
|
31
|
+
forward_hook=None,
|
|
32
|
+
backward_pre_hook=None,
|
|
33
|
+
backward_hook=None,
|
|
34
|
+
distributed_forward_hook=None
|
|
35
|
+
):
|
|
30
36
|
self.forward_pre_hook = forward_pre_hook
|
|
31
|
-
self.
|
|
37
|
+
self.forward_hook = forward_hook
|
|
32
38
|
self.backward_pre_hook = backward_pre_hook
|
|
39
|
+
self.backward_hook = backward_hook
|
|
40
|
+
self.distributed_forward_hook = distributed_forward_hook
|
|
33
41
|
|
|
34
42
|
|
|
35
43
|
class BaseHookManager(ABC):
|
|
36
44
|
inner_switch = defaultdict(bool)
|
|
45
|
+
inner_api_count = defaultdict(int)
|
|
37
46
|
hook_handle_dict = {}
|
|
38
47
|
params_grad_info = {}
|
|
39
48
|
|
|
40
|
-
def __init__(self, data_collector, config
|
|
49
|
+
def __init__(self, data_collector, config):
|
|
41
50
|
self.data_collector = data_collector
|
|
42
51
|
self.config = config
|
|
43
|
-
self.attl_manager = attl_manager
|
|
44
52
|
|
|
45
53
|
@property
|
|
46
54
|
def _pid(self):
|
|
@@ -51,6 +59,30 @@ class BaseHookManager(ABC):
|
|
|
51
59
|
def _is_recompute(self):
|
|
52
60
|
pass
|
|
53
61
|
|
|
62
|
+
@staticmethod
|
|
63
|
+
def reset_status():
|
|
64
|
+
BaseHookManager.inner_switch = defaultdict(bool)
|
|
65
|
+
BaseHookManager.inner_api_count = defaultdict(int)
|
|
66
|
+
BaseHookManager.hook_handle_dict.clear()
|
|
67
|
+
BaseHookManager.params_grad_info.clear()
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def ensure_gc_enabled():
|
|
71
|
+
is_gc_disabled = not gc.isenabled()
|
|
72
|
+
if is_gc_disabled:
|
|
73
|
+
gc.enable()
|
|
74
|
+
return is_gc_disabled
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def restore_gc_state(original_state):
|
|
78
|
+
if original_state:
|
|
79
|
+
gc.disable()
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def _clear_input_kwargs(module, tid):
|
|
83
|
+
if hasattr(module, 'msprobe_input_kwargs') and tid in module.msprobe_input_kwargs:
|
|
84
|
+
del module.msprobe_input_kwargs[tid]
|
|
85
|
+
|
|
54
86
|
@staticmethod
|
|
55
87
|
@abstractmethod
|
|
56
88
|
def _no_grad_context():
|
|
@@ -63,18 +95,30 @@ class BaseHookManager(ABC):
|
|
|
63
95
|
|
|
64
96
|
@staticmethod
|
|
65
97
|
@abstractmethod
|
|
66
|
-
def
|
|
98
|
+
def _get_count(name):
|
|
67
99
|
pass
|
|
68
100
|
|
|
69
101
|
@staticmethod
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
102
|
+
@abstractmethod
|
|
103
|
+
def _process_kwargs_and_output(module, tid, hook_type, kwargs_or_output, output_or_kwargs):
|
|
104
|
+
pass
|
|
73
105
|
|
|
74
106
|
@abstractmethod
|
|
75
107
|
def build_hook(self):
|
|
76
108
|
pass
|
|
77
109
|
|
|
110
|
+
@abstractmethod
|
|
111
|
+
def _register_forward_hook(self, module, api_name):
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
@abstractmethod
|
|
115
|
+
def _register_backward_hook(self, module, full_backward_name, args):
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
@abstractmethod
|
|
119
|
+
def _register_backward_pre_hook(self, module, full_backward_name, output):
|
|
120
|
+
pass
|
|
121
|
+
|
|
78
122
|
@abstractmethod
|
|
79
123
|
def _get_params_dict(self, module):
|
|
80
124
|
pass
|
|
@@ -96,7 +140,7 @@ class BaseHookManager(ABC):
|
|
|
96
140
|
old_handle = BaseHookManager.hook_handle_dict.get(name)
|
|
97
141
|
if old_handle and hasattr(old_handle, "remove"):
|
|
98
142
|
old_handle.remove()
|
|
99
|
-
handle = param.register_hook(self._build_grad_hook(
|
|
143
|
+
handle = param.register_hook(self._build_grad_hook(ori_name, param_name))
|
|
100
144
|
BaseHookManager.hook_handle_dict[name] = handle
|
|
101
145
|
|
|
102
146
|
def _init_params_grad_info(self, module, params_dict):
|
|
@@ -115,108 +159,116 @@ class BaseHookManager(ABC):
|
|
|
115
159
|
# 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
|
|
116
160
|
self.data_collector.handle_data(grad_name, data_info,
|
|
117
161
|
flush=self.data_collector.data_processor.is_terminated)
|
|
162
|
+
self.data_collector.params_grad_record[grad_name] = True
|
|
118
163
|
# 记录当前模块的参数梯度信息已占位
|
|
119
164
|
BaseHookManager.params_grad_info[grad_name] = True
|
|
120
165
|
|
|
121
|
-
def _should_execute_hook(self, hook_type,
|
|
122
|
-
|
|
123
|
-
if
|
|
124
|
-
return False
|
|
125
|
-
elif not is_module_hook and is_forward and not Runtime.is_running:
|
|
166
|
+
def _should_execute_hook(self, hook_type, tid, is_forward=True):
|
|
167
|
+
is_api_hook = hook_type == Const.API
|
|
168
|
+
if BaseHookManager.inner_switch[tid]:
|
|
126
169
|
return False
|
|
127
|
-
|
|
170
|
+
if not is_api_hook and not Runtime.is_running:
|
|
128
171
|
return False
|
|
129
|
-
|
|
172
|
+
elif is_api_hook and is_forward and not Runtime.is_running:
|
|
130
173
|
return False
|
|
131
174
|
if not self.data_collector or self.data_collector.data_processor.is_terminated:
|
|
132
175
|
return False
|
|
133
176
|
return True
|
|
134
177
|
|
|
135
|
-
def _build_grad_hook(self,
|
|
178
|
+
def _build_grad_hook(self, ori_name, param_name):
|
|
136
179
|
def hook_fn(grad):
|
|
137
180
|
tid = threading.get_ident()
|
|
138
|
-
if not self._should_execute_hook(Const.MODULE,
|
|
181
|
+
if not self._should_execute_hook(Const.MODULE, tid):
|
|
139
182
|
return
|
|
140
183
|
with ThreadSafe():
|
|
184
|
+
original_state = self.ensure_gc_enabled()
|
|
141
185
|
BaseHookManager.inner_switch[tid] = True
|
|
142
186
|
self.data_collector.params_data_collect(ori_name, param_name, self._pid, grad)
|
|
143
187
|
BaseHookManager.inner_switch[tid] = False
|
|
188
|
+
self.restore_gc_state(original_state)
|
|
144
189
|
return
|
|
145
190
|
|
|
146
191
|
return hook_fn
|
|
147
192
|
|
|
148
|
-
def _build_forward_pre_hook(self, hook_type,
|
|
193
|
+
def _build_forward_pre_hook(self, hook_type, api_name):
|
|
149
194
|
def forward_pre_hook(module, args, kwargs=None):
|
|
150
|
-
"""
|
|
151
|
-
为确保多线程场景下 L1 级别数据采集的正确性,每个封装后的 API 的 init 方法和 forward_pre_hook 需要确保在一个线程内连续完成,
|
|
152
|
-
因此在 API 的 init 方法执行 ThreadSafe.acquire() 加锁操作,
|
|
153
|
-
并且在 API 的 forward_pre_hook 方法执行 ThreadSafe.release() 释放锁操作。
|
|
154
|
-
"""
|
|
155
195
|
if hook_type == Const.MODULE:
|
|
156
|
-
return
|
|
196
|
+
return None
|
|
157
197
|
|
|
158
198
|
tid = threading.get_ident()
|
|
159
|
-
if not self._should_execute_hook(hook_type,
|
|
160
|
-
|
|
161
|
-
return
|
|
199
|
+
if not self._should_execute_hook(hook_type, tid):
|
|
200
|
+
return None
|
|
162
201
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
202
|
+
with ThreadSafe():
|
|
203
|
+
original_state = self.ensure_gc_enabled()
|
|
204
|
+
self._register_forward_hook(module, api_name)
|
|
205
|
+
BaseHookManager.inner_api_count[tid] += 1
|
|
206
|
+
if BaseHookManager.inner_api_count[tid] != 1:
|
|
207
|
+
return None
|
|
208
|
+
|
|
209
|
+
full_forward_name = api_name + str(self._get_count(api_name)) + Const.SEP + Const.FORWARD
|
|
210
|
+
full_backward_name = api_name + str(self._get_count(api_name)) + Const.SEP + Const.BACKWARD
|
|
211
|
+
module.full_forward_name = full_forward_name
|
|
212
|
+
if kwargs is None:
|
|
213
|
+
kwargs = module.msprobe_input_kwargs.get(tid, {}) if hasattr(module, 'msprobe_input_kwargs') else {}
|
|
214
|
+
BaseHookManager.inner_switch[tid] = True
|
|
215
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
|
|
168
216
|
|
|
169
|
-
|
|
170
|
-
if kwargs is None:
|
|
171
|
-
kwargs = module.msprobe_input_kwargs if hasattr(module, 'msprobe_input_kwargs') else {}
|
|
172
|
-
try:
|
|
217
|
+
args = self._register_backward_hook(module, full_backward_name, args)
|
|
173
218
|
with self._no_grad_context():
|
|
174
|
-
|
|
175
|
-
self.data_collector.update_api_or_module_name(full_name)
|
|
219
|
+
self.data_collector.update_api_or_module_name(full_forward_name)
|
|
176
220
|
self.data_collector.forward_input_data_collect(
|
|
177
|
-
|
|
221
|
+
full_forward_name,
|
|
178
222
|
module,
|
|
179
223
|
self._pid,
|
|
180
224
|
module_input_output,
|
|
181
225
|
self._is_recompute
|
|
182
226
|
)
|
|
183
|
-
except Exception as e:
|
|
184
|
-
logger.error(f"The forward pre hook execution of the {full_name} API failed.")
|
|
185
|
-
raise e
|
|
186
|
-
finally:
|
|
187
227
|
BaseHookManager.inner_switch[tid] = False
|
|
188
|
-
|
|
228
|
+
self.restore_gc_state(original_state)
|
|
229
|
+
return args
|
|
189
230
|
|
|
190
231
|
return forward_pre_hook
|
|
191
232
|
|
|
192
|
-
def _build_forward_hook(self, hook_type,
|
|
233
|
+
def _build_forward_hook(self, hook_type, api_name):
|
|
193
234
|
def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None):
|
|
194
235
|
tid = threading.get_ident()
|
|
195
|
-
if not self._should_execute_hook(hook_type,
|
|
196
|
-
self._clear_input_kwargs(module)
|
|
236
|
+
if not self._should_execute_hook(hook_type, tid):
|
|
237
|
+
self._clear_input_kwargs(module, tid)
|
|
197
238
|
return None
|
|
198
239
|
|
|
199
240
|
with ThreadSafe():
|
|
200
|
-
|
|
241
|
+
original_state = self.ensure_gc_enabled()
|
|
242
|
+
if hook_type == Const.API:
|
|
243
|
+
if BaseHookManager.inner_api_count[tid] != 1:
|
|
244
|
+
if BaseHookManager.inner_api_count[tid] > 1:
|
|
245
|
+
BaseHookManager.inner_api_count[tid] -= 1
|
|
246
|
+
self._clear_input_kwargs(module, tid)
|
|
247
|
+
return None
|
|
248
|
+
|
|
249
|
+
kwargs, output = self._process_kwargs_and_output(
|
|
250
|
+
module,
|
|
251
|
+
tid,
|
|
252
|
+
hook_type,
|
|
253
|
+
kwargs_or_output,
|
|
254
|
+
output_or_kwargs
|
|
255
|
+
)
|
|
201
256
|
BaseHookManager.inner_switch[tid] = True
|
|
202
|
-
self.data_collector.update_api_or_module_name(full_name)
|
|
203
257
|
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
258
|
+
if hook_type == Const.API:
|
|
259
|
+
full_forward_name = api_name + str(self._get_count(api_name)) + Const.SEP + Const.FORWARD
|
|
260
|
+
full_backward_name = api_name + str(self._get_count(api_name)) + Const.SEP + Const.BACKWARD
|
|
261
|
+
output = self._register_backward_pre_hook(module, full_backward_name, output)
|
|
262
|
+
|
|
204
263
|
with self._no_grad_context():
|
|
205
|
-
if getattr(self.config, "online_run_ut", False):
|
|
206
|
-
if self.data_collector.scope and not self.data_collector.scope.check(full_name):
|
|
207
|
-
return None
|
|
208
|
-
if self.attl_manager:
|
|
209
|
-
self.attl_manager.attl_send(full_name, args, kwargs, output)
|
|
210
|
-
BaseHookManager.inner_switch[tid] = False
|
|
211
|
-
return None
|
|
212
264
|
if hook_type == Const.MODULE:
|
|
213
265
|
params_dict = self._get_params_dict(module)
|
|
214
266
|
setattr(module_input_output, Const.PARAMS, params_dict)
|
|
215
267
|
if params_dict:
|
|
216
|
-
self._register_param_hook(
|
|
217
|
-
self.data_collector.update_api_or_module_name(
|
|
268
|
+
self._register_param_hook(api_name, module, params_dict)
|
|
269
|
+
self.data_collector.update_api_or_module_name(api_name)
|
|
218
270
|
self.data_collector.forward_data_collect(
|
|
219
|
-
|
|
271
|
+
api_name,
|
|
220
272
|
module,
|
|
221
273
|
self._pid,
|
|
222
274
|
module_input_output,
|
|
@@ -224,37 +276,40 @@ class BaseHookManager(ABC):
|
|
|
224
276
|
)
|
|
225
277
|
self._init_params_grad_info(module, params_dict)
|
|
226
278
|
else:
|
|
279
|
+
self.data_collector.update_api_or_module_name(full_forward_name)
|
|
227
280
|
self.data_collector.forward_output_data_collect(
|
|
228
|
-
|
|
281
|
+
full_forward_name,
|
|
229
282
|
module,
|
|
230
283
|
self._pid,
|
|
231
284
|
module_input_output,
|
|
232
285
|
self._is_recompute
|
|
233
286
|
)
|
|
234
|
-
|
|
287
|
+
self._add_count(api_name)
|
|
288
|
+
BaseHookManager.inner_api_count[tid] -= 1
|
|
289
|
+
self._clear_input_kwargs(module, tid)
|
|
235
290
|
|
|
236
291
|
if self.data_collector.if_return_forward_new_output():
|
|
237
292
|
forward_new_output = self.data_collector.get_forward_new_output()
|
|
238
293
|
BaseHookManager.inner_switch[tid] = False
|
|
239
294
|
return forward_new_output
|
|
240
295
|
|
|
241
|
-
|
|
242
|
-
|
|
296
|
+
BaseHookManager.inner_switch[tid] = False
|
|
297
|
+
self.restore_gc_state(original_state)
|
|
298
|
+
return output
|
|
243
299
|
|
|
244
300
|
return forward_hook
|
|
245
301
|
|
|
246
302
|
def _build_backward_hook(self, hook_type, full_name):
|
|
247
303
|
def backward_hook(module, grad_input, grad_output):
|
|
248
304
|
tid = threading.get_ident()
|
|
249
|
-
if not self._should_execute_hook(hook_type,
|
|
305
|
+
if not self._should_execute_hook(hook_type, tid, is_forward=False):
|
|
250
306
|
return
|
|
251
307
|
|
|
252
308
|
with ThreadSafe():
|
|
309
|
+
original_state = self.ensure_gc_enabled()
|
|
253
310
|
BaseHookManager.inner_switch[tid] = True
|
|
254
311
|
self.data_collector.update_api_or_module_name(full_name)
|
|
255
|
-
|
|
256
|
-
BaseHookManager.inner_switch[tid] = False
|
|
257
|
-
return
|
|
312
|
+
|
|
258
313
|
need_exchange = self._need_exchange(module) if hook_type == Const.MODULE else True
|
|
259
314
|
if need_exchange:
|
|
260
315
|
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
|
|
@@ -267,6 +322,10 @@ class BaseHookManager(ABC):
|
|
|
267
322
|
module_input_output,
|
|
268
323
|
self._is_recompute
|
|
269
324
|
)
|
|
325
|
+
if hook_type == Const.MODULE:
|
|
326
|
+
params_dict = self._get_params_dict(module)
|
|
327
|
+
self.data_collector.params_data_collect_in_bw_hook(params_dict, full_name)
|
|
270
328
|
BaseHookManager.inner_switch[tid] = False
|
|
329
|
+
self.restore_gc_state(original_state)
|
|
271
330
|
|
|
272
331
|
return backward_hook
|