mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -12,50 +12,44 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
import time
|
|
16
15
|
import json
|
|
17
16
|
import os
|
|
18
17
|
import uuid
|
|
19
18
|
from collections import defaultdict
|
|
20
|
-
from datetime import datetime
|
|
19
|
+
from datetime import datetime
|
|
21
20
|
from functools import partial
|
|
22
21
|
|
|
23
22
|
import pytz
|
|
24
23
|
import torch
|
|
25
24
|
import torch.distributed as dist
|
|
25
|
+
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
|
|
26
|
+
from torch.utils.hooks import BackwardHook
|
|
27
|
+
|
|
26
28
|
from msprobe.core.common.const import MonitorConst
|
|
27
|
-
from msprobe.core.common.file_utils import load_json
|
|
28
|
-
from msprobe.
|
|
29
|
+
from msprobe.core.common.file_utils import load_json, save_json
|
|
30
|
+
from msprobe.pytorch.common.log import logger
|
|
29
31
|
from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter
|
|
30
32
|
from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
|
|
31
33
|
CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
32
34
|
from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
|
|
33
35
|
get_process_group
|
|
34
36
|
from msprobe.pytorch.monitor.features import get_sign_matches
|
|
35
|
-
from msprobe.pytorch.monitor.module_metric import get_metrics,
|
|
36
|
-
TensorMetrics,
|
|
37
|
+
from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
|
|
38
|
+
TensorMetrics, squash_param_name
|
|
37
39
|
from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
|
|
38
40
|
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory, OptimizerMon
|
|
39
|
-
from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, is_recomputation
|
|
41
|
+
from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, is_recomputation, \
|
|
42
|
+
get_output_base_dir, get_target_output_dir
|
|
40
43
|
from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
|
|
41
|
-
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
|
|
42
|
-
from torch.utils.hooks import BackwardHook
|
|
43
|
-
|
|
44
|
-
try:
|
|
45
|
-
import torch_npu
|
|
46
|
-
except ImportError:
|
|
47
|
-
pass
|
|
48
44
|
|
|
49
45
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
50
46
|
if not torch_version_above_or_equal_2:
|
|
51
47
|
raise ValueError("monitor require torch>=2.0")
|
|
52
48
|
|
|
53
|
-
output_base_dir = os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
|
|
54
|
-
|
|
55
49
|
FORMAT_MAPPING = {
|
|
56
|
-
MonitorConst.TENSORBOARD:
|
|
57
|
-
MonitorConst.CSV:
|
|
58
|
-
MonitorConst.API:
|
|
50
|
+
MonitorConst.TENSORBOARD: SummaryWriterWithAD,
|
|
51
|
+
MonitorConst.CSV: CSVWriterWithAD,
|
|
52
|
+
MonitorConst.API: BaseWriterWithAD
|
|
59
53
|
}
|
|
60
54
|
|
|
61
55
|
|
|
@@ -71,7 +65,6 @@ def param_is_data_parallel_duplicate(dp_group):
|
|
|
71
65
|
|
|
72
66
|
class ModuleHookContext:
|
|
73
67
|
def __init__(self, module_name) -> None:
|
|
74
|
-
self.step = 0
|
|
75
68
|
self.micro_step = 0
|
|
76
69
|
self.actv = defaultdict(dict)
|
|
77
70
|
self.actvgrad = []
|
|
@@ -81,26 +74,47 @@ class ModuleHookContext:
|
|
|
81
74
|
self.verified = False
|
|
82
75
|
self.focused_in_col = 0
|
|
83
76
|
self.focused_out_col = 0
|
|
84
|
-
self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found
|
|
85
77
|
|
|
86
78
|
def set_format_by_arg(self, key_name: str, target_config: dict):
|
|
79
|
+
""" 按照监控对象配置format_by_arg
|
|
80
|
+
1) module_name 在 target 中配置监控对象
|
|
81
|
+
2) module_name 未在 targets 中配置,且 all_xy 全量监控
|
|
82
|
+
3) module_name 未在 targets 中配置,且 all_xy 未全量监控
|
|
83
|
+
|
|
84
|
+
:param key_name: str, one of [input, output, input_grad, output_grad]
|
|
85
|
+
:param target_config: target obj in config json.
|
|
86
|
+
:return:
|
|
87
|
+
"""
|
|
88
|
+
valid_key = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT, MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
|
|
89
|
+
if key_name not in valid_key:
|
|
90
|
+
raise ValueError(f"key({key_name}) error, valid_key: {valid_key}")
|
|
87
91
|
cared = target_config.get(self.module_name, self.struct)
|
|
88
92
|
if key_name in cared:
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
self.format_by_arg[key_name] = config
|
|
93
|
-
|
|
93
|
+
target_module_config = cared[key_name]
|
|
94
|
+
if isinstance(target_module_config, dict):
|
|
95
|
+
# current cared is self.struct, monitor all data for module_name
|
|
96
|
+
self.format_by_arg[key_name] = target_module_config.get('config')
|
|
97
|
+
elif isinstance(target_module_config, str):
|
|
94
98
|
# current cared is target_config[self.module_name]
|
|
95
|
-
self.format_by_arg[key_name] =
|
|
96
|
-
|
|
97
|
-
|
|
99
|
+
self.format_by_arg[key_name] = target_module_config
|
|
100
|
+
else:
|
|
101
|
+
logger.warning_on_rank_0(f"target module config error, result maybe empty."
|
|
102
|
+
f"module_name: {self.module_name}, key_name: {key_name}")
|
|
103
|
+
self.format_by_arg[key_name] = None
|
|
104
|
+
else:
|
|
105
|
+
self.format_by_arg[key_name] = self.struct.get(key_name).get('config')
|
|
106
|
+
|
|
107
|
+
def reset(self):
|
|
108
|
+
self.actv.clear()
|
|
109
|
+
self.actvgrad.clear()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
start_step = 0
|
|
98
113
|
|
|
99
114
|
|
|
100
115
|
class OptimizerContext:
|
|
101
116
|
def __init__(self) -> None:
|
|
102
|
-
self.step =
|
|
103
|
-
self.param_effective_rank = defaultdict(float)
|
|
117
|
+
self.step = start_step
|
|
104
118
|
self.param_mg_direction = defaultdict(float)
|
|
105
119
|
self.param_adam_update = defaultdict()
|
|
106
120
|
self.param_adam_ratio = defaultdict()
|
|
@@ -112,6 +126,18 @@ class OptimizerContext:
|
|
|
112
126
|
self.metric_dict = {}
|
|
113
127
|
self.param_metric = {}
|
|
114
128
|
|
|
129
|
+
def reset(self):
|
|
130
|
+
self.param_mg_direction.clear()
|
|
131
|
+
self.param_adam_update.clear()
|
|
132
|
+
self.param_adam_ratio.clear()
|
|
133
|
+
self.param_weight_grad.clear()
|
|
134
|
+
self.param_exp_avg.clear()
|
|
135
|
+
self.exp_avg_metric.clear()
|
|
136
|
+
self.param_exp_avg_sq.clear()
|
|
137
|
+
self.exp_avg_sq_metric.clear()
|
|
138
|
+
self.metric_dict.clear()
|
|
139
|
+
self.param_metric.clear()
|
|
140
|
+
|
|
115
141
|
|
|
116
142
|
class CommunicationContext:
|
|
117
143
|
def __init__(self) -> None:
|
|
@@ -156,17 +182,131 @@ class TrainerMon:
|
|
|
156
182
|
"""
|
|
157
183
|
opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer"
|
|
158
184
|
"""
|
|
159
|
-
|
|
160
|
-
self.
|
|
161
|
-
self.optimizer_context = defaultdict(OptimizerContext)
|
|
162
|
-
self.cc_context = defaultdict(CommunicationContext)
|
|
163
|
-
self.grad_context = GradContext()
|
|
185
|
+
# TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
|
|
186
|
+
self.config_file_path = config_file_path
|
|
164
187
|
self.process_group = get_process_group(process_group)
|
|
165
188
|
self.params_have_main_grad = params_have_main_grad
|
|
166
189
|
self.opt_ty = opt_ty
|
|
190
|
+
self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty)
|
|
191
|
+
self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
192
|
+
self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
193
|
+
self.origin_step_func = None
|
|
194
|
+
self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过switch开关直接打开
|
|
167
195
|
self.config = load_json(config_file_path)
|
|
168
196
|
validate_config(self.config)
|
|
169
197
|
|
|
198
|
+
self.squash_name = self.config.get('squash_name', True) # 不允许修改防止前后名字对不上
|
|
199
|
+
local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
|
|
200
|
+
cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
|
|
201
|
+
self.unique_id = str(uuid.uuid4())[:8]
|
|
202
|
+
self.output_base_dir = get_output_base_dir()
|
|
203
|
+
time_tags = self.config.get("append_output", [])
|
|
204
|
+
if dist.is_initialized():
|
|
205
|
+
self.rank = dist.get_rank()
|
|
206
|
+
if time_tags:
|
|
207
|
+
output_append_dirs = get_target_output_dir(self.output_base_dir, time_tags[0], time_tags[1])
|
|
208
|
+
if str(self.rank) in output_append_dirs:
|
|
209
|
+
self.tensorboard_dir = output_append_dirs[str(self.rank)]
|
|
210
|
+
logger.info(f"append rank({self.rank}) result to {self.tensorboard_dir}")
|
|
211
|
+
else:
|
|
212
|
+
self.tensorboard_dir = os.path.join(self.output_base_dir,
|
|
213
|
+
f"{cur_time}-rank{self.rank}-{self.unique_id}")
|
|
214
|
+
self.pp_stage = dist.get_group_rank(self.process_group, self.rank)
|
|
215
|
+
self.group_mates = dist.get_process_group_ranks(self.process_group)
|
|
216
|
+
else:
|
|
217
|
+
self.rank = 0
|
|
218
|
+
self.tensorboard_dir = os.path.join(self.output_base_dir, f"{cur_time}-rank{self.rank}-{self.unique_id}")
|
|
219
|
+
self.pp_stage = 0
|
|
220
|
+
self.group_mates = [0]
|
|
221
|
+
|
|
222
|
+
# TYPE2: 只会在monitor_gnorm_with_ad()主调中赋值的变量
|
|
223
|
+
self.model = None
|
|
224
|
+
self.vpp = False
|
|
225
|
+
self.dp_group = None
|
|
226
|
+
self.tp_group = None
|
|
227
|
+
self.enable_megatron = False
|
|
228
|
+
self.micro_batch_number = 1
|
|
229
|
+
|
|
230
|
+
# TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
|
|
231
|
+
self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
232
|
+
self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
233
|
+
self.optimizer_context = defaultdict(OptimizerContext)
|
|
234
|
+
self.cc_context = defaultdict(CommunicationContext)
|
|
235
|
+
self.grad_context = GradContext()
|
|
236
|
+
self.handles = defaultdict(list)
|
|
237
|
+
self.param2name = defaultdict(str)
|
|
238
|
+
self.name2index = defaultdict()
|
|
239
|
+
self.name2indices = defaultdict()
|
|
240
|
+
self.name2param = {}
|
|
241
|
+
self.duplicate_param = {}
|
|
242
|
+
self.name2tag = {}
|
|
243
|
+
self.param_name_call_id = {}
|
|
244
|
+
self.call_id = 0
|
|
245
|
+
self.module_struct = defaultdict(dict)
|
|
246
|
+
self.grad_accs = []
|
|
247
|
+
self.weight_hooked = False
|
|
248
|
+
self.optimizer_hooked = False
|
|
249
|
+
self.param_registered = False
|
|
250
|
+
self.struct_printed = False
|
|
251
|
+
|
|
252
|
+
# 动静态区分
|
|
253
|
+
self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
|
|
254
|
+
if self.dynamic_enable:
|
|
255
|
+
logger.warning(f"DYNAMIC_MONITOR is set, "
|
|
256
|
+
f"please make sure you have 'switch' and 'collect_times' item in {self.config_file_path}")
|
|
257
|
+
self.monitoring = False
|
|
258
|
+
else:
|
|
259
|
+
self.set_config()
|
|
260
|
+
# 静态且collect_times>0时在第0步self.monitoring就可以True, 动态默认在下一步开启
|
|
261
|
+
if self.collect_times > 0:
|
|
262
|
+
self.monitoring = True
|
|
263
|
+
|
|
264
|
+
def __del__(self):
|
|
265
|
+
if hasattr(self, "summary_writer"):
|
|
266
|
+
self.summary_writer.close()
|
|
267
|
+
|
|
268
|
+
@property
|
|
269
|
+
def ops(self):
|
|
270
|
+
return self._ops
|
|
271
|
+
|
|
272
|
+
@ops.setter
|
|
273
|
+
def ops(self, value):
|
|
274
|
+
self._ops = validate_ops(value)
|
|
275
|
+
|
|
276
|
+
@staticmethod
|
|
277
|
+
def set_wrapped_optimizer(_wrapped_optimizer):
|
|
278
|
+
OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer)
|
|
279
|
+
|
|
280
|
+
@staticmethod
|
|
281
|
+
def has_register_backward_hook(module_name, module):
|
|
282
|
+
if hasattr(module, '_backward_hooks') and \
|
|
283
|
+
len(module._backward_hooks) > 0 and \
|
|
284
|
+
module._is_full_backward_hook is False:
|
|
285
|
+
logger.warning(
|
|
286
|
+
f"The {module_name} has registered deprecated register_backward_hook,"
|
|
287
|
+
f"which may cause abnormal data dump. The backward input/output for this module will be skipped."
|
|
288
|
+
)
|
|
289
|
+
return True
|
|
290
|
+
return False
|
|
291
|
+
|
|
292
|
+
@staticmethod
|
|
293
|
+
def generate_cc_metrics(cc_name, cc_tensor):
|
|
294
|
+
metrics = defaultdict(dict)
|
|
295
|
+
rank = dist.get_rank() if dist.is_initialized() else None
|
|
296
|
+
for op, tag2tensor in cc_tensor.data.items():
|
|
297
|
+
for tag, tensor in tag2tensor.items():
|
|
298
|
+
key = get_summary_writer_tag_name(cc_name, tag, rank)
|
|
299
|
+
metrics[op].update({key: tensor})
|
|
300
|
+
cc_tensor.reset()
|
|
301
|
+
return metrics
|
|
302
|
+
|
|
303
|
+
def set_config(self):
|
|
304
|
+
logger.info(f"current config: {self.config}")
|
|
305
|
+
self.start_step = self.config.get("start_step", 0)
|
|
306
|
+
self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集
|
|
307
|
+
self.step_interval = self.config.get("step_interval", 1)
|
|
308
|
+
self.has_collect_times = 0 # 重设采集计数器
|
|
309
|
+
self.print_struct = self.config.get("print_struct", False)
|
|
170
310
|
self.module_rank_list = self.config.get("module_ranks", [])
|
|
171
311
|
self.format = self.config.get('format', 'tensorboard')
|
|
172
312
|
self.eps = self.config.get('eps', 1e-8)
|
|
@@ -182,6 +322,7 @@ class TrainerMon:
|
|
|
182
322
|
self.param_distribution = self.config.get("param_distribution", False)
|
|
183
323
|
self.mg_direction = self.config.get('mg_direction', False)
|
|
184
324
|
self.cc_distribution = self.config.get("cc_distribution", {})
|
|
325
|
+
|
|
185
326
|
if not self.cc_distribution.get('enable', False):
|
|
186
327
|
self.cc_log_only = False
|
|
187
328
|
else:
|
|
@@ -189,49 +330,30 @@ class TrainerMon:
|
|
|
189
330
|
self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
|
|
190
331
|
self.cc_logged_stack = defaultdict(set)
|
|
191
332
|
self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
|
|
192
|
-
api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
333
|
+
self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
193
334
|
api_register.redirect_api()
|
|
194
335
|
|
|
195
336
|
self.common_info()
|
|
196
337
|
|
|
338
|
+
# 初始化AnomalyData工厂
|
|
197
339
|
alert_setting = self.config.get('alert', {"rules": []})
|
|
198
340
|
self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
|
|
199
|
-
|
|
200
|
-
# 设置时区,使用 'UTC' 作为示例
|
|
201
|
-
local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
|
|
202
|
-
|
|
203
|
-
cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
|
|
204
|
-
unique_id = str(uuid.uuid4())[:8]
|
|
205
|
-
|
|
206
|
-
if dist.is_initialized():
|
|
207
|
-
rank = dist.get_rank()
|
|
208
|
-
tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}")
|
|
209
|
-
pp_stage = dist.get_group_rank(self.process_group, rank)
|
|
210
|
-
group_mates = dist.get_process_group_ranks(self.process_group)
|
|
211
|
-
else:
|
|
212
|
-
rank = 0
|
|
213
|
-
tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}")
|
|
214
|
-
pp_stage = 0
|
|
215
|
-
group_mates = [0]
|
|
216
|
-
self.rank = rank
|
|
217
|
-
|
|
218
|
-
# 初始化AnomalyData工厂
|
|
219
341
|
self.anomaly_data_factory = None
|
|
220
342
|
if alert_setting.get('dump', False):
|
|
221
|
-
self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates)
|
|
343
|
+
self.anomaly_data_factory = AnomalyDataFactory(self.rank, self.pp_stage, self.group_mates)
|
|
222
344
|
|
|
345
|
+
# 初始化writer, 创建输出目录
|
|
223
346
|
if self.format not in FORMAT_MAPPING:
|
|
224
347
|
raise ValueError(f"Unsupported format: {self.format}")
|
|
225
|
-
writer
|
|
348
|
+
writer = FORMAT_MAPPING[self.format]
|
|
226
349
|
self.step_count_per_record = self.config.get('step_count_per_record', 1)
|
|
227
350
|
|
|
228
|
-
if (rank in self.module_rank_list) or len(self.module_rank_list) == 0:
|
|
351
|
+
if (self.rank in self.module_rank_list) or len(self.module_rank_list) == 0:
|
|
229
352
|
self.summary_writer = writer(
|
|
230
353
|
WriterInput(
|
|
231
|
-
tensorboard_dir,
|
|
354
|
+
self.tensorboard_dir,
|
|
232
355
|
self.alert_rules,
|
|
233
|
-
unique_id,
|
|
234
|
-
None,
|
|
356
|
+
self.unique_id,
|
|
235
357
|
self.anomaly_data_factory,
|
|
236
358
|
self.ndigits,
|
|
237
359
|
self.step_count_per_record
|
|
@@ -239,83 +361,22 @@ class TrainerMon:
|
|
|
239
361
|
)
|
|
240
362
|
# 初始化anomaly detected文件目录
|
|
241
363
|
if self.anomaly_data_factory:
|
|
242
|
-
self.anomaly_data_writer = AnomalyDataWriter(os.path.join(output_base_dir, "anomaly_detected"),
|
|
364
|
+
self.anomaly_data_writer = AnomalyDataWriter(os.path.join(self.output_base_dir, "anomaly_detected"),
|
|
365
|
+
self.rank)
|
|
243
366
|
self.anomaly_data_writer.init_detected_json()
|
|
244
367
|
|
|
245
|
-
|
|
246
|
-
self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
247
|
-
self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
248
|
-
self.micro_batch_number = 1
|
|
249
|
-
|
|
250
|
-
self.model = None
|
|
251
|
-
self.weight_hooked = False
|
|
252
|
-
self.optimizer_hooked = False
|
|
253
|
-
self.param_registered = False
|
|
254
|
-
self.vpp = False
|
|
255
|
-
self.dp_group = None
|
|
256
|
-
self.tp_group = None
|
|
257
|
-
self.enable_megatron = False
|
|
258
|
-
|
|
259
|
-
self.param2name = defaultdict(str)
|
|
260
|
-
self.name2index = defaultdict()
|
|
261
|
-
self.name2indices = defaultdict()
|
|
262
|
-
self.name2param = {}
|
|
263
|
-
self.param_name_call_id = {}
|
|
264
|
-
self.duplicate_param = {}
|
|
265
|
-
self.name2tag = {}
|
|
266
|
-
self.call_id = 0
|
|
267
|
-
self.grad_accs = []
|
|
268
|
-
self.handles = defaultdict(list)
|
|
269
|
-
|
|
270
|
-
self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty)
|
|
271
|
-
self.print_struct = self.config.get("print_struct", False)
|
|
272
|
-
self.struct_printed = False
|
|
273
|
-
self.module_struct = {}
|
|
274
|
-
|
|
275
|
-
def __del__(self):
|
|
276
|
-
if hasattr(self, "summary_writer"):
|
|
277
|
-
self.summary_writer.close()
|
|
278
|
-
|
|
279
|
-
@property
|
|
280
|
-
def ops(self):
|
|
281
|
-
return self._ops
|
|
282
|
-
|
|
283
|
-
@ops.setter
|
|
284
|
-
def ops(self, value):
|
|
285
|
-
self._ops = validate_ops(value)
|
|
286
|
-
|
|
287
|
-
@staticmethod
|
|
288
|
-
def set_wrapped_optimizer(_wrapped_optimizer):
|
|
289
|
-
OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer)
|
|
290
|
-
|
|
291
|
-
@staticmethod
|
|
292
|
-
def adhoc_check(target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
|
|
368
|
+
def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
|
|
293
369
|
rank = None
|
|
294
370
|
if dist.is_initialized():
|
|
295
371
|
rank = dist.get_rank()
|
|
296
372
|
if (rank not in rank_list) and len(rank_list) != 0:
|
|
297
373
|
return
|
|
298
|
-
|
|
374
|
+
self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
|
|
299
375
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
key = get_summary_writer_tag_name(module_name, tag, rank)
|
|
305
|
-
if torch.is_tensor(tensor):
|
|
306
|
-
metrics[key] = tensor
|
|
307
|
-
return metrics
|
|
308
|
-
|
|
309
|
-
@staticmethod
|
|
310
|
-
def generate_cc_metrics(cc_name, cc_tensor):
|
|
311
|
-
metrics = defaultdict(dict)
|
|
312
|
-
rank = dist.get_rank() if dist.is_initialized() else None
|
|
313
|
-
for op, tag2tensor in cc_tensor.data.items():
|
|
314
|
-
for tag, tensor in tag2tensor.items():
|
|
315
|
-
key = get_summary_writer_tag_name(cc_name, tag, rank)
|
|
316
|
-
metrics[op].update({key: tensor})
|
|
317
|
-
cc_tensor.reset()
|
|
318
|
-
return metrics
|
|
376
|
+
def build_tbtag_tensor_map(self, module_name, tag, tensor):
|
|
377
|
+
key = get_summary_writer_tag_name(module_name, tag, self.rank)
|
|
378
|
+
self._register_param_call_id("_hook_module", key)
|
|
379
|
+
return {key: tensor}
|
|
319
380
|
|
|
320
381
|
def common_info(self):
|
|
321
382
|
if not self.xy_distribution:
|
|
@@ -338,31 +399,24 @@ class TrainerMon:
|
|
|
338
399
|
if self.mv_distribution:
|
|
339
400
|
raise Exception("mv_distribution cannot be enabled with unknown optimizer.")
|
|
340
401
|
|
|
341
|
-
def hook_modules(self
|
|
402
|
+
def hook_modules(self):
|
|
342
403
|
if self.module_rank_list and (self.rank not in self.module_rank_list):
|
|
343
404
|
return
|
|
344
405
|
|
|
345
|
-
if not isinstance(model, list):
|
|
346
|
-
model = [model]
|
|
347
|
-
self.model = model
|
|
348
|
-
self._register_param_name(model)
|
|
349
|
-
|
|
350
|
-
self.micro_batch_number = grad_acc_steps
|
|
351
|
-
|
|
352
406
|
targets = self.config['targets']
|
|
353
407
|
module_in_all_stage = [key for key in targets.keys() if MonitorConst.VPP_SEP not in key]
|
|
354
408
|
for key in module_in_all_stage:
|
|
355
409
|
struct = targets.pop(key)
|
|
356
|
-
targets.update({f'{vpp_stage}{MonitorConst.VPP_SEP}{key}': struct for vpp_stage in range(len(model))})
|
|
410
|
+
targets.update({f'{vpp_stage}{MonitorConst.VPP_SEP}{key}': struct for vpp_stage in range(len(self.model))})
|
|
357
411
|
|
|
358
412
|
hooked_count = 0
|
|
359
|
-
for vpp_stage, model_chunk in enumerate(model):
|
|
413
|
+
for vpp_stage, model_chunk in enumerate(self.model):
|
|
360
414
|
vpp_stage = f'{vpp_stage}{MonitorConst.VPP_SEP}'
|
|
361
415
|
targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
|
|
362
416
|
'targets'].keys()
|
|
363
417
|
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
364
418
|
|
|
365
|
-
logger.info_on_rank_0(f"> {hooked_count}
|
|
419
|
+
logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
|
|
366
420
|
|
|
367
421
|
def clone_if_tensor(args):
|
|
368
422
|
if isinstance(args, tuple):
|
|
@@ -383,11 +437,11 @@ class TrainerMon:
|
|
|
383
437
|
|
|
384
438
|
BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
|
|
385
439
|
|
|
386
|
-
if not self.optimizer_hooked:
|
|
387
|
-
self.hook_optimizer()
|
|
388
440
|
return
|
|
389
441
|
|
|
390
442
|
def generate_param_metrics(self, opt_context):
|
|
443
|
+
if not self.param_distribution:
|
|
444
|
+
return
|
|
391
445
|
get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
|
|
392
446
|
|
|
393
447
|
def generate_mv_metrics(self, opt_context):
|
|
@@ -416,29 +470,50 @@ class TrainerMon:
|
|
|
416
470
|
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
417
471
|
continue
|
|
418
472
|
tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
473
|
+
self._register_param_call_id("hook_optimizer", tag)
|
|
419
474
|
grad_dict[tag] = grad
|
|
420
475
|
|
|
421
476
|
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
|
|
422
477
|
return self.grad_context.post, self.grad_context.pre
|
|
423
478
|
|
|
424
|
-
def monitor_gnorm_with_ad(
|
|
479
|
+
def monitor_gnorm_with_ad(
|
|
480
|
+
self,
|
|
481
|
+
model,
|
|
482
|
+
grad_acc_steps=1,
|
|
483
|
+
optimizer=None,
|
|
484
|
+
tp_group=None,
|
|
485
|
+
dp_group=None,
|
|
486
|
+
start_iteration=0
|
|
487
|
+
):
|
|
425
488
|
"""External interface"""
|
|
489
|
+
global start_step
|
|
490
|
+
start_step = start_iteration
|
|
426
491
|
logger.info(f'grad acc steps {grad_acc_steps}')
|
|
427
|
-
self.hook_optimizer(optimizer)
|
|
428
492
|
self.micro_batch_number = grad_acc_steps
|
|
429
|
-
|
|
430
493
|
self.dp_group = dp_group
|
|
431
494
|
self.tp_group = tp_group
|
|
495
|
+
self.hook_step_final(optimizer)
|
|
496
|
+
if not isinstance(model, list):
|
|
497
|
+
model = [model]
|
|
498
|
+
self.model = model
|
|
499
|
+
if len(model) > 1:
|
|
500
|
+
self.vpp = True
|
|
501
|
+
self._smallest_rank_print('vpp enabled')
|
|
502
|
+
if not self.dynamic_enable:
|
|
503
|
+
self.register_hooks(optimizer)
|
|
432
504
|
|
|
433
|
-
|
|
505
|
+
def register_hooks(self, optimizer):
|
|
506
|
+
self._register_param_name()
|
|
507
|
+
self.hook_optimizer(optimizer)
|
|
434
508
|
self._patch_grad_sync()
|
|
435
|
-
self.hook_modules(
|
|
509
|
+
self.hook_modules()
|
|
510
|
+
self.monitoring = True
|
|
436
511
|
|
|
437
512
|
def generate_param_map(self, tag, param_tensor):
|
|
438
513
|
metrics = {}
|
|
439
|
-
rank = dist.get_rank() if dist.is_initialized() else None
|
|
440
514
|
for name in self.param2name.values():
|
|
441
|
-
key = get_summary_writer_tag_name(name, tag, rank)
|
|
515
|
+
key = get_summary_writer_tag_name(name, tag, self.rank)
|
|
516
|
+
self._register_param_call_id("optimizer_pre_step_hook", key)
|
|
442
517
|
if name not in param_tensor or param_tensor[name] is None:
|
|
443
518
|
continue
|
|
444
519
|
metrics[key] = param_tensor[name]
|
|
@@ -459,12 +534,12 @@ class TrainerMon:
|
|
|
459
534
|
for handle in self.handles['xy']:
|
|
460
535
|
handle.remove()
|
|
461
536
|
self.handles['xy'].clear()
|
|
462
|
-
self.hook_modules(
|
|
537
|
+
self.hook_modules()
|
|
463
538
|
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
464
539
|
fwd_context.actv.clear()
|
|
465
540
|
|
|
466
541
|
def write_adhoc_check(self, step):
|
|
467
|
-
|
|
542
|
+
self.tensor_metrics.flush(self.summary_writer)
|
|
468
543
|
|
|
469
544
|
def write_xy_tb(self, step):
|
|
470
545
|
if not self.xy_distribution:
|
|
@@ -472,40 +547,53 @@ class TrainerMon:
|
|
|
472
547
|
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
473
548
|
if len(fwd_context.actv) == 0:
|
|
474
549
|
continue
|
|
475
|
-
self.write_metrics(self.ops,
|
|
550
|
+
self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, 'actv')
|
|
476
551
|
fwd_context.actv.clear()
|
|
477
552
|
if self.grad_context.actv:
|
|
478
|
-
self.write_metrics(self.ops, self.
|
|
553
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, 'actv_grad')
|
|
479
554
|
|
|
480
555
|
def write_param_tb(self, opt_context):
|
|
481
556
|
if not self.param_distribution:
|
|
482
557
|
return
|
|
483
|
-
self.write_metrics(self.ops,
|
|
558
|
+
self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, 'param')
|
|
484
559
|
|
|
485
560
|
def write_mv_tb(self, opt_context):
|
|
486
561
|
if not self.mv_distribution:
|
|
487
562
|
return
|
|
488
|
-
self.write_metrics(self.ops,
|
|
489
|
-
|
|
490
|
-
self.write_metrics(self.ops, self.summary_writer, opt_context.exp_avg_sq_metric,
|
|
491
|
-
opt_context.step, 'exp_avg_sq')
|
|
563
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, 'exp_avg')
|
|
564
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq')
|
|
492
565
|
|
|
493
566
|
def write_grad_tb(self, step):
|
|
494
567
|
if not self.wg_distribution:
|
|
495
568
|
return
|
|
496
569
|
|
|
497
570
|
if self.enable_megatron:
|
|
498
|
-
self.write_metrics(self.ops, self.
|
|
571
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
|
|
499
572
|
else:
|
|
500
|
-
self.write_metrics(self.ops, self.
|
|
501
|
-
self.write_metrics(self.ops, self.
|
|
573
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
|
|
574
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
502
575
|
|
|
503
576
|
def hook_optimizer(self, optimizer=None):
|
|
504
577
|
# in DDP by default use params_have_main_grad
|
|
505
578
|
def optimizer_pre_step_hook(optimizer, args, kwargs):
|
|
506
579
|
context = self.optimizer_context[optimizer]
|
|
580
|
+
|
|
581
|
+
if (self.print_struct and not all(value == {} for value in self.module_struct.values())
|
|
582
|
+
and not self.struct_printed):
|
|
583
|
+
self._save_module_struct()
|
|
584
|
+
if not self.cc_log_only:
|
|
585
|
+
raise Exception("exit after first monitor step when print model struct")
|
|
586
|
+
if self.cc_log_only and context.step > 0:
|
|
587
|
+
self._smallest_rank_print("> Used communication ops and corresponding stack")
|
|
588
|
+
self._smallest_rank_print(
|
|
589
|
+
json.dumps({k: [i.split(';') for i in v] for k, v in self.cc_logged_stack.items()}))
|
|
590
|
+
raise Exception("exit after first step when print cc stack")
|
|
591
|
+
|
|
592
|
+
# skip generate metrics
|
|
593
|
+
if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0:
|
|
594
|
+
return
|
|
507
595
|
if self.opt_ty in MonitorConst.DEEPSPEED_OPT_TY:
|
|
508
|
-
if
|
|
596
|
+
if not self.name2indices:
|
|
509
597
|
self.name2indices = self.mix_precision_optimizer_mon.get_param_index(self.param2name,
|
|
510
598
|
self.name2index)
|
|
511
599
|
mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name,
|
|
@@ -518,19 +606,6 @@ class TrainerMon:
|
|
|
518
606
|
context.param_adam_update = mv_result.update
|
|
519
607
|
context.param_adam_ratio = mv_result.ratio
|
|
520
608
|
|
|
521
|
-
if (self.print_struct and not all(value == {} for value in self.module_struct.values())
|
|
522
|
-
and not self.struct_printed):
|
|
523
|
-
self._smallest_rank_print("> module struct:")
|
|
524
|
-
self._smallest_rank_print(json.dumps(self.module_struct))
|
|
525
|
-
self.struct_printed = True
|
|
526
|
-
if not self.cc_log_only:
|
|
527
|
-
raise Exception("exit after first step when print model struct")
|
|
528
|
-
if self.cc_log_only and context.step > 0:
|
|
529
|
-
self._smallest_rank_print("> Used communication ops and corresponding stack")
|
|
530
|
-
self._smallest_rank_print(
|
|
531
|
-
json.dumps({k: [i.split(';') for i in v] for k, v in self.cc_logged_stack.items()}))
|
|
532
|
-
raise Exception("exit after first step when print cc stack")
|
|
533
|
-
|
|
534
609
|
self.generate_wgrad_metrics()
|
|
535
610
|
self.generate_mv_metrics(context)
|
|
536
611
|
self.generate_param_metrics(context)
|
|
@@ -561,41 +636,10 @@ class TrainerMon:
|
|
|
561
636
|
context.metric_dict = metric_dict
|
|
562
637
|
return
|
|
563
638
|
|
|
564
|
-
def optimizer_post_step_hook(optimizer, args, kwargs):
|
|
565
|
-
context = self.optimizer_context[optimizer]
|
|
566
|
-
rank = dist.get_rank() if dist.is_initialized() else None
|
|
567
|
-
|
|
568
|
-
if self.anomaly_data_factory:
|
|
569
|
-
self.anomaly_data_factory.set_call_id(self.param_name_call_id)
|
|
570
|
-
self.write_xy_tb(context.step)
|
|
571
|
-
self.write_grad_tb(context.step)
|
|
572
|
-
self.write_mv_tb(context)
|
|
573
|
-
self.write_param_tb(context)
|
|
574
|
-
self.write_adhoc_check(context.step)
|
|
575
|
-
|
|
576
|
-
if self.ur_distribution:
|
|
577
|
-
for param_name, _ in context.param_adam_update.items():
|
|
578
|
-
self.update_heatmap_visualizer[param_name].visualize(
|
|
579
|
-
get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, self.summary_writer)
|
|
580
|
-
for param_name, _ in context.param_adam_ratio.items():
|
|
581
|
-
self.ratio_heatmap_visualizer[param_name].visualize(
|
|
582
|
-
get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer)
|
|
583
|
-
|
|
584
|
-
if context.metric_dict:
|
|
585
|
-
self.write_metrics(self.ops, self.summary_writer, context.metric_dict, context.step, 'other')
|
|
586
|
-
context.metric_dict.clear()
|
|
587
|
-
context.step += 1
|
|
588
|
-
if self.anomaly_data_factory:
|
|
589
|
-
self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
|
|
590
|
-
self.summary_writer.clear_anomalies()
|
|
591
|
-
self.call_id = 0
|
|
592
|
-
return
|
|
593
|
-
|
|
594
639
|
def patch_step(func, optimizer):
|
|
595
640
|
def wrapper(*args, **kwargs):
|
|
596
641
|
optimizer_pre_step_hook(optimizer, args, kwargs)
|
|
597
642
|
out = func(*args, **kwargs)
|
|
598
|
-
optimizer_post_step_hook(optimizer, args, kwargs)
|
|
599
643
|
return out
|
|
600
644
|
|
|
601
645
|
return wrapper
|
|
@@ -605,14 +649,171 @@ class TrainerMon:
|
|
|
605
649
|
|
|
606
650
|
if optimizer:
|
|
607
651
|
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
608
|
-
|
|
652
|
+
self.handles['optimizer'] = []
|
|
609
653
|
else:
|
|
610
654
|
if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list):
|
|
611
|
-
register_optimizer_step_pre_hook(optimizer_pre_step_hook)
|
|
612
|
-
|
|
655
|
+
step_pre_hook = register_optimizer_step_pre_hook(optimizer_pre_step_hook)
|
|
656
|
+
self.handles['optimizer'] = [step_pre_hook]
|
|
613
657
|
self.optimizer_hooked = True
|
|
614
658
|
return
|
|
615
659
|
|
|
660
|
+
def dynamic_monitor(self, optimizer):
|
|
661
|
+
"""
|
|
662
|
+
If dynamic monitor enabled and config.json updated,
|
|
663
|
+
remove hooks and register new hooks according to new configuration.
|
|
664
|
+
"""
|
|
665
|
+
context = self.optimizer_context[optimizer]
|
|
666
|
+
if not self.dynamic_enable:
|
|
667
|
+
return
|
|
668
|
+
try:
|
|
669
|
+
# 如果文件时间戳没变, 可以不读取节省时间
|
|
670
|
+
config_timestamp = os.path.getmtime(self.config_file_path)
|
|
671
|
+
if config_timestamp == self.config_timestamp:
|
|
672
|
+
return
|
|
673
|
+
# 更新config文件最新修改时间戳
|
|
674
|
+
self.config_timestamp = config_timestamp
|
|
675
|
+
config = load_json(self.config_file_path)
|
|
676
|
+
except Exception as e:
|
|
677
|
+
logger.error(f"get config.json wrong because {e}, not updated, please check!!!")
|
|
678
|
+
return
|
|
679
|
+
|
|
680
|
+
if config.get("switch", False):
|
|
681
|
+
try:
|
|
682
|
+
validate_config(config)
|
|
683
|
+
self.config = config
|
|
684
|
+
self.set_config()
|
|
685
|
+
logger.warning(f"config is updated at step{context.step - 1}, "
|
|
686
|
+
f"will start new hook at step{context.step}.")
|
|
687
|
+
except Exception as e:
|
|
688
|
+
logger.error(f"set config wrong because {e}, not updated, please check!!!")
|
|
689
|
+
return
|
|
690
|
+
|
|
691
|
+
self._remove_all_hooks(optimizer)
|
|
692
|
+
self.register_hooks(optimizer)
|
|
693
|
+
|
|
694
|
+
def hook_step_final(self, optimizer):
|
|
695
|
+
def step_final_hook(optimizer, args, kwargs):
|
|
696
|
+
context = self.optimizer_context[optimizer]
|
|
697
|
+
rank = dist.get_rank() if dist.is_initialized() else None
|
|
698
|
+
# 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
|
|
699
|
+
if self.monitoring:
|
|
700
|
+
module_rank_valid = not self.module_rank_list or (
|
|
701
|
+
dist.is_initialized() and dist.get_rank() in self.module_rank_list)
|
|
702
|
+
step_condition = (context.step >= self.start_step and (
|
|
703
|
+
context.step - self.start_step) % self.step_interval == 0)
|
|
704
|
+
if module_rank_valid and step_condition:
|
|
705
|
+
self.has_collect_times += 1
|
|
706
|
+
|
|
707
|
+
if self.anomaly_data_factory:
|
|
708
|
+
self.anomaly_data_factory.set_call_id(self.param_name_call_id)
|
|
709
|
+
self.write_xy_tb(context.step)
|
|
710
|
+
self.write_grad_tb(context.step)
|
|
711
|
+
self.write_mv_tb(context)
|
|
712
|
+
self.write_param_tb(context)
|
|
713
|
+
self.write_adhoc_check(context.step)
|
|
714
|
+
|
|
715
|
+
if self.ur_distribution:
|
|
716
|
+
for param_name, _ in context.param_adam_update.items():
|
|
717
|
+
self.update_heatmap_visualizer[param_name].visualize(
|
|
718
|
+
get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step,
|
|
719
|
+
self.summary_writer)
|
|
720
|
+
for param_name, _ in context.param_adam_ratio.items():
|
|
721
|
+
self.ratio_heatmap_visualizer[param_name].visualize(
|
|
722
|
+
get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step,
|
|
723
|
+
self.summary_writer)
|
|
724
|
+
|
|
725
|
+
if context.metric_dict:
|
|
726
|
+
self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
|
|
727
|
+
context.metric_dict.clear()
|
|
728
|
+
|
|
729
|
+
if self.anomaly_data_factory:
|
|
730
|
+
self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
|
|
731
|
+
self.summary_writer.clear_anomalies()
|
|
732
|
+
self.call_id = 0
|
|
733
|
+
self.param_name_call_id.clear()
|
|
734
|
+
|
|
735
|
+
if self.has_collect_times >= self.collect_times:
|
|
736
|
+
self._remove_all_hooks_final(optimizer)
|
|
737
|
+
|
|
738
|
+
context.step += 1
|
|
739
|
+
self.dynamic_monitor(optimizer)
|
|
740
|
+
|
|
741
|
+
def patch_step(func, optimizer):
|
|
742
|
+
def wrapper(*args, **kwargs):
|
|
743
|
+
out = func(*args, **kwargs)
|
|
744
|
+
step_final_hook(optimizer, args, kwargs)
|
|
745
|
+
return out
|
|
746
|
+
return wrapper
|
|
747
|
+
|
|
748
|
+
if optimizer:
|
|
749
|
+
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
750
|
+
self.origin_step_func = optimizer.__class__.step
|
|
751
|
+
else:
|
|
752
|
+
register_optimizer_step_post_hook(step_final_hook)
|
|
753
|
+
return
|
|
754
|
+
|
|
755
|
+
def _remove_all_hooks(self, optimizer):
|
|
756
|
+
# 清空hook handle
|
|
757
|
+
for handle in self.handles['xy']:
|
|
758
|
+
handle.remove()
|
|
759
|
+
self.handles['xy'].clear()
|
|
760
|
+
# 清空对应context缓存
|
|
761
|
+
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
762
|
+
fwd_context.reset()
|
|
763
|
+
for _, bwd_context in self.module_bwd_hook_context_by_module.items():
|
|
764
|
+
bwd_context.reset()
|
|
765
|
+
self.grad_context.reset() # 权重梯度和激活值梯度都在这
|
|
766
|
+
|
|
767
|
+
for handle in self.handles['wgrads']:
|
|
768
|
+
handle.remove()
|
|
769
|
+
self.handles['wgrads'].clear()
|
|
770
|
+
self.weight_hooked = False
|
|
771
|
+
|
|
772
|
+
if len(self.handles['optimizer']) == 0 and self.optimizer_hooked:
|
|
773
|
+
optimizer.__class__.step = self.origin_step_func
|
|
774
|
+
else:
|
|
775
|
+
for handle in self.handles['optimizer']:
|
|
776
|
+
handle.remove()
|
|
777
|
+
self.handles['optimizer'].clear()
|
|
778
|
+
for _, context in self.optimizer_context.items():
|
|
779
|
+
context.reset()
|
|
780
|
+
self.optimizer_hooked = False
|
|
781
|
+
|
|
782
|
+
for handle in self.handles['cc']:
|
|
783
|
+
handle.remove()
|
|
784
|
+
self.handles['cc'].clear()
|
|
785
|
+
for _, context in self.cc_context.items():
|
|
786
|
+
context.reset()
|
|
787
|
+
|
|
788
|
+
# 清空节点缓存
|
|
789
|
+
self.param2name.clear()
|
|
790
|
+
self.name2index.clear()
|
|
791
|
+
self.name2indices.clear()
|
|
792
|
+
self.name2param.clear()
|
|
793
|
+
self.duplicate_param.clear()
|
|
794
|
+
self.name2tag.clear()
|
|
795
|
+
self.module_struct.clear()
|
|
796
|
+
self.grad_accs.clear()
|
|
797
|
+
|
|
798
|
+
# 关闭采集状态
|
|
799
|
+
self.monitoring = False
|
|
800
|
+
|
|
801
|
+
def _remove_all_hooks_final(self, optimizer):
|
|
802
|
+
if self.dynamic_enable:
|
|
803
|
+
# 结束后自动重置switch为False等待用户手动开启
|
|
804
|
+
try:
|
|
805
|
+
config = load_json(self.config_file_path)
|
|
806
|
+
config['switch'] = False
|
|
807
|
+
save_json(self.config_file_path, config, indent=2)
|
|
808
|
+
config_timestamp = os.path.getmtime(self.config_file_path)
|
|
809
|
+
self.config_timestamp = config_timestamp
|
|
810
|
+
logger.info(
|
|
811
|
+
"Finish monitor, set config'switch=False, will restart by set switch=True and update content")
|
|
812
|
+
except Exception as e:
|
|
813
|
+
logger.warning(f"Finish monitor, set config'switch=False fail because {e}, please check!!!")
|
|
814
|
+
logger.info("Finish monitor")
|
|
815
|
+
self._remove_all_hooks(optimizer)
|
|
816
|
+
|
|
616
817
|
def _smallest_rank_print(self, msg):
|
|
617
818
|
if dist.is_initialized():
|
|
618
819
|
if self.module_rank_list:
|
|
@@ -624,9 +825,20 @@ class TrainerMon:
|
|
|
624
825
|
else:
|
|
625
826
|
logger.info(msg)
|
|
626
827
|
|
|
828
|
+
def _save_module_struct(self):
|
|
829
|
+
save_module_struct = (not dist.is_initialized()
|
|
830
|
+
or (self.module_rank_list and dist.get_rank() == min(self.module_rank_list))
|
|
831
|
+
or (not self.module_rank_list and dist.get_rank() == 0))
|
|
832
|
+
|
|
833
|
+
if save_module_struct:
|
|
834
|
+
module_struct_file = os.path.realpath(os.path.join(get_output_base_dir(), 'module_struct.json'))
|
|
835
|
+
save_json(module_struct_file, self.module_struct, indent=2)
|
|
836
|
+
logger.info(f"> save module struct to {module_struct_file}")
|
|
837
|
+
self.struct_printed = True
|
|
838
|
+
|
|
627
839
|
def _is_target_param(self, param_name, param, prefix):
|
|
628
|
-
squash_name = prefix + squash_param_name(param_name)
|
|
629
840
|
name = prefix + param_name
|
|
841
|
+
squash_name = prefix + squash_param_name(param_name, self.squash_name)
|
|
630
842
|
for target in self.config['targets'].keys():
|
|
631
843
|
if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
|
|
632
844
|
setattr(param, "zero_out_wgrad", True)
|
|
@@ -635,15 +847,14 @@ class TrainerMon:
|
|
|
635
847
|
return False
|
|
636
848
|
|
|
637
849
|
def _register_chunk(self, model_chunk, prefix):
|
|
638
|
-
|
|
850
|
+
index = 0
|
|
851
|
+
for (param_name, param) in model_chunk.named_parameters():
|
|
639
852
|
if not param.requires_grad:
|
|
640
853
|
continue
|
|
641
854
|
if self._is_target_param(param_name, param, prefix):
|
|
642
|
-
name = prefix + squash_param_name(param_name)
|
|
855
|
+
name = prefix + squash_param_name(param_name, self.squash_name)
|
|
643
856
|
if name in self.param2name.values():
|
|
644
|
-
|
|
645
|
-
May be error of squash_param_name')
|
|
646
|
-
raise Exception("param with same name will be overwritten.")
|
|
857
|
+
name = prefix + param_name
|
|
647
858
|
self.param2name[param] = name
|
|
648
859
|
self.name2param[name] = param
|
|
649
860
|
self.name2index[name] = index
|
|
@@ -652,34 +863,22 @@ class TrainerMon:
|
|
|
652
863
|
self.duplicate_param[name] = True
|
|
653
864
|
if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
|
|
654
865
|
self.duplicate_param[name] = True
|
|
655
|
-
self.name2tag[name] = {
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
def _register_param_name(self
|
|
662
|
-
|
|
663
|
-
return
|
|
664
|
-
|
|
665
|
-
if not isinstance(model, list):
|
|
666
|
-
model = [model]
|
|
667
|
-
|
|
668
|
-
if len(model) > 1:
|
|
669
|
-
self.vpp = True
|
|
670
|
-
self._smallest_rank_print('vpp enabled')
|
|
671
|
-
|
|
672
|
-
for vpp_stage, model_chunk in enumerate(model):
|
|
866
|
+
self.name2tag[name] = {
|
|
867
|
+
MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank),
|
|
868
|
+
MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank)
|
|
869
|
+
}
|
|
870
|
+
index += 1
|
|
871
|
+
|
|
872
|
+
def _register_param_name(self):
|
|
873
|
+
for vpp_stage, model_chunk in enumerate(self.model):
|
|
673
874
|
prefix = f'{vpp_stage}{MonitorConst.VPP_SEP}'
|
|
674
875
|
self._register_chunk(model_chunk, prefix)
|
|
675
876
|
|
|
676
|
-
self.param_registered = True
|
|
677
|
-
|
|
678
877
|
def _is_target_module(self, module_name, targets, vpp_stage):
|
|
679
878
|
if self.all_xy or self.print_struct:
|
|
680
|
-
return vpp_stage + squash_param_name(module_name)
|
|
879
|
+
return vpp_stage + squash_param_name(module_name, self.squash_name)
|
|
681
880
|
for pattern in [
|
|
682
|
-
vpp_stage + squash_param_name(module_name),
|
|
881
|
+
vpp_stage + squash_param_name(module_name, self.squash_name),
|
|
683
882
|
vpp_stage + module_name,
|
|
684
883
|
]:
|
|
685
884
|
if pattern in targets:
|
|
@@ -692,63 +891,59 @@ class TrainerMon:
|
|
|
692
891
|
return 0
|
|
693
892
|
|
|
694
893
|
def fwd_hook_fun(module, module_input, module_output, name):
|
|
695
|
-
if is_recomputation():
|
|
894
|
+
if not module.training or is_recomputation():
|
|
895
|
+
# 1 only monitor training stage.
|
|
896
|
+
# 2 when open recompute, skip recomputed forward stage.
|
|
696
897
|
return
|
|
697
898
|
if module not in self.module_fwd_hook_context_by_module:
|
|
698
899
|
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
699
900
|
context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
|
|
700
901
|
if not context.struct:
|
|
701
|
-
context.struct = {
|
|
702
|
-
|
|
902
|
+
context.struct = {
|
|
903
|
+
MonitorConst.ACTV_IN: get_param_struct(module_input),
|
|
904
|
+
MonitorConst.ACTV_OUT: get_param_struct(module_output)
|
|
905
|
+
}
|
|
703
906
|
if self.print_struct:
|
|
704
|
-
if context.module_name not in self.module_struct:
|
|
705
|
-
self.module_struct[context.module_name] = {}
|
|
706
907
|
self.module_struct[context.module_name].update(context.struct)
|
|
707
908
|
return
|
|
708
|
-
if not module.training:
|
|
709
|
-
return
|
|
710
909
|
if not context.format_by_arg:
|
|
711
910
|
context.set_format_by_arg(MonitorConst.ACTV_IN, self.config['targets'])
|
|
712
911
|
context.set_format_by_arg(MonitorConst.ACTV_OUT, self.config['targets'])
|
|
713
912
|
if not context.format_by_arg:
|
|
714
913
|
return
|
|
715
914
|
if not context.verified:
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
MonitorConst.ACTV_IN)
|
|
915
|
+
context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
|
|
916
|
+
module_input, context.module_name,
|
|
917
|
+
MonitorConst.ACTV_IN)
|
|
720
918
|
context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
|
|
721
919
|
module_output, context.module_name,
|
|
722
920
|
MonitorConst.ACTV_OUT)
|
|
723
921
|
context.verified = True
|
|
724
922
|
# expect output be tensor type
|
|
725
923
|
tbtag_tensor_map = {}
|
|
726
|
-
if
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
cared_input))
|
|
924
|
+
cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
|
|
925
|
+
tbtag_tensor_map.update(
|
|
926
|
+
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
|
|
927
|
+
cared_input))
|
|
731
928
|
cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
|
|
732
929
|
tbtag_tensor_map.update(
|
|
733
930
|
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT,
|
|
734
931
|
cared_output))
|
|
735
932
|
|
|
736
933
|
get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
|
|
737
|
-
|
|
738
934
|
context.micro_step += 1
|
|
739
935
|
if context.micro_step == self.micro_batch_number:
|
|
740
936
|
context.micro_step = 0
|
|
741
|
-
context.step += 1
|
|
742
937
|
return
|
|
743
938
|
|
|
744
939
|
def bwd_hook_fun(module, input_grad, output_grad):
|
|
745
940
|
context: ModuleHookContext = self.module_bwd_hook_context_by_module[module]
|
|
746
941
|
if not context.struct:
|
|
747
|
-
context.struct = {
|
|
748
|
-
|
|
942
|
+
context.struct = {
|
|
943
|
+
MonitorConst.ACTVGRAD_IN: get_param_struct(input_grad),
|
|
944
|
+
MonitorConst.ACTVGRAD_OUT: get_param_struct(output_grad)
|
|
945
|
+
}
|
|
749
946
|
if self.print_struct:
|
|
750
|
-
if context.module_name not in self.module_struct:
|
|
751
|
-
self.module_struct[context.module_name] = {}
|
|
752
947
|
self.module_struct[context.module_name].update(context.struct)
|
|
753
948
|
return
|
|
754
949
|
if not context.format_by_arg:
|
|
@@ -757,21 +952,19 @@ class TrainerMon:
|
|
|
757
952
|
if not context.format_by_arg:
|
|
758
953
|
return
|
|
759
954
|
if not context.verified:
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
MonitorConst.ACTVGRAD_IN)
|
|
955
|
+
context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
|
|
956
|
+
input_grad, context.module_name,
|
|
957
|
+
MonitorConst.ACTVGRAD_IN)
|
|
764
958
|
context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
|
|
765
959
|
output_grad, context.module_name,
|
|
766
960
|
MonitorConst.ACTVGRAD_OUT)
|
|
767
961
|
context.verified = True
|
|
768
962
|
|
|
769
963
|
tbtag_tensor_map = {}
|
|
770
|
-
if
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad))
|
|
964
|
+
cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
|
|
965
|
+
tbtag_tensor_map.update(
|
|
966
|
+
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN,
|
|
967
|
+
cared_input_grad))
|
|
775
968
|
cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
|
|
776
969
|
tbtag_tensor_map.update(
|
|
777
970
|
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT,
|
|
@@ -787,7 +980,6 @@ class TrainerMon:
|
|
|
787
980
|
context.micro_step += 1
|
|
788
981
|
if context.micro_step == self.micro_batch_number:
|
|
789
982
|
context.micro_step = 0
|
|
790
|
-
context.step += 1
|
|
791
983
|
return
|
|
792
984
|
|
|
793
985
|
if self.backward_only and self.forward_only:
|
|
@@ -802,7 +994,7 @@ class TrainerMon:
|
|
|
802
994
|
if not self.backward_only:
|
|
803
995
|
handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name))
|
|
804
996
|
self.handles['xy'].append(handle)
|
|
805
|
-
if not self.forward_only:
|
|
997
|
+
if not self.forward_only and not self.has_register_backward_hook(name, submodule):
|
|
806
998
|
handle = submodule.register_full_backward_hook(bwd_hook_fun)
|
|
807
999
|
self.handles['xy'].append(handle)
|
|
808
1000
|
self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
|
|
@@ -814,8 +1006,9 @@ class TrainerMon:
|
|
|
814
1006
|
def patch_sync(sync_grad_func):
|
|
815
1007
|
def wrapper(bucket):
|
|
816
1008
|
grad_dict = {}
|
|
1009
|
+
bucket_params_id_list = [id(params) for params in bucket.params_list]
|
|
817
1010
|
for param, name in self.param2name.items():
|
|
818
|
-
if param not in
|
|
1011
|
+
if id(param) not in bucket_params_id_list:
|
|
819
1012
|
continue
|
|
820
1013
|
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
821
1014
|
if grad is None:
|
|
@@ -825,6 +1018,7 @@ class TrainerMon:
|
|
|
825
1018
|
if tag is None:
|
|
826
1019
|
continue
|
|
827
1020
|
grad_dict[tag] = grad
|
|
1021
|
+
self._register_param_call_id("sync_grad_func", tag)
|
|
828
1022
|
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
829
1023
|
out = sync_grad_func(bucket)
|
|
830
1024
|
return out
|
|
@@ -837,6 +1031,9 @@ class TrainerMon:
|
|
|
837
1031
|
except ImportError:
|
|
838
1032
|
self.enable_megatron = False
|
|
839
1033
|
|
|
1034
|
+
if not self.wg_distribution:
|
|
1035
|
+
return
|
|
1036
|
+
|
|
840
1037
|
if self.enable_megatron:
|
|
841
1038
|
Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
|
|
842
1039
|
else:
|
|
@@ -848,8 +1045,7 @@ class TrainerMon:
|
|
|
848
1045
|
@torch.no_grad
|
|
849
1046
|
def param_hook(*args, context_dict, param, key, name):
|
|
850
1047
|
param.micro_step += 1
|
|
851
|
-
self.
|
|
852
|
-
self.call_id += 1
|
|
1048
|
+
self._register_param_call_id("param_hook", key)
|
|
853
1049
|
if param.micro_step == self.micro_batch_number:
|
|
854
1050
|
param.micro_step = 0
|
|
855
1051
|
if self.params_have_main_grad:
|
|
@@ -868,3 +1064,13 @@ class TrainerMon:
|
|
|
868
1064
|
self.handles['wgrads'].append(handle)
|
|
869
1065
|
|
|
870
1066
|
self.weight_hooked = True
|
|
1067
|
+
|
|
1068
|
+
def _register_param_call_id(self, hook_name: str, key: str):
|
|
1069
|
+
"""
|
|
1070
|
+
:param hook_name:
|
|
1071
|
+
:param key: str, '0:relu_0/output_grad'
|
|
1072
|
+
:return:
|
|
1073
|
+
"""
|
|
1074
|
+
logger.debug(f"{hook_name} {key}: {self.call_id}")
|
|
1075
|
+
self.param_name_call_id[key] = self.call_id
|
|
1076
|
+
self.call_id += 1
|