mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -22,26 +22,29 @@ from functools import partial
|
|
|
22
22
|
import pytz
|
|
23
23
|
import torch
|
|
24
24
|
import torch.distributed as dist
|
|
25
|
+
import pandas as pd
|
|
25
26
|
from torch.utils.hooks import BackwardHook
|
|
26
27
|
|
|
27
28
|
from msprobe.core.common.const import MonitorConst, Const
|
|
28
29
|
from msprobe.core.common.file_utils import load_json, save_json
|
|
30
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
31
|
+
from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter
|
|
32
|
+
from msprobe.core.common.file_utils import write_df_to_csv
|
|
33
|
+
from msprobe.core.common.utils import analyze_api_call_stack
|
|
29
34
|
from msprobe.pytorch.common.log import logger
|
|
30
|
-
from msprobe.pytorch.common.utils import is_recomputation
|
|
31
|
-
from msprobe.pytorch.monitor.
|
|
32
|
-
from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
|
|
33
|
-
CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
35
|
+
from msprobe.pytorch.common.utils import is_recomputation, is_float8_tensor
|
|
36
|
+
from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
34
37
|
from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
|
|
35
38
|
get_process_group
|
|
36
39
|
from msprobe.pytorch.monitor.features import get_sign_matches
|
|
37
40
|
from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
|
|
38
41
|
TensorMetrics, squash_param_name
|
|
39
|
-
from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
|
|
40
42
|
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
|
|
41
43
|
from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \
|
|
42
|
-
get_output_base_dir, get_target_output_dir
|
|
44
|
+
get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor
|
|
43
45
|
from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
|
|
44
46
|
|
|
47
|
+
|
|
45
48
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
46
49
|
if not torch_version_above_or_equal_2:
|
|
47
50
|
raise ValueError("monitor require torch>=2.0")
|
|
@@ -71,36 +74,7 @@ class ModuleHookContext:
|
|
|
71
74
|
self.actvgrad = []
|
|
72
75
|
self.module_name = module_name
|
|
73
76
|
self.struct = {}
|
|
74
|
-
self.
|
|
75
|
-
self.verified = False
|
|
76
|
-
self.focused_in_col = 0
|
|
77
|
-
self.focused_out_col = 0
|
|
78
|
-
|
|
79
|
-
def set_format_by_arg(self, key_name: str, target_config: dict):
|
|
80
|
-
""" 按照监控对象配置format_by_arg
|
|
81
|
-
1) module_name 在 target 中配置监控对象
|
|
82
|
-
2) module_name 未在 targets 中配置,且 all_xy 全量监控
|
|
83
|
-
3) module_name 未在 targets 中配置,且 all_xy 未全量监控
|
|
84
|
-
|
|
85
|
-
:param key_name: str, one of [input, output, input_grad, output_grad]
|
|
86
|
-
:param target_config: target obj in config json.
|
|
87
|
-
:return:
|
|
88
|
-
"""
|
|
89
|
-
cared = target_config.get(self.module_name, self.struct)
|
|
90
|
-
if key_name in cared:
|
|
91
|
-
target_module_config = cared[key_name]
|
|
92
|
-
if isinstance(target_module_config, dict):
|
|
93
|
-
# current cared is self.struct, monitor all data for module_name
|
|
94
|
-
self.format_by_arg[key_name] = target_module_config.get('config')
|
|
95
|
-
elif isinstance(target_module_config, str):
|
|
96
|
-
# current cared is target_config[self.module_name]
|
|
97
|
-
self.format_by_arg[key_name] = target_module_config
|
|
98
|
-
else:
|
|
99
|
-
logger.warning_on_rank_0(f"target module config error, result maybe empty."
|
|
100
|
-
f"module_name: {self.module_name}, key_name: {key_name}")
|
|
101
|
-
self.format_by_arg[key_name] = None
|
|
102
|
-
else:
|
|
103
|
-
self.format_by_arg[key_name] = self.struct.get(key_name).get('config')
|
|
77
|
+
self.stack = ""
|
|
104
78
|
|
|
105
79
|
def reset(self):
|
|
106
80
|
self.actv.clear()
|
|
@@ -176,15 +150,16 @@ class GradContext:
|
|
|
176
150
|
class TrainerMon:
|
|
177
151
|
tensor_metrics = TensorMetrics()
|
|
178
152
|
|
|
179
|
-
|
|
153
|
+
# 保留原opt_ty参数, 兼容msprobe1.2.2前旧版本
|
|
154
|
+
def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None:
|
|
180
155
|
# TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
|
|
181
156
|
self.config_file_path = config_file_path
|
|
182
157
|
self.process_group = get_process_group(process_group)
|
|
183
158
|
self.params_have_main_grad = params_have_main_grad
|
|
184
159
|
self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
185
160
|
self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
186
|
-
self.origin_step_func = None
|
|
187
161
|
self.origin_start_grad_sync = None
|
|
162
|
+
self.fsdp_post_backward_hook = None
|
|
188
163
|
self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
|
|
189
164
|
self.config = load_json(config_file_path)
|
|
190
165
|
validate_config(self.config)
|
|
@@ -219,9 +194,10 @@ class TrainerMon:
|
|
|
219
194
|
self.dp_group = None
|
|
220
195
|
self.tp_group = None
|
|
221
196
|
self.enable_megatron = False
|
|
197
|
+
self.fsdp_wrapped_module = False
|
|
222
198
|
self.micro_batch_number = 1
|
|
223
|
-
self.optimizer_class = None
|
|
224
199
|
self.optimizer_mon = None
|
|
200
|
+
self.optimizer_trans = None
|
|
225
201
|
|
|
226
202
|
# TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
|
|
227
203
|
self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
@@ -231,7 +207,6 @@ class TrainerMon:
|
|
|
231
207
|
self.grad_context = GradContext()
|
|
232
208
|
self.handles = defaultdict(list)
|
|
233
209
|
self.param2name = defaultdict(str)
|
|
234
|
-
self.name2index = defaultdict()
|
|
235
210
|
self.name2indices = defaultdict()
|
|
236
211
|
self.name2param = {}
|
|
237
212
|
self.duplicate_param = {}
|
|
@@ -244,6 +219,8 @@ class TrainerMon:
|
|
|
244
219
|
self.optimizer_hooked = False
|
|
245
220
|
self.param_registered = False
|
|
246
221
|
self.struct_printed = False
|
|
222
|
+
self.pre_step_hooks = []
|
|
223
|
+
self.post_step_hooks = []
|
|
247
224
|
|
|
248
225
|
# 动静态区分
|
|
249
226
|
self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
|
|
@@ -314,6 +291,8 @@ class TrainerMon:
|
|
|
314
291
|
self.param_distribution = self.config.get("param_distribution", False)
|
|
315
292
|
self.mg_direction = self.config.get('mg_direction', False)
|
|
316
293
|
self.cc_distribution = self.config.get("cc_distribution", {})
|
|
294
|
+
self.stack_info = self.config.get('stack_info', False)
|
|
295
|
+
self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False)
|
|
317
296
|
|
|
318
297
|
if not self.cc_distribution.get('enable', False):
|
|
319
298
|
self.cc_log_only = False
|
|
@@ -322,8 +301,6 @@ class TrainerMon:
|
|
|
322
301
|
self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
|
|
323
302
|
self.cc_logged_stack = defaultdict(set)
|
|
324
303
|
self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
|
|
325
|
-
self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
326
|
-
api_register.redirect_api()
|
|
327
304
|
|
|
328
305
|
self.common_info()
|
|
329
306
|
|
|
@@ -336,11 +313,11 @@ class TrainerMon:
|
|
|
336
313
|
|
|
337
314
|
# 初始化writer, 创建输出目录
|
|
338
315
|
if self.format not in FORMAT_MAPPING:
|
|
339
|
-
logger.
|
|
316
|
+
logger.warning(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
|
|
340
317
|
self.format = MonitorConst.CSV
|
|
341
318
|
|
|
342
319
|
if self.ur_distribution and self.format != 'tensorboard':
|
|
343
|
-
logger.
|
|
320
|
+
logger.warning("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution")
|
|
344
321
|
self.ur_distribution = False
|
|
345
322
|
|
|
346
323
|
writer = FORMAT_MAPPING[self.format]
|
|
@@ -363,19 +340,6 @@ class TrainerMon:
|
|
|
363
340
|
self.rank)
|
|
364
341
|
self.anomaly_data_writer.init_detected_json()
|
|
365
342
|
|
|
366
|
-
def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
|
|
367
|
-
rank = None
|
|
368
|
-
if dist.is_initialized():
|
|
369
|
-
rank = dist.get_rank()
|
|
370
|
-
if (rank not in rank_list) and len(rank_list) != 0:
|
|
371
|
-
return
|
|
372
|
-
self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
|
|
373
|
-
|
|
374
|
-
def build_tbtag_tensor_map(self, module_name, tag, tensor):
|
|
375
|
-
key = get_summary_writer_tag_name(module_name, tag, self.rank)
|
|
376
|
-
self._register_param_call_id("_hook_module", key)
|
|
377
|
-
return {key: tensor}
|
|
378
|
-
|
|
379
343
|
def common_info(self):
|
|
380
344
|
if not self.xy_distribution:
|
|
381
345
|
logger.info_on_rank_0("> module input/output input_grad/output_grad is not monitored. ")
|
|
@@ -392,101 +356,38 @@ class TrainerMon:
|
|
|
392
356
|
if not self.cc_distribution.get('enable', False):
|
|
393
357
|
logger.info_on_rank_0("> cc operator is not monitored.")
|
|
394
358
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
hooked_count = 0
|
|
406
|
-
for vpp_stage, model_chunk in enumerate(self.model):
|
|
407
|
-
vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
408
|
-
targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
|
|
409
|
-
'targets'].keys()
|
|
410
|
-
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
411
|
-
|
|
412
|
-
logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
|
|
413
|
-
|
|
414
|
-
def clone_if_tensor(args):
|
|
415
|
-
if isinstance(args, tuple):
|
|
416
|
-
return tuple([clone_if_tensor(arg) for arg in args])
|
|
417
|
-
elif isinstance(args, torch.Tensor):
|
|
418
|
-
return args.clone()
|
|
419
|
-
else:
|
|
420
|
-
return args
|
|
421
|
-
|
|
422
|
-
@torch.no_grad
|
|
423
|
-
def wrap_hook_setup(setup):
|
|
424
|
-
def wrapped_setup(*args, **kwargs):
|
|
425
|
-
args = setup(*args, **kwargs)
|
|
426
|
-
args = clone_if_tensor(args)
|
|
427
|
-
return args
|
|
428
|
-
|
|
429
|
-
return wrapped_setup
|
|
430
|
-
|
|
431
|
-
BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
|
|
432
|
-
|
|
433
|
-
return
|
|
434
|
-
|
|
435
|
-
def generate_param_metrics(self, opt_context):
|
|
436
|
-
if not self.param_distribution:
|
|
437
|
-
return
|
|
438
|
-
get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
|
|
439
|
-
|
|
440
|
-
def generate_mv_metrics(self, opt_context):
|
|
441
|
-
if not self.mv_distribution:
|
|
442
|
-
return
|
|
443
|
-
opt_context.exp_avg_metric = {}
|
|
444
|
-
opt_context.exp_avg_sq_metric = {}
|
|
445
|
-
m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
|
|
446
|
-
v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
|
|
447
|
-
get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
|
|
448
|
-
get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
|
|
449
|
-
|
|
450
|
-
def generate_wgrad_metrics(self):
|
|
451
|
-
if not self.wg_distribution:
|
|
452
|
-
return {}, {}
|
|
453
|
-
|
|
454
|
-
if self.weight_hooked:
|
|
455
|
-
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
456
|
-
|
|
457
|
-
grad_dict = {}
|
|
458
|
-
for param, name in self.param2name.items():
|
|
459
|
-
if self.duplicate_param.get(name, False):
|
|
460
|
-
continue
|
|
461
|
-
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
462
|
-
if grad is None:
|
|
463
|
-
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
464
|
-
continue
|
|
465
|
-
tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
466
|
-
self._register_param_call_id("hook_optimizer", tag)
|
|
467
|
-
grad_dict[tag] = grad
|
|
359
|
+
# 保留原接口, 兼容msprobe1.2.2前旧版本
|
|
360
|
+
def monitor_gnorm_with_ad(self, model, optimizer=None, grad_acc_steps=1, tp_group=None, dp_group=None,
|
|
361
|
+
start_iteration=0):
|
|
362
|
+
if optimizer is None:
|
|
363
|
+
optimizer = getattr(self, "optimizer_trans", None) # 兼容老版本可传None的情况, 从set_wrapped_optimizer获取
|
|
364
|
+
if optimizer is None:
|
|
365
|
+
logger.error("monitor_gnorm_with_ad: please set_wrapped_optimizer before it or input optimizer!=None")
|
|
366
|
+
return
|
|
367
|
+
self.set_monitor(model, optimizer, grad_acc_steps, tp_group, dp_group, start_iteration)
|
|
468
368
|
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
369
|
+
# 保留原接口, 兼容msprobe1.2.2前旧版本
|
|
370
|
+
def set_wrapped_optimizer(self, optimizer):
|
|
371
|
+
self.optimizer_trans = optimizer
|
|
472
372
|
|
|
473
373
|
def set_monitor(
|
|
474
374
|
self,
|
|
475
375
|
model,
|
|
376
|
+
optimizer,
|
|
476
377
|
grad_acc_steps=1,
|
|
477
|
-
optimizer=None,
|
|
478
378
|
tp_group=None,
|
|
479
379
|
dp_group=None,
|
|
480
380
|
start_iteration=0
|
|
481
381
|
):
|
|
482
382
|
"""External interface"""
|
|
383
|
+
grad_acc_steps, start_iteration = validate_set_monitor(grad_acc_steps, start_iteration)
|
|
483
384
|
global start_step
|
|
484
385
|
start_step = start_iteration
|
|
485
386
|
logger.info(f'grad acc steps {grad_acc_steps}')
|
|
486
387
|
self.micro_batch_number = grad_acc_steps
|
|
487
388
|
self.dp_group = dp_group
|
|
488
389
|
self.tp_group = tp_group
|
|
489
|
-
self.optimizer_mon
|
|
390
|
+
self.optimizer_mon = OptimizerMonFactory.create_optimizer_mon(optimizer)
|
|
490
391
|
self.hook_step_final(optimizer)
|
|
491
392
|
if not isinstance(model, list):
|
|
492
393
|
model = [model]
|
|
@@ -502,18 +403,89 @@ class TrainerMon:
|
|
|
502
403
|
self.hook_optimizer(optimizer)
|
|
503
404
|
self._patch_grad_sync()
|
|
504
405
|
self.hook_modules()
|
|
406
|
+
if self.cc_distribution.get('enable', False):
|
|
407
|
+
self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
408
|
+
api_register.redirect_api()
|
|
505
409
|
self.monitoring = True
|
|
506
410
|
|
|
411
|
+
def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
|
|
412
|
+
rank = None
|
|
413
|
+
if dist.is_initialized():
|
|
414
|
+
rank = dist.get_rank()
|
|
415
|
+
if (rank not in rank_list) and len(rank_list) != 0:
|
|
416
|
+
return
|
|
417
|
+
self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
|
|
418
|
+
|
|
419
|
+
def build_tbtag_tensor_map(self, module_name, suffix, tag, tensor):
|
|
420
|
+
"""
|
|
421
|
+
:param module_name: str of module name
|
|
422
|
+
:param suffix:
|
|
423
|
+
:param tag:
|
|
424
|
+
:param tensor: torch.tensor or tuple/list of torch.tensor
|
|
425
|
+
:return: tensor_map
|
|
426
|
+
"""
|
|
427
|
+
tensor_map = {}
|
|
428
|
+
if isinstance(tensor, torch.Tensor):
|
|
429
|
+
tensor = [tensor]
|
|
430
|
+
if isinstance(tensor, tuple) or isinstance(tensor, list):
|
|
431
|
+
if len(tensor) == 1:
|
|
432
|
+
key = get_summary_writer_tag_name(module_name + suffix, tag, self.rank)
|
|
433
|
+
self.register_param_call_id("_hook_module", key)
|
|
434
|
+
tensor_map[key] = tensor[0]
|
|
435
|
+
else:
|
|
436
|
+
for i, tensor_i in enumerate(tensor):
|
|
437
|
+
key = get_summary_writer_tag_name(module_name + f"_{i}" + suffix, tag, self.rank)
|
|
438
|
+
self.register_param_call_id("_hook_module", key)
|
|
439
|
+
tensor_map[key] = tensor_i
|
|
440
|
+
return tensor_map
|
|
441
|
+
|
|
507
442
|
def generate_param_map(self, tag, param_tensor):
|
|
508
443
|
metrics = {}
|
|
509
444
|
for name in self.param2name.values():
|
|
510
445
|
key = get_summary_writer_tag_name(name, tag, self.rank)
|
|
511
|
-
self.
|
|
446
|
+
self.register_param_call_id("optimizer_pre_step_hook", key)
|
|
512
447
|
if name not in param_tensor or param_tensor[name] is None:
|
|
513
448
|
continue
|
|
514
449
|
metrics[key] = param_tensor[name]
|
|
515
450
|
return metrics
|
|
516
451
|
|
|
452
|
+
def generate_param_metrics(self, opt_context, stage=MonitorConst.PRE_PARAM):
|
|
453
|
+
if not self.param_distribution:
|
|
454
|
+
return
|
|
455
|
+
tag2param = {
|
|
456
|
+
self.name2tag.get(name, {}).get(stage): param
|
|
457
|
+
for name, param in self.name2param.items()
|
|
458
|
+
if param.numel() != 0
|
|
459
|
+
}
|
|
460
|
+
get_metrics(self.ops, tag2param, self.eps, opt_context.param_metric)
|
|
461
|
+
|
|
462
|
+
def generate_mv_metrics(self, opt_context):
|
|
463
|
+
if not self.mv_distribution:
|
|
464
|
+
return
|
|
465
|
+
opt_context.exp_avg_metric = {}
|
|
466
|
+
opt_context.exp_avg_sq_metric = {}
|
|
467
|
+
m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
|
|
468
|
+
v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
|
|
469
|
+
get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
|
|
470
|
+
get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
|
|
471
|
+
|
|
472
|
+
def generate_wgrad_metrics(self, post_grad_dict):
|
|
473
|
+
if not self.wg_distribution:
|
|
474
|
+
return {}, {}
|
|
475
|
+
|
|
476
|
+
if self.weight_hooked:
|
|
477
|
+
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
478
|
+
|
|
479
|
+
get_metrics(self.ops, post_grad_dict, self.eps, self.grad_context.post)
|
|
480
|
+
reduced_grad = self.grad_context.post
|
|
481
|
+
|
|
482
|
+
if self.weight_hooked:
|
|
483
|
+
unreduced_grad = self.grad_context.acc_metric
|
|
484
|
+
else:
|
|
485
|
+
unreduced_grad = self.grad_context.pre
|
|
486
|
+
|
|
487
|
+
return reduced_grad, unreduced_grad
|
|
488
|
+
|
|
517
489
|
def generate_xy_metrics(self):
|
|
518
490
|
actv = {}
|
|
519
491
|
for fwd_context in self.module_fwd_hook_context_by_module.values():
|
|
@@ -538,6 +510,17 @@ class TrainerMon:
|
|
|
538
510
|
def write_adhoc_check(self, step):
|
|
539
511
|
self.tensor_metrics.flush(self.summary_writer)
|
|
540
512
|
|
|
513
|
+
def write_stack_info(self):
|
|
514
|
+
stack_data = []
|
|
515
|
+
header = ["module_name", "stack_info"]
|
|
516
|
+
stack_data.append(header)
|
|
517
|
+
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
518
|
+
stack_data.append([fwd_context.module_name, fwd_context.stack])
|
|
519
|
+
filepath = os.path.join(self.tensorboard_dir, f'stack_info.csv')
|
|
520
|
+
if not os.path.exists(filepath):
|
|
521
|
+
data_frame = pd.DataFrame(columns=stack_data)
|
|
522
|
+
write_df_to_csv(data_frame, filepath)
|
|
523
|
+
|
|
541
524
|
def write_xy_tb(self, step):
|
|
542
525
|
if not self.xy_distribution:
|
|
543
526
|
return
|
|
@@ -552,27 +535,31 @@ class TrainerMon:
|
|
|
552
535
|
def write_param_tb(self, opt_context):
|
|
553
536
|
if not self.param_distribution:
|
|
554
537
|
return
|
|
555
|
-
|
|
538
|
+
param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.PRE_PARAM in k}
|
|
539
|
+
updated_param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.POST_PARAM in k}
|
|
540
|
+
self.summary_writer.write_metrics(self.ops, param_metrics, opt_context.step, MonitorConst.PRE_PARAM)
|
|
541
|
+
self.summary_writer.write_metrics(self.ops, updated_param_metrics, opt_context.step, MonitorConst.POST_PARAM)
|
|
556
542
|
|
|
557
543
|
def write_mv_tb(self, opt_context):
|
|
558
544
|
if not self.mv_distribution:
|
|
559
545
|
return
|
|
560
|
-
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric,
|
|
546
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric,
|
|
561
547
|
opt_context.step, MonitorConst.EXP_AVG)
|
|
562
|
-
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric,
|
|
548
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric,
|
|
563
549
|
opt_context.step, MonitorConst.EXP_AVG_SQ)
|
|
564
550
|
|
|
565
551
|
def write_grad_tb(self, step):
|
|
566
552
|
if not self.wg_distribution:
|
|
567
553
|
return
|
|
568
554
|
|
|
569
|
-
if self.
|
|
570
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.
|
|
555
|
+
if self.weight_hooked:
|
|
556
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced',
|
|
557
|
+
use_micro_step=self.monitor_mbs_grad)
|
|
571
558
|
else:
|
|
572
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.
|
|
559
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
|
|
573
560
|
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
574
561
|
|
|
575
|
-
def hook_optimizer(self, optimizer
|
|
562
|
+
def hook_optimizer(self, optimizer):
|
|
576
563
|
# in DDP by default use params_have_main_grad
|
|
577
564
|
def optimizer_pre_step_hook(optimizer, args, kwargs):
|
|
578
565
|
context = self.optimizer_context[optimizer]
|
|
@@ -591,21 +578,23 @@ class TrainerMon:
|
|
|
591
578
|
# skip generate metrics
|
|
592
579
|
if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0:
|
|
593
580
|
return
|
|
594
|
-
if MonitorConst.DEEPSPEED_ZERO_OPT_FILTER in self.optimizer_class: # use deepspeed with zero1/2/3
|
|
595
|
-
if not self.name2indices:
|
|
596
|
-
self.name2indices = self.optimizer_mon.get_param_index(self.param2name, self.name2index, optimizer)
|
|
597
|
-
mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name, self.name2indices)
|
|
598
|
-
self.param2name = mv_result.grad
|
|
599
|
-
else:
|
|
600
|
-
mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name)
|
|
601
|
-
context.param_exp_avg = mv_result.exp_avg
|
|
602
|
-
context.param_exp_avg_sq = mv_result.exp_avg_sq
|
|
603
|
-
context.param_adam_update = mv_result.update
|
|
604
|
-
context.param_adam_ratio = mv_result.ratio
|
|
605
581
|
|
|
606
|
-
|
|
582
|
+
grad_dict = {}
|
|
583
|
+
if self.wg_distribution:
|
|
584
|
+
grad_dict = self.optimizer_mon.fetch_grad(self, self.param2name)
|
|
585
|
+
|
|
586
|
+
mv_result = None
|
|
587
|
+
if self.mv_distribution or self.ur_distribution or self.mg_direction:
|
|
588
|
+
mv_result = self.optimizer_mon.fetch_mv(self, self.param2name)
|
|
589
|
+
if mv_result:
|
|
590
|
+
context.param_exp_avg = mv_result.exp_avg
|
|
591
|
+
context.param_exp_avg_sq = mv_result.exp_avg_sq
|
|
592
|
+
context.param_adam_update = mv_result.update
|
|
593
|
+
context.param_adam_ratio = mv_result.ratio
|
|
594
|
+
|
|
595
|
+
self.generate_wgrad_metrics(grad_dict)
|
|
607
596
|
self.generate_mv_metrics(context)
|
|
608
|
-
self.generate_param_metrics(context)
|
|
597
|
+
self.generate_param_metrics(context, MonitorConst.PRE_PARAM)
|
|
609
598
|
|
|
610
599
|
tbtag_tensor_map = {}
|
|
611
600
|
if self.mg_direction:
|
|
@@ -633,18 +622,15 @@ class TrainerMon:
|
|
|
633
622
|
context.metric_dict = metric_dict
|
|
634
623
|
return
|
|
635
624
|
|
|
636
|
-
def
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
out = func(*args, **kwargs)
|
|
640
|
-
return out
|
|
641
|
-
|
|
642
|
-
return wrapper
|
|
625
|
+
def optimizer_post_step_hook(optimizer, args, kwargs):
|
|
626
|
+
context = self.optimizer_context[optimizer]
|
|
627
|
+
self.generate_param_metrics(context, MonitorConst.POST_PARAM)
|
|
643
628
|
|
|
644
629
|
if self.optimizer_hooked:
|
|
645
630
|
return
|
|
646
631
|
|
|
647
|
-
|
|
632
|
+
self.pre_step_hooks.append(optimizer_pre_step_hook)
|
|
633
|
+
self.post_step_hooks.append(optimizer_post_step_hook)
|
|
648
634
|
|
|
649
635
|
self.optimizer_hooked = True
|
|
650
636
|
return
|
|
@@ -674,6 +660,7 @@ class TrainerMon:
|
|
|
674
660
|
validate_config(config)
|
|
675
661
|
self.config = config
|
|
676
662
|
self.set_config()
|
|
663
|
+
self.start_step = context.step # 动态启停时不受原start_step影响,永远从下一步开始
|
|
677
664
|
logger.warning(f"config is updated at step{context.step - 1}, "
|
|
678
665
|
f"will start new hook at step{context.step}.")
|
|
679
666
|
except Exception as e:
|
|
@@ -703,6 +690,12 @@ class TrainerMon:
|
|
|
703
690
|
self.write_mv_tb(context)
|
|
704
691
|
self.write_param_tb(context)
|
|
705
692
|
self.write_adhoc_check(context.step)
|
|
693
|
+
if self.stack_info:
|
|
694
|
+
self.write_stack_info()
|
|
695
|
+
self.stack_info = False
|
|
696
|
+
for handle in self.handles["stack"]:
|
|
697
|
+
handle.remove()
|
|
698
|
+
self.handles["stack"].clear()
|
|
706
699
|
|
|
707
700
|
if self.ur_distribution:
|
|
708
701
|
for param_name, _ in context.param_adam_update.items():
|
|
@@ -721,6 +714,9 @@ class TrainerMon:
|
|
|
721
714
|
if self.anomaly_data_factory:
|
|
722
715
|
self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
|
|
723
716
|
self.summary_writer.clear_anomalies()
|
|
717
|
+
|
|
718
|
+
if self.format == MonitorConst.TENSORBOARD:
|
|
719
|
+
chmod_tensorboard_dir(self.tensorboard_dir)
|
|
724
720
|
self.call_id = 0
|
|
725
721
|
self.param_name_call_id.clear()
|
|
726
722
|
|
|
@@ -732,16 +728,69 @@ class TrainerMon:
|
|
|
732
728
|
|
|
733
729
|
def patch_step(func, optimizer):
|
|
734
730
|
def wrapper(*args, **kwargs):
|
|
731
|
+
for hook in self.pre_step_hooks:
|
|
732
|
+
hook(optimizer, args, kwargs)
|
|
735
733
|
out = func(*args, **kwargs)
|
|
734
|
+
for hook in self.post_step_hooks:
|
|
735
|
+
hook(optimizer, args, kwargs)
|
|
736
736
|
step_final_hook(optimizer, args, kwargs)
|
|
737
737
|
return out
|
|
738
738
|
return wrapper
|
|
739
739
|
|
|
740
740
|
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
741
|
-
|
|
741
|
+
return
|
|
742
|
+
|
|
743
|
+
def hook_modules(self):
|
|
744
|
+
if self.module_rank_list and (self.rank not in self.module_rank_list):
|
|
745
|
+
return
|
|
742
746
|
|
|
747
|
+
targets = self.config['targets']
|
|
748
|
+
module_in_all_stage = [key for key in targets.keys() if MonitorConst.NAME_SEP not in key]
|
|
749
|
+
for key in module_in_all_stage:
|
|
750
|
+
struct = targets.pop(key)
|
|
751
|
+
targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
|
|
752
|
+
|
|
753
|
+
hooked_count = 0
|
|
754
|
+
for vpp_stage, model_chunk in enumerate(self.model):
|
|
755
|
+
vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
756
|
+
targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
|
|
757
|
+
'targets'].keys()
|
|
758
|
+
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
759
|
+
|
|
760
|
+
logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
|
|
761
|
+
|
|
762
|
+
@recursion_depth_decorator('msprobe.pytorch.monitor.clone_if_tensor')
|
|
763
|
+
def clone_if_tensor(args):
|
|
764
|
+
if isinstance(args, tuple):
|
|
765
|
+
return tuple([clone_if_tensor(arg) for arg in args])
|
|
766
|
+
elif isinstance(args, torch.Tensor) and not is_float8_tensor(args):
|
|
767
|
+
return args.clone()
|
|
768
|
+
else:
|
|
769
|
+
return args
|
|
770
|
+
|
|
771
|
+
@torch.no_grad
|
|
772
|
+
def wrap_hook_setup(setup):
|
|
773
|
+
def wrapped_setup(*args, **kwargs):
|
|
774
|
+
args = setup(*args, **kwargs)
|
|
775
|
+
args = clone_if_tensor(args)
|
|
776
|
+
return args
|
|
777
|
+
|
|
778
|
+
return wrapped_setup
|
|
779
|
+
|
|
780
|
+
BackwardHook.setup_input_hook = wrap_hook_setup(BackwardHook.setup_input_hook)
|
|
781
|
+
BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
|
|
743
782
|
return
|
|
744
783
|
|
|
784
|
+
def register_param_call_id(self, hook_name: str, key: str):
|
|
785
|
+
"""
|
|
786
|
+
:param hook_name:
|
|
787
|
+
:param key: str, '0:relu_0/output_grad'
|
|
788
|
+
:return:
|
|
789
|
+
"""
|
|
790
|
+
logger.debug(f"{hook_name} {key}: {self.call_id}")
|
|
791
|
+
self.param_name_call_id[key] = self.call_id
|
|
792
|
+
self.call_id += 1
|
|
793
|
+
|
|
745
794
|
def _remove_all_hooks(self, optimizer):
|
|
746
795
|
# 清空hook handle
|
|
747
796
|
for handle in self.handles['xy']:
|
|
@@ -767,14 +816,18 @@ class TrainerMon:
|
|
|
767
816
|
logger.info("remove _ParamAndGradBucketGroup start_grad_sync")
|
|
768
817
|
except ImportError:
|
|
769
818
|
pass
|
|
770
|
-
|
|
819
|
+
elif self.fsdp_post_backward_hook: # fsdp
|
|
820
|
+
torch.distributed.fsdp._runtime_utils._post_backward_hook = self.fsdp_post_backward_hook
|
|
821
|
+
logger.info("remove patch_post_backward_hook in fsdp.")
|
|
822
|
+
else: # not megatron and not fsdp
|
|
771
823
|
for handle in self.handles['wgrads']:
|
|
772
824
|
handle.remove()
|
|
773
825
|
self.handles['wgrads'].clear()
|
|
774
826
|
self.weight_hooked = False
|
|
775
827
|
|
|
776
828
|
if self.optimizer_hooked:
|
|
777
|
-
|
|
829
|
+
self.pre_step_hooks.clear()
|
|
830
|
+
self.post_step_hooks.clear()
|
|
778
831
|
|
|
779
832
|
for _, context in self.optimizer_context.items():
|
|
780
833
|
context.reset()
|
|
@@ -783,12 +836,12 @@ class TrainerMon:
|
|
|
783
836
|
for handle in self.handles['cc']:
|
|
784
837
|
handle.remove()
|
|
785
838
|
self.handles['cc'].clear()
|
|
839
|
+
api_register.restore_api()
|
|
786
840
|
for _, context in self.cc_context.items():
|
|
787
841
|
context.reset()
|
|
788
842
|
|
|
789
843
|
# 清空节点缓存
|
|
790
844
|
self.param2name.clear()
|
|
791
|
-
self.name2index.clear()
|
|
792
845
|
self.name2indices.clear()
|
|
793
846
|
self.name2param.clear()
|
|
794
847
|
self.duplicate_param.clear()
|
|
@@ -848,27 +901,33 @@ class TrainerMon:
|
|
|
848
901
|
return False
|
|
849
902
|
|
|
850
903
|
def _register_chunk(self, model_chunk, prefix):
|
|
851
|
-
index = 0
|
|
852
904
|
for (param_name, param) in model_chunk.named_parameters():
|
|
853
905
|
if not param.requires_grad:
|
|
854
906
|
continue
|
|
907
|
+
if not self.fsdp_wrapped_module and param_name.startswith("_fsdp_wrapped_module"):
|
|
908
|
+
self.fsdp_wrapped_module = True
|
|
855
909
|
if self._is_target_param(param_name, param, prefix):
|
|
856
910
|
name = prefix + squash_param_name(param_name, self.squash_name)
|
|
857
911
|
if name in self.param2name.values():
|
|
858
912
|
name = prefix + param_name
|
|
859
913
|
self.param2name[param] = name
|
|
860
914
|
self.name2param[name] = param
|
|
861
|
-
self.name2index[name] = index
|
|
862
915
|
|
|
863
916
|
if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group):
|
|
864
917
|
self.duplicate_param[name] = True
|
|
865
918
|
if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
|
|
866
919
|
self.duplicate_param[name] = True
|
|
920
|
+
|
|
921
|
+
keywords = [
|
|
922
|
+
MonitorConst.PRE_GRAD,
|
|
923
|
+
MonitorConst.POST_GRAD,
|
|
924
|
+
MonitorConst.PRE_PARAM,
|
|
925
|
+
MonitorConst.POST_PARAM
|
|
926
|
+
]
|
|
867
927
|
self.name2tag[name] = {
|
|
868
|
-
|
|
869
|
-
|
|
928
|
+
k: get_summary_writer_tag_name(name, k, self.rank)
|
|
929
|
+
for k in keywords
|
|
870
930
|
}
|
|
871
|
-
index += 1
|
|
872
931
|
|
|
873
932
|
def _register_param_name(self):
|
|
874
933
|
for vpp_stage, model_chunk in enumerate(self.model):
|
|
@@ -891,11 +950,17 @@ class TrainerMon:
|
|
|
891
950
|
# nothing to hook
|
|
892
951
|
return 0
|
|
893
952
|
|
|
894
|
-
def fwd_hook_fun(module,
|
|
953
|
+
def fwd_hook_fun(module, args, kwargs, module_output, name):
|
|
895
954
|
if not module.training or is_recomputation():
|
|
896
955
|
# 1 only monitor training stage.
|
|
897
956
|
# 2 when open recompute, skip recomputed forward stage.
|
|
898
957
|
return
|
|
958
|
+
|
|
959
|
+
module_input = [tensor for tensor in args if torch.is_tensor(tensor)]
|
|
960
|
+
if kwargs:
|
|
961
|
+
kwargs_tensors = [tensor for tensor in kwargs.values() if torch.is_tensor(tensor)]
|
|
962
|
+
module_input.extend(kwargs_tensors)
|
|
963
|
+
|
|
899
964
|
if module not in self.module_fwd_hook_context_by_module:
|
|
900
965
|
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
901
966
|
context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
|
|
@@ -904,34 +969,20 @@ class TrainerMon:
|
|
|
904
969
|
Const.INPUT: get_param_struct(module_input),
|
|
905
970
|
Const.OUTPUT: get_param_struct(module_output)
|
|
906
971
|
}
|
|
972
|
+
|
|
907
973
|
if self.print_struct:
|
|
908
974
|
self.module_struct[context.module_name].update(context.struct)
|
|
909
975
|
return
|
|
910
|
-
|
|
911
|
-
context.set_format_by_arg(Const.INPUT, self.config['targets'])
|
|
912
|
-
context.set_format_by_arg(Const.OUTPUT, self.config['targets'])
|
|
913
|
-
if not context.format_by_arg:
|
|
914
|
-
return
|
|
915
|
-
if not context.verified:
|
|
916
|
-
context.focused_in_col = validate_config_spec(context.format_by_arg[Const.INPUT],
|
|
917
|
-
module_input, context.module_name,
|
|
918
|
-
Const.INPUT)
|
|
919
|
-
context.focused_out_col = validate_config_spec(context.format_by_arg[Const.OUTPUT],
|
|
920
|
-
module_output, context.module_name,
|
|
921
|
-
Const.OUTPUT)
|
|
922
|
-
context.verified = True
|
|
923
|
-
# expect output be tensor type
|
|
976
|
+
|
|
924
977
|
tbtag_tensor_map = {}
|
|
925
|
-
cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
|
|
926
978
|
tbtag_tensor_map.update(
|
|
927
979
|
self.build_tbtag_tensor_map(
|
|
928
|
-
f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
929
|
-
MonitorConst.ACTV,
|
|
930
|
-
cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
|
|
980
|
+
f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
981
|
+
MonitorConst.ACTV, module_input))
|
|
931
982
|
tbtag_tensor_map.update(
|
|
932
983
|
self.build_tbtag_tensor_map(
|
|
933
|
-
f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
934
|
-
MonitorConst.ACTV,
|
|
984
|
+
f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
985
|
+
MonitorConst.ACTV, module_output))
|
|
935
986
|
|
|
936
987
|
get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
|
|
937
988
|
context.micro_step += 1
|
|
@@ -949,31 +1000,17 @@ class TrainerMon:
|
|
|
949
1000
|
if self.print_struct:
|
|
950
1001
|
self.module_struct[context.module_name].update(context.struct)
|
|
951
1002
|
return
|
|
952
|
-
if not context.format_by_arg:
|
|
953
|
-
context.set_format_by_arg(MonitorConst.INPUT_GRAD, self.config['targets'])
|
|
954
|
-
context.set_format_by_arg(MonitorConst.OUTPUT_GRAD, self.config['targets'])
|
|
955
|
-
if not context.format_by_arg:
|
|
956
|
-
return
|
|
957
|
-
if not context.verified:
|
|
958
|
-
context.focused_in_col = validate_config_spec(
|
|
959
|
-
context.format_by_arg[MonitorConst.INPUT_GRAD],
|
|
960
|
-
input_grad, context.module_name, MonitorConst.INPUT_GRAD)
|
|
961
|
-
context.focused_out_col = validate_config_spec(
|
|
962
|
-
context.format_by_arg[MonitorConst.OUTPUT_GRAD],
|
|
963
|
-
output_grad, context.module_name, MonitorConst.OUTPUT_GRAD)
|
|
964
|
-
context.verified = True
|
|
965
1003
|
|
|
966
1004
|
tbtag_tensor_map = {}
|
|
967
|
-
cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
|
|
968
1005
|
tbtag_tensor_map.update(
|
|
969
1006
|
self.build_tbtag_tensor_map(
|
|
970
|
-
f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
971
|
-
MonitorConst.
|
|
972
|
-
|
|
1007
|
+
f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
1008
|
+
MonitorConst.ACTVGRAD, input_grad))
|
|
1009
|
+
|
|
973
1010
|
tbtag_tensor_map.update(
|
|
974
1011
|
self.build_tbtag_tensor_map(
|
|
975
|
-
f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
976
|
-
MonitorConst.
|
|
1012
|
+
f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
1013
|
+
MonitorConst.ACTVGRAD, output_grad))
|
|
977
1014
|
|
|
978
1015
|
if context.micro_step == 0 and context.actvgrad:
|
|
979
1016
|
logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
|
|
@@ -987,17 +1024,30 @@ class TrainerMon:
|
|
|
987
1024
|
context.micro_step = 0
|
|
988
1025
|
return
|
|
989
1026
|
|
|
1027
|
+
def stack_hook(module, args, kwargs, module_output, name):
|
|
1028
|
+
if module not in self.module_fwd_hook_context_by_module:
|
|
1029
|
+
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
1030
|
+
context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
|
|
1031
|
+
context.stack = analyze_api_call_stack(name)
|
|
1032
|
+
return
|
|
1033
|
+
|
|
990
1034
|
if self.backward_only and self.forward_only:
|
|
991
1035
|
logger.warning('not enable backward_only and forward_only simultaneously')
|
|
992
1036
|
|
|
993
1037
|
hooked_count = 0
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
name =
|
|
997
|
-
|
|
998
|
-
|
|
1038
|
+
for module_name, submodule in module.named_modules():
|
|
1039
|
+
if self.stack_info:
|
|
1040
|
+
name = vpp_stage + squash_param_name(module_name, self.squash_name)
|
|
1041
|
+
handle = submodule.register_forward_hook(partial(stack_hook, name=name), with_kwargs=True)
|
|
1042
|
+
self.handles['stack'].append(handle)
|
|
1043
|
+
name = self._is_target_module(module_name, target_names, vpp_stage)
|
|
1044
|
+
if not name:
|
|
1045
|
+
continue
|
|
1046
|
+
if submodule.__class__.__name__ == "FullyShardedDataParallel":
|
|
1047
|
+
continue
|
|
1048
|
+
if self.xy_distribution or self.print_struct:
|
|
999
1049
|
if not self.backward_only:
|
|
1000
|
-
handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name))
|
|
1050
|
+
handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name), with_kwargs=True)
|
|
1001
1051
|
self.handles['xy'].append(handle)
|
|
1002
1052
|
if not self.forward_only and not self.has_register_backward_hook(name, submodule):
|
|
1003
1053
|
handle = submodule.register_full_backward_hook(bwd_hook_fun)
|
|
@@ -1026,7 +1076,7 @@ class TrainerMon:
|
|
|
1026
1076
|
if tag is None:
|
|
1027
1077
|
continue
|
|
1028
1078
|
grad_dict[tag] = grad
|
|
1029
|
-
self.
|
|
1079
|
+
self.register_param_call_id("sync_grad_func", tag)
|
|
1030
1080
|
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
1031
1081
|
out = sync_grad_func(bucket)
|
|
1032
1082
|
return out
|
|
@@ -1035,7 +1085,14 @@ class TrainerMon:
|
|
|
1035
1085
|
|
|
1036
1086
|
if not self.wg_distribution:
|
|
1037
1087
|
return
|
|
1088
|
+
if self.fsdp_wrapped_module:
|
|
1089
|
+
# patch fsdp _runtime_utils._post_backward_hook
|
|
1090
|
+
self._patch_fsdp_post_backward_hook()
|
|
1091
|
+
return
|
|
1038
1092
|
|
|
1093
|
+
if self.monitor_mbs_grad:
|
|
1094
|
+
self._hook_weights()
|
|
1095
|
+
return
|
|
1039
1096
|
try:
|
|
1040
1097
|
from megatron.core.distributed.param_and_grad_buffer import Bucket
|
|
1041
1098
|
self.origin_start_grad_sync = Bucket.start_grad_sync
|
|
@@ -1052,44 +1109,82 @@ class TrainerMon:
|
|
|
1052
1109
|
self.enable_megatron = True
|
|
1053
1110
|
logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0")
|
|
1054
1111
|
except ImportError:
|
|
1055
|
-
self.enable_megatron = False
|
|
1112
|
+
self.enable_megatron = False | self.enable_megatron
|
|
1113
|
+
if self.enable_megatron:
|
|
1114
|
+
return
|
|
1056
1115
|
|
|
1057
|
-
|
|
1058
|
-
|
|
1116
|
+
# default hook weights
|
|
1117
|
+
self._hook_weights()
|
|
1118
|
+
|
|
1119
|
+
def _patch_fsdp_post_backward_hook(self):
|
|
1120
|
+
"""
|
|
1121
|
+
FSDP runtime 需要处理整个forward和backward计算和通信的流程,通过override nn.Module的forward,定义相应的逻辑。
|
|
1122
|
+
对AccumulateGrad对象注册hook,可以在backward计算grad后立刻执行,在reduce_scatter操作前采集梯度累计后,通信聚合前的梯度。
|
|
1123
|
+
每个forward阶段,fsdp对AccumulateGrad重复注册hook方法,monitor工具内注册hook无法生效,
|
|
1124
|
+
因此对_post_backward_hook进行patch,在backward后,reduce_scatter前采集梯度。
|
|
1125
|
+
"""
|
|
1126
|
+
def patch_post_backward_hook(_post_backward_hook):
|
|
1127
|
+
def wrapper(state, handle, *unused):
|
|
1128
|
+
grad_dict = {}
|
|
1129
|
+
offset = 0
|
|
1130
|
+
for param, name in self.param2name.items():
|
|
1131
|
+
limit = param.numel()
|
|
1132
|
+
if not limit:
|
|
1133
|
+
continue
|
|
1134
|
+
grad = handle.flat_param.grad[offset:offset + limit]
|
|
1135
|
+
offset += limit
|
|
1136
|
+
tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
|
|
1137
|
+
if tag is None:
|
|
1138
|
+
continue
|
|
1139
|
+
grad_dict[tag] = grad
|
|
1140
|
+
self.register_param_call_id("_post_backward_hook", tag)
|
|
1141
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
1142
|
+
out = _post_backward_hook(state, handle, *unused)
|
|
1143
|
+
return out
|
|
1144
|
+
|
|
1145
|
+
return wrapper
|
|
1146
|
+
|
|
1147
|
+
logger.info("Patch fsdp _post_backward_hook, collect pre_grad metrics.")
|
|
1148
|
+
self.fsdp_post_backward_hook = torch.distributed.fsdp._runtime_utils._post_backward_hook
|
|
1149
|
+
torch.distributed.fsdp._runtime_utils._post_backward_hook = \
|
|
1150
|
+
patch_post_backward_hook(torch.distributed.fsdp._runtime_utils._post_backward_hook)
|
|
1059
1151
|
|
|
1060
1152
|
def _hook_weights(self):
|
|
1153
|
+
"""
|
|
1154
|
+
遍历参数的梯度生成函数(grad_acc),并挂载hook,以便在该参数所有梯度计算后,采集通信聚合前梯度数据。
|
|
1155
|
+
"""
|
|
1061
1156
|
context = self.grad_context
|
|
1062
1157
|
|
|
1063
1158
|
@torch.no_grad
|
|
1064
|
-
def param_hook(*args, context_dict, param,
|
|
1159
|
+
def param_hook(*args, context_dict, param, name):
|
|
1160
|
+
key = name
|
|
1161
|
+
if self.monitor_mbs_grad:
|
|
1162
|
+
key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
|
|
1163
|
+
|
|
1164
|
+
key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
|
|
1165
|
+
self.register_param_call_id("param_hook", key)
|
|
1065
1166
|
param.micro_step += 1
|
|
1066
|
-
|
|
1067
|
-
if param.micro_step == self.micro_batch_number:
|
|
1068
|
-
param.micro_step = 0
|
|
1167
|
+
|
|
1168
|
+
if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
|
|
1069
1169
|
if self.params_have_main_grad:
|
|
1070
|
-
|
|
1170
|
+
grad = param.main_grad
|
|
1071
1171
|
else:
|
|
1072
|
-
|
|
1172
|
+
grad = param.grad
|
|
1173
|
+
if is_float8_tensor(grad):
|
|
1174
|
+
grad = grad.float()
|
|
1175
|
+
context_dict[key] = grad.clone()
|
|
1176
|
+
|
|
1177
|
+
if param.micro_step == self.micro_batch_number:
|
|
1178
|
+
param.micro_step = 0
|
|
1073
1179
|
|
|
1074
1180
|
logger.info("hooking weights.")
|
|
1075
1181
|
for param, name in self.param2name.items():
|
|
1076
|
-
key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
|
|
1077
1182
|
setattr(param, 'micro_step', 0)
|
|
1078
1183
|
param_tmp = param.expand_as(param)
|
|
1079
1184
|
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
|
1080
1185
|
handle = grad_acc.register_hook(
|
|
1081
|
-
partial(param_hook, context_dict=context.acc, param=param,
|
|
1186
|
+
partial(param_hook, context_dict=context.acc, param=param, name=name))
|
|
1082
1187
|
self.grad_accs.append(grad_acc)
|
|
1083
1188
|
self.handles['wgrads'].append(handle)
|
|
1084
1189
|
|
|
1085
1190
|
self.weight_hooked = True
|
|
1086
|
-
|
|
1087
|
-
def _register_param_call_id(self, hook_name: str, key: str):
|
|
1088
|
-
"""
|
|
1089
|
-
:param hook_name:
|
|
1090
|
-
:param key: str, '0:relu_0/output_grad'
|
|
1091
|
-
:return:
|
|
1092
|
-
"""
|
|
1093
|
-
logger.debug(f"{hook_name} {key}: {self.call_id}")
|
|
1094
|
-
self.param_name_call_id[key] = self.call_id
|
|
1095
|
-
self.call_id += 1
|