mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
- mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +16 -21
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +185 -11
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +4 -0
- msprobe/core/common/utils.py +42 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +381 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +199 -69
- msprobe/core/data_dump/data_collector.py +100 -25
- msprobe/core/data_dump/data_processor/base.py +130 -28
- msprobe/core/data_dump/data_processor/factory.py +8 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
- msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
- msprobe/core/data_dump/json_writer.py +54 -8
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +121 -17
- msprobe/docs/02.config_introduction.md +18 -16
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +107 -58
- msprobe/docs/06.data_dump_MindSpore.md +95 -34
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +310 -220
- msprobe/docs/21.visualization_PyTorch.md +125 -35
- msprobe/docs/22.visualization_MindSpore.md +149 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +525 -0
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +11 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +68 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +28 -2
- msprobe/mindspore/debugger/precision_debugger.py +100 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +13 -3
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +347 -107
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -7
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +116 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +42 -17
- msprobe/pytorch/debugger/precision_debugger.py +56 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +98 -28
- msprobe/pytorch/monitor/csv2tb.py +164 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +543 -318
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +76 -56
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +84 -48
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +264 -115
- msprobe/visualization/builder/graph_builder.py +93 -10
- msprobe/visualization/builder/msprobe_adapter.py +30 -6
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +15 -19
- msprobe/visualization/graph/distributed_analyzer.py +395 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +100 -27
- msprobe/visualization/utils.py +24 -31
- mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
- msprobe/pytorch/functional/module_dump.py +0 -84
- msprobe/pytorch/module_processer.py +0 -150
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -12,50 +12,45 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
import time
|
|
16
15
|
import json
|
|
17
16
|
import os
|
|
18
17
|
import uuid
|
|
19
18
|
from collections import defaultdict
|
|
20
|
-
from datetime import datetime
|
|
19
|
+
from datetime import datetime
|
|
21
20
|
from functools import partial
|
|
22
21
|
|
|
23
22
|
import pytz
|
|
24
23
|
import torch
|
|
25
24
|
import torch.distributed as dist
|
|
26
|
-
from
|
|
27
|
-
|
|
28
|
-
from msprobe.core.common.
|
|
25
|
+
from torch.utils.hooks import BackwardHook
|
|
26
|
+
|
|
27
|
+
from msprobe.core.common.const import MonitorConst, Const
|
|
28
|
+
from msprobe.core.common.file_utils import load_json, save_json
|
|
29
|
+
from msprobe.pytorch.common.log import logger
|
|
30
|
+
from msprobe.pytorch.common.utils import is_recomputation
|
|
29
31
|
from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter
|
|
30
32
|
from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
|
|
31
33
|
CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
32
34
|
from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
|
|
33
35
|
get_process_group
|
|
34
36
|
from msprobe.pytorch.monitor.features import get_sign_matches
|
|
35
|
-
from msprobe.pytorch.monitor.module_metric import get_metrics,
|
|
36
|
-
TensorMetrics,
|
|
37
|
+
from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
|
|
38
|
+
TensorMetrics, squash_param_name
|
|
37
39
|
from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
|
|
38
|
-
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
|
|
39
|
-
from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops,
|
|
40
|
+
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
|
|
41
|
+
from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \
|
|
42
|
+
get_output_base_dir, get_target_output_dir
|
|
40
43
|
from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
|
|
41
|
-
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
|
|
42
|
-
from torch.utils.hooks import BackwardHook
|
|
43
|
-
|
|
44
|
-
try:
|
|
45
|
-
import torch_npu
|
|
46
|
-
except ImportError:
|
|
47
|
-
pass
|
|
48
44
|
|
|
49
45
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
50
46
|
if not torch_version_above_or_equal_2:
|
|
51
47
|
raise ValueError("monitor require torch>=2.0")
|
|
52
48
|
|
|
53
|
-
output_base_dir = os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
|
|
54
49
|
|
|
55
50
|
FORMAT_MAPPING = {
|
|
56
|
-
MonitorConst.TENSORBOARD:
|
|
57
|
-
MonitorConst.CSV:
|
|
58
|
-
MonitorConst.API:
|
|
51
|
+
MonitorConst.TENSORBOARD: SummaryWriterWithAD,
|
|
52
|
+
MonitorConst.CSV: CSVWriterWithAD,
|
|
53
|
+
MonitorConst.API: BaseWriterWithAD
|
|
59
54
|
}
|
|
60
55
|
|
|
61
56
|
|
|
@@ -71,7 +66,6 @@ def param_is_data_parallel_duplicate(dp_group):
|
|
|
71
66
|
|
|
72
67
|
class ModuleHookContext:
|
|
73
68
|
def __init__(self, module_name) -> None:
|
|
74
|
-
self.step = 0
|
|
75
69
|
self.micro_step = 0
|
|
76
70
|
self.actv = defaultdict(dict)
|
|
77
71
|
self.actvgrad = []
|
|
@@ -81,26 +75,44 @@ class ModuleHookContext:
|
|
|
81
75
|
self.verified = False
|
|
82
76
|
self.focused_in_col = 0
|
|
83
77
|
self.focused_out_col = 0
|
|
84
|
-
self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found
|
|
85
78
|
|
|
86
79
|
def set_format_by_arg(self, key_name: str, target_config: dict):
|
|
80
|
+
""" 按照监控对象配置format_by_arg
|
|
81
|
+
1) module_name 在 target 中配置监控对象
|
|
82
|
+
2) module_name 未在 targets 中配置,且 all_xy 全量监控
|
|
83
|
+
3) module_name 未在 targets 中配置,且 all_xy 未全量监控
|
|
84
|
+
|
|
85
|
+
:param key_name: str, one of [input, output, input_grad, output_grad]
|
|
86
|
+
:param target_config: target obj in config json.
|
|
87
|
+
:return:
|
|
88
|
+
"""
|
|
87
89
|
cared = target_config.get(self.module_name, self.struct)
|
|
88
90
|
if key_name in cared:
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
self.format_by_arg[key_name] = config
|
|
93
|
-
|
|
91
|
+
target_module_config = cared[key_name]
|
|
92
|
+
if isinstance(target_module_config, dict):
|
|
93
|
+
# current cared is self.struct, monitor all data for module_name
|
|
94
|
+
self.format_by_arg[key_name] = target_module_config.get('config')
|
|
95
|
+
elif isinstance(target_module_config, str):
|
|
94
96
|
# current cared is target_config[self.module_name]
|
|
95
|
-
self.format_by_arg[key_name] =
|
|
96
|
-
|
|
97
|
-
|
|
97
|
+
self.format_by_arg[key_name] = target_module_config
|
|
98
|
+
else:
|
|
99
|
+
logger.warning_on_rank_0(f"target module config error, result maybe empty."
|
|
100
|
+
f"module_name: {self.module_name}, key_name: {key_name}")
|
|
101
|
+
self.format_by_arg[key_name] = None
|
|
102
|
+
else:
|
|
103
|
+
self.format_by_arg[key_name] = self.struct.get(key_name).get('config')
|
|
104
|
+
|
|
105
|
+
def reset(self):
|
|
106
|
+
self.actv.clear()
|
|
107
|
+
self.actvgrad.clear()
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
start_step = 0
|
|
98
111
|
|
|
99
112
|
|
|
100
113
|
class OptimizerContext:
|
|
101
114
|
def __init__(self) -> None:
|
|
102
|
-
self.step =
|
|
103
|
-
self.param_effective_rank = defaultdict(float)
|
|
115
|
+
self.step = start_step
|
|
104
116
|
self.param_mg_direction = defaultdict(float)
|
|
105
117
|
self.param_adam_update = defaultdict()
|
|
106
118
|
self.param_adam_ratio = defaultdict()
|
|
@@ -112,6 +124,18 @@ class OptimizerContext:
|
|
|
112
124
|
self.metric_dict = {}
|
|
113
125
|
self.param_metric = {}
|
|
114
126
|
|
|
127
|
+
def reset(self):
|
|
128
|
+
self.param_mg_direction.clear()
|
|
129
|
+
self.param_adam_update.clear()
|
|
130
|
+
self.param_adam_ratio.clear()
|
|
131
|
+
self.param_weight_grad.clear()
|
|
132
|
+
self.param_exp_avg.clear()
|
|
133
|
+
self.exp_avg_metric.clear()
|
|
134
|
+
self.param_exp_avg_sq.clear()
|
|
135
|
+
self.exp_avg_sq_metric.clear()
|
|
136
|
+
self.metric_dict.clear()
|
|
137
|
+
self.param_metric.clear()
|
|
138
|
+
|
|
115
139
|
|
|
116
140
|
class CommunicationContext:
|
|
117
141
|
def __init__(self) -> None:
|
|
@@ -152,23 +176,131 @@ class GradContext:
|
|
|
152
176
|
class TrainerMon:
|
|
153
177
|
tensor_metrics = TensorMetrics()
|
|
154
178
|
|
|
155
|
-
def __init__(self, config_file_path, process_group=None, params_have_main_grad=True
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
179
|
+
def __init__(self, config_file_path, process_group=None, params_have_main_grad=True) -> None:
|
|
180
|
+
# TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
|
|
181
|
+
self.config_file_path = config_file_path
|
|
182
|
+
self.process_group = get_process_group(process_group)
|
|
183
|
+
self.params_have_main_grad = params_have_main_grad
|
|
184
|
+
self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
185
|
+
self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
186
|
+
self.origin_step_func = None
|
|
187
|
+
self.origin_start_grad_sync = None
|
|
188
|
+
self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
|
|
189
|
+
self.config = load_json(config_file_path)
|
|
190
|
+
validate_config(self.config)
|
|
191
|
+
|
|
192
|
+
self.squash_name = self.config.get('squash_name', True) # 不允许修改防止前后名字对不上
|
|
193
|
+
local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
|
|
194
|
+
cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
|
|
195
|
+
self.unique_id = str(uuid.uuid4())[:8]
|
|
196
|
+
self.output_base_dir = get_output_base_dir()
|
|
197
|
+
time_tags = self.config.get("append_output", [])
|
|
198
|
+
if dist.is_initialized():
|
|
199
|
+
self.rank = dist.get_rank()
|
|
200
|
+
if time_tags:
|
|
201
|
+
output_append_dirs = get_target_output_dir(self.output_base_dir, time_tags[0], time_tags[1])
|
|
202
|
+
if str(self.rank) in output_append_dirs:
|
|
203
|
+
self.tensorboard_dir = output_append_dirs[str(self.rank)]
|
|
204
|
+
logger.info(f"append rank({self.rank}) result to {self.tensorboard_dir}")
|
|
205
|
+
else:
|
|
206
|
+
self.tensorboard_dir = os.path.join(self.output_base_dir,
|
|
207
|
+
f"{cur_time}-rank{self.rank}-{self.unique_id}")
|
|
208
|
+
self.pp_stage = dist.get_group_rank(self.process_group, self.rank)
|
|
209
|
+
self.group_mates = dist.get_process_group_ranks(self.process_group)
|
|
210
|
+
else:
|
|
211
|
+
self.rank = 0
|
|
212
|
+
self.tensorboard_dir = os.path.join(self.output_base_dir, f"{cur_time}-rank{self.rank}-{self.unique_id}")
|
|
213
|
+
self.pp_stage = 0
|
|
214
|
+
self.group_mates = [0]
|
|
215
|
+
|
|
216
|
+
# TYPE2: 只会在set_monitor()主调中赋值的变量
|
|
217
|
+
self.model = None
|
|
218
|
+
self.vpp = False
|
|
219
|
+
self.dp_group = None
|
|
220
|
+
self.tp_group = None
|
|
221
|
+
self.enable_megatron = False
|
|
222
|
+
self.micro_batch_number = 1
|
|
223
|
+
self.optimizer_class = None
|
|
224
|
+
self.optimizer_mon = None
|
|
225
|
+
|
|
226
|
+
# TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
|
|
159
227
|
self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
160
228
|
self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
161
229
|
self.optimizer_context = defaultdict(OptimizerContext)
|
|
162
230
|
self.cc_context = defaultdict(CommunicationContext)
|
|
163
231
|
self.grad_context = GradContext()
|
|
164
|
-
self.
|
|
165
|
-
self.
|
|
166
|
-
self.
|
|
167
|
-
self.
|
|
168
|
-
|
|
232
|
+
self.handles = defaultdict(list)
|
|
233
|
+
self.param2name = defaultdict(str)
|
|
234
|
+
self.name2index = defaultdict()
|
|
235
|
+
self.name2indices = defaultdict()
|
|
236
|
+
self.name2param = {}
|
|
237
|
+
self.duplicate_param = {}
|
|
238
|
+
self.name2tag = {}
|
|
239
|
+
self.param_name_call_id = {}
|
|
240
|
+
self.call_id = 0
|
|
241
|
+
self.module_struct = defaultdict(dict)
|
|
242
|
+
self.grad_accs = []
|
|
243
|
+
self.weight_hooked = False
|
|
244
|
+
self.optimizer_hooked = False
|
|
245
|
+
self.param_registered = False
|
|
246
|
+
self.struct_printed = False
|
|
247
|
+
|
|
248
|
+
# 动静态区分
|
|
249
|
+
self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
|
|
250
|
+
if self.dynamic_enable:
|
|
251
|
+
logger.warning(f"DYNAMIC_MONITOR is set, "
|
|
252
|
+
f"please make sure you have 'dynamic_on' and 'collect_times' in {self.config_file_path}")
|
|
253
|
+
self.monitoring = False
|
|
254
|
+
else:
|
|
255
|
+
self.set_config()
|
|
256
|
+
# 静态且collect_times>0时在第0步self.monitoring就可以True, 动态默认在下一步开启
|
|
257
|
+
if self.collect_times > 0:
|
|
258
|
+
self.monitoring = True
|
|
259
|
+
|
|
260
|
+
def __del__(self):
|
|
261
|
+
if hasattr(self, "summary_writer"):
|
|
262
|
+
self.summary_writer.close()
|
|
263
|
+
|
|
264
|
+
@property
|
|
265
|
+
def ops(self):
|
|
266
|
+
return self._ops
|
|
169
267
|
|
|
268
|
+
@ops.setter
|
|
269
|
+
def ops(self, value):
|
|
270
|
+
self._ops = validate_ops(value)
|
|
271
|
+
|
|
272
|
+
@staticmethod
|
|
273
|
+
def has_register_backward_hook(module_name, module):
|
|
274
|
+
if hasattr(module, '_backward_hooks') and \
|
|
275
|
+
len(module._backward_hooks) > 0 and \
|
|
276
|
+
module._is_full_backward_hook is False:
|
|
277
|
+
logger.warning(
|
|
278
|
+
f"The {module_name} has registered deprecated register_backward_hook,"
|
|
279
|
+
f"which may cause abnormal data dump. The backward input/output for this module will be skipped."
|
|
280
|
+
)
|
|
281
|
+
return True
|
|
282
|
+
return False
|
|
283
|
+
|
|
284
|
+
@staticmethod
|
|
285
|
+
def generate_cc_metrics(cc_name, cc_tensor):
|
|
286
|
+
metrics = defaultdict(dict)
|
|
287
|
+
rank = dist.get_rank() if dist.is_initialized() else None
|
|
288
|
+
for op, tag2tensor in cc_tensor.data.items():
|
|
289
|
+
for tag, tensor in tag2tensor.items():
|
|
290
|
+
key = get_summary_writer_tag_name(cc_name, tag, rank)
|
|
291
|
+
metrics[op].update({key: tensor})
|
|
292
|
+
cc_tensor.reset()
|
|
293
|
+
return metrics
|
|
294
|
+
|
|
295
|
+
def set_config(self):
|
|
296
|
+
logger.info(f"current config: {self.config}")
|
|
297
|
+
self.start_step = self.config.get("start_step", 0)
|
|
298
|
+
self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集
|
|
299
|
+
self.step_interval = self.config.get("step_interval", 1)
|
|
300
|
+
self.has_collect_times = 0 # 重设采集计数器
|
|
301
|
+
self.print_struct = self.config.get("print_struct", False)
|
|
170
302
|
self.module_rank_list = self.config.get("module_ranks", [])
|
|
171
|
-
self.format = self.config.get('format',
|
|
303
|
+
self.format = self.config.get('format', MonitorConst.CSV)
|
|
172
304
|
self.eps = self.config.get('eps', 1e-8)
|
|
173
305
|
self.ops = self.config.get('ops', [])
|
|
174
306
|
self.ndigits = self.config.get('ndigits', 6)
|
|
@@ -182,6 +314,7 @@ class TrainerMon:
|
|
|
182
314
|
self.param_distribution = self.config.get("param_distribution", False)
|
|
183
315
|
self.mg_direction = self.config.get('mg_direction', False)
|
|
184
316
|
self.cc_distribution = self.config.get("cc_distribution", {})
|
|
317
|
+
|
|
185
318
|
if not self.cc_distribution.get('enable', False):
|
|
186
319
|
self.cc_log_only = False
|
|
187
320
|
else:
|
|
@@ -189,49 +322,36 @@ class TrainerMon:
|
|
|
189
322
|
self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
|
|
190
323
|
self.cc_logged_stack = defaultdict(set)
|
|
191
324
|
self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
|
|
192
|
-
api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
325
|
+
self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
193
326
|
api_register.redirect_api()
|
|
194
327
|
|
|
195
328
|
self.common_info()
|
|
196
329
|
|
|
330
|
+
# 初始化AnomalyData工厂
|
|
197
331
|
alert_setting = self.config.get('alert', {"rules": []})
|
|
198
332
|
self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
|
|
199
|
-
|
|
200
|
-
# 设置时区,使用 'UTC' 作为示例
|
|
201
|
-
local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
|
|
202
|
-
|
|
203
|
-
cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
|
|
204
|
-
unique_id = str(uuid.uuid4())[:8]
|
|
205
|
-
|
|
206
|
-
if dist.is_initialized():
|
|
207
|
-
rank = dist.get_rank()
|
|
208
|
-
tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}")
|
|
209
|
-
pp_stage = dist.get_group_rank(self.process_group, rank)
|
|
210
|
-
group_mates = dist.get_process_group_ranks(self.process_group)
|
|
211
|
-
else:
|
|
212
|
-
rank = 0
|
|
213
|
-
tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}")
|
|
214
|
-
pp_stage = 0
|
|
215
|
-
group_mates = [0]
|
|
216
|
-
self.rank = rank
|
|
217
|
-
|
|
218
|
-
# 初始化AnomalyData工厂
|
|
219
333
|
self.anomaly_data_factory = None
|
|
220
334
|
if alert_setting.get('dump', False):
|
|
221
|
-
self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates)
|
|
335
|
+
self.anomaly_data_factory = AnomalyDataFactory(self.rank, self.pp_stage, self.group_mates)
|
|
222
336
|
|
|
337
|
+
# 初始化writer, 创建输出目录
|
|
223
338
|
if self.format not in FORMAT_MAPPING:
|
|
224
|
-
|
|
225
|
-
|
|
339
|
+
logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
|
|
340
|
+
self.format = MonitorConst.CSV
|
|
341
|
+
|
|
342
|
+
if self.ur_distribution and self.format != 'tensorboard':
|
|
343
|
+
logger.error("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution")
|
|
344
|
+
self.ur_distribution = False
|
|
345
|
+
|
|
346
|
+
writer = FORMAT_MAPPING[self.format]
|
|
226
347
|
self.step_count_per_record = self.config.get('step_count_per_record', 1)
|
|
227
348
|
|
|
228
|
-
if (rank in self.module_rank_list) or len(self.module_rank_list) == 0:
|
|
349
|
+
if (self.rank in self.module_rank_list) or len(self.module_rank_list) == 0:
|
|
229
350
|
self.summary_writer = writer(
|
|
230
351
|
WriterInput(
|
|
231
|
-
tensorboard_dir,
|
|
352
|
+
self.tensorboard_dir,
|
|
232
353
|
self.alert_rules,
|
|
233
|
-
unique_id,
|
|
234
|
-
None,
|
|
354
|
+
self.unique_id,
|
|
235
355
|
self.anomaly_data_factory,
|
|
236
356
|
self.ndigits,
|
|
237
357
|
self.step_count_per_record
|
|
@@ -239,83 +359,22 @@ class TrainerMon:
|
|
|
239
359
|
)
|
|
240
360
|
# 初始化anomaly detected文件目录
|
|
241
361
|
if self.anomaly_data_factory:
|
|
242
|
-
self.anomaly_data_writer = AnomalyDataWriter(os.path.join(output_base_dir, "anomaly_detected"),
|
|
362
|
+
self.anomaly_data_writer = AnomalyDataWriter(os.path.join(self.output_base_dir, "anomaly_detected"),
|
|
363
|
+
self.rank)
|
|
243
364
|
self.anomaly_data_writer.init_detected_json()
|
|
244
365
|
|
|
245
|
-
|
|
246
|
-
self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
247
|
-
self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
248
|
-
self.micro_batch_number = 1
|
|
249
|
-
|
|
250
|
-
self.model = None
|
|
251
|
-
self.weight_hooked = False
|
|
252
|
-
self.optimizer_hooked = False
|
|
253
|
-
self.param_registered = False
|
|
254
|
-
self.vpp = False
|
|
255
|
-
self.dp_group = None
|
|
256
|
-
self.tp_group = None
|
|
257
|
-
self.enable_megatron = False
|
|
258
|
-
|
|
259
|
-
self.param2name = defaultdict(str)
|
|
260
|
-
self.name2index = defaultdict()
|
|
261
|
-
self.name2indices = defaultdict()
|
|
262
|
-
self.name2param = {}
|
|
263
|
-
self.param_name_call_id = {}
|
|
264
|
-
self.duplicate_param = {}
|
|
265
|
-
self.name2tag = {}
|
|
266
|
-
self.call_id = 0
|
|
267
|
-
self.grad_accs = []
|
|
268
|
-
self.handles = defaultdict(list)
|
|
269
|
-
|
|
270
|
-
self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty)
|
|
271
|
-
self.print_struct = self.config.get("print_struct", False)
|
|
272
|
-
self.struct_printed = False
|
|
273
|
-
self.module_struct = {}
|
|
274
|
-
|
|
275
|
-
def __del__(self):
|
|
276
|
-
if hasattr(self, "summary_writer"):
|
|
277
|
-
self.summary_writer.close()
|
|
278
|
-
|
|
279
|
-
@property
|
|
280
|
-
def ops(self):
|
|
281
|
-
return self._ops
|
|
282
|
-
|
|
283
|
-
@ops.setter
|
|
284
|
-
def ops(self, value):
|
|
285
|
-
self._ops = validate_ops(value)
|
|
286
|
-
|
|
287
|
-
@staticmethod
|
|
288
|
-
def set_wrapped_optimizer(_wrapped_optimizer):
|
|
289
|
-
OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer)
|
|
290
|
-
|
|
291
|
-
@staticmethod
|
|
292
|
-
def adhoc_check(target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
|
|
366
|
+
def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
|
|
293
367
|
rank = None
|
|
294
368
|
if dist.is_initialized():
|
|
295
369
|
rank = dist.get_rank()
|
|
296
370
|
if (rank not in rank_list) and len(rank_list) != 0:
|
|
297
371
|
return
|
|
298
|
-
|
|
372
|
+
self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
|
|
299
373
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
key = get_summary_writer_tag_name(module_name, tag, rank)
|
|
305
|
-
if torch.is_tensor(tensor):
|
|
306
|
-
metrics[key] = tensor
|
|
307
|
-
return metrics
|
|
308
|
-
|
|
309
|
-
@staticmethod
|
|
310
|
-
def generate_cc_metrics(cc_name, cc_tensor):
|
|
311
|
-
metrics = defaultdict(dict)
|
|
312
|
-
rank = dist.get_rank() if dist.is_initialized() else None
|
|
313
|
-
for op, tag2tensor in cc_tensor.data.items():
|
|
314
|
-
for tag, tensor in tag2tensor.items():
|
|
315
|
-
key = get_summary_writer_tag_name(cc_name, tag, rank)
|
|
316
|
-
metrics[op].update({key: tensor})
|
|
317
|
-
cc_tensor.reset()
|
|
318
|
-
return metrics
|
|
374
|
+
def build_tbtag_tensor_map(self, module_name, tag, tensor):
|
|
375
|
+
key = get_summary_writer_tag_name(module_name, tag, self.rank)
|
|
376
|
+
self._register_param_call_id("_hook_module", key)
|
|
377
|
+
return {key: tensor}
|
|
319
378
|
|
|
320
379
|
def common_info(self):
|
|
321
380
|
if not self.xy_distribution:
|
|
@@ -332,37 +391,25 @@ class TrainerMon:
|
|
|
332
391
|
logger.info_on_rank_0('> grad and momentum direction will not be compared.')
|
|
333
392
|
if not self.cc_distribution.get('enable', False):
|
|
334
393
|
logger.info_on_rank_0("> cc operator is not monitored.")
|
|
335
|
-
if not self.opt_ty:
|
|
336
|
-
if self.ur_distribution:
|
|
337
|
-
raise Exception("ur_distribution cannot be enabled with unknown optimizer.")
|
|
338
|
-
if self.mv_distribution:
|
|
339
|
-
raise Exception("mv_distribution cannot be enabled with unknown optimizer.")
|
|
340
394
|
|
|
341
|
-
def hook_modules(self
|
|
395
|
+
def hook_modules(self):
|
|
342
396
|
if self.module_rank_list and (self.rank not in self.module_rank_list):
|
|
343
397
|
return
|
|
344
398
|
|
|
345
|
-
if not isinstance(model, list):
|
|
346
|
-
model = [model]
|
|
347
|
-
self.model = model
|
|
348
|
-
self._register_param_name(model)
|
|
349
|
-
|
|
350
|
-
self.micro_batch_number = grad_acc_steps
|
|
351
|
-
|
|
352
399
|
targets = self.config['targets']
|
|
353
|
-
module_in_all_stage = [key for key in targets.keys() if MonitorConst.
|
|
400
|
+
module_in_all_stage = [key for key in targets.keys() if MonitorConst.NAME_SEP not in key]
|
|
354
401
|
for key in module_in_all_stage:
|
|
355
402
|
struct = targets.pop(key)
|
|
356
|
-
targets.update({f'{vpp_stage}{MonitorConst.
|
|
403
|
+
targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
|
|
357
404
|
|
|
358
405
|
hooked_count = 0
|
|
359
|
-
for vpp_stage, model_chunk in enumerate(model):
|
|
360
|
-
vpp_stage = f'{vpp_stage}{MonitorConst.
|
|
406
|
+
for vpp_stage, model_chunk in enumerate(self.model):
|
|
407
|
+
vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
361
408
|
targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
|
|
362
409
|
'targets'].keys()
|
|
363
410
|
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
364
411
|
|
|
365
|
-
logger.info_on_rank_0(f"> {hooked_count}
|
|
412
|
+
logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
|
|
366
413
|
|
|
367
414
|
def clone_if_tensor(args):
|
|
368
415
|
if isinstance(args, tuple):
|
|
@@ -383,11 +430,11 @@ class TrainerMon:
|
|
|
383
430
|
|
|
384
431
|
BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
|
|
385
432
|
|
|
386
|
-
if not self.optimizer_hooked:
|
|
387
|
-
self.hook_optimizer()
|
|
388
433
|
return
|
|
389
434
|
|
|
390
435
|
def generate_param_metrics(self, opt_context):
|
|
436
|
+
if not self.param_distribution:
|
|
437
|
+
return
|
|
391
438
|
get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
|
|
392
439
|
|
|
393
440
|
def generate_mv_metrics(self, opt_context):
|
|
@@ -395,8 +442,8 @@ class TrainerMon:
|
|
|
395
442
|
return
|
|
396
443
|
opt_context.exp_avg_metric = {}
|
|
397
444
|
opt_context.exp_avg_sq_metric = {}
|
|
398
|
-
m_tag_tensor_map = self.generate_param_map(
|
|
399
|
-
v_tag_tensor_map = self.generate_param_map(
|
|
445
|
+
m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
|
|
446
|
+
v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
|
|
400
447
|
get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
|
|
401
448
|
get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
|
|
402
449
|
|
|
@@ -416,29 +463,52 @@ class TrainerMon:
|
|
|
416
463
|
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
417
464
|
continue
|
|
418
465
|
tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
466
|
+
self._register_param_call_id("hook_optimizer", tag)
|
|
419
467
|
grad_dict[tag] = grad
|
|
420
468
|
|
|
421
469
|
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
470
|
+
unreduced_grad = self.grad_context.acc_metric if self.weight_hooked else self.grad_context.pre
|
|
471
|
+
return self.grad_context.post, unreduced_grad
|
|
472
|
+
|
|
473
|
+
def set_monitor(
|
|
474
|
+
self,
|
|
475
|
+
model,
|
|
476
|
+
grad_acc_steps=1,
|
|
477
|
+
optimizer=None,
|
|
478
|
+
tp_group=None,
|
|
479
|
+
dp_group=None,
|
|
480
|
+
start_iteration=0
|
|
481
|
+
):
|
|
425
482
|
"""External interface"""
|
|
483
|
+
global start_step
|
|
484
|
+
start_step = start_iteration
|
|
426
485
|
logger.info(f'grad acc steps {grad_acc_steps}')
|
|
427
|
-
self.hook_optimizer(optimizer)
|
|
428
486
|
self.micro_batch_number = grad_acc_steps
|
|
429
|
-
|
|
430
487
|
self.dp_group = dp_group
|
|
431
488
|
self.tp_group = tp_group
|
|
489
|
+
self.optimizer_mon, self.optimizer_class = OptimizerMonFactory.create_optimizer_mon(optimizer)
|
|
490
|
+
self.hook_step_final(optimizer)
|
|
491
|
+
if not isinstance(model, list):
|
|
492
|
+
model = [model]
|
|
493
|
+
self.model = model
|
|
494
|
+
if len(model) > 1:
|
|
495
|
+
self.vpp = True
|
|
496
|
+
self._smallest_rank_print('vpp enabled')
|
|
497
|
+
if not self.dynamic_enable:
|
|
498
|
+
self.register_hooks(optimizer)
|
|
432
499
|
|
|
433
|
-
|
|
500
|
+
def register_hooks(self, optimizer):
|
|
501
|
+
self._register_param_name()
|
|
502
|
+
self.hook_optimizer(optimizer)
|
|
434
503
|
self._patch_grad_sync()
|
|
435
|
-
self.hook_modules(
|
|
504
|
+
self.hook_modules()
|
|
505
|
+
self.monitoring = True
|
|
436
506
|
|
|
437
507
|
def generate_param_map(self, tag, param_tensor):
|
|
438
508
|
metrics = {}
|
|
439
|
-
rank = dist.get_rank() if dist.is_initialized() else None
|
|
440
509
|
for name in self.param2name.values():
|
|
441
|
-
key = get_summary_writer_tag_name(name, tag, rank)
|
|
510
|
+
key = get_summary_writer_tag_name(name, tag, self.rank)
|
|
511
|
+
self._register_param_call_id("optimizer_pre_step_hook", key)
|
|
442
512
|
if name not in param_tensor or param_tensor[name] is None:
|
|
443
513
|
continue
|
|
444
514
|
metrics[key] = param_tensor[name]
|
|
@@ -454,17 +524,19 @@ class TrainerMon:
|
|
|
454
524
|
return actv, actv_grad
|
|
455
525
|
|
|
456
526
|
def reload_xy(self, xy_distribution=False):
|
|
527
|
+
logger.warning("reload_xy() is deprecated and will be removed in a future version. "
|
|
528
|
+
"Use DYNAMIC_MONITOR instead.")
|
|
457
529
|
self.xy_distribution = xy_distribution
|
|
458
530
|
|
|
459
531
|
for handle in self.handles['xy']:
|
|
460
532
|
handle.remove()
|
|
461
533
|
self.handles['xy'].clear()
|
|
462
|
-
self.hook_modules(
|
|
534
|
+
self.hook_modules()
|
|
463
535
|
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
464
536
|
fwd_context.actv.clear()
|
|
465
537
|
|
|
466
538
|
def write_adhoc_check(self, step):
|
|
467
|
-
|
|
539
|
+
self.tensor_metrics.flush(self.summary_writer)
|
|
468
540
|
|
|
469
541
|
def write_xy_tb(self, step):
|
|
470
542
|
if not self.xy_distribution:
|
|
@@ -472,65 +544,65 @@ class TrainerMon:
|
|
|
472
544
|
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
473
545
|
if len(fwd_context.actv) == 0:
|
|
474
546
|
continue
|
|
475
|
-
self.write_metrics(self.ops,
|
|
547
|
+
self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, MonitorConst.ACTV)
|
|
476
548
|
fwd_context.actv.clear()
|
|
477
549
|
if self.grad_context.actv:
|
|
478
|
-
self.write_metrics(self.ops, self.
|
|
550
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD)
|
|
479
551
|
|
|
480
552
|
def write_param_tb(self, opt_context):
|
|
481
553
|
if not self.param_distribution:
|
|
482
554
|
return
|
|
483
|
-
self.write_metrics(self.ops,
|
|
555
|
+
self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, MonitorConst.PARAM)
|
|
484
556
|
|
|
485
557
|
def write_mv_tb(self, opt_context):
|
|
486
558
|
if not self.mv_distribution:
|
|
487
559
|
return
|
|
488
|
-
self.write_metrics(self.ops,
|
|
489
|
-
|
|
490
|
-
self.write_metrics(self.ops,
|
|
491
|
-
|
|
560
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric,
|
|
561
|
+
opt_context.step, MonitorConst.EXP_AVG)
|
|
562
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric,
|
|
563
|
+
opt_context.step, MonitorConst.EXP_AVG_SQ)
|
|
492
564
|
|
|
493
565
|
def write_grad_tb(self, step):
|
|
494
566
|
if not self.wg_distribution:
|
|
495
567
|
return
|
|
496
568
|
|
|
497
569
|
if self.enable_megatron:
|
|
498
|
-
self.write_metrics(self.ops, self.
|
|
570
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
|
|
499
571
|
else:
|
|
500
|
-
self.write_metrics(self.ops, self.
|
|
501
|
-
self.write_metrics(self.ops, self.
|
|
572
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
|
|
573
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
502
574
|
|
|
503
575
|
def hook_optimizer(self, optimizer=None):
|
|
504
576
|
# in DDP by default use params_have_main_grad
|
|
505
577
|
def optimizer_pre_step_hook(optimizer, args, kwargs):
|
|
506
578
|
context = self.optimizer_context[optimizer]
|
|
507
|
-
if self.opt_ty in MonitorConst.DEEPSPEED_OPT_TY:
|
|
508
|
-
if context.step == 0:
|
|
509
|
-
self.name2indices = self.mix_precision_optimizer_mon.get_param_index(self.param2name,
|
|
510
|
-
self.name2index)
|
|
511
|
-
mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name,
|
|
512
|
-
self.name2indices)
|
|
513
|
-
self.param2name = mv_result.grad
|
|
514
|
-
else:
|
|
515
|
-
mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name)
|
|
516
|
-
context.param_exp_avg = mv_result.exp_avg
|
|
517
|
-
context.param_exp_avg_sq = mv_result.exp_avg_sq
|
|
518
|
-
context.param_adam_update = mv_result.update
|
|
519
|
-
context.param_adam_ratio = mv_result.ratio
|
|
520
579
|
|
|
521
580
|
if (self.print_struct and not all(value == {} for value in self.module_struct.values())
|
|
522
581
|
and not self.struct_printed):
|
|
523
|
-
self.
|
|
524
|
-
self._smallest_rank_print(json.dumps(self.module_struct))
|
|
525
|
-
self.struct_printed = True
|
|
582
|
+
self._save_module_struct()
|
|
526
583
|
if not self.cc_log_only:
|
|
527
|
-
raise Exception("exit after first step when print model struct")
|
|
584
|
+
raise Exception("exit after first monitor step when print model struct")
|
|
528
585
|
if self.cc_log_only and context.step > 0:
|
|
529
586
|
self._smallest_rank_print("> Used communication ops and corresponding stack")
|
|
530
587
|
self._smallest_rank_print(
|
|
531
588
|
json.dumps({k: [i.split(';') for i in v] for k, v in self.cc_logged_stack.items()}))
|
|
532
589
|
raise Exception("exit after first step when print cc stack")
|
|
533
590
|
|
|
591
|
+
# skip generate metrics
|
|
592
|
+
if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0:
|
|
593
|
+
return
|
|
594
|
+
if MonitorConst.DEEPSPEED_ZERO_OPT_FILTER in self.optimizer_class: # use deepspeed with zero1/2/3
|
|
595
|
+
if not self.name2indices:
|
|
596
|
+
self.name2indices = self.optimizer_mon.get_param_index(self.param2name, self.name2index, optimizer)
|
|
597
|
+
mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name, self.name2indices)
|
|
598
|
+
self.param2name = mv_result.grad
|
|
599
|
+
else:
|
|
600
|
+
mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name)
|
|
601
|
+
context.param_exp_avg = mv_result.exp_avg
|
|
602
|
+
context.param_exp_avg_sq = mv_result.exp_avg_sq
|
|
603
|
+
context.param_adam_update = mv_result.update
|
|
604
|
+
context.param_adam_ratio = mv_result.ratio
|
|
605
|
+
|
|
534
606
|
self.generate_wgrad_metrics()
|
|
535
607
|
self.generate_mv_metrics(context)
|
|
536
608
|
self.generate_param_metrics(context)
|
|
@@ -561,41 +633,10 @@ class TrainerMon:
|
|
|
561
633
|
context.metric_dict = metric_dict
|
|
562
634
|
return
|
|
563
635
|
|
|
564
|
-
def optimizer_post_step_hook(optimizer, args, kwargs):
|
|
565
|
-
context = self.optimizer_context[optimizer]
|
|
566
|
-
rank = dist.get_rank() if dist.is_initialized() else None
|
|
567
|
-
|
|
568
|
-
if self.anomaly_data_factory:
|
|
569
|
-
self.anomaly_data_factory.set_call_id(self.param_name_call_id)
|
|
570
|
-
self.write_xy_tb(context.step)
|
|
571
|
-
self.write_grad_tb(context.step)
|
|
572
|
-
self.write_mv_tb(context)
|
|
573
|
-
self.write_param_tb(context)
|
|
574
|
-
self.write_adhoc_check(context.step)
|
|
575
|
-
|
|
576
|
-
if self.ur_distribution:
|
|
577
|
-
for param_name, _ in context.param_adam_update.items():
|
|
578
|
-
self.update_heatmap_visualizer[param_name].visualize(
|
|
579
|
-
get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, self.summary_writer)
|
|
580
|
-
for param_name, _ in context.param_adam_ratio.items():
|
|
581
|
-
self.ratio_heatmap_visualizer[param_name].visualize(
|
|
582
|
-
get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer)
|
|
583
|
-
|
|
584
|
-
if context.metric_dict:
|
|
585
|
-
self.write_metrics(self.ops, self.summary_writer, context.metric_dict, context.step, 'other')
|
|
586
|
-
context.metric_dict.clear()
|
|
587
|
-
context.step += 1
|
|
588
|
-
if self.anomaly_data_factory:
|
|
589
|
-
self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
|
|
590
|
-
self.summary_writer.clear_anomalies()
|
|
591
|
-
self.call_id = 0
|
|
592
|
-
return
|
|
593
|
-
|
|
594
636
|
def patch_step(func, optimizer):
|
|
595
637
|
def wrapper(*args, **kwargs):
|
|
596
638
|
optimizer_pre_step_hook(optimizer, args, kwargs)
|
|
597
639
|
out = func(*args, **kwargs)
|
|
598
|
-
optimizer_post_step_hook(optimizer, args, kwargs)
|
|
599
640
|
return out
|
|
600
641
|
|
|
601
642
|
return wrapper
|
|
@@ -603,16 +644,177 @@ class TrainerMon:
|
|
|
603
644
|
if self.optimizer_hooked:
|
|
604
645
|
return
|
|
605
646
|
|
|
606
|
-
|
|
607
|
-
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
647
|
+
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
608
648
|
|
|
609
|
-
else:
|
|
610
|
-
if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list):
|
|
611
|
-
register_optimizer_step_pre_hook(optimizer_pre_step_hook)
|
|
612
|
-
register_optimizer_step_post_hook(optimizer_post_step_hook)
|
|
613
649
|
self.optimizer_hooked = True
|
|
614
650
|
return
|
|
615
651
|
|
|
652
|
+
def dynamic_monitor(self, optimizer):
|
|
653
|
+
"""
|
|
654
|
+
If dynamic monitor enabled and config.json updated,
|
|
655
|
+
remove hooks and register new hooks according to new configuration.
|
|
656
|
+
"""
|
|
657
|
+
context = self.optimizer_context[optimizer]
|
|
658
|
+
if not self.dynamic_enable:
|
|
659
|
+
return
|
|
660
|
+
try:
|
|
661
|
+
# 如果文件时间戳没变, 可以不读取节省时间
|
|
662
|
+
config_timestamp = os.path.getmtime(self.config_file_path)
|
|
663
|
+
if config_timestamp == self.config_timestamp:
|
|
664
|
+
return
|
|
665
|
+
# 更新config文件最新修改时间戳
|
|
666
|
+
self.config_timestamp = config_timestamp
|
|
667
|
+
config = load_json(self.config_file_path)
|
|
668
|
+
except Exception as e:
|
|
669
|
+
logger.error(f"get config.json wrong because {e}, not updated, please check!!!")
|
|
670
|
+
return
|
|
671
|
+
|
|
672
|
+
if config.get("dynamic_on", False):
|
|
673
|
+
try:
|
|
674
|
+
validate_config(config)
|
|
675
|
+
self.config = config
|
|
676
|
+
self.set_config()
|
|
677
|
+
logger.warning(f"config is updated at step{context.step - 1}, "
|
|
678
|
+
f"will start new hook at step{context.step}.")
|
|
679
|
+
except Exception as e:
|
|
680
|
+
logger.error(f"set config wrong because {e}, not updated, please check!!!")
|
|
681
|
+
return
|
|
682
|
+
|
|
683
|
+
self._remove_all_hooks(optimizer)
|
|
684
|
+
self.register_hooks(optimizer)
|
|
685
|
+
|
|
686
|
+
def hook_step_final(self, optimizer):
|
|
687
|
+
def step_final_hook(optimizer, args, kwargs):
|
|
688
|
+
context = self.optimizer_context[optimizer]
|
|
689
|
+
rank = dist.get_rank() if dist.is_initialized() else None
|
|
690
|
+
# 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
|
|
691
|
+
if self.monitoring:
|
|
692
|
+
module_rank_valid = not self.module_rank_list or (
|
|
693
|
+
dist.is_initialized() and dist.get_rank() in self.module_rank_list)
|
|
694
|
+
step_condition = (context.step >= self.start_step and (
|
|
695
|
+
context.step - self.start_step) % self.step_interval == 0)
|
|
696
|
+
if module_rank_valid and step_condition:
|
|
697
|
+
self.has_collect_times += 1
|
|
698
|
+
|
|
699
|
+
if self.anomaly_data_factory:
|
|
700
|
+
self.anomaly_data_factory.set_call_id(self.param_name_call_id)
|
|
701
|
+
self.write_xy_tb(context.step)
|
|
702
|
+
self.write_grad_tb(context.step)
|
|
703
|
+
self.write_mv_tb(context)
|
|
704
|
+
self.write_param_tb(context)
|
|
705
|
+
self.write_adhoc_check(context.step)
|
|
706
|
+
|
|
707
|
+
if self.ur_distribution:
|
|
708
|
+
for param_name, _ in context.param_adam_update.items():
|
|
709
|
+
self.update_heatmap_visualizer[param_name].visualize(
|
|
710
|
+
get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step,
|
|
711
|
+
self.summary_writer)
|
|
712
|
+
for param_name, _ in context.param_adam_ratio.items():
|
|
713
|
+
self.ratio_heatmap_visualizer[param_name].visualize(
|
|
714
|
+
get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step,
|
|
715
|
+
self.summary_writer)
|
|
716
|
+
|
|
717
|
+
if context.metric_dict:
|
|
718
|
+
self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
|
|
719
|
+
context.metric_dict.clear()
|
|
720
|
+
|
|
721
|
+
if self.anomaly_data_factory:
|
|
722
|
+
self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
|
|
723
|
+
self.summary_writer.clear_anomalies()
|
|
724
|
+
self.call_id = 0
|
|
725
|
+
self.param_name_call_id.clear()
|
|
726
|
+
|
|
727
|
+
if self.has_collect_times >= self.collect_times:
|
|
728
|
+
self._remove_all_hooks_final(optimizer)
|
|
729
|
+
|
|
730
|
+
context.step += 1
|
|
731
|
+
self.dynamic_monitor(optimizer)
|
|
732
|
+
|
|
733
|
+
def patch_step(func, optimizer):
|
|
734
|
+
def wrapper(*args, **kwargs):
|
|
735
|
+
out = func(*args, **kwargs)
|
|
736
|
+
step_final_hook(optimizer, args, kwargs)
|
|
737
|
+
return out
|
|
738
|
+
return wrapper
|
|
739
|
+
|
|
740
|
+
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
741
|
+
self.origin_step_func = optimizer.__class__.step
|
|
742
|
+
|
|
743
|
+
return
|
|
744
|
+
|
|
745
|
+
def _remove_all_hooks(self, optimizer):
|
|
746
|
+
# 清空hook handle
|
|
747
|
+
for handle in self.handles['xy']:
|
|
748
|
+
handle.remove()
|
|
749
|
+
self.handles['xy'].clear()
|
|
750
|
+
# 清空对应context缓存
|
|
751
|
+
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
752
|
+
fwd_context.reset()
|
|
753
|
+
for _, bwd_context in self.module_bwd_hook_context_by_module.items():
|
|
754
|
+
bwd_context.reset()
|
|
755
|
+
self.grad_context.reset() # 权重梯度和激活值梯度都在这
|
|
756
|
+
|
|
757
|
+
if self.origin_start_grad_sync: # megatron
|
|
758
|
+
try:
|
|
759
|
+
from megatron.core.distributed.param_and_grad_buffer import Bucket
|
|
760
|
+
Bucket.start_grad_sync = self.origin_start_grad_sync
|
|
761
|
+
logger.info("remove Bucket start_grad_sync")
|
|
762
|
+
except ImportError:
|
|
763
|
+
pass
|
|
764
|
+
try:
|
|
765
|
+
from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
|
|
766
|
+
_ParamAndGradBucketGroup.start_grad_sync = self.origin_start_grad_sync
|
|
767
|
+
logger.info("remove _ParamAndGradBucketGroup start_grad_sync")
|
|
768
|
+
except ImportError:
|
|
769
|
+
pass
|
|
770
|
+
else: # not megatron
|
|
771
|
+
for handle in self.handles['wgrads']:
|
|
772
|
+
handle.remove()
|
|
773
|
+
self.handles['wgrads'].clear()
|
|
774
|
+
self.weight_hooked = False
|
|
775
|
+
|
|
776
|
+
if self.optimizer_hooked:
|
|
777
|
+
optimizer.__class__.step = self.origin_step_func
|
|
778
|
+
|
|
779
|
+
for _, context in self.optimizer_context.items():
|
|
780
|
+
context.reset()
|
|
781
|
+
self.optimizer_hooked = False
|
|
782
|
+
|
|
783
|
+
for handle in self.handles['cc']:
|
|
784
|
+
handle.remove()
|
|
785
|
+
self.handles['cc'].clear()
|
|
786
|
+
for _, context in self.cc_context.items():
|
|
787
|
+
context.reset()
|
|
788
|
+
|
|
789
|
+
# 清空节点缓存
|
|
790
|
+
self.param2name.clear()
|
|
791
|
+
self.name2index.clear()
|
|
792
|
+
self.name2indices.clear()
|
|
793
|
+
self.name2param.clear()
|
|
794
|
+
self.duplicate_param.clear()
|
|
795
|
+
self.name2tag.clear()
|
|
796
|
+
self.module_struct.clear()
|
|
797
|
+
self.grad_accs.clear()
|
|
798
|
+
|
|
799
|
+
# 关闭采集状态
|
|
800
|
+
self.monitoring = False
|
|
801
|
+
|
|
802
|
+
def _remove_all_hooks_final(self, optimizer):
|
|
803
|
+
if self.dynamic_enable:
|
|
804
|
+
# 结束后自动重置dynamic_on为False等待用户手动开启
|
|
805
|
+
try:
|
|
806
|
+
config = load_json(self.config_file_path)
|
|
807
|
+
config['dynamic_on'] = False
|
|
808
|
+
save_json(self.config_file_path, config, indent=2)
|
|
809
|
+
config_timestamp = os.path.getmtime(self.config_file_path)
|
|
810
|
+
self.config_timestamp = config_timestamp
|
|
811
|
+
logger.info(
|
|
812
|
+
"Finish monitor, set config'dynamic_on=False, will restart by set it to True and update config")
|
|
813
|
+
except Exception as e:
|
|
814
|
+
logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!")
|
|
815
|
+
logger.info("Finish monitor")
|
|
816
|
+
self._remove_all_hooks(optimizer)
|
|
817
|
+
|
|
616
818
|
def _smallest_rank_print(self, msg):
|
|
617
819
|
if dist.is_initialized():
|
|
618
820
|
if self.module_rank_list:
|
|
@@ -624,9 +826,20 @@ class TrainerMon:
|
|
|
624
826
|
else:
|
|
625
827
|
logger.info(msg)
|
|
626
828
|
|
|
829
|
+
def _save_module_struct(self):
|
|
830
|
+
save_module_struct = (not dist.is_initialized()
|
|
831
|
+
or (self.module_rank_list and dist.get_rank() == min(self.module_rank_list))
|
|
832
|
+
or (not self.module_rank_list and dist.get_rank() == 0))
|
|
833
|
+
|
|
834
|
+
if save_module_struct:
|
|
835
|
+
module_struct_file = os.path.realpath(os.path.join(get_output_base_dir(), 'module_struct.json'))
|
|
836
|
+
save_json(module_struct_file, self.module_struct, indent=2)
|
|
837
|
+
logger.info(f"> save module struct to {module_struct_file}")
|
|
838
|
+
self.struct_printed = True
|
|
839
|
+
|
|
627
840
|
def _is_target_param(self, param_name, param, prefix):
|
|
628
|
-
squash_name = prefix + squash_param_name(param_name)
|
|
629
841
|
name = prefix + param_name
|
|
842
|
+
squash_name = prefix + squash_param_name(param_name, self.squash_name)
|
|
630
843
|
for target in self.config['targets'].keys():
|
|
631
844
|
if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
|
|
632
845
|
setattr(param, "zero_out_wgrad", True)
|
|
@@ -635,15 +848,14 @@ class TrainerMon:
|
|
|
635
848
|
return False
|
|
636
849
|
|
|
637
850
|
def _register_chunk(self, model_chunk, prefix):
|
|
638
|
-
|
|
851
|
+
index = 0
|
|
852
|
+
for (param_name, param) in model_chunk.named_parameters():
|
|
639
853
|
if not param.requires_grad:
|
|
640
854
|
continue
|
|
641
855
|
if self._is_target_param(param_name, param, prefix):
|
|
642
|
-
name = prefix + squash_param_name(param_name)
|
|
856
|
+
name = prefix + squash_param_name(param_name, self.squash_name)
|
|
643
857
|
if name in self.param2name.values():
|
|
644
|
-
|
|
645
|
-
May be error of squash_param_name')
|
|
646
|
-
raise Exception("param with same name will be overwritten.")
|
|
858
|
+
name = prefix + param_name
|
|
647
859
|
self.param2name[param] = name
|
|
648
860
|
self.name2param[name] = param
|
|
649
861
|
self.name2index[name] = index
|
|
@@ -652,34 +864,22 @@ class TrainerMon:
|
|
|
652
864
|
self.duplicate_param[name] = True
|
|
653
865
|
if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
|
|
654
866
|
self.duplicate_param[name] = True
|
|
655
|
-
self.name2tag[name] = {
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
def _register_param_name(self
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
if not isinstance(model, list):
|
|
666
|
-
model = [model]
|
|
667
|
-
|
|
668
|
-
if len(model) > 1:
|
|
669
|
-
self.vpp = True
|
|
670
|
-
self._smallest_rank_print('vpp enabled')
|
|
671
|
-
|
|
672
|
-
for vpp_stage, model_chunk in enumerate(model):
|
|
673
|
-
prefix = f'{vpp_stage}{MonitorConst.VPP_SEP}'
|
|
867
|
+
self.name2tag[name] = {
|
|
868
|
+
MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank),
|
|
869
|
+
MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank)
|
|
870
|
+
}
|
|
871
|
+
index += 1
|
|
872
|
+
|
|
873
|
+
def _register_param_name(self):
|
|
874
|
+
for vpp_stage, model_chunk in enumerate(self.model):
|
|
875
|
+
prefix = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
674
876
|
self._register_chunk(model_chunk, prefix)
|
|
675
877
|
|
|
676
|
-
self.param_registered = True
|
|
677
|
-
|
|
678
878
|
def _is_target_module(self, module_name, targets, vpp_stage):
|
|
679
879
|
if self.all_xy or self.print_struct:
|
|
680
|
-
return vpp_stage + squash_param_name(module_name)
|
|
880
|
+
return vpp_stage + squash_param_name(module_name, self.squash_name)
|
|
681
881
|
for pattern in [
|
|
682
|
-
vpp_stage + squash_param_name(module_name),
|
|
882
|
+
vpp_stage + squash_param_name(module_name, self.squash_name),
|
|
683
883
|
vpp_stage + module_name,
|
|
684
884
|
]:
|
|
685
885
|
if pattern in targets:
|
|
@@ -692,90 +892,88 @@ class TrainerMon:
|
|
|
692
892
|
return 0
|
|
693
893
|
|
|
694
894
|
def fwd_hook_fun(module, module_input, module_output, name):
|
|
695
|
-
if is_recomputation():
|
|
895
|
+
if not module.training or is_recomputation():
|
|
896
|
+
# 1 only monitor training stage.
|
|
897
|
+
# 2 when open recompute, skip recomputed forward stage.
|
|
696
898
|
return
|
|
697
899
|
if module not in self.module_fwd_hook_context_by_module:
|
|
698
900
|
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
699
901
|
context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
|
|
700
902
|
if not context.struct:
|
|
701
|
-
context.struct = {
|
|
702
|
-
|
|
903
|
+
context.struct = {
|
|
904
|
+
Const.INPUT: get_param_struct(module_input),
|
|
905
|
+
Const.OUTPUT: get_param_struct(module_output)
|
|
906
|
+
}
|
|
703
907
|
if self.print_struct:
|
|
704
|
-
if context.module_name not in self.module_struct:
|
|
705
|
-
self.module_struct[context.module_name] = {}
|
|
706
908
|
self.module_struct[context.module_name].update(context.struct)
|
|
707
909
|
return
|
|
708
|
-
if not module.training:
|
|
709
|
-
return
|
|
710
910
|
if not context.format_by_arg:
|
|
711
|
-
context.set_format_by_arg(
|
|
712
|
-
context.set_format_by_arg(
|
|
911
|
+
context.set_format_by_arg(Const.INPUT, self.config['targets'])
|
|
912
|
+
context.set_format_by_arg(Const.OUTPUT, self.config['targets'])
|
|
713
913
|
if not context.format_by_arg:
|
|
714
914
|
return
|
|
715
915
|
if not context.verified:
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
|
|
916
|
+
context.focused_in_col = validate_config_spec(context.format_by_arg[Const.INPUT],
|
|
917
|
+
module_input, context.module_name,
|
|
918
|
+
Const.INPUT)
|
|
919
|
+
context.focused_out_col = validate_config_spec(context.format_by_arg[Const.OUTPUT],
|
|
721
920
|
module_output, context.module_name,
|
|
722
|
-
|
|
921
|
+
Const.OUTPUT)
|
|
723
922
|
context.verified = True
|
|
724
923
|
# expect output be tensor type
|
|
725
924
|
tbtag_tensor_map = {}
|
|
726
|
-
if
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
925
|
+
cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
|
|
926
|
+
tbtag_tensor_map.update(
|
|
927
|
+
self.build_tbtag_tensor_map(
|
|
928
|
+
f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
929
|
+
MonitorConst.ACTV, cared_input))
|
|
731
930
|
cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
|
|
732
931
|
tbtag_tensor_map.update(
|
|
733
|
-
self.build_tbtag_tensor_map(
|
|
734
|
-
|
|
932
|
+
self.build_tbtag_tensor_map(
|
|
933
|
+
f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
934
|
+
MonitorConst.ACTV, cared_output))
|
|
735
935
|
|
|
736
936
|
get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
|
|
737
|
-
|
|
738
937
|
context.micro_step += 1
|
|
739
938
|
if context.micro_step == self.micro_batch_number:
|
|
740
939
|
context.micro_step = 0
|
|
741
|
-
context.step += 1
|
|
742
940
|
return
|
|
743
941
|
|
|
744
942
|
def bwd_hook_fun(module, input_grad, output_grad):
|
|
745
943
|
context: ModuleHookContext = self.module_bwd_hook_context_by_module[module]
|
|
746
944
|
if not context.struct:
|
|
747
|
-
context.struct = {
|
|
748
|
-
|
|
945
|
+
context.struct = {
|
|
946
|
+
MonitorConst.INPUT_GRAD: get_param_struct(input_grad),
|
|
947
|
+
MonitorConst.OUTPUT_GRAD: get_param_struct(output_grad)
|
|
948
|
+
}
|
|
749
949
|
if self.print_struct:
|
|
750
|
-
if context.module_name not in self.module_struct:
|
|
751
|
-
self.module_struct[context.module_name] = {}
|
|
752
950
|
self.module_struct[context.module_name].update(context.struct)
|
|
753
951
|
return
|
|
754
952
|
if not context.format_by_arg:
|
|
755
|
-
context.set_format_by_arg(MonitorConst.
|
|
756
|
-
context.set_format_by_arg(MonitorConst.
|
|
953
|
+
context.set_format_by_arg(MonitorConst.INPUT_GRAD, self.config['targets'])
|
|
954
|
+
context.set_format_by_arg(MonitorConst.OUTPUT_GRAD, self.config['targets'])
|
|
757
955
|
if not context.format_by_arg:
|
|
758
956
|
return
|
|
759
957
|
if not context.verified:
|
|
760
|
-
|
|
761
|
-
context.
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
MonitorConst.ACTVGRAD_OUT)
|
|
958
|
+
context.focused_in_col = validate_config_spec(
|
|
959
|
+
context.format_by_arg[MonitorConst.INPUT_GRAD],
|
|
960
|
+
input_grad, context.module_name, MonitorConst.INPUT_GRAD)
|
|
961
|
+
context.focused_out_col = validate_config_spec(
|
|
962
|
+
context.format_by_arg[MonitorConst.OUTPUT_GRAD],
|
|
963
|
+
output_grad, context.module_name, MonitorConst.OUTPUT_GRAD)
|
|
767
964
|
context.verified = True
|
|
768
965
|
|
|
769
966
|
tbtag_tensor_map = {}
|
|
770
|
-
if
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
967
|
+
cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
|
|
968
|
+
tbtag_tensor_map.update(
|
|
969
|
+
self.build_tbtag_tensor_map(
|
|
970
|
+
f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
971
|
+
MonitorConst.ACTV, cared_input_grad))
|
|
775
972
|
cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
|
|
776
973
|
tbtag_tensor_map.update(
|
|
777
|
-
self.build_tbtag_tensor_map(
|
|
778
|
-
|
|
974
|
+
self.build_tbtag_tensor_map(
|
|
975
|
+
f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
976
|
+
MonitorConst.ACTV, cared_output_grad))
|
|
779
977
|
|
|
780
978
|
if context.micro_step == 0 and context.actvgrad:
|
|
781
979
|
logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
|
|
@@ -787,7 +985,6 @@ class TrainerMon:
|
|
|
787
985
|
context.micro_step += 1
|
|
788
986
|
if context.micro_step == self.micro_batch_number:
|
|
789
987
|
context.micro_step = 0
|
|
790
|
-
context.step += 1
|
|
791
988
|
return
|
|
792
989
|
|
|
793
990
|
if self.backward_only and self.forward_only:
|
|
@@ -802,7 +999,7 @@ class TrainerMon:
|
|
|
802
999
|
if not self.backward_only:
|
|
803
1000
|
handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name))
|
|
804
1001
|
self.handles['xy'].append(handle)
|
|
805
|
-
if not self.forward_only:
|
|
1002
|
+
if not self.forward_only and not self.has_register_backward_hook(name, submodule):
|
|
806
1003
|
handle = submodule.register_full_backward_hook(bwd_hook_fun)
|
|
807
1004
|
self.handles['xy'].append(handle)
|
|
808
1005
|
self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
|
|
@@ -814,8 +1011,12 @@ class TrainerMon:
|
|
|
814
1011
|
def patch_sync(sync_grad_func):
|
|
815
1012
|
def wrapper(bucket):
|
|
816
1013
|
grad_dict = {}
|
|
1014
|
+
# Megatron between core_r0.6.0 and core_r0.8.0, this bucket is Bucket.
|
|
1015
|
+
# When megatron is core_r0.9.0, this bucket is _ParamAndGradBucketGroup.
|
|
1016
|
+
# In megatron version core_r0.9.0, func start_grad_sync from Bucket moved to _ParamAndGradBucketGroup.
|
|
1017
|
+
bucket_params_id_list = [id(params) for params in bucket.params]
|
|
817
1018
|
for param, name in self.param2name.items():
|
|
818
|
-
if param not in
|
|
1019
|
+
if id(param) not in bucket_params_id_list:
|
|
819
1020
|
continue
|
|
820
1021
|
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
821
1022
|
if grad is None:
|
|
@@ -825,21 +1026,35 @@ class TrainerMon:
|
|
|
825
1026
|
if tag is None:
|
|
826
1027
|
continue
|
|
827
1028
|
grad_dict[tag] = grad
|
|
1029
|
+
self._register_param_call_id("sync_grad_func", tag)
|
|
828
1030
|
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
829
1031
|
out = sync_grad_func(bucket)
|
|
830
1032
|
return out
|
|
831
1033
|
|
|
832
1034
|
return wrapper
|
|
833
1035
|
|
|
1036
|
+
if not self.wg_distribution:
|
|
1037
|
+
return
|
|
1038
|
+
|
|
834
1039
|
try:
|
|
835
1040
|
from megatron.core.distributed.param_and_grad_buffer import Bucket
|
|
1041
|
+
self.origin_start_grad_sync = Bucket.start_grad_sync
|
|
1042
|
+
Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync)
|
|
836
1043
|
self.enable_megatron = True
|
|
1044
|
+
logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0")
|
|
837
1045
|
except ImportError:
|
|
838
1046
|
self.enable_megatron = False
|
|
839
1047
|
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
1048
|
+
try:
|
|
1049
|
+
from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
|
|
1050
|
+
self.origin_start_grad_sync = _ParamAndGradBucketGroup.start_grad_sync
|
|
1051
|
+
_ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync)
|
|
1052
|
+
self.enable_megatron = True
|
|
1053
|
+
logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0")
|
|
1054
|
+
except ImportError:
|
|
1055
|
+
self.enable_megatron = False
|
|
1056
|
+
|
|
1057
|
+
if not self.enable_megatron:
|
|
843
1058
|
self._hook_weights()
|
|
844
1059
|
|
|
845
1060
|
def _hook_weights(self):
|
|
@@ -848,8 +1063,7 @@ class TrainerMon:
|
|
|
848
1063
|
@torch.no_grad
|
|
849
1064
|
def param_hook(*args, context_dict, param, key, name):
|
|
850
1065
|
param.micro_step += 1
|
|
851
|
-
self.
|
|
852
|
-
self.call_id += 1
|
|
1066
|
+
self._register_param_call_id("param_hook", key)
|
|
853
1067
|
if param.micro_step == self.micro_batch_number:
|
|
854
1068
|
param.micro_step = 0
|
|
855
1069
|
if self.params_have_main_grad:
|
|
@@ -857,6 +1071,7 @@ class TrainerMon:
|
|
|
857
1071
|
else:
|
|
858
1072
|
context_dict[key] = param.grad.clone()
|
|
859
1073
|
|
|
1074
|
+
logger.info("hooking weights.")
|
|
860
1075
|
for param, name in self.param2name.items():
|
|
861
1076
|
key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
|
|
862
1077
|
setattr(param, 'micro_step', 0)
|
|
@@ -868,3 +1083,13 @@ class TrainerMon:
|
|
|
868
1083
|
self.handles['wgrads'].append(handle)
|
|
869
1084
|
|
|
870
1085
|
self.weight_hooked = True
|
|
1086
|
+
|
|
1087
|
+
def _register_param_call_id(self, hook_name: str, key: str):
|
|
1088
|
+
"""
|
|
1089
|
+
:param hook_name:
|
|
1090
|
+
:param key: str, '0:relu_0/output_grad'
|
|
1091
|
+
:return:
|
|
1092
|
+
"""
|
|
1093
|
+
logger.debug(f"{hook_name} {key}: {self.call_id}")
|
|
1094
|
+
self.param_name_call_id[key] = self.call_id
|
|
1095
|
+
self.call_id += 1
|