mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -14,15 +14,18 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
from msprobe.mindspore.common.const import Const
|
|
17
|
+
from msprobe.core.common.log import logger
|
|
18
|
+
from msprobe.mindspore.common.utils import is_graph_mode_cell_dump_allowed
|
|
17
19
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
18
20
|
from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
|
|
19
21
|
from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump
|
|
22
|
+
from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump
|
|
20
23
|
|
|
21
24
|
|
|
22
25
|
class DumpToolFactory:
|
|
23
26
|
tools = {
|
|
24
27
|
Const.CELL: {
|
|
25
|
-
Const.GRAPH_KBYK_MODE:
|
|
28
|
+
Const.GRAPH_KBYK_MODE: GraphModeCellDump,
|
|
26
29
|
Const.GRAPH_GE_MODE: None,
|
|
27
30
|
Const.PYNATIVE_MODE: None
|
|
28
31
|
},
|
|
@@ -39,14 +42,21 @@ class DumpToolFactory:
|
|
|
39
42
|
}
|
|
40
43
|
|
|
41
44
|
@staticmethod
|
|
42
|
-
def create(config: DebuggerConfig):
|
|
43
|
-
if
|
|
44
|
-
|
|
45
|
+
def create(config: DebuggerConfig, model=None):
|
|
46
|
+
if config.level == Const.CELL:
|
|
47
|
+
if not is_graph_mode_cell_dump_allowed(config):
|
|
48
|
+
raise Exception("Cell dump is not supported in graph mode.")
|
|
49
|
+
if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST:
|
|
50
|
+
raise Exception("data_mode must be one of all, forward, backward.")
|
|
51
|
+
else:
|
|
52
|
+
if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST:
|
|
53
|
+
raise Exception("data_mode must be one of all, input, output.")
|
|
45
54
|
tool = DumpToolFactory.tools.get(config.level)
|
|
46
55
|
if not tool:
|
|
47
56
|
raise Exception("Valid level is needed.")
|
|
48
57
|
tool = tool.get(config.execution_mode)
|
|
49
58
|
if not tool:
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
59
|
+
logger.error(f"Data dump is not supported in {config.execution_mode} mode "
|
|
60
|
+
f"when dump level is {config.level}.")
|
|
61
|
+
raise ValueError
|
|
62
|
+
return tool(config, model) if tool == GraphModeCellDump else tool(config)
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
import mindspore as ms
|
|
19
|
+
from mindspore import hal, ops, Tensor
|
|
20
|
+
from mindspore.ops.primitive import _run_op
|
|
21
|
+
|
|
22
|
+
from msprobe.core.common.const import Const as CoreConst
|
|
23
|
+
from msprobe.core.common.runtime import Runtime
|
|
24
|
+
from msprobe.mindspore.common.const import Const
|
|
25
|
+
from msprobe.mindspore.common.log import logger
|
|
26
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
27
|
+
import msprobe.mindspore.dump.cell_dump_process as cellDumperWithDumpGradient
|
|
28
|
+
import msprobe.mindspore.dump.cell_dump_with_insert_gradient as cellDumperWithInsertGradient
|
|
29
|
+
|
|
30
|
+
tensordump_flag = True
|
|
31
|
+
try:
|
|
32
|
+
from mindspore._c_expression import _tensordump_set_step
|
|
33
|
+
except ImportError:
|
|
34
|
+
tensordump_flag = False
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class GraphModeCellDump:
|
|
38
|
+
task = CoreConst.STATISTICS
|
|
39
|
+
|
|
40
|
+
def __init__(self, config: DebuggerConfig, model, strict=True):
|
|
41
|
+
self.net = model
|
|
42
|
+
self.white_list = []
|
|
43
|
+
self.black_list = []
|
|
44
|
+
self.execution_mode = config.execution_mode
|
|
45
|
+
self.dump_path = config.dump_path if config.dump_path else "./"
|
|
46
|
+
self.rank = config.rank
|
|
47
|
+
self.step = config.step
|
|
48
|
+
self.scope = config.scope
|
|
49
|
+
self.list = config.list
|
|
50
|
+
self.data_mode = config.data_mode
|
|
51
|
+
self.file_format = config.file_format
|
|
52
|
+
GraphModeCellDump.task = config.task
|
|
53
|
+
self.summary_mode = config.summary_mode
|
|
54
|
+
self.check_config(strict)
|
|
55
|
+
self.set_step()
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def step():
|
|
59
|
+
# 更新TensorDump Step
|
|
60
|
+
if GraphModeCellDump.task == CoreConst.TENSOR:
|
|
61
|
+
hal.synchronize()
|
|
62
|
+
temp_tensor = ms.Tensor([1], dtype=ms.float32)
|
|
63
|
+
step_flag = "<tensordump-update-step>"
|
|
64
|
+
_run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor))
|
|
65
|
+
ops.tensordump(step_flag, temp_tensor)
|
|
66
|
+
|
|
67
|
+
def check_config(self, strict):
|
|
68
|
+
if not self.net:
|
|
69
|
+
raise Exception("The model is empty and cell dump is not enabled.")
|
|
70
|
+
|
|
71
|
+
if strict:
|
|
72
|
+
if self.rank:
|
|
73
|
+
raise Exception("In graph mode, cell dump does not currently support specifying rank.")
|
|
74
|
+
if self.scope:
|
|
75
|
+
raise Exception("In graph mode, cell dump does not currently support specifying scope.")
|
|
76
|
+
if self.list:
|
|
77
|
+
raise Exception("In graph mode, cell dump does not currently support specifying list.")
|
|
78
|
+
if len(self.data_mode) != 1 or self.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST:
|
|
79
|
+
raise Exception("In graph mode and cell dump, data_mode must be one of all, forword, backword.")
|
|
80
|
+
if self.file_format != []:
|
|
81
|
+
logger.warning("In graph mode, cell dump does not currently support specifying file_format."
|
|
82
|
+
" The file will be stored in npy format.")
|
|
83
|
+
if self.task == CoreConst.STATISTICS and self.summary_mode == CoreConst.MD5:
|
|
84
|
+
raise Exception("The L0 level statistics dump mode does not support "
|
|
85
|
+
"the calculation of md5 values currently In graph mode.")
|
|
86
|
+
else:
|
|
87
|
+
self.rank = []
|
|
88
|
+
self.scope = []
|
|
89
|
+
self.list = []
|
|
90
|
+
self.file_format = []
|
|
91
|
+
if len(self.data_mode) != 1 or self.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST:
|
|
92
|
+
self.data_mode = [CoreConst.ALL]
|
|
93
|
+
if self.task == CoreConst.STATISTICS and self.summary_mode == CoreConst.MD5:
|
|
94
|
+
self.summary_mode = CoreConst.STATISTICS
|
|
95
|
+
|
|
96
|
+
return True
|
|
97
|
+
|
|
98
|
+
def set_step(self):
|
|
99
|
+
if tensordump_flag:
|
|
100
|
+
_tensordump_set_step(self.step)
|
|
101
|
+
else:
|
|
102
|
+
raise Exception(
|
|
103
|
+
"Importing _tensordump_set_step failed, "
|
|
104
|
+
"please use the latest version package of MindSpore."
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def handle(self):
|
|
108
|
+
os.environ['MS_JIT_MODULES'] = 'msprobe'
|
|
109
|
+
|
|
110
|
+
if Runtime.run_mode == Const.PYNATIVE_GRAPH_MODE:
|
|
111
|
+
dump_path = os.path.join(self.dump_path, Const.GRAPH_MODE)
|
|
112
|
+
else:
|
|
113
|
+
dump_path = self.dump_path
|
|
114
|
+
|
|
115
|
+
cell_dumper = cellDumperWithDumpGradient
|
|
116
|
+
|
|
117
|
+
if self.execution_mode == Const.PYNATIVE_MODE:
|
|
118
|
+
enable_dump_gradient = hasattr(ops, 'DumpGradient')
|
|
119
|
+
if hasattr(ops, 'DumpGradient'):
|
|
120
|
+
try:
|
|
121
|
+
ops.DumpGradient()('grad.npy', Tensor([0], dtype=ms.float32), 'in')
|
|
122
|
+
except Exception:
|
|
123
|
+
enable_dump_gradient = False
|
|
124
|
+
logger.warning('the DumpGradient operator failed to execute.')
|
|
125
|
+
if not enable_dump_gradient:
|
|
126
|
+
cell_dumper = cellDumperWithInsertGradient
|
|
127
|
+
|
|
128
|
+
dump_config = cell_dumper.CellDumpConfig(
|
|
129
|
+
net=self.net,
|
|
130
|
+
dump_path=dump_path,
|
|
131
|
+
data_mode=self.data_mode[0],
|
|
132
|
+
task=self.task,
|
|
133
|
+
summary_mode=self.summary_mode,
|
|
134
|
+
step=self.step
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
cell_dumper.start(
|
|
138
|
+
dump_config
|
|
139
|
+
)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from collections import OrderedDict
|
|
18
|
+
import mindspore as ms
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _iterate_items(data):
|
|
22
|
+
if isinstance(data, (dict, OrderedDict)):
|
|
23
|
+
return data.items()
|
|
24
|
+
elif isinstance(data, (list, tuple)):
|
|
25
|
+
return enumerate(data)
|
|
26
|
+
else:
|
|
27
|
+
raise TypeError("Unsupported data type")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class _SaveBase:
|
|
31
|
+
def __init__(self, save_dir):
|
|
32
|
+
super(_SaveBase, self).__init__()
|
|
33
|
+
self.path = save_dir
|
|
34
|
+
self.save_func = _npy_save
|
|
35
|
+
|
|
36
|
+
def get_save_func(self):
|
|
37
|
+
return self.save_func
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@ms.jit_class
|
|
41
|
+
class _SaveCell(_SaveBase):
|
|
42
|
+
def __call__(self, name, data):
|
|
43
|
+
return self.get_save_func()(self.path, name, data)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class _SaveGradBase:
|
|
47
|
+
def __init__(self, save_dir, name):
|
|
48
|
+
super(_SaveGradBase, self).__init__()
|
|
49
|
+
self.file = save_dir + name
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@ms.jit_class
|
|
53
|
+
class _SaveGradCell(_SaveGradBase):
|
|
54
|
+
def __init__(self, save_dir, name):
|
|
55
|
+
super(_SaveGradCell, self).__init__(save_dir, name)
|
|
56
|
+
self.ms_save_grad = ms.ops.InsertGradientOf(
|
|
57
|
+
_wrapper_save_grad_func(self.file))
|
|
58
|
+
|
|
59
|
+
def __call__(self, x):
|
|
60
|
+
if isinstance(x, ms.Tensor):
|
|
61
|
+
return self.ms_save_grad(x)
|
|
62
|
+
else:
|
|
63
|
+
raise TypeError(f"For 'save_grad', the type of argument 'data' must be mindspore.Tensor or torch.tensor, "
|
|
64
|
+
f"but got {type(x)}")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _npy_save_ops(file, data):
|
|
68
|
+
if isinstance(data, ms.Tensor):
|
|
69
|
+
if data.dtype == ms.bfloat16:
|
|
70
|
+
data = data.float()
|
|
71
|
+
ms.ops.TensorDump()(file, data)
|
|
72
|
+
else:
|
|
73
|
+
raise TypeError(f"For 'save', the type of argument 'data' must be mindspore.Tensor or torch.tensor, "
|
|
74
|
+
f"but got {type(data)}")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _wrapper_save_grad_func(file):
|
|
78
|
+
def _save_grad_func(grad):
|
|
79
|
+
data = grad
|
|
80
|
+
if data.dtype == ms.bfloat16:
|
|
81
|
+
data = data.float()
|
|
82
|
+
ms.ops.TensorDump()(file, data)
|
|
83
|
+
return grad
|
|
84
|
+
return _save_grad_func
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _npy_save(save_dir, item_name, data):
|
|
88
|
+
if isinstance(data, (list, tuple, dict, OrderedDict)):
|
|
89
|
+
for key, val in _iterate_items(data):
|
|
90
|
+
_npy_save(save_dir, f"{item_name}.{key}", val)
|
|
91
|
+
else:
|
|
92
|
+
if data is None:
|
|
93
|
+
return
|
|
94
|
+
_npy_save_ops(f"{save_dir}{item_name}", data)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def generate_dump_dir(save_dir, sep=os.sep):
|
|
98
|
+
"""
|
|
99
|
+
usage: generate dump directory path str in mindspore graph mode
|
|
100
|
+
"""
|
|
101
|
+
full_suffix = '{step}' + sep + '{rank}' + sep
|
|
102
|
+
if save_dir and save_dir[-1] != sep:
|
|
103
|
+
result_dir = save_dir + sep + full_suffix
|
|
104
|
+
else:
|
|
105
|
+
result_dir = save_dir + full_suffix
|
|
106
|
+
return result_dir
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def save(save_dir, name, data):
|
|
110
|
+
"""
|
|
111
|
+
save tensor.
|
|
112
|
+
"""
|
|
113
|
+
dump_dir = generate_dump_dir(save_dir)
|
|
114
|
+
_SaveCell(dump_dir)(name, data)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def save_grad(save_dir, name, data):
|
|
118
|
+
"""
|
|
119
|
+
save grad.
|
|
120
|
+
"""
|
|
121
|
+
dump_dir = generate_dump_dir(save_dir)
|
|
122
|
+
suffix_name = name + '_grad'
|
|
123
|
+
return _SaveGradCell(dump_dir, suffix_name)(data)
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import inspect
|
|
18
|
+
|
|
19
|
+
from mindspore import Tensor, ops, mint
|
|
20
|
+
from mindspore.mint import distributed
|
|
21
|
+
from mindspore.mint.nn import functional
|
|
22
|
+
from mindspore.communication import comm_func
|
|
23
|
+
|
|
24
|
+
from msprobe.core.common.file_utils import load_yaml
|
|
25
|
+
from msprobe.core.common.utils import Const
|
|
26
|
+
from msprobe.core.data_dump.api_registry import ApiRegistry
|
|
27
|
+
from msprobe.mindspore.common.log import logger
|
|
28
|
+
from msprobe.mindspore.common.const import Const as MsConst
|
|
29
|
+
from msprobe.mindspore.common.utils import is_mindtorch
|
|
30
|
+
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
stub_tensor_existed = True
|
|
34
|
+
try:
|
|
35
|
+
from mindspore.common._stub_tensor import StubTensor
|
|
36
|
+
except ImportError:
|
|
37
|
+
stub_tensor_existed = False
|
|
38
|
+
|
|
39
|
+
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
40
|
+
if not is_mindtorch():
|
|
41
|
+
_api_types = {
|
|
42
|
+
Const.MS_FRAMEWORK: {
|
|
43
|
+
Const.MS_API_TYPE_OPS: (ops, (ops,)),
|
|
44
|
+
Const.MS_API_TYPE_TENSOR: (Tensor, (Tensor,)),
|
|
45
|
+
Const.MS_API_TYPE_MINT: (mint, (mint,)),
|
|
46
|
+
Const.MS_API_TYPE_MINT_FUNC: (functional, (functional,)),
|
|
47
|
+
Const.MS_API_TYPE_COM: (comm_func, (comm_func,)),
|
|
48
|
+
Const.MS_API_TYPE_MINT_DIST: (distributed, (distributed,))
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
if stub_tensor_existed:
|
|
52
|
+
_api_types.get(Const.MS_FRAMEWORK).update(
|
|
53
|
+
{Const.MS_API_TYPE_STUB_TENSOR: (StubTensor, (StubTensor,))}
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
_supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),)
|
|
57
|
+
_backlist = []
|
|
58
|
+
else:
|
|
59
|
+
import torch
|
|
60
|
+
import torch_npu
|
|
61
|
+
_api_types = {
|
|
62
|
+
Const.MT_FRAMEWORK: {
|
|
63
|
+
Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
|
|
64
|
+
Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)),
|
|
65
|
+
Const.PT_API_TYPE_TORCH: (torch, (torch,)),
|
|
66
|
+
Const.PT_API_TYPE_NPU: (torch_npu, (torch_npu,)),
|
|
67
|
+
Const.PT_API_TYPE_DIST: (torch.distributed, (torch.distributed, torch.distributed.distributed_c10d))
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
_supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module',
|
|
71
|
+
MsConst.SUPPORTED_API_LIST_FILE),)
|
|
72
|
+
_backlist = [f'{Const.PT_API_TYPE_TENSOR}.__setitem__']
|
|
73
|
+
|
|
74
|
+
_inner_used_api = {
|
|
75
|
+
Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: (
|
|
76
|
+
ops, "norm", "square", "sqrt", "is_complex", "stack", "is_floating_point"
|
|
77
|
+
),
|
|
78
|
+
Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_TENSOR: (
|
|
79
|
+
Tensor, "to", "numel", 'sum'
|
|
80
|
+
),
|
|
81
|
+
Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_MINT: (
|
|
82
|
+
mint, "max", "min", "mean", "norm"
|
|
83
|
+
)
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class ApiTemplate(HOOKCell):
|
|
88
|
+
def __init__(self, api_name, api_func, prefix, hook_build_func):
|
|
89
|
+
self.api_name = api_name
|
|
90
|
+
self.api_func = api_func
|
|
91
|
+
self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
|
|
92
|
+
super().__init__(hook_build_func)
|
|
93
|
+
distributed_prefix = Const.DIST_API_TYPE_PREFIX if is_mindtorch() else Const.MINT_DIST_API_TYPE_PREFIX
|
|
94
|
+
if prefix == distributed_prefix:
|
|
95
|
+
self.op_is_distributed = True
|
|
96
|
+
|
|
97
|
+
@staticmethod
|
|
98
|
+
def async_to_sync(output):
|
|
99
|
+
# Fake handle, used to return after the CommHandle executes the wait method
|
|
100
|
+
fake_handle = type("FakeHandle", (), {"wait": lambda self: None})()
|
|
101
|
+
if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"):
|
|
102
|
+
output[1].wait()
|
|
103
|
+
output = (output[0], fake_handle)
|
|
104
|
+
elif hasattr(output, "wait"):
|
|
105
|
+
output.wait()
|
|
106
|
+
output = fake_handle
|
|
107
|
+
return output
|
|
108
|
+
|
|
109
|
+
def construct(self, *args, **kwargs):
|
|
110
|
+
if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
|
|
111
|
+
return args[0] if args else kwargs.get(Const.INPUT)
|
|
112
|
+
|
|
113
|
+
output = self.api_func(*args, **kwargs)
|
|
114
|
+
|
|
115
|
+
if self.prefix_api_name.startswith(
|
|
116
|
+
(MsConst.DISTRIBUTED_DATA_PREFIX, Const.MINT_DIST_API_TYPE_PREFIX)
|
|
117
|
+
):
|
|
118
|
+
try:
|
|
119
|
+
bound = inspect.signature(self.api_func).bind(*args, **kwargs)
|
|
120
|
+
bound.apply_defaults()
|
|
121
|
+
use_asyn_op_flag = bound.arguments.get("asyn_op", False)
|
|
122
|
+
except Exception as e:
|
|
123
|
+
use_asyn_op_flag = False
|
|
124
|
+
logger.warning(f"fail to get dist api's func signature because {e}, no wait")
|
|
125
|
+
|
|
126
|
+
if use_asyn_op_flag or self.api_name in ["isend", "irecv"]:
|
|
127
|
+
output = self.async_to_sync(output)
|
|
128
|
+
if self.api_name == "batch_isend_irecv" and isinstance(output, list):
|
|
129
|
+
output = [self.async_to_sync(handle) for handle in output]
|
|
130
|
+
|
|
131
|
+
return output
|
|
132
|
+
|
|
133
|
+
def forward(self, *args, **kwargs):
|
|
134
|
+
if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
|
|
135
|
+
return args[0] if args else kwargs.get(Const.INPUT)
|
|
136
|
+
return self.api_func(*args, **kwargs)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
api_register = None
|
|
140
|
+
stub_tensor_set = False
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def get_api_register(return_new=False):
|
|
144
|
+
global stub_tensor_set
|
|
145
|
+
|
|
146
|
+
def stub_method(method):
|
|
147
|
+
def wrapped_method(*args, **kwargs):
|
|
148
|
+
return method(*args, **kwargs)
|
|
149
|
+
return wrapped_method
|
|
150
|
+
if not is_mindtorch() and stub_tensor_existed and not stub_tensor_set:
|
|
151
|
+
api_names = load_yaml(_supported_api_list_path[0]).get(Const.MS_API_TYPE_TENSOR, [])
|
|
152
|
+
for attr_name in dir(StubTensor):
|
|
153
|
+
attr = getattr(StubTensor, attr_name)
|
|
154
|
+
if attr_name in api_names and callable(attr):
|
|
155
|
+
setattr(StubTensor, attr_name, stub_method(attr))
|
|
156
|
+
stub_tensor_set = True
|
|
157
|
+
|
|
158
|
+
if return_new:
|
|
159
|
+
return ApiRegistry(
|
|
160
|
+
_api_types,
|
|
161
|
+
_inner_used_api,
|
|
162
|
+
_supported_api_list_path,
|
|
163
|
+
ApiTemplate,
|
|
164
|
+
_backlist
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
global api_register
|
|
168
|
+
if api_register is None:
|
|
169
|
+
api_register = ApiRegistry(
|
|
170
|
+
_api_types,
|
|
171
|
+
_inner_used_api,
|
|
172
|
+
_supported_api_list_path,
|
|
173
|
+
ApiTemplate,
|
|
174
|
+
_backlist
|
|
175
|
+
)
|
|
176
|
+
return api_register
|
|
@@ -15,11 +15,16 @@
|
|
|
15
15
|
|
|
16
16
|
from collections import defaultdict
|
|
17
17
|
|
|
18
|
+
import mindspore as ms
|
|
18
19
|
from mindspore import nn
|
|
19
20
|
|
|
21
|
+
from msprobe.core.common.runtime import Runtime
|
|
20
22
|
from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
ms_version = ms.__version__
|
|
26
|
+
|
|
27
|
+
|
|
23
28
|
def add_cell_count(name):
|
|
24
29
|
HOOKCell.cell_count[name] += 1
|
|
25
30
|
|
|
@@ -28,29 +33,34 @@ def get_cell_count(name):
|
|
|
28
33
|
return HOOKCell.cell_count[name]
|
|
29
34
|
|
|
30
35
|
|
|
31
|
-
def __init__(self,
|
|
36
|
+
def __init__(self, hook_build_func) -> None:
|
|
32
37
|
super(HOOKCell, self).__init__()
|
|
33
38
|
self.changed_status = False
|
|
34
|
-
self.
|
|
35
|
-
self.prefix = ""
|
|
39
|
+
self.msprobe_input_kwargs = {}
|
|
36
40
|
if not HOOKCell.g_stop_hook:
|
|
37
41
|
HOOKCell.g_stop_hook = True
|
|
38
42
|
self.changed_status = True
|
|
39
|
-
if hasattr(self, "prefix_api_name"):
|
|
40
|
-
self.prefix = self.prefix_api_name
|
|
41
|
-
|
|
42
43
|
self.forward_data_collected = False
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
44
|
+
|
|
45
|
+
if not Runtime.is_running:
|
|
46
|
+
return
|
|
47
|
+
prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
|
|
48
|
+
if callable(hook_build_func):
|
|
49
|
+
hook_set = hook_build_func(prefix)
|
|
50
|
+
if ms_version < "2.6.0" and not is_mindtorch():
|
|
51
|
+
getattr(self, "_forward_pre_hook", {})[id(self)] = hook_set.forward_pre_hook
|
|
52
|
+
getattr(self, "_forward_hook", {})[id(self)] = hook_set.forward_hook
|
|
53
|
+
else:
|
|
54
|
+
self.register_forward_pre_hook(hook_set.forward_pre_hook)
|
|
55
|
+
self.register_forward_hook(hook_set.forward_hook)
|
|
56
|
+
register_backward_hook_functions["full"](self, hook_set.backward_hook)
|
|
57
|
+
register_backward_hook_functions["pre"](self, hook_set.backward_pre_hook)
|
|
48
58
|
|
|
49
59
|
|
|
50
60
|
# 重载call,加全局标志。
|
|
51
61
|
def __call__(self, *args, **kwargs):
|
|
52
62
|
try:
|
|
53
|
-
self.
|
|
63
|
+
self.msprobe_input_kwargs = kwargs
|
|
54
64
|
out = super(HOOKCell, self).__call__(*args, **kwargs)
|
|
55
65
|
except Exception as e:
|
|
56
66
|
raise e
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from mindspore.common.api import _no_grad
|
|
17
|
+
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.core.common.utils import replace_last_occurrence
|
|
19
|
+
from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputs
|
|
20
|
+
from msprobe.core.hook_manager import BaseHookManager, HookSet
|
|
21
|
+
from msprobe.mindspore.common.utils import has_kwargs_in_forward_hook
|
|
22
|
+
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MindsproeHookManager(BaseHookManager):
|
|
26
|
+
@property
|
|
27
|
+
def _is_recompute(self):
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def _no_grad_context():
|
|
32
|
+
return _no_grad()
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def _add_count(name):
|
|
36
|
+
HOOKCell.add_cell_count(name)
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs):
|
|
40
|
+
if not has_kwargs_in_forward_hook() or hook_type == Const.API:
|
|
41
|
+
kwargs = module.msprobe_input_kwargs if hasattr(module, 'msprobe_input_kwargs') else {}
|
|
42
|
+
output = kwargs_or_output
|
|
43
|
+
else:
|
|
44
|
+
kwargs = kwargs_or_output
|
|
45
|
+
output = output_or_kwargs
|
|
46
|
+
return kwargs, output
|
|
47
|
+
|
|
48
|
+
def build_hook(self, hook_type, name):
|
|
49
|
+
if hook_type == Const.API:
|
|
50
|
+
full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD
|
|
51
|
+
else:
|
|
52
|
+
full_forward_name = name
|
|
53
|
+
full_backward_name = replace_last_occurrence(full_forward_name, Const.FORWARD, Const.BACKWARD)
|
|
54
|
+
hookset = HookSet(
|
|
55
|
+
forward_hook=self._build_forward_hook(hook_type, full_forward_name),
|
|
56
|
+
forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name),
|
|
57
|
+
backward_hook=self._build_backward_hook(hook_type, full_backward_name),
|
|
58
|
+
backward_pre_hook=self._build_backward_pre_hook(hook_type, full_backward_name)
|
|
59
|
+
)
|
|
60
|
+
return hookset
|
|
61
|
+
|
|
62
|
+
def _need_exchange(self, module):
|
|
63
|
+
if not hasattr(module, 'has_pre_hook_called') or not module.has_pre_hook_called:
|
|
64
|
+
return False
|
|
65
|
+
else:
|
|
66
|
+
return True
|
|
67
|
+
|
|
68
|
+
def _get_params_dict(self, module):
|
|
69
|
+
params_dict = {}
|
|
70
|
+
if self.config.task != Const.STRUCTURE:
|
|
71
|
+
params_dict = {
|
|
72
|
+
key.split(Const.SEP)[-1]: value
|
|
73
|
+
for key, value in module.parameters_dict(recurse=False).items()
|
|
74
|
+
}
|
|
75
|
+
return params_dict
|
|
76
|
+
|
|
77
|
+
def _build_backward_pre_hook(self, hook_type, name):
|
|
78
|
+
def backward_pre_hook(module, grad_input):
|
|
79
|
+
if self.config.level != Const.LEVEL_L2:
|
|
80
|
+
return
|
|
81
|
+
if not self._should_execute_hook(hook_type, module, False):
|
|
82
|
+
return
|
|
83
|
+
BaseHookManager.inner_switch = True
|
|
84
|
+
module_input = ModuleBackwardInputs(grad_input=grad_input)
|
|
85
|
+
self.data_collector.update_api_or_module_name(name)
|
|
86
|
+
self.data_collector.backward_input_data_collect(name, module, self._pid, module_input)
|
|
87
|
+
BaseHookManager.inner_switch = False
|
|
88
|
+
return backward_pre_hook
|