mindstudio-probe 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/RECORD +85 -66
  3. msprobe/README.md +2 -2
  4. msprobe/core/common/const.py +34 -9
  5. msprobe/core/common/inplace_ops.yaml +1 -0
  6. msprobe/core/common/utils.py +14 -0
  7. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  8. msprobe/core/compare/merge_result/merge_result.py +8 -7
  9. msprobe/core/compare/merge_result/utils.py +81 -0
  10. msprobe/core/compare/utils.py +10 -0
  11. msprobe/core/data_dump/data_collector.py +58 -13
  12. msprobe/core/data_dump/data_processor/base.py +92 -8
  13. msprobe/core/data_dump/data_processor/factory.py +3 -0
  14. msprobe/core/data_dump/data_processor/mindspore_processor.py +17 -4
  15. msprobe/core/data_dump/data_processor/pytorch_processor.py +58 -7
  16. msprobe/core/data_dump/json_writer.py +26 -8
  17. msprobe/docs/01.installation.md +25 -0
  18. msprobe/docs/02.config_introduction.md +14 -12
  19. msprobe/docs/03.config_examples.md +24 -0
  20. msprobe/docs/05.data_dump_PyTorch.md +34 -15
  21. msprobe/docs/06.data_dump_MindSpore.md +45 -22
  22. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -2
  23. msprobe/docs/19.monitor.md +257 -260
  24. msprobe/docs/21.visualization_PyTorch.md +10 -0
  25. msprobe/docs/22.visualization_MindSpore.md +11 -0
  26. msprobe/docs/27.dump_json_instruction.md +24 -20
  27. msprobe/docs/28.debugger_save_instruction.md +94 -0
  28. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  29. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  30. msprobe/mindspore/__init__.py +1 -0
  31. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +26 -6
  32. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  33. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  34. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  35. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  36. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  37. msprobe/mindspore/common/utils.py +20 -2
  38. msprobe/mindspore/debugger/debugger_config.py +25 -2
  39. msprobe/mindspore/debugger/precision_debugger.py +25 -6
  40. msprobe/mindspore/dump/hook_cell/api_registry.py +2 -0
  41. msprobe/mindspore/dump/jit_dump.py +7 -6
  42. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  43. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  44. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  45. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  46. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  47. msprobe/mindspore/monitor/features.py +63 -0
  48. msprobe/mindspore/monitor/module_hook.py +821 -0
  49. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  50. msprobe/mindspore/monitor/utils.py +267 -0
  51. msprobe/mindspore/ms_config.py +8 -2
  52. msprobe/mindspore/service.py +95 -21
  53. msprobe/pytorch/__init__.py +0 -1
  54. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  55. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  56. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  57. msprobe/pytorch/bench_functions/mish.py +21 -0
  58. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  59. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  60. msprobe/pytorch/common/utils.py +71 -0
  61. msprobe/pytorch/debugger/debugger_config.py +19 -9
  62. msprobe/pytorch/debugger/precision_debugger.py +14 -0
  63. msprobe/pytorch/dump/module_dump/module_processer.py +10 -30
  64. msprobe/pytorch/function_factory.py +7 -1
  65. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  66. msprobe/pytorch/hook_module/wrap_distributed.py +4 -0
  67. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  68. msprobe/pytorch/monitor/csv2tb.py +10 -12
  69. msprobe/pytorch/monitor/module_hook.py +123 -104
  70. msprobe/pytorch/monitor/module_metric.py +6 -6
  71. msprobe/pytorch/monitor/optimizer_collect.py +45 -63
  72. msprobe/pytorch/monitor/utils.py +8 -43
  73. msprobe/pytorch/pt_config.py +19 -22
  74. msprobe/pytorch/service.py +103 -24
  75. msprobe/visualization/builder/graph_builder.py +31 -5
  76. msprobe/visualization/builder/msprobe_adapter.py +7 -5
  77. msprobe/visualization/graph/base_node.py +3 -2
  78. msprobe/visualization/graph/distributed_analyzer.py +80 -3
  79. msprobe/visualization/graph/node_op.py +4 -2
  80. msprobe/visualization/graph_service.py +3 -4
  81. msprobe/visualization/utils.py +10 -2
  82. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  83. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  84. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  85. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
@@ -136,8 +136,8 @@ class AnomalyDataFactory(ABC):
136
136
  tag_name = tag[0]
137
137
  param_name = tag_name.split('/')[0]
138
138
  call_id = self.name2callid.get(tag_name, -1)
139
- if MonitorConst.VPP_SEP in param_name:
140
- vpp_stage = int(param_name.split(MonitorConst.VPP_SEP)[0])
139
+ if MonitorConst.NAME_SEP in param_name:
140
+ vpp_stage = int(param_name.split(MonitorConst.NAME_SEP)[0])
141
141
  else:
142
142
  vpp_stage = 0
143
143
 
@@ -161,10 +161,10 @@ class TrainStage:
161
161
  OPTIMIZER_STAGE = 2
162
162
 
163
163
 
164
- FORWARD_KEY = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT]
165
- BACKWARD_KEY = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT,
166
- MonitorConst.PRE_GRAD, MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD]
167
- OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EFXP_AVG_SQ]
164
+ FORWARD_KEY = [MonitorConst.ACTV]
165
+ BACKWARD_KEY = [MonitorConst.ACTVGRAD, MonitorConst.PRE_GRAD,
166
+ MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD]
167
+ OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EXP_AVG_SQ]
168
168
  TRAIN_STAGE = {
169
169
  **{key_: TrainStage.FORWARD_STAGE for key_ in FORWARD_KEY},
170
170
  **{key_: TrainStage.BACKWARD_STAGE for key_ in BACKWARD_KEY},
@@ -221,7 +221,7 @@ class GradAnomalyData:
221
221
  @staticmethod
222
222
  def get_train_stage(tag_name):
223
223
  """
224
- :param tag_name: "0:fc2_0/rank0/input", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/efxp_avg_sq"
224
+ :param tag_name: "0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq"
225
225
  :return: int, if forward return 0; if backward return 1; if optimizer return 2
226
226
  """
227
227
  key_ = tag_name.split("/")[-1]
@@ -361,10 +361,10 @@ class CSVWriterWithAD(BaseWriterWithAD):
361
361
 
362
362
  new_data = []
363
363
  for name, metric_value in self.context_dict.items():
364
- if MonitorConst.VPP_SEP not in name:
365
- new_data.append([name] + [step] + metric_value)
366
- else:
367
- new_data.append(name.split(MonitorConst.VPP_SEP) + [step] + metric_value)
364
+ new_line = name.split(MonitorConst.NAME_SEP) + metric_value
365
+ new_line.insert(2, step)
366
+ new_data.append(new_line)
367
+
368
368
  new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan")
369
369
  write_df_to_csv(new_data, filepath, mode='a+', header=False)
370
370
  self.context_dict = defaultdict(list)
@@ -381,26 +381,11 @@ class CSVWriterWithAD(BaseWriterWithAD):
381
381
  def write_metrics(self, ops, metric_value, step, prefix=''):
382
382
  super().write_metrics(ops, metric_value, step, prefix='')
383
383
 
384
- # generate csv headers
385
- # set hashmap to reduce the number of headers generated.
386
- # 前向的norm用input.ops_和output.ops_,反向的用input_grad.ops_和output_grad.ops_
387
- if prefix in {"actv", "actv_grad"}:
388
- if prefix == "actv":
389
- input_and_output = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT]
390
- else:
391
- input_and_output = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
392
- ops_ = [MonitorConst.DOT.join(i) for i in itertools.product(input_and_output, ops)]
393
- csv_header = ["module_name", "step", *ops_]
384
+ if prefix in [MonitorConst.ACTV, MonitorConst.ACTVGRAD]:
385
+ self.header = MonitorConst.CSV_HEADER_XY + ops
394
386
  else:
395
- csv_header = ["param_name", "step", *ops]
396
-
397
- keys = list(metric_value.keys())
398
- if keys and MonitorConst.VPP_SEP in keys[0]:
399
- csv_header.insert(0, "vpp_stage")
400
-
401
- self.header = csv_header
387
+ self.header = MonitorConst.CSV_HEADER + ops
402
388
  self.write_csv(prefix, step)
403
- self.header = []
404
389
 
405
390
  def close(self):
406
391
  pass
@@ -31,28 +31,26 @@ all_data_type_list = ["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unredu
31
31
  CSV_FILE_SUFFIX = r"_\d+-\d+\.csv"
32
32
 
33
33
 
34
- def parse_step_line(data, line_id, name, ops):
35
- vp_id = data["vpp_stage"][line_id]
36
- module_name = data[name][line_id]
37
- step = data["step"][line_id]
34
+ def parse_step_line(line, ops):
35
+ vp_id = line["vpp_stage"]
36
+ module_name = line[MonitorConst.HEADER_NAME]
37
+ step = line["step"]
38
38
  vpp_name = f"vp{vp_id}:{module_name}"
39
+ if 'micro_step' in line:
40
+ vpp_name = f'{vpp_name}{MonitorConst.NAME_SEP}micro{line["micro_step"]}'
39
41
  ops_result = {}
40
42
  for op in ops:
41
- ops_result[op] = data[op][line_id]
43
+ ops_result[op] = line[op]
42
44
  return vpp_name, step, ops_result
43
45
 
44
46
 
45
47
  def parse_step_fn(filepath):
46
48
  data = read_csv(filepath)
47
-
48
- header = list(data.keys())
49
- name = header[MonitorConst.HEADER_NAME_INDEX]
50
- ops = header[MonitorConst.OPS_START_INDEX:]
51
-
49
+ ops = [k for k in data.keys() if k in MonitorConst.OP_LIST]
52
50
  parse_step_result = {}
53
51
 
54
- for line_id in range(len(data)):
55
- vpp_name, step, ops_result = parse_step_line(data, line_id, name, ops)
52
+ for _, line in data.iterrows():
53
+ vpp_name, step, ops_result = parse_step_line(line, ops)
56
54
  if vpp_name not in parse_step_result:
57
55
  parse_step_result[vpp_name] = {}
58
56
  if step in parse_step_result[vpp_name]:
@@ -22,12 +22,12 @@ from functools import partial
22
22
  import pytz
23
23
  import torch
24
24
  import torch.distributed as dist
25
- from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
26
25
  from torch.utils.hooks import BackwardHook
27
26
 
28
- from msprobe.core.common.const import MonitorConst
27
+ from msprobe.core.common.const import MonitorConst, Const
29
28
  from msprobe.core.common.file_utils import load_json, save_json
30
29
  from msprobe.pytorch.common.log import logger
30
+ from msprobe.pytorch.common.utils import is_recomputation
31
31
  from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter
32
32
  from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
33
33
  CSVWriterWithAD, BaseWriterWithAD, WriterInput
@@ -37,8 +37,8 @@ from msprobe.pytorch.monitor.features import get_sign_matches
37
37
  from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
38
38
  TensorMetrics, squash_param_name
39
39
  from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
40
- from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory, OptimizerMon
41
- from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, is_recomputation, \
40
+ from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
41
+ from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \
42
42
  get_output_base_dir, get_target_output_dir
43
43
  from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
44
44
 
@@ -46,6 +46,7 @@ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
46
46
  if not torch_version_above_or_equal_2:
47
47
  raise ValueError("monitor require torch>=2.0")
48
48
 
49
+
49
50
  FORMAT_MAPPING = {
50
51
  MonitorConst.TENSORBOARD: SummaryWriterWithAD,
51
52
  MonitorConst.CSV: CSVWriterWithAD,
@@ -85,9 +86,6 @@ class ModuleHookContext:
85
86
  :param target_config: target obj in config json.
86
87
  :return:
87
88
  """
88
- valid_key = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT, MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
89
- if key_name not in valid_key:
90
- raise ValueError(f"key({key_name}) error, valid_key: {valid_key}")
91
89
  cared = target_config.get(self.module_name, self.struct)
92
90
  if key_name in cared:
93
91
  target_module_config = cared[key_name]
@@ -178,20 +176,16 @@ class GradContext:
178
176
  class TrainerMon:
179
177
  tensor_metrics = TensorMetrics()
180
178
 
181
- def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None:
182
- """
183
- opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer"
184
- """
179
+ def __init__(self, config_file_path, process_group=None, params_have_main_grad=True) -> None:
185
180
  # TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
186
181
  self.config_file_path = config_file_path
187
182
  self.process_group = get_process_group(process_group)
188
183
  self.params_have_main_grad = params_have_main_grad
189
- self.opt_ty = opt_ty
190
- self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty)
191
184
  self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
192
185
  self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
193
186
  self.origin_step_func = None
194
- self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过switch开关直接打开
187
+ self.origin_start_grad_sync = None
188
+ self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
195
189
  self.config = load_json(config_file_path)
196
190
  validate_config(self.config)
197
191
 
@@ -219,13 +213,15 @@ class TrainerMon:
219
213
  self.pp_stage = 0
220
214
  self.group_mates = [0]
221
215
 
222
- # TYPE2: 只会在monitor_gnorm_with_ad()主调中赋值的变量
216
+ # TYPE2: 只会在set_monitor()主调中赋值的变量
223
217
  self.model = None
224
218
  self.vpp = False
225
219
  self.dp_group = None
226
220
  self.tp_group = None
227
221
  self.enable_megatron = False
228
222
  self.micro_batch_number = 1
223
+ self.optimizer_class = None
224
+ self.optimizer_mon = None
229
225
 
230
226
  # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
231
227
  self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
@@ -253,7 +249,7 @@ class TrainerMon:
253
249
  self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
254
250
  if self.dynamic_enable:
255
251
  logger.warning(f"DYNAMIC_MONITOR is set, "
256
- f"please make sure you have 'switch' and 'collect_times' item in {self.config_file_path}")
252
+ f"please make sure you have 'dynamic_on' and 'collect_times' in {self.config_file_path}")
257
253
  self.monitoring = False
258
254
  else:
259
255
  self.set_config()
@@ -273,10 +269,6 @@ class TrainerMon:
273
269
  def ops(self, value):
274
270
  self._ops = validate_ops(value)
275
271
 
276
- @staticmethod
277
- def set_wrapped_optimizer(_wrapped_optimizer):
278
- OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer)
279
-
280
272
  @staticmethod
281
273
  def has_register_backward_hook(module_name, module):
282
274
  if hasattr(module, '_backward_hooks') and \
@@ -308,7 +300,7 @@ class TrainerMon:
308
300
  self.has_collect_times = 0 # 重设采集计数器
309
301
  self.print_struct = self.config.get("print_struct", False)
310
302
  self.module_rank_list = self.config.get("module_ranks", [])
311
- self.format = self.config.get('format', 'tensorboard')
303
+ self.format = self.config.get('format', MonitorConst.CSV)
312
304
  self.eps = self.config.get('eps', 1e-8)
313
305
  self.ops = self.config.get('ops', [])
314
306
  self.ndigits = self.config.get('ndigits', 6)
@@ -344,7 +336,13 @@ class TrainerMon:
344
336
 
345
337
  # 初始化writer, 创建输出目录
346
338
  if self.format not in FORMAT_MAPPING:
347
- raise ValueError(f"Unsupported format: {self.format}")
339
+ logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
340
+ self.format = MonitorConst.CSV
341
+
342
+ if self.ur_distribution and self.format != 'tensorboard':
343
+ logger.error("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution")
344
+ self.ur_distribution = False
345
+
348
346
  writer = FORMAT_MAPPING[self.format]
349
347
  self.step_count_per_record = self.config.get('step_count_per_record', 1)
350
348
 
@@ -393,25 +391,20 @@ class TrainerMon:
393
391
  logger.info_on_rank_0('> grad and momentum direction will not be compared.')
394
392
  if not self.cc_distribution.get('enable', False):
395
393
  logger.info_on_rank_0("> cc operator is not monitored.")
396
- if not self.opt_ty:
397
- if self.ur_distribution:
398
- raise Exception("ur_distribution cannot be enabled with unknown optimizer.")
399
- if self.mv_distribution:
400
- raise Exception("mv_distribution cannot be enabled with unknown optimizer.")
401
394
 
402
395
  def hook_modules(self):
403
396
  if self.module_rank_list and (self.rank not in self.module_rank_list):
404
397
  return
405
398
 
406
399
  targets = self.config['targets']
407
- module_in_all_stage = [key for key in targets.keys() if MonitorConst.VPP_SEP not in key]
400
+ module_in_all_stage = [key for key in targets.keys() if MonitorConst.NAME_SEP not in key]
408
401
  for key in module_in_all_stage:
409
402
  struct = targets.pop(key)
410
- targets.update({f'{vpp_stage}{MonitorConst.VPP_SEP}{key}': struct for vpp_stage in range(len(self.model))})
403
+ targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
411
404
 
412
405
  hooked_count = 0
413
406
  for vpp_stage, model_chunk in enumerate(self.model):
414
- vpp_stage = f'{vpp_stage}{MonitorConst.VPP_SEP}'
407
+ vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
415
408
  targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
416
409
  'targets'].keys()
417
410
  hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
@@ -449,8 +442,8 @@ class TrainerMon:
449
442
  return
450
443
  opt_context.exp_avg_metric = {}
451
444
  opt_context.exp_avg_sq_metric = {}
452
- m_tag_tensor_map = self.generate_param_map('exp_avg', opt_context.param_exp_avg)
453
- v_tag_tensor_map = self.generate_param_map('efxp_avg_sq', opt_context.param_exp_avg_sq)
445
+ m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
446
+ v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
454
447
  get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
455
448
  get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
456
449
 
@@ -474,9 +467,10 @@ class TrainerMon:
474
467
  grad_dict[tag] = grad
475
468
 
476
469
  get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
477
- return self.grad_context.post, self.grad_context.pre
470
+ unreduced_grad = self.grad_context.acc_metric if self.weight_hooked else self.grad_context.pre
471
+ return self.grad_context.post, unreduced_grad
478
472
 
479
- def monitor_gnorm_with_ad(
473
+ def set_monitor(
480
474
  self,
481
475
  model,
482
476
  grad_acc_steps=1,
@@ -492,6 +486,7 @@ class TrainerMon:
492
486
  self.micro_batch_number = grad_acc_steps
493
487
  self.dp_group = dp_group
494
488
  self.tp_group = tp_group
489
+ self.optimizer_mon, self.optimizer_class = OptimizerMonFactory.create_optimizer_mon(optimizer)
495
490
  self.hook_step_final(optimizer)
496
491
  if not isinstance(model, list):
497
492
  model = [model]
@@ -529,6 +524,8 @@ class TrainerMon:
529
524
  return actv, actv_grad
530
525
 
531
526
  def reload_xy(self, xy_distribution=False):
527
+ logger.warning("reload_xy() is deprecated and will be removed in a future version. "
528
+ "Use DYNAMIC_MONITOR instead.")
532
529
  self.xy_distribution = xy_distribution
533
530
 
534
531
  for handle in self.handles['xy']:
@@ -547,21 +544,23 @@ class TrainerMon:
547
544
  for _, fwd_context in self.module_fwd_hook_context_by_module.items():
548
545
  if len(fwd_context.actv) == 0:
549
546
  continue
550
- self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, 'actv')
547
+ self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, MonitorConst.ACTV)
551
548
  fwd_context.actv.clear()
552
549
  if self.grad_context.actv:
553
- self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, 'actv_grad')
550
+ self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD)
554
551
 
555
552
  def write_param_tb(self, opt_context):
556
553
  if not self.param_distribution:
557
554
  return
558
- self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, 'param')
555
+ self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, MonitorConst.PARAM)
559
556
 
560
557
  def write_mv_tb(self, opt_context):
561
558
  if not self.mv_distribution:
562
559
  return
563
- self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, 'exp_avg')
564
- self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq')
560
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric,
561
+ opt_context.step, MonitorConst.EXP_AVG)
562
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric,
563
+ opt_context.step, MonitorConst.EXP_AVG_SQ)
565
564
 
566
565
  def write_grad_tb(self, step):
567
566
  if not self.wg_distribution:
@@ -592,15 +591,13 @@ class TrainerMon:
592
591
  # skip generate metrics
593
592
  if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0:
594
593
  return
595
- if self.opt_ty in MonitorConst.DEEPSPEED_OPT_TY:
594
+ if MonitorConst.DEEPSPEED_ZERO_OPT_FILTER in self.optimizer_class: # use deepspeed with zero1/2/3
596
595
  if not self.name2indices:
597
- self.name2indices = self.mix_precision_optimizer_mon.get_param_index(self.param2name,
598
- self.name2index)
599
- mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name,
600
- self.name2indices)
596
+ self.name2indices = self.optimizer_mon.get_param_index(self.param2name, self.name2index, optimizer)
597
+ mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name, self.name2indices)
601
598
  self.param2name = mv_result.grad
602
599
  else:
603
- mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name)
600
+ mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name)
604
601
  context.param_exp_avg = mv_result.exp_avg
605
602
  context.param_exp_avg_sq = mv_result.exp_avg_sq
606
603
  context.param_adam_update = mv_result.update
@@ -647,13 +644,8 @@ class TrainerMon:
647
644
  if self.optimizer_hooked:
648
645
  return
649
646
 
650
- if optimizer:
651
- optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
652
- self.handles['optimizer'] = []
653
- else:
654
- if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list):
655
- step_pre_hook = register_optimizer_step_pre_hook(optimizer_pre_step_hook)
656
- self.handles['optimizer'] = [step_pre_hook]
647
+ optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
648
+
657
649
  self.optimizer_hooked = True
658
650
  return
659
651
 
@@ -677,7 +669,7 @@ class TrainerMon:
677
669
  logger.error(f"get config.json wrong because {e}, not updated, please check!!!")
678
670
  return
679
671
 
680
- if config.get("switch", False):
672
+ if config.get("dynamic_on", False):
681
673
  try:
682
674
  validate_config(config)
683
675
  self.config = config
@@ -745,11 +737,9 @@ class TrainerMon:
745
737
  return out
746
738
  return wrapper
747
739
 
748
- if optimizer:
749
- optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
750
- self.origin_step_func = optimizer.__class__.step
751
- else:
752
- register_optimizer_step_post_hook(step_final_hook)
740
+ optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
741
+ self.origin_step_func = optimizer.__class__.step
742
+
753
743
  return
754
744
 
755
745
  def _remove_all_hooks(self, optimizer):
@@ -764,17 +754,28 @@ class TrainerMon:
764
754
  bwd_context.reset()
765
755
  self.grad_context.reset() # 权重梯度和激活值梯度都在这
766
756
 
767
- for handle in self.handles['wgrads']:
768
- handle.remove()
769
- self.handles['wgrads'].clear()
770
- self.weight_hooked = False
757
+ if self.origin_start_grad_sync: # megatron
758
+ try:
759
+ from megatron.core.distributed.param_and_grad_buffer import Bucket
760
+ Bucket.start_grad_sync = self.origin_start_grad_sync
761
+ logger.info("remove Bucket start_grad_sync")
762
+ except ImportError:
763
+ pass
764
+ try:
765
+ from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
766
+ _ParamAndGradBucketGroup.start_grad_sync = self.origin_start_grad_sync
767
+ logger.info("remove _ParamAndGradBucketGroup start_grad_sync")
768
+ except ImportError:
769
+ pass
770
+ else: # not megatron
771
+ for handle in self.handles['wgrads']:
772
+ handle.remove()
773
+ self.handles['wgrads'].clear()
774
+ self.weight_hooked = False
771
775
 
772
- if len(self.handles['optimizer']) == 0 and self.optimizer_hooked:
776
+ if self.optimizer_hooked:
773
777
  optimizer.__class__.step = self.origin_step_func
774
- else:
775
- for handle in self.handles['optimizer']:
776
- handle.remove()
777
- self.handles['optimizer'].clear()
778
+
778
779
  for _, context in self.optimizer_context.items():
779
780
  context.reset()
780
781
  self.optimizer_hooked = False
@@ -800,17 +801,17 @@ class TrainerMon:
800
801
 
801
802
  def _remove_all_hooks_final(self, optimizer):
802
803
  if self.dynamic_enable:
803
- # 结束后自动重置switch为False等待用户手动开启
804
+ # 结束后自动重置dynamic_on为False等待用户手动开启
804
805
  try:
805
806
  config = load_json(self.config_file_path)
806
- config['switch'] = False
807
+ config['dynamic_on'] = False
807
808
  save_json(self.config_file_path, config, indent=2)
808
809
  config_timestamp = os.path.getmtime(self.config_file_path)
809
810
  self.config_timestamp = config_timestamp
810
811
  logger.info(
811
- "Finish monitor, set config'switch=False, will restart by set switch=True and update content")
812
+ "Finish monitor, set config'dynamic_on=False, will restart by set it to True and update config")
812
813
  except Exception as e:
813
- logger.warning(f"Finish monitor, set config'switch=False fail because {e}, please check!!!")
814
+ logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!")
814
815
  logger.info("Finish monitor")
815
816
  self._remove_all_hooks(optimizer)
816
817
 
@@ -871,7 +872,7 @@ class TrainerMon:
871
872
 
872
873
  def _register_param_name(self):
873
874
  for vpp_stage, model_chunk in enumerate(self.model):
874
- prefix = f'{vpp_stage}{MonitorConst.VPP_SEP}'
875
+ prefix = f'{vpp_stage}{MonitorConst.NAME_SEP}'
875
876
  self._register_chunk(model_chunk, prefix)
876
877
 
877
878
  def _is_target_module(self, module_name, targets, vpp_stage):
@@ -900,35 +901,37 @@ class TrainerMon:
900
901
  context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
901
902
  if not context.struct:
902
903
  context.struct = {
903
- MonitorConst.ACTV_IN: get_param_struct(module_input),
904
- MonitorConst.ACTV_OUT: get_param_struct(module_output)
904
+ Const.INPUT: get_param_struct(module_input),
905
+ Const.OUTPUT: get_param_struct(module_output)
905
906
  }
906
907
  if self.print_struct:
907
908
  self.module_struct[context.module_name].update(context.struct)
908
909
  return
909
910
  if not context.format_by_arg:
910
- context.set_format_by_arg(MonitorConst.ACTV_IN, self.config['targets'])
911
- context.set_format_by_arg(MonitorConst.ACTV_OUT, self.config['targets'])
911
+ context.set_format_by_arg(Const.INPUT, self.config['targets'])
912
+ context.set_format_by_arg(Const.OUTPUT, self.config['targets'])
912
913
  if not context.format_by_arg:
913
914
  return
914
915
  if not context.verified:
915
- context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
916
+ context.focused_in_col = validate_config_spec(context.format_by_arg[Const.INPUT],
916
917
  module_input, context.module_name,
917
- MonitorConst.ACTV_IN)
918
- context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
918
+ Const.INPUT)
919
+ context.focused_out_col = validate_config_spec(context.format_by_arg[Const.OUTPUT],
919
920
  module_output, context.module_name,
920
- MonitorConst.ACTV_OUT)
921
+ Const.OUTPUT)
921
922
  context.verified = True
922
923
  # expect output be tensor type
923
924
  tbtag_tensor_map = {}
924
925
  cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
925
926
  tbtag_tensor_map.update(
926
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
927
- cared_input))
927
+ self.build_tbtag_tensor_map(
928
+ f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
929
+ MonitorConst.ACTV, cared_input))
928
930
  cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
929
931
  tbtag_tensor_map.update(
930
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT,
931
- cared_output))
932
+ self.build_tbtag_tensor_map(
933
+ f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
934
+ MonitorConst.ACTV, cared_output))
932
935
 
933
936
  get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
934
937
  context.micro_step += 1
@@ -940,35 +943,37 @@ class TrainerMon:
940
943
  context: ModuleHookContext = self.module_bwd_hook_context_by_module[module]
941
944
  if not context.struct:
942
945
  context.struct = {
943
- MonitorConst.ACTVGRAD_IN: get_param_struct(input_grad),
944
- MonitorConst.ACTVGRAD_OUT: get_param_struct(output_grad)
946
+ MonitorConst.INPUT_GRAD: get_param_struct(input_grad),
947
+ MonitorConst.OUTPUT_GRAD: get_param_struct(output_grad)
945
948
  }
946
949
  if self.print_struct:
947
950
  self.module_struct[context.module_name].update(context.struct)
948
951
  return
949
952
  if not context.format_by_arg:
950
- context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.config['targets'])
951
- context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.config['targets'])
953
+ context.set_format_by_arg(MonitorConst.INPUT_GRAD, self.config['targets'])
954
+ context.set_format_by_arg(MonitorConst.OUTPUT_GRAD, self.config['targets'])
952
955
  if not context.format_by_arg:
953
956
  return
954
957
  if not context.verified:
955
- context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
956
- input_grad, context.module_name,
957
- MonitorConst.ACTVGRAD_IN)
958
- context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
959
- output_grad, context.module_name,
960
- MonitorConst.ACTVGRAD_OUT)
958
+ context.focused_in_col = validate_config_spec(
959
+ context.format_by_arg[MonitorConst.INPUT_GRAD],
960
+ input_grad, context.module_name, MonitorConst.INPUT_GRAD)
961
+ context.focused_out_col = validate_config_spec(
962
+ context.format_by_arg[MonitorConst.OUTPUT_GRAD],
963
+ output_grad, context.module_name, MonitorConst.OUTPUT_GRAD)
961
964
  context.verified = True
962
965
 
963
966
  tbtag_tensor_map = {}
964
967
  cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
965
968
  tbtag_tensor_map.update(
966
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN,
967
- cared_input_grad))
969
+ self.build_tbtag_tensor_map(
970
+ f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
971
+ MonitorConst.ACTV, cared_input_grad))
968
972
  cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
969
973
  tbtag_tensor_map.update(
970
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT,
971
- cared_output_grad))
974
+ self.build_tbtag_tensor_map(
975
+ f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
976
+ MonitorConst.ACTV, cared_output_grad))
972
977
 
973
978
  if context.micro_step == 0 and context.actvgrad:
974
979
  logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
@@ -1006,7 +1011,10 @@ class TrainerMon:
1006
1011
  def patch_sync(sync_grad_func):
1007
1012
  def wrapper(bucket):
1008
1013
  grad_dict = {}
1009
- bucket_params_id_list = [id(params) for params in bucket.params_list]
1014
+ # Megatron between core_r0.6.0 and core_r0.8.0, this bucket is Bucket.
1015
+ # When megatron is core_r0.9.0, this bucket is _ParamAndGradBucketGroup.
1016
+ # In megatron version core_r0.9.0, func start_grad_sync from Bucket moved to _ParamAndGradBucketGroup.
1017
+ bucket_params_id_list = [id(params) for params in bucket.params]
1010
1018
  for param, name in self.param2name.items():
1011
1019
  if id(param) not in bucket_params_id_list:
1012
1020
  continue
@@ -1025,18 +1033,28 @@ class TrainerMon:
1025
1033
 
1026
1034
  return wrapper
1027
1035
 
1036
+ if not self.wg_distribution:
1037
+ return
1038
+
1028
1039
  try:
1029
1040
  from megatron.core.distributed.param_and_grad_buffer import Bucket
1041
+ self.origin_start_grad_sync = Bucket.start_grad_sync
1042
+ Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync)
1030
1043
  self.enable_megatron = True
1044
+ logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0")
1031
1045
  except ImportError:
1032
1046
  self.enable_megatron = False
1033
1047
 
1034
- if not self.wg_distribution:
1035
- return
1048
+ try:
1049
+ from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
1050
+ self.origin_start_grad_sync = _ParamAndGradBucketGroup.start_grad_sync
1051
+ _ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync)
1052
+ self.enable_megatron = True
1053
+ logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0")
1054
+ except ImportError:
1055
+ self.enable_megatron = False
1036
1056
 
1037
- if self.enable_megatron:
1038
- Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
1039
- else:
1057
+ if not self.enable_megatron:
1040
1058
  self._hook_weights()
1041
1059
 
1042
1060
  def _hook_weights(self):
@@ -1053,6 +1071,7 @@ class TrainerMon:
1053
1071
  else:
1054
1072
  context_dict[key] = param.grad.clone()
1055
1073
 
1074
+ logger.info("hooking weights.")
1056
1075
  for param, name in self.param2name.items():
1057
1076
  key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
1058
1077
  setattr(param, 'micro_step', 0)