mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
- mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +16 -21
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +185 -11
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +4 -0
- msprobe/core/common/utils.py +42 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +381 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +199 -69
- msprobe/core/data_dump/data_collector.py +100 -25
- msprobe/core/data_dump/data_processor/base.py +130 -28
- msprobe/core/data_dump/data_processor/factory.py +8 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
- msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
- msprobe/core/data_dump/json_writer.py +54 -8
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +121 -17
- msprobe/docs/02.config_introduction.md +18 -16
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +107 -58
- msprobe/docs/06.data_dump_MindSpore.md +95 -34
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +310 -220
- msprobe/docs/21.visualization_PyTorch.md +125 -35
- msprobe/docs/22.visualization_MindSpore.md +149 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +525 -0
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +11 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +68 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +28 -2
- msprobe/mindspore/debugger/precision_debugger.py +100 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +13 -3
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +347 -107
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -7
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +116 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +42 -17
- msprobe/pytorch/debugger/precision_debugger.py +56 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +98 -28
- msprobe/pytorch/monitor/csv2tb.py +164 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +543 -318
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +76 -56
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +84 -48
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +264 -115
- msprobe/visualization/builder/graph_builder.py +93 -10
- msprobe/visualization/builder/msprobe_adapter.py +30 -6
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +15 -19
- msprobe/visualization/graph/distributed_analyzer.py +395 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +100 -27
- msprobe/visualization/utils.py +24 -31
- mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
- msprobe/pytorch/functional/module_dump.py +0 -84
- msprobe/pytorch/module_processer.py +0 -150
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
msprobe/mindspore/service.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -20,6 +20,9 @@ from collections import defaultdict
|
|
|
20
20
|
|
|
21
21
|
import mindspore as ms
|
|
22
22
|
from mindspore import nn
|
|
23
|
+
from mindspore.common.api import _no_grad
|
|
24
|
+
from mindspore.ops.primitive import Primitive
|
|
25
|
+
|
|
23
26
|
try:
|
|
24
27
|
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
25
28
|
except ImportError:
|
|
@@ -27,19 +30,25 @@ except ImportError:
|
|
|
27
30
|
else:
|
|
28
31
|
pijit_label = True
|
|
29
32
|
|
|
30
|
-
|
|
31
33
|
from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
|
|
32
34
|
from msprobe.core.common.file_utils import create_directory
|
|
33
|
-
from msprobe.core.common.utils import Const, print_tools_ends_info
|
|
35
|
+
from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
|
|
34
36
|
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
35
|
-
from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs
|
|
37
|
+
from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,
|
|
38
|
+
ModuleBackwardInputs)
|
|
36
39
|
from msprobe.core.data_dump.scope import BaseScope
|
|
37
40
|
from msprobe.mindspore.cell_processor import CellProcessor
|
|
38
41
|
from msprobe.mindspore.common.log import logger
|
|
39
|
-
from msprobe.mindspore.common.utils import get_rank_if_initialized
|
|
42
|
+
from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs,
|
|
43
|
+
is_mindtorch, register_backward_hook_functions)
|
|
40
44
|
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
41
45
|
from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
|
|
42
46
|
from msprobe.mindspore.dump.jit_dump import JitDump
|
|
47
|
+
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
48
|
+
from msprobe.mindspore.dump.kernel_dump.kernel_config import create_kernel_config_json
|
|
49
|
+
|
|
50
|
+
if is_mindtorch():
|
|
51
|
+
import torch
|
|
43
52
|
|
|
44
53
|
|
|
45
54
|
class Service:
|
|
@@ -51,54 +60,155 @@ class Service:
|
|
|
51
60
|
self.cell_processor = CellProcessor(self.data_collector.scope)
|
|
52
61
|
self.primitive_hook_service = PrimitiveHookService(self)
|
|
53
62
|
self.switch = False
|
|
63
|
+
self.inner_switch = False
|
|
54
64
|
self.primitive_switch = False
|
|
55
65
|
self.current_iter = 0
|
|
56
66
|
self.first_start = True
|
|
57
67
|
self.current_rank = None
|
|
58
68
|
self.dump_iter_dir = None
|
|
59
69
|
self.start_call = False
|
|
60
|
-
self.check_level_valid()
|
|
61
70
|
self.should_stop_service = False
|
|
71
|
+
self.params_grad_info = {}
|
|
72
|
+
self.hook_handle_dict = {}
|
|
73
|
+
# 提前注册,确保注册尽可能多的API hook
|
|
74
|
+
self.register_api_hook()
|
|
75
|
+
self.init_for_debug_level()
|
|
62
76
|
|
|
63
77
|
@staticmethod
|
|
64
|
-
def check_model_valid(
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
)
|
|
78
|
+
def check_model_valid(models):
|
|
79
|
+
target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell")
|
|
80
|
+
if models is None or isinstance(models, target_module_type[0]):
|
|
81
|
+
return models
|
|
82
|
+
error_model = None
|
|
83
|
+
if isinstance(models, (list, tuple)):
|
|
84
|
+
for model in models:
|
|
85
|
+
if not isinstance(model, target_module_type[0]):
|
|
86
|
+
error_model = model
|
|
87
|
+
break
|
|
88
|
+
else:
|
|
89
|
+
error_model = models
|
|
70
90
|
|
|
71
|
-
|
|
72
|
-
|
|
91
|
+
if error_model is not None:
|
|
92
|
+
error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] "
|
|
93
|
+
f"type, currently there is a {type(error_model)} type.")
|
|
73
94
|
raise MsprobeException(
|
|
74
|
-
MsprobeException.INVALID_PARAM_ERROR,
|
|
75
|
-
|
|
95
|
+
MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
96
|
+
return models
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def prepare_module_input_output(target_type, cell, input_data, output):
|
|
100
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
101
|
+
module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output)
|
|
102
|
+
else:
|
|
103
|
+
module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output)
|
|
104
|
+
return module_input_output
|
|
76
105
|
|
|
77
106
|
def build_hook(self, target_type, name):
|
|
78
|
-
def
|
|
79
|
-
if not self.
|
|
80
|
-
|
|
81
|
-
del cell.input_kwargs
|
|
107
|
+
def pre_hook(api_or_cell_name, cell, input_data):
|
|
108
|
+
if not self.should_execute_hook(target_type, cell, True):
|
|
109
|
+
clean_input_kwargs(cell)
|
|
82
110
|
return None
|
|
83
111
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
112
|
+
with _no_grad():
|
|
113
|
+
self.inner_switch = True
|
|
114
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
115
|
+
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
116
|
+
else:
|
|
117
|
+
cell.forward_data_collected = True
|
|
118
|
+
HOOKCell.add_cell_count(name)
|
|
119
|
+
module_input_output = self.prepare_module_input_output(target_type, cell, input_data, None)
|
|
120
|
+
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
121
|
+
self.data_collector.forward_input_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
122
|
+
self.inner_switch = False
|
|
123
|
+
return input_data
|
|
124
|
+
|
|
125
|
+
def grad_hook(cell, ori_name, param_name):
|
|
126
|
+
def hook_fn(grad):
|
|
127
|
+
if not self.should_execute_hook(target_type, cell, False):
|
|
128
|
+
return None
|
|
129
|
+
self.inner_switch = True
|
|
130
|
+
self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
|
|
131
|
+
self.inner_switch = False
|
|
132
|
+
return None
|
|
90
133
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
134
|
+
return hook_fn
|
|
135
|
+
|
|
136
|
+
def register_param_hook(ori_name, cell, params_dict):
|
|
137
|
+
'''
|
|
138
|
+
注册参数hook
|
|
139
|
+
'''
|
|
140
|
+
# data_mode为forward时,不注册参数hook
|
|
141
|
+
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
142
|
+
for param_name, param in params_dict.items():
|
|
143
|
+
if param.requires_grad:
|
|
144
|
+
name = ori_name + Const.SEP + param_name
|
|
145
|
+
old_handle = self.hook_handle_dict.get(name)
|
|
146
|
+
if old_handle and hasattr(old_handle, "remove"):
|
|
147
|
+
old_handle.remove()
|
|
148
|
+
handle = param.register_hook(grad_hook(cell, ori_name, param_name))
|
|
149
|
+
self.hook_handle_dict[name] = handle
|
|
150
|
+
|
|
151
|
+
def init_params_grad_info(cell, params_dict):
|
|
152
|
+
'''
|
|
153
|
+
初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
|
|
154
|
+
'''
|
|
155
|
+
if not params_dict:
|
|
156
|
+
return
|
|
157
|
+
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
158
|
+
grad_name = cell.params_grad_name if hasattr(cell, 'params_grad_name') else None
|
|
159
|
+
# 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
|
|
160
|
+
if not self.params_grad_info.get(grad_name):
|
|
161
|
+
data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
|
|
162
|
+
# 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
|
|
163
|
+
if data_info.get(grad_name):
|
|
164
|
+
# 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
|
|
165
|
+
self.data_collector.handle_data(grad_name, data_info,
|
|
166
|
+
flush=self.data_collector.data_processor.is_terminated)
|
|
167
|
+
# 记录当前模块的参数梯度信息已占位
|
|
168
|
+
self.params_grad_info[grad_name] = True
|
|
169
|
+
|
|
170
|
+
def forward_hook(api_or_cell_name, cell, input_data, output):
|
|
171
|
+
if not self.should_execute_hook(target_type, cell, True):
|
|
172
|
+
clean_input_kwargs(cell)
|
|
173
|
+
return None
|
|
174
|
+
with _no_grad():
|
|
175
|
+
self.inner_switch = True
|
|
176
|
+
module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output)
|
|
177
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
178
|
+
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
179
|
+
params_dict = {}
|
|
180
|
+
if self.config.task != Const.STRUCTURE:
|
|
181
|
+
params_dict = {
|
|
182
|
+
key.split(Const.SEP)[-1]: value
|
|
183
|
+
for key, value in cell.parameters_dict(recurse=False).items()
|
|
184
|
+
}
|
|
185
|
+
setattr(module_input_output, Const.PARAMS, params_dict)
|
|
186
|
+
# 判断是否需要注册参数hook
|
|
187
|
+
if params_dict:
|
|
188
|
+
ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0]
|
|
189
|
+
grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
|
|
190
|
+
# 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
|
|
191
|
+
setattr(cell, 'params_grad_name', grad_name)
|
|
192
|
+
register_param_hook(ori_name, cell, params_dict)
|
|
193
|
+
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
194
|
+
self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
195
|
+
init_params_grad_info(cell, params_dict)
|
|
196
|
+
else:
|
|
197
|
+
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
198
|
+
self.data_collector.forward_output_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
199
|
+
|
|
200
|
+
if self.data_collector.if_return_forward_new_output():
|
|
201
|
+
forward_new_output = self.data_collector.get_forward_new_output()
|
|
202
|
+
self.inner_switch = False
|
|
203
|
+
return forward_new_output
|
|
204
|
+
clean_input_kwargs(cell)
|
|
205
|
+
self.inner_switch = False
|
|
206
|
+
return output
|
|
98
207
|
|
|
99
208
|
def backward_hook(api_or_cell_name, cell, grad_input, grad_output):
|
|
100
|
-
if not self.
|
|
209
|
+
if not self.should_execute_hook(target_type, cell, False):
|
|
101
210
|
return
|
|
211
|
+
self.inner_switch = True
|
|
102
212
|
|
|
103
213
|
need_exchange = True
|
|
104
214
|
if target_type == BaseScope.Module_Type_Module:
|
|
@@ -114,12 +224,32 @@ class Service:
|
|
|
114
224
|
else:
|
|
115
225
|
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
|
|
116
226
|
self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
227
|
+
self.inner_switch = False
|
|
228
|
+
|
|
229
|
+
def pre_backward_hook(api_or_cell_name, cell, grad_input):
|
|
230
|
+
if not self.should_execute_hook(target_type, cell, False):
|
|
231
|
+
return
|
|
232
|
+
self.inner_switch = True
|
|
233
|
+
module_input = ModuleBackwardInputs(grad_input=grad_input)
|
|
234
|
+
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
235
|
+
self.data_collector.backward_input_data_collect(api_or_cell_name, cell, pid, module_input)
|
|
236
|
+
|
|
237
|
+
self.inner_switch = False
|
|
117
238
|
|
|
118
239
|
pid = os.getpid()
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
240
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
241
|
+
full_forward_name = name + Const.FORWARD
|
|
242
|
+
full_backward_name = name + Const.BACKWARD
|
|
243
|
+
else:
|
|
244
|
+
full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD
|
|
245
|
+
full_backward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.BACKWARD
|
|
246
|
+
pre_forward_hook = functools.partial(pre_hook, full_forward_name)
|
|
247
|
+
forward_hook = functools.partial(forward_hook, full_forward_name)
|
|
248
|
+
backward_hook = functools.partial(backward_hook, full_backward_name)
|
|
249
|
+
pre_backward_hook = functools.partial(pre_backward_hook, full_backward_name)
|
|
250
|
+
|
|
251
|
+
def wrap_pre_forward_hook(cell, input_data):
|
|
252
|
+
return pre_forward_hook(cell, input_data)
|
|
123
253
|
|
|
124
254
|
def wrap_forward_hook(cell, input_data, output_data):
|
|
125
255
|
return forward_hook(cell, input_data, output_data)
|
|
@@ -127,7 +257,10 @@ class Service:
|
|
|
127
257
|
def wrap_backward_hook(cell, grad_input, grad_output):
|
|
128
258
|
return backward_hook(cell, grad_input, grad_output)
|
|
129
259
|
|
|
130
|
-
|
|
260
|
+
def wrap_pre_backward_hook(cell, grad_input):
|
|
261
|
+
return pre_backward_hook(cell, grad_input)
|
|
262
|
+
|
|
263
|
+
return wrap_pre_forward_hook, wrap_forward_hook, wrap_backward_hook, wrap_pre_backward_hook
|
|
131
264
|
|
|
132
265
|
def update_primitive_counters(self, primitive_name):
|
|
133
266
|
if primitive_name not in self.primitive_counters:
|
|
@@ -135,33 +268,25 @@ class Service:
|
|
|
135
268
|
else:
|
|
136
269
|
self.primitive_counters[primitive_name] += 1
|
|
137
270
|
|
|
138
|
-
def register_primitive_hooks(self):
|
|
139
|
-
primitive_set = set()
|
|
140
|
-
for _, cell in self.model.cells_and_names():
|
|
141
|
-
for pname, primitive in cell._primitives.items():
|
|
142
|
-
primitive_set.add((pname, primitive))
|
|
143
|
-
|
|
144
|
-
for pname, primitive in primitive_set:
|
|
145
|
-
primitive_class_name = primitive.__class__.__name__
|
|
146
|
-
primitive_combined_name = pname + Const.SEP + primitive_class_name
|
|
147
|
-
new_primitive = type('NewPrimitive', (primitive.__class__,),
|
|
148
|
-
{'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
|
|
149
|
-
primitive_combined_name)})
|
|
150
|
-
primitive.__class__ = new_primitive
|
|
151
|
-
|
|
152
271
|
def step(self):
|
|
272
|
+
if self.config.level == Const.LEVEL_DEBUG:
|
|
273
|
+
return
|
|
274
|
+
if self.config.async_dump:
|
|
275
|
+
self.data_collector.fill_stack_tensor_data()
|
|
276
|
+
if self.config.task == Const.TENSOR:
|
|
277
|
+
self.data_collector.data_processor.dump_async_data()
|
|
278
|
+
self.data_collector.write_json()
|
|
153
279
|
self.current_iter += 1
|
|
154
280
|
self.data_collector.update_iter(self.current_iter)
|
|
155
|
-
self.
|
|
156
|
-
self.data_collector.data_writer.reset_cache()
|
|
157
|
-
JitDump.jit_count = defaultdict(int)
|
|
281
|
+
self.reset_status()
|
|
158
282
|
|
|
159
283
|
def start(self, model=None):
|
|
284
|
+
if self.config.level == Const.LEVEL_DEBUG:
|
|
285
|
+
return
|
|
160
286
|
self.start_call = True
|
|
161
287
|
if self.should_stop_service:
|
|
162
288
|
return
|
|
163
289
|
if self.need_end_service():
|
|
164
|
-
api_register.api_set_ori_func()
|
|
165
290
|
self.should_stop_service = True
|
|
166
291
|
self.switch = False
|
|
167
292
|
self.primitive_switch = False
|
|
@@ -181,11 +306,15 @@ class Service:
|
|
|
181
306
|
|
|
182
307
|
if self.config.rank and self.current_rank not in self.config.rank:
|
|
183
308
|
return
|
|
184
|
-
self.
|
|
309
|
+
self.register_primitive_hook()
|
|
310
|
+
self.register_cell_hook()
|
|
185
311
|
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
186
312
|
JitDump.set_config(self.config)
|
|
187
313
|
JitDump.set_data_collector(self.data_collector)
|
|
188
|
-
ms.common.api
|
|
314
|
+
if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
|
|
315
|
+
ms.common.api._MindsporeFunctionExecutor = JitDump
|
|
316
|
+
else:
|
|
317
|
+
ms.common.api._JitExecutor = JitDump
|
|
189
318
|
ms.common.api._PyNativeExecutor.grad = JitDump.grad
|
|
190
319
|
if pijit_label:
|
|
191
320
|
PIJitCaptureContext.__enter__ = self.empty
|
|
@@ -200,26 +329,9 @@ class Service:
|
|
|
200
329
|
logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
|
|
201
330
|
JitDump.jit_dump_switch = True
|
|
202
331
|
|
|
203
|
-
def forward_backward_dump_end(self):
|
|
204
|
-
if self.should_stop_service:
|
|
205
|
-
return
|
|
206
|
-
logger.info(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() is set successfully. ")
|
|
207
|
-
if not self.start_call:
|
|
208
|
-
logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
|
|
209
|
-
raise Exception("debugger.start() is not set in the current scope.")
|
|
210
|
-
if not self.switch:
|
|
211
|
-
logger.error(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() should be called between "
|
|
212
|
-
"debugger.start() and debugger.stop() ")
|
|
213
|
-
raise Exception("debugger.stop() is already called. ")
|
|
214
|
-
if self.config.step and self.current_iter not in self.config.step:
|
|
215
|
-
return
|
|
216
|
-
if self.config.rank and self.current_rank not in self.config.rank:
|
|
217
|
-
return
|
|
218
|
-
self.primitive_switch = False
|
|
219
|
-
api_register.api_set_ori_func()
|
|
220
|
-
JitDump.jit_dump_switch = False
|
|
221
|
-
|
|
222
332
|
def stop(self):
|
|
333
|
+
if self.config.level == Const.LEVEL_DEBUG:
|
|
334
|
+
return
|
|
223
335
|
if self.should_stop_service:
|
|
224
336
|
return
|
|
225
337
|
logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. "
|
|
@@ -234,6 +346,10 @@ class Service:
|
|
|
234
346
|
self.switch = False
|
|
235
347
|
self.primitive_switch = False
|
|
236
348
|
self.start_call = False
|
|
349
|
+
if self.config.async_dump:
|
|
350
|
+
self.data_collector.fill_stack_tensor_data()
|
|
351
|
+
if self.config.task == Const.TENSOR:
|
|
352
|
+
self.data_collector.data_processor.dump_async_data()
|
|
237
353
|
self.data_collector.write_json()
|
|
238
354
|
JitDump.jit_dump_switch = False
|
|
239
355
|
|
|
@@ -244,8 +360,16 @@ class Service:
|
|
|
244
360
|
return True
|
|
245
361
|
return False
|
|
246
362
|
|
|
247
|
-
def
|
|
248
|
-
|
|
363
|
+
def should_execute_hook(self, hook_type, cell, is_forward):
|
|
364
|
+
is_cell_hook = hook_type == BaseScope.Module_Type_Module
|
|
365
|
+
if is_cell_hook and not self.switch:
|
|
366
|
+
return False
|
|
367
|
+
elif not is_cell_hook and is_forward and not self.switch:
|
|
368
|
+
return False
|
|
369
|
+
elif not is_cell_hook and not is_forward and not cell.forward_data_collected:
|
|
370
|
+
return False
|
|
371
|
+
|
|
372
|
+
if self.inner_switch:
|
|
249
373
|
return False
|
|
250
374
|
if not self.data_collector or self.data_collector.data_processor.is_terminated:
|
|
251
375
|
return False
|
|
@@ -255,6 +379,12 @@ class Service:
|
|
|
255
379
|
create_directory(self.config.dump_path)
|
|
256
380
|
self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
|
|
257
381
|
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
382
|
+
if self.config.level == Const.LEVEL_L2:
|
|
383
|
+
create_directory(self.dump_iter_dir)
|
|
384
|
+
kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
|
|
385
|
+
self.config.kernel_config_path = kernel_config_path
|
|
386
|
+
return
|
|
387
|
+
|
|
258
388
|
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
259
389
|
create_directory(dump_dir)
|
|
260
390
|
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
@@ -263,41 +393,151 @@ class Service:
|
|
|
263
393
|
else:
|
|
264
394
|
dump_data_dir = None
|
|
265
395
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
396
|
+
dump_path_aggregation = DumpPathAggregation()
|
|
397
|
+
dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
|
|
398
|
+
dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
|
|
399
|
+
dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
|
|
400
|
+
dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
|
|
401
|
+
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
402
|
+
|
|
403
|
+
self.data_collector.initialize_json_file(
|
|
404
|
+
framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
|
|
405
|
+
)
|
|
271
406
|
|
|
272
407
|
def empty(self, *args, **kwargs):
|
|
273
408
|
pass
|
|
274
409
|
|
|
275
|
-
def
|
|
276
|
-
|
|
277
|
-
|
|
410
|
+
def register_api_hook(self):
|
|
411
|
+
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
|
|
412
|
+
logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.")
|
|
278
413
|
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
|
|
279
414
|
api_register.api_set_hook_func()
|
|
280
|
-
if self.model and self.config.task in Const.DUMP_DATA_COLLECTION_LIST:
|
|
281
|
-
self.register_primitive_hooks()
|
|
282
415
|
|
|
416
|
+
def get_cells_and_names(self):
|
|
417
|
+
cells_and_names_with_index = {}
|
|
418
|
+
|
|
419
|
+
def get_cell_or_module(model):
|
|
420
|
+
return model.named_modules() if is_mindtorch() else model.cells_and_names()
|
|
421
|
+
|
|
422
|
+
if isinstance(self.model, (list, tuple)):
|
|
423
|
+
for index, model in enumerate(self.model):
|
|
424
|
+
cells_and_names_with_index[str(index)] = get_cell_or_module(model)
|
|
425
|
+
else:
|
|
426
|
+
cells_and_names_with_index["-1"] = get_cell_or_module(self.model)
|
|
427
|
+
return cells_and_names_with_index
|
|
428
|
+
|
|
429
|
+
def register_primitive_hook(self):
|
|
430
|
+
if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
431
|
+
return
|
|
432
|
+
if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST:
|
|
433
|
+
return
|
|
434
|
+
|
|
435
|
+
primitive_set = set()
|
|
436
|
+
cells_and_names_with_index = self.get_cells_and_names()
|
|
437
|
+
for cells_and_names in cells_and_names_with_index.values():
|
|
438
|
+
for _, cell in cells_and_names:
|
|
439
|
+
for attribute, value in vars(cell).items():
|
|
440
|
+
if isinstance(value, Primitive):
|
|
441
|
+
primitive_set.add((attribute, value))
|
|
442
|
+
|
|
443
|
+
for pname, primitive in primitive_set:
|
|
444
|
+
primitive_class_name = primitive.__class__.__name__
|
|
445
|
+
primitive_combined_name = pname + Const.SEP + primitive_class_name
|
|
446
|
+
new_primitive = type('NewPrimitive', (primitive.__class__,),
|
|
447
|
+
{'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
|
|
448
|
+
primitive_combined_name)})
|
|
449
|
+
primitive.__class__ = new_primitive
|
|
450
|
+
|
|
451
|
+
def register_cell_hook(self):
|
|
283
452
|
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]:
|
|
453
|
+
logger.info(f"The cell {self.config.task} hook function is successfully mounted to the model.")
|
|
284
454
|
if not self.model:
|
|
285
455
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
286
456
|
f"The current level is {self.config.level}, the model cannot be None")
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
457
|
+
model_type = Const.MODULE if is_mindtorch() else Const.CELL
|
|
458
|
+
cells_and_names_with_index = self.get_cells_and_names()
|
|
459
|
+
|
|
460
|
+
for index, cells_and_names in cells_and_names_with_index.items():
|
|
461
|
+
model = self.model if index == "-1" else self.model[int(index)]
|
|
462
|
+
for name, cell in cells_and_names:
|
|
463
|
+
if cell == model:
|
|
464
|
+
continue
|
|
465
|
+
cell_index = (index + Const.SEP) if index != "-1" else ""
|
|
466
|
+
prefix = (model_type + Const.SEP + cell_index + name +
|
|
467
|
+
Const.SEP + cell.__class__.__name__ + Const.SEP)
|
|
468
|
+
_, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix)
|
|
469
|
+
cell.register_forward_hook(forward_hook)
|
|
470
|
+
cell.register_forward_pre_hook(
|
|
471
|
+
self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
|
|
472
|
+
cell.register_forward_hook(
|
|
473
|
+
self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
|
|
474
|
+
|
|
475
|
+
register_backward_hook_functions["full"](cell, backward_hook)
|
|
476
|
+
register_backward_hook_functions["pre"](
|
|
477
|
+
cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
|
|
478
|
+
register_backward_hook_functions["full"](
|
|
479
|
+
cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
480
|
+
|
|
481
|
+
def reset_status(self):
|
|
482
|
+
self.primitive_hook_service.primitive_counters.clear()
|
|
483
|
+
self.data_collector.reset_status()
|
|
484
|
+
JitDump.jit_count = defaultdict(int)
|
|
485
|
+
self.params_grad_info.clear()
|
|
486
|
+
if self.config.level == Const.LEVEL_L2:
|
|
487
|
+
self.data_collector.data_processor.reset_status()
|
|
488
|
+
return
|
|
489
|
+
if self.config.step and self.current_iter not in self.config.step:
|
|
490
|
+
return
|
|
491
|
+
if self.config.rank and self.current_rank not in self.config.rank:
|
|
492
|
+
return
|
|
493
|
+
|
|
494
|
+
def init_for_debug_level(self):
|
|
495
|
+
if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
|
|
496
|
+
return
|
|
497
|
+
try:
|
|
498
|
+
self.current_rank = get_rank_if_initialized()
|
|
499
|
+
except DistributedNotInitializedError:
|
|
500
|
+
self.current_rank = None
|
|
501
|
+
# dir: dump_path -- rank{} -- debug.json
|
|
502
|
+
self.dump_iter_dir = self.config.dump_path
|
|
503
|
+
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
504
|
+
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
505
|
+
create_directory(dump_dir)
|
|
506
|
+
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
507
|
+
dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
|
|
508
|
+
create_directory(dump_data_dir)
|
|
509
|
+
else:
|
|
510
|
+
dump_data_dir = None
|
|
511
|
+
|
|
512
|
+
dump_path_aggregation = DumpPathAggregation()
|
|
513
|
+
dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
|
|
514
|
+
dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
|
|
515
|
+
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
516
|
+
self.data_collector.initialize_json_file(
|
|
517
|
+
framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
|
|
518
|
+
)
|
|
519
|
+
self.debug_variable_counter = defaultdict(int)
|
|
520
|
+
|
|
521
|
+
def save(self, variable, name, save_backward):
|
|
522
|
+
'''
|
|
523
|
+
Args:
|
|
524
|
+
variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int]
|
|
525
|
+
name: str
|
|
526
|
+
save_backward: boolean
|
|
527
|
+
Return:
|
|
528
|
+
void
|
|
529
|
+
'''
|
|
530
|
+
if self.config.level != Const.LEVEL_DEBUG:
|
|
531
|
+
return
|
|
532
|
+
count = self.debug_variable_counter[name]
|
|
533
|
+
self.debug_variable_counter[name] += 1
|
|
534
|
+
|
|
535
|
+
name_with_count = f"{name}.{count}"
|
|
536
|
+
grad_name_with_count = f"{name}_grad.{count}"
|
|
537
|
+
|
|
538
|
+
# forward save
|
|
539
|
+
self.data_collector.debug_data_collect_forward(variable, name_with_count)
|
|
540
|
+
|
|
541
|
+
# backward save
|
|
542
|
+
if save_backward:
|
|
543
|
+
self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)
|