mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
  3. msprobe/README.md +27 -22
  4. msprobe/core/common/const.py +129 -60
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/inplace_ops.yaml +1 -0
  9. msprobe/core/common/utils.py +43 -33
  10. msprobe/core/compare/acc_compare.py +43 -74
  11. msprobe/core/compare/check.py +2 -6
  12. msprobe/core/compare/highlight.py +2 -0
  13. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  14. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  15. msprobe/core/compare/merge_result/merge_result.py +16 -9
  16. msprobe/core/compare/merge_result/utils.py +81 -0
  17. msprobe/core/compare/multiprocessing_compute.py +19 -12
  18. msprobe/core/compare/npy_compare.py +30 -12
  19. msprobe/core/compare/utils.py +30 -10
  20. msprobe/core/data_dump/api_registry.py +176 -0
  21. msprobe/core/data_dump/data_collector.py +58 -13
  22. msprobe/core/data_dump/data_processor/base.py +94 -10
  23. msprobe/core/data_dump/data_processor/factory.py +3 -0
  24. msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
  25. msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
  26. msprobe/core/data_dump/json_writer.py +61 -40
  27. msprobe/core/grad_probe/constant.py +1 -0
  28. msprobe/core/grad_probe/grad_compare.py +1 -1
  29. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  30. msprobe/docs/01.installation.md +27 -1
  31. msprobe/docs/02.config_introduction.md +27 -23
  32. msprobe/docs/03.config_examples.md +24 -0
  33. msprobe/docs/05.data_dump_PyTorch.md +103 -16
  34. msprobe/docs/06.data_dump_MindSpore.md +76 -32
  35. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  36. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  37. msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
  38. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  39. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  40. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  41. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  42. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  43. msprobe/docs/18.online_dispatch.md +1 -1
  44. msprobe/docs/19.monitor.md +332 -273
  45. msprobe/docs/21.visualization_PyTorch.md +42 -13
  46. msprobe/docs/22.visualization_MindSpore.md +43 -13
  47. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  48. msprobe/docs/27.dump_json_instruction.md +301 -27
  49. msprobe/docs/28.debugger_save_instruction.md +94 -0
  50. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  51. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  52. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  53. msprobe/docs/FAQ.md +3 -11
  54. msprobe/docs/img/compare_result.png +0 -0
  55. msprobe/docs/img/merge_result.png +0 -0
  56. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  57. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  58. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  59. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  60. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  61. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  63. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  64. msprobe/mindspore/__init__.py +4 -2
  65. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
  66. msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
  67. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  68. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  69. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  70. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  71. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  72. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  73. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
  74. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  75. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  76. msprobe/mindspore/common/const.py +61 -0
  77. msprobe/mindspore/common/utils.py +48 -18
  78. msprobe/mindspore/compare/ms_compare.py +27 -19
  79. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  80. msprobe/mindspore/debugger/debugger_config.py +31 -6
  81. msprobe/mindspore/debugger/precision_debugger.py +45 -14
  82. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  83. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  86. msprobe/mindspore/dump/jit_dump.py +21 -15
  87. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  88. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  89. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  90. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  91. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  92. msprobe/mindspore/grad_probe/global_context.py +2 -0
  93. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  94. msprobe/mindspore/grad_probe/hook.py +2 -4
  95. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  96. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  97. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  98. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  99. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  100. msprobe/mindspore/monitor/features.py +63 -0
  101. msprobe/mindspore/monitor/module_hook.py +873 -0
  102. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  103. msprobe/mindspore/monitor/utils.py +309 -0
  104. msprobe/mindspore/ms_config.py +8 -2
  105. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  106. msprobe/mindspore/service.py +114 -34
  107. msprobe/pytorch/__init__.py +0 -1
  108. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  109. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
  110. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  111. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  112. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  116. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  117. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  118. msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
  119. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
  120. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  121. msprobe/pytorch/common/utils.py +97 -4
  122. msprobe/pytorch/debugger/debugger_config.py +19 -9
  123. msprobe/pytorch/debugger/precision_debugger.py +24 -1
  124. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  125. msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
  126. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  127. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  132. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  133. msprobe/pytorch/function_factory.py +8 -2
  134. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  135. msprobe/pytorch/hook_module/api_register.py +131 -0
  136. msprobe/pytorch/hook_module/hook_module.py +19 -14
  137. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  138. msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
  139. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  140. msprobe/pytorch/monitor/csv2tb.py +18 -14
  141. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  142. msprobe/pytorch/monitor/module_hook.py +238 -193
  143. msprobe/pytorch/monitor/module_metric.py +9 -6
  144. msprobe/pytorch/monitor/optimizer_collect.py +100 -67
  145. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  146. msprobe/pytorch/monitor/utils.py +76 -44
  147. msprobe/pytorch/online_dispatch/compare.py +0 -2
  148. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  149. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  150. msprobe/pytorch/online_dispatch/utils.py +3 -0
  151. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  152. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  153. msprobe/pytorch/pt_config.py +30 -29
  154. msprobe/pytorch/service.py +114 -32
  155. msprobe/visualization/builder/graph_builder.py +75 -10
  156. msprobe/visualization/builder/msprobe_adapter.py +7 -6
  157. msprobe/visualization/compare/graph_comparator.py +42 -38
  158. msprobe/visualization/compare/mode_adapter.py +0 -19
  159. msprobe/visualization/graph/base_node.py +11 -3
  160. msprobe/visualization/graph/distributed_analyzer.py +71 -3
  161. msprobe/visualization/graph/graph.py +0 -11
  162. msprobe/visualization/graph/node_op.py +4 -3
  163. msprobe/visualization/graph_service.py +4 -5
  164. msprobe/visualization/utils.py +12 -35
  165. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
  166. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  167. msprobe/pytorch/hook_module/api_registry.py +0 -166
  168. msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
  169. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  171. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  172. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  173. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  174. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  175. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  176. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  177. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -22,12 +22,13 @@ from functools import partial
22
22
  import pytz
23
23
  import torch
24
24
  import torch.distributed as dist
25
- from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
26
25
  from torch.utils.hooks import BackwardHook
27
26
 
28
- from msprobe.core.common.const import MonitorConst
27
+ from msprobe.core.common.const import MonitorConst, Const
29
28
  from msprobe.core.common.file_utils import load_json, save_json
29
+ from msprobe.core.common.decorator import recursion_depth_decorator
30
30
  from msprobe.pytorch.common.log import logger
31
+ from msprobe.pytorch.common.utils import is_recomputation, is_float8_tensor
31
32
  from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter
32
33
  from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
33
34
  CSVWriterWithAD, BaseWriterWithAD, WriterInput
@@ -37,15 +38,16 @@ from msprobe.pytorch.monitor.features import get_sign_matches
37
38
  from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
38
39
  TensorMetrics, squash_param_name
39
40
  from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
40
- from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory, OptimizerMon
41
- from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, is_recomputation, \
42
- get_output_base_dir, get_target_output_dir
41
+ from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
42
+ from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \
43
+ get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor
43
44
  from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
44
45
 
45
46
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
46
47
  if not torch_version_above_or_equal_2:
47
48
  raise ValueError("monitor require torch>=2.0")
48
49
 
50
+
49
51
  FORMAT_MAPPING = {
50
52
  MonitorConst.TENSORBOARD: SummaryWriterWithAD,
51
53
  MonitorConst.CSV: CSVWriterWithAD,
@@ -85,9 +87,6 @@ class ModuleHookContext:
85
87
  :param target_config: target obj in config json.
86
88
  :return:
87
89
  """
88
- valid_key = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT, MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
89
- if key_name not in valid_key:
90
- raise ValueError(f"key({key_name}) error, valid_key: {valid_key}")
91
90
  cared = target_config.get(self.module_name, self.struct)
92
91
  if key_name in cared:
93
92
  target_module_config = cared[key_name]
@@ -178,20 +177,17 @@ class GradContext:
178
177
  class TrainerMon:
179
178
  tensor_metrics = TensorMetrics()
180
179
 
180
+ # 保留原opt_ty参数, 兼容msprobe1.2.2前旧版本
181
181
  def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None:
182
- """
183
- opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer"
184
- """
185
182
  # TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
186
183
  self.config_file_path = config_file_path
187
184
  self.process_group = get_process_group(process_group)
188
185
  self.params_have_main_grad = params_have_main_grad
189
- self.opt_ty = opt_ty
190
- self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty)
191
186
  self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
192
187
  self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
193
188
  self.origin_step_func = None
194
- self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过switch开关直接打开
189
+ self.origin_start_grad_sync = None
190
+ self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
195
191
  self.config = load_json(config_file_path)
196
192
  validate_config(self.config)
197
193
 
@@ -219,13 +215,16 @@ class TrainerMon:
219
215
  self.pp_stage = 0
220
216
  self.group_mates = [0]
221
217
 
222
- # TYPE2: 只会在monitor_gnorm_with_ad()主调中赋值的变量
218
+ # TYPE2: 只会在set_monitor()主调中赋值的变量
223
219
  self.model = None
224
220
  self.vpp = False
225
221
  self.dp_group = None
226
222
  self.tp_group = None
227
223
  self.enable_megatron = False
228
224
  self.micro_batch_number = 1
225
+ self.optimizer_class = None
226
+ self.optimizer_mon = None
227
+ self.optimizer_trans = None
229
228
 
230
229
  # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
231
230
  self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
@@ -253,7 +252,7 @@ class TrainerMon:
253
252
  self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
254
253
  if self.dynamic_enable:
255
254
  logger.warning(f"DYNAMIC_MONITOR is set, "
256
- f"please make sure you have 'switch' and 'collect_times' item in {self.config_file_path}")
255
+ f"please make sure you have 'dynamic_on' and 'collect_times' in {self.config_file_path}")
257
256
  self.monitoring = False
258
257
  else:
259
258
  self.set_config()
@@ -273,10 +272,6 @@ class TrainerMon:
273
272
  def ops(self, value):
274
273
  self._ops = validate_ops(value)
275
274
 
276
- @staticmethod
277
- def set_wrapped_optimizer(_wrapped_optimizer):
278
- OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer)
279
-
280
275
  @staticmethod
281
276
  def has_register_backward_hook(module_name, module):
282
277
  if hasattr(module, '_backward_hooks') and \
@@ -308,7 +303,7 @@ class TrainerMon:
308
303
  self.has_collect_times = 0 # 重设采集计数器
309
304
  self.print_struct = self.config.get("print_struct", False)
310
305
  self.module_rank_list = self.config.get("module_ranks", [])
311
- self.format = self.config.get('format', 'tensorboard')
306
+ self.format = self.config.get('format', MonitorConst.CSV)
312
307
  self.eps = self.config.get('eps', 1e-8)
313
308
  self.ops = self.config.get('ops', [])
314
309
  self.ndigits = self.config.get('ndigits', 6)
@@ -330,8 +325,6 @@ class TrainerMon:
330
325
  self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
331
326
  self.cc_logged_stack = defaultdict(set)
332
327
  self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
333
- self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
334
- api_register.redirect_api()
335
328
 
336
329
  self.common_info()
337
330
 
@@ -344,7 +337,13 @@ class TrainerMon:
344
337
 
345
338
  # 初始化writer, 创建输出目录
346
339
  if self.format not in FORMAT_MAPPING:
347
- raise ValueError(f"Unsupported format: {self.format}")
340
+ logger.warning(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
341
+ self.format = MonitorConst.CSV
342
+
343
+ if self.ur_distribution and self.format != 'tensorboard':
344
+ logger.warning("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution")
345
+ self.ur_distribution = False
346
+
348
347
  writer = FORMAT_MAPPING[self.format]
349
348
  self.step_count_per_record = self.config.get('step_count_per_record', 1)
350
349
 
@@ -365,19 +364,6 @@ class TrainerMon:
365
364
  self.rank)
366
365
  self.anomaly_data_writer.init_detected_json()
367
366
 
368
- def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
369
- rank = None
370
- if dist.is_initialized():
371
- rank = dist.get_rank()
372
- if (rank not in rank_list) and len(rank_list) != 0:
373
- return
374
- self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
375
-
376
- def build_tbtag_tensor_map(self, module_name, tag, tensor):
377
- key = get_summary_writer_tag_name(module_name, tag, self.rank)
378
- self._register_param_call_id("_hook_module", key)
379
- return {key: tensor}
380
-
381
367
  def common_info(self):
382
368
  if not self.xy_distribution:
383
369
  logger.info_on_rank_0("> module input/output input_grad/output_grad is not monitored. ")
@@ -393,105 +379,39 @@ class TrainerMon:
393
379
  logger.info_on_rank_0('> grad and momentum direction will not be compared.')
394
380
  if not self.cc_distribution.get('enable', False):
395
381
  logger.info_on_rank_0("> cc operator is not monitored.")
396
- if not self.opt_ty:
397
- if self.ur_distribution:
398
- raise Exception("ur_distribution cannot be enabled with unknown optimizer.")
399
- if self.mv_distribution:
400
- raise Exception("mv_distribution cannot be enabled with unknown optimizer.")
401
-
402
- def hook_modules(self):
403
- if self.module_rank_list and (self.rank not in self.module_rank_list):
404
- return
405
-
406
- targets = self.config['targets']
407
- module_in_all_stage = [key for key in targets.keys() if MonitorConst.VPP_SEP not in key]
408
- for key in module_in_all_stage:
409
- struct = targets.pop(key)
410
- targets.update({f'{vpp_stage}{MonitorConst.VPP_SEP}{key}': struct for vpp_stage in range(len(self.model))})
411
-
412
- hooked_count = 0
413
- for vpp_stage, model_chunk in enumerate(self.model):
414
- vpp_stage = f'{vpp_stage}{MonitorConst.VPP_SEP}'
415
- targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
416
- 'targets'].keys()
417
- hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
418
-
419
- logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
420
-
421
- def clone_if_tensor(args):
422
- if isinstance(args, tuple):
423
- return tuple([clone_if_tensor(arg) for arg in args])
424
- elif isinstance(args, torch.Tensor):
425
- return args.clone()
426
- else:
427
- return args
428
-
429
- @torch.no_grad
430
- def wrap_hook_setup(setup):
431
- def wrapped_setup(*args, **kwargs):
432
- args = setup(*args, **kwargs)
433
- args = clone_if_tensor(args)
434
- return args
435
-
436
- return wrapped_setup
437
-
438
- BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
439
-
440
- return
441
-
442
- def generate_param_metrics(self, opt_context):
443
- if not self.param_distribution:
444
- return
445
- get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
446
-
447
- def generate_mv_metrics(self, opt_context):
448
- if not self.mv_distribution:
449
- return
450
- opt_context.exp_avg_metric = {}
451
- opt_context.exp_avg_sq_metric = {}
452
- m_tag_tensor_map = self.generate_param_map('exp_avg', opt_context.param_exp_avg)
453
- v_tag_tensor_map = self.generate_param_map('efxp_avg_sq', opt_context.param_exp_avg_sq)
454
- get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
455
- get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
456
-
457
- def generate_wgrad_metrics(self):
458
- if not self.wg_distribution:
459
- return {}, {}
460
382
 
461
- if self.weight_hooked:
462
- get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
463
-
464
- grad_dict = {}
465
- for param, name in self.param2name.items():
466
- if self.duplicate_param.get(name, False):
467
- continue
468
- grad = param.main_grad if self.params_have_main_grad else param.grad
469
- if grad is None:
470
- logger.warning(f"grad is None: {name}, maybe something wrong happened.")
471
- continue
472
- tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
473
- self._register_param_call_id("hook_optimizer", tag)
474
- grad_dict[tag] = grad
383
+ # 保留原接口, 兼容msprobe1.2.2前旧版本
384
+ def monitor_gnorm_with_ad(self, model, optimizer=None, grad_acc_steps=1, tp_group=None, dp_group=None,
385
+ start_iteration=0):
386
+ if optimizer is None:
387
+ optimizer = getattr(self, "optimizer_trans", None) # 兼容老版本可传None的情况, 从set_wrapped_optimizer获取
388
+ if optimizer is None:
389
+ logger.error("monitor_gnorm_with_ad: please set_wrapped_optimizer before it or input optimizer!=None")
390
+ return
391
+ self.set_monitor(model, optimizer, grad_acc_steps, tp_group, dp_group, start_iteration)
475
392
 
476
- get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
477
- return self.grad_context.post, self.grad_context.pre
393
+ # 保留原接口, 兼容msprobe1.2.2前旧版本
394
+ def set_wrapped_optimizer(self, optimizer):
395
+ self.optimizer_trans = optimizer
478
396
 
479
- def monitor_gnorm_with_ad(
397
+ def set_monitor(
480
398
  self,
481
399
  model,
400
+ optimizer,
482
401
  grad_acc_steps=1,
483
- optimizer=None,
484
402
  tp_group=None,
485
403
  dp_group=None,
486
404
  start_iteration=0
487
405
  ):
488
406
  """External interface"""
407
+ grad_acc_steps, start_iteration = validate_set_monitor(grad_acc_steps, start_iteration)
489
408
  global start_step
490
409
  start_step = start_iteration
491
410
  logger.info(f'grad acc steps {grad_acc_steps}')
492
411
  self.micro_batch_number = grad_acc_steps
493
412
  self.dp_group = dp_group
494
413
  self.tp_group = tp_group
414
+ self.optimizer_mon, self.optimizer_class = OptimizerMonFactory.create_optimizer_mon(optimizer)
495
415
  self.hook_step_final(optimizer)
496
416
  if not isinstance(model, list):
497
417
  model = [model]
@@ -507,8 +427,24 @@ class TrainerMon:
507
427
  self.hook_optimizer(optimizer)
508
428
  self._patch_grad_sync()
509
429
  self.hook_modules()
430
+ if self.cc_distribution.get('enable', False):
431
+ self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
432
+ api_register.redirect_api()
510
433
  self.monitoring = True
511
434
 
435
+ def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
436
+ rank = None
437
+ if dist.is_initialized():
438
+ rank = dist.get_rank()
439
+ if (rank not in rank_list) and len(rank_list) != 0:
440
+ return
441
+ self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
442
+
443
+ def build_tbtag_tensor_map(self, module_name, tag, tensor):
444
+ key = get_summary_writer_tag_name(module_name, tag, self.rank)
445
+ self._register_param_call_id("_hook_module", key)
446
+ return {key: tensor}
447
+
512
448
  def generate_param_map(self, tag, param_tensor):
513
449
  metrics = {}
514
450
  for name in self.param2name.values():
@@ -519,6 +455,44 @@ class TrainerMon:
519
455
  metrics[key] = param_tensor[name]
520
456
  return metrics
521
457
 
458
+ def generate_param_metrics(self, opt_context):
459
+ if not self.param_distribution:
460
+ return
461
+ get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
462
+
463
+ def generate_mv_metrics(self, opt_context):
464
+ if not self.mv_distribution:
465
+ return
466
+ opt_context.exp_avg_metric = {}
467
+ opt_context.exp_avg_sq_metric = {}
468
+ m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
469
+ v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
470
+ get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
471
+ get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
472
+
473
+ def generate_wgrad_metrics(self):
474
+ if not self.wg_distribution:
475
+ return {}, {}
476
+
477
+ if self.weight_hooked:
478
+ get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
479
+
480
+ grad_dict = {}
481
+ for param, name in self.param2name.items():
482
+ if self.duplicate_param.get(name, False):
483
+ continue
484
+ grad = param.main_grad if self.params_have_main_grad else param.grad
485
+ if grad is None:
486
+ logger.warning(f"grad is None: {name}, maybe something wrong happened.")
487
+ continue
488
+ tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
489
+ self._register_param_call_id("hook_optimizer", tag)
490
+ grad_dict[tag] = grad
491
+
492
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
493
+ unreduced_grad = self.grad_context.acc_metric if self.weight_hooked else self.grad_context.pre
494
+ return self.grad_context.post, unreduced_grad
495
+
522
496
  def generate_xy_metrics(self):
523
497
  actv = {}
524
498
  for fwd_context in self.module_fwd_hook_context_by_module.values():
@@ -529,6 +503,8 @@ class TrainerMon:
529
503
  return actv, actv_grad
530
504
 
531
505
  def reload_xy(self, xy_distribution=False):
506
+ logger.warning("reload_xy() is deprecated and will be removed in a future version. "
507
+ "Use DYNAMIC_MONITOR instead.")
532
508
  self.xy_distribution = xy_distribution
533
509
 
534
510
  for handle in self.handles['xy']:
@@ -547,21 +523,23 @@ class TrainerMon:
547
523
  for _, fwd_context in self.module_fwd_hook_context_by_module.items():
548
524
  if len(fwd_context.actv) == 0:
549
525
  continue
550
- self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, 'actv')
526
+ self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, MonitorConst.ACTV)
551
527
  fwd_context.actv.clear()
552
528
  if self.grad_context.actv:
553
- self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, 'actv_grad')
529
+ self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD)
554
530
 
555
531
  def write_param_tb(self, opt_context):
556
532
  if not self.param_distribution:
557
533
  return
558
- self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, 'param')
534
+ self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, MonitorConst.PARAM)
559
535
 
560
536
  def write_mv_tb(self, opt_context):
561
537
  if not self.mv_distribution:
562
538
  return
563
- self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, 'exp_avg')
564
- self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq')
539
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric,
540
+ opt_context.step, MonitorConst.EXP_AVG)
541
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric,
542
+ opt_context.step, MonitorConst.EXP_AVG_SQ)
565
543
 
566
544
  def write_grad_tb(self, step):
567
545
  if not self.wg_distribution:
@@ -573,7 +551,7 @@ class TrainerMon:
573
551
  self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
574
552
  self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
575
553
 
576
- def hook_optimizer(self, optimizer=None):
554
+ def hook_optimizer(self, optimizer):
577
555
  # in DDP by default use params_have_main_grad
578
556
  def optimizer_pre_step_hook(optimizer, args, kwargs):
579
557
  context = self.optimizer_context[optimizer]
@@ -592,15 +570,13 @@ class TrainerMon:
592
570
  # skip generate metrics
593
571
  if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0:
594
572
  return
595
- if self.opt_ty in MonitorConst.DEEPSPEED_OPT_TY:
573
+ if MonitorConst.DEEPSPEED_ZERO_OPT_FILTER in self.optimizer_class: # use deepspeed with zero1/2/3
596
574
  if not self.name2indices:
597
- self.name2indices = self.mix_precision_optimizer_mon.get_param_index(self.param2name,
598
- self.name2index)
599
- mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name,
600
- self.name2indices)
575
+ self.name2indices = self.optimizer_mon.get_param_index(self.param2name, self.name2index, optimizer)
576
+ mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name, self.name2indices)
601
577
  self.param2name = mv_result.grad
602
578
  else:
603
- mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name)
579
+ mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name)
604
580
  context.param_exp_avg = mv_result.exp_avg
605
581
  context.param_exp_avg_sq = mv_result.exp_avg_sq
606
582
  context.param_adam_update = mv_result.update
@@ -641,19 +617,13 @@ class TrainerMon:
641
617
  optimizer_pre_step_hook(optimizer, args, kwargs)
642
618
  out = func(*args, **kwargs)
643
619
  return out
644
-
645
620
  return wrapper
646
621
 
647
622
  if self.optimizer_hooked:
648
623
  return
649
624
 
650
- if optimizer:
651
- optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
652
- self.handles['optimizer'] = []
653
- else:
654
- if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list):
655
- step_pre_hook = register_optimizer_step_pre_hook(optimizer_pre_step_hook)
656
- self.handles['optimizer'] = [step_pre_hook]
625
+ optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
626
+
657
627
  self.optimizer_hooked = True
658
628
  return
659
629
 
@@ -677,11 +647,12 @@ class TrainerMon:
677
647
  logger.error(f"get config.json wrong because {e}, not updated, please check!!!")
678
648
  return
679
649
 
680
- if config.get("switch", False):
650
+ if config.get("dynamic_on", False):
681
651
  try:
682
652
  validate_config(config)
683
653
  self.config = config
684
654
  self.set_config()
655
+ self.start_step = context.step # 动态启停时不受原start_step影响,永远从下一步开始
685
656
  logger.warning(f"config is updated at step{context.step - 1}, "
686
657
  f"will start new hook at step{context.step}.")
687
658
  except Exception as e:
@@ -729,6 +700,9 @@ class TrainerMon:
729
700
  if self.anomaly_data_factory:
730
701
  self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
731
702
  self.summary_writer.clear_anomalies()
703
+
704
+ if self.format == MonitorConst.TENSORBOARD:
705
+ chmod_tensorboard_dir(self.tensorboard_dir)
732
706
  self.call_id = 0
733
707
  self.param_name_call_id.clear()
734
708
 
@@ -745,11 +719,49 @@ class TrainerMon:
745
719
  return out
746
720
  return wrapper
747
721
 
748
- if optimizer:
749
- optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
750
- self.origin_step_func = optimizer.__class__.step
751
- else:
752
- register_optimizer_step_post_hook(step_final_hook)
722
+ optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
723
+ self.origin_step_func = optimizer.__class__.step
724
+ return
725
+
726
+ def hook_modules(self):
727
+ if self.module_rank_list and (self.rank not in self.module_rank_list):
728
+ return
729
+
730
+ targets = self.config['targets']
731
+ module_in_all_stage = [key for key in targets.keys() if MonitorConst.NAME_SEP not in key]
732
+ for key in module_in_all_stage:
733
+ struct = targets.pop(key)
734
+ targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
735
+
736
+ hooked_count = 0
737
+ for vpp_stage, model_chunk in enumerate(self.model):
738
+ vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
739
+ targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
740
+ 'targets'].keys()
741
+ hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
742
+
743
+ logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
744
+
745
+ @recursion_depth_decorator('msprobe.pytorch.monitor.clone_if_tensor')
746
+ def clone_if_tensor(args):
747
+ if isinstance(args, tuple):
748
+ return tuple([clone_if_tensor(arg) for arg in args])
749
+ elif isinstance(args, torch.Tensor) and not is_float8_tensor(args):
750
+ return args.clone()
751
+ else:
752
+ return args
753
+
754
+ @torch.no_grad
755
+ def wrap_hook_setup(setup):
756
+ def wrapped_setup(*args, **kwargs):
757
+ args = setup(*args, **kwargs)
758
+ args = clone_if_tensor(args)
759
+ return args
760
+
761
+ return wrapped_setup
762
+
763
+ BackwardHook.setup_input_hook = wrap_hook_setup(BackwardHook.setup_input_hook)
764
+ BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
753
765
  return
754
766
 
755
767
  def _remove_all_hooks(self, optimizer):
@@ -764,17 +776,28 @@ class TrainerMon:
764
776
  bwd_context.reset()
765
777
  self.grad_context.reset() # 权重梯度和激活值梯度都在这
766
778
 
767
- for handle in self.handles['wgrads']:
768
- handle.remove()
769
- self.handles['wgrads'].clear()
770
- self.weight_hooked = False
779
+ if self.origin_start_grad_sync: # megatron
780
+ try:
781
+ from megatron.core.distributed.param_and_grad_buffer import Bucket
782
+ Bucket.start_grad_sync = self.origin_start_grad_sync
783
+ logger.info("remove Bucket start_grad_sync")
784
+ except ImportError:
785
+ pass
786
+ try:
787
+ from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
788
+ _ParamAndGradBucketGroup.start_grad_sync = self.origin_start_grad_sync
789
+ logger.info("remove _ParamAndGradBucketGroup start_grad_sync")
790
+ except ImportError:
791
+ pass
792
+ else: # not megatron
793
+ for handle in self.handles['wgrads']:
794
+ handle.remove()
795
+ self.handles['wgrads'].clear()
796
+ self.weight_hooked = False
771
797
 
772
- if len(self.handles['optimizer']) == 0 and self.optimizer_hooked:
798
+ if self.optimizer_hooked:
773
799
  optimizer.__class__.step = self.origin_step_func
774
- else:
775
- for handle in self.handles['optimizer']:
776
- handle.remove()
777
- self.handles['optimizer'].clear()
800
+
778
801
  for _, context in self.optimizer_context.items():
779
802
  context.reset()
780
803
  self.optimizer_hooked = False
@@ -782,6 +805,7 @@ class TrainerMon:
782
805
  for handle in self.handles['cc']:
783
806
  handle.remove()
784
807
  self.handles['cc'].clear()
808
+ api_register.restore_api()
785
809
  for _, context in self.cc_context.items():
786
810
  context.reset()
787
811
 
@@ -800,17 +824,17 @@ class TrainerMon:
800
824
 
801
825
  def _remove_all_hooks_final(self, optimizer):
802
826
  if self.dynamic_enable:
803
- # 结束后自动重置switch为False等待用户手动开启
827
+ # 结束后自动重置dynamic_on为False等待用户手动开启
804
828
  try:
805
829
  config = load_json(self.config_file_path)
806
- config['switch'] = False
830
+ config['dynamic_on'] = False
807
831
  save_json(self.config_file_path, config, indent=2)
808
832
  config_timestamp = os.path.getmtime(self.config_file_path)
809
833
  self.config_timestamp = config_timestamp
810
834
  logger.info(
811
- "Finish monitor, set config'switch=False, will restart by set switch=True and update content")
835
+ "Finish monitor, set config'dynamic_on=False, will restart by set it to True and update config")
812
836
  except Exception as e:
813
- logger.warning(f"Finish monitor, set config'switch=False fail because {e}, please check!!!")
837
+ logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!")
814
838
  logger.info("Finish monitor")
815
839
  self._remove_all_hooks(optimizer)
816
840
 
@@ -871,7 +895,7 @@ class TrainerMon:
871
895
 
872
896
  def _register_param_name(self):
873
897
  for vpp_stage, model_chunk in enumerate(self.model):
874
- prefix = f'{vpp_stage}{MonitorConst.VPP_SEP}'
898
+ prefix = f'{vpp_stage}{MonitorConst.NAME_SEP}'
875
899
  self._register_chunk(model_chunk, prefix)
876
900
 
877
901
  def _is_target_module(self, module_name, targets, vpp_stage):
@@ -900,35 +924,37 @@ class TrainerMon:
900
924
  context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
901
925
  if not context.struct:
902
926
  context.struct = {
903
- MonitorConst.ACTV_IN: get_param_struct(module_input),
904
- MonitorConst.ACTV_OUT: get_param_struct(module_output)
927
+ Const.INPUT: get_param_struct(module_input),
928
+ Const.OUTPUT: get_param_struct(module_output)
905
929
  }
906
930
  if self.print_struct:
907
931
  self.module_struct[context.module_name].update(context.struct)
908
932
  return
909
933
  if not context.format_by_arg:
910
- context.set_format_by_arg(MonitorConst.ACTV_IN, self.config['targets'])
911
- context.set_format_by_arg(MonitorConst.ACTV_OUT, self.config['targets'])
934
+ context.set_format_by_arg(Const.INPUT, self.config['targets'])
935
+ context.set_format_by_arg(Const.OUTPUT, self.config['targets'])
912
936
  if not context.format_by_arg:
913
937
  return
914
938
  if not context.verified:
915
- context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
939
+ context.focused_in_col = validate_config_spec(context.format_by_arg[Const.INPUT],
916
940
  module_input, context.module_name,
917
- MonitorConst.ACTV_IN)
918
- context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
941
+ Const.INPUT)
942
+ context.focused_out_col = validate_config_spec(context.format_by_arg[Const.OUTPUT],
919
943
  module_output, context.module_name,
920
- MonitorConst.ACTV_OUT)
944
+ Const.OUTPUT)
921
945
  context.verified = True
922
946
  # expect output be tensor type
923
947
  tbtag_tensor_map = {}
924
948
  cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
925
949
  tbtag_tensor_map.update(
926
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
927
- cared_input))
950
+ self.build_tbtag_tensor_map(
951
+ f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
952
+ MonitorConst.ACTV, cared_input))
928
953
  cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
929
954
  tbtag_tensor_map.update(
930
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT,
931
- cared_output))
955
+ self.build_tbtag_tensor_map(
956
+ f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
957
+ MonitorConst.ACTV, cared_output))
932
958
 
933
959
  get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
934
960
  context.micro_step += 1
@@ -940,35 +966,37 @@ class TrainerMon:
940
966
  context: ModuleHookContext = self.module_bwd_hook_context_by_module[module]
941
967
  if not context.struct:
942
968
  context.struct = {
943
- MonitorConst.ACTVGRAD_IN: get_param_struct(input_grad),
944
- MonitorConst.ACTVGRAD_OUT: get_param_struct(output_grad)
969
+ MonitorConst.INPUT_GRAD: get_param_struct(input_grad),
970
+ MonitorConst.OUTPUT_GRAD: get_param_struct(output_grad)
945
971
  }
946
972
  if self.print_struct:
947
973
  self.module_struct[context.module_name].update(context.struct)
948
974
  return
949
975
  if not context.format_by_arg:
950
- context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.config['targets'])
951
- context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.config['targets'])
976
+ context.set_format_by_arg(MonitorConst.INPUT_GRAD, self.config['targets'])
977
+ context.set_format_by_arg(MonitorConst.OUTPUT_GRAD, self.config['targets'])
952
978
  if not context.format_by_arg:
953
979
  return
954
980
  if not context.verified:
955
- context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
956
- input_grad, context.module_name,
957
- MonitorConst.ACTVGRAD_IN)
958
- context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
959
- output_grad, context.module_name,
960
- MonitorConst.ACTVGRAD_OUT)
981
+ context.focused_in_col = validate_config_spec(
982
+ context.format_by_arg[MonitorConst.INPUT_GRAD],
983
+ input_grad, context.module_name, MonitorConst.INPUT_GRAD)
984
+ context.focused_out_col = validate_config_spec(
985
+ context.format_by_arg[MonitorConst.OUTPUT_GRAD],
986
+ output_grad, context.module_name, MonitorConst.OUTPUT_GRAD)
961
987
  context.verified = True
962
988
 
963
989
  tbtag_tensor_map = {}
964
990
  cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
965
991
  tbtag_tensor_map.update(
966
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN,
967
- cared_input_grad))
992
+ self.build_tbtag_tensor_map(
993
+ f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
994
+ MonitorConst.ACTV, cared_input_grad))
968
995
  cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
969
996
  tbtag_tensor_map.update(
970
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT,
971
- cared_output_grad))
997
+ self.build_tbtag_tensor_map(
998
+ f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
999
+ MonitorConst.ACTV, cared_output_grad))
972
1000
 
973
1001
  if context.micro_step == 0 and context.actvgrad:
974
1002
  logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
@@ -1006,7 +1034,10 @@ class TrainerMon:
1006
1034
  def patch_sync(sync_grad_func):
1007
1035
  def wrapper(bucket):
1008
1036
  grad_dict = {}
1009
- bucket_params_id_list = [id(params) for params in bucket.params_list]
1037
+ # Megatron between core_r0.6.0 and core_r0.8.0, this bucket is Bucket.
1038
+ # When megatron is core_r0.9.0, this bucket is _ParamAndGradBucketGroup.
1039
+ # In megatron version core_r0.9.0, func start_grad_sync from Bucket moved to _ParamAndGradBucketGroup.
1040
+ bucket_params_id_list = [id(params) for params in bucket.params]
1010
1041
  for param, name in self.param2name.items():
1011
1042
  if id(param) not in bucket_params_id_list:
1012
1043
  continue
@@ -1025,18 +1056,28 @@ class TrainerMon:
1025
1056
 
1026
1057
  return wrapper
1027
1058
 
1059
+ if not self.wg_distribution:
1060
+ return
1061
+
1028
1062
  try:
1029
1063
  from megatron.core.distributed.param_and_grad_buffer import Bucket
1064
+ self.origin_start_grad_sync = Bucket.start_grad_sync
1065
+ Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync)
1030
1066
  self.enable_megatron = True
1067
+ logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0")
1031
1068
  except ImportError:
1032
1069
  self.enable_megatron = False
1033
1070
 
1034
- if not self.wg_distribution:
1035
- return
1071
+ try:
1072
+ from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
1073
+ self.origin_start_grad_sync = _ParamAndGradBucketGroup.start_grad_sync
1074
+ _ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync)
1075
+ self.enable_megatron = True
1076
+ logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0")
1077
+ except ImportError:
1078
+ self.enable_megatron = False | self.enable_megatron
1036
1079
 
1037
- if self.enable_megatron:
1038
- Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
1039
- else:
1080
+ if not self.enable_megatron:
1040
1081
  self._hook_weights()
1041
1082
 
1042
1083
  def _hook_weights(self):
@@ -1049,10 +1090,14 @@ class TrainerMon:
1049
1090
  if param.micro_step == self.micro_batch_number:
1050
1091
  param.micro_step = 0
1051
1092
  if self.params_have_main_grad:
1052
- context_dict[key] = param.main_grad.clone()
1093
+ grad = param.main_grad
1053
1094
  else:
1054
- context_dict[key] = param.grad.clone()
1095
+ grad = param.grad
1096
+ if is_float8_tensor(grad):
1097
+ grad = grad.float()
1098
+ context_dict[key] = grad.clone()
1055
1099
 
1100
+ logger.info("hooking weights.")
1056
1101
  for param, name in self.param2name.items():
1057
1102
  key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
1058
1103
  setattr(param, 'micro_step', 0)