mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
msprobe/mindspore/service.py
DELETED
|
@@ -1,543 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import copy
|
|
17
|
-
import functools
|
|
18
|
-
import os
|
|
19
|
-
from collections import defaultdict
|
|
20
|
-
|
|
21
|
-
import mindspore as ms
|
|
22
|
-
from mindspore import nn
|
|
23
|
-
from mindspore.common.api import _no_grad
|
|
24
|
-
from mindspore.ops.primitive import Primitive
|
|
25
|
-
|
|
26
|
-
try:
|
|
27
|
-
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
28
|
-
except ImportError:
|
|
29
|
-
pijit_label = False
|
|
30
|
-
else:
|
|
31
|
-
pijit_label = True
|
|
32
|
-
|
|
33
|
-
from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
|
|
34
|
-
from msprobe.core.common.file_utils import create_directory
|
|
35
|
-
from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
|
|
36
|
-
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
37
|
-
from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,
|
|
38
|
-
ModuleBackwardInputs)
|
|
39
|
-
from msprobe.core.data_dump.scope import BaseScope
|
|
40
|
-
from msprobe.mindspore.cell_processor import CellProcessor
|
|
41
|
-
from msprobe.mindspore.common.log import logger
|
|
42
|
-
from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs,
|
|
43
|
-
is_mindtorch, register_backward_hook_functions)
|
|
44
|
-
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
45
|
-
from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
|
|
46
|
-
from msprobe.mindspore.dump.jit_dump import JitDump
|
|
47
|
-
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
48
|
-
from msprobe.mindspore.dump.kernel_dump.kernel_config import create_kernel_config_json
|
|
49
|
-
|
|
50
|
-
if is_mindtorch():
|
|
51
|
-
import torch
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
class Service:
|
|
55
|
-
def __init__(self, config):
|
|
56
|
-
self.model = None
|
|
57
|
-
self.config = copy.deepcopy(config)
|
|
58
|
-
self.config.level = self.config.level_ori
|
|
59
|
-
self.data_collector = build_data_collector(self.config)
|
|
60
|
-
self.cell_processor = CellProcessor(self.data_collector.scope)
|
|
61
|
-
self.primitive_hook_service = PrimitiveHookService(self)
|
|
62
|
-
self.switch = False
|
|
63
|
-
self.inner_switch = False
|
|
64
|
-
self.primitive_switch = False
|
|
65
|
-
self.current_iter = 0
|
|
66
|
-
self.first_start = True
|
|
67
|
-
self.current_rank = None
|
|
68
|
-
self.dump_iter_dir = None
|
|
69
|
-
self.start_call = False
|
|
70
|
-
self.should_stop_service = False
|
|
71
|
-
self.params_grad_info = {}
|
|
72
|
-
self.hook_handle_dict = {}
|
|
73
|
-
# 提前注册,确保注册尽可能多的API hook
|
|
74
|
-
self.register_api_hook()
|
|
75
|
-
self.init_for_debug_level()
|
|
76
|
-
|
|
77
|
-
@staticmethod
|
|
78
|
-
def check_model_valid(models):
|
|
79
|
-
target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell")
|
|
80
|
-
if models is None or isinstance(models, target_module_type[0]):
|
|
81
|
-
return models
|
|
82
|
-
error_model = None
|
|
83
|
-
if isinstance(models, (list, tuple)):
|
|
84
|
-
for model in models:
|
|
85
|
-
if not isinstance(model, target_module_type[0]):
|
|
86
|
-
error_model = model
|
|
87
|
-
break
|
|
88
|
-
else:
|
|
89
|
-
error_model = models
|
|
90
|
-
|
|
91
|
-
if error_model is not None:
|
|
92
|
-
error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] "
|
|
93
|
-
f"type, currently there is a {type(error_model)} type.")
|
|
94
|
-
raise MsprobeException(
|
|
95
|
-
MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
96
|
-
return models
|
|
97
|
-
|
|
98
|
-
@staticmethod
|
|
99
|
-
def prepare_module_input_output(target_type, cell, input_data, output):
|
|
100
|
-
if target_type == BaseScope.Module_Type_Module:
|
|
101
|
-
module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output)
|
|
102
|
-
else:
|
|
103
|
-
module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output)
|
|
104
|
-
return module_input_output
|
|
105
|
-
|
|
106
|
-
def build_hook(self, target_type, name):
|
|
107
|
-
def pre_hook(api_or_cell_name, cell, input_data):
|
|
108
|
-
if not self.should_execute_hook(target_type, cell, True):
|
|
109
|
-
clean_input_kwargs(cell)
|
|
110
|
-
return None
|
|
111
|
-
|
|
112
|
-
with _no_grad():
|
|
113
|
-
self.inner_switch = True
|
|
114
|
-
if target_type == BaseScope.Module_Type_Module:
|
|
115
|
-
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
116
|
-
else:
|
|
117
|
-
cell.forward_data_collected = True
|
|
118
|
-
HOOKCell.add_cell_count(name)
|
|
119
|
-
module_input_output = self.prepare_module_input_output(target_type, cell, input_data, None)
|
|
120
|
-
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
121
|
-
self.data_collector.forward_input_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
122
|
-
self.inner_switch = False
|
|
123
|
-
return input_data
|
|
124
|
-
|
|
125
|
-
def grad_hook(cell, ori_name, param_name):
|
|
126
|
-
def hook_fn(grad):
|
|
127
|
-
if not self.should_execute_hook(target_type, cell, False):
|
|
128
|
-
return None
|
|
129
|
-
self.inner_switch = True
|
|
130
|
-
self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
|
|
131
|
-
self.inner_switch = False
|
|
132
|
-
return None
|
|
133
|
-
|
|
134
|
-
return hook_fn
|
|
135
|
-
|
|
136
|
-
def register_param_hook(ori_name, cell, params_dict):
|
|
137
|
-
'''
|
|
138
|
-
注册参数hook
|
|
139
|
-
'''
|
|
140
|
-
# data_mode为forward时,不注册参数hook
|
|
141
|
-
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
142
|
-
for param_name, param in params_dict.items():
|
|
143
|
-
if param.requires_grad:
|
|
144
|
-
name = ori_name + Const.SEP + param_name
|
|
145
|
-
old_handle = self.hook_handle_dict.get(name)
|
|
146
|
-
if old_handle and hasattr(old_handle, "remove"):
|
|
147
|
-
old_handle.remove()
|
|
148
|
-
handle = param.register_hook(grad_hook(cell, ori_name, param_name))
|
|
149
|
-
self.hook_handle_dict[name] = handle
|
|
150
|
-
|
|
151
|
-
def init_params_grad_info(cell, params_dict):
|
|
152
|
-
'''
|
|
153
|
-
初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
|
|
154
|
-
'''
|
|
155
|
-
if not params_dict:
|
|
156
|
-
return
|
|
157
|
-
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
158
|
-
grad_name = cell.params_grad_name if hasattr(cell, 'params_grad_name') else None
|
|
159
|
-
# 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
|
|
160
|
-
if not self.params_grad_info.get(grad_name):
|
|
161
|
-
data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
|
|
162
|
-
# 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
|
|
163
|
-
if data_info.get(grad_name):
|
|
164
|
-
# 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
|
|
165
|
-
self.data_collector.handle_data(grad_name, data_info,
|
|
166
|
-
flush=self.data_collector.data_processor.is_terminated)
|
|
167
|
-
# 记录当前模块的参数梯度信息已占位
|
|
168
|
-
self.params_grad_info[grad_name] = True
|
|
169
|
-
|
|
170
|
-
def forward_hook(api_or_cell_name, cell, input_data, output):
|
|
171
|
-
if not self.should_execute_hook(target_type, cell, True):
|
|
172
|
-
clean_input_kwargs(cell)
|
|
173
|
-
return None
|
|
174
|
-
with _no_grad():
|
|
175
|
-
self.inner_switch = True
|
|
176
|
-
module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output)
|
|
177
|
-
if target_type == BaseScope.Module_Type_Module:
|
|
178
|
-
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
179
|
-
params_dict = {}
|
|
180
|
-
if self.config.task != Const.STRUCTURE:
|
|
181
|
-
params_dict = {
|
|
182
|
-
key.split(Const.SEP)[-1]: value
|
|
183
|
-
for key, value in cell.parameters_dict(recurse=False).items()
|
|
184
|
-
}
|
|
185
|
-
setattr(module_input_output, Const.PARAMS, params_dict)
|
|
186
|
-
# 判断是否需要注册参数hook
|
|
187
|
-
if params_dict:
|
|
188
|
-
ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0]
|
|
189
|
-
grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
|
|
190
|
-
# 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
|
|
191
|
-
setattr(cell, 'params_grad_name', grad_name)
|
|
192
|
-
register_param_hook(ori_name, cell, params_dict)
|
|
193
|
-
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
194
|
-
self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
195
|
-
init_params_grad_info(cell, params_dict)
|
|
196
|
-
else:
|
|
197
|
-
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
198
|
-
self.data_collector.forward_output_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
199
|
-
|
|
200
|
-
if self.data_collector.if_return_forward_new_output():
|
|
201
|
-
forward_new_output = self.data_collector.get_forward_new_output()
|
|
202
|
-
self.inner_switch = False
|
|
203
|
-
return forward_new_output
|
|
204
|
-
clean_input_kwargs(cell)
|
|
205
|
-
self.inner_switch = False
|
|
206
|
-
return output
|
|
207
|
-
|
|
208
|
-
def backward_hook(api_or_cell_name, cell, grad_input, grad_output):
|
|
209
|
-
if not self.should_execute_hook(target_type, cell, False):
|
|
210
|
-
return
|
|
211
|
-
self.inner_switch = True
|
|
212
|
-
|
|
213
|
-
need_exchange = True
|
|
214
|
-
if target_type == BaseScope.Module_Type_Module:
|
|
215
|
-
if not hasattr(cell, 'has_pre_hook_called') or not cell.has_pre_hook_called:
|
|
216
|
-
need_exchange = False
|
|
217
|
-
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
218
|
-
|
|
219
|
-
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
220
|
-
if self.data_collector:
|
|
221
|
-
# 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入
|
|
222
|
-
if need_exchange:
|
|
223
|
-
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
|
|
224
|
-
else:
|
|
225
|
-
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
|
|
226
|
-
self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
227
|
-
self.inner_switch = False
|
|
228
|
-
|
|
229
|
-
def pre_backward_hook(api_or_cell_name, cell, grad_input):
|
|
230
|
-
if not self.should_execute_hook(target_type, cell, False):
|
|
231
|
-
return
|
|
232
|
-
self.inner_switch = True
|
|
233
|
-
module_input = ModuleBackwardInputs(grad_input=grad_input)
|
|
234
|
-
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
235
|
-
self.data_collector.backward_input_data_collect(api_or_cell_name, cell, pid, module_input)
|
|
236
|
-
|
|
237
|
-
self.inner_switch = False
|
|
238
|
-
|
|
239
|
-
pid = os.getpid()
|
|
240
|
-
if target_type == BaseScope.Module_Type_Module:
|
|
241
|
-
full_forward_name = name + Const.FORWARD
|
|
242
|
-
full_backward_name = name + Const.BACKWARD
|
|
243
|
-
else:
|
|
244
|
-
full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD
|
|
245
|
-
full_backward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.BACKWARD
|
|
246
|
-
pre_forward_hook = functools.partial(pre_hook, full_forward_name)
|
|
247
|
-
forward_hook = functools.partial(forward_hook, full_forward_name)
|
|
248
|
-
backward_hook = functools.partial(backward_hook, full_backward_name)
|
|
249
|
-
pre_backward_hook = functools.partial(pre_backward_hook, full_backward_name)
|
|
250
|
-
|
|
251
|
-
def wrap_pre_forward_hook(cell, input_data):
|
|
252
|
-
return pre_forward_hook(cell, input_data)
|
|
253
|
-
|
|
254
|
-
def wrap_forward_hook(cell, input_data, output_data):
|
|
255
|
-
return forward_hook(cell, input_data, output_data)
|
|
256
|
-
|
|
257
|
-
def wrap_backward_hook(cell, grad_input, grad_output):
|
|
258
|
-
return backward_hook(cell, grad_input, grad_output)
|
|
259
|
-
|
|
260
|
-
def wrap_pre_backward_hook(cell, grad_input):
|
|
261
|
-
return pre_backward_hook(cell, grad_input)
|
|
262
|
-
|
|
263
|
-
return wrap_pre_forward_hook, wrap_forward_hook, wrap_backward_hook, wrap_pre_backward_hook
|
|
264
|
-
|
|
265
|
-
def update_primitive_counters(self, primitive_name):
|
|
266
|
-
if primitive_name not in self.primitive_counters:
|
|
267
|
-
self.primitive_counters[primitive_name] = 0
|
|
268
|
-
else:
|
|
269
|
-
self.primitive_counters[primitive_name] += 1
|
|
270
|
-
|
|
271
|
-
def step(self):
|
|
272
|
-
if self.config.level == Const.LEVEL_DEBUG:
|
|
273
|
-
return
|
|
274
|
-
if self.config.async_dump:
|
|
275
|
-
self.data_collector.fill_stack_tensor_data()
|
|
276
|
-
if self.config.task == Const.TENSOR:
|
|
277
|
-
self.data_collector.data_processor.dump_async_data()
|
|
278
|
-
self.data_collector.write_json()
|
|
279
|
-
self.current_iter += 1
|
|
280
|
-
self.data_collector.update_iter(self.current_iter)
|
|
281
|
-
self.reset_status()
|
|
282
|
-
|
|
283
|
-
def start(self, model=None):
|
|
284
|
-
if self.config.level == Const.LEVEL_DEBUG:
|
|
285
|
-
return
|
|
286
|
-
self.start_call = True
|
|
287
|
-
if self.should_stop_service:
|
|
288
|
-
return
|
|
289
|
-
if self.need_end_service():
|
|
290
|
-
self.should_stop_service = True
|
|
291
|
-
self.switch = False
|
|
292
|
-
self.primitive_switch = False
|
|
293
|
-
print_tools_ends_info()
|
|
294
|
-
return
|
|
295
|
-
if self.config.step and self.current_iter not in self.config.step:
|
|
296
|
-
return
|
|
297
|
-
self.model = self.check_model_valid(model)
|
|
298
|
-
|
|
299
|
-
logger.info(f"{Const.TOOL_NAME}: debugger.start() is set successfully")
|
|
300
|
-
|
|
301
|
-
if self.first_start:
|
|
302
|
-
try:
|
|
303
|
-
self.current_rank = get_rank_if_initialized()
|
|
304
|
-
except DistributedNotInitializedError:
|
|
305
|
-
self.current_rank = None
|
|
306
|
-
|
|
307
|
-
if self.config.rank and self.current_rank not in self.config.rank:
|
|
308
|
-
return
|
|
309
|
-
self.register_primitive_hook()
|
|
310
|
-
self.register_cell_hook()
|
|
311
|
-
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
312
|
-
JitDump.set_config(self.config)
|
|
313
|
-
JitDump.set_data_collector(self.data_collector)
|
|
314
|
-
if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
|
|
315
|
-
ms.common.api._MindsporeFunctionExecutor = JitDump
|
|
316
|
-
else:
|
|
317
|
-
ms.common.api._JitExecutor = JitDump
|
|
318
|
-
ms.common.api._PyNativeExecutor.grad = JitDump.grad
|
|
319
|
-
if pijit_label:
|
|
320
|
-
PIJitCaptureContext.__enter__ = self.empty
|
|
321
|
-
PIJitCaptureContext.__exit__ = self.empty
|
|
322
|
-
self.first_start = False
|
|
323
|
-
|
|
324
|
-
api_register.api_set_hook_func()
|
|
325
|
-
self.switch = True
|
|
326
|
-
self.primitive_switch = True
|
|
327
|
-
logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
|
|
328
|
-
self.create_dirs()
|
|
329
|
-
logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
|
|
330
|
-
JitDump.jit_dump_switch = True
|
|
331
|
-
|
|
332
|
-
def stop(self):
|
|
333
|
-
if self.config.level == Const.LEVEL_DEBUG:
|
|
334
|
-
return
|
|
335
|
-
if self.should_stop_service:
|
|
336
|
-
return
|
|
337
|
-
logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. "
|
|
338
|
-
"Please set debugger.start() to turn on the dump switch again. ")
|
|
339
|
-
if not self.start_call:
|
|
340
|
-
logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
|
|
341
|
-
raise Exception("debugger.start() is not set in the current scope.")
|
|
342
|
-
if self.config.step and self.current_iter not in self.config.step:
|
|
343
|
-
return
|
|
344
|
-
if self.config.rank and self.current_rank not in self.config.rank:
|
|
345
|
-
return
|
|
346
|
-
self.switch = False
|
|
347
|
-
self.primitive_switch = False
|
|
348
|
-
self.start_call = False
|
|
349
|
-
if self.config.async_dump:
|
|
350
|
-
self.data_collector.fill_stack_tensor_data()
|
|
351
|
-
if self.config.task == Const.TENSOR:
|
|
352
|
-
self.data_collector.data_processor.dump_async_data()
|
|
353
|
-
self.data_collector.write_json()
|
|
354
|
-
JitDump.jit_dump_switch = False
|
|
355
|
-
|
|
356
|
-
def need_end_service(self):
|
|
357
|
-
if self.config.step and self.current_iter > max(self.config.step):
|
|
358
|
-
return True
|
|
359
|
-
if self.data_collector and self.data_collector.data_processor.is_terminated:
|
|
360
|
-
return True
|
|
361
|
-
return False
|
|
362
|
-
|
|
363
|
-
def should_execute_hook(self, hook_type, cell, is_forward):
|
|
364
|
-
is_cell_hook = hook_type == BaseScope.Module_Type_Module
|
|
365
|
-
if is_cell_hook and not self.switch:
|
|
366
|
-
return False
|
|
367
|
-
elif not is_cell_hook and is_forward and not self.switch:
|
|
368
|
-
return False
|
|
369
|
-
elif not is_cell_hook and not is_forward and not cell.forward_data_collected:
|
|
370
|
-
return False
|
|
371
|
-
|
|
372
|
-
if self.inner_switch:
|
|
373
|
-
return False
|
|
374
|
-
if not self.data_collector or self.data_collector.data_processor.is_terminated:
|
|
375
|
-
return False
|
|
376
|
-
return True
|
|
377
|
-
|
|
378
|
-
def create_dirs(self):
|
|
379
|
-
create_directory(self.config.dump_path)
|
|
380
|
-
self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
|
|
381
|
-
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
382
|
-
if self.config.level == Const.LEVEL_L2:
|
|
383
|
-
create_directory(self.dump_iter_dir)
|
|
384
|
-
kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
|
|
385
|
-
self.config.kernel_config_path = kernel_config_path
|
|
386
|
-
return
|
|
387
|
-
|
|
388
|
-
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
389
|
-
create_directory(dump_dir)
|
|
390
|
-
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
391
|
-
dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
|
|
392
|
-
create_directory(dump_data_dir)
|
|
393
|
-
else:
|
|
394
|
-
dump_data_dir = None
|
|
395
|
-
|
|
396
|
-
dump_path_aggregation = DumpPathAggregation()
|
|
397
|
-
dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
|
|
398
|
-
dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
|
|
399
|
-
dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
|
|
400
|
-
dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
|
|
401
|
-
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
402
|
-
|
|
403
|
-
self.data_collector.initialize_json_file(
|
|
404
|
-
framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
|
|
405
|
-
)
|
|
406
|
-
|
|
407
|
-
def empty(self, *args, **kwargs):
|
|
408
|
-
pass
|
|
409
|
-
|
|
410
|
-
def register_api_hook(self):
|
|
411
|
-
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
|
|
412
|
-
logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.")
|
|
413
|
-
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
|
|
414
|
-
api_register.api_set_hook_func()
|
|
415
|
-
|
|
416
|
-
def get_cells_and_names(self):
|
|
417
|
-
cells_and_names_with_index = {}
|
|
418
|
-
|
|
419
|
-
def get_cell_or_module(model):
|
|
420
|
-
return model.named_modules() if is_mindtorch() else model.cells_and_names()
|
|
421
|
-
|
|
422
|
-
if isinstance(self.model, (list, tuple)):
|
|
423
|
-
for index, model in enumerate(self.model):
|
|
424
|
-
cells_and_names_with_index[str(index)] = get_cell_or_module(model)
|
|
425
|
-
else:
|
|
426
|
-
cells_and_names_with_index["-1"] = get_cell_or_module(self.model)
|
|
427
|
-
return cells_and_names_with_index
|
|
428
|
-
|
|
429
|
-
def register_primitive_hook(self):
|
|
430
|
-
if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
431
|
-
return
|
|
432
|
-
if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST:
|
|
433
|
-
return
|
|
434
|
-
|
|
435
|
-
primitive_set = set()
|
|
436
|
-
cells_and_names_with_index = self.get_cells_and_names()
|
|
437
|
-
for cells_and_names in cells_and_names_with_index.values():
|
|
438
|
-
for _, cell in cells_and_names:
|
|
439
|
-
for attribute, value in vars(cell).items():
|
|
440
|
-
if isinstance(value, Primitive):
|
|
441
|
-
primitive_set.add((attribute, value))
|
|
442
|
-
|
|
443
|
-
for pname, primitive in primitive_set:
|
|
444
|
-
primitive_class_name = primitive.__class__.__name__
|
|
445
|
-
primitive_combined_name = pname + Const.SEP + primitive_class_name
|
|
446
|
-
new_primitive = type('NewPrimitive', (primitive.__class__,),
|
|
447
|
-
{'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
|
|
448
|
-
primitive_combined_name)})
|
|
449
|
-
primitive.__class__ = new_primitive
|
|
450
|
-
|
|
451
|
-
def register_cell_hook(self):
|
|
452
|
-
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]:
|
|
453
|
-
logger.info(f"The cell {self.config.task} hook function is successfully mounted to the model.")
|
|
454
|
-
if not self.model:
|
|
455
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
456
|
-
f"The current level is {self.config.level}, the model cannot be None")
|
|
457
|
-
model_type = Const.MODULE if is_mindtorch() else Const.CELL
|
|
458
|
-
cells_and_names_with_index = self.get_cells_and_names()
|
|
459
|
-
|
|
460
|
-
for index, cells_and_names in cells_and_names_with_index.items():
|
|
461
|
-
model = self.model if index == "-1" else self.model[int(index)]
|
|
462
|
-
for name, cell in cells_and_names:
|
|
463
|
-
if cell == model:
|
|
464
|
-
continue
|
|
465
|
-
cell_index = (index + Const.SEP) if index != "-1" else ""
|
|
466
|
-
prefix = (model_type + Const.SEP + cell_index + name +
|
|
467
|
-
Const.SEP + cell.__class__.__name__ + Const.SEP)
|
|
468
|
-
_, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix)
|
|
469
|
-
cell.register_forward_hook(forward_hook)
|
|
470
|
-
cell.register_forward_pre_hook(
|
|
471
|
-
self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
|
|
472
|
-
cell.register_forward_hook(
|
|
473
|
-
self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
|
|
474
|
-
|
|
475
|
-
register_backward_hook_functions["full"](cell, backward_hook)
|
|
476
|
-
register_backward_hook_functions["pre"](
|
|
477
|
-
cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
|
|
478
|
-
register_backward_hook_functions["full"](
|
|
479
|
-
cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
480
|
-
|
|
481
|
-
def reset_status(self):
|
|
482
|
-
self.primitive_hook_service.primitive_counters.clear()
|
|
483
|
-
self.data_collector.reset_status()
|
|
484
|
-
JitDump.jit_count = defaultdict(int)
|
|
485
|
-
self.params_grad_info.clear()
|
|
486
|
-
if self.config.level == Const.LEVEL_L2:
|
|
487
|
-
self.data_collector.data_processor.reset_status()
|
|
488
|
-
return
|
|
489
|
-
if self.config.step and self.current_iter not in self.config.step:
|
|
490
|
-
return
|
|
491
|
-
if self.config.rank and self.current_rank not in self.config.rank:
|
|
492
|
-
return
|
|
493
|
-
|
|
494
|
-
def init_for_debug_level(self):
|
|
495
|
-
if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
|
|
496
|
-
return
|
|
497
|
-
try:
|
|
498
|
-
self.current_rank = get_rank_if_initialized()
|
|
499
|
-
except DistributedNotInitializedError:
|
|
500
|
-
self.current_rank = None
|
|
501
|
-
# dir: dump_path -- rank{} -- debug.json
|
|
502
|
-
self.dump_iter_dir = self.config.dump_path
|
|
503
|
-
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
504
|
-
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
505
|
-
create_directory(dump_dir)
|
|
506
|
-
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
507
|
-
dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
|
|
508
|
-
create_directory(dump_data_dir)
|
|
509
|
-
else:
|
|
510
|
-
dump_data_dir = None
|
|
511
|
-
|
|
512
|
-
dump_path_aggregation = DumpPathAggregation()
|
|
513
|
-
dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
|
|
514
|
-
dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
|
|
515
|
-
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
516
|
-
self.data_collector.initialize_json_file(
|
|
517
|
-
framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
|
|
518
|
-
)
|
|
519
|
-
self.debug_variable_counter = defaultdict(int)
|
|
520
|
-
|
|
521
|
-
def save(self, variable, name, save_backward):
|
|
522
|
-
'''
|
|
523
|
-
Args:
|
|
524
|
-
variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int]
|
|
525
|
-
name: str
|
|
526
|
-
save_backward: boolean
|
|
527
|
-
Return:
|
|
528
|
-
void
|
|
529
|
-
'''
|
|
530
|
-
if self.config.level != Const.LEVEL_DEBUG:
|
|
531
|
-
return
|
|
532
|
-
count = self.debug_variable_counter[name]
|
|
533
|
-
self.debug_variable_counter[name] += 1
|
|
534
|
-
|
|
535
|
-
name_with_count = f"{name}.{count}"
|
|
536
|
-
grad_name_with_count = f"{name}_grad.{count}"
|
|
537
|
-
|
|
538
|
-
# forward save
|
|
539
|
-
self.data_collector.debug_data_collect_forward(variable, name_with_count)
|
|
540
|
-
|
|
541
|
-
# backward save
|
|
542
|
-
if save_backward:
|
|
543
|
-
self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)
|