mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -22,27 +22,29 @@ from functools import partial
22
22
  import pytz
23
23
  import torch
24
24
  import torch.distributed as dist
25
+ import pandas as pd
25
26
  from torch.utils.hooks import BackwardHook
26
27
 
27
28
  from msprobe.core.common.const import MonitorConst, Const
28
29
  from msprobe.core.common.file_utils import load_json, save_json
29
30
  from msprobe.core.common.decorator import recursion_depth_decorator
31
+ from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter
32
+ from msprobe.core.common.file_utils import write_df_to_csv
33
+ from msprobe.core.common.utils import analyze_api_call_stack
30
34
  from msprobe.pytorch.common.log import logger
31
35
  from msprobe.pytorch.common.utils import is_recomputation, is_float8_tensor
32
- from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter
33
- from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
34
- CSVWriterWithAD, BaseWriterWithAD, WriterInput
36
+ from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput
35
37
  from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
36
38
  get_process_group
37
39
  from msprobe.pytorch.monitor.features import get_sign_matches
38
40
  from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
39
41
  TensorMetrics, squash_param_name
40
- from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
41
42
  from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
42
43
  from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \
43
44
  get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor
44
45
  from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
45
46
 
47
+
46
48
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
47
49
  if not torch_version_above_or_equal_2:
48
50
  raise ValueError("monitor require torch>=2.0")
@@ -72,36 +74,7 @@ class ModuleHookContext:
72
74
  self.actvgrad = []
73
75
  self.module_name = module_name
74
76
  self.struct = {}
75
- self.format_by_arg = {}
76
- self.verified = False
77
- self.focused_in_col = 0
78
- self.focused_out_col = 0
79
-
80
- def set_format_by_arg(self, key_name: str, target_config: dict):
81
- """ 按照监控对象配置format_by_arg
82
- 1) module_name 在 target 中配置监控对象
83
- 2) module_name 未在 targets 中配置,且 all_xy 全量监控
84
- 3) module_name 未在 targets 中配置,且 all_xy 未全量监控
85
-
86
- :param key_name: str, one of [input, output, input_grad, output_grad]
87
- :param target_config: target obj in config json.
88
- :return:
89
- """
90
- cared = target_config.get(self.module_name, self.struct)
91
- if key_name in cared:
92
- target_module_config = cared[key_name]
93
- if isinstance(target_module_config, dict):
94
- # current cared is self.struct, monitor all data for module_name
95
- self.format_by_arg[key_name] = target_module_config.get('config')
96
- elif isinstance(target_module_config, str):
97
- # current cared is target_config[self.module_name]
98
- self.format_by_arg[key_name] = target_module_config
99
- else:
100
- logger.warning_on_rank_0(f"target module config error, result maybe empty."
101
- f"module_name: {self.module_name}, key_name: {key_name}")
102
- self.format_by_arg[key_name] = None
103
- else:
104
- self.format_by_arg[key_name] = self.struct.get(key_name).get('config')
77
+ self.stack = ""
105
78
 
106
79
  def reset(self):
107
80
  self.actv.clear()
@@ -185,8 +158,8 @@ class TrainerMon:
185
158
  self.params_have_main_grad = params_have_main_grad
186
159
  self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
187
160
  self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
188
- self.origin_step_func = None
189
161
  self.origin_start_grad_sync = None
162
+ self.fsdp_post_backward_hook = None
190
163
  self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
191
164
  self.config = load_json(config_file_path)
192
165
  validate_config(self.config)
@@ -221,8 +194,8 @@ class TrainerMon:
221
194
  self.dp_group = None
222
195
  self.tp_group = None
223
196
  self.enable_megatron = False
197
+ self.fsdp_wrapped_module = False
224
198
  self.micro_batch_number = 1
225
- self.optimizer_class = None
226
199
  self.optimizer_mon = None
227
200
  self.optimizer_trans = None
228
201
 
@@ -234,7 +207,6 @@ class TrainerMon:
234
207
  self.grad_context = GradContext()
235
208
  self.handles = defaultdict(list)
236
209
  self.param2name = defaultdict(str)
237
- self.name2index = defaultdict()
238
210
  self.name2indices = defaultdict()
239
211
  self.name2param = {}
240
212
  self.duplicate_param = {}
@@ -247,6 +219,8 @@ class TrainerMon:
247
219
  self.optimizer_hooked = False
248
220
  self.param_registered = False
249
221
  self.struct_printed = False
222
+ self.pre_step_hooks = []
223
+ self.post_step_hooks = []
250
224
 
251
225
  # 动静态区分
252
226
  self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
@@ -317,6 +291,8 @@ class TrainerMon:
317
291
  self.param_distribution = self.config.get("param_distribution", False)
318
292
  self.mg_direction = self.config.get('mg_direction', False)
319
293
  self.cc_distribution = self.config.get("cc_distribution", {})
294
+ self.stack_info = self.config.get('stack_info', False)
295
+ self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False)
320
296
 
321
297
  if not self.cc_distribution.get('enable', False):
322
298
  self.cc_log_only = False
@@ -411,7 +387,7 @@ class TrainerMon:
411
387
  self.micro_batch_number = grad_acc_steps
412
388
  self.dp_group = dp_group
413
389
  self.tp_group = tp_group
414
- self.optimizer_mon, self.optimizer_class = OptimizerMonFactory.create_optimizer_mon(optimizer)
390
+ self.optimizer_mon = OptimizerMonFactory.create_optimizer_mon(optimizer)
415
391
  self.hook_step_final(optimizer)
416
392
  if not isinstance(model, list):
417
393
  model = [model]
@@ -440,25 +416,48 @@ class TrainerMon:
440
416
  return
441
417
  self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
442
418
 
443
- def build_tbtag_tensor_map(self, module_name, tag, tensor):
444
- key = get_summary_writer_tag_name(module_name, tag, self.rank)
445
- self._register_param_call_id("_hook_module", key)
446
- return {key: tensor}
419
+ def build_tbtag_tensor_map(self, module_name, suffix, tag, tensor):
420
+ """
421
+ :param module_name: str of module name
422
+ :param suffix:
423
+ :param tag:
424
+ :param tensor: torch.tensor or tuple/list of torch.tensor
425
+ :return: tensor_map
426
+ """
427
+ tensor_map = {}
428
+ if isinstance(tensor, torch.Tensor):
429
+ tensor = [tensor]
430
+ if isinstance(tensor, tuple) or isinstance(tensor, list):
431
+ if len(tensor) == 1:
432
+ key = get_summary_writer_tag_name(module_name + suffix, tag, self.rank)
433
+ self.register_param_call_id("_hook_module", key)
434
+ tensor_map[key] = tensor[0]
435
+ else:
436
+ for i, tensor_i in enumerate(tensor):
437
+ key = get_summary_writer_tag_name(module_name + f"_{i}" + suffix, tag, self.rank)
438
+ self.register_param_call_id("_hook_module", key)
439
+ tensor_map[key] = tensor_i
440
+ return tensor_map
447
441
 
448
442
  def generate_param_map(self, tag, param_tensor):
449
443
  metrics = {}
450
444
  for name in self.param2name.values():
451
445
  key = get_summary_writer_tag_name(name, tag, self.rank)
452
- self._register_param_call_id("optimizer_pre_step_hook", key)
446
+ self.register_param_call_id("optimizer_pre_step_hook", key)
453
447
  if name not in param_tensor or param_tensor[name] is None:
454
448
  continue
455
449
  metrics[key] = param_tensor[name]
456
450
  return metrics
457
451
 
458
- def generate_param_metrics(self, opt_context):
452
+ def generate_param_metrics(self, opt_context, stage=MonitorConst.PRE_PARAM):
459
453
  if not self.param_distribution:
460
454
  return
461
- get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
455
+ tag2param = {
456
+ self.name2tag.get(name, {}).get(stage): param
457
+ for name, param in self.name2param.items()
458
+ if param.numel() != 0
459
+ }
460
+ get_metrics(self.ops, tag2param, self.eps, opt_context.param_metric)
462
461
 
463
462
  def generate_mv_metrics(self, opt_context):
464
463
  if not self.mv_distribution:
@@ -470,28 +469,22 @@ class TrainerMon:
470
469
  get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
471
470
  get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
472
471
 
473
- def generate_wgrad_metrics(self):
472
+ def generate_wgrad_metrics(self, post_grad_dict):
474
473
  if not self.wg_distribution:
475
474
  return {}, {}
476
475
 
477
476
  if self.weight_hooked:
478
477
  get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
479
478
 
480
- grad_dict = {}
481
- for param, name in self.param2name.items():
482
- if self.duplicate_param.get(name, False):
483
- continue
484
- grad = param.main_grad if self.params_have_main_grad else param.grad
485
- if grad is None:
486
- logger.warning(f"grad is None: {name}, maybe something wrong happened.")
487
- continue
488
- tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
489
- self._register_param_call_id("hook_optimizer", tag)
490
- grad_dict[tag] = grad
479
+ get_metrics(self.ops, post_grad_dict, self.eps, self.grad_context.post)
480
+ reduced_grad = self.grad_context.post
481
+
482
+ if self.weight_hooked:
483
+ unreduced_grad = self.grad_context.acc_metric
484
+ else:
485
+ unreduced_grad = self.grad_context.pre
491
486
 
492
- get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
493
- unreduced_grad = self.grad_context.acc_metric if self.weight_hooked else self.grad_context.pre
494
- return self.grad_context.post, unreduced_grad
487
+ return reduced_grad, unreduced_grad
495
488
 
496
489
  def generate_xy_metrics(self):
497
490
  actv = {}
@@ -517,6 +510,17 @@ class TrainerMon:
517
510
  def write_adhoc_check(self, step):
518
511
  self.tensor_metrics.flush(self.summary_writer)
519
512
 
513
+ def write_stack_info(self):
514
+ stack_data = []
515
+ header = ["module_name", "stack_info"]
516
+ stack_data.append(header)
517
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
518
+ stack_data.append([fwd_context.module_name, fwd_context.stack])
519
+ filepath = os.path.join(self.tensorboard_dir, f'stack_info.csv')
520
+ if not os.path.exists(filepath):
521
+ data_frame = pd.DataFrame(columns=stack_data)
522
+ write_df_to_csv(data_frame, filepath)
523
+
520
524
  def write_xy_tb(self, step):
521
525
  if not self.xy_distribution:
522
526
  return
@@ -531,7 +535,10 @@ class TrainerMon:
531
535
  def write_param_tb(self, opt_context):
532
536
  if not self.param_distribution:
533
537
  return
534
- self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, MonitorConst.PARAM)
538
+ param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.PRE_PARAM in k}
539
+ updated_param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.POST_PARAM in k}
540
+ self.summary_writer.write_metrics(self.ops, param_metrics, opt_context.step, MonitorConst.PRE_PARAM)
541
+ self.summary_writer.write_metrics(self.ops, updated_param_metrics, opt_context.step, MonitorConst.POST_PARAM)
535
542
 
536
543
  def write_mv_tb(self, opt_context):
537
544
  if not self.mv_distribution:
@@ -545,10 +552,11 @@ class TrainerMon:
545
552
  if not self.wg_distribution:
546
553
  return
547
554
 
548
- if self.enable_megatron:
549
- self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
555
+ if self.weight_hooked:
556
+ self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced',
557
+ use_micro_step=self.monitor_mbs_grad)
550
558
  else:
551
- self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
559
+ self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
552
560
  self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
553
561
 
554
562
  def hook_optimizer(self, optimizer):
@@ -570,21 +578,23 @@ class TrainerMon:
570
578
  # skip generate metrics
571
579
  if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0:
572
580
  return
573
- if MonitorConst.DEEPSPEED_ZERO_OPT_FILTER in self.optimizer_class: # use deepspeed with zero1/2/3
574
- if not self.name2indices:
575
- self.name2indices = self.optimizer_mon.get_param_index(self.param2name, self.name2index, optimizer)
576
- mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name, self.name2indices)
577
- self.param2name = mv_result.grad
578
- else:
579
- mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name)
580
- context.param_exp_avg = mv_result.exp_avg
581
- context.param_exp_avg_sq = mv_result.exp_avg_sq
582
- context.param_adam_update = mv_result.update
583
- context.param_adam_ratio = mv_result.ratio
584
581
 
585
- self.generate_wgrad_metrics()
582
+ grad_dict = {}
583
+ if self.wg_distribution:
584
+ grad_dict = self.optimizer_mon.fetch_grad(self, self.param2name)
585
+
586
+ mv_result = None
587
+ if self.mv_distribution or self.ur_distribution or self.mg_direction:
588
+ mv_result = self.optimizer_mon.fetch_mv(self, self.param2name)
589
+ if mv_result:
590
+ context.param_exp_avg = mv_result.exp_avg
591
+ context.param_exp_avg_sq = mv_result.exp_avg_sq
592
+ context.param_adam_update = mv_result.update
593
+ context.param_adam_ratio = mv_result.ratio
594
+
595
+ self.generate_wgrad_metrics(grad_dict)
586
596
  self.generate_mv_metrics(context)
587
- self.generate_param_metrics(context)
597
+ self.generate_param_metrics(context, MonitorConst.PRE_PARAM)
588
598
 
589
599
  tbtag_tensor_map = {}
590
600
  if self.mg_direction:
@@ -612,17 +622,15 @@ class TrainerMon:
612
622
  context.metric_dict = metric_dict
613
623
  return
614
624
 
615
- def patch_step(func, optimizer):
616
- def wrapper(*args, **kwargs):
617
- optimizer_pre_step_hook(optimizer, args, kwargs)
618
- out = func(*args, **kwargs)
619
- return out
620
- return wrapper
625
+ def optimizer_post_step_hook(optimizer, args, kwargs):
626
+ context = self.optimizer_context[optimizer]
627
+ self.generate_param_metrics(context, MonitorConst.POST_PARAM)
621
628
 
622
629
  if self.optimizer_hooked:
623
630
  return
624
631
 
625
- optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
632
+ self.pre_step_hooks.append(optimizer_pre_step_hook)
633
+ self.post_step_hooks.append(optimizer_post_step_hook)
626
634
 
627
635
  self.optimizer_hooked = True
628
636
  return
@@ -682,6 +690,12 @@ class TrainerMon:
682
690
  self.write_mv_tb(context)
683
691
  self.write_param_tb(context)
684
692
  self.write_adhoc_check(context.step)
693
+ if self.stack_info:
694
+ self.write_stack_info()
695
+ self.stack_info = False
696
+ for handle in self.handles["stack"]:
697
+ handle.remove()
698
+ self.handles["stack"].clear()
685
699
 
686
700
  if self.ur_distribution:
687
701
  for param_name, _ in context.param_adam_update.items():
@@ -714,13 +728,16 @@ class TrainerMon:
714
728
 
715
729
  def patch_step(func, optimizer):
716
730
  def wrapper(*args, **kwargs):
731
+ for hook in self.pre_step_hooks:
732
+ hook(optimizer, args, kwargs)
717
733
  out = func(*args, **kwargs)
734
+ for hook in self.post_step_hooks:
735
+ hook(optimizer, args, kwargs)
718
736
  step_final_hook(optimizer, args, kwargs)
719
737
  return out
720
738
  return wrapper
721
739
 
722
740
  optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
723
- self.origin_step_func = optimizer.__class__.step
724
741
  return
725
742
 
726
743
  def hook_modules(self):
@@ -764,6 +781,16 @@ class TrainerMon:
764
781
  BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
765
782
  return
766
783
 
784
+ def register_param_call_id(self, hook_name: str, key: str):
785
+ """
786
+ :param hook_name:
787
+ :param key: str, '0:relu_0/output_grad'
788
+ :return:
789
+ """
790
+ logger.debug(f"{hook_name} {key}: {self.call_id}")
791
+ self.param_name_call_id[key] = self.call_id
792
+ self.call_id += 1
793
+
767
794
  def _remove_all_hooks(self, optimizer):
768
795
  # 清空hook handle
769
796
  for handle in self.handles['xy']:
@@ -789,14 +816,18 @@ class TrainerMon:
789
816
  logger.info("remove _ParamAndGradBucketGroup start_grad_sync")
790
817
  except ImportError:
791
818
  pass
792
- else: # not megatron
819
+ elif self.fsdp_post_backward_hook: # fsdp
820
+ torch.distributed.fsdp._runtime_utils._post_backward_hook = self.fsdp_post_backward_hook
821
+ logger.info("remove patch_post_backward_hook in fsdp.")
822
+ else: # not megatron and not fsdp
793
823
  for handle in self.handles['wgrads']:
794
824
  handle.remove()
795
825
  self.handles['wgrads'].clear()
796
826
  self.weight_hooked = False
797
827
 
798
828
  if self.optimizer_hooked:
799
- optimizer.__class__.step = self.origin_step_func
829
+ self.pre_step_hooks.clear()
830
+ self.post_step_hooks.clear()
800
831
 
801
832
  for _, context in self.optimizer_context.items():
802
833
  context.reset()
@@ -811,7 +842,6 @@ class TrainerMon:
811
842
 
812
843
  # 清空节点缓存
813
844
  self.param2name.clear()
814
- self.name2index.clear()
815
845
  self.name2indices.clear()
816
846
  self.name2param.clear()
817
847
  self.duplicate_param.clear()
@@ -871,27 +901,33 @@ class TrainerMon:
871
901
  return False
872
902
 
873
903
  def _register_chunk(self, model_chunk, prefix):
874
- index = 0
875
904
  for (param_name, param) in model_chunk.named_parameters():
876
905
  if not param.requires_grad:
877
906
  continue
907
+ if not self.fsdp_wrapped_module and param_name.startswith("_fsdp_wrapped_module"):
908
+ self.fsdp_wrapped_module = True
878
909
  if self._is_target_param(param_name, param, prefix):
879
910
  name = prefix + squash_param_name(param_name, self.squash_name)
880
911
  if name in self.param2name.values():
881
912
  name = prefix + param_name
882
913
  self.param2name[param] = name
883
914
  self.name2param[name] = param
884
- self.name2index[name] = index
885
915
 
886
916
  if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group):
887
917
  self.duplicate_param[name] = True
888
918
  if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
889
919
  self.duplicate_param[name] = True
920
+
921
+ keywords = [
922
+ MonitorConst.PRE_GRAD,
923
+ MonitorConst.POST_GRAD,
924
+ MonitorConst.PRE_PARAM,
925
+ MonitorConst.POST_PARAM
926
+ ]
890
927
  self.name2tag[name] = {
891
- MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank),
892
- MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank)
928
+ k: get_summary_writer_tag_name(name, k, self.rank)
929
+ for k in keywords
893
930
  }
894
- index += 1
895
931
 
896
932
  def _register_param_name(self):
897
933
  for vpp_stage, model_chunk in enumerate(self.model):
@@ -914,11 +950,17 @@ class TrainerMon:
914
950
  # nothing to hook
915
951
  return 0
916
952
 
917
- def fwd_hook_fun(module, module_input, module_output, name):
953
+ def fwd_hook_fun(module, args, kwargs, module_output, name):
918
954
  if not module.training or is_recomputation():
919
955
  # 1 only monitor training stage.
920
956
  # 2 when open recompute, skip recomputed forward stage.
921
957
  return
958
+
959
+ module_input = [tensor for tensor in args if torch.is_tensor(tensor)]
960
+ if kwargs:
961
+ kwargs_tensors = [tensor for tensor in kwargs.values() if torch.is_tensor(tensor)]
962
+ module_input.extend(kwargs_tensors)
963
+
922
964
  if module not in self.module_fwd_hook_context_by_module:
923
965
  self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
924
966
  context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
@@ -927,34 +969,20 @@ class TrainerMon:
927
969
  Const.INPUT: get_param_struct(module_input),
928
970
  Const.OUTPUT: get_param_struct(module_output)
929
971
  }
972
+
930
973
  if self.print_struct:
931
974
  self.module_struct[context.module_name].update(context.struct)
932
975
  return
933
- if not context.format_by_arg:
934
- context.set_format_by_arg(Const.INPUT, self.config['targets'])
935
- context.set_format_by_arg(Const.OUTPUT, self.config['targets'])
936
- if not context.format_by_arg:
937
- return
938
- if not context.verified:
939
- context.focused_in_col = validate_config_spec(context.format_by_arg[Const.INPUT],
940
- module_input, context.module_name,
941
- Const.INPUT)
942
- context.focused_out_col = validate_config_spec(context.format_by_arg[Const.OUTPUT],
943
- module_output, context.module_name,
944
- Const.OUTPUT)
945
- context.verified = True
946
- # expect output be tensor type
976
+
947
977
  tbtag_tensor_map = {}
948
- cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
949
978
  tbtag_tensor_map.update(
950
979
  self.build_tbtag_tensor_map(
951
- f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
952
- MonitorConst.ACTV, cared_input))
953
- cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
980
+ f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
981
+ MonitorConst.ACTV, module_input))
954
982
  tbtag_tensor_map.update(
955
983
  self.build_tbtag_tensor_map(
956
- f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
957
- MonitorConst.ACTV, cared_output))
984
+ f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
985
+ MonitorConst.ACTV, module_output))
958
986
 
959
987
  get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
960
988
  context.micro_step += 1
@@ -972,31 +1000,17 @@ class TrainerMon:
972
1000
  if self.print_struct:
973
1001
  self.module_struct[context.module_name].update(context.struct)
974
1002
  return
975
- if not context.format_by_arg:
976
- context.set_format_by_arg(MonitorConst.INPUT_GRAD, self.config['targets'])
977
- context.set_format_by_arg(MonitorConst.OUTPUT_GRAD, self.config['targets'])
978
- if not context.format_by_arg:
979
- return
980
- if not context.verified:
981
- context.focused_in_col = validate_config_spec(
982
- context.format_by_arg[MonitorConst.INPUT_GRAD],
983
- input_grad, context.module_name, MonitorConst.INPUT_GRAD)
984
- context.focused_out_col = validate_config_spec(
985
- context.format_by_arg[MonitorConst.OUTPUT_GRAD],
986
- output_grad, context.module_name, MonitorConst.OUTPUT_GRAD)
987
- context.verified = True
988
1003
 
989
1004
  tbtag_tensor_map = {}
990
- cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
991
1005
  tbtag_tensor_map.update(
992
1006
  self.build_tbtag_tensor_map(
993
- f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
994
- MonitorConst.ACTV, cared_input_grad))
995
- cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
1007
+ f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
1008
+ MonitorConst.ACTVGRAD, input_grad))
1009
+
996
1010
  tbtag_tensor_map.update(
997
1011
  self.build_tbtag_tensor_map(
998
- f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
999
- MonitorConst.ACTV, cared_output_grad))
1012
+ f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
1013
+ MonitorConst.ACTVGRAD, output_grad))
1000
1014
 
1001
1015
  if context.micro_step == 0 and context.actvgrad:
1002
1016
  logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
@@ -1010,17 +1024,30 @@ class TrainerMon:
1010
1024
  context.micro_step = 0
1011
1025
  return
1012
1026
 
1027
+ def stack_hook(module, args, kwargs, module_output, name):
1028
+ if module not in self.module_fwd_hook_context_by_module:
1029
+ self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
1030
+ context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
1031
+ context.stack = analyze_api_call_stack(name)
1032
+ return
1033
+
1013
1034
  if self.backward_only and self.forward_only:
1014
1035
  logger.warning('not enable backward_only and forward_only simultaneously')
1015
1036
 
1016
1037
  hooked_count = 0
1017
- if self.xy_distribution or self.print_struct:
1018
- for module_name, submodule in module.named_modules():
1019
- name = self._is_target_module(module_name, target_names, vpp_stage)
1020
- if not name:
1021
- continue
1038
+ for module_name, submodule in module.named_modules():
1039
+ if self.stack_info:
1040
+ name = vpp_stage + squash_param_name(module_name, self.squash_name)
1041
+ handle = submodule.register_forward_hook(partial(stack_hook, name=name), with_kwargs=True)
1042
+ self.handles['stack'].append(handle)
1043
+ name = self._is_target_module(module_name, target_names, vpp_stage)
1044
+ if not name:
1045
+ continue
1046
+ if submodule.__class__.__name__ == "FullyShardedDataParallel":
1047
+ continue
1048
+ if self.xy_distribution or self.print_struct:
1022
1049
  if not self.backward_only:
1023
- handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name))
1050
+ handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name), with_kwargs=True)
1024
1051
  self.handles['xy'].append(handle)
1025
1052
  if not self.forward_only and not self.has_register_backward_hook(name, submodule):
1026
1053
  handle = submodule.register_full_backward_hook(bwd_hook_fun)
@@ -1049,7 +1076,7 @@ class TrainerMon:
1049
1076
  if tag is None:
1050
1077
  continue
1051
1078
  grad_dict[tag] = grad
1052
- self._register_param_call_id("sync_grad_func", tag)
1079
+ self.register_param_call_id("sync_grad_func", tag)
1053
1080
  get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
1054
1081
  out = sync_grad_func(bucket)
1055
1082
  return out
@@ -1058,7 +1085,14 @@ class TrainerMon:
1058
1085
 
1059
1086
  if not self.wg_distribution:
1060
1087
  return
1088
+ if self.fsdp_wrapped_module:
1089
+ # patch fsdp _runtime_utils._post_backward_hook
1090
+ self._patch_fsdp_post_backward_hook()
1091
+ return
1061
1092
 
1093
+ if self.monitor_mbs_grad:
1094
+ self._hook_weights()
1095
+ return
1062
1096
  try:
1063
1097
  from megatron.core.distributed.param_and_grad_buffer import Bucket
1064
1098
  self.origin_start_grad_sync = Bucket.start_grad_sync
@@ -1076,19 +1110,62 @@ class TrainerMon:
1076
1110
  logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0")
1077
1111
  except ImportError:
1078
1112
  self.enable_megatron = False | self.enable_megatron
1113
+ if self.enable_megatron:
1114
+ return
1079
1115
 
1080
- if not self.enable_megatron:
1081
- self._hook_weights()
1116
+ # default hook weights
1117
+ self._hook_weights()
1118
+
1119
+ def _patch_fsdp_post_backward_hook(self):
1120
+ """
1121
+ FSDP runtime 需要处理整个forward和backward计算和通信的流程,通过override nn.Module的forward,定义相应的逻辑。
1122
+ 对AccumulateGrad对象注册hook,可以在backward计算grad后立刻执行,在reduce_scatter操作前采集梯度累计后,通信聚合前的梯度。
1123
+ 每个forward阶段,fsdp对AccumulateGrad重复注册hook方法,monitor工具内注册hook无法生效,
1124
+ 因此对_post_backward_hook进行patch,在backward后,reduce_scatter前采集梯度。
1125
+ """
1126
+ def patch_post_backward_hook(_post_backward_hook):
1127
+ def wrapper(state, handle, *unused):
1128
+ grad_dict = {}
1129
+ offset = 0
1130
+ for param, name in self.param2name.items():
1131
+ limit = param.numel()
1132
+ if not limit:
1133
+ continue
1134
+ grad = handle.flat_param.grad[offset:offset + limit]
1135
+ offset += limit
1136
+ tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
1137
+ if tag is None:
1138
+ continue
1139
+ grad_dict[tag] = grad
1140
+ self.register_param_call_id("_post_backward_hook", tag)
1141
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
1142
+ out = _post_backward_hook(state, handle, *unused)
1143
+ return out
1144
+
1145
+ return wrapper
1146
+
1147
+ logger.info("Patch fsdp _post_backward_hook, collect pre_grad metrics.")
1148
+ self.fsdp_post_backward_hook = torch.distributed.fsdp._runtime_utils._post_backward_hook
1149
+ torch.distributed.fsdp._runtime_utils._post_backward_hook = \
1150
+ patch_post_backward_hook(torch.distributed.fsdp._runtime_utils._post_backward_hook)
1082
1151
 
1083
1152
  def _hook_weights(self):
1153
+ """
1154
+ 遍历参数的梯度生成函数(grad_acc),并挂载hook,以便在该参数所有梯度计算后,采集通信聚合前梯度数据。
1155
+ """
1084
1156
  context = self.grad_context
1085
1157
 
1086
1158
  @torch.no_grad
1087
- def param_hook(*args, context_dict, param, key, name):
1159
+ def param_hook(*args, context_dict, param, name):
1160
+ key = name
1161
+ if self.monitor_mbs_grad:
1162
+ key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
1163
+
1164
+ key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
1165
+ self.register_param_call_id("param_hook", key)
1088
1166
  param.micro_step += 1
1089
- self._register_param_call_id("param_hook", key)
1090
- if param.micro_step == self.micro_batch_number:
1091
- param.micro_step = 0
1167
+
1168
+ if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
1092
1169
  if self.params_have_main_grad:
1093
1170
  grad = param.main_grad
1094
1171
  else:
@@ -1097,25 +1174,17 @@ class TrainerMon:
1097
1174
  grad = grad.float()
1098
1175
  context_dict[key] = grad.clone()
1099
1176
 
1177
+ if param.micro_step == self.micro_batch_number:
1178
+ param.micro_step = 0
1179
+
1100
1180
  logger.info("hooking weights.")
1101
1181
  for param, name in self.param2name.items():
1102
- key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
1103
1182
  setattr(param, 'micro_step', 0)
1104
1183
  param_tmp = param.expand_as(param)
1105
1184
  grad_acc = param_tmp.grad_fn.next_functions[0][0]
1106
1185
  handle = grad_acc.register_hook(
1107
- partial(param_hook, context_dict=context.acc, param=param, key=key, name=name))
1186
+ partial(param_hook, context_dict=context.acc, param=param, name=name))
1108
1187
  self.grad_accs.append(grad_acc)
1109
1188
  self.handles['wgrads'].append(handle)
1110
1189
 
1111
1190
  self.weight_hooked = True
1112
-
1113
- def _register_param_call_id(self, hook_name: str, key: str):
1114
- """
1115
- :param hook_name:
1116
- :param key: str, '0:relu_0/output_grad'
1117
- :return:
1118
- """
1119
- logger.debug(f"{hook_name} {key}: {self.call_id}")
1120
- self.param_name_call_id[key] = self.call_id
1121
- self.call_id += 1