mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -22,27 +22,29 @@ from functools import partial
|
|
|
22
22
|
import pytz
|
|
23
23
|
import torch
|
|
24
24
|
import torch.distributed as dist
|
|
25
|
+
import pandas as pd
|
|
25
26
|
from torch.utils.hooks import BackwardHook
|
|
26
27
|
|
|
27
28
|
from msprobe.core.common.const import MonitorConst, Const
|
|
28
29
|
from msprobe.core.common.file_utils import load_json, save_json
|
|
29
30
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
31
|
+
from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter
|
|
32
|
+
from msprobe.core.common.file_utils import write_df_to_csv
|
|
33
|
+
from msprobe.core.common.utils import analyze_api_call_stack
|
|
30
34
|
from msprobe.pytorch.common.log import logger
|
|
31
35
|
from msprobe.pytorch.common.utils import is_recomputation, is_float8_tensor
|
|
32
|
-
from msprobe.pytorch.monitor.
|
|
33
|
-
from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
|
|
34
|
-
CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
36
|
+
from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
35
37
|
from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
|
|
36
38
|
get_process_group
|
|
37
39
|
from msprobe.pytorch.monitor.features import get_sign_matches
|
|
38
40
|
from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
|
|
39
41
|
TensorMetrics, squash_param_name
|
|
40
|
-
from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
|
|
41
42
|
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
|
|
42
43
|
from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \
|
|
43
44
|
get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor
|
|
44
45
|
from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
|
|
45
46
|
|
|
47
|
+
|
|
46
48
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
47
49
|
if not torch_version_above_or_equal_2:
|
|
48
50
|
raise ValueError("monitor require torch>=2.0")
|
|
@@ -72,36 +74,7 @@ class ModuleHookContext:
|
|
|
72
74
|
self.actvgrad = []
|
|
73
75
|
self.module_name = module_name
|
|
74
76
|
self.struct = {}
|
|
75
|
-
self.
|
|
76
|
-
self.verified = False
|
|
77
|
-
self.focused_in_col = 0
|
|
78
|
-
self.focused_out_col = 0
|
|
79
|
-
|
|
80
|
-
def set_format_by_arg(self, key_name: str, target_config: dict):
|
|
81
|
-
""" 按照监控对象配置format_by_arg
|
|
82
|
-
1) module_name 在 target 中配置监控对象
|
|
83
|
-
2) module_name 未在 targets 中配置,且 all_xy 全量监控
|
|
84
|
-
3) module_name 未在 targets 中配置,且 all_xy 未全量监控
|
|
85
|
-
|
|
86
|
-
:param key_name: str, one of [input, output, input_grad, output_grad]
|
|
87
|
-
:param target_config: target obj in config json.
|
|
88
|
-
:return:
|
|
89
|
-
"""
|
|
90
|
-
cared = target_config.get(self.module_name, self.struct)
|
|
91
|
-
if key_name in cared:
|
|
92
|
-
target_module_config = cared[key_name]
|
|
93
|
-
if isinstance(target_module_config, dict):
|
|
94
|
-
# current cared is self.struct, monitor all data for module_name
|
|
95
|
-
self.format_by_arg[key_name] = target_module_config.get('config')
|
|
96
|
-
elif isinstance(target_module_config, str):
|
|
97
|
-
# current cared is target_config[self.module_name]
|
|
98
|
-
self.format_by_arg[key_name] = target_module_config
|
|
99
|
-
else:
|
|
100
|
-
logger.warning_on_rank_0(f"target module config error, result maybe empty."
|
|
101
|
-
f"module_name: {self.module_name}, key_name: {key_name}")
|
|
102
|
-
self.format_by_arg[key_name] = None
|
|
103
|
-
else:
|
|
104
|
-
self.format_by_arg[key_name] = self.struct.get(key_name).get('config')
|
|
77
|
+
self.stack = ""
|
|
105
78
|
|
|
106
79
|
def reset(self):
|
|
107
80
|
self.actv.clear()
|
|
@@ -185,8 +158,8 @@ class TrainerMon:
|
|
|
185
158
|
self.params_have_main_grad = params_have_main_grad
|
|
186
159
|
self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
187
160
|
self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
188
|
-
self.origin_step_func = None
|
|
189
161
|
self.origin_start_grad_sync = None
|
|
162
|
+
self.fsdp_post_backward_hook = None
|
|
190
163
|
self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
|
|
191
164
|
self.config = load_json(config_file_path)
|
|
192
165
|
validate_config(self.config)
|
|
@@ -221,8 +194,8 @@ class TrainerMon:
|
|
|
221
194
|
self.dp_group = None
|
|
222
195
|
self.tp_group = None
|
|
223
196
|
self.enable_megatron = False
|
|
197
|
+
self.fsdp_wrapped_module = False
|
|
224
198
|
self.micro_batch_number = 1
|
|
225
|
-
self.optimizer_class = None
|
|
226
199
|
self.optimizer_mon = None
|
|
227
200
|
self.optimizer_trans = None
|
|
228
201
|
|
|
@@ -234,7 +207,6 @@ class TrainerMon:
|
|
|
234
207
|
self.grad_context = GradContext()
|
|
235
208
|
self.handles = defaultdict(list)
|
|
236
209
|
self.param2name = defaultdict(str)
|
|
237
|
-
self.name2index = defaultdict()
|
|
238
210
|
self.name2indices = defaultdict()
|
|
239
211
|
self.name2param = {}
|
|
240
212
|
self.duplicate_param = {}
|
|
@@ -247,6 +219,8 @@ class TrainerMon:
|
|
|
247
219
|
self.optimizer_hooked = False
|
|
248
220
|
self.param_registered = False
|
|
249
221
|
self.struct_printed = False
|
|
222
|
+
self.pre_step_hooks = []
|
|
223
|
+
self.post_step_hooks = []
|
|
250
224
|
|
|
251
225
|
# 动静态区分
|
|
252
226
|
self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
|
|
@@ -317,6 +291,8 @@ class TrainerMon:
|
|
|
317
291
|
self.param_distribution = self.config.get("param_distribution", False)
|
|
318
292
|
self.mg_direction = self.config.get('mg_direction', False)
|
|
319
293
|
self.cc_distribution = self.config.get("cc_distribution", {})
|
|
294
|
+
self.stack_info = self.config.get('stack_info', False)
|
|
295
|
+
self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False)
|
|
320
296
|
|
|
321
297
|
if not self.cc_distribution.get('enable', False):
|
|
322
298
|
self.cc_log_only = False
|
|
@@ -411,7 +387,7 @@ class TrainerMon:
|
|
|
411
387
|
self.micro_batch_number = grad_acc_steps
|
|
412
388
|
self.dp_group = dp_group
|
|
413
389
|
self.tp_group = tp_group
|
|
414
|
-
self.optimizer_mon
|
|
390
|
+
self.optimizer_mon = OptimizerMonFactory.create_optimizer_mon(optimizer)
|
|
415
391
|
self.hook_step_final(optimizer)
|
|
416
392
|
if not isinstance(model, list):
|
|
417
393
|
model = [model]
|
|
@@ -440,25 +416,48 @@ class TrainerMon:
|
|
|
440
416
|
return
|
|
441
417
|
self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
|
|
442
418
|
|
|
443
|
-
def build_tbtag_tensor_map(self, module_name, tag, tensor):
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
419
|
+
def build_tbtag_tensor_map(self, module_name, suffix, tag, tensor):
|
|
420
|
+
"""
|
|
421
|
+
:param module_name: str of module name
|
|
422
|
+
:param suffix:
|
|
423
|
+
:param tag:
|
|
424
|
+
:param tensor: torch.tensor or tuple/list of torch.tensor
|
|
425
|
+
:return: tensor_map
|
|
426
|
+
"""
|
|
427
|
+
tensor_map = {}
|
|
428
|
+
if isinstance(tensor, torch.Tensor):
|
|
429
|
+
tensor = [tensor]
|
|
430
|
+
if isinstance(tensor, tuple) or isinstance(tensor, list):
|
|
431
|
+
if len(tensor) == 1:
|
|
432
|
+
key = get_summary_writer_tag_name(module_name + suffix, tag, self.rank)
|
|
433
|
+
self.register_param_call_id("_hook_module", key)
|
|
434
|
+
tensor_map[key] = tensor[0]
|
|
435
|
+
else:
|
|
436
|
+
for i, tensor_i in enumerate(tensor):
|
|
437
|
+
key = get_summary_writer_tag_name(module_name + f"_{i}" + suffix, tag, self.rank)
|
|
438
|
+
self.register_param_call_id("_hook_module", key)
|
|
439
|
+
tensor_map[key] = tensor_i
|
|
440
|
+
return tensor_map
|
|
447
441
|
|
|
448
442
|
def generate_param_map(self, tag, param_tensor):
|
|
449
443
|
metrics = {}
|
|
450
444
|
for name in self.param2name.values():
|
|
451
445
|
key = get_summary_writer_tag_name(name, tag, self.rank)
|
|
452
|
-
self.
|
|
446
|
+
self.register_param_call_id("optimizer_pre_step_hook", key)
|
|
453
447
|
if name not in param_tensor or param_tensor[name] is None:
|
|
454
448
|
continue
|
|
455
449
|
metrics[key] = param_tensor[name]
|
|
456
450
|
return metrics
|
|
457
451
|
|
|
458
|
-
def generate_param_metrics(self, opt_context):
|
|
452
|
+
def generate_param_metrics(self, opt_context, stage=MonitorConst.PRE_PARAM):
|
|
459
453
|
if not self.param_distribution:
|
|
460
454
|
return
|
|
461
|
-
|
|
455
|
+
tag2param = {
|
|
456
|
+
self.name2tag.get(name, {}).get(stage): param
|
|
457
|
+
for name, param in self.name2param.items()
|
|
458
|
+
if param.numel() != 0
|
|
459
|
+
}
|
|
460
|
+
get_metrics(self.ops, tag2param, self.eps, opt_context.param_metric)
|
|
462
461
|
|
|
463
462
|
def generate_mv_metrics(self, opt_context):
|
|
464
463
|
if not self.mv_distribution:
|
|
@@ -470,28 +469,22 @@ class TrainerMon:
|
|
|
470
469
|
get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
|
|
471
470
|
get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
|
|
472
471
|
|
|
473
|
-
def generate_wgrad_metrics(self):
|
|
472
|
+
def generate_wgrad_metrics(self, post_grad_dict):
|
|
474
473
|
if not self.wg_distribution:
|
|
475
474
|
return {}, {}
|
|
476
475
|
|
|
477
476
|
if self.weight_hooked:
|
|
478
477
|
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
479
478
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
continue
|
|
488
|
-
tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
489
|
-
self._register_param_call_id("hook_optimizer", tag)
|
|
490
|
-
grad_dict[tag] = grad
|
|
479
|
+
get_metrics(self.ops, post_grad_dict, self.eps, self.grad_context.post)
|
|
480
|
+
reduced_grad = self.grad_context.post
|
|
481
|
+
|
|
482
|
+
if self.weight_hooked:
|
|
483
|
+
unreduced_grad = self.grad_context.acc_metric
|
|
484
|
+
else:
|
|
485
|
+
unreduced_grad = self.grad_context.pre
|
|
491
486
|
|
|
492
|
-
|
|
493
|
-
unreduced_grad = self.grad_context.acc_metric if self.weight_hooked else self.grad_context.pre
|
|
494
|
-
return self.grad_context.post, unreduced_grad
|
|
487
|
+
return reduced_grad, unreduced_grad
|
|
495
488
|
|
|
496
489
|
def generate_xy_metrics(self):
|
|
497
490
|
actv = {}
|
|
@@ -517,6 +510,17 @@ class TrainerMon:
|
|
|
517
510
|
def write_adhoc_check(self, step):
|
|
518
511
|
self.tensor_metrics.flush(self.summary_writer)
|
|
519
512
|
|
|
513
|
+
def write_stack_info(self):
|
|
514
|
+
stack_data = []
|
|
515
|
+
header = ["module_name", "stack_info"]
|
|
516
|
+
stack_data.append(header)
|
|
517
|
+
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
518
|
+
stack_data.append([fwd_context.module_name, fwd_context.stack])
|
|
519
|
+
filepath = os.path.join(self.tensorboard_dir, f'stack_info.csv')
|
|
520
|
+
if not os.path.exists(filepath):
|
|
521
|
+
data_frame = pd.DataFrame(columns=stack_data)
|
|
522
|
+
write_df_to_csv(data_frame, filepath)
|
|
523
|
+
|
|
520
524
|
def write_xy_tb(self, step):
|
|
521
525
|
if not self.xy_distribution:
|
|
522
526
|
return
|
|
@@ -531,7 +535,10 @@ class TrainerMon:
|
|
|
531
535
|
def write_param_tb(self, opt_context):
|
|
532
536
|
if not self.param_distribution:
|
|
533
537
|
return
|
|
534
|
-
|
|
538
|
+
param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.PRE_PARAM in k}
|
|
539
|
+
updated_param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.POST_PARAM in k}
|
|
540
|
+
self.summary_writer.write_metrics(self.ops, param_metrics, opt_context.step, MonitorConst.PRE_PARAM)
|
|
541
|
+
self.summary_writer.write_metrics(self.ops, updated_param_metrics, opt_context.step, MonitorConst.POST_PARAM)
|
|
535
542
|
|
|
536
543
|
def write_mv_tb(self, opt_context):
|
|
537
544
|
if not self.mv_distribution:
|
|
@@ -545,10 +552,11 @@ class TrainerMon:
|
|
|
545
552
|
if not self.wg_distribution:
|
|
546
553
|
return
|
|
547
554
|
|
|
548
|
-
if self.
|
|
549
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.
|
|
555
|
+
if self.weight_hooked:
|
|
556
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced',
|
|
557
|
+
use_micro_step=self.monitor_mbs_grad)
|
|
550
558
|
else:
|
|
551
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.
|
|
559
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
|
|
552
560
|
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
553
561
|
|
|
554
562
|
def hook_optimizer(self, optimizer):
|
|
@@ -570,21 +578,23 @@ class TrainerMon:
|
|
|
570
578
|
# skip generate metrics
|
|
571
579
|
if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0:
|
|
572
580
|
return
|
|
573
|
-
if MonitorConst.DEEPSPEED_ZERO_OPT_FILTER in self.optimizer_class: # use deepspeed with zero1/2/3
|
|
574
|
-
if not self.name2indices:
|
|
575
|
-
self.name2indices = self.optimizer_mon.get_param_index(self.param2name, self.name2index, optimizer)
|
|
576
|
-
mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name, self.name2indices)
|
|
577
|
-
self.param2name = mv_result.grad
|
|
578
|
-
else:
|
|
579
|
-
mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name)
|
|
580
|
-
context.param_exp_avg = mv_result.exp_avg
|
|
581
|
-
context.param_exp_avg_sq = mv_result.exp_avg_sq
|
|
582
|
-
context.param_adam_update = mv_result.update
|
|
583
|
-
context.param_adam_ratio = mv_result.ratio
|
|
584
581
|
|
|
585
|
-
|
|
582
|
+
grad_dict = {}
|
|
583
|
+
if self.wg_distribution:
|
|
584
|
+
grad_dict = self.optimizer_mon.fetch_grad(self, self.param2name)
|
|
585
|
+
|
|
586
|
+
mv_result = None
|
|
587
|
+
if self.mv_distribution or self.ur_distribution or self.mg_direction:
|
|
588
|
+
mv_result = self.optimizer_mon.fetch_mv(self, self.param2name)
|
|
589
|
+
if mv_result:
|
|
590
|
+
context.param_exp_avg = mv_result.exp_avg
|
|
591
|
+
context.param_exp_avg_sq = mv_result.exp_avg_sq
|
|
592
|
+
context.param_adam_update = mv_result.update
|
|
593
|
+
context.param_adam_ratio = mv_result.ratio
|
|
594
|
+
|
|
595
|
+
self.generate_wgrad_metrics(grad_dict)
|
|
586
596
|
self.generate_mv_metrics(context)
|
|
587
|
-
self.generate_param_metrics(context)
|
|
597
|
+
self.generate_param_metrics(context, MonitorConst.PRE_PARAM)
|
|
588
598
|
|
|
589
599
|
tbtag_tensor_map = {}
|
|
590
600
|
if self.mg_direction:
|
|
@@ -612,17 +622,15 @@ class TrainerMon:
|
|
|
612
622
|
context.metric_dict = metric_dict
|
|
613
623
|
return
|
|
614
624
|
|
|
615
|
-
def
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
out = func(*args, **kwargs)
|
|
619
|
-
return out
|
|
620
|
-
return wrapper
|
|
625
|
+
def optimizer_post_step_hook(optimizer, args, kwargs):
|
|
626
|
+
context = self.optimizer_context[optimizer]
|
|
627
|
+
self.generate_param_metrics(context, MonitorConst.POST_PARAM)
|
|
621
628
|
|
|
622
629
|
if self.optimizer_hooked:
|
|
623
630
|
return
|
|
624
631
|
|
|
625
|
-
|
|
632
|
+
self.pre_step_hooks.append(optimizer_pre_step_hook)
|
|
633
|
+
self.post_step_hooks.append(optimizer_post_step_hook)
|
|
626
634
|
|
|
627
635
|
self.optimizer_hooked = True
|
|
628
636
|
return
|
|
@@ -682,6 +690,12 @@ class TrainerMon:
|
|
|
682
690
|
self.write_mv_tb(context)
|
|
683
691
|
self.write_param_tb(context)
|
|
684
692
|
self.write_adhoc_check(context.step)
|
|
693
|
+
if self.stack_info:
|
|
694
|
+
self.write_stack_info()
|
|
695
|
+
self.stack_info = False
|
|
696
|
+
for handle in self.handles["stack"]:
|
|
697
|
+
handle.remove()
|
|
698
|
+
self.handles["stack"].clear()
|
|
685
699
|
|
|
686
700
|
if self.ur_distribution:
|
|
687
701
|
for param_name, _ in context.param_adam_update.items():
|
|
@@ -714,13 +728,16 @@ class TrainerMon:
|
|
|
714
728
|
|
|
715
729
|
def patch_step(func, optimizer):
|
|
716
730
|
def wrapper(*args, **kwargs):
|
|
731
|
+
for hook in self.pre_step_hooks:
|
|
732
|
+
hook(optimizer, args, kwargs)
|
|
717
733
|
out = func(*args, **kwargs)
|
|
734
|
+
for hook in self.post_step_hooks:
|
|
735
|
+
hook(optimizer, args, kwargs)
|
|
718
736
|
step_final_hook(optimizer, args, kwargs)
|
|
719
737
|
return out
|
|
720
738
|
return wrapper
|
|
721
739
|
|
|
722
740
|
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
723
|
-
self.origin_step_func = optimizer.__class__.step
|
|
724
741
|
return
|
|
725
742
|
|
|
726
743
|
def hook_modules(self):
|
|
@@ -764,6 +781,16 @@ class TrainerMon:
|
|
|
764
781
|
BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
|
|
765
782
|
return
|
|
766
783
|
|
|
784
|
+
def register_param_call_id(self, hook_name: str, key: str):
|
|
785
|
+
"""
|
|
786
|
+
:param hook_name:
|
|
787
|
+
:param key: str, '0:relu_0/output_grad'
|
|
788
|
+
:return:
|
|
789
|
+
"""
|
|
790
|
+
logger.debug(f"{hook_name} {key}: {self.call_id}")
|
|
791
|
+
self.param_name_call_id[key] = self.call_id
|
|
792
|
+
self.call_id += 1
|
|
793
|
+
|
|
767
794
|
def _remove_all_hooks(self, optimizer):
|
|
768
795
|
# 清空hook handle
|
|
769
796
|
for handle in self.handles['xy']:
|
|
@@ -789,14 +816,18 @@ class TrainerMon:
|
|
|
789
816
|
logger.info("remove _ParamAndGradBucketGroup start_grad_sync")
|
|
790
817
|
except ImportError:
|
|
791
818
|
pass
|
|
792
|
-
|
|
819
|
+
elif self.fsdp_post_backward_hook: # fsdp
|
|
820
|
+
torch.distributed.fsdp._runtime_utils._post_backward_hook = self.fsdp_post_backward_hook
|
|
821
|
+
logger.info("remove patch_post_backward_hook in fsdp.")
|
|
822
|
+
else: # not megatron and not fsdp
|
|
793
823
|
for handle in self.handles['wgrads']:
|
|
794
824
|
handle.remove()
|
|
795
825
|
self.handles['wgrads'].clear()
|
|
796
826
|
self.weight_hooked = False
|
|
797
827
|
|
|
798
828
|
if self.optimizer_hooked:
|
|
799
|
-
|
|
829
|
+
self.pre_step_hooks.clear()
|
|
830
|
+
self.post_step_hooks.clear()
|
|
800
831
|
|
|
801
832
|
for _, context in self.optimizer_context.items():
|
|
802
833
|
context.reset()
|
|
@@ -811,7 +842,6 @@ class TrainerMon:
|
|
|
811
842
|
|
|
812
843
|
# 清空节点缓存
|
|
813
844
|
self.param2name.clear()
|
|
814
|
-
self.name2index.clear()
|
|
815
845
|
self.name2indices.clear()
|
|
816
846
|
self.name2param.clear()
|
|
817
847
|
self.duplicate_param.clear()
|
|
@@ -871,27 +901,33 @@ class TrainerMon:
|
|
|
871
901
|
return False
|
|
872
902
|
|
|
873
903
|
def _register_chunk(self, model_chunk, prefix):
|
|
874
|
-
index = 0
|
|
875
904
|
for (param_name, param) in model_chunk.named_parameters():
|
|
876
905
|
if not param.requires_grad:
|
|
877
906
|
continue
|
|
907
|
+
if not self.fsdp_wrapped_module and param_name.startswith("_fsdp_wrapped_module"):
|
|
908
|
+
self.fsdp_wrapped_module = True
|
|
878
909
|
if self._is_target_param(param_name, param, prefix):
|
|
879
910
|
name = prefix + squash_param_name(param_name, self.squash_name)
|
|
880
911
|
if name in self.param2name.values():
|
|
881
912
|
name = prefix + param_name
|
|
882
913
|
self.param2name[param] = name
|
|
883
914
|
self.name2param[name] = param
|
|
884
|
-
self.name2index[name] = index
|
|
885
915
|
|
|
886
916
|
if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group):
|
|
887
917
|
self.duplicate_param[name] = True
|
|
888
918
|
if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
|
|
889
919
|
self.duplicate_param[name] = True
|
|
920
|
+
|
|
921
|
+
keywords = [
|
|
922
|
+
MonitorConst.PRE_GRAD,
|
|
923
|
+
MonitorConst.POST_GRAD,
|
|
924
|
+
MonitorConst.PRE_PARAM,
|
|
925
|
+
MonitorConst.POST_PARAM
|
|
926
|
+
]
|
|
890
927
|
self.name2tag[name] = {
|
|
891
|
-
|
|
892
|
-
|
|
928
|
+
k: get_summary_writer_tag_name(name, k, self.rank)
|
|
929
|
+
for k in keywords
|
|
893
930
|
}
|
|
894
|
-
index += 1
|
|
895
931
|
|
|
896
932
|
def _register_param_name(self):
|
|
897
933
|
for vpp_stage, model_chunk in enumerate(self.model):
|
|
@@ -914,11 +950,17 @@ class TrainerMon:
|
|
|
914
950
|
# nothing to hook
|
|
915
951
|
return 0
|
|
916
952
|
|
|
917
|
-
def fwd_hook_fun(module,
|
|
953
|
+
def fwd_hook_fun(module, args, kwargs, module_output, name):
|
|
918
954
|
if not module.training or is_recomputation():
|
|
919
955
|
# 1 only monitor training stage.
|
|
920
956
|
# 2 when open recompute, skip recomputed forward stage.
|
|
921
957
|
return
|
|
958
|
+
|
|
959
|
+
module_input = [tensor for tensor in args if torch.is_tensor(tensor)]
|
|
960
|
+
if kwargs:
|
|
961
|
+
kwargs_tensors = [tensor for tensor in kwargs.values() if torch.is_tensor(tensor)]
|
|
962
|
+
module_input.extend(kwargs_tensors)
|
|
963
|
+
|
|
922
964
|
if module not in self.module_fwd_hook_context_by_module:
|
|
923
965
|
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
924
966
|
context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
|
|
@@ -927,34 +969,20 @@ class TrainerMon:
|
|
|
927
969
|
Const.INPUT: get_param_struct(module_input),
|
|
928
970
|
Const.OUTPUT: get_param_struct(module_output)
|
|
929
971
|
}
|
|
972
|
+
|
|
930
973
|
if self.print_struct:
|
|
931
974
|
self.module_struct[context.module_name].update(context.struct)
|
|
932
975
|
return
|
|
933
|
-
|
|
934
|
-
context.set_format_by_arg(Const.INPUT, self.config['targets'])
|
|
935
|
-
context.set_format_by_arg(Const.OUTPUT, self.config['targets'])
|
|
936
|
-
if not context.format_by_arg:
|
|
937
|
-
return
|
|
938
|
-
if not context.verified:
|
|
939
|
-
context.focused_in_col = validate_config_spec(context.format_by_arg[Const.INPUT],
|
|
940
|
-
module_input, context.module_name,
|
|
941
|
-
Const.INPUT)
|
|
942
|
-
context.focused_out_col = validate_config_spec(context.format_by_arg[Const.OUTPUT],
|
|
943
|
-
module_output, context.module_name,
|
|
944
|
-
Const.OUTPUT)
|
|
945
|
-
context.verified = True
|
|
946
|
-
# expect output be tensor type
|
|
976
|
+
|
|
947
977
|
tbtag_tensor_map = {}
|
|
948
|
-
cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
|
|
949
978
|
tbtag_tensor_map.update(
|
|
950
979
|
self.build_tbtag_tensor_map(
|
|
951
|
-
f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
952
|
-
MonitorConst.ACTV,
|
|
953
|
-
cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
|
|
980
|
+
f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
981
|
+
MonitorConst.ACTV, module_input))
|
|
954
982
|
tbtag_tensor_map.update(
|
|
955
983
|
self.build_tbtag_tensor_map(
|
|
956
|
-
f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
957
|
-
MonitorConst.ACTV,
|
|
984
|
+
f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
985
|
+
MonitorConst.ACTV, module_output))
|
|
958
986
|
|
|
959
987
|
get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
|
|
960
988
|
context.micro_step += 1
|
|
@@ -972,31 +1000,17 @@ class TrainerMon:
|
|
|
972
1000
|
if self.print_struct:
|
|
973
1001
|
self.module_struct[context.module_name].update(context.struct)
|
|
974
1002
|
return
|
|
975
|
-
if not context.format_by_arg:
|
|
976
|
-
context.set_format_by_arg(MonitorConst.INPUT_GRAD, self.config['targets'])
|
|
977
|
-
context.set_format_by_arg(MonitorConst.OUTPUT_GRAD, self.config['targets'])
|
|
978
|
-
if not context.format_by_arg:
|
|
979
|
-
return
|
|
980
|
-
if not context.verified:
|
|
981
|
-
context.focused_in_col = validate_config_spec(
|
|
982
|
-
context.format_by_arg[MonitorConst.INPUT_GRAD],
|
|
983
|
-
input_grad, context.module_name, MonitorConst.INPUT_GRAD)
|
|
984
|
-
context.focused_out_col = validate_config_spec(
|
|
985
|
-
context.format_by_arg[MonitorConst.OUTPUT_GRAD],
|
|
986
|
-
output_grad, context.module_name, MonitorConst.OUTPUT_GRAD)
|
|
987
|
-
context.verified = True
|
|
988
1003
|
|
|
989
1004
|
tbtag_tensor_map = {}
|
|
990
|
-
cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
|
|
991
1005
|
tbtag_tensor_map.update(
|
|
992
1006
|
self.build_tbtag_tensor_map(
|
|
993
|
-
f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
994
|
-
MonitorConst.
|
|
995
|
-
|
|
1007
|
+
f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
1008
|
+
MonitorConst.ACTVGRAD, input_grad))
|
|
1009
|
+
|
|
996
1010
|
tbtag_tensor_map.update(
|
|
997
1011
|
self.build_tbtag_tensor_map(
|
|
998
|
-
f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
999
|
-
MonitorConst.
|
|
1012
|
+
f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
1013
|
+
MonitorConst.ACTVGRAD, output_grad))
|
|
1000
1014
|
|
|
1001
1015
|
if context.micro_step == 0 and context.actvgrad:
|
|
1002
1016
|
logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
|
|
@@ -1010,17 +1024,30 @@ class TrainerMon:
|
|
|
1010
1024
|
context.micro_step = 0
|
|
1011
1025
|
return
|
|
1012
1026
|
|
|
1027
|
+
def stack_hook(module, args, kwargs, module_output, name):
|
|
1028
|
+
if module not in self.module_fwd_hook_context_by_module:
|
|
1029
|
+
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
1030
|
+
context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
|
|
1031
|
+
context.stack = analyze_api_call_stack(name)
|
|
1032
|
+
return
|
|
1033
|
+
|
|
1013
1034
|
if self.backward_only and self.forward_only:
|
|
1014
1035
|
logger.warning('not enable backward_only and forward_only simultaneously')
|
|
1015
1036
|
|
|
1016
1037
|
hooked_count = 0
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
name =
|
|
1020
|
-
|
|
1021
|
-
|
|
1038
|
+
for module_name, submodule in module.named_modules():
|
|
1039
|
+
if self.stack_info:
|
|
1040
|
+
name = vpp_stage + squash_param_name(module_name, self.squash_name)
|
|
1041
|
+
handle = submodule.register_forward_hook(partial(stack_hook, name=name), with_kwargs=True)
|
|
1042
|
+
self.handles['stack'].append(handle)
|
|
1043
|
+
name = self._is_target_module(module_name, target_names, vpp_stage)
|
|
1044
|
+
if not name:
|
|
1045
|
+
continue
|
|
1046
|
+
if submodule.__class__.__name__ == "FullyShardedDataParallel":
|
|
1047
|
+
continue
|
|
1048
|
+
if self.xy_distribution or self.print_struct:
|
|
1022
1049
|
if not self.backward_only:
|
|
1023
|
-
handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name))
|
|
1050
|
+
handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name), with_kwargs=True)
|
|
1024
1051
|
self.handles['xy'].append(handle)
|
|
1025
1052
|
if not self.forward_only and not self.has_register_backward_hook(name, submodule):
|
|
1026
1053
|
handle = submodule.register_full_backward_hook(bwd_hook_fun)
|
|
@@ -1049,7 +1076,7 @@ class TrainerMon:
|
|
|
1049
1076
|
if tag is None:
|
|
1050
1077
|
continue
|
|
1051
1078
|
grad_dict[tag] = grad
|
|
1052
|
-
self.
|
|
1079
|
+
self.register_param_call_id("sync_grad_func", tag)
|
|
1053
1080
|
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
1054
1081
|
out = sync_grad_func(bucket)
|
|
1055
1082
|
return out
|
|
@@ -1058,7 +1085,14 @@ class TrainerMon:
|
|
|
1058
1085
|
|
|
1059
1086
|
if not self.wg_distribution:
|
|
1060
1087
|
return
|
|
1088
|
+
if self.fsdp_wrapped_module:
|
|
1089
|
+
# patch fsdp _runtime_utils._post_backward_hook
|
|
1090
|
+
self._patch_fsdp_post_backward_hook()
|
|
1091
|
+
return
|
|
1061
1092
|
|
|
1093
|
+
if self.monitor_mbs_grad:
|
|
1094
|
+
self._hook_weights()
|
|
1095
|
+
return
|
|
1062
1096
|
try:
|
|
1063
1097
|
from megatron.core.distributed.param_and_grad_buffer import Bucket
|
|
1064
1098
|
self.origin_start_grad_sync = Bucket.start_grad_sync
|
|
@@ -1076,19 +1110,62 @@ class TrainerMon:
|
|
|
1076
1110
|
logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0")
|
|
1077
1111
|
except ImportError:
|
|
1078
1112
|
self.enable_megatron = False | self.enable_megatron
|
|
1113
|
+
if self.enable_megatron:
|
|
1114
|
+
return
|
|
1079
1115
|
|
|
1080
|
-
|
|
1081
|
-
|
|
1116
|
+
# default hook weights
|
|
1117
|
+
self._hook_weights()
|
|
1118
|
+
|
|
1119
|
+
def _patch_fsdp_post_backward_hook(self):
|
|
1120
|
+
"""
|
|
1121
|
+
FSDP runtime 需要处理整个forward和backward计算和通信的流程,通过override nn.Module的forward,定义相应的逻辑。
|
|
1122
|
+
对AccumulateGrad对象注册hook,可以在backward计算grad后立刻执行,在reduce_scatter操作前采集梯度累计后,通信聚合前的梯度。
|
|
1123
|
+
每个forward阶段,fsdp对AccumulateGrad重复注册hook方法,monitor工具内注册hook无法生效,
|
|
1124
|
+
因此对_post_backward_hook进行patch,在backward后,reduce_scatter前采集梯度。
|
|
1125
|
+
"""
|
|
1126
|
+
def patch_post_backward_hook(_post_backward_hook):
|
|
1127
|
+
def wrapper(state, handle, *unused):
|
|
1128
|
+
grad_dict = {}
|
|
1129
|
+
offset = 0
|
|
1130
|
+
for param, name in self.param2name.items():
|
|
1131
|
+
limit = param.numel()
|
|
1132
|
+
if not limit:
|
|
1133
|
+
continue
|
|
1134
|
+
grad = handle.flat_param.grad[offset:offset + limit]
|
|
1135
|
+
offset += limit
|
|
1136
|
+
tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
|
|
1137
|
+
if tag is None:
|
|
1138
|
+
continue
|
|
1139
|
+
grad_dict[tag] = grad
|
|
1140
|
+
self.register_param_call_id("_post_backward_hook", tag)
|
|
1141
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
1142
|
+
out = _post_backward_hook(state, handle, *unused)
|
|
1143
|
+
return out
|
|
1144
|
+
|
|
1145
|
+
return wrapper
|
|
1146
|
+
|
|
1147
|
+
logger.info("Patch fsdp _post_backward_hook, collect pre_grad metrics.")
|
|
1148
|
+
self.fsdp_post_backward_hook = torch.distributed.fsdp._runtime_utils._post_backward_hook
|
|
1149
|
+
torch.distributed.fsdp._runtime_utils._post_backward_hook = \
|
|
1150
|
+
patch_post_backward_hook(torch.distributed.fsdp._runtime_utils._post_backward_hook)
|
|
1082
1151
|
|
|
1083
1152
|
def _hook_weights(self):
|
|
1153
|
+
"""
|
|
1154
|
+
遍历参数的梯度生成函数(grad_acc),并挂载hook,以便在该参数所有梯度计算后,采集通信聚合前梯度数据。
|
|
1155
|
+
"""
|
|
1084
1156
|
context = self.grad_context
|
|
1085
1157
|
|
|
1086
1158
|
@torch.no_grad
|
|
1087
|
-
def param_hook(*args, context_dict, param,
|
|
1159
|
+
def param_hook(*args, context_dict, param, name):
|
|
1160
|
+
key = name
|
|
1161
|
+
if self.monitor_mbs_grad:
|
|
1162
|
+
key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
|
|
1163
|
+
|
|
1164
|
+
key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
|
|
1165
|
+
self.register_param_call_id("param_hook", key)
|
|
1088
1166
|
param.micro_step += 1
|
|
1089
|
-
|
|
1090
|
-
if param.micro_step == self.micro_batch_number:
|
|
1091
|
-
param.micro_step = 0
|
|
1167
|
+
|
|
1168
|
+
if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
|
|
1092
1169
|
if self.params_have_main_grad:
|
|
1093
1170
|
grad = param.main_grad
|
|
1094
1171
|
else:
|
|
@@ -1097,25 +1174,17 @@ class TrainerMon:
|
|
|
1097
1174
|
grad = grad.float()
|
|
1098
1175
|
context_dict[key] = grad.clone()
|
|
1099
1176
|
|
|
1177
|
+
if param.micro_step == self.micro_batch_number:
|
|
1178
|
+
param.micro_step = 0
|
|
1179
|
+
|
|
1100
1180
|
logger.info("hooking weights.")
|
|
1101
1181
|
for param, name in self.param2name.items():
|
|
1102
|
-
key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
|
|
1103
1182
|
setattr(param, 'micro_step', 0)
|
|
1104
1183
|
param_tmp = param.expand_as(param)
|
|
1105
1184
|
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
|
1106
1185
|
handle = grad_acc.register_hook(
|
|
1107
|
-
partial(param_hook, context_dict=context.acc, param=param,
|
|
1186
|
+
partial(param_hook, context_dict=context.acc, param=param, name=name))
|
|
1108
1187
|
self.grad_accs.append(grad_acc)
|
|
1109
1188
|
self.handles['wgrads'].append(handle)
|
|
1110
1189
|
|
|
1111
1190
|
self.weight_hooked = True
|
|
1112
|
-
|
|
1113
|
-
def _register_param_call_id(self, hook_name: str, key: str):
|
|
1114
|
-
"""
|
|
1115
|
-
:param hook_name:
|
|
1116
|
-
:param key: str, '0:relu_0/output_grad'
|
|
1117
|
-
:return:
|
|
1118
|
-
"""
|
|
1119
|
-
logger.debug(f"{hook_name} {key}: {self.call_id}")
|
|
1120
|
-
self.param_name_call_id[key] = self.call_id
|
|
1121
|
-
self.call_id += 1
|