mindstudio-probe 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/RECORD +85 -66
  3. msprobe/README.md +2 -2
  4. msprobe/core/common/const.py +34 -9
  5. msprobe/core/common/inplace_ops.yaml +1 -0
  6. msprobe/core/common/utils.py +14 -0
  7. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  8. msprobe/core/compare/merge_result/merge_result.py +8 -7
  9. msprobe/core/compare/merge_result/utils.py +81 -0
  10. msprobe/core/compare/utils.py +10 -0
  11. msprobe/core/data_dump/data_collector.py +58 -13
  12. msprobe/core/data_dump/data_processor/base.py +92 -8
  13. msprobe/core/data_dump/data_processor/factory.py +3 -0
  14. msprobe/core/data_dump/data_processor/mindspore_processor.py +17 -4
  15. msprobe/core/data_dump/data_processor/pytorch_processor.py +58 -7
  16. msprobe/core/data_dump/json_writer.py +26 -8
  17. msprobe/docs/01.installation.md +25 -0
  18. msprobe/docs/02.config_introduction.md +14 -12
  19. msprobe/docs/03.config_examples.md +24 -0
  20. msprobe/docs/05.data_dump_PyTorch.md +34 -15
  21. msprobe/docs/06.data_dump_MindSpore.md +45 -22
  22. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -2
  23. msprobe/docs/19.monitor.md +257 -260
  24. msprobe/docs/21.visualization_PyTorch.md +10 -0
  25. msprobe/docs/22.visualization_MindSpore.md +11 -0
  26. msprobe/docs/27.dump_json_instruction.md +24 -20
  27. msprobe/docs/28.debugger_save_instruction.md +94 -0
  28. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  29. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  30. msprobe/mindspore/__init__.py +1 -0
  31. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +26 -6
  32. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  33. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  34. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  35. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  36. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  37. msprobe/mindspore/common/utils.py +20 -2
  38. msprobe/mindspore/debugger/debugger_config.py +25 -2
  39. msprobe/mindspore/debugger/precision_debugger.py +25 -6
  40. msprobe/mindspore/dump/hook_cell/api_registry.py +2 -0
  41. msprobe/mindspore/dump/jit_dump.py +7 -6
  42. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  43. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  44. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  45. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  46. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  47. msprobe/mindspore/monitor/features.py +63 -0
  48. msprobe/mindspore/monitor/module_hook.py +821 -0
  49. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  50. msprobe/mindspore/monitor/utils.py +267 -0
  51. msprobe/mindspore/ms_config.py +8 -2
  52. msprobe/mindspore/service.py +95 -21
  53. msprobe/pytorch/__init__.py +0 -1
  54. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  55. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  56. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  57. msprobe/pytorch/bench_functions/mish.py +21 -0
  58. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  59. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  60. msprobe/pytorch/common/utils.py +71 -0
  61. msprobe/pytorch/debugger/debugger_config.py +19 -9
  62. msprobe/pytorch/debugger/precision_debugger.py +14 -0
  63. msprobe/pytorch/dump/module_dump/module_processer.py +10 -30
  64. msprobe/pytorch/function_factory.py +7 -1
  65. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  66. msprobe/pytorch/hook_module/wrap_distributed.py +4 -0
  67. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  68. msprobe/pytorch/monitor/csv2tb.py +10 -12
  69. msprobe/pytorch/monitor/module_hook.py +123 -104
  70. msprobe/pytorch/monitor/module_metric.py +6 -6
  71. msprobe/pytorch/monitor/optimizer_collect.py +45 -63
  72. msprobe/pytorch/monitor/utils.py +8 -43
  73. msprobe/pytorch/pt_config.py +19 -22
  74. msprobe/pytorch/service.py +103 -24
  75. msprobe/visualization/builder/graph_builder.py +31 -5
  76. msprobe/visualization/builder/msprobe_adapter.py +7 -5
  77. msprobe/visualization/graph/base_node.py +3 -2
  78. msprobe/visualization/graph/distributed_analyzer.py +80 -3
  79. msprobe/visualization/graph/node_op.py +4 -2
  80. msprobe/visualization/graph_service.py +3 -4
  81. msprobe/visualization/utils.py +10 -2
  82. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  83. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  84. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  85. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,821 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import re
18
+ import uuid
19
+ from collections import defaultdict
20
+ from datetime import datetime
21
+
22
+ import pytz
23
+ import mindspore as ms
24
+ import mindspore.common.dtype as mstype
25
+ from mindspore import Tensor, ops, mint
26
+ from mindspore import nn, _no_grad
27
+ from mindspore.communication import get_rank
28
+
29
+ from msprobe.core.common.log import logger
30
+ from msprobe.core.common.const import MonitorConst
31
+ from msprobe.core.common.file_utils import load_json
32
+ from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \
33
+ is_skip_step, get_metrics, get_single_metrics
34
+ from msprobe.mindspore.monitor.module_spec_verifier import validate_config_spec
35
+ from msprobe.mindspore.monitor.anomaly_detect import AnomalyScanner, AnomalyDataFactory, \
36
+ CSVWriterWithAD, BaseWriterWithAD, WriterInput
37
+ from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
38
+ get_process_group
39
+
40
+ FORMAT_MAPPING = {
41
+ MonitorConst.CSV: CSVWriterWithAD,
42
+ MonitorConst.API: BaseWriterWithAD
43
+ }
44
+
45
+
46
+ def get_output_base_dir():
47
+ return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
48
+
49
+
50
+ def get_param_struct(param):
51
+ res = {}
52
+ if isinstance(param, (tuple, list)):
53
+ res['config'] = f'{type(param).__name__}[{len(param)}]'
54
+ for i, x in enumerate(param):
55
+ res[i] = f'size={tuple(x.shape)}, dtype={x.dtype}' if isinstance(x, Tensor) else f'{type(x)}'
56
+ elif isinstance(param, Tensor):
57
+ res['config'] = 'tensor'
58
+ res['tensor'] = f'size={tuple(param.shape)}, dtype={param.dtype}'
59
+ else:
60
+ res['config'] = f'{type(param)}'
61
+ logger.warning(f'Not support type({type(param)}) now, please check the type of param {param}')
62
+ return res
63
+
64
+
65
+ def param_is_not_tensor_parallel_duplicate(param, tp_group):
66
+ return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or (
67
+ mint.distributed.get_rank(group=tp_group) == 0
68
+ )
69
+
70
+
71
+ def param_is_data_parallel_duplicate(dp_group):
72
+ return mint.distributed.get_rank(group=dp_group) != 0
73
+
74
+
75
+ def squash_param_name(param_name):
76
+ for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']:
77
+ match = re.findall(pattern, param_name)
78
+ if match:
79
+ return match[0]
80
+ return param_name
81
+
82
+
83
+ # Used For Module Forward & Backward Collect
84
+ class ModuleHookContext:
85
+ def __init__(self, module_name) -> None:
86
+ self.step = 0
87
+ self.micro_step = 0
88
+ self.actv = defaultdict(dict)
89
+ self.actvgrad = []
90
+ self.module_name = module_name
91
+ self.struct = {}
92
+ self.format_by_arg = {}
93
+ self.verified = False
94
+ self.focused_in_col = 0
95
+ self.focused_out_col = 0
96
+ self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found
97
+
98
+ def set_format_by_arg(self, key_name: str, target_config: dict):
99
+ cared = target_config.get(self.module_name, self.struct)
100
+ if key_name in cared:
101
+ if isinstance(cared[key_name], dict):
102
+ # current cared is self.struct
103
+ config = cared[key_name].get('config')
104
+ self.format_by_arg[key_name] = config
105
+ else:
106
+ # current cared is target_config[self.module_name]
107
+ self.format_by_arg[key_name] = cared[key_name]
108
+ elif key_name in ['input', 'input_grad']:
109
+ self.ignore_in = True
110
+
111
+
112
+ start_step = 0
113
+
114
+
115
+ # Used For Optimizer Weight Grad & M/V Collect
116
+ class OptimizerContext:
117
+ def __init__(self) -> None:
118
+ self.step = start_step
119
+ self.param_effective_rank = defaultdict(float)
120
+ self.param_mg_direction = defaultdict(float)
121
+ self.param_adam_update = defaultdict()
122
+ self.param_adam_ratio = defaultdict()
123
+ self.param_weight_grad = defaultdict()
124
+ self.param_exp_avg = defaultdict()
125
+ self.exp_avg_metric = {}
126
+ self.param_exp_avg_sq = defaultdict()
127
+ self.exp_avg_sq_metric = {}
128
+ self.metric_dict = {}
129
+ self.param_metric = {}
130
+
131
+ def reset(self) -> None:
132
+ self.param_mg_direction.clear()
133
+ self.param_adam_update.clear()
134
+ self.param_weight_grad.clear()
135
+ self.param_exp_avg.clear()
136
+ self.exp_avg_metric.clear()
137
+ self.param_exp_avg_sq.clear()
138
+ self.exp_avg_sq_metric.clear()
139
+ self.metric_dict.clear()
140
+ self.param_metric.clear()
141
+
142
+
143
+ # Used For Weight Grad Collect
144
+ class GradContext:
145
+ def __init__(self) -> None:
146
+ self.pre = {}
147
+ self.post = {}
148
+ self.acc_metric = {}
149
+ self.acc = {}
150
+ self.actv = {}
151
+
152
+ def reset(self):
153
+ self.pre.clear()
154
+ self.post.clear()
155
+ self.acc_metric.clear()
156
+ self.acc.clear()
157
+ self.actv.clear()
158
+
159
+
160
+ class CommunicationContext:
161
+ def __init__(self) -> None:
162
+ self.data = {}
163
+
164
+ @staticmethod
165
+ def _agg(data):
166
+ aggregated_data = {}
167
+ for tag, op2tensorlist in data.items():
168
+ aggregated_data[tag] = {}
169
+ for op, tensorlist in op2tensorlist.items():
170
+ aggregated_data[tag][op] = op_aggregate(op, tensorlist)
171
+ return aggregated_data
172
+
173
+ def reset(self):
174
+ self.data = {}
175
+
176
+ def aggregate(self):
177
+ self.data = self._agg(self.data)
178
+
179
+
180
+ class TrainerMon:
181
+ def __init__(self, config_file_path, process_group=None, params_have_main_grad=True) -> None:
182
+ self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
183
+ self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
184
+ self.optimizer_context = defaultdict(OptimizerContext)
185
+ self.cc_context = defaultdict(CommunicationContext)
186
+ self.grad_context = GradContext()
187
+ self.params_have_main_grad = params_have_main_grad
188
+ self.handles = defaultdict(list)
189
+ self.config = load_json(config_file_path)
190
+ validate_config(self.config)
191
+
192
+ self.start_step = self.config.get("start_step", 0)
193
+ self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集
194
+ self.step_interval = self.config.get("step_interval", 1)
195
+ self.has_collect_times = 0
196
+
197
+ # monitor target in module, such as layer, weight, grad
198
+ self.targets = self.config.get("targets", None)
199
+ self.is_select = self.config.get("is_select", False)
200
+ self.module_rank_list = self.config.get("module_ranks", [])
201
+ # only csv supported in mindspore
202
+ self.format = self.config.get('format', MonitorConst.CSV)
203
+ self.eps = self.config.get('eps', 1e-8)
204
+ # monitor mean/max/norm/min/nan...
205
+ self.ops = self.config.get('ops', [])
206
+ self.ndigits = self.config.get('ndigits', 6)
207
+ self.all_xy = self.config.get('all_xy', False)
208
+ # module input/output input_grad/output_grad
209
+ self.xy_distribution = self.config.get('xy_distribution', False)
210
+ # activation forward
211
+ self.forward_only = self.config.get('forward_only', False)
212
+ # activation backward
213
+ self.backward_only = self.config.get('backward_only', False)
214
+ # update vector and ratio vector of adam
215
+ self.ur_distribution = self.config.get('ur_distribution', False)
216
+ # m/v of adam
217
+ self.mv_distribution = self.config.get("mv_distribution", False)
218
+ # weight grad
219
+ self.wg_distribution = self.config.get("wg_distribution", False)
220
+ # optimizer param
221
+ self.param_distribution = self.config.get("param_distribution", False)
222
+ # main grad direction
223
+ self.mg_direction = self.config.get('mg_direction', False)
224
+ # communication ops
225
+ self.cc_distribution = self.config.get("cc_distribution", {})
226
+ if not self.cc_distribution.get('enable', False):
227
+ self.cc_log_only = False
228
+ else:
229
+ self.cc_codeline = self.cc_distribution.get('cc_codeline', [])
230
+ self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
231
+ self.cc_logged_stack = defaultdict(set)
232
+ self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
233
+ self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
234
+ api_register.redirect_api()
235
+ self.common_info()
236
+
237
+ alert_setting = self.config.get('alert', {"rules": []})
238
+ self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
239
+
240
+ local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
241
+
242
+ cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
243
+ unique_id = str(uuid.uuid4())[:8]
244
+ output_base_dir = get_output_base_dir()
245
+
246
+ time_tags = self.config.get("append_output", [])
247
+ if time_tags:
248
+ output_append_dirs = get_target_output_dir(output_base_dir, time_tags[0], time_tags[1])
249
+ try:
250
+ rank = get_rank()
251
+ except Exception as e:
252
+ rank = 0
253
+ tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}")
254
+ logger.error(f"Failed to get rank, setting tensorboard_dir to {tensorboard_dir}")
255
+ pp_stage = 0
256
+ group_mates = [0]
257
+ else:
258
+ if time_tags and str(rank) in output_append_dirs:
259
+ tensorboard_dir = outputappenddirs[str(rank)]
260
+ logger.info(f"Append rank({rank}) result to {tensorboard_dir}")
261
+ else:
262
+ tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}")
263
+ pp_stage = 0
264
+ group_mates = [0]
265
+
266
+ self.rank = rank
267
+
268
+ # 初始化AnomalyData工厂
269
+ self.anomaly_data_factory = None
270
+ if alert_setting.get('dump', False):
271
+ self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates)
272
+
273
+ if self.format not in FORMAT_MAPPING:
274
+ logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
275
+ self.format = MonitorConst.CSV
276
+ writer = FORMAT_MAPPING[self.format]
277
+ self.step_count_per_record = self.config.get('step_count_per_record', 1)
278
+
279
+ self.summary_writer = writer(
280
+ WriterInput(
281
+ tensorboard_dir,
282
+ self.alert_rules,
283
+ unique_id,
284
+ self.anomaly_data_factory,
285
+ self.ndigits,
286
+ self.step_count_per_record
287
+ )
288
+ )
289
+
290
+ self.micro_batch_number = 1
291
+
292
+ self.model = None
293
+ self.weight_hooked = False
294
+ self.optimizer_hooked = False
295
+ self.param_registered = False
296
+ self.vpp = False
297
+ self.dp_group = None
298
+ self.tp_group = None
299
+ self.enable_megatron = False
300
+
301
+ self.param2name = defaultdict(str)
302
+ self.name2index = defaultdict()
303
+ self.name2indices = defaultdict()
304
+ self.name2param = {}
305
+ self.param_name_call_id = {}
306
+ self.duplicate_param = {}
307
+ self.name2tag = {}
308
+ self.call_id = 0
309
+ self.grad_accs = []
310
+ self.handles = defaultdict(list)
311
+
312
+ self.print_struct = self.config.get("print_struct", False)
313
+ self.struct_printed = False
314
+ self.module_struct = defaultdict(dict)
315
+
316
+ # Start
317
+ def set_monitor(
318
+ self,
319
+ model,
320
+ grad_acc_steps=1,
321
+ optimizer=None,
322
+ tp_group=None,
323
+ dp_group=None,
324
+ start_iteration=0):
325
+ global start_step
326
+ start_step = start_iteration
327
+ logger.info(f'grad acc steps {grad_acc_steps}')
328
+ self.hook_optimizer(optimizer)
329
+ self.micro_batch_number = grad_acc_steps
330
+ self.dp_group = dp_group
331
+ self.tp_group = tp_group
332
+
333
+ self.hook_modules(model, grad_acc_steps)
334
+ self._patch_grad_sync()
335
+
336
+ """
337
+ Start
338
+ """
339
+ def hook_optimizer(self, optimizer):
340
+ rank_id = str(get_rank())
341
+ if self.optimizer_hooked:
342
+ return
343
+
344
+ if not self.is_target_rank():
345
+ return
346
+
347
+ m_list = []
348
+ v_list = []
349
+ param_list = []
350
+ grad_names = []
351
+ for param in optimizer.get_parameters():
352
+ if MonitorConst.EXP_AVG_SQ in param.name:
353
+ v_list.append(param)
354
+ elif MonitorConst.EXP_AVG in param.name:
355
+ m_list.append(param)
356
+ else:
357
+ param_list.append(param)
358
+ grad_names.append(param.name)
359
+
360
+ """
361
+ grad reduced
362
+ m/v
363
+ """
364
+ def optimizer_pre_hook_function(opt, grad_names, gradients):
365
+ context = self.optimizer_context[opt]
366
+ if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, \
367
+ self.collect_times):
368
+ return
369
+ gradient_list = gradients[0] if isinstance(gradients, tuple) else gradients
370
+ is_select = self.is_select
371
+ for idx, grad in enumerate(gradient_list):
372
+ grad_name = grad_names[idx]
373
+ if is_select and grad_name not in self.targets:
374
+ continue
375
+ get_single_metrics(self.ops, grad_name, grad, context.param_weight_grad)
376
+
377
+ if self.mv_distribution:
378
+ # fetch mean
379
+ for param in m_list:
380
+ name = param.name
381
+ if is_select and name not in self.targets:
382
+ continue
383
+ get_single_metrics(self.ops, name, param, context.exp_avg_metric)
384
+ # fetch variance
385
+ for param in v_list:
386
+ name = param.name
387
+ if is_select and name not in self.targets:
388
+ continue
389
+ get_single_metrics(self.ops, name, param, context.exp_avg_sq_metric)
390
+ if self.param_distribution:
391
+ for param in param_list:
392
+ get_single_metrics(self.ops, param.name, param, context.param_metric)
393
+ self.generate_wgrad_metrics()
394
+ metric_dict = {}
395
+ for cc in self.cc_context.values():
396
+ cc.aggregate()
397
+ metric_dict.update(cc.data)
398
+ cc.reset()
399
+
400
+ if not metric_dict:
401
+ return
402
+ context.metric_dict = metric_dict
403
+ return
404
+
405
+ def optimizer_post_hook_function(opt, args, gradients, outputs):
406
+ context = self.optimizer_context[opt]
407
+ step_skip = is_skip_step(context.step, self.start_step, self.step_interval, \
408
+ self.has_collect_times, self.collect_times)
409
+ if step_skip:
410
+ context.step += 1
411
+ return
412
+ self.write_xy_tb(context.step)
413
+ self.write_grad_tb(context.step)
414
+ self.write_mv_tb(context)
415
+ self.write_param_tb(context)
416
+
417
+ if context.metric_dict:
418
+ self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
419
+ context.metric_dict.clear()
420
+ self.has_collect_times += 1
421
+ context.step += 1
422
+ if self.anomaly_data_factory:
423
+ self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
424
+ self.summary_writer.clear_anomalies()
425
+ self.call_id = 0
426
+ self.param_name_call_id.clear()
427
+ return
428
+
429
+ def optimizer_pre_hook_wrapper(func, grad_names):
430
+ def wrapper(opt, gradients):
431
+ return func(opt, grad_names, gradients)
432
+ return wrapper
433
+
434
+ def optimizer_post_hook_wrapper(func, args=None):
435
+ def wrapper(opt, gradients, outputs):
436
+ return func(opt, args, gradients, outputs)
437
+ return wrapper
438
+
439
+ optimizer.register_forward_pre_hook(optimizer_pre_hook_wrapper(optimizer_pre_hook_function, grad_names))
440
+ optimizer.register_forward_hook(optimizer_post_hook_wrapper(optimizer_post_hook_function))
441
+
442
+ self.optimizer_hooked = True
443
+ return
444
+
445
+ def write_xy_tb(self, step):
446
+ if not self.xy_distribution:
447
+ return
448
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
449
+ if len(fwd_context.actv) == 0:
450
+ continue
451
+ self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, 'actv')
452
+ fwd_context.actv.clear()
453
+ if self.grad_context.actv:
454
+ self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, 'actv_grad')
455
+
456
+ def write_param_tb(self, opt_context):
457
+ if not self.param_distribution:
458
+ return
459
+ self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, 'param')
460
+
461
+ def write_mv_tb(self, opt_context):
462
+ if not self.mv_distribution:
463
+ return
464
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, 'exp_avg')
465
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq')
466
+
467
+ def write_grad_tb(self, step):
468
+ if not self.wg_distribution:
469
+ return
470
+
471
+ if self.enable_megatron:
472
+ self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
473
+ else:
474
+ self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
475
+ self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
476
+
477
+ def common_info(self):
478
+ if not self.xy_distribution:
479
+ logger.info("> module input/output input_grad/output_grad is not monitored. ")
480
+ if self.forward_only:
481
+ logger.info("> only module forward is monitored. ")
482
+ if not self.ur_distribution:
483
+ logger.info("> update vector and ratio vector of adam is not monitored. ")
484
+ if not self.mv_distribution:
485
+ logger.info("> momentum and variance of adam is not monitored. ")
486
+ if not self.wg_distribution:
487
+ logger.info("> weight grad of specified module is not monitored. ")
488
+ if not self.mg_direction:
489
+ logger.info('> grad and momentum direction will not be compared.')
490
+ if not self.cc_distribution.get('enable', False):
491
+ logger.info("> cc operator is not monitored.")
492
+
493
+ def is_target_rank(self):
494
+ rank_id = str(get_rank())
495
+ if self.module_rank_list and (rank_id not in self.module_rank_list):
496
+ return False
497
+ return True
498
+
499
+ def hook_modules(self, model, grad_acc_steps):
500
+ if not self.is_target_rank():
501
+ return
502
+ if not isinstance(model, list):
503
+ model = [model]
504
+ self.model = model # list
505
+ self._register_param_name(model)
506
+ self.micro_batch_number = grad_acc_steps
507
+ module_in_all_stage = [key for key in self.targets.keys() if MonitorConst.NAME_SEP not in key]
508
+
509
+ for key in module_in_all_stage:
510
+ struct = self.targets.pop(key)
511
+ self.targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(model))})
512
+
513
+ hooked_count = 0
514
+ for vpp_stage, model_chunk in enumerate(model):
515
+ if not isinstance(model_chunk, nn.Cell):
516
+ logger.info("Target Model is not Cell")
517
+ continue
518
+ vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
519
+ targets = [x for x, _ in model_chunk.cells_and_names()] if self.print_struct else self.targets.keys()
520
+ hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
521
+ logger.info(f"> {hooked_count} modules are monitored.")
522
+
523
+ def build_tbtag_tensor_map(self, module_name, tag, tensor):
524
+ rank_id = str(get_rank())
525
+ metrics = {}
526
+ key = get_summary_writer_tag_name(module_name, tag, rank_id)
527
+ if isinstance(tensor, Tensor):
528
+ self._register_param_call_id("_hook_module", key)
529
+ metrics[key] = tensor
530
+ return metrics
531
+
532
+ def generate_wgrad_metrics(self):
533
+ if not self.wg_distribution:
534
+ return {}, {}
535
+
536
+ if self.weight_hooked:
537
+ try:
538
+ get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
539
+ except Exception as e:
540
+ logger.warning(f"An error occurred while generating wgrad pre metrics")
541
+ return {}, {}
542
+
543
+ grad_dict = {}
544
+ for param, name in self.param2name.items():
545
+ if self.duplicate_param.get(name, False):
546
+ continue
547
+ grad = param.main_grad if self.params_have_main_grad else param.grad
548
+ if grad is None:
549
+ logger.warning(f"grad is None: {name}, maybe something wrong happened.")
550
+ continue
551
+ tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
552
+ self._register_param_call_id("hook_optimizer", tag)
553
+ grad_dict[tag] = grad
554
+ try:
555
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
556
+ except Exception as e:
557
+ logger.warning(f"An error occurred while generating wgrad post metrics")
558
+ return {}, {}
559
+ return self.grad_context.post, self.grad_context.pre
560
+
561
+ def _register_param_name(self, model):
562
+ if self.param_registered:
563
+ return
564
+
565
+ if len(model) > 1:
566
+ self.vpp = True
567
+ logger.info('vpp enabled')
568
+
569
+ for vpp_stage, model_chunk in enumerate(model):
570
+ prefix = f'{vpp_stage}{MonitorConst.NAME_SEP}'
571
+ self._register_chunk(model_chunk, prefix)
572
+
573
+ self.param_registered = True
574
+
575
+ def _is_target_param(self, param_name, param, prefix):
576
+ if not self.targets:
577
+ return True
578
+ squash_name = prefix + squash_param_name(param_name)
579
+ name = prefix + param_name
580
+ for target in self.targets.keys():
581
+ if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
582
+ setattr(param, "zero_out_wgrad", True)
583
+ return True
584
+ return False
585
+
586
+ def _register_chunk(self, model_chunk, prefix):
587
+ index = 0
588
+ for param in model_chunk.get_parameters():
589
+ param_name = param.name
590
+ if not param.requires_grad:
591
+ continue
592
+ if self._is_target_param(param_name, param, prefix):
593
+ name = prefix + squash_param_name(param_name)
594
+ if name in self.param2name.values():
595
+ name = prefix + param_name
596
+ self.param2name[param] = name
597
+ self.name2param[name] = param
598
+ self.name2index[name] = index
599
+
600
+ if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group):
601
+ self.duplicate_param[name] = True
602
+ if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
603
+ self.duplicate_param[name] = True
604
+ self.name2tag[name] = {
605
+ MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank),
606
+ MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank)
607
+ }
608
+ index += 1
609
+
610
+ def _is_target_module(self, module_name, targets, vpp_stage):
611
+ if self.all_xy or self.print_struct:
612
+ return vpp_stage + squash_param_name(module_name)
613
+ for pattern in [
614
+ vpp_stage + squash_param_name(module_name),
615
+ vpp_stage + module_name,
616
+ ]:
617
+ if pattern in targets:
618
+ return pattern
619
+ return ""
620
+
621
+ def _hook_module(self, target_names, module, vpp_stage=''):
622
+ if not isinstance(module, nn.Cell):
623
+ # nothing to hook
624
+ return 0
625
+
626
+ def fwd_hook_fun(module, module_input, module_output, name):
627
+ if module not in self.module_fwd_hook_context_by_module:
628
+ self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
629
+ context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
630
+ if not context.struct:
631
+ context.struct = {
632
+ MonitorConst.ACTV_IN: get_param_struct(module_input),
633
+ MonitorConst.ACTV_OUT: get_param_struct(module_output)
634
+ }
635
+ if self.print_struct:
636
+ self.module_struct[context.module_name].update(context.struct)
637
+ return
638
+ if not module.training:
639
+ return
640
+ if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, \
641
+ self.collect_times):
642
+ step_accumulates_one(context, self.micro_batch_number)
643
+ return
644
+ if not context.format_by_arg:
645
+ context.set_format_by_arg(MonitorConst.ACTV_IN, self.targets)
646
+ context.set_format_by_arg(MonitorConst.ACTV_OUT, self.targets)
647
+ if not context.format_by_arg:
648
+ return
649
+ if not context.verified:
650
+ if not context.ignore_in:
651
+ context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
652
+ module_input, context.module_name,
653
+ MonitorConst.ACTV_IN)
654
+ context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
655
+ module_output, context.module_name,
656
+ MonitorConst.ACTV_OUT)
657
+ context.verified = True
658
+
659
+ tbtag_tensor_map = {}
660
+ if not context.ignore_in:
661
+ cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
662
+ tbtag_tensor_map.update(
663
+ self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
664
+ cared_input))
665
+ cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
666
+ tbtag_tensor_map.update(
667
+ self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT,
668
+ cared_output))
669
+ try:
670
+ get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
671
+ except Exception as e:
672
+ logger.warning(f"An error occurred while generating forward activation metrics")
673
+
674
+ step_accumulates_one(context, self.micro_batch_number)
675
+ return
676
+
677
+ def bwd_hook_fun(module, input_grad, output_grad):
678
+ context: ModuleHookContext = self.module_bwd_hook_context_by_module[module]
679
+ if not context.struct:
680
+ context.struct = {
681
+ MonitorConst.ACTVGRAD_IN: get_param_struct(input_grad),
682
+ MonitorConst.ACTVGRAD_OUT: get_param_struct(output_grad)
683
+ }
684
+ if self.print_struct:
685
+ self.module_struct[context.module_name].update(context.struct)
686
+ return
687
+
688
+ if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, \
689
+ self.collect_times):
690
+ step_accumulates_one(context, self.micro_batch_number)
691
+ return
692
+
693
+ if not context.format_by_arg:
694
+ context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.targets)
695
+ context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.targets)
696
+ if not context.format_by_arg:
697
+ return
698
+ if not context.verified:
699
+ if not context.ignore_in:
700
+ context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
701
+ input_grad, context.module_name,
702
+ MonitorConst.ACTVGRAD_IN)
703
+ context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
704
+ output_grad, context.module_name,
705
+ MonitorConst.ACTVGRAD_OUT)
706
+ context.verified = True
707
+
708
+ tbtag_tensor_map = {}
709
+ if not context.ignore_in:
710
+ cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
711
+ tbtag_tensor_map.update(
712
+ self.build_tbtag_tensor_map(
713
+ f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad))
714
+ cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
715
+ tbtag_tensor_map.update(
716
+ self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT,
717
+ cared_output_grad))
718
+
719
+ if context.micro_step == 0 and context.actvgrad:
720
+ logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
721
+ f"maybe something wrong happened. Now clear it.")
722
+ context.actvgrad.clear()
723
+ try:
724
+ get_metrics(self.ops, tbtag_tensor_map, self.eps, self.grad_context.actv)
725
+ except Exception as e:
726
+ logger.warning(f"An error occurred while generating backward activation metrics: {e}")
727
+
728
+ step_accumulates_one(context, self.micro_batch_number)
729
+ return
730
+
731
+ def fwd_hook_fun_wrapper(fwd_hook_fun, name):
732
+ def wrapper(module, module_input, module_output):
733
+ return fwd_hook_fun(module, module_input, module_output, name)
734
+ return wrapper
735
+
736
+ if self.backward_only and self.forward_only:
737
+ logger.warning('not enable backward_only and forward_only simultaneously')
738
+ hooked_count = 0
739
+ if self.xy_distribution or self.print_struct:
740
+ for module_name, submodule in module.cells_and_names():
741
+ name = self._is_target_module(module_name, target_names, vpp_stage)
742
+ if not name:
743
+ continue
744
+ if not self.backward_only:
745
+ handle = submodule.register_forward_hook(fwd_hook_fun_wrapper(fwd_hook_fun, name=name))
746
+ self.handles['xy'].append(handle)
747
+ if not self.forward_only:
748
+ handle = submodule.register_backward_hook(bwd_hook_fun)
749
+ self.handles['xy'].append(handle)
750
+ self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
751
+ logger.info(f"> {name} is monitored successfully")
752
+ hooked_count += 1
753
+ return hooked_count
754
+
755
+ def _register_param_call_id(self, hook_name: str, key: str):
756
+ """
757
+ :param hook_name:
758
+ :param key: str, '0:relu_0/output_grad'
759
+ :return:
760
+ """
761
+ logger.debug(f"{hook_name} {key}: {self.call_id}")
762
+ self.param_name_call_id[key] = self.call_id
763
+ self.call_id += 1
764
+
765
+ def _patch_grad_sync(self):
766
+ # mindspore 暂不使用megatron
767
+ def patch_sync(sync_grad_func):
768
+ def wrapper(bucket):
769
+ grad_dict = {}
770
+ for param, name in self.param2name.items():
771
+ if param not in bucket.params_list:
772
+ continue
773
+ grad = param.main_grad if self.params_have_main_grad else param.grad
774
+ if grad is None:
775
+ logger.warning(f"grad is None: {name}, maybe something wrong happened.")
776
+ continue
777
+ tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
778
+ if tag is None:
779
+ continue
780
+ grad_dict[tag] = grad
781
+ try:
782
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
783
+ except Exception as e:
784
+ logger.warning(f"An error occurred while generating weight grad metrics")
785
+ out = sync_grad_func(bucket)
786
+ return out
787
+
788
+ return wrapper
789
+
790
+ self.enable_megatron = False
791
+
792
+ if not self.wg_distribution:
793
+ return
794
+
795
+ if self.enable_megatron:
796
+ Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
797
+ else:
798
+ self._hook_weights()
799
+
800
+ def _hook_weights(self):
801
+ context = self.grad_context
802
+
803
+ @_no_grad()
804
+ def param_hook(grad, context_dict, param, key):
805
+ param.micro_step += 1
806
+ self._register_param_call_id("param_hook", key)
807
+ if param.micro_step == self.micro_batch_number:
808
+ param.micro_step = 0
809
+ context_dict[key] = grad
810
+
811
+ def param_hook_wrapper(param_hook, context_dict, param, key):
812
+ def wrapper(grad):
813
+ return param_hook(grad, context_dict, param, key)
814
+ return wrapper
815
+
816
+ for param, name in self.param2name.items():
817
+ key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
818
+ setattr(param, 'micro_step', 0)
819
+ handle = param.register_hook(param_hook_wrapper(param_hook, context_dict=context.acc, param=param, key=key))
820
+ self.handles['wgrads'].append(handle)
821
+ self.weight_hooked = True