mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -17,6 +17,7 @@ import os
|
|
|
17
17
|
from collections import defaultdict, namedtuple
|
|
18
18
|
|
|
19
19
|
import mindspore as ms
|
|
20
|
+
from mindspore.ops.operations import _inner_ops as inner
|
|
20
21
|
from mindspore._c_expression import MSContext
|
|
21
22
|
|
|
22
23
|
from msprobe.core.common.const import Const, MsgConst
|
|
@@ -28,7 +29,8 @@ from msprobe.mindspore.common.const import Const as MsConst
|
|
|
28
29
|
from msprobe.mindspore.common.utils import (
|
|
29
30
|
set_register_backward_hook_functions,
|
|
30
31
|
check_save_param,
|
|
31
|
-
is_graph_mode_cell_dump_allowed
|
|
32
|
+
is_graph_mode_cell_dump_allowed,
|
|
33
|
+
wrap_backward_hook_call_func
|
|
32
34
|
)
|
|
33
35
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
34
36
|
from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump
|
|
@@ -41,6 +43,7 @@ from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
|
|
|
41
43
|
|
|
42
44
|
try:
|
|
43
45
|
from mindspore._c_expression import _dump_start, _dump_stop, _dump_step, _set_init_iter, _dump_set_dynamic
|
|
46
|
+
import mindspore as ms
|
|
44
47
|
except ImportError:
|
|
45
48
|
enable_dynamic_kbyk_dump = False
|
|
46
49
|
else:
|
|
@@ -80,6 +83,9 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
80
83
|
if self._is_kernel_dump() and not self.task_config.is_regex_valid:
|
|
81
84
|
raise ValueError('Illegal regular expressions exist in the list.')
|
|
82
85
|
|
|
86
|
+
setattr(inner.CellBackwardHook, '__call__',
|
|
87
|
+
wrap_backward_hook_call_func(getattr(inner.CellBackwardHook, '__call__')))
|
|
88
|
+
|
|
83
89
|
if self._is_kernel_dump() and _msprobe_c:
|
|
84
90
|
os.environ["MS_HOOK_ENABLE"] = "on"
|
|
85
91
|
_msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path)
|
|
@@ -90,7 +96,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
90
96
|
|
|
91
97
|
Runtime.step_count = 0
|
|
92
98
|
Runtime.is_running = False
|
|
93
|
-
if enable_dynamic_kbyk_dump:
|
|
99
|
+
if enable_dynamic_kbyk_dump and self.config.level_ori == Const.LEVEL_L2:
|
|
94
100
|
_dump_set_dynamic()
|
|
95
101
|
|
|
96
102
|
@staticmethod
|
|
@@ -160,7 +166,8 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
160
166
|
instance.service.stop()
|
|
161
167
|
else:
|
|
162
168
|
Runtime.is_running = False
|
|
163
|
-
if enable_dynamic_kbyk_dump:
|
|
169
|
+
if enable_dynamic_kbyk_dump and instance.config.level_ori == Const.LEVEL_L2:
|
|
170
|
+
ms.runtime.synchronize()
|
|
164
171
|
_dump_stop()
|
|
165
172
|
if cls._is_kernel_dump() and _msprobe_c:
|
|
166
173
|
_msprobe_c._PrecisionDebugger().stop()
|
|
@@ -175,8 +182,8 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
175
182
|
with ThreadSafe():
|
|
176
183
|
instance.service.step()
|
|
177
184
|
if is_graph_mode_cell_dump_allowed(instance.config):
|
|
178
|
-
GraphModeCellDump.step()
|
|
179
|
-
if enable_dynamic_kbyk_dump:
|
|
185
|
+
GraphModeCellDump.step(instance.config.dump_path, instance.config.step, instance.config.task)
|
|
186
|
+
if enable_dynamic_kbyk_dump and instance.config.level_ori == Const.LEVEL_L2:
|
|
180
187
|
_dump_step(1)
|
|
181
188
|
if cls._is_kernel_dump() and _msprobe_c:
|
|
182
189
|
_msprobe_c._PrecisionDebugger().step()
|
|
@@ -207,12 +214,9 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
207
214
|
check_save_param(variable, name, save_backward)
|
|
208
215
|
except ValueError:
|
|
209
216
|
return
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
if not instance.service:
|
|
214
|
-
instance.service = MindsporeService(instance.config)
|
|
215
|
-
instance.service.save(variable, name, save_backward)
|
|
217
|
+
if not instance.service:
|
|
218
|
+
instance.service = MindsporeService(instance.config)
|
|
219
|
+
instance.service.save(variable, name, save_backward)
|
|
216
220
|
|
|
217
221
|
@classmethod
|
|
218
222
|
def _need_service(cls):
|
|
@@ -220,7 +224,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
220
224
|
if not instance:
|
|
221
225
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
222
226
|
if instance.config.level_ori == Const.LEVEL_L2:
|
|
223
|
-
return
|
|
227
|
+
return not instance._is_graph_dump(instance.config)
|
|
224
228
|
if instance.config.execution_mode != MsConst.PYNATIVE_MODE:
|
|
225
229
|
return False
|
|
226
230
|
else:
|
|
@@ -38,15 +38,19 @@ DEFAULT_RANK_DIR = "rank0"
|
|
|
38
38
|
KEY_LAYERS = "layers"
|
|
39
39
|
construct = {}
|
|
40
40
|
cell_list = []
|
|
41
|
+
free_cells = {}
|
|
42
|
+
parent_cell_types = {}
|
|
41
43
|
KEY_SIDE_EFFECT = "side_effect_io"
|
|
42
44
|
KEY_TOPLAYER = "TopLayer"
|
|
43
45
|
KEY_FORWARD = CoreConst.FORWARD
|
|
44
46
|
KEY_BACKWARD = CoreConst.BACKWARD
|
|
45
47
|
KEY_INPUT = CoreConst.INPUT
|
|
46
48
|
KEY_OUTPUT = CoreConst.OUTPUT
|
|
47
|
-
KEY_DUMP_TENSOR_DATA = "
|
|
49
|
+
KEY_DUMP_TENSOR_DATA = "dump_tensor_data/"
|
|
48
50
|
KEY_STATISTIC_CSV = "statistic.csv"
|
|
49
51
|
KEY_TD_FLAG = "td_flag"
|
|
52
|
+
# 设置落盘文件检测超时时间
|
|
53
|
+
TIMEOUT = 600
|
|
50
54
|
td = ops.TensorDump()
|
|
51
55
|
if (ms.__version__ >= "2.5.0"):
|
|
52
56
|
td_in = ops.TensorDump("in")
|
|
@@ -219,8 +223,16 @@ def cell_construct_wrapper(func, self):
|
|
|
219
223
|
def sort_filenames(path):
|
|
220
224
|
filenames = os.listdir(path)
|
|
221
225
|
id_pattern = re.compile(rf'{CoreConst.REPLACEMENT_CHARACTER}(\d+){CoreConst.NUMPY_SUFFIX}$')
|
|
222
|
-
|
|
223
|
-
|
|
226
|
+
# 只保留能提取到数字id的文件,避免数组越界
|
|
227
|
+
valid_files = []
|
|
228
|
+
for filename in filenames:
|
|
229
|
+
match = id_pattern.findall(filename)
|
|
230
|
+
if match and match[0].isdigit():
|
|
231
|
+
valid_files.append(filename)
|
|
232
|
+
else:
|
|
233
|
+
logger.warning(f"File {filename} does not match the expected pattern and will be ignored.")
|
|
234
|
+
valid_files.sort(key=lambda x: int(id_pattern.findall(x)[0]))
|
|
235
|
+
return valid_files
|
|
224
236
|
|
|
225
237
|
|
|
226
238
|
def rename_filename(path="", data_df=None):
|
|
@@ -294,7 +306,24 @@ def check_relation(cell_name, parent_cell_name):
|
|
|
294
306
|
return False
|
|
295
307
|
|
|
296
308
|
|
|
309
|
+
def get_parent_cell_name(child_cell_name):
|
|
310
|
+
parent_cell_name = ''
|
|
311
|
+
|
|
312
|
+
last_dot_index = child_cell_name.rfind(CoreConst.SEP)
|
|
313
|
+
if last_dot_index == -1:
|
|
314
|
+
return parent_cell_name
|
|
315
|
+
|
|
316
|
+
layers_pattern = rf"{CoreConst.SEP}{KEY_LAYERS}{CoreConst.SEP}\d+$"
|
|
317
|
+
if re.search(layers_pattern, child_cell_name):
|
|
318
|
+
parent_cell_name = re.sub(layers_pattern, '', child_cell_name)
|
|
319
|
+
else:
|
|
320
|
+
parent_cell_name = child_cell_name[:last_dot_index]
|
|
321
|
+
|
|
322
|
+
return parent_cell_name
|
|
323
|
+
|
|
324
|
+
|
|
297
325
|
def get_construct(cell_list_input):
|
|
326
|
+
global free_cells, parent_cell_types
|
|
298
327
|
for cell in cell_list_input:
|
|
299
328
|
cell_name = get_cell_name(cell)
|
|
300
329
|
cell_data_mode = get_data_mode(cell)
|
|
@@ -308,7 +337,20 @@ def get_construct(cell_list_input):
|
|
|
308
337
|
found_flag = True
|
|
309
338
|
break
|
|
310
339
|
if not found_flag:
|
|
311
|
-
|
|
340
|
+
cell_name_with_mode = f'{cell_name}{CoreConst.SEP}{cell_data_mode}'
|
|
341
|
+
if cell_name_with_mode in free_cells:
|
|
342
|
+
construct.update({cell: free_cells.get(cell_name_with_mode)})
|
|
343
|
+
continue
|
|
344
|
+
|
|
345
|
+
parent_cell = None
|
|
346
|
+
parent_cell_name = get_parent_cell_name(cell_name)
|
|
347
|
+
if parent_cell_name and cell_name in parent_cell_types:
|
|
348
|
+
parent_cell = CoreConst.SEP.join([CoreConst.CELL, parent_cell_name, parent_cell_types.get(cell_name)])
|
|
349
|
+
second_last_dot_index = cell.rfind(CoreConst.SEP, 0, cell.rfind(CoreConst.SEP))
|
|
350
|
+
parent_cell = f'{parent_cell}{cell[second_last_dot_index:]}'
|
|
351
|
+
free_cells[cell_name_with_mode] = parent_cell
|
|
352
|
+
|
|
353
|
+
construct.update({cell: parent_cell})
|
|
312
354
|
|
|
313
355
|
|
|
314
356
|
def generate_construct(path):
|
|
@@ -462,7 +504,7 @@ def process_csv(path):
|
|
|
462
504
|
if col_name in columns:
|
|
463
505
|
value = convert_special_values(row[col_name])
|
|
464
506
|
tensor_json[json_key] = value
|
|
465
|
-
|
|
507
|
+
|
|
466
508
|
if io_key == KEY_INPUT:
|
|
467
509
|
data_info.append([op_name, CoreConst.INPUT_ARGS, tensor_json])
|
|
468
510
|
elif io_key == KEY_OUTPUT:
|
|
@@ -534,59 +576,75 @@ def generate_stack_info(path):
|
|
|
534
576
|
logger.info(f"Stack data saved to {json_path}")
|
|
535
577
|
|
|
536
578
|
|
|
537
|
-
def is_download_finished(directory,
|
|
579
|
+
def is_download_finished(directory, save_flag):
|
|
538
580
|
"""
|
|
539
581
|
判断指定目录在一段时间后是否有数据被下载完成
|
|
540
582
|
:param directory: 指定目录的路径
|
|
541
|
-
:param
|
|
583
|
+
:param save_flag: 数据落盘完成后的标志文件
|
|
542
584
|
:return: 如有数据被下载完成返回 True,否则返回 False
|
|
543
585
|
"""
|
|
586
|
+
# 设定一定的延迟间隔,避免频繁进行磁盘的io读取操作
|
|
587
|
+
time.sleep(0.5)
|
|
588
|
+
logger.info("Waiting for download...")
|
|
544
589
|
# 检查目录是否存在
|
|
545
590
|
if not os.path.exists(directory):
|
|
546
591
|
logger.warning(f"The specified directory {directory} does not exist.")
|
|
547
592
|
return False
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
return False
|
|
554
|
-
else:
|
|
555
|
-
return True
|
|
593
|
+
|
|
594
|
+
# 遍历当前目录中的所有条目
|
|
595
|
+
for entry_path in os.listdir(directory):
|
|
596
|
+
if entry_path.startswith(save_flag):
|
|
597
|
+
return True
|
|
556
598
|
|
|
599
|
+
return False
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def process_step(dump_path, flag_path, step, step_list):
|
|
603
|
+
if step not in step_list:
|
|
604
|
+
return
|
|
605
|
+
|
|
606
|
+
if not os.path.exists(dump_path):
|
|
607
|
+
logger.warning('No grap cell data is dumped.')
|
|
608
|
+
create_directory(dump_path)
|
|
609
|
+
return
|
|
557
610
|
|
|
558
|
-
def process(dump_path):
|
|
559
611
|
rank_id = os.environ.get('RANK_ID')
|
|
560
612
|
rank_dir = DEFAULT_RANK_DIR
|
|
561
613
|
if rank_id is not None:
|
|
562
614
|
rank_dir = CoreConst.RANK + str(rank_id)
|
|
563
615
|
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
616
|
+
step_dir = CoreConst.STEP + str(step)
|
|
617
|
+
|
|
618
|
+
step_path = os.path.join(dump_path, step_dir)
|
|
619
|
+
rank_path = os.path.join(step_path, rank_dir)
|
|
620
|
+
npy_path = os.path.join(rank_path, CoreConst.DUMP_TENSOR_DATA)
|
|
621
|
+
save_finish_flag = f"step_{step}"
|
|
622
|
+
start_time = time.time()
|
|
623
|
+
while True:
|
|
624
|
+
is_finished = is_download_finished(flag_path, save_finish_flag)
|
|
625
|
+
if not is_finished:
|
|
626
|
+
logger.info("There is data being downloaded in the specified directory, continue checking...")
|
|
627
|
+
else:
|
|
628
|
+
logger.info("There is no data being downloaded in the specified directory, Stop checking.")
|
|
629
|
+
break
|
|
630
|
+
elapsed_time = time.time() - start_time
|
|
631
|
+
if elapsed_time > TIMEOUT:
|
|
632
|
+
logger.error(f"Check timed out after {TIMEOUT} seconds. Exiting.")
|
|
633
|
+
return
|
|
634
|
+
logger.info(f"==========Start processing step_{step}'s data that has already been stored on the disk!==========")
|
|
635
|
+
rename_filename(path=npy_path)
|
|
636
|
+
generate_construct(npy_path)
|
|
637
|
+
generate_dump_info(npy_path)
|
|
638
|
+
generate_stack_info(npy_path)
|
|
639
|
+
# 单卡场景,rank目录名称为rank
|
|
640
|
+
if rank_id is None:
|
|
641
|
+
new_rank_path = os.path.join(step_path, CoreConst.RANK)
|
|
642
|
+
try:
|
|
643
|
+
move_directory(rank_path, new_rank_path)
|
|
644
|
+
logger.info(f"Directory was successfully renamed to: {new_rank_path}")
|
|
645
|
+
except Exception as e:
|
|
646
|
+
logger.warning(f"Failed to renamed to {new_rank_path}: {e}")
|
|
647
|
+
logger.info(f"==========Step_{step}'s JSON file generation completed!==========")
|
|
590
648
|
|
|
591
649
|
|
|
592
650
|
# 删除csv文件中每行数据最后面的逗号
|
|
@@ -644,7 +702,15 @@ def merge_file(dump_path, rank_dir, file_dict):
|
|
|
644
702
|
" and the index is out of bounds.")
|
|
645
703
|
|
|
646
704
|
|
|
647
|
-
def
|
|
705
|
+
def process_statistics_step(dump_path, step, step_list):
|
|
706
|
+
if step_list and step not in step_list:
|
|
707
|
+
return
|
|
708
|
+
|
|
709
|
+
if not os.path.exists(dump_path):
|
|
710
|
+
logger.warning('No grap cell data is dumped.')
|
|
711
|
+
create_directory(dump_path)
|
|
712
|
+
return
|
|
713
|
+
|
|
648
714
|
rank_id = os.environ.get('RANK_ID')
|
|
649
715
|
rank_dir_kbk = "rank_0"
|
|
650
716
|
if rank_id is not None:
|
|
@@ -673,25 +739,24 @@ def process_statistics(dump_path):
|
|
|
673
739
|
|
|
674
740
|
rank_dir = rank_dir_kbk.replace(CoreConst.REPLACEMENT_CHARACTER, '')
|
|
675
741
|
dir_list = os.listdir(dump_path)
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
logger.info("==========JSON file generation completed!==========")
|
|
742
|
+
step_dir = CoreConst.STEP + str(step)
|
|
743
|
+
step_path = os.path.join(dump_path, step_dir)
|
|
744
|
+
rank_path = os.path.join(step_path, rank_dir)
|
|
745
|
+
csv_path = os.path.join(rank_path, KEY_STATISTIC_CSV)
|
|
746
|
+
logger.info("==========Start processing data csv!==========")
|
|
747
|
+
generate_construct(csv_path)
|
|
748
|
+
generate_dump_info(csv_path)
|
|
749
|
+
generate_stack_info(csv_path)
|
|
750
|
+
remove_path(rank_path_kbk)
|
|
751
|
+
# 单卡场景,rank目录名称为rank
|
|
752
|
+
if rank_id is None:
|
|
753
|
+
new_rank_path = os.path.join(step_path, CoreConst.RANK)
|
|
754
|
+
try:
|
|
755
|
+
move_directory(rank_path, new_rank_path)
|
|
756
|
+
logger.info(f"Directory was successfully renamed to: {new_rank_path}")
|
|
757
|
+
except Exception as e:
|
|
758
|
+
logger.warning(f"Failed to renamed to {new_rank_path}: {e}")
|
|
759
|
+
logger.info("==========JSON file generation completed!==========")
|
|
695
760
|
|
|
696
761
|
|
|
697
762
|
def get_yaml_keys(yaml_data):
|
|
@@ -786,7 +851,7 @@ def create_kbyk_json(dump_path, summary_mode, step):
|
|
|
786
851
|
|
|
787
852
|
|
|
788
853
|
def start(config: CellDumpConfig):
|
|
789
|
-
global dump_task
|
|
854
|
+
global dump_task, parent_cell_types
|
|
790
855
|
dump_task = config.task
|
|
791
856
|
net = config.net
|
|
792
857
|
dump_path = config.dump_path
|
|
@@ -814,7 +879,7 @@ def start(config: CellDumpConfig):
|
|
|
814
879
|
return
|
|
815
880
|
|
|
816
881
|
if isinstance(net, nn.Cell):
|
|
817
|
-
net = (('', net),)
|
|
882
|
+
net = (('', net, None),)
|
|
818
883
|
|
|
819
884
|
td_config_path = ""
|
|
820
885
|
try:
|
|
@@ -837,6 +902,7 @@ def start(config: CellDumpConfig):
|
|
|
837
902
|
black_list = ["grad_reducer", ""]
|
|
838
903
|
|
|
839
904
|
for name_and_model in net:
|
|
905
|
+
parent_cell_types[name_and_model[0]] = name_and_model[2].__class__.__name__
|
|
840
906
|
for name, cell in name_and_model[1].cells_and_names(name_prefix=name_and_model[0]):
|
|
841
907
|
class_name = cell.__class__.__name__
|
|
842
908
|
# 跳过黑名单cell
|
|
@@ -871,7 +937,3 @@ def start(config: CellDumpConfig):
|
|
|
871
937
|
cell.data_mode = data_mode
|
|
872
938
|
|
|
873
939
|
logger.info("==========The cell_dump_process_start phase is Finished!==========")
|
|
874
|
-
if dump_task == CoreConst.TENSOR:
|
|
875
|
-
atexit.register(process, dump_path=dump_path)
|
|
876
|
-
if dump_task == CoreConst.STATISTICS:
|
|
877
|
-
atexit.register(process_statistics, dump_path=dump_path)
|
|
@@ -197,8 +197,16 @@ def cell_construct_wrapper(func, self):
|
|
|
197
197
|
def sort_filenames(path):
|
|
198
198
|
filenames = os.listdir(path)
|
|
199
199
|
id_pattern = re.compile(rf'{CoreConst.REPLACEMENT_CHARACTER}(\d+){CoreConst.NUMPY_SUFFIX}$')
|
|
200
|
-
|
|
201
|
-
|
|
200
|
+
# 只保留能提取到数字id的文件,避免数组越界
|
|
201
|
+
valid_files = []
|
|
202
|
+
for filename in filenames:
|
|
203
|
+
match = id_pattern.findall(filename)
|
|
204
|
+
if match and match[0].isdigit():
|
|
205
|
+
valid_files.append(filename)
|
|
206
|
+
else:
|
|
207
|
+
logger.warning(f"File {filename} does not match the expected pattern and will be ignored.")
|
|
208
|
+
valid_files.sort(key=lambda x: int(id_pattern.findall(x)[0]))
|
|
209
|
+
return valid_files
|
|
202
210
|
|
|
203
211
|
|
|
204
212
|
def rename_filename(path="", data_df=None):
|
|
@@ -14,7 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
-
|
|
17
|
+
import glob
|
|
18
|
+
import tempfile
|
|
18
19
|
import mindspore as ms
|
|
19
20
|
from mindspore import hal, ops, Tensor
|
|
20
21
|
from mindspore.ops.primitive import _run_op
|
|
@@ -28,15 +29,20 @@ import msprobe.mindspore.dump.cell_dump_process as cellDumperWithDumpGradient
|
|
|
28
29
|
import msprobe.mindspore.dump.cell_dump_with_insert_gradient as cellDumperWithInsertGradient
|
|
29
30
|
|
|
30
31
|
tensordump_flag = True
|
|
32
|
+
DEFAULT_RANK_DIR = "rank0"
|
|
31
33
|
try:
|
|
32
34
|
from mindspore._c_expression import _tensordump_set_step
|
|
33
35
|
except ImportError:
|
|
34
36
|
tensordump_flag = False
|
|
35
37
|
|
|
38
|
+
graph_step_flag = True
|
|
39
|
+
try:
|
|
40
|
+
from mindspore._c_expression import _dump_step
|
|
41
|
+
except ImportError:
|
|
42
|
+
graph_step_flag = False
|
|
36
43
|
|
|
37
|
-
class GraphModeCellDump:
|
|
38
|
-
task = CoreConst.STATISTICS
|
|
39
44
|
|
|
45
|
+
class GraphModeCellDump:
|
|
40
46
|
def __init__(self, config: DebuggerConfig, model, strict=True):
|
|
41
47
|
self.net = model
|
|
42
48
|
self.white_list = []
|
|
@@ -49,20 +55,40 @@ class GraphModeCellDump:
|
|
|
49
55
|
self.list = config.list
|
|
50
56
|
self.data_mode = config.data_mode
|
|
51
57
|
self.file_format = config.file_format
|
|
52
|
-
GraphModeCellDump.task = config.task
|
|
53
58
|
self.summary_mode = config.summary_mode
|
|
59
|
+
self.task = config.task
|
|
54
60
|
self.check_config(strict)
|
|
55
61
|
self.set_step()
|
|
56
62
|
|
|
57
63
|
@staticmethod
|
|
58
|
-
def step():
|
|
64
|
+
def step(dump_path, step_list, task):
|
|
59
65
|
# 更新TensorDump Step
|
|
60
|
-
if
|
|
66
|
+
if task == CoreConst.TENSOR:
|
|
61
67
|
hal.synchronize()
|
|
62
68
|
temp_tensor = ms.Tensor([1], dtype=ms.float32)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
69
|
+
rank_id = os.environ.get('RANK_ID')
|
|
70
|
+
rank_dir = DEFAULT_RANK_DIR
|
|
71
|
+
|
|
72
|
+
if rank_id is not None:
|
|
73
|
+
rank_dir = CoreConst.RANK + str(rank_id)
|
|
74
|
+
|
|
75
|
+
with tempfile.TemporaryDirectory(dir=dump_path, prefix=rank_dir) as temp_dir:
|
|
76
|
+
save_file_flag = f"{temp_dir}/step_{Runtime.step_count}"
|
|
77
|
+
_run_op(ops.TensorDump(), "TensorDump", (save_file_flag, temp_tensor))
|
|
78
|
+
step_flag = "<tensordump-update-step>"
|
|
79
|
+
_run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor))
|
|
80
|
+
ops.tensordump(step_flag, temp_tensor)
|
|
81
|
+
cellDumperWithDumpGradient.process_step(dump_path, temp_dir, Runtime.step_count, step_list)
|
|
82
|
+
|
|
83
|
+
# 更新静态图KBK dump的step数
|
|
84
|
+
if task == CoreConst.STATISTICS:
|
|
85
|
+
if not graph_step_flag:
|
|
86
|
+
raise Exception(
|
|
87
|
+
"Importing _dump_step failed, "
|
|
88
|
+
"please use the latest version package of MindSpore."
|
|
89
|
+
)
|
|
90
|
+
_dump_step(1)
|
|
91
|
+
cellDumperWithDumpGradient.process_statistics_step(dump_path, Runtime.step_count, step_list)
|
|
66
92
|
|
|
67
93
|
def check_config(self, strict):
|
|
68
94
|
if not self.net:
|
|
@@ -16,6 +16,8 @@
|
|
|
16
16
|
import os
|
|
17
17
|
from collections import OrderedDict
|
|
18
18
|
import mindspore as ms
|
|
19
|
+
from mindspore import hal, ops, Tensor
|
|
20
|
+
from mindspore.ops.primitive import _run_op
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
def _iterate_items(data):
|
|
@@ -121,3 +123,12 @@ def save_grad(save_dir, name, data):
|
|
|
121
123
|
dump_dir = generate_dump_dir(save_dir)
|
|
122
124
|
suffix_name = name + '_grad'
|
|
123
125
|
return _SaveGradCell(dump_dir, suffix_name)(data)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def step():
|
|
129
|
+
hal.synchronize()
|
|
130
|
+
temp_tensor = Tensor([1], dtype=ms.float32)
|
|
131
|
+
step_flag = "<tensordump-update-step>"
|
|
132
|
+
_run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor))
|
|
133
|
+
ops.tensordump(step_flag, temp_tensor)
|
|
134
|
+
hal.synchronize()
|
|
@@ -40,36 +40,36 @@ cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
|
40
40
|
if not is_mindtorch():
|
|
41
41
|
_api_types = {
|
|
42
42
|
Const.MS_FRAMEWORK: {
|
|
43
|
-
Const.MS_API_TYPE_OPS: (ops, (ops,)),
|
|
44
|
-
Const.MS_API_TYPE_TENSOR: (Tensor, (Tensor,)),
|
|
45
|
-
Const.MS_API_TYPE_MINT: (mint, (mint,)),
|
|
46
|
-
Const.MS_API_TYPE_MINT_FUNC: (functional, (functional,)),
|
|
47
|
-
Const.MS_API_TYPE_COM: (comm_func, (comm_func,)),
|
|
48
|
-
Const.MS_API_TYPE_MINT_DIST: (distributed, (distributed,))
|
|
43
|
+
Const.MS_API_TYPE_OPS: ((ops,), (ops,)),
|
|
44
|
+
Const.MS_API_TYPE_TENSOR: ((Tensor,), (Tensor,)),
|
|
45
|
+
Const.MS_API_TYPE_MINT: ((mint,), (mint,)),
|
|
46
|
+
Const.MS_API_TYPE_MINT_FUNC: ((functional,), (functional,)),
|
|
47
|
+
Const.MS_API_TYPE_COM: ((comm_func,), (comm_func,)),
|
|
48
|
+
Const.MS_API_TYPE_MINT_DIST: ((distributed,), (distributed,))
|
|
49
49
|
}
|
|
50
50
|
}
|
|
51
51
|
if stub_tensor_existed:
|
|
52
52
|
_api_types.get(Const.MS_FRAMEWORK).update(
|
|
53
|
-
{Const.MS_API_TYPE_STUB_TENSOR: (StubTensor, (StubTensor,))}
|
|
53
|
+
{Const.MS_API_TYPE_STUB_TENSOR: ((StubTensor,), (StubTensor,))}
|
|
54
54
|
)
|
|
55
55
|
|
|
56
56
|
_supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),)
|
|
57
|
-
|
|
57
|
+
_blacklist = []
|
|
58
58
|
else:
|
|
59
59
|
import torch
|
|
60
60
|
import torch_npu
|
|
61
61
|
_api_types = {
|
|
62
62
|
Const.MT_FRAMEWORK: {
|
|
63
|
-
Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
|
|
64
|
-
Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)),
|
|
65
|
-
Const.PT_API_TYPE_TORCH: (torch, (torch,)),
|
|
66
|
-
Const.PT_API_TYPE_NPU: (torch_npu, (torch_npu,)),
|
|
67
|
-
Const.PT_API_TYPE_DIST: (torch.distributed, (torch.distributed, torch.distributed.distributed_c10d))
|
|
63
|
+
Const.PT_API_TYPE_FUNCTIONAL: ((torch.nn.functional,), (torch.nn.functional,)),
|
|
64
|
+
Const.PT_API_TYPE_TENSOR: ((torch.Tensor,), (torch.Tensor,)),
|
|
65
|
+
Const.PT_API_TYPE_TORCH: ((torch,), (torch,)),
|
|
66
|
+
Const.PT_API_TYPE_NPU: ((torch_npu,), (torch_npu,)),
|
|
67
|
+
Const.PT_API_TYPE_DIST: ((torch.distributed,), (torch.distributed, torch.distributed.distributed_c10d))
|
|
68
68
|
}
|
|
69
69
|
}
|
|
70
70
|
_supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module',
|
|
71
71
|
MsConst.SUPPORTED_API_LIST_FILE),)
|
|
72
|
-
|
|
72
|
+
_blacklist = []
|
|
73
73
|
|
|
74
74
|
_inner_used_api = {
|
|
75
75
|
Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: (
|
|
@@ -87,12 +87,11 @@ _inner_used_api = {
|
|
|
87
87
|
class ApiTemplate(HOOKCell):
|
|
88
88
|
def __init__(self, api_name, api_func, prefix, hook_build_func):
|
|
89
89
|
self.api_name = api_name
|
|
90
|
-
self.api_func = api_func
|
|
91
90
|
self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
|
|
92
|
-
super().__init__(hook_build_func)
|
|
93
91
|
distributed_prefix = Const.DIST_API_TYPE_PREFIX if is_mindtorch() else Const.MINT_DIST_API_TYPE_PREFIX
|
|
94
|
-
|
|
95
|
-
|
|
92
|
+
self.op_is_distributed = prefix == distributed_prefix
|
|
93
|
+
super().__init__(hook_build_func)
|
|
94
|
+
self.api_func = api_func
|
|
96
95
|
|
|
97
96
|
@staticmethod
|
|
98
97
|
def async_to_sync(output):
|
|
@@ -161,7 +160,7 @@ def get_api_register(return_new=False):
|
|
|
161
160
|
_inner_used_api,
|
|
162
161
|
_supported_api_list_path,
|
|
163
162
|
ApiTemplate,
|
|
164
|
-
|
|
163
|
+
_blacklist
|
|
165
164
|
)
|
|
166
165
|
|
|
167
166
|
global api_register
|
|
@@ -171,6 +170,6 @@ def get_api_register(return_new=False):
|
|
|
171
170
|
_inner_used_api,
|
|
172
171
|
_supported_api_list_path,
|
|
173
172
|
ApiTemplate,
|
|
174
|
-
|
|
173
|
+
_blacklist
|
|
175
174
|
)
|
|
176
175
|
return api_register
|
|
@@ -19,8 +19,6 @@ from collections import defaultdict
|
|
|
19
19
|
import mindspore as ms
|
|
20
20
|
from mindspore import nn
|
|
21
21
|
|
|
22
|
-
from msprobe.core.common.runtime import Runtime
|
|
23
|
-
from msprobe.core.common.utils import ThreadSafe
|
|
24
22
|
from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions
|
|
25
23
|
|
|
26
24
|
ms_version = ms.__version__
|
|
@@ -37,48 +35,28 @@ def get_cell_count(name):
|
|
|
37
35
|
def __init__(self, hook_build_func) -> None:
|
|
38
36
|
super(HOOKCell, self).__init__()
|
|
39
37
|
self.msprobe_input_kwargs = {}
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
if not Runtime.is_running:
|
|
47
|
-
return
|
|
48
|
-
prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
|
|
49
|
-
ThreadSafe.acquire()
|
|
50
|
-
if callable(hook_build_func):
|
|
51
|
-
hook_set = hook_build_func(prefix)
|
|
52
|
-
if ms_version < "2.6.0" and not is_mindtorch():
|
|
53
|
-
getattr(self, "_forward_pre_hook", {})[id(self)] = hook_set.forward_pre_hook
|
|
38
|
+
prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
|
|
39
|
+
if callable(hook_build_func):
|
|
40
|
+
hook_set = hook_build_func(prefix)
|
|
41
|
+
if ms_version < "2.6.0" and not is_mindtorch():
|
|
42
|
+
getattr(self, "_forward_pre_hook", {})[id(self)] = hook_set.forward_pre_hook
|
|
43
|
+
if hook_set.forward_hook:
|
|
54
44
|
getattr(self, "_forward_hook", {})[id(self)] = hook_set.forward_hook
|
|
55
|
-
|
|
56
|
-
|
|
45
|
+
else:
|
|
46
|
+
self.register_forward_pre_hook(hook_set.forward_pre_hook)
|
|
47
|
+
if hook_set.forward_hook:
|
|
57
48
|
self.register_forward_hook(hook_set.forward_hook)
|
|
58
|
-
register_backward_hook_functions["full"](self, hook_set.backward_hook)
|
|
59
|
-
register_backward_hook_functions["pre"](self, hook_set.backward_pre_hook)
|
|
60
49
|
|
|
61
50
|
|
|
62
|
-
# 重载call,加全局标志。
|
|
63
51
|
def __call__(self, *args, **kwargs):
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
changed = True
|
|
68
|
-
try:
|
|
69
|
-
self.msprobe_input_kwargs = kwargs
|
|
70
|
-
out = super(HOOKCell, self).__call__(*args, **kwargs)
|
|
71
|
-
except Exception as e:
|
|
72
|
-
raise e
|
|
73
|
-
finally:
|
|
74
|
-
if changed:
|
|
75
|
-
HOOKCell.inner_stop_hook[self.tid] = False
|
|
52
|
+
tid = threading.get_ident()
|
|
53
|
+
self.msprobe_input_kwargs[tid] = kwargs
|
|
54
|
+
out = super(HOOKCell, self).__call__(*args, **kwargs)
|
|
76
55
|
return out
|
|
77
56
|
|
|
78
57
|
|
|
79
58
|
hook_cell_dict = {
|
|
80
59
|
"cell_count": defaultdict(int),
|
|
81
|
-
"inner_stop_hook": defaultdict(bool),
|
|
82
60
|
"add_cell_count": staticmethod(add_cell_count),
|
|
83
61
|
"get_cell_count": staticmethod(get_cell_count),
|
|
84
62
|
"__init__": __init__,
|