mindstudio-probe 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +1 -1
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/RECORD +85 -66
- msprobe/README.md +2 -2
- msprobe/core/common/const.py +34 -9
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +14 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -7
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/utils.py +10 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +92 -8
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +17 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +58 -7
- msprobe/core/data_dump/json_writer.py +26 -8
- msprobe/docs/01.installation.md +25 -0
- msprobe/docs/02.config_introduction.md +14 -12
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +34 -15
- msprobe/docs/06.data_dump_MindSpore.md +45 -22
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -2
- msprobe/docs/19.monitor.md +257 -260
- msprobe/docs/21.visualization_PyTorch.md +10 -0
- msprobe/docs/22.visualization_MindSpore.md +11 -0
- msprobe/docs/27.dump_json_instruction.md +24 -20
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +26 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/utils.py +20 -2
- msprobe/mindspore/debugger/debugger_config.py +25 -2
- msprobe/mindspore/debugger/precision_debugger.py +25 -6
- msprobe/mindspore/dump/hook_cell/api_registry.py +2 -0
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/service.py +95 -21
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +71 -0
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +14 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +10 -30
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +4 -0
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +10 -12
- msprobe/pytorch/monitor/module_hook.py +123 -104
- msprobe/pytorch/monitor/module_metric.py +6 -6
- msprobe/pytorch/monitor/optimizer_collect.py +45 -63
- msprobe/pytorch/monitor/utils.py +8 -43
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +103 -24
- msprobe/visualization/builder/graph_builder.py +31 -5
- msprobe/visualization/builder/msprobe_adapter.py +7 -5
- msprobe/visualization/graph/base_node.py +3 -2
- msprobe/visualization/graph/distributed_analyzer.py +80 -3
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +3 -4
- msprobe/visualization/utils.py +10 -2
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,821 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
import uuid
|
|
19
|
+
from collections import defaultdict
|
|
20
|
+
from datetime import datetime
|
|
21
|
+
|
|
22
|
+
import pytz
|
|
23
|
+
import mindspore as ms
|
|
24
|
+
import mindspore.common.dtype as mstype
|
|
25
|
+
from mindspore import Tensor, ops, mint
|
|
26
|
+
from mindspore import nn, _no_grad
|
|
27
|
+
from mindspore.communication import get_rank
|
|
28
|
+
|
|
29
|
+
from msprobe.core.common.log import logger
|
|
30
|
+
from msprobe.core.common.const import MonitorConst
|
|
31
|
+
from msprobe.core.common.file_utils import load_json
|
|
32
|
+
from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \
|
|
33
|
+
is_skip_step, get_metrics, get_single_metrics
|
|
34
|
+
from msprobe.mindspore.monitor.module_spec_verifier import validate_config_spec
|
|
35
|
+
from msprobe.mindspore.monitor.anomaly_detect import AnomalyScanner, AnomalyDataFactory, \
|
|
36
|
+
CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
37
|
+
from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
|
|
38
|
+
get_process_group
|
|
39
|
+
|
|
40
|
+
FORMAT_MAPPING = {
|
|
41
|
+
MonitorConst.CSV: CSVWriterWithAD,
|
|
42
|
+
MonitorConst.API: BaseWriterWithAD
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_output_base_dir():
|
|
47
|
+
return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_param_struct(param):
|
|
51
|
+
res = {}
|
|
52
|
+
if isinstance(param, (tuple, list)):
|
|
53
|
+
res['config'] = f'{type(param).__name__}[{len(param)}]'
|
|
54
|
+
for i, x in enumerate(param):
|
|
55
|
+
res[i] = f'size={tuple(x.shape)}, dtype={x.dtype}' if isinstance(x, Tensor) else f'{type(x)}'
|
|
56
|
+
elif isinstance(param, Tensor):
|
|
57
|
+
res['config'] = 'tensor'
|
|
58
|
+
res['tensor'] = f'size={tuple(param.shape)}, dtype={param.dtype}'
|
|
59
|
+
else:
|
|
60
|
+
res['config'] = f'{type(param)}'
|
|
61
|
+
logger.warning(f'Not support type({type(param)}) now, please check the type of param {param}')
|
|
62
|
+
return res
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def param_is_not_tensor_parallel_duplicate(param, tp_group):
|
|
66
|
+
return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or (
|
|
67
|
+
mint.distributed.get_rank(group=tp_group) == 0
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def param_is_data_parallel_duplicate(dp_group):
|
|
72
|
+
return mint.distributed.get_rank(group=dp_group) != 0
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def squash_param_name(param_name):
|
|
76
|
+
for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']:
|
|
77
|
+
match = re.findall(pattern, param_name)
|
|
78
|
+
if match:
|
|
79
|
+
return match[0]
|
|
80
|
+
return param_name
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# Used For Module Forward & Backward Collect
|
|
84
|
+
class ModuleHookContext:
|
|
85
|
+
def __init__(self, module_name) -> None:
|
|
86
|
+
self.step = 0
|
|
87
|
+
self.micro_step = 0
|
|
88
|
+
self.actv = defaultdict(dict)
|
|
89
|
+
self.actvgrad = []
|
|
90
|
+
self.module_name = module_name
|
|
91
|
+
self.struct = {}
|
|
92
|
+
self.format_by_arg = {}
|
|
93
|
+
self.verified = False
|
|
94
|
+
self.focused_in_col = 0
|
|
95
|
+
self.focused_out_col = 0
|
|
96
|
+
self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found
|
|
97
|
+
|
|
98
|
+
def set_format_by_arg(self, key_name: str, target_config: dict):
|
|
99
|
+
cared = target_config.get(self.module_name, self.struct)
|
|
100
|
+
if key_name in cared:
|
|
101
|
+
if isinstance(cared[key_name], dict):
|
|
102
|
+
# current cared is self.struct
|
|
103
|
+
config = cared[key_name].get('config')
|
|
104
|
+
self.format_by_arg[key_name] = config
|
|
105
|
+
else:
|
|
106
|
+
# current cared is target_config[self.module_name]
|
|
107
|
+
self.format_by_arg[key_name] = cared[key_name]
|
|
108
|
+
elif key_name in ['input', 'input_grad']:
|
|
109
|
+
self.ignore_in = True
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
start_step = 0
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# Used For Optimizer Weight Grad & M/V Collect
|
|
116
|
+
class OptimizerContext:
|
|
117
|
+
def __init__(self) -> None:
|
|
118
|
+
self.step = start_step
|
|
119
|
+
self.param_effective_rank = defaultdict(float)
|
|
120
|
+
self.param_mg_direction = defaultdict(float)
|
|
121
|
+
self.param_adam_update = defaultdict()
|
|
122
|
+
self.param_adam_ratio = defaultdict()
|
|
123
|
+
self.param_weight_grad = defaultdict()
|
|
124
|
+
self.param_exp_avg = defaultdict()
|
|
125
|
+
self.exp_avg_metric = {}
|
|
126
|
+
self.param_exp_avg_sq = defaultdict()
|
|
127
|
+
self.exp_avg_sq_metric = {}
|
|
128
|
+
self.metric_dict = {}
|
|
129
|
+
self.param_metric = {}
|
|
130
|
+
|
|
131
|
+
def reset(self) -> None:
|
|
132
|
+
self.param_mg_direction.clear()
|
|
133
|
+
self.param_adam_update.clear()
|
|
134
|
+
self.param_weight_grad.clear()
|
|
135
|
+
self.param_exp_avg.clear()
|
|
136
|
+
self.exp_avg_metric.clear()
|
|
137
|
+
self.param_exp_avg_sq.clear()
|
|
138
|
+
self.exp_avg_sq_metric.clear()
|
|
139
|
+
self.metric_dict.clear()
|
|
140
|
+
self.param_metric.clear()
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
# Used For Weight Grad Collect
|
|
144
|
+
class GradContext:
|
|
145
|
+
def __init__(self) -> None:
|
|
146
|
+
self.pre = {}
|
|
147
|
+
self.post = {}
|
|
148
|
+
self.acc_metric = {}
|
|
149
|
+
self.acc = {}
|
|
150
|
+
self.actv = {}
|
|
151
|
+
|
|
152
|
+
def reset(self):
|
|
153
|
+
self.pre.clear()
|
|
154
|
+
self.post.clear()
|
|
155
|
+
self.acc_metric.clear()
|
|
156
|
+
self.acc.clear()
|
|
157
|
+
self.actv.clear()
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class CommunicationContext:
|
|
161
|
+
def __init__(self) -> None:
|
|
162
|
+
self.data = {}
|
|
163
|
+
|
|
164
|
+
@staticmethod
|
|
165
|
+
def _agg(data):
|
|
166
|
+
aggregated_data = {}
|
|
167
|
+
for tag, op2tensorlist in data.items():
|
|
168
|
+
aggregated_data[tag] = {}
|
|
169
|
+
for op, tensorlist in op2tensorlist.items():
|
|
170
|
+
aggregated_data[tag][op] = op_aggregate(op, tensorlist)
|
|
171
|
+
return aggregated_data
|
|
172
|
+
|
|
173
|
+
def reset(self):
|
|
174
|
+
self.data = {}
|
|
175
|
+
|
|
176
|
+
def aggregate(self):
|
|
177
|
+
self.data = self._agg(self.data)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class TrainerMon:
|
|
181
|
+
def __init__(self, config_file_path, process_group=None, params_have_main_grad=True) -> None:
|
|
182
|
+
self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
183
|
+
self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
184
|
+
self.optimizer_context = defaultdict(OptimizerContext)
|
|
185
|
+
self.cc_context = defaultdict(CommunicationContext)
|
|
186
|
+
self.grad_context = GradContext()
|
|
187
|
+
self.params_have_main_grad = params_have_main_grad
|
|
188
|
+
self.handles = defaultdict(list)
|
|
189
|
+
self.config = load_json(config_file_path)
|
|
190
|
+
validate_config(self.config)
|
|
191
|
+
|
|
192
|
+
self.start_step = self.config.get("start_step", 0)
|
|
193
|
+
self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集
|
|
194
|
+
self.step_interval = self.config.get("step_interval", 1)
|
|
195
|
+
self.has_collect_times = 0
|
|
196
|
+
|
|
197
|
+
# monitor target in module, such as layer, weight, grad
|
|
198
|
+
self.targets = self.config.get("targets", None)
|
|
199
|
+
self.is_select = self.config.get("is_select", False)
|
|
200
|
+
self.module_rank_list = self.config.get("module_ranks", [])
|
|
201
|
+
# only csv supported in mindspore
|
|
202
|
+
self.format = self.config.get('format', MonitorConst.CSV)
|
|
203
|
+
self.eps = self.config.get('eps', 1e-8)
|
|
204
|
+
# monitor mean/max/norm/min/nan...
|
|
205
|
+
self.ops = self.config.get('ops', [])
|
|
206
|
+
self.ndigits = self.config.get('ndigits', 6)
|
|
207
|
+
self.all_xy = self.config.get('all_xy', False)
|
|
208
|
+
# module input/output input_grad/output_grad
|
|
209
|
+
self.xy_distribution = self.config.get('xy_distribution', False)
|
|
210
|
+
# activation forward
|
|
211
|
+
self.forward_only = self.config.get('forward_only', False)
|
|
212
|
+
# activation backward
|
|
213
|
+
self.backward_only = self.config.get('backward_only', False)
|
|
214
|
+
# update vector and ratio vector of adam
|
|
215
|
+
self.ur_distribution = self.config.get('ur_distribution', False)
|
|
216
|
+
# m/v of adam
|
|
217
|
+
self.mv_distribution = self.config.get("mv_distribution", False)
|
|
218
|
+
# weight grad
|
|
219
|
+
self.wg_distribution = self.config.get("wg_distribution", False)
|
|
220
|
+
# optimizer param
|
|
221
|
+
self.param_distribution = self.config.get("param_distribution", False)
|
|
222
|
+
# main grad direction
|
|
223
|
+
self.mg_direction = self.config.get('mg_direction', False)
|
|
224
|
+
# communication ops
|
|
225
|
+
self.cc_distribution = self.config.get("cc_distribution", {})
|
|
226
|
+
if not self.cc_distribution.get('enable', False):
|
|
227
|
+
self.cc_log_only = False
|
|
228
|
+
else:
|
|
229
|
+
self.cc_codeline = self.cc_distribution.get('cc_codeline', [])
|
|
230
|
+
self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
|
|
231
|
+
self.cc_logged_stack = defaultdict(set)
|
|
232
|
+
self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
|
|
233
|
+
self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
234
|
+
api_register.redirect_api()
|
|
235
|
+
self.common_info()
|
|
236
|
+
|
|
237
|
+
alert_setting = self.config.get('alert', {"rules": []})
|
|
238
|
+
self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
|
|
239
|
+
|
|
240
|
+
local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
|
|
241
|
+
|
|
242
|
+
cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
|
|
243
|
+
unique_id = str(uuid.uuid4())[:8]
|
|
244
|
+
output_base_dir = get_output_base_dir()
|
|
245
|
+
|
|
246
|
+
time_tags = self.config.get("append_output", [])
|
|
247
|
+
if time_tags:
|
|
248
|
+
output_append_dirs = get_target_output_dir(output_base_dir, time_tags[0], time_tags[1])
|
|
249
|
+
try:
|
|
250
|
+
rank = get_rank()
|
|
251
|
+
except Exception as e:
|
|
252
|
+
rank = 0
|
|
253
|
+
tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}")
|
|
254
|
+
logger.error(f"Failed to get rank, setting tensorboard_dir to {tensorboard_dir}")
|
|
255
|
+
pp_stage = 0
|
|
256
|
+
group_mates = [0]
|
|
257
|
+
else:
|
|
258
|
+
if time_tags and str(rank) in output_append_dirs:
|
|
259
|
+
tensorboard_dir = outputappenddirs[str(rank)]
|
|
260
|
+
logger.info(f"Append rank({rank}) result to {tensorboard_dir}")
|
|
261
|
+
else:
|
|
262
|
+
tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}")
|
|
263
|
+
pp_stage = 0
|
|
264
|
+
group_mates = [0]
|
|
265
|
+
|
|
266
|
+
self.rank = rank
|
|
267
|
+
|
|
268
|
+
# 初始化AnomalyData工厂
|
|
269
|
+
self.anomaly_data_factory = None
|
|
270
|
+
if alert_setting.get('dump', False):
|
|
271
|
+
self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates)
|
|
272
|
+
|
|
273
|
+
if self.format not in FORMAT_MAPPING:
|
|
274
|
+
logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
|
|
275
|
+
self.format = MonitorConst.CSV
|
|
276
|
+
writer = FORMAT_MAPPING[self.format]
|
|
277
|
+
self.step_count_per_record = self.config.get('step_count_per_record', 1)
|
|
278
|
+
|
|
279
|
+
self.summary_writer = writer(
|
|
280
|
+
WriterInput(
|
|
281
|
+
tensorboard_dir,
|
|
282
|
+
self.alert_rules,
|
|
283
|
+
unique_id,
|
|
284
|
+
self.anomaly_data_factory,
|
|
285
|
+
self.ndigits,
|
|
286
|
+
self.step_count_per_record
|
|
287
|
+
)
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
self.micro_batch_number = 1
|
|
291
|
+
|
|
292
|
+
self.model = None
|
|
293
|
+
self.weight_hooked = False
|
|
294
|
+
self.optimizer_hooked = False
|
|
295
|
+
self.param_registered = False
|
|
296
|
+
self.vpp = False
|
|
297
|
+
self.dp_group = None
|
|
298
|
+
self.tp_group = None
|
|
299
|
+
self.enable_megatron = False
|
|
300
|
+
|
|
301
|
+
self.param2name = defaultdict(str)
|
|
302
|
+
self.name2index = defaultdict()
|
|
303
|
+
self.name2indices = defaultdict()
|
|
304
|
+
self.name2param = {}
|
|
305
|
+
self.param_name_call_id = {}
|
|
306
|
+
self.duplicate_param = {}
|
|
307
|
+
self.name2tag = {}
|
|
308
|
+
self.call_id = 0
|
|
309
|
+
self.grad_accs = []
|
|
310
|
+
self.handles = defaultdict(list)
|
|
311
|
+
|
|
312
|
+
self.print_struct = self.config.get("print_struct", False)
|
|
313
|
+
self.struct_printed = False
|
|
314
|
+
self.module_struct = defaultdict(dict)
|
|
315
|
+
|
|
316
|
+
# Start
|
|
317
|
+
def set_monitor(
|
|
318
|
+
self,
|
|
319
|
+
model,
|
|
320
|
+
grad_acc_steps=1,
|
|
321
|
+
optimizer=None,
|
|
322
|
+
tp_group=None,
|
|
323
|
+
dp_group=None,
|
|
324
|
+
start_iteration=0):
|
|
325
|
+
global start_step
|
|
326
|
+
start_step = start_iteration
|
|
327
|
+
logger.info(f'grad acc steps {grad_acc_steps}')
|
|
328
|
+
self.hook_optimizer(optimizer)
|
|
329
|
+
self.micro_batch_number = grad_acc_steps
|
|
330
|
+
self.dp_group = dp_group
|
|
331
|
+
self.tp_group = tp_group
|
|
332
|
+
|
|
333
|
+
self.hook_modules(model, grad_acc_steps)
|
|
334
|
+
self._patch_grad_sync()
|
|
335
|
+
|
|
336
|
+
"""
|
|
337
|
+
Start
|
|
338
|
+
"""
|
|
339
|
+
def hook_optimizer(self, optimizer):
|
|
340
|
+
rank_id = str(get_rank())
|
|
341
|
+
if self.optimizer_hooked:
|
|
342
|
+
return
|
|
343
|
+
|
|
344
|
+
if not self.is_target_rank():
|
|
345
|
+
return
|
|
346
|
+
|
|
347
|
+
m_list = []
|
|
348
|
+
v_list = []
|
|
349
|
+
param_list = []
|
|
350
|
+
grad_names = []
|
|
351
|
+
for param in optimizer.get_parameters():
|
|
352
|
+
if MonitorConst.EXP_AVG_SQ in param.name:
|
|
353
|
+
v_list.append(param)
|
|
354
|
+
elif MonitorConst.EXP_AVG in param.name:
|
|
355
|
+
m_list.append(param)
|
|
356
|
+
else:
|
|
357
|
+
param_list.append(param)
|
|
358
|
+
grad_names.append(param.name)
|
|
359
|
+
|
|
360
|
+
"""
|
|
361
|
+
grad reduced
|
|
362
|
+
m/v
|
|
363
|
+
"""
|
|
364
|
+
def optimizer_pre_hook_function(opt, grad_names, gradients):
|
|
365
|
+
context = self.optimizer_context[opt]
|
|
366
|
+
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, \
|
|
367
|
+
self.collect_times):
|
|
368
|
+
return
|
|
369
|
+
gradient_list = gradients[0] if isinstance(gradients, tuple) else gradients
|
|
370
|
+
is_select = self.is_select
|
|
371
|
+
for idx, grad in enumerate(gradient_list):
|
|
372
|
+
grad_name = grad_names[idx]
|
|
373
|
+
if is_select and grad_name not in self.targets:
|
|
374
|
+
continue
|
|
375
|
+
get_single_metrics(self.ops, grad_name, grad, context.param_weight_grad)
|
|
376
|
+
|
|
377
|
+
if self.mv_distribution:
|
|
378
|
+
# fetch mean
|
|
379
|
+
for param in m_list:
|
|
380
|
+
name = param.name
|
|
381
|
+
if is_select and name not in self.targets:
|
|
382
|
+
continue
|
|
383
|
+
get_single_metrics(self.ops, name, param, context.exp_avg_metric)
|
|
384
|
+
# fetch variance
|
|
385
|
+
for param in v_list:
|
|
386
|
+
name = param.name
|
|
387
|
+
if is_select and name not in self.targets:
|
|
388
|
+
continue
|
|
389
|
+
get_single_metrics(self.ops, name, param, context.exp_avg_sq_metric)
|
|
390
|
+
if self.param_distribution:
|
|
391
|
+
for param in param_list:
|
|
392
|
+
get_single_metrics(self.ops, param.name, param, context.param_metric)
|
|
393
|
+
self.generate_wgrad_metrics()
|
|
394
|
+
metric_dict = {}
|
|
395
|
+
for cc in self.cc_context.values():
|
|
396
|
+
cc.aggregate()
|
|
397
|
+
metric_dict.update(cc.data)
|
|
398
|
+
cc.reset()
|
|
399
|
+
|
|
400
|
+
if not metric_dict:
|
|
401
|
+
return
|
|
402
|
+
context.metric_dict = metric_dict
|
|
403
|
+
return
|
|
404
|
+
|
|
405
|
+
def optimizer_post_hook_function(opt, args, gradients, outputs):
|
|
406
|
+
context = self.optimizer_context[opt]
|
|
407
|
+
step_skip = is_skip_step(context.step, self.start_step, self.step_interval, \
|
|
408
|
+
self.has_collect_times, self.collect_times)
|
|
409
|
+
if step_skip:
|
|
410
|
+
context.step += 1
|
|
411
|
+
return
|
|
412
|
+
self.write_xy_tb(context.step)
|
|
413
|
+
self.write_grad_tb(context.step)
|
|
414
|
+
self.write_mv_tb(context)
|
|
415
|
+
self.write_param_tb(context)
|
|
416
|
+
|
|
417
|
+
if context.metric_dict:
|
|
418
|
+
self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
|
|
419
|
+
context.metric_dict.clear()
|
|
420
|
+
self.has_collect_times += 1
|
|
421
|
+
context.step += 1
|
|
422
|
+
if self.anomaly_data_factory:
|
|
423
|
+
self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
|
|
424
|
+
self.summary_writer.clear_anomalies()
|
|
425
|
+
self.call_id = 0
|
|
426
|
+
self.param_name_call_id.clear()
|
|
427
|
+
return
|
|
428
|
+
|
|
429
|
+
def optimizer_pre_hook_wrapper(func, grad_names):
|
|
430
|
+
def wrapper(opt, gradients):
|
|
431
|
+
return func(opt, grad_names, gradients)
|
|
432
|
+
return wrapper
|
|
433
|
+
|
|
434
|
+
def optimizer_post_hook_wrapper(func, args=None):
|
|
435
|
+
def wrapper(opt, gradients, outputs):
|
|
436
|
+
return func(opt, args, gradients, outputs)
|
|
437
|
+
return wrapper
|
|
438
|
+
|
|
439
|
+
optimizer.register_forward_pre_hook(optimizer_pre_hook_wrapper(optimizer_pre_hook_function, grad_names))
|
|
440
|
+
optimizer.register_forward_hook(optimizer_post_hook_wrapper(optimizer_post_hook_function))
|
|
441
|
+
|
|
442
|
+
self.optimizer_hooked = True
|
|
443
|
+
return
|
|
444
|
+
|
|
445
|
+
def write_xy_tb(self, step):
|
|
446
|
+
if not self.xy_distribution:
|
|
447
|
+
return
|
|
448
|
+
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
449
|
+
if len(fwd_context.actv) == 0:
|
|
450
|
+
continue
|
|
451
|
+
self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, 'actv')
|
|
452
|
+
fwd_context.actv.clear()
|
|
453
|
+
if self.grad_context.actv:
|
|
454
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, 'actv_grad')
|
|
455
|
+
|
|
456
|
+
def write_param_tb(self, opt_context):
|
|
457
|
+
if not self.param_distribution:
|
|
458
|
+
return
|
|
459
|
+
self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, 'param')
|
|
460
|
+
|
|
461
|
+
def write_mv_tb(self, opt_context):
|
|
462
|
+
if not self.mv_distribution:
|
|
463
|
+
return
|
|
464
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, 'exp_avg')
|
|
465
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq')
|
|
466
|
+
|
|
467
|
+
def write_grad_tb(self, step):
|
|
468
|
+
if not self.wg_distribution:
|
|
469
|
+
return
|
|
470
|
+
|
|
471
|
+
if self.enable_megatron:
|
|
472
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
|
|
473
|
+
else:
|
|
474
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
|
|
475
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
476
|
+
|
|
477
|
+
def common_info(self):
|
|
478
|
+
if not self.xy_distribution:
|
|
479
|
+
logger.info("> module input/output input_grad/output_grad is not monitored. ")
|
|
480
|
+
if self.forward_only:
|
|
481
|
+
logger.info("> only module forward is monitored. ")
|
|
482
|
+
if not self.ur_distribution:
|
|
483
|
+
logger.info("> update vector and ratio vector of adam is not monitored. ")
|
|
484
|
+
if not self.mv_distribution:
|
|
485
|
+
logger.info("> momentum and variance of adam is not monitored. ")
|
|
486
|
+
if not self.wg_distribution:
|
|
487
|
+
logger.info("> weight grad of specified module is not monitored. ")
|
|
488
|
+
if not self.mg_direction:
|
|
489
|
+
logger.info('> grad and momentum direction will not be compared.')
|
|
490
|
+
if not self.cc_distribution.get('enable', False):
|
|
491
|
+
logger.info("> cc operator is not monitored.")
|
|
492
|
+
|
|
493
|
+
def is_target_rank(self):
|
|
494
|
+
rank_id = str(get_rank())
|
|
495
|
+
if self.module_rank_list and (rank_id not in self.module_rank_list):
|
|
496
|
+
return False
|
|
497
|
+
return True
|
|
498
|
+
|
|
499
|
+
def hook_modules(self, model, grad_acc_steps):
|
|
500
|
+
if not self.is_target_rank():
|
|
501
|
+
return
|
|
502
|
+
if not isinstance(model, list):
|
|
503
|
+
model = [model]
|
|
504
|
+
self.model = model # list
|
|
505
|
+
self._register_param_name(model)
|
|
506
|
+
self.micro_batch_number = grad_acc_steps
|
|
507
|
+
module_in_all_stage = [key for key in self.targets.keys() if MonitorConst.NAME_SEP not in key]
|
|
508
|
+
|
|
509
|
+
for key in module_in_all_stage:
|
|
510
|
+
struct = self.targets.pop(key)
|
|
511
|
+
self.targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(model))})
|
|
512
|
+
|
|
513
|
+
hooked_count = 0
|
|
514
|
+
for vpp_stage, model_chunk in enumerate(model):
|
|
515
|
+
if not isinstance(model_chunk, nn.Cell):
|
|
516
|
+
logger.info("Target Model is not Cell")
|
|
517
|
+
continue
|
|
518
|
+
vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
519
|
+
targets = [x for x, _ in model_chunk.cells_and_names()] if self.print_struct else self.targets.keys()
|
|
520
|
+
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
521
|
+
logger.info(f"> {hooked_count} modules are monitored.")
|
|
522
|
+
|
|
523
|
+
def build_tbtag_tensor_map(self, module_name, tag, tensor):
|
|
524
|
+
rank_id = str(get_rank())
|
|
525
|
+
metrics = {}
|
|
526
|
+
key = get_summary_writer_tag_name(module_name, tag, rank_id)
|
|
527
|
+
if isinstance(tensor, Tensor):
|
|
528
|
+
self._register_param_call_id("_hook_module", key)
|
|
529
|
+
metrics[key] = tensor
|
|
530
|
+
return metrics
|
|
531
|
+
|
|
532
|
+
def generate_wgrad_metrics(self):
|
|
533
|
+
if not self.wg_distribution:
|
|
534
|
+
return {}, {}
|
|
535
|
+
|
|
536
|
+
if self.weight_hooked:
|
|
537
|
+
try:
|
|
538
|
+
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
539
|
+
except Exception as e:
|
|
540
|
+
logger.warning(f"An error occurred while generating wgrad pre metrics")
|
|
541
|
+
return {}, {}
|
|
542
|
+
|
|
543
|
+
grad_dict = {}
|
|
544
|
+
for param, name in self.param2name.items():
|
|
545
|
+
if self.duplicate_param.get(name, False):
|
|
546
|
+
continue
|
|
547
|
+
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
548
|
+
if grad is None:
|
|
549
|
+
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
550
|
+
continue
|
|
551
|
+
tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
552
|
+
self._register_param_call_id("hook_optimizer", tag)
|
|
553
|
+
grad_dict[tag] = grad
|
|
554
|
+
try:
|
|
555
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
|
|
556
|
+
except Exception as e:
|
|
557
|
+
logger.warning(f"An error occurred while generating wgrad post metrics")
|
|
558
|
+
return {}, {}
|
|
559
|
+
return self.grad_context.post, self.grad_context.pre
|
|
560
|
+
|
|
561
|
+
def _register_param_name(self, model):
|
|
562
|
+
if self.param_registered:
|
|
563
|
+
return
|
|
564
|
+
|
|
565
|
+
if len(model) > 1:
|
|
566
|
+
self.vpp = True
|
|
567
|
+
logger.info('vpp enabled')
|
|
568
|
+
|
|
569
|
+
for vpp_stage, model_chunk in enumerate(model):
|
|
570
|
+
prefix = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
571
|
+
self._register_chunk(model_chunk, prefix)
|
|
572
|
+
|
|
573
|
+
self.param_registered = True
|
|
574
|
+
|
|
575
|
+
def _is_target_param(self, param_name, param, prefix):
|
|
576
|
+
if not self.targets:
|
|
577
|
+
return True
|
|
578
|
+
squash_name = prefix + squash_param_name(param_name)
|
|
579
|
+
name = prefix + param_name
|
|
580
|
+
for target in self.targets.keys():
|
|
581
|
+
if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
|
|
582
|
+
setattr(param, "zero_out_wgrad", True)
|
|
583
|
+
return True
|
|
584
|
+
return False
|
|
585
|
+
|
|
586
|
+
def _register_chunk(self, model_chunk, prefix):
|
|
587
|
+
index = 0
|
|
588
|
+
for param in model_chunk.get_parameters():
|
|
589
|
+
param_name = param.name
|
|
590
|
+
if not param.requires_grad:
|
|
591
|
+
continue
|
|
592
|
+
if self._is_target_param(param_name, param, prefix):
|
|
593
|
+
name = prefix + squash_param_name(param_name)
|
|
594
|
+
if name in self.param2name.values():
|
|
595
|
+
name = prefix + param_name
|
|
596
|
+
self.param2name[param] = name
|
|
597
|
+
self.name2param[name] = param
|
|
598
|
+
self.name2index[name] = index
|
|
599
|
+
|
|
600
|
+
if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group):
|
|
601
|
+
self.duplicate_param[name] = True
|
|
602
|
+
if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
|
|
603
|
+
self.duplicate_param[name] = True
|
|
604
|
+
self.name2tag[name] = {
|
|
605
|
+
MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank),
|
|
606
|
+
MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank)
|
|
607
|
+
}
|
|
608
|
+
index += 1
|
|
609
|
+
|
|
610
|
+
def _is_target_module(self, module_name, targets, vpp_stage):
|
|
611
|
+
if self.all_xy or self.print_struct:
|
|
612
|
+
return vpp_stage + squash_param_name(module_name)
|
|
613
|
+
for pattern in [
|
|
614
|
+
vpp_stage + squash_param_name(module_name),
|
|
615
|
+
vpp_stage + module_name,
|
|
616
|
+
]:
|
|
617
|
+
if pattern in targets:
|
|
618
|
+
return pattern
|
|
619
|
+
return ""
|
|
620
|
+
|
|
621
|
+
def _hook_module(self, target_names, module, vpp_stage=''):
|
|
622
|
+
if not isinstance(module, nn.Cell):
|
|
623
|
+
# nothing to hook
|
|
624
|
+
return 0
|
|
625
|
+
|
|
626
|
+
def fwd_hook_fun(module, module_input, module_output, name):
|
|
627
|
+
if module not in self.module_fwd_hook_context_by_module:
|
|
628
|
+
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
629
|
+
context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
|
|
630
|
+
if not context.struct:
|
|
631
|
+
context.struct = {
|
|
632
|
+
MonitorConst.ACTV_IN: get_param_struct(module_input),
|
|
633
|
+
MonitorConst.ACTV_OUT: get_param_struct(module_output)
|
|
634
|
+
}
|
|
635
|
+
if self.print_struct:
|
|
636
|
+
self.module_struct[context.module_name].update(context.struct)
|
|
637
|
+
return
|
|
638
|
+
if not module.training:
|
|
639
|
+
return
|
|
640
|
+
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, \
|
|
641
|
+
self.collect_times):
|
|
642
|
+
step_accumulates_one(context, self.micro_batch_number)
|
|
643
|
+
return
|
|
644
|
+
if not context.format_by_arg:
|
|
645
|
+
context.set_format_by_arg(MonitorConst.ACTV_IN, self.targets)
|
|
646
|
+
context.set_format_by_arg(MonitorConst.ACTV_OUT, self.targets)
|
|
647
|
+
if not context.format_by_arg:
|
|
648
|
+
return
|
|
649
|
+
if not context.verified:
|
|
650
|
+
if not context.ignore_in:
|
|
651
|
+
context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
|
|
652
|
+
module_input, context.module_name,
|
|
653
|
+
MonitorConst.ACTV_IN)
|
|
654
|
+
context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
|
|
655
|
+
module_output, context.module_name,
|
|
656
|
+
MonitorConst.ACTV_OUT)
|
|
657
|
+
context.verified = True
|
|
658
|
+
|
|
659
|
+
tbtag_tensor_map = {}
|
|
660
|
+
if not context.ignore_in:
|
|
661
|
+
cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
|
|
662
|
+
tbtag_tensor_map.update(
|
|
663
|
+
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
|
|
664
|
+
cared_input))
|
|
665
|
+
cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
|
|
666
|
+
tbtag_tensor_map.update(
|
|
667
|
+
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT,
|
|
668
|
+
cared_output))
|
|
669
|
+
try:
|
|
670
|
+
get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
|
|
671
|
+
except Exception as e:
|
|
672
|
+
logger.warning(f"An error occurred while generating forward activation metrics")
|
|
673
|
+
|
|
674
|
+
step_accumulates_one(context, self.micro_batch_number)
|
|
675
|
+
return
|
|
676
|
+
|
|
677
|
+
def bwd_hook_fun(module, input_grad, output_grad):
|
|
678
|
+
context: ModuleHookContext = self.module_bwd_hook_context_by_module[module]
|
|
679
|
+
if not context.struct:
|
|
680
|
+
context.struct = {
|
|
681
|
+
MonitorConst.ACTVGRAD_IN: get_param_struct(input_grad),
|
|
682
|
+
MonitorConst.ACTVGRAD_OUT: get_param_struct(output_grad)
|
|
683
|
+
}
|
|
684
|
+
if self.print_struct:
|
|
685
|
+
self.module_struct[context.module_name].update(context.struct)
|
|
686
|
+
return
|
|
687
|
+
|
|
688
|
+
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, \
|
|
689
|
+
self.collect_times):
|
|
690
|
+
step_accumulates_one(context, self.micro_batch_number)
|
|
691
|
+
return
|
|
692
|
+
|
|
693
|
+
if not context.format_by_arg:
|
|
694
|
+
context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.targets)
|
|
695
|
+
context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.targets)
|
|
696
|
+
if not context.format_by_arg:
|
|
697
|
+
return
|
|
698
|
+
if not context.verified:
|
|
699
|
+
if not context.ignore_in:
|
|
700
|
+
context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
|
|
701
|
+
input_grad, context.module_name,
|
|
702
|
+
MonitorConst.ACTVGRAD_IN)
|
|
703
|
+
context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
|
|
704
|
+
output_grad, context.module_name,
|
|
705
|
+
MonitorConst.ACTVGRAD_OUT)
|
|
706
|
+
context.verified = True
|
|
707
|
+
|
|
708
|
+
tbtag_tensor_map = {}
|
|
709
|
+
if not context.ignore_in:
|
|
710
|
+
cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
|
|
711
|
+
tbtag_tensor_map.update(
|
|
712
|
+
self.build_tbtag_tensor_map(
|
|
713
|
+
f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad))
|
|
714
|
+
cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
|
|
715
|
+
tbtag_tensor_map.update(
|
|
716
|
+
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT,
|
|
717
|
+
cared_output_grad))
|
|
718
|
+
|
|
719
|
+
if context.micro_step == 0 and context.actvgrad:
|
|
720
|
+
logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
|
|
721
|
+
f"maybe something wrong happened. Now clear it.")
|
|
722
|
+
context.actvgrad.clear()
|
|
723
|
+
try:
|
|
724
|
+
get_metrics(self.ops, tbtag_tensor_map, self.eps, self.grad_context.actv)
|
|
725
|
+
except Exception as e:
|
|
726
|
+
logger.warning(f"An error occurred while generating backward activation metrics: {e}")
|
|
727
|
+
|
|
728
|
+
step_accumulates_one(context, self.micro_batch_number)
|
|
729
|
+
return
|
|
730
|
+
|
|
731
|
+
def fwd_hook_fun_wrapper(fwd_hook_fun, name):
|
|
732
|
+
def wrapper(module, module_input, module_output):
|
|
733
|
+
return fwd_hook_fun(module, module_input, module_output, name)
|
|
734
|
+
return wrapper
|
|
735
|
+
|
|
736
|
+
if self.backward_only and self.forward_only:
|
|
737
|
+
logger.warning('not enable backward_only and forward_only simultaneously')
|
|
738
|
+
hooked_count = 0
|
|
739
|
+
if self.xy_distribution or self.print_struct:
|
|
740
|
+
for module_name, submodule in module.cells_and_names():
|
|
741
|
+
name = self._is_target_module(module_name, target_names, vpp_stage)
|
|
742
|
+
if not name:
|
|
743
|
+
continue
|
|
744
|
+
if not self.backward_only:
|
|
745
|
+
handle = submodule.register_forward_hook(fwd_hook_fun_wrapper(fwd_hook_fun, name=name))
|
|
746
|
+
self.handles['xy'].append(handle)
|
|
747
|
+
if not self.forward_only:
|
|
748
|
+
handle = submodule.register_backward_hook(bwd_hook_fun)
|
|
749
|
+
self.handles['xy'].append(handle)
|
|
750
|
+
self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
|
|
751
|
+
logger.info(f"> {name} is monitored successfully")
|
|
752
|
+
hooked_count += 1
|
|
753
|
+
return hooked_count
|
|
754
|
+
|
|
755
|
+
def _register_param_call_id(self, hook_name: str, key: str):
|
|
756
|
+
"""
|
|
757
|
+
:param hook_name:
|
|
758
|
+
:param key: str, '0:relu_0/output_grad'
|
|
759
|
+
:return:
|
|
760
|
+
"""
|
|
761
|
+
logger.debug(f"{hook_name} {key}: {self.call_id}")
|
|
762
|
+
self.param_name_call_id[key] = self.call_id
|
|
763
|
+
self.call_id += 1
|
|
764
|
+
|
|
765
|
+
def _patch_grad_sync(self):
|
|
766
|
+
# mindspore 暂不使用megatron
|
|
767
|
+
def patch_sync(sync_grad_func):
|
|
768
|
+
def wrapper(bucket):
|
|
769
|
+
grad_dict = {}
|
|
770
|
+
for param, name in self.param2name.items():
|
|
771
|
+
if param not in bucket.params_list:
|
|
772
|
+
continue
|
|
773
|
+
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
774
|
+
if grad is None:
|
|
775
|
+
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
776
|
+
continue
|
|
777
|
+
tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
|
|
778
|
+
if tag is None:
|
|
779
|
+
continue
|
|
780
|
+
grad_dict[tag] = grad
|
|
781
|
+
try:
|
|
782
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
783
|
+
except Exception as e:
|
|
784
|
+
logger.warning(f"An error occurred while generating weight grad metrics")
|
|
785
|
+
out = sync_grad_func(bucket)
|
|
786
|
+
return out
|
|
787
|
+
|
|
788
|
+
return wrapper
|
|
789
|
+
|
|
790
|
+
self.enable_megatron = False
|
|
791
|
+
|
|
792
|
+
if not self.wg_distribution:
|
|
793
|
+
return
|
|
794
|
+
|
|
795
|
+
if self.enable_megatron:
|
|
796
|
+
Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
|
|
797
|
+
else:
|
|
798
|
+
self._hook_weights()
|
|
799
|
+
|
|
800
|
+
def _hook_weights(self):
|
|
801
|
+
context = self.grad_context
|
|
802
|
+
|
|
803
|
+
@_no_grad()
|
|
804
|
+
def param_hook(grad, context_dict, param, key):
|
|
805
|
+
param.micro_step += 1
|
|
806
|
+
self._register_param_call_id("param_hook", key)
|
|
807
|
+
if param.micro_step == self.micro_batch_number:
|
|
808
|
+
param.micro_step = 0
|
|
809
|
+
context_dict[key] = grad
|
|
810
|
+
|
|
811
|
+
def param_hook_wrapper(param_hook, context_dict, param, key):
|
|
812
|
+
def wrapper(grad):
|
|
813
|
+
return param_hook(grad, context_dict, param, key)
|
|
814
|
+
return wrapper
|
|
815
|
+
|
|
816
|
+
for param, name in self.param2name.items():
|
|
817
|
+
key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
|
|
818
|
+
setattr(param, 'micro_step', 0)
|
|
819
|
+
handle = param.register_hook(param_hook_wrapper(param_hook, context_dict=context.acc, param=param, key=key))
|
|
820
|
+
self.handles['wgrads'].append(handle)
|
|
821
|
+
self.weight_hooked = True
|