mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -15,15 +15,17 @@
|
|
|
15
15
|
|
|
16
16
|
from msprobe.mindspore.common.const import Const
|
|
17
17
|
from msprobe.core.common.log import logger
|
|
18
|
+
from msprobe.mindspore.common.utils import is_graph_mode_cell_dump_allowed
|
|
18
19
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
19
20
|
from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
|
|
20
21
|
from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump
|
|
22
|
+
from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
class DumpToolFactory:
|
|
24
26
|
tools = {
|
|
25
27
|
Const.CELL: {
|
|
26
|
-
Const.GRAPH_KBYK_MODE:
|
|
28
|
+
Const.GRAPH_KBYK_MODE: GraphModeCellDump,
|
|
27
29
|
Const.GRAPH_GE_MODE: None,
|
|
28
30
|
Const.PYNATIVE_MODE: None
|
|
29
31
|
},
|
|
@@ -40,9 +42,15 @@ class DumpToolFactory:
|
|
|
40
42
|
}
|
|
41
43
|
|
|
42
44
|
@staticmethod
|
|
43
|
-
def create(config: DebuggerConfig):
|
|
44
|
-
if
|
|
45
|
-
|
|
45
|
+
def create(config: DebuggerConfig, model=None):
|
|
46
|
+
if config.level == Const.CELL:
|
|
47
|
+
if not is_graph_mode_cell_dump_allowed(config):
|
|
48
|
+
raise Exception("Cell dump is not supported in graph mode.")
|
|
49
|
+
if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST:
|
|
50
|
+
raise Exception("data_mode must be one of all, forward, backward.")
|
|
51
|
+
else:
|
|
52
|
+
if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST:
|
|
53
|
+
raise Exception("data_mode must be one of all, input, output.")
|
|
46
54
|
tool = DumpToolFactory.tools.get(config.level)
|
|
47
55
|
if not tool:
|
|
48
56
|
raise Exception("Valid level is needed.")
|
|
@@ -51,4 +59,4 @@ class DumpToolFactory:
|
|
|
51
59
|
logger.error(f"Data dump is not supported in {config.execution_mode} mode "
|
|
52
60
|
f"when dump level is {config.level}.")
|
|
53
61
|
raise ValueError
|
|
54
|
-
return tool(config)
|
|
62
|
+
return tool(config, model) if tool == GraphModeCellDump else tool(config)
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
import mindspore as ms
|
|
19
|
+
from mindspore import hal, ops, Tensor
|
|
20
|
+
from mindspore.ops.primitive import _run_op
|
|
21
|
+
|
|
22
|
+
from msprobe.core.common.const import Const as CoreConst
|
|
23
|
+
from msprobe.core.common.runtime import Runtime
|
|
24
|
+
from msprobe.mindspore.common.const import Const
|
|
25
|
+
from msprobe.mindspore.common.log import logger
|
|
26
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
27
|
+
import msprobe.mindspore.dump.cell_dump_process as cellDumperWithDumpGradient
|
|
28
|
+
import msprobe.mindspore.dump.cell_dump_with_insert_gradient as cellDumperWithInsertGradient
|
|
29
|
+
|
|
30
|
+
tensordump_flag = True
|
|
31
|
+
try:
|
|
32
|
+
from mindspore._c_expression import _tensordump_set_step
|
|
33
|
+
except ImportError:
|
|
34
|
+
tensordump_flag = False
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class GraphModeCellDump:
|
|
38
|
+
task = CoreConst.STATISTICS
|
|
39
|
+
|
|
40
|
+
def __init__(self, config: DebuggerConfig, model, strict=True):
|
|
41
|
+
self.net = model
|
|
42
|
+
self.white_list = []
|
|
43
|
+
self.black_list = []
|
|
44
|
+
self.execution_mode = config.execution_mode
|
|
45
|
+
self.dump_path = config.dump_path if config.dump_path else "./"
|
|
46
|
+
self.rank = config.rank
|
|
47
|
+
self.step = config.step
|
|
48
|
+
self.scope = config.scope
|
|
49
|
+
self.list = config.list
|
|
50
|
+
self.data_mode = config.data_mode
|
|
51
|
+
self.file_format = config.file_format
|
|
52
|
+
GraphModeCellDump.task = config.task
|
|
53
|
+
self.summary_mode = config.summary_mode
|
|
54
|
+
self.check_config(strict)
|
|
55
|
+
self.set_step()
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def step():
|
|
59
|
+
# 更新TensorDump Step
|
|
60
|
+
if GraphModeCellDump.task == CoreConst.TENSOR:
|
|
61
|
+
hal.synchronize()
|
|
62
|
+
temp_tensor = ms.Tensor([1], dtype=ms.float32)
|
|
63
|
+
step_flag = "<tensordump-update-step>"
|
|
64
|
+
_run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor))
|
|
65
|
+
ops.tensordump(step_flag, temp_tensor)
|
|
66
|
+
|
|
67
|
+
def check_config(self, strict):
|
|
68
|
+
if not self.net:
|
|
69
|
+
raise Exception("The model is empty and cell dump is not enabled.")
|
|
70
|
+
|
|
71
|
+
if strict:
|
|
72
|
+
if self.rank:
|
|
73
|
+
raise Exception("In graph mode, cell dump does not currently support specifying rank.")
|
|
74
|
+
if self.scope:
|
|
75
|
+
raise Exception("In graph mode, cell dump does not currently support specifying scope.")
|
|
76
|
+
if self.list:
|
|
77
|
+
raise Exception("In graph mode, cell dump does not currently support specifying list.")
|
|
78
|
+
if len(self.data_mode) != 1 or self.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST:
|
|
79
|
+
raise Exception("In graph mode and cell dump, data_mode must be one of all, forword, backword.")
|
|
80
|
+
if self.file_format != []:
|
|
81
|
+
logger.warning("In graph mode, cell dump does not currently support specifying file_format."
|
|
82
|
+
" The file will be stored in npy format.")
|
|
83
|
+
if self.task == CoreConst.STATISTICS and self.summary_mode == CoreConst.MD5:
|
|
84
|
+
raise Exception("The L0 level statistics dump mode does not support "
|
|
85
|
+
"the calculation of md5 values currently In graph mode.")
|
|
86
|
+
else:
|
|
87
|
+
self.rank = []
|
|
88
|
+
self.scope = []
|
|
89
|
+
self.list = []
|
|
90
|
+
self.file_format = []
|
|
91
|
+
if len(self.data_mode) != 1 or self.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST:
|
|
92
|
+
self.data_mode = [CoreConst.ALL]
|
|
93
|
+
if self.task == CoreConst.STATISTICS and self.summary_mode == CoreConst.MD5:
|
|
94
|
+
self.summary_mode = CoreConst.STATISTICS
|
|
95
|
+
|
|
96
|
+
return True
|
|
97
|
+
|
|
98
|
+
def set_step(self):
|
|
99
|
+
if tensordump_flag:
|
|
100
|
+
_tensordump_set_step(self.step)
|
|
101
|
+
else:
|
|
102
|
+
raise Exception(
|
|
103
|
+
"Importing _tensordump_set_step failed, "
|
|
104
|
+
"please use the latest version package of MindSpore."
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def handle(self):
|
|
108
|
+
os.environ['MS_JIT_MODULES'] = 'msprobe'
|
|
109
|
+
|
|
110
|
+
if Runtime.run_mode == Const.PYNATIVE_GRAPH_MODE:
|
|
111
|
+
dump_path = os.path.join(self.dump_path, Const.GRAPH_MODE)
|
|
112
|
+
else:
|
|
113
|
+
dump_path = self.dump_path
|
|
114
|
+
|
|
115
|
+
cell_dumper = cellDumperWithDumpGradient
|
|
116
|
+
|
|
117
|
+
if self.execution_mode == Const.PYNATIVE_MODE:
|
|
118
|
+
enable_dump_gradient = hasattr(ops, 'DumpGradient')
|
|
119
|
+
if hasattr(ops, 'DumpGradient'):
|
|
120
|
+
try:
|
|
121
|
+
ops.DumpGradient()('grad.npy', Tensor([0], dtype=ms.float32), 'in')
|
|
122
|
+
except Exception:
|
|
123
|
+
enable_dump_gradient = False
|
|
124
|
+
logger.warning('the DumpGradient operator failed to execute.')
|
|
125
|
+
if not enable_dump_gradient:
|
|
126
|
+
cell_dumper = cellDumperWithInsertGradient
|
|
127
|
+
|
|
128
|
+
dump_config = cell_dumper.CellDumpConfig(
|
|
129
|
+
net=self.net,
|
|
130
|
+
dump_path=dump_path,
|
|
131
|
+
data_mode=self.data_mode[0],
|
|
132
|
+
task=self.task,
|
|
133
|
+
summary_mode=self.summary_mode,
|
|
134
|
+
step=self.step
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
cell_dumper.start(
|
|
138
|
+
dump_config
|
|
139
|
+
)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from collections import OrderedDict
|
|
18
|
+
import mindspore as ms
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _iterate_items(data):
|
|
22
|
+
if isinstance(data, (dict, OrderedDict)):
|
|
23
|
+
return data.items()
|
|
24
|
+
elif isinstance(data, (list, tuple)):
|
|
25
|
+
return enumerate(data)
|
|
26
|
+
else:
|
|
27
|
+
raise TypeError("Unsupported data type")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class _SaveBase:
|
|
31
|
+
def __init__(self, save_dir):
|
|
32
|
+
super(_SaveBase, self).__init__()
|
|
33
|
+
self.path = save_dir
|
|
34
|
+
self.save_func = _npy_save
|
|
35
|
+
|
|
36
|
+
def get_save_func(self):
|
|
37
|
+
return self.save_func
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@ms.jit_class
|
|
41
|
+
class _SaveCell(_SaveBase):
|
|
42
|
+
def __call__(self, name, data):
|
|
43
|
+
return self.get_save_func()(self.path, name, data)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class _SaveGradBase:
|
|
47
|
+
def __init__(self, save_dir, name):
|
|
48
|
+
super(_SaveGradBase, self).__init__()
|
|
49
|
+
self.file = save_dir + name
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@ms.jit_class
|
|
53
|
+
class _SaveGradCell(_SaveGradBase):
|
|
54
|
+
def __init__(self, save_dir, name):
|
|
55
|
+
super(_SaveGradCell, self).__init__(save_dir, name)
|
|
56
|
+
self.ms_save_grad = ms.ops.InsertGradientOf(
|
|
57
|
+
_wrapper_save_grad_func(self.file))
|
|
58
|
+
|
|
59
|
+
def __call__(self, x):
|
|
60
|
+
if isinstance(x, ms.Tensor):
|
|
61
|
+
return self.ms_save_grad(x)
|
|
62
|
+
else:
|
|
63
|
+
raise TypeError(f"For 'save_grad', the type of argument 'data' must be mindspore.Tensor or torch.tensor, "
|
|
64
|
+
f"but got {type(x)}")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _npy_save_ops(file, data):
|
|
68
|
+
if isinstance(data, ms.Tensor):
|
|
69
|
+
if data.dtype == ms.bfloat16:
|
|
70
|
+
data = data.float()
|
|
71
|
+
ms.ops.TensorDump()(file, data)
|
|
72
|
+
else:
|
|
73
|
+
raise TypeError(f"For 'save', the type of argument 'data' must be mindspore.Tensor or torch.tensor, "
|
|
74
|
+
f"but got {type(data)}")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _wrapper_save_grad_func(file):
|
|
78
|
+
def _save_grad_func(grad):
|
|
79
|
+
data = grad
|
|
80
|
+
if data.dtype == ms.bfloat16:
|
|
81
|
+
data = data.float()
|
|
82
|
+
ms.ops.TensorDump()(file, data)
|
|
83
|
+
return grad
|
|
84
|
+
return _save_grad_func
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _npy_save(save_dir, item_name, data):
|
|
88
|
+
if isinstance(data, (list, tuple, dict, OrderedDict)):
|
|
89
|
+
for key, val in _iterate_items(data):
|
|
90
|
+
_npy_save(save_dir, f"{item_name}.{key}", val)
|
|
91
|
+
else:
|
|
92
|
+
if data is None:
|
|
93
|
+
return
|
|
94
|
+
_npy_save_ops(f"{save_dir}{item_name}", data)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def generate_dump_dir(save_dir, sep=os.sep):
|
|
98
|
+
"""
|
|
99
|
+
usage: generate dump directory path str in mindspore graph mode
|
|
100
|
+
"""
|
|
101
|
+
full_suffix = '{step}' + sep + '{rank}' + sep
|
|
102
|
+
if save_dir and save_dir[-1] != sep:
|
|
103
|
+
result_dir = save_dir + sep + full_suffix
|
|
104
|
+
else:
|
|
105
|
+
result_dir = save_dir + full_suffix
|
|
106
|
+
return result_dir
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def save(save_dir, name, data):
|
|
110
|
+
"""
|
|
111
|
+
save tensor.
|
|
112
|
+
"""
|
|
113
|
+
dump_dir = generate_dump_dir(save_dir)
|
|
114
|
+
_SaveCell(dump_dir)(name, data)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def save_grad(save_dir, name, data):
|
|
118
|
+
"""
|
|
119
|
+
save grad.
|
|
120
|
+
"""
|
|
121
|
+
dump_dir = generate_dump_dir(save_dir)
|
|
122
|
+
suffix_name = name + '_grad'
|
|
123
|
+
return _SaveGradCell(dump_dir, suffix_name)(data)
|
|
@@ -14,14 +14,17 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
+
import inspect
|
|
17
18
|
|
|
18
19
|
from mindspore import Tensor, ops, mint
|
|
20
|
+
from mindspore.mint import distributed
|
|
19
21
|
from mindspore.mint.nn import functional
|
|
20
22
|
from mindspore.communication import comm_func
|
|
21
23
|
|
|
22
24
|
from msprobe.core.common.file_utils import load_yaml
|
|
23
25
|
from msprobe.core.common.utils import Const
|
|
24
26
|
from msprobe.core.data_dump.api_registry import ApiRegistry
|
|
27
|
+
from msprobe.mindspore.common.log import logger
|
|
25
28
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
26
29
|
from msprobe.mindspore.common.utils import is_mindtorch
|
|
27
30
|
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
@@ -41,7 +44,8 @@ if not is_mindtorch():
|
|
|
41
44
|
Const.MS_API_TYPE_TENSOR: (Tensor, (Tensor,)),
|
|
42
45
|
Const.MS_API_TYPE_MINT: (mint, (mint,)),
|
|
43
46
|
Const.MS_API_TYPE_MINT_FUNC: (functional, (functional,)),
|
|
44
|
-
Const.MS_API_TYPE_COM: (comm_func, (comm_func,))
|
|
47
|
+
Const.MS_API_TYPE_COM: (comm_func, (comm_func,)),
|
|
48
|
+
Const.MS_API_TYPE_MINT_DIST: (distributed, (distributed,))
|
|
45
49
|
}
|
|
46
50
|
}
|
|
47
51
|
if stub_tensor_existed:
|
|
@@ -50,6 +54,7 @@ if not is_mindtorch():
|
|
|
50
54
|
)
|
|
51
55
|
|
|
52
56
|
_supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),)
|
|
57
|
+
_backlist = []
|
|
53
58
|
else:
|
|
54
59
|
import torch
|
|
55
60
|
import torch_npu
|
|
@@ -64,13 +69,14 @@ else:
|
|
|
64
69
|
}
|
|
65
70
|
_supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module',
|
|
66
71
|
MsConst.SUPPORTED_API_LIST_FILE),)
|
|
72
|
+
_backlist = [f'{Const.PT_API_TYPE_TENSOR}.__setitem__']
|
|
67
73
|
|
|
68
74
|
_inner_used_api = {
|
|
69
75
|
Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: (
|
|
70
76
|
ops, "norm", "square", "sqrt", "is_complex", "stack", "is_floating_point"
|
|
71
77
|
),
|
|
72
78
|
Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_TENSOR: (
|
|
73
|
-
Tensor, "to", "numel"
|
|
79
|
+
Tensor, "to", "numel", 'sum'
|
|
74
80
|
),
|
|
75
81
|
Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_MINT: (
|
|
76
82
|
mint, "max", "min", "mean", "norm"
|
|
@@ -84,6 +90,9 @@ class ApiTemplate(HOOKCell):
|
|
|
84
90
|
self.api_func = api_func
|
|
85
91
|
self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
|
|
86
92
|
super().__init__(hook_build_func)
|
|
93
|
+
distributed_prefix = Const.DIST_API_TYPE_PREFIX if is_mindtorch() else Const.MINT_DIST_API_TYPE_PREFIX
|
|
94
|
+
if prefix == distributed_prefix:
|
|
95
|
+
self.op_is_distributed = True
|
|
87
96
|
|
|
88
97
|
@staticmethod
|
|
89
98
|
def async_to_sync(output):
|
|
@@ -103,9 +112,22 @@ class ApiTemplate(HOOKCell):
|
|
|
103
112
|
|
|
104
113
|
output = self.api_func(*args, **kwargs)
|
|
105
114
|
|
|
106
|
-
if self.prefix_api_name.startswith(
|
|
107
|
-
|
|
115
|
+
if self.prefix_api_name.startswith(
|
|
116
|
+
(MsConst.DISTRIBUTED_DATA_PREFIX, Const.MINT_DIST_API_TYPE_PREFIX)
|
|
117
|
+
):
|
|
118
|
+
try:
|
|
119
|
+
bound = inspect.signature(self.api_func).bind(*args, **kwargs)
|
|
120
|
+
bound.apply_defaults()
|
|
121
|
+
use_async_op_flag = bound.arguments.get("async_op", False)
|
|
122
|
+
except Exception as e:
|
|
123
|
+
use_async_op_flag = False
|
|
124
|
+
logger.warning(f"fail to get dist api's func signature because {e}, no wait")
|
|
125
|
+
|
|
126
|
+
if use_async_op_flag or self.api_name in ["isend", "irecv"]:
|
|
108
127
|
output = self.async_to_sync(output)
|
|
128
|
+
if self.api_name == "batch_isend_irecv" and isinstance(output, list):
|
|
129
|
+
output = [self.async_to_sync(handle) for handle in output]
|
|
130
|
+
|
|
109
131
|
return output
|
|
110
132
|
|
|
111
133
|
def forward(self, *args, **kwargs):
|
|
@@ -134,9 +156,21 @@ def get_api_register(return_new=False):
|
|
|
134
156
|
stub_tensor_set = True
|
|
135
157
|
|
|
136
158
|
if return_new:
|
|
137
|
-
return ApiRegistry(
|
|
159
|
+
return ApiRegistry(
|
|
160
|
+
_api_types,
|
|
161
|
+
_inner_used_api,
|
|
162
|
+
_supported_api_list_path,
|
|
163
|
+
ApiTemplate,
|
|
164
|
+
_backlist
|
|
165
|
+
)
|
|
138
166
|
|
|
139
167
|
global api_register
|
|
140
168
|
if api_register is None:
|
|
141
|
-
api_register = ApiRegistry(
|
|
169
|
+
api_register = ApiRegistry(
|
|
170
|
+
_api_types,
|
|
171
|
+
_inner_used_api,
|
|
172
|
+
_supported_api_list_path,
|
|
173
|
+
ApiTemplate,
|
|
174
|
+
_backlist
|
|
175
|
+
)
|
|
142
176
|
return api_register
|
|
@@ -15,11 +15,16 @@
|
|
|
15
15
|
|
|
16
16
|
from collections import defaultdict
|
|
17
17
|
|
|
18
|
+
import mindspore as ms
|
|
18
19
|
from mindspore import nn
|
|
19
20
|
|
|
21
|
+
from msprobe.core.common.runtime import Runtime
|
|
20
22
|
from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
ms_version = ms.__version__
|
|
26
|
+
|
|
27
|
+
|
|
23
28
|
def add_cell_count(name):
|
|
24
29
|
HOOKCell.cell_count[name] += 1
|
|
25
30
|
|
|
@@ -31,25 +36,31 @@ def get_cell_count(name):
|
|
|
31
36
|
def __init__(self, hook_build_func) -> None:
|
|
32
37
|
super(HOOKCell, self).__init__()
|
|
33
38
|
self.changed_status = False
|
|
34
|
-
self.
|
|
39
|
+
self.msprobe_input_kwargs = {}
|
|
35
40
|
if not HOOKCell.g_stop_hook:
|
|
36
41
|
HOOKCell.g_stop_hook = True
|
|
37
42
|
self.changed_status = True
|
|
38
43
|
self.forward_data_collected = False
|
|
39
44
|
|
|
45
|
+
if not Runtime.is_running:
|
|
46
|
+
return
|
|
40
47
|
prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
|
|
41
48
|
if callable(hook_build_func):
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
49
|
+
hook_set = hook_build_func(prefix)
|
|
50
|
+
if ms_version < "2.6.0" and not is_mindtorch():
|
|
51
|
+
getattr(self, "_forward_pre_hook", {})[id(self)] = hook_set.forward_pre_hook
|
|
52
|
+
getattr(self, "_forward_hook", {})[id(self)] = hook_set.forward_hook
|
|
53
|
+
else:
|
|
54
|
+
self.register_forward_pre_hook(hook_set.forward_pre_hook)
|
|
55
|
+
self.register_forward_hook(hook_set.forward_hook)
|
|
56
|
+
register_backward_hook_functions["full"](self, hook_set.backward_hook)
|
|
57
|
+
register_backward_hook_functions["pre"](self, hook_set.backward_pre_hook)
|
|
47
58
|
|
|
48
59
|
|
|
49
60
|
# 重载call,加全局标志。
|
|
50
61
|
def __call__(self, *args, **kwargs):
|
|
51
62
|
try:
|
|
52
|
-
self.
|
|
63
|
+
self.msprobe_input_kwargs = kwargs
|
|
53
64
|
out = super(HOOKCell, self).__call__(*args, **kwargs)
|
|
54
65
|
except Exception as e:
|
|
55
66
|
raise e
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from mindspore.common.api import _no_grad
|
|
17
|
+
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.core.common.utils import replace_last_occurrence
|
|
19
|
+
from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputs
|
|
20
|
+
from msprobe.core.hook_manager import BaseHookManager, HookSet
|
|
21
|
+
from msprobe.mindspore.common.utils import has_kwargs_in_forward_hook
|
|
22
|
+
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MindsproeHookManager(BaseHookManager):
|
|
26
|
+
@property
|
|
27
|
+
def _is_recompute(self):
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def _no_grad_context():
|
|
32
|
+
return _no_grad()
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def _add_count(name):
|
|
36
|
+
HOOKCell.add_cell_count(name)
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs):
|
|
40
|
+
if not has_kwargs_in_forward_hook() or hook_type == Const.API:
|
|
41
|
+
kwargs = module.msprobe_input_kwargs if hasattr(module, 'msprobe_input_kwargs') else {}
|
|
42
|
+
output = kwargs_or_output
|
|
43
|
+
else:
|
|
44
|
+
kwargs = kwargs_or_output
|
|
45
|
+
output = output_or_kwargs
|
|
46
|
+
return kwargs, output
|
|
47
|
+
|
|
48
|
+
def build_hook(self, hook_type, name):
|
|
49
|
+
if hook_type == Const.API:
|
|
50
|
+
full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD
|
|
51
|
+
else:
|
|
52
|
+
full_forward_name = name
|
|
53
|
+
full_backward_name = replace_last_occurrence(full_forward_name, Const.FORWARD, Const.BACKWARD)
|
|
54
|
+
hookset = HookSet(
|
|
55
|
+
forward_hook=self._build_forward_hook(hook_type, full_forward_name),
|
|
56
|
+
forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name),
|
|
57
|
+
backward_hook=self._build_backward_hook(hook_type, full_backward_name),
|
|
58
|
+
backward_pre_hook=self._build_backward_pre_hook(hook_type, full_backward_name)
|
|
59
|
+
)
|
|
60
|
+
return hookset
|
|
61
|
+
|
|
62
|
+
def _need_exchange(self, module):
|
|
63
|
+
if not hasattr(module, 'has_pre_hook_called') or not module.has_pre_hook_called:
|
|
64
|
+
return False
|
|
65
|
+
else:
|
|
66
|
+
return True
|
|
67
|
+
|
|
68
|
+
def _get_params_dict(self, module):
|
|
69
|
+
params_dict = {}
|
|
70
|
+
if self.config.task != Const.STRUCTURE:
|
|
71
|
+
params_dict = {
|
|
72
|
+
key.split(Const.SEP)[-1]: value
|
|
73
|
+
for key, value in module.parameters_dict(recurse=False).items()
|
|
74
|
+
}
|
|
75
|
+
return params_dict
|
|
76
|
+
|
|
77
|
+
def _build_backward_pre_hook(self, hook_type, name):
|
|
78
|
+
def backward_pre_hook(module, grad_input):
|
|
79
|
+
if self.config.level != Const.LEVEL_L2:
|
|
80
|
+
return
|
|
81
|
+
if not self._should_execute_hook(hook_type, module, False):
|
|
82
|
+
return
|
|
83
|
+
BaseHookManager.inner_switch = True
|
|
84
|
+
module_input = ModuleBackwardInputs(grad_input=grad_input)
|
|
85
|
+
self.data_collector.update_api_or_module_name(name)
|
|
86
|
+
self.data_collector.backward_input_data_collect(name, module, self._pid, module_input)
|
|
87
|
+
BaseHookManager.inner_switch = False
|
|
88
|
+
return backward_pre_hook
|
|
@@ -21,6 +21,7 @@ from mindspore.common.tensor import Tensor
|
|
|
21
21
|
from msprobe.core.common.utils import Const, DumpException
|
|
22
22
|
from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, ModuleBackwardOutputs,
|
|
23
23
|
ModuleForwardInputsOutputs)
|
|
24
|
+
from msprobe.core.hook_manager import BaseHookManager
|
|
24
25
|
from msprobe.mindspore.common.log import logger
|
|
25
26
|
|
|
26
27
|
|
|
@@ -58,7 +59,7 @@ class PrimitiveHookService:
|
|
|
58
59
|
def backward_hook(grad):
|
|
59
60
|
captured_grads.extend(grad)
|
|
60
61
|
backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}"
|
|
61
|
-
|
|
62
|
+
self.service_instance.inner_switch = True
|
|
62
63
|
try:
|
|
63
64
|
if hook_type == Const.INPUT:
|
|
64
65
|
self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
|
|
@@ -77,6 +78,7 @@ class PrimitiveHookService:
|
|
|
77
78
|
logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, "
|
|
78
79
|
f"updated_primitive_name: {updated_primitive_name}")
|
|
79
80
|
raise DumpException(DumpException.BACKWARD_DATA_COLLECTION_ERROR) from exception
|
|
81
|
+
self.service_instance.inner_switch = False
|
|
80
82
|
|
|
81
83
|
return backward_hook
|
|
82
84
|
|
|
@@ -137,6 +139,7 @@ class PrimitiveHookService:
|
|
|
137
139
|
|
|
138
140
|
def pre_forward_hook(primitive_name, primitive_instance, args, kwargs):
|
|
139
141
|
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
|
|
142
|
+
self.service_instance.inner_switch = True
|
|
140
143
|
try:
|
|
141
144
|
self.service_instance.data_collector.forward_input_data_collect(
|
|
142
145
|
primitive_name,
|
|
@@ -148,9 +151,11 @@ class PrimitiveHookService:
|
|
|
148
151
|
logger.error(f"This is a primitive op dump error during forward input data collection: {exception}, "
|
|
149
152
|
f"primitive_name: {primitive_name}")
|
|
150
153
|
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
154
|
+
self.service_instance.inner_switch = False
|
|
151
155
|
|
|
152
156
|
def post_forward_hook(primitive_name, primitive_instance, args, kwargs, output):
|
|
153
157
|
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
158
|
+
self.service_instance.inner_switch = True
|
|
154
159
|
try:
|
|
155
160
|
self.service_instance.data_collector.forward_output_data_collect(
|
|
156
161
|
primitive_name,
|
|
@@ -162,6 +167,7 @@ class PrimitiveHookService:
|
|
|
162
167
|
logger.error(f"This is a primitive op dump error during forward output data collection: {exception}, "
|
|
163
168
|
f"primitive_name: {primitive_name}")
|
|
164
169
|
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
170
|
+
self.service_instance.inner_switch = False
|
|
165
171
|
|
|
166
172
|
def wrapped_primitive_call(instance_self, *args, **kwargs):
|
|
167
173
|
"""
|
|
@@ -179,7 +185,7 @@ class PrimitiveHookService:
|
|
|
179
185
|
current_count = self.primitive_counters.get(primitive_name, 0)
|
|
180
186
|
updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}{Const.SEP}{primitive_name}{Const.SEP}{current_count}"
|
|
181
187
|
|
|
182
|
-
if not self.service_instance.primitive_switch:
|
|
188
|
+
if not self.service_instance.primitive_switch or BaseHookManager.inner_switch:
|
|
183
189
|
return origin_func(*args, **kwargs)
|
|
184
190
|
|
|
185
191
|
captured_grads_input, captured_grads_output = [], []
|
|
@@ -1025,3 +1025,21 @@ communication.comm_func:
|
|
|
1025
1025
|
- recv
|
|
1026
1026
|
- isend
|
|
1027
1027
|
- irecv
|
|
1028
|
+
|
|
1029
|
+
mint.distributed:
|
|
1030
|
+
- send
|
|
1031
|
+
- recv
|
|
1032
|
+
- broadcast
|
|
1033
|
+
- all_reduce
|
|
1034
|
+
- reduce
|
|
1035
|
+
- all_gather
|
|
1036
|
+
- gather
|
|
1037
|
+
- isend
|
|
1038
|
+
- irecv
|
|
1039
|
+
- scatter
|
|
1040
|
+
- reduce_scatter
|
|
1041
|
+
- all_to_all_single
|
|
1042
|
+
- all_to_all
|
|
1043
|
+
- all_gather_into_tensor
|
|
1044
|
+
- reduce_scatter_tensor
|
|
1045
|
+
- batch_isend_irecv
|