mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -20,22 +20,24 @@ from collections import defaultdict
|
|
|
20
20
|
from datetime import datetime
|
|
21
21
|
|
|
22
22
|
import pytz
|
|
23
|
-
import
|
|
24
|
-
import mindspore
|
|
25
|
-
from mindspore import Tensor,
|
|
23
|
+
import pandas as pd
|
|
24
|
+
import mindspore
|
|
25
|
+
from mindspore import Tensor, mint
|
|
26
26
|
from mindspore import nn, _no_grad
|
|
27
|
-
from mindspore.communication import get_rank
|
|
28
27
|
|
|
29
28
|
from msprobe.core.common.log import logger
|
|
30
|
-
from msprobe.core.common.const import MonitorConst
|
|
31
|
-
from msprobe.core.common.file_utils import load_json
|
|
29
|
+
from msprobe.core.common.const import MonitorConst, Const
|
|
30
|
+
from msprobe.core.common.file_utils import load_json, save_json
|
|
31
|
+
from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter
|
|
32
|
+
from msprobe.mindspore.common.utils import is_mindtorch
|
|
33
|
+
from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank
|
|
32
34
|
from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \
|
|
33
|
-
is_skip_step, get_metrics,
|
|
34
|
-
from msprobe.mindspore.monitor.
|
|
35
|
-
from msprobe.mindspore.monitor.
|
|
36
|
-
|
|
37
|
-
from msprobe.
|
|
38
|
-
|
|
35
|
+
is_skip_step, get_metrics, get_target_output_dir
|
|
36
|
+
from msprobe.mindspore.monitor.optimizer_collect import OptimizerMonFactory
|
|
37
|
+
from msprobe.mindspore.monitor.data_writers import CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
38
|
+
from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate
|
|
39
|
+
from msprobe.core.common.file_utils import write_df_to_csv
|
|
40
|
+
from msprobe.core.common.utils import analyze_api_call_stack
|
|
39
41
|
|
|
40
42
|
FORMAT_MAPPING = {
|
|
41
43
|
MonitorConst.CSV: CSVWriterWithAD,
|
|
@@ -89,24 +91,11 @@ class ModuleHookContext:
|
|
|
89
91
|
self.actvgrad = []
|
|
90
92
|
self.module_name = module_name
|
|
91
93
|
self.struct = {}
|
|
92
|
-
self.
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
self.
|
|
96
|
-
self.
|
|
97
|
-
|
|
98
|
-
def set_format_by_arg(self, key_name: str, target_config: dict):
|
|
99
|
-
cared = target_config.get(self.module_name, self.struct)
|
|
100
|
-
if key_name in cared:
|
|
101
|
-
if isinstance(cared[key_name], dict):
|
|
102
|
-
# current cared is self.struct
|
|
103
|
-
config = cared[key_name].get('config')
|
|
104
|
-
self.format_by_arg[key_name] = config
|
|
105
|
-
else:
|
|
106
|
-
# current cared is target_config[self.module_name]
|
|
107
|
-
self.format_by_arg[key_name] = cared[key_name]
|
|
108
|
-
elif key_name in ['input', 'input_grad']:
|
|
109
|
-
self.ignore_in = True
|
|
94
|
+
self.stack = ""
|
|
95
|
+
|
|
96
|
+
def reset(self):
|
|
97
|
+
self.actv.clear()
|
|
98
|
+
self.actvgrad.clear()
|
|
110
99
|
|
|
111
100
|
|
|
112
101
|
start_step = 0
|
|
@@ -116,7 +105,6 @@ start_step = 0
|
|
|
116
105
|
class OptimizerContext:
|
|
117
106
|
def __init__(self) -> None:
|
|
118
107
|
self.step = start_step
|
|
119
|
-
self.param_effective_rank = defaultdict(float)
|
|
120
108
|
self.param_mg_direction = defaultdict(float)
|
|
121
109
|
self.param_adam_update = defaultdict()
|
|
122
110
|
self.param_adam_ratio = defaultdict()
|
|
@@ -131,6 +119,7 @@ class OptimizerContext:
|
|
|
131
119
|
def reset(self) -> None:
|
|
132
120
|
self.param_mg_direction.clear()
|
|
133
121
|
self.param_adam_update.clear()
|
|
122
|
+
self.param_adam_ratio.clear()
|
|
134
123
|
self.param_weight_grad.clear()
|
|
135
124
|
self.param_exp_avg.clear()
|
|
136
125
|
self.exp_avg_metric.clear()
|
|
@@ -179,50 +168,107 @@ class CommunicationContext:
|
|
|
179
168
|
|
|
180
169
|
class TrainerMon:
|
|
181
170
|
def __init__(self, config_file_path, process_group=None, params_have_main_grad=True) -> None:
|
|
171
|
+
# TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
|
|
172
|
+
self.config_file_path = config_file_path
|
|
173
|
+
self.process_group = process_group
|
|
174
|
+
self.params_have_main_grad = params_have_main_grad
|
|
175
|
+
self.is_mindtorch = is_mindtorch()
|
|
176
|
+
self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
|
|
177
|
+
self.config = load_json(config_file_path)
|
|
178
|
+
validate_config(self.config)
|
|
179
|
+
|
|
180
|
+
local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
|
|
181
|
+
cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
|
|
182
|
+
self.unique_id = str(uuid.uuid4())[:8]
|
|
183
|
+
self.output_base_dir = get_output_base_dir()
|
|
184
|
+
time_tags = self.config.get("append_output", [])
|
|
185
|
+
try:
|
|
186
|
+
self.rank = get_rank()
|
|
187
|
+
if time_tags:
|
|
188
|
+
output_append_dirs = get_target_output_dir(self.output_base_dir, time_tags[0], time_tags[1])
|
|
189
|
+
if str(self.rank) in output_append_dirs:
|
|
190
|
+
self.tensorboard_dir = output_append_dirs[str(self.rank)]
|
|
191
|
+
logger.info(f"Append rank({self.rank}) result to {self.tensorboard_dir}")
|
|
192
|
+
else:
|
|
193
|
+
self.tensorboard_dir = os.path.join(self.output_base_dir,
|
|
194
|
+
f"{cur_time}-rank{self.rank}-{self.unique_id}")
|
|
195
|
+
except Exception as e:
|
|
196
|
+
self.rank = 0
|
|
197
|
+
self.tensorboard_dir = os.path.join(self.output_base_dir, f"{cur_time}-rank{self.rank}-{self.unique_id}")
|
|
198
|
+
|
|
199
|
+
self.pp_stage = 0
|
|
200
|
+
self.group_mates = [0]
|
|
201
|
+
|
|
202
|
+
# TYPE2: 只会在set_monitor()主调中赋值的变量
|
|
203
|
+
self.model = None
|
|
204
|
+
self.vpp = False
|
|
205
|
+
self.dp_group = None
|
|
206
|
+
self.tp_group = None
|
|
207
|
+
self.micro_batch_number = 1
|
|
208
|
+
self.optimizer_mon = None
|
|
209
|
+
|
|
210
|
+
# TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
|
|
182
211
|
self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
183
212
|
self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
184
213
|
self.optimizer_context = defaultdict(OptimizerContext)
|
|
185
214
|
self.cc_context = defaultdict(CommunicationContext)
|
|
186
215
|
self.grad_context = GradContext()
|
|
187
|
-
self.params_have_main_grad = params_have_main_grad
|
|
188
216
|
self.handles = defaultdict(list)
|
|
189
|
-
self.
|
|
190
|
-
|
|
217
|
+
self.param2name = defaultdict(str)
|
|
218
|
+
self.name2index = defaultdict()
|
|
219
|
+
self.name2indices = defaultdict()
|
|
220
|
+
self.name2param = {}
|
|
221
|
+
self.duplicate_param = {}
|
|
222
|
+
self.name2tag = {}
|
|
223
|
+
self.param_name_call_id = {}
|
|
224
|
+
self.call_id = 0
|
|
225
|
+
self.module_struct = defaultdict(dict)
|
|
226
|
+
self.grad_accs = []
|
|
227
|
+
self.weight_hooked = False
|
|
228
|
+
self.optimizer_hooked = False
|
|
229
|
+
self.param_registered = False
|
|
230
|
+
self.struct_printed = False
|
|
231
|
+
self.pre_step_hooks = []
|
|
232
|
+
self.post_step_hooks = []
|
|
233
|
+
|
|
234
|
+
# 动静态区分
|
|
235
|
+
self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
|
|
236
|
+
if self.dynamic_enable:
|
|
237
|
+
logger.warning(f"DYNAMIC_MONITOR is set, "
|
|
238
|
+
f"please make sure you have 'dynamic_on' and 'collect_times' in {self.config_file_path}")
|
|
239
|
+
self.monitoring = False
|
|
240
|
+
else:
|
|
241
|
+
self.set_config()
|
|
242
|
+
# 静态且collect_times>0时在第0步self.monitoring就可以True, 动态默认在下一步开启
|
|
243
|
+
if self.collect_times > 0:
|
|
244
|
+
self.monitoring = True
|
|
191
245
|
|
|
246
|
+
def set_config(self):
|
|
192
247
|
self.start_step = self.config.get("start_step", 0)
|
|
193
248
|
self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集
|
|
194
249
|
self.step_interval = self.config.get("step_interval", 1)
|
|
195
|
-
self.has_collect_times = 0
|
|
196
|
-
|
|
197
|
-
# monitor target in module, such as layer, weight, grad
|
|
250
|
+
self.has_collect_times = 0 # 重设采集计数器
|
|
251
|
+
self.print_struct = self.config.get("print_struct", False)
|
|
198
252
|
self.targets = self.config.get("targets", None)
|
|
199
253
|
self.is_select = self.config.get("is_select", False)
|
|
200
254
|
self.module_rank_list = self.config.get("module_ranks", [])
|
|
201
|
-
# only csv supported in mindspore
|
|
202
|
-
self.format = self.config.get('format', MonitorConst.CSV)
|
|
255
|
+
self.format = self.config.get('format', MonitorConst.CSV) # only csv supported in mindspore
|
|
203
256
|
self.eps = self.config.get('eps', 1e-8)
|
|
204
|
-
# monitor mean/max/norm/min/nan...
|
|
205
|
-
self.ops = self.config.get('ops', [])
|
|
257
|
+
self.ops = self.config.get('ops', []) # monitor mean/max/norm/min/nan...
|
|
206
258
|
self.ndigits = self.config.get('ndigits', 6)
|
|
207
259
|
self.all_xy = self.config.get('all_xy', False)
|
|
208
|
-
# module input/output input_grad/output_grad
|
|
209
260
|
self.xy_distribution = self.config.get('xy_distribution', False)
|
|
210
|
-
# activation forward
|
|
211
261
|
self.forward_only = self.config.get('forward_only', False)
|
|
212
|
-
# activation backward
|
|
213
262
|
self.backward_only = self.config.get('backward_only', False)
|
|
214
|
-
#
|
|
215
|
-
self.
|
|
216
|
-
# m/v of adam
|
|
217
|
-
self.mv_distribution = self.config.get("mv_distribution", False)
|
|
218
|
-
# weight grad
|
|
263
|
+
self.ur_distribution = self.config.get('ur_distribution', False) # vector and ratio vector of adam
|
|
264
|
+
self.mv_distribution = self.config.get("mv_distribution", False) # m/v of adam
|
|
219
265
|
self.wg_distribution = self.config.get("wg_distribution", False)
|
|
220
|
-
# optimizer param
|
|
221
266
|
self.param_distribution = self.config.get("param_distribution", False)
|
|
222
|
-
# main grad direction
|
|
223
|
-
self.
|
|
224
|
-
|
|
225
|
-
self.
|
|
267
|
+
self.mg_direction = self.config.get('mg_direction', False) # main grad direction
|
|
268
|
+
self.cc_distribution = self.config.get("cc_distribution", {}) # communication ops
|
|
269
|
+
self.stack_info = self.config.get('stack_info', False)
|
|
270
|
+
self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False)
|
|
271
|
+
|
|
226
272
|
if not self.cc_distribution.get('enable', False):
|
|
227
273
|
self.cc_log_only = False
|
|
228
274
|
else:
|
|
@@ -230,167 +276,227 @@ class TrainerMon:
|
|
|
230
276
|
self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
|
|
231
277
|
self.cc_logged_stack = defaultdict(set)
|
|
232
278
|
self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
|
|
233
|
-
self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
234
|
-
api_register.redirect_api()
|
|
235
279
|
self.common_info()
|
|
236
280
|
|
|
281
|
+
# 初始化AnomalyData工厂
|
|
237
282
|
alert_setting = self.config.get('alert', {"rules": []})
|
|
238
283
|
self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
|
|
239
|
-
|
|
240
|
-
local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
|
|
241
|
-
|
|
242
|
-
cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
|
|
243
|
-
unique_id = str(uuid.uuid4())[:8]
|
|
244
|
-
output_base_dir = get_output_base_dir()
|
|
245
|
-
|
|
246
|
-
time_tags = self.config.get("append_output", [])
|
|
247
|
-
if time_tags:
|
|
248
|
-
output_append_dirs = get_target_output_dir(output_base_dir, time_tags[0], time_tags[1])
|
|
249
|
-
try:
|
|
250
|
-
rank = get_rank()
|
|
251
|
-
except Exception as e:
|
|
252
|
-
rank = 0
|
|
253
|
-
tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}")
|
|
254
|
-
logger.error(f"Failed to get rank, setting tensorboard_dir to {tensorboard_dir}")
|
|
255
|
-
pp_stage = 0
|
|
256
|
-
group_mates = [0]
|
|
257
|
-
else:
|
|
258
|
-
if time_tags and str(rank) in output_append_dirs:
|
|
259
|
-
tensorboard_dir = outputappenddirs[str(rank)]
|
|
260
|
-
logger.info(f"Append rank({rank}) result to {tensorboard_dir}")
|
|
261
|
-
else:
|
|
262
|
-
tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}")
|
|
263
|
-
pp_stage = 0
|
|
264
|
-
group_mates = [0]
|
|
265
|
-
|
|
266
|
-
self.rank = rank
|
|
267
|
-
|
|
268
|
-
# 初始化AnomalyData工厂
|
|
269
284
|
self.anomaly_data_factory = None
|
|
270
285
|
if alert_setting.get('dump', False):
|
|
271
|
-
self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates)
|
|
286
|
+
self.anomaly_data_factory = AnomalyDataFactory(self.rank, self.pp_stage, self.group_mates)
|
|
272
287
|
|
|
288
|
+
# 初始化writer, 创建输出目录
|
|
273
289
|
if self.format not in FORMAT_MAPPING:
|
|
274
290
|
logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
|
|
275
291
|
self.format = MonitorConst.CSV
|
|
276
|
-
writer = FORMAT_MAPPING[self.format]
|
|
277
292
|
self.step_count_per_record = self.config.get('step_count_per_record', 1)
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
293
|
+
if not self.module_rank_list or (self.rank in self.module_rank_list):
|
|
294
|
+
writer = FORMAT_MAPPING[self.format]
|
|
295
|
+
self.summary_writer = writer(
|
|
296
|
+
WriterInput(
|
|
297
|
+
self.tensorboard_dir,
|
|
298
|
+
self.alert_rules,
|
|
299
|
+
self.unique_id,
|
|
300
|
+
self.anomaly_data_factory,
|
|
301
|
+
self.ndigits,
|
|
302
|
+
self.step_count_per_record
|
|
303
|
+
)
|
|
287
304
|
)
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
self.micro_batch_number = 1
|
|
291
|
-
|
|
292
|
-
self.model = None
|
|
293
|
-
self.weight_hooked = False
|
|
294
|
-
self.optimizer_hooked = False
|
|
295
|
-
self.param_registered = False
|
|
296
|
-
self.vpp = False
|
|
297
|
-
self.dp_group = None
|
|
298
|
-
self.tp_group = None
|
|
299
|
-
self.enable_megatron = False
|
|
300
305
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
self.duplicate_param = {}
|
|
307
|
-
self.name2tag = {}
|
|
308
|
-
self.call_id = 0
|
|
309
|
-
self.grad_accs = []
|
|
310
|
-
self.handles = defaultdict(list)
|
|
306
|
+
# 初始化anomaly detected文件目录
|
|
307
|
+
if self.anomaly_data_factory:
|
|
308
|
+
self.anomaly_data_writer = AnomalyDataWriter(os.path.join(self.output_base_dir, "anomaly_detected"),
|
|
309
|
+
self.rank)
|
|
310
|
+
self.anomaly_data_writer.init_detected_json()
|
|
311
311
|
|
|
312
|
-
|
|
313
|
-
self.
|
|
314
|
-
|
|
312
|
+
def common_info(self):
|
|
313
|
+
if not self.xy_distribution:
|
|
314
|
+
logger.info("> module input/output input_grad/output_grad is not monitored. ")
|
|
315
|
+
if self.forward_only:
|
|
316
|
+
logger.info("> only module forward is monitored. ")
|
|
317
|
+
if not self.ur_distribution:
|
|
318
|
+
logger.info("> update vector and ratio vector of adam is not monitored. ")
|
|
319
|
+
if not self.mv_distribution:
|
|
320
|
+
logger.info("> momentum and variance of adam is not monitored. ")
|
|
321
|
+
if not self.wg_distribution:
|
|
322
|
+
logger.info("> weight grad of specified module is not monitored. ")
|
|
323
|
+
if not self.mg_direction:
|
|
324
|
+
logger.info('> grad and momentum direction will not be compared.')
|
|
325
|
+
if not self.cc_distribution.get('enable', False):
|
|
326
|
+
logger.info("> cc operator is not monitored.")
|
|
315
327
|
|
|
316
|
-
# Start
|
|
317
328
|
def set_monitor(
|
|
318
329
|
self,
|
|
319
330
|
model,
|
|
331
|
+
optimizer,
|
|
320
332
|
grad_acc_steps=1,
|
|
321
|
-
optimizer=None,
|
|
322
333
|
tp_group=None,
|
|
323
334
|
dp_group=None,
|
|
324
|
-
start_iteration=0
|
|
335
|
+
start_iteration=0
|
|
336
|
+
):
|
|
325
337
|
global start_step
|
|
326
338
|
start_step = start_iteration
|
|
327
|
-
logger.info(f'grad acc steps {grad_acc_steps}')
|
|
328
|
-
self.hook_optimizer(optimizer)
|
|
329
339
|
self.micro_batch_number = grad_acc_steps
|
|
330
340
|
self.dp_group = dp_group
|
|
331
341
|
self.tp_group = tp_group
|
|
342
|
+
self.optimizer_mon = OptimizerMonFactory.create_optimizer_mon(optimizer)
|
|
343
|
+
self.hook_step_final(optimizer)
|
|
344
|
+
if not isinstance(model, list):
|
|
345
|
+
model = [model]
|
|
346
|
+
self.model = model
|
|
347
|
+
if len(model) > 1:
|
|
348
|
+
self.vpp = True
|
|
349
|
+
logger.info('vpp enabled')
|
|
350
|
+
if not self.dynamic_enable:
|
|
351
|
+
self.register_hooks(optimizer)
|
|
352
|
+
|
|
353
|
+
def hook_step_final(self, optimizer):
|
|
354
|
+
def step_final_hook(optimizer, *args, **kwargs):
|
|
355
|
+
context = self.optimizer_context[optimizer]
|
|
356
|
+
# 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
|
|
357
|
+
if self.monitoring:
|
|
358
|
+
module_rank_valid = self.is_target_rank()
|
|
359
|
+
step_condition = (context.step >= self.start_step and (
|
|
360
|
+
context.step - self.start_step) % self.step_interval == 0)
|
|
361
|
+
if module_rank_valid and step_condition:
|
|
362
|
+
self.has_collect_times += 1
|
|
363
|
+
|
|
364
|
+
if self.anomaly_data_factory:
|
|
365
|
+
self.anomaly_data_factory.set_call_id(self.param_name_call_id)
|
|
366
|
+
self.write_xy_tb(context.step)
|
|
367
|
+
self.write_grad_tb(context.step)
|
|
368
|
+
self.write_mv_tb(context)
|
|
369
|
+
self.write_param_tb(context)
|
|
370
|
+
if self.stack_info:
|
|
371
|
+
self.write_stack_info()
|
|
372
|
+
self.stack_info = False
|
|
373
|
+
for handle in self.handles["stack"]:
|
|
374
|
+
handle.remove()
|
|
375
|
+
self.handles["stack"].clear()
|
|
376
|
+
|
|
377
|
+
if context.metric_dict:
|
|
378
|
+
self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
|
|
379
|
+
context.metric_dict.clear()
|
|
380
|
+
|
|
381
|
+
if self.anomaly_data_factory:
|
|
382
|
+
self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
|
|
383
|
+
self.summary_writer.clear_anomalies()
|
|
384
|
+
|
|
385
|
+
self.call_id = 0
|
|
386
|
+
self.param_name_call_id.clear()
|
|
387
|
+
|
|
388
|
+
if self.has_collect_times >= self.collect_times:
|
|
389
|
+
self._remove_all_hooks_final(optimizer)
|
|
332
390
|
|
|
333
|
-
|
|
334
|
-
|
|
391
|
+
context.step += 1
|
|
392
|
+
self.dynamic_monitor(optimizer)
|
|
335
393
|
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
394
|
+
|
|
395
|
+
def patch_step(func, optimizer):
|
|
396
|
+
def wrapper(*args, **kwargs):
|
|
397
|
+
for hook in self.pre_step_hooks:
|
|
398
|
+
hook(optimizer, args, kwargs)
|
|
399
|
+
out = func(*args, **kwargs)
|
|
400
|
+
for hook in self.post_step_hooks:
|
|
401
|
+
hook(optimizer, args, kwargs)
|
|
402
|
+
step_final_hook(optimizer, args, kwargs)
|
|
403
|
+
return out
|
|
404
|
+
return wrapper
|
|
405
|
+
|
|
406
|
+
if self.is_mindtorch:
|
|
407
|
+
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
408
|
+
else:
|
|
409
|
+
optimizer.__class__.construct = patch_step(optimizer.__class__.construct, optimizer)
|
|
410
|
+
|
|
411
|
+
return
|
|
412
|
+
|
|
413
|
+
def dynamic_monitor(self, optimizer):
|
|
414
|
+
"""
|
|
415
|
+
If dynamic monitor enabled and config.json updated,
|
|
416
|
+
remove hooks and register new hooks according to new configuration.
|
|
417
|
+
"""
|
|
418
|
+
context = self.optimizer_context[optimizer]
|
|
419
|
+
if not self.dynamic_enable:
|
|
420
|
+
return
|
|
421
|
+
try:
|
|
422
|
+
# 如果文件时间戳没变, 可以不读取节省时间
|
|
423
|
+
config_timestamp = os.path.getmtime(self.config_file_path)
|
|
424
|
+
if config_timestamp == self.config_timestamp:
|
|
425
|
+
return
|
|
426
|
+
# 更新config文件最新修改时间戳
|
|
427
|
+
self.config_timestamp = config_timestamp
|
|
428
|
+
config = load_json(self.config_file_path)
|
|
429
|
+
except Exception as e:
|
|
430
|
+
logger.error(f"get config.json wrong because {e}, not updated, please check!!!")
|
|
342
431
|
return
|
|
343
432
|
|
|
433
|
+
if config.get("dynamic_on", False):
|
|
434
|
+
try:
|
|
435
|
+
validate_config(config)
|
|
436
|
+
self.config = config
|
|
437
|
+
self.set_config()
|
|
438
|
+
self.start_step = context.step # 动态启停时不受原start_step影响,永远从下一步开始
|
|
439
|
+
logger.warning(f"config is updated at step{context.step - 1}, "
|
|
440
|
+
f"will start new hook at step{context.step}.")
|
|
441
|
+
except Exception as e:
|
|
442
|
+
logger.error(f"set config wrong because {e}, not updated, please check!!!")
|
|
443
|
+
return
|
|
444
|
+
|
|
445
|
+
self._remove_all_hooks(optimizer)
|
|
446
|
+
self.register_hooks(optimizer)
|
|
447
|
+
|
|
448
|
+
def register_hooks(self, optimizer):
|
|
449
|
+
self._register_param_name()
|
|
450
|
+
self.hook_modules()
|
|
451
|
+
self.hook_optimizer(optimizer)
|
|
452
|
+
self._patch_grad_sync()
|
|
453
|
+
if self.cc_distribution.get('enable', False):
|
|
454
|
+
self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
455
|
+
api_register.redirect_api()
|
|
456
|
+
self.monitoring = True
|
|
457
|
+
|
|
458
|
+
def hook_modules(self):
|
|
344
459
|
if not self.is_target_rank():
|
|
345
460
|
return
|
|
461
|
+
module_in_all_stage = [key for key in self.targets.keys() if MonitorConst.NAME_SEP not in key]
|
|
346
462
|
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
for param in optimizer.get_parameters():
|
|
352
|
-
if MonitorConst.EXP_AVG_SQ in param.name:
|
|
353
|
-
v_list.append(param)
|
|
354
|
-
elif MonitorConst.EXP_AVG in param.name:
|
|
355
|
-
m_list.append(param)
|
|
356
|
-
else:
|
|
357
|
-
param_list.append(param)
|
|
358
|
-
grad_names.append(param.name)
|
|
463
|
+
for key in module_in_all_stage:
|
|
464
|
+
struct = self.targets.pop(key)
|
|
465
|
+
self.targets.update(
|
|
466
|
+
{f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
|
|
359
467
|
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
468
|
+
hooked_count = 0
|
|
469
|
+
for vpp_stage, model_chunk in enumerate(self.model):
|
|
470
|
+
if not is_valid_instance(model_chunk):
|
|
471
|
+
logger.info("Target Model is not Cell")
|
|
472
|
+
continue
|
|
473
|
+
vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
474
|
+
targets = [x for x, _ in get_submodules(model_chunk)] if self.print_struct else self.targets.keys()
|
|
475
|
+
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
476
|
+
logger.info(f"> {hooked_count} modules are monitored.")
|
|
477
|
+
|
|
478
|
+
def hook_optimizer(self, optimizer):
|
|
479
|
+
def optimizer_pre_step_hook(opt, *args, **kwargs):
|
|
365
480
|
context = self.optimizer_context[opt]
|
|
366
|
-
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
|
|
481
|
+
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
|
|
367
482
|
self.collect_times):
|
|
368
483
|
return
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
for param in v_list:
|
|
386
|
-
name = param.name
|
|
387
|
-
if is_select and name not in self.targets:
|
|
388
|
-
continue
|
|
389
|
-
get_single_metrics(self.ops, name, param, context.exp_avg_sq_metric)
|
|
390
|
-
if self.param_distribution:
|
|
391
|
-
for param in param_list:
|
|
392
|
-
get_single_metrics(self.ops, param.name, param, context.param_metric)
|
|
393
|
-
self.generate_wgrad_metrics()
|
|
484
|
+
|
|
485
|
+
grad_dict = {}
|
|
486
|
+
if self.wg_distribution:
|
|
487
|
+
grad_dict = self.optimizer_mon.fetch_grad(self, self.param2name)
|
|
488
|
+
|
|
489
|
+
if self.mv_distribution or self.ur_distribution or self.mg_direction:
|
|
490
|
+
if self.is_mindtorch:
|
|
491
|
+
context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, \
|
|
492
|
+
context.param_adam_ratio = self.optimizer_mon.fetch_mv(self, self.param2name)
|
|
493
|
+
else:
|
|
494
|
+
context.param_exp_avg, context.param_exp_avg_sq = self.get_mv_for_ms(optimizer)
|
|
495
|
+
|
|
496
|
+
self.generate_wgrad_metrics(grad_dict)
|
|
497
|
+
self.generate_mv_metrics(context)
|
|
498
|
+
self.generate_param_metrics(context, MonitorConst.PRE_PARAM)
|
|
499
|
+
|
|
394
500
|
metric_dict = {}
|
|
395
501
|
for cc in self.cc_context.values():
|
|
396
502
|
cc.aggregate()
|
|
@@ -402,191 +508,167 @@ class TrainerMon:
|
|
|
402
508
|
context.metric_dict = metric_dict
|
|
403
509
|
return
|
|
404
510
|
|
|
405
|
-
def
|
|
406
|
-
context = self.optimizer_context[
|
|
407
|
-
|
|
408
|
-
self.has_collect_times, self.collect_times)
|
|
409
|
-
if step_skip:
|
|
410
|
-
context.step += 1
|
|
411
|
-
return
|
|
412
|
-
self.write_xy_tb(context.step)
|
|
413
|
-
self.write_grad_tb(context.step)
|
|
414
|
-
self.write_mv_tb(context)
|
|
415
|
-
self.write_param_tb(context)
|
|
416
|
-
|
|
417
|
-
if context.metric_dict:
|
|
418
|
-
self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
|
|
419
|
-
context.metric_dict.clear()
|
|
420
|
-
self.has_collect_times += 1
|
|
421
|
-
context.step += 1
|
|
422
|
-
if self.anomaly_data_factory:
|
|
423
|
-
self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
|
|
424
|
-
self.summary_writer.clear_anomalies()
|
|
425
|
-
self.call_id = 0
|
|
426
|
-
self.param_name_call_id.clear()
|
|
427
|
-
return
|
|
511
|
+
def optimizer_post_step_hook(optimizer, args, kwargs):
|
|
512
|
+
context = self.optimizer_context[optimizer]
|
|
513
|
+
self.generate_param_metrics(context, MonitorConst.POST_PARAM)
|
|
428
514
|
|
|
429
|
-
def optimizer_pre_hook_wrapper(func, grad_names):
|
|
430
|
-
def wrapper(opt, gradients):
|
|
431
|
-
return func(opt, grad_names, gradients)
|
|
432
|
-
return wrapper
|
|
433
515
|
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
return func(opt, args, gradients, outputs)
|
|
437
|
-
return wrapper
|
|
438
|
-
|
|
439
|
-
optimizer.register_forward_pre_hook(optimizer_pre_hook_wrapper(optimizer_pre_hook_function, grad_names))
|
|
440
|
-
optimizer.register_forward_hook(optimizer_post_hook_wrapper(optimizer_post_hook_function))
|
|
516
|
+
if self.optimizer_hooked or not self.is_target_rank():
|
|
517
|
+
return
|
|
441
518
|
|
|
519
|
+
self.pre_step_hooks.append(optimizer_pre_step_hook)
|
|
520
|
+
self.post_step_hooks.append(optimizer_post_step_hook)
|
|
442
521
|
self.optimizer_hooked = True
|
|
443
522
|
return
|
|
444
523
|
|
|
524
|
+
def generate_wgrad_metrics(self, grad_dict):
|
|
525
|
+
if not self.wg_distribution:
|
|
526
|
+
return
|
|
527
|
+
|
|
528
|
+
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
529
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
|
|
530
|
+
|
|
531
|
+
def generate_param_map(self, tag, param_tensor):
|
|
532
|
+
metrics = {}
|
|
533
|
+
if not self.is_mindtorch:
|
|
534
|
+
return param_tensor
|
|
535
|
+
for name in self.param2name.values():
|
|
536
|
+
key = get_summary_writer_tag_name(name, tag, self.rank)
|
|
537
|
+
self.register_param_call_id("optimizer_pre_step_hook", key)
|
|
538
|
+
if name not in param_tensor or param_tensor[name] is None:
|
|
539
|
+
continue
|
|
540
|
+
metrics[key] = param_tensor[name]
|
|
541
|
+
return metrics
|
|
542
|
+
|
|
543
|
+
def generate_param_metrics(self, opt_context, stage=MonitorConst.PRE_PARAM):
|
|
544
|
+
if not self.param_distribution:
|
|
545
|
+
return
|
|
546
|
+
tag2param = {
|
|
547
|
+
self.name2tag.get(name, {}).get(stage): param
|
|
548
|
+
for name, param in self.name2param.items()
|
|
549
|
+
if param.numel() != 0
|
|
550
|
+
}
|
|
551
|
+
get_metrics(self.ops, tag2param, self.eps, opt_context.param_metric)
|
|
552
|
+
|
|
553
|
+
def get_mv_for_ms(self, opt):
|
|
554
|
+
if not self.mv_distribution:
|
|
555
|
+
return {}, {}
|
|
556
|
+
common_opt = opt
|
|
557
|
+
if not is_valid_instance(opt):
|
|
558
|
+
common_opt = getattr(opt, 'optimizer')
|
|
559
|
+
if not is_valid_instance(common_opt):
|
|
560
|
+
logger.warning("Optimizer is not valid, please check usage")
|
|
561
|
+
return {}, {}
|
|
562
|
+
m_dict = {}
|
|
563
|
+
v_dict = {}
|
|
564
|
+
for name, param in get_parameters(common_opt):
|
|
565
|
+
if MonitorConst.EXP_AVG_SQ in name:
|
|
566
|
+
v_dict[name] = param
|
|
567
|
+
elif MonitorConst.EXP_AVG in name:
|
|
568
|
+
m_dict[name] = param
|
|
569
|
+
return m_dict, v_dict
|
|
570
|
+
|
|
571
|
+
def generate_mv_metrics(self, opt_context):
|
|
572
|
+
if not self.mv_distribution:
|
|
573
|
+
return
|
|
574
|
+
opt_context.exp_avg_metric = {}
|
|
575
|
+
opt_context.exp_avg_sq_metric = {}
|
|
576
|
+
m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
|
|
577
|
+
v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
|
|
578
|
+
get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
|
|
579
|
+
get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
|
|
580
|
+
|
|
581
|
+
def write_stack_info(self):
|
|
582
|
+
stack_data = []
|
|
583
|
+
header = ["module_name", "stack_info"]
|
|
584
|
+
stack_data.append(header)
|
|
585
|
+
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
586
|
+
stack_data.append([fwd_context.module_name, fwd_context.stack])
|
|
587
|
+
filepath = os.path.join(self.tensorboard_dir, f'stack_info.csv')
|
|
588
|
+
if not os.path.exists(filepath):
|
|
589
|
+
data_frame = pd.DataFrame(columns=stack_data)
|
|
590
|
+
write_df_to_csv(data_frame, filepath)
|
|
591
|
+
|
|
445
592
|
def write_xy_tb(self, step):
|
|
446
593
|
if not self.xy_distribution:
|
|
447
594
|
return
|
|
448
595
|
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
449
596
|
if len(fwd_context.actv) == 0:
|
|
450
597
|
continue
|
|
451
|
-
self.summary_writer.write_metrics(self.ops, fwd_context.actv, step,
|
|
598
|
+
self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, MonitorConst.ACTV)
|
|
452
599
|
fwd_context.actv.clear()
|
|
453
600
|
if self.grad_context.actv:
|
|
454
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step,
|
|
601
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD)
|
|
455
602
|
|
|
456
603
|
def write_param_tb(self, opt_context):
|
|
457
604
|
if not self.param_distribution:
|
|
458
605
|
return
|
|
459
|
-
|
|
606
|
+
param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.PRE_PARAM in k}
|
|
607
|
+
updated_param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.POST_PARAM in k}
|
|
608
|
+
self.summary_writer.write_metrics(self.ops, param_metrics, opt_context.step, MonitorConst.PRE_PARAM)
|
|
609
|
+
self.summary_writer.write_metrics(self.ops, updated_param_metrics, opt_context.step, MonitorConst.POST_PARAM)
|
|
460
610
|
|
|
461
611
|
def write_mv_tb(self, opt_context):
|
|
462
612
|
if not self.mv_distribution:
|
|
463
613
|
return
|
|
464
|
-
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step,
|
|
465
|
-
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step,
|
|
614
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, MonitorConst.EXP_AVG)
|
|
615
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step,
|
|
616
|
+
MonitorConst.EXP_AVG_SQ)
|
|
466
617
|
|
|
467
618
|
def write_grad_tb(self, step):
|
|
468
619
|
if not self.wg_distribution:
|
|
469
620
|
return
|
|
470
621
|
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
else:
|
|
474
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
|
|
622
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced',
|
|
623
|
+
use_micro_step=self.monitor_mbs_grad)
|
|
475
624
|
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
476
625
|
|
|
477
|
-
def common_info(self):
|
|
478
|
-
if not self.xy_distribution:
|
|
479
|
-
logger.info("> module input/output input_grad/output_grad is not monitored. ")
|
|
480
|
-
if self.forward_only:
|
|
481
|
-
logger.info("> only module forward is monitored. ")
|
|
482
|
-
if not self.ur_distribution:
|
|
483
|
-
logger.info("> update vector and ratio vector of adam is not monitored. ")
|
|
484
|
-
if not self.mv_distribution:
|
|
485
|
-
logger.info("> momentum and variance of adam is not monitored. ")
|
|
486
|
-
if not self.wg_distribution:
|
|
487
|
-
logger.info("> weight grad of specified module is not monitored. ")
|
|
488
|
-
if not self.mg_direction:
|
|
489
|
-
logger.info('> grad and momentum direction will not be compared.')
|
|
490
|
-
if not self.cc_distribution.get('enable', False):
|
|
491
|
-
logger.info("> cc operator is not monitored.")
|
|
492
|
-
|
|
493
626
|
def is_target_rank(self):
|
|
494
|
-
|
|
495
|
-
if self.module_rank_list and (rank_id not in self.module_rank_list):
|
|
627
|
+
if self.module_rank_list and (self.rank not in self.module_rank_list):
|
|
496
628
|
return False
|
|
497
629
|
return True
|
|
498
630
|
|
|
499
|
-
def
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
for key in module_in_all_stage:
|
|
510
|
-
struct = self.targets.pop(key)
|
|
511
|
-
self.targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(model))})
|
|
512
|
-
|
|
513
|
-
hooked_count = 0
|
|
514
|
-
for vpp_stage, model_chunk in enumerate(model):
|
|
515
|
-
if not isinstance(model_chunk, nn.Cell):
|
|
516
|
-
logger.info("Target Model is not Cell")
|
|
517
|
-
continue
|
|
518
|
-
vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
519
|
-
targets = [x for x, _ in model_chunk.cells_and_names()] if self.print_struct else self.targets.keys()
|
|
520
|
-
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
521
|
-
logger.info(f"> {hooked_count} modules are monitored.")
|
|
522
|
-
|
|
523
|
-
def build_tbtag_tensor_map(self, module_name, tag, tensor):
|
|
524
|
-
rank_id = str(get_rank())
|
|
525
|
-
metrics = {}
|
|
526
|
-
key = get_summary_writer_tag_name(module_name, tag, rank_id)
|
|
631
|
+
def build_tbtag_tensor_map(self, module_name, suffix, tag, tensor):
|
|
632
|
+
"""
|
|
633
|
+
:param module_name: str of module name
|
|
634
|
+
:param suffix:
|
|
635
|
+
:param tag:
|
|
636
|
+
:param tensor: torch.tensor or tuple/list of torch.tensor
|
|
637
|
+
:return: tensor_map
|
|
638
|
+
"""
|
|
639
|
+
tensor_map = {}
|
|
527
640
|
if isinstance(tensor, Tensor):
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
logger.warning(f"An error occurred while generating wgrad pre metrics")
|
|
541
|
-
return {}, {}
|
|
542
|
-
|
|
543
|
-
grad_dict = {}
|
|
544
|
-
for param, name in self.param2name.items():
|
|
545
|
-
if self.duplicate_param.get(name, False):
|
|
546
|
-
continue
|
|
547
|
-
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
548
|
-
if grad is None:
|
|
549
|
-
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
550
|
-
continue
|
|
551
|
-
tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
552
|
-
self._register_param_call_id("hook_optimizer", tag)
|
|
553
|
-
grad_dict[tag] = grad
|
|
554
|
-
try:
|
|
555
|
-
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
|
|
556
|
-
except Exception as e:
|
|
557
|
-
logger.warning(f"An error occurred while generating wgrad post metrics")
|
|
558
|
-
return {}, {}
|
|
559
|
-
return self.grad_context.post, self.grad_context.pre
|
|
560
|
-
|
|
561
|
-
def _register_param_name(self, model):
|
|
562
|
-
if self.param_registered:
|
|
563
|
-
return
|
|
641
|
+
tensor = [tensor]
|
|
642
|
+
if isinstance(tensor, tuple) or isinstance(tensor, list):
|
|
643
|
+
if len(tensor) == 1:
|
|
644
|
+
key = get_summary_writer_tag_name(module_name + suffix, tag, self.rank)
|
|
645
|
+
self.register_param_call_id("_hook_module", key)
|
|
646
|
+
tensor_map[key] = tensor[0]
|
|
647
|
+
else:
|
|
648
|
+
for i, tensor_i in enumerate(tensor):
|
|
649
|
+
key = get_summary_writer_tag_name(module_name + f"_{i}" + suffix, tag, self.rank)
|
|
650
|
+
self.register_param_call_id("_hook_module", key)
|
|
651
|
+
tensor_map[key] = tensor_i
|
|
652
|
+
return tensor_map
|
|
564
653
|
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
654
|
+
def register_param_call_id(self, hook_name: str, key: str):
|
|
655
|
+
"""
|
|
656
|
+
:param hook_name:
|
|
657
|
+
:param key: str, '0:relu_0/output_grad'
|
|
658
|
+
:return:
|
|
659
|
+
"""
|
|
660
|
+
logger.debug(f"{hook_name} {key}: {self.call_id}")
|
|
661
|
+
self.param_name_call_id[key] = self.call_id
|
|
662
|
+
self.call_id += 1
|
|
568
663
|
|
|
569
|
-
|
|
664
|
+
def _register_param_name(self):
|
|
665
|
+
for vpp_stage, model_chunk in enumerate(self.model):
|
|
570
666
|
prefix = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
571
667
|
self._register_chunk(model_chunk, prefix)
|
|
572
668
|
|
|
573
|
-
self.param_registered = True
|
|
574
|
-
|
|
575
|
-
def _is_target_param(self, param_name, param, prefix):
|
|
576
|
-
if not self.targets:
|
|
577
|
-
return True
|
|
578
|
-
squash_name = prefix + squash_param_name(param_name)
|
|
579
|
-
name = prefix + param_name
|
|
580
|
-
for target in self.targets.keys():
|
|
581
|
-
if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
|
|
582
|
-
setattr(param, "zero_out_wgrad", True)
|
|
583
|
-
return True
|
|
584
|
-
return False
|
|
585
|
-
|
|
586
669
|
def _register_chunk(self, model_chunk, prefix):
|
|
587
670
|
index = 0
|
|
588
|
-
for param in
|
|
589
|
-
param_name = param.name
|
|
671
|
+
for param_name, param in get_parameters(model_chunk):
|
|
590
672
|
if not param.requires_grad:
|
|
591
673
|
continue
|
|
592
674
|
if self._is_target_param(param_name, param, prefix):
|
|
@@ -601,71 +683,59 @@ class TrainerMon:
|
|
|
601
683
|
self.duplicate_param[name] = True
|
|
602
684
|
if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
|
|
603
685
|
self.duplicate_param[name] = True
|
|
686
|
+
keywords = [
|
|
687
|
+
MonitorConst.PRE_GRAD,
|
|
688
|
+
MonitorConst.POST_GRAD,
|
|
689
|
+
MonitorConst.PRE_PARAM,
|
|
690
|
+
MonitorConst.POST_PARAM
|
|
691
|
+
]
|
|
604
692
|
self.name2tag[name] = {
|
|
605
|
-
|
|
606
|
-
|
|
693
|
+
k: get_summary_writer_tag_name(name, k, self.rank)
|
|
694
|
+
for k in keywords
|
|
607
695
|
}
|
|
608
696
|
index += 1
|
|
609
697
|
|
|
610
|
-
def _is_target_module(self, module_name, targets, vpp_stage):
|
|
611
|
-
if self.all_xy or self.print_struct:
|
|
612
|
-
return vpp_stage + squash_param_name(module_name)
|
|
613
|
-
for pattern in [
|
|
614
|
-
vpp_stage + squash_param_name(module_name),
|
|
615
|
-
vpp_stage + module_name,
|
|
616
|
-
]:
|
|
617
|
-
if pattern in targets:
|
|
618
|
-
return pattern
|
|
619
|
-
return ""
|
|
620
|
-
|
|
621
698
|
def _hook_module(self, target_names, module, vpp_stage=''):
|
|
622
|
-
if not
|
|
699
|
+
if not is_valid_instance(module):
|
|
623
700
|
# nothing to hook
|
|
624
701
|
return 0
|
|
625
702
|
|
|
626
|
-
def fwd_hook_fun(module,
|
|
703
|
+
def fwd_hook_fun(module, args, kwargs, module_output, name):
|
|
704
|
+
|
|
705
|
+
module_input = [tensor for tensor in args if isinstance(tensor, Tensor)]
|
|
706
|
+
if kwargs:
|
|
707
|
+
kwargs_tensors = [tensor for tensor in kwargs.values() if isinstance(tensor, Tensor)]
|
|
708
|
+
module_input.extend(kwargs_tensors)
|
|
709
|
+
|
|
627
710
|
if module not in self.module_fwd_hook_context_by_module:
|
|
628
711
|
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
629
712
|
context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
|
|
630
713
|
if not context.struct:
|
|
631
714
|
context.struct = {
|
|
632
|
-
|
|
633
|
-
|
|
715
|
+
Const.INPUT: get_param_struct(module_input),
|
|
716
|
+
Const.OUTPUT: get_param_struct(module_output)
|
|
634
717
|
}
|
|
635
718
|
if self.print_struct:
|
|
636
719
|
self.module_struct[context.module_name].update(context.struct)
|
|
637
720
|
return
|
|
638
721
|
if not module.training:
|
|
639
722
|
return
|
|
640
|
-
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
|
|
723
|
+
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
|
|
641
724
|
self.collect_times):
|
|
642
725
|
step_accumulates_one(context, self.micro_batch_number)
|
|
643
726
|
return
|
|
644
|
-
if not context.format_by_arg:
|
|
645
|
-
context.set_format_by_arg(MonitorConst.ACTV_IN, self.targets)
|
|
646
|
-
context.set_format_by_arg(MonitorConst.ACTV_OUT, self.targets)
|
|
647
|
-
if not context.format_by_arg:
|
|
648
|
-
return
|
|
649
|
-
if not context.verified:
|
|
650
|
-
if not context.ignore_in:
|
|
651
|
-
context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
|
|
652
|
-
module_input, context.module_name,
|
|
653
|
-
MonitorConst.ACTV_IN)
|
|
654
|
-
context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
|
|
655
|
-
module_output, context.module_name,
|
|
656
|
-
MonitorConst.ACTV_OUT)
|
|
657
|
-
context.verified = True
|
|
658
727
|
|
|
659
728
|
tbtag_tensor_map = {}
|
|
660
|
-
if not context.ignore_in:
|
|
661
|
-
cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
|
|
662
|
-
tbtag_tensor_map.update(
|
|
663
|
-
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
|
|
664
|
-
cared_input))
|
|
665
|
-
cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
|
|
666
729
|
tbtag_tensor_map.update(
|
|
667
|
-
self.build_tbtag_tensor_map(
|
|
668
|
-
|
|
730
|
+
self.build_tbtag_tensor_map(
|
|
731
|
+
f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
732
|
+
MonitorConst.ACTV, module_input))
|
|
733
|
+
module_output = [tensor for tensor in module_output if isinstance(tensor, Tensor)] \
|
|
734
|
+
if isinstance(module_output, tuple) else module_output
|
|
735
|
+
tbtag_tensor_map.update(
|
|
736
|
+
self.build_tbtag_tensor_map(
|
|
737
|
+
f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
738
|
+
MonitorConst.ACTV, module_output))
|
|
669
739
|
try:
|
|
670
740
|
get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
|
|
671
741
|
except Exception as e:
|
|
@@ -685,36 +755,22 @@ class TrainerMon:
|
|
|
685
755
|
self.module_struct[context.module_name].update(context.struct)
|
|
686
756
|
return
|
|
687
757
|
|
|
688
|
-
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
|
|
758
|
+
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
|
|
689
759
|
self.collect_times):
|
|
690
760
|
step_accumulates_one(context, self.micro_batch_number)
|
|
691
761
|
return
|
|
692
762
|
|
|
693
|
-
if
|
|
694
|
-
context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.targets)
|
|
695
|
-
context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.targets)
|
|
696
|
-
if not context.format_by_arg:
|
|
697
|
-
return
|
|
698
|
-
if not context.verified:
|
|
699
|
-
if not context.ignore_in:
|
|
700
|
-
context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
|
|
701
|
-
input_grad, context.module_name,
|
|
702
|
-
MonitorConst.ACTVGRAD_IN)
|
|
703
|
-
context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
|
|
704
|
-
output_grad, context.module_name,
|
|
705
|
-
MonitorConst.ACTVGRAD_OUT)
|
|
706
|
-
context.verified = True
|
|
707
|
-
|
|
763
|
+
valid_input_grad = [tensor for tensor in input_grad if isinstance(tensor, Tensor)]
|
|
708
764
|
tbtag_tensor_map = {}
|
|
709
|
-
if not context.ignore_in:
|
|
710
|
-
cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
|
|
711
|
-
tbtag_tensor_map.update(
|
|
712
|
-
self.build_tbtag_tensor_map(
|
|
713
|
-
f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad))
|
|
714
|
-
cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
|
|
715
765
|
tbtag_tensor_map.update(
|
|
716
|
-
self.build_tbtag_tensor_map(
|
|
717
|
-
|
|
766
|
+
self.build_tbtag_tensor_map(
|
|
767
|
+
f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
768
|
+
MonitorConst.ACTVGRAD, valid_input_grad))
|
|
769
|
+
|
|
770
|
+
tbtag_tensor_map.update(
|
|
771
|
+
self.build_tbtag_tensor_map(
|
|
772
|
+
f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
773
|
+
MonitorConst.ACTVGRAD, output_grad))
|
|
718
774
|
|
|
719
775
|
if context.micro_step == 0 and context.actvgrad:
|
|
720
776
|
logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
|
|
@@ -728,21 +784,39 @@ class TrainerMon:
|
|
|
728
784
|
step_accumulates_one(context, self.micro_batch_number)
|
|
729
785
|
return
|
|
730
786
|
|
|
731
|
-
def
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
787
|
+
def fwd_hook_register(module, fwd_hook_fun, name):
|
|
788
|
+
if mindspore.__version__ >= '2.6.0':
|
|
789
|
+
def wrapper(module, args, kwargs, module_output):
|
|
790
|
+
return fwd_hook_fun(module, args, kwargs, module_output, name)
|
|
791
|
+
return module.register_forward_hook(wrapper, with_kwargs=True)
|
|
792
|
+
|
|
793
|
+
else:
|
|
794
|
+
def wrapper(module, args, module_output):
|
|
795
|
+
return fwd_hook_fun(module, args, None, module_output, name)
|
|
796
|
+
return module.register_forward_hook(wrapper)
|
|
797
|
+
|
|
798
|
+
def stack_hook(module, args, kwargs, module_output, name):
|
|
799
|
+
if module not in self.module_fwd_hook_context_by_module:
|
|
800
|
+
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
801
|
+
context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
|
|
802
|
+
context.stack = analyze_api_call_stack(name)
|
|
803
|
+
return
|
|
735
804
|
|
|
736
805
|
if self.backward_only and self.forward_only:
|
|
737
806
|
logger.warning('not enable backward_only and forward_only simultaneously')
|
|
738
807
|
hooked_count = 0
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
808
|
+
|
|
809
|
+
for module_name, submodule in get_submodules(module):
|
|
810
|
+
if self.stack_info:
|
|
811
|
+
name = vpp_stage + squash_param_name(module_name)
|
|
812
|
+
handle = fwd_hook_register(submodule, stack_hook, name=name)
|
|
813
|
+
self.handles["stack"].append(handle)
|
|
814
|
+
name = self._is_target_module(module_name, target_names, vpp_stage)
|
|
815
|
+
if not name:
|
|
816
|
+
continue
|
|
817
|
+
if self.xy_distribution or self.print_struct:
|
|
744
818
|
if not self.backward_only:
|
|
745
|
-
handle = submodule
|
|
819
|
+
handle = fwd_hook_register(submodule, fwd_hook_fun, name=name)
|
|
746
820
|
self.handles['xy'].append(handle)
|
|
747
821
|
if not self.forward_only:
|
|
748
822
|
handle = submodule.register_backward_hook(bwd_hook_fun)
|
|
@@ -752,70 +826,120 @@ class TrainerMon:
|
|
|
752
826
|
hooked_count += 1
|
|
753
827
|
return hooked_count
|
|
754
828
|
|
|
755
|
-
def _register_param_call_id(self, hook_name: str, key: str):
|
|
756
|
-
"""
|
|
757
|
-
:param hook_name:
|
|
758
|
-
:param key: str, '0:relu_0/output_grad'
|
|
759
|
-
:return:
|
|
760
|
-
"""
|
|
761
|
-
logger.debug(f"{hook_name} {key}: {self.call_id}")
|
|
762
|
-
self.param_name_call_id[key] = self.call_id
|
|
763
|
-
self.call_id += 1
|
|
764
|
-
|
|
765
829
|
def _patch_grad_sync(self):
|
|
766
|
-
# mindspore 暂不使用megatron
|
|
767
|
-
def patch_sync(sync_grad_func):
|
|
768
|
-
def wrapper(bucket):
|
|
769
|
-
grad_dict = {}
|
|
770
|
-
for param, name in self.param2name.items():
|
|
771
|
-
if param not in bucket.params_list:
|
|
772
|
-
continue
|
|
773
|
-
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
774
|
-
if grad is None:
|
|
775
|
-
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
776
|
-
continue
|
|
777
|
-
tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
|
|
778
|
-
if tag is None:
|
|
779
|
-
continue
|
|
780
|
-
grad_dict[tag] = grad
|
|
781
|
-
try:
|
|
782
|
-
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
783
|
-
except Exception as e:
|
|
784
|
-
logger.warning(f"An error occurred while generating weight grad metrics")
|
|
785
|
-
out = sync_grad_func(bucket)
|
|
786
|
-
return out
|
|
787
|
-
|
|
788
|
-
return wrapper
|
|
789
|
-
|
|
790
|
-
self.enable_megatron = False
|
|
791
|
-
|
|
792
830
|
if not self.wg_distribution:
|
|
793
831
|
return
|
|
794
|
-
|
|
795
|
-
if self.enable_megatron:
|
|
796
|
-
Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
|
|
797
|
-
else:
|
|
798
|
-
self._hook_weights()
|
|
832
|
+
self._hook_weights()
|
|
799
833
|
|
|
800
834
|
def _hook_weights(self):
|
|
801
835
|
context = self.grad_context
|
|
802
836
|
|
|
803
837
|
@_no_grad()
|
|
804
|
-
def param_hook(grad, context_dict, param,
|
|
838
|
+
def param_hook(grad, context_dict, param, name):
|
|
839
|
+
key = name
|
|
840
|
+
if self.monitor_mbs_grad:
|
|
841
|
+
key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
|
|
842
|
+
key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
|
|
843
|
+
self.register_param_call_id("param_hook", key)
|
|
805
844
|
param.micro_step += 1
|
|
806
|
-
|
|
845
|
+
|
|
846
|
+
if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
|
|
847
|
+
context_dict[key] = grad
|
|
807
848
|
if param.micro_step == self.micro_batch_number:
|
|
808
849
|
param.micro_step = 0
|
|
809
|
-
context_dict[key] = grad
|
|
810
850
|
|
|
811
|
-
def param_hook_wrapper(param_hook, context_dict, param,
|
|
851
|
+
def param_hook_wrapper(param_hook, context_dict, param, name):
|
|
812
852
|
def wrapper(grad):
|
|
813
|
-
return param_hook(grad, context_dict, param,
|
|
853
|
+
return param_hook(grad, context_dict, param, name)
|
|
854
|
+
|
|
814
855
|
return wrapper
|
|
815
856
|
|
|
857
|
+
logger.info("hooking weights.")
|
|
816
858
|
for param, name in self.param2name.items():
|
|
817
|
-
key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
|
|
818
859
|
setattr(param, 'micro_step', 0)
|
|
819
|
-
handle = param.register_hook(
|
|
860
|
+
handle = param.register_hook(
|
|
861
|
+
param_hook_wrapper(param_hook, context_dict=context.acc, param=param, name=name))
|
|
820
862
|
self.handles['wgrads'].append(handle)
|
|
821
863
|
self.weight_hooked = True
|
|
864
|
+
|
|
865
|
+
def _is_target_param(self, param_name, param, prefix):
|
|
866
|
+
if not self.targets:
|
|
867
|
+
return True
|
|
868
|
+
squash_name = prefix + squash_param_name(param_name)
|
|
869
|
+
name = prefix + param_name
|
|
870
|
+
for target in self.targets.keys():
|
|
871
|
+
if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
|
|
872
|
+
setattr(param, "zero_out_wgrad", True)
|
|
873
|
+
return True
|
|
874
|
+
return False
|
|
875
|
+
|
|
876
|
+
def _is_target_module(self, module_name, targets, vpp_stage):
|
|
877
|
+
if self.all_xy or self.print_struct:
|
|
878
|
+
return vpp_stage + squash_param_name(module_name)
|
|
879
|
+
for pattern in [
|
|
880
|
+
vpp_stage + squash_param_name(module_name),
|
|
881
|
+
vpp_stage + module_name,
|
|
882
|
+
]:
|
|
883
|
+
if pattern in targets:
|
|
884
|
+
return pattern
|
|
885
|
+
return ""
|
|
886
|
+
|
|
887
|
+
def _remove_all_hooks(self, optimizer):
|
|
888
|
+
# 清空hook handle
|
|
889
|
+
for handle in self.handles['xy']:
|
|
890
|
+
handle.remove()
|
|
891
|
+
self.handles['xy'].clear()
|
|
892
|
+
# 清空对应context缓存
|
|
893
|
+
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
894
|
+
fwd_context.reset()
|
|
895
|
+
for _, bwd_context in self.module_bwd_hook_context_by_module.items():
|
|
896
|
+
bwd_context.reset()
|
|
897
|
+
self.grad_context.reset() # 权重梯度和激活值梯度都在这
|
|
898
|
+
|
|
899
|
+
for handle in self.handles['wgrads']:
|
|
900
|
+
handle.remove()
|
|
901
|
+
self.handles['wgrads'].clear()
|
|
902
|
+
self.weight_hooked = False
|
|
903
|
+
|
|
904
|
+
if self.optimizer_hooked:
|
|
905
|
+
self.pre_step_hooks.clear()
|
|
906
|
+
self.post_step_hooks.clear()
|
|
907
|
+
for _, context in self.optimizer_context.items():
|
|
908
|
+
context.reset()
|
|
909
|
+
self.optimizer_hooked = False
|
|
910
|
+
|
|
911
|
+
for handle in self.handles['cc']:
|
|
912
|
+
handle.remove()
|
|
913
|
+
self.handles['cc'].clear()
|
|
914
|
+
api_register.restore_api()
|
|
915
|
+
for _, context in self.cc_context.items():
|
|
916
|
+
context.reset()
|
|
917
|
+
|
|
918
|
+
# 清空节点缓存
|
|
919
|
+
self.param2name.clear()
|
|
920
|
+
self.name2index.clear()
|
|
921
|
+
self.name2indices.clear()
|
|
922
|
+
self.name2param.clear()
|
|
923
|
+
self.duplicate_param.clear()
|
|
924
|
+
self.name2tag.clear()
|
|
925
|
+
self.module_struct.clear()
|
|
926
|
+
self.grad_accs.clear()
|
|
927
|
+
|
|
928
|
+
# 关闭采集状态
|
|
929
|
+
self.monitoring = False
|
|
930
|
+
|
|
931
|
+
def _remove_all_hooks_final(self, optimizer):
|
|
932
|
+
if self.dynamic_enable:
|
|
933
|
+
# 结束后自动重置dynamic_on为False等待用户手动开启
|
|
934
|
+
try:
|
|
935
|
+
config = load_json(self.config_file_path)
|
|
936
|
+
config['dynamic_on'] = False
|
|
937
|
+
save_json(self.config_file_path, config, indent=2)
|
|
938
|
+
config_timestamp = os.path.getmtime(self.config_file_path)
|
|
939
|
+
self.config_timestamp = config_timestamp
|
|
940
|
+
logger.info(
|
|
941
|
+
"Finish monitor, set config'dynamic_on=False, will restart by set it to True and update config")
|
|
942
|
+
except Exception as e:
|
|
943
|
+
logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!")
|
|
944
|
+
logger.info("Finish monitor")
|
|
945
|
+
self._remove_all_hooks(optimizer)
|