mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -20,21 +20,24 @@ from collections import defaultdict
20
20
  from datetime import datetime
21
21
 
22
22
  import pytz
23
- import mindspore as ms
23
+ import pandas as pd
24
+ import mindspore
24
25
  from mindspore import Tensor, mint
25
26
  from mindspore import nn, _no_grad
26
- from mindspore.communication import get_rank
27
27
 
28
28
  from msprobe.core.common.log import logger
29
- from msprobe.core.common.const import MonitorConst
29
+ from msprobe.core.common.const import MonitorConst, Const
30
30
  from msprobe.core.common.file_utils import load_json, save_json
31
+ from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter
32
+ from msprobe.mindspore.common.utils import is_mindtorch
33
+ from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank
31
34
  from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \
32
- is_skip_step, get_metrics, get_single_metrics, get_target_output_dir
33
- from msprobe.mindspore.monitor.module_spec_verifier import validate_config_spec
34
- from msprobe.mindspore.monitor.anomaly_detect import AnomalyScanner, AnomalyDataFactory, \
35
- CSVWriterWithAD, BaseWriterWithAD, WriterInput
36
- from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
37
- get_process_group
35
+ is_skip_step, get_metrics, get_target_output_dir
36
+ from msprobe.mindspore.monitor.optimizer_collect import OptimizerMonFactory
37
+ from msprobe.mindspore.monitor.data_writers import CSVWriterWithAD, BaseWriterWithAD, WriterInput
38
+ from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate
39
+ from msprobe.core.common.file_utils import write_df_to_csv
40
+ from msprobe.core.common.utils import analyze_api_call_stack
38
41
 
39
42
  FORMAT_MAPPING = {
40
43
  MonitorConst.CSV: CSVWriterWithAD,
@@ -88,24 +91,7 @@ class ModuleHookContext:
88
91
  self.actvgrad = []
89
92
  self.module_name = module_name
90
93
  self.struct = {}
91
- self.format_by_arg = {}
92
- self.verified = False
93
- self.focused_in_col = 0
94
- self.focused_out_col = 0
95
- self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found
96
-
97
- def set_format_by_arg(self, key_name: str, target_config: dict):
98
- cared = target_config.get(self.module_name, self.struct)
99
- if key_name in cared:
100
- if isinstance(cared[key_name], dict):
101
- # current cared is self.struct
102
- config = cared[key_name].get('config')
103
- self.format_by_arg[key_name] = config
104
- else:
105
- # current cared is target_config[self.module_name]
106
- self.format_by_arg[key_name] = cared[key_name]
107
- elif key_name in ['input', 'input_grad']:
108
- self.ignore_in = True
94
+ self.stack = ""
109
95
 
110
96
  def reset(self):
111
97
  self.actv.clear()
@@ -186,6 +172,7 @@ class TrainerMon:
186
172
  self.config_file_path = config_file_path
187
173
  self.process_group = process_group
188
174
  self.params_have_main_grad = params_have_main_grad
175
+ self.is_mindtorch = is_mindtorch()
189
176
  self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
190
177
  self.config = load_json(config_file_path)
191
178
  validate_config(self.config)
@@ -218,6 +205,7 @@ class TrainerMon:
218
205
  self.dp_group = None
219
206
  self.tp_group = None
220
207
  self.micro_batch_number = 1
208
+ self.optimizer_mon = None
221
209
 
222
210
  # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
223
211
  self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
@@ -240,6 +228,8 @@ class TrainerMon:
240
228
  self.optimizer_hooked = False
241
229
  self.param_registered = False
242
230
  self.struct_printed = False
231
+ self.pre_step_hooks = []
232
+ self.post_step_hooks = []
243
233
 
244
234
  # 动静态区分
245
235
  self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
@@ -276,6 +266,9 @@ class TrainerMon:
276
266
  self.param_distribution = self.config.get("param_distribution", False)
277
267
  self.mg_direction = self.config.get('mg_direction', False) # main grad direction
278
268
  self.cc_distribution = self.config.get("cc_distribution", {}) # communication ops
269
+ self.stack_info = self.config.get('stack_info', False)
270
+ self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False)
271
+
279
272
  if not self.cc_distribution.get('enable', False):
280
273
  self.cc_log_only = False
281
274
  else:
@@ -296,18 +289,25 @@ class TrainerMon:
296
289
  if self.format not in FORMAT_MAPPING:
297
290
  logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
298
291
  self.format = MonitorConst.CSV
299
- writer = FORMAT_MAPPING[self.format]
300
292
  self.step_count_per_record = self.config.get('step_count_per_record', 1)
301
- self.summary_writer = writer(
302
- WriterInput(
303
- self.tensorboard_dir,
304
- self.alert_rules,
305
- self.unique_id,
306
- self.anomaly_data_factory,
307
- self.ndigits,
308
- self.step_count_per_record
293
+ if not self.module_rank_list or (self.rank in self.module_rank_list):
294
+ writer = FORMAT_MAPPING[self.format]
295
+ self.summary_writer = writer(
296
+ WriterInput(
297
+ self.tensorboard_dir,
298
+ self.alert_rules,
299
+ self.unique_id,
300
+ self.anomaly_data_factory,
301
+ self.ndigits,
302
+ self.step_count_per_record
303
+ )
309
304
  )
310
- )
305
+
306
+ # 初始化anomaly detected文件目录
307
+ if self.anomaly_data_factory:
308
+ self.anomaly_data_writer = AnomalyDataWriter(os.path.join(self.output_base_dir, "anomaly_detected"),
309
+ self.rank)
310
+ self.anomaly_data_writer.init_detected_json()
311
311
 
312
312
  def common_info(self):
313
313
  if not self.xy_distribution:
@@ -339,6 +339,7 @@ class TrainerMon:
339
339
  self.micro_batch_number = grad_acc_steps
340
340
  self.dp_group = dp_group
341
341
  self.tp_group = tp_group
342
+ self.optimizer_mon = OptimizerMonFactory.create_optimizer_mon(optimizer)
342
343
  self.hook_step_final(optimizer)
343
344
  if not isinstance(model, list):
344
345
  model = [model]
@@ -359,16 +360,28 @@ class TrainerMon:
359
360
  context.step - self.start_step) % self.step_interval == 0)
360
361
  if module_rank_valid and step_condition:
361
362
  self.has_collect_times += 1
363
+
364
+ if self.anomaly_data_factory:
365
+ self.anomaly_data_factory.set_call_id(self.param_name_call_id)
362
366
  self.write_xy_tb(context.step)
363
367
  self.write_grad_tb(context.step)
364
368
  self.write_mv_tb(context)
365
369
  self.write_param_tb(context)
370
+ if self.stack_info:
371
+ self.write_stack_info()
372
+ self.stack_info = False
373
+ for handle in self.handles["stack"]:
374
+ handle.remove()
375
+ self.handles["stack"].clear()
366
376
 
367
377
  if context.metric_dict:
368
378
  self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
369
379
  context.metric_dict.clear()
370
380
 
381
+ if self.anomaly_data_factory:
382
+ self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
371
383
  self.summary_writer.clear_anomalies()
384
+
372
385
  self.call_id = 0
373
386
  self.param_name_call_id.clear()
374
387
 
@@ -378,7 +391,23 @@ class TrainerMon:
378
391
  context.step += 1
379
392
  self.dynamic_monitor(optimizer)
380
393
 
381
- optimizer.register_forward_hook(step_final_hook)
394
+
395
+ def patch_step(func, optimizer):
396
+ def wrapper(*args, **kwargs):
397
+ for hook in self.pre_step_hooks:
398
+ hook(optimizer, args, kwargs)
399
+ out = func(*args, **kwargs)
400
+ for hook in self.post_step_hooks:
401
+ hook(optimizer, args, kwargs)
402
+ step_final_hook(optimizer, args, kwargs)
403
+ return out
404
+ return wrapper
405
+
406
+ if self.is_mindtorch:
407
+ optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
408
+ else:
409
+ optimizer.__class__.construct = patch_step(optimizer.__class__.construct, optimizer)
410
+
382
411
  return
383
412
 
384
413
  def dynamic_monitor(self, optimizer):
@@ -413,7 +442,7 @@ class TrainerMon:
413
442
  logger.error(f"set config wrong because {e}, not updated, please check!!!")
414
443
  return
415
444
 
416
- self._remove_all_hooks()
445
+ self._remove_all_hooks(optimizer)
417
446
  self.register_hooks(optimizer)
418
447
 
419
448
  def register_hooks(self, optimizer):
@@ -438,45 +467,36 @@ class TrainerMon:
438
467
 
439
468
  hooked_count = 0
440
469
  for vpp_stage, model_chunk in enumerate(self.model):
441
- if not isinstance(model_chunk, nn.Cell):
470
+ if not is_valid_instance(model_chunk):
442
471
  logger.info("Target Model is not Cell")
443
472
  continue
444
473
  vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
445
- targets = [x for x, _ in model_chunk.cells_and_names()] if self.print_struct else self.targets.keys()
474
+ targets = [x for x, _ in get_submodules(model_chunk)] if self.print_struct else self.targets.keys()
446
475
  hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
447
476
  logger.info(f"> {hooked_count} modules are monitored.")
448
477
 
449
478
  def hook_optimizer(self, optimizer):
450
- def optimizer_pre_hook_function(opt, grad_names, gradients):
479
+ def optimizer_pre_step_hook(opt, *args, **kwargs):
451
480
  context = self.optimizer_context[opt]
452
481
  if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
453
482
  self.collect_times):
454
483
  return
455
- gradient_list = gradients[0] if isinstance(gradients, tuple) else gradients
456
- is_select = self.is_select
457
- for idx, grad in enumerate(gradient_list):
458
- grad_name = grad_names[idx]
459
- if is_select and grad_name not in self.targets:
460
- continue
461
- get_single_metrics(self.ops, grad_name, grad, context.param_weight_grad)
462
-
463
- if self.mv_distribution:
464
- # fetch mean
465
- for param in m_list:
466
- name = param.name
467
- if is_select and name not in self.targets:
468
- continue
469
- get_single_metrics(self.ops, name, param, context.exp_avg_metric)
470
- # fetch variance
471
- for param in v_list:
472
- name = param.name
473
- if is_select and name not in self.targets:
474
- continue
475
- get_single_metrics(self.ops, name, param, context.exp_avg_sq_metric)
476
- if self.param_distribution:
477
- for param in param_list:
478
- get_single_metrics(self.ops, param.name, param, context.param_metric)
479
- self.generate_wgrad_metrics()
484
+
485
+ grad_dict = {}
486
+ if self.wg_distribution:
487
+ grad_dict = self.optimizer_mon.fetch_grad(self, self.param2name)
488
+
489
+ if self.mv_distribution or self.ur_distribution or self.mg_direction:
490
+ if self.is_mindtorch:
491
+ context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, \
492
+ context.param_adam_ratio = self.optimizer_mon.fetch_mv(self, self.param2name)
493
+ else:
494
+ context.param_exp_avg, context.param_exp_avg_sq = self.get_mv_for_ms(optimizer)
495
+
496
+ self.generate_wgrad_metrics(grad_dict)
497
+ self.generate_mv_metrics(context)
498
+ self.generate_param_metrics(context, MonitorConst.PRE_PARAM)
499
+
480
500
  metric_dict = {}
481
501
  for cc in self.cc_context.values():
482
502
  cc.aggregate()
@@ -488,63 +508,86 @@ class TrainerMon:
488
508
  context.metric_dict = metric_dict
489
509
  return
490
510
 
491
- def optimizer_pre_hook_wrapper(func, grad_names):
492
- def wrapper(opt, gradients):
493
- return func(opt, grad_names, gradients)
494
- return wrapper
511
+ def optimizer_post_step_hook(optimizer, args, kwargs):
512
+ context = self.optimizer_context[optimizer]
513
+ self.generate_param_metrics(context, MonitorConst.POST_PARAM)
514
+
495
515
 
496
516
  if self.optimizer_hooked or not self.is_target_rank():
497
517
  return
498
518
 
499
- m_list = []
500
- v_list = []
501
- param_list = []
502
- grad_names = []
503
- for param in optimizer.get_parameters():
504
- if MonitorConst.EXP_AVG_SQ in param.name:
505
- v_list.append(param)
506
- elif MonitorConst.EXP_AVG in param.name:
507
- m_list.append(param)
508
- elif param.name in ['global_step', 'learning_rate']:
509
- pass
510
- else:
511
- param_list.append(param)
512
- grad_names.append(param.name)
513
-
514
- handle = optimizer.register_forward_pre_hook(
515
- optimizer_pre_hook_wrapper(optimizer_pre_hook_function, grad_names))
516
- self.handles['optimizer'].append(handle)
519
+ self.pre_step_hooks.append(optimizer_pre_step_hook)
520
+ self.post_step_hooks.append(optimizer_post_step_hook)
517
521
  self.optimizer_hooked = True
518
522
  return
519
523
 
520
- def generate_wgrad_metrics(self):
524
+ def generate_wgrad_metrics(self, grad_dict):
521
525
  if not self.wg_distribution:
522
- return {}, {}
526
+ return
523
527
 
524
- if self.weight_hooked:
525
- try:
526
- get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
527
- except Exception as e:
528
- logger.warning(f"An error occurred while generating wgrad pre metrics")
529
- return {}, {}
528
+ get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
529
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
530
530
 
531
- grad_dict = {}
532
- for param, name in self.param2name.items():
533
- if self.duplicate_param.get(name, False):
534
- continue
535
- grad = param.main_grad if self.params_have_main_grad else param.grad
536
- if grad is None:
537
- logger.warning(f"grad is None: {name}, maybe something wrong happened.")
531
+ def generate_param_map(self, tag, param_tensor):
532
+ metrics = {}
533
+ if not self.is_mindtorch:
534
+ return param_tensor
535
+ for name in self.param2name.values():
536
+ key = get_summary_writer_tag_name(name, tag, self.rank)
537
+ self.register_param_call_id("optimizer_pre_step_hook", key)
538
+ if name not in param_tensor or param_tensor[name] is None:
538
539
  continue
539
- tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
540
- self._register_param_call_id("hook_optimizer", tag)
541
- grad_dict[tag] = grad
542
- try:
543
- get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
544
- except Exception as e:
545
- logger.warning(f"An error occurred while generating wgrad post metrics")
540
+ metrics[key] = param_tensor[name]
541
+ return metrics
542
+
543
+ def generate_param_metrics(self, opt_context, stage=MonitorConst.PRE_PARAM):
544
+ if not self.param_distribution:
545
+ return
546
+ tag2param = {
547
+ self.name2tag.get(name, {}).get(stage): param
548
+ for name, param in self.name2param.items()
549
+ if param.numel() != 0
550
+ }
551
+ get_metrics(self.ops, tag2param, self.eps, opt_context.param_metric)
552
+
553
+ def get_mv_for_ms(self, opt):
554
+ if not self.mv_distribution:
546
555
  return {}, {}
547
- return self.grad_context.post, self.grad_context.pre
556
+ common_opt = opt
557
+ if not is_valid_instance(opt):
558
+ common_opt = getattr(opt, 'optimizer')
559
+ if not is_valid_instance(common_opt):
560
+ logger.warning("Optimizer is not valid, please check usage")
561
+ return {}, {}
562
+ m_dict = {}
563
+ v_dict = {}
564
+ for name, param in get_parameters(common_opt):
565
+ if MonitorConst.EXP_AVG_SQ in name:
566
+ v_dict[name] = param
567
+ elif MonitorConst.EXP_AVG in name:
568
+ m_dict[name] = param
569
+ return m_dict, v_dict
570
+
571
+ def generate_mv_metrics(self, opt_context):
572
+ if not self.mv_distribution:
573
+ return
574
+ opt_context.exp_avg_metric = {}
575
+ opt_context.exp_avg_sq_metric = {}
576
+ m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
577
+ v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
578
+ get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
579
+ get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
580
+
581
+ def write_stack_info(self):
582
+ stack_data = []
583
+ header = ["module_name", "stack_info"]
584
+ stack_data.append(header)
585
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
586
+ stack_data.append([fwd_context.module_name, fwd_context.stack])
587
+ filepath = os.path.join(self.tensorboard_dir, f'stack_info.csv')
588
+ if not os.path.exists(filepath):
589
+ data_frame = pd.DataFrame(columns=stack_data)
590
+ write_df_to_csv(data_frame, filepath)
548
591
 
549
592
  def write_xy_tb(self, step):
550
593
  if not self.xy_distribution:
@@ -552,27 +595,32 @@ class TrainerMon:
552
595
  for _, fwd_context in self.module_fwd_hook_context_by_module.items():
553
596
  if len(fwd_context.actv) == 0:
554
597
  continue
555
- self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, 'actv')
598
+ self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, MonitorConst.ACTV)
556
599
  fwd_context.actv.clear()
557
600
  if self.grad_context.actv:
558
- self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, 'actv_grad')
601
+ self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD)
559
602
 
560
603
  def write_param_tb(self, opt_context):
561
604
  if not self.param_distribution:
562
605
  return
563
- self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, 'param')
606
+ param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.PRE_PARAM in k}
607
+ updated_param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.POST_PARAM in k}
608
+ self.summary_writer.write_metrics(self.ops, param_metrics, opt_context.step, MonitorConst.PRE_PARAM)
609
+ self.summary_writer.write_metrics(self.ops, updated_param_metrics, opt_context.step, MonitorConst.POST_PARAM)
564
610
 
565
611
  def write_mv_tb(self, opt_context):
566
612
  if not self.mv_distribution:
567
613
  return
568
- self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, 'exp_avg')
569
- self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq')
614
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, MonitorConst.EXP_AVG)
615
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step,
616
+ MonitorConst.EXP_AVG_SQ)
570
617
 
571
618
  def write_grad_tb(self, step):
572
619
  if not self.wg_distribution:
573
620
  return
574
621
 
575
- self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
622
+ self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced',
623
+ use_micro_step=self.monitor_mbs_grad)
576
624
  self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
577
625
 
578
626
  def is_target_rank(self):
@@ -580,13 +628,38 @@ class TrainerMon:
580
628
  return False
581
629
  return True
582
630
 
583
- def build_tbtag_tensor_map(self, module_name, tag, tensor):
584
- metrics = {}
585
- key = get_summary_writer_tag_name(module_name, tag, str(self.rank))
631
+ def build_tbtag_tensor_map(self, module_name, suffix, tag, tensor):
632
+ """
633
+ :param module_name: str of module name
634
+ :param suffix:
635
+ :param tag:
636
+ :param tensor: torch.tensor or tuple/list of torch.tensor
637
+ :return: tensor_map
638
+ """
639
+ tensor_map = {}
586
640
  if isinstance(tensor, Tensor):
587
- self._register_param_call_id("_hook_module", key)
588
- metrics[key] = tensor
589
- return metrics
641
+ tensor = [tensor]
642
+ if isinstance(tensor, tuple) or isinstance(tensor, list):
643
+ if len(tensor) == 1:
644
+ key = get_summary_writer_tag_name(module_name + suffix, tag, self.rank)
645
+ self.register_param_call_id("_hook_module", key)
646
+ tensor_map[key] = tensor[0]
647
+ else:
648
+ for i, tensor_i in enumerate(tensor):
649
+ key = get_summary_writer_tag_name(module_name + f"_{i}" + suffix, tag, self.rank)
650
+ self.register_param_call_id("_hook_module", key)
651
+ tensor_map[key] = tensor_i
652
+ return tensor_map
653
+
654
+ def register_param_call_id(self, hook_name: str, key: str):
655
+ """
656
+ :param hook_name:
657
+ :param key: str, '0:relu_0/output_grad'
658
+ :return:
659
+ """
660
+ logger.debug(f"{hook_name} {key}: {self.call_id}")
661
+ self.param_name_call_id[key] = self.call_id
662
+ self.call_id += 1
590
663
 
591
664
  def _register_param_name(self):
592
665
  for vpp_stage, model_chunk in enumerate(self.model):
@@ -595,8 +668,7 @@ class TrainerMon:
595
668
 
596
669
  def _register_chunk(self, model_chunk, prefix):
597
670
  index = 0
598
- for param in model_chunk.get_parameters():
599
- param_name = param.name
671
+ for param_name, param in get_parameters(model_chunk):
600
672
  if not param.requires_grad:
601
673
  continue
602
674
  if self._is_target_param(param_name, param, prefix):
@@ -611,25 +683,37 @@ class TrainerMon:
611
683
  self.duplicate_param[name] = True
612
684
  if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
613
685
  self.duplicate_param[name] = True
686
+ keywords = [
687
+ MonitorConst.PRE_GRAD,
688
+ MonitorConst.POST_GRAD,
689
+ MonitorConst.PRE_PARAM,
690
+ MonitorConst.POST_PARAM
691
+ ]
614
692
  self.name2tag[name] = {
615
- MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank),
616
- MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank)
693
+ k: get_summary_writer_tag_name(name, k, self.rank)
694
+ for k in keywords
617
695
  }
618
696
  index += 1
619
697
 
620
698
  def _hook_module(self, target_names, module, vpp_stage=''):
621
- if not isinstance(module, nn.Cell):
699
+ if not is_valid_instance(module):
622
700
  # nothing to hook
623
701
  return 0
624
702
 
625
- def fwd_hook_fun(module, module_input, module_output, name):
703
+ def fwd_hook_fun(module, args, kwargs, module_output, name):
704
+
705
+ module_input = [tensor for tensor in args if isinstance(tensor, Tensor)]
706
+ if kwargs:
707
+ kwargs_tensors = [tensor for tensor in kwargs.values() if isinstance(tensor, Tensor)]
708
+ module_input.extend(kwargs_tensors)
709
+
626
710
  if module not in self.module_fwd_hook_context_by_module:
627
711
  self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
628
712
  context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
629
713
  if not context.struct:
630
714
  context.struct = {
631
- MonitorConst.ACTV_IN: get_param_struct(module_input),
632
- MonitorConst.ACTV_OUT: get_param_struct(module_output)
715
+ Const.INPUT: get_param_struct(module_input),
716
+ Const.OUTPUT: get_param_struct(module_output)
633
717
  }
634
718
  if self.print_struct:
635
719
  self.module_struct[context.module_name].update(context.struct)
@@ -640,31 +724,18 @@ class TrainerMon:
640
724
  self.collect_times):
641
725
  step_accumulates_one(context, self.micro_batch_number)
642
726
  return
643
- if not context.format_by_arg:
644
- context.set_format_by_arg(MonitorConst.ACTV_IN, self.targets)
645
- context.set_format_by_arg(MonitorConst.ACTV_OUT, self.targets)
646
- if not context.format_by_arg:
647
- return
648
- if not context.verified:
649
- if not context.ignore_in:
650
- context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
651
- module_input, context.module_name,
652
- MonitorConst.ACTV_IN)
653
- context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
654
- module_output, context.module_name,
655
- MonitorConst.ACTV_OUT)
656
- context.verified = True
657
727
 
658
728
  tbtag_tensor_map = {}
659
- if not context.ignore_in:
660
- cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
661
- tbtag_tensor_map.update(
662
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
663
- cared_input))
664
- cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
665
729
  tbtag_tensor_map.update(
666
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT,
667
- cared_output))
730
+ self.build_tbtag_tensor_map(
731
+ f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
732
+ MonitorConst.ACTV, module_input))
733
+ module_output = [tensor for tensor in module_output if isinstance(tensor, Tensor)] \
734
+ if isinstance(module_output, tuple) else module_output
735
+ tbtag_tensor_map.update(
736
+ self.build_tbtag_tensor_map(
737
+ f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
738
+ MonitorConst.ACTV, module_output))
668
739
  try:
669
740
  get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
670
741
  except Exception as e:
@@ -689,31 +760,17 @@ class TrainerMon:
689
760
  step_accumulates_one(context, self.micro_batch_number)
690
761
  return
691
762
 
692
- if not context.format_by_arg:
693
- context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.targets)
694
- context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.targets)
695
- if not context.format_by_arg:
696
- return
697
- if not context.verified:
698
- if not context.ignore_in:
699
- context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
700
- input_grad, context.module_name,
701
- MonitorConst.ACTVGRAD_IN)
702
- context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
703
- output_grad, context.module_name,
704
- MonitorConst.ACTVGRAD_OUT)
705
- context.verified = True
706
-
763
+ valid_input_grad = [tensor for tensor in input_grad if isinstance(tensor, Tensor)]
707
764
  tbtag_tensor_map = {}
708
- if not context.ignore_in:
709
- cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
710
- tbtag_tensor_map.update(
711
- self.build_tbtag_tensor_map(
712
- f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad))
713
- cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
714
765
  tbtag_tensor_map.update(
715
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT,
716
- cared_output_grad))
766
+ self.build_tbtag_tensor_map(
767
+ f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
768
+ MonitorConst.ACTVGRAD, valid_input_grad))
769
+
770
+ tbtag_tensor_map.update(
771
+ self.build_tbtag_tensor_map(
772
+ f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
773
+ MonitorConst.ACTVGRAD, output_grad))
717
774
 
718
775
  if context.micro_step == 0 and context.actvgrad:
719
776
  logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
@@ -727,21 +784,39 @@ class TrainerMon:
727
784
  step_accumulates_one(context, self.micro_batch_number)
728
785
  return
729
786
 
730
- def fwd_hook_fun_wrapper(fwd_hook_fun, name):
731
- def wrapper(module, module_input, module_output):
732
- return fwd_hook_fun(module, module_input, module_output, name)
733
- return wrapper
787
+ def fwd_hook_register(module, fwd_hook_fun, name):
788
+ if mindspore.__version__ >= '2.6.0':
789
+ def wrapper(module, args, kwargs, module_output):
790
+ return fwd_hook_fun(module, args, kwargs, module_output, name)
791
+ return module.register_forward_hook(wrapper, with_kwargs=True)
792
+
793
+ else:
794
+ def wrapper(module, args, module_output):
795
+ return fwd_hook_fun(module, args, None, module_output, name)
796
+ return module.register_forward_hook(wrapper)
797
+
798
+ def stack_hook(module, args, kwargs, module_output, name):
799
+ if module not in self.module_fwd_hook_context_by_module:
800
+ self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
801
+ context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
802
+ context.stack = analyze_api_call_stack(name)
803
+ return
734
804
 
735
805
  if self.backward_only and self.forward_only:
736
806
  logger.warning('not enable backward_only and forward_only simultaneously')
737
807
  hooked_count = 0
738
- if self.xy_distribution or self.print_struct:
739
- for module_name, submodule in module.cells_and_names():
740
- name = self._is_target_module(module_name, target_names, vpp_stage)
741
- if not name:
742
- continue
808
+
809
+ for module_name, submodule in get_submodules(module):
810
+ if self.stack_info:
811
+ name = vpp_stage + squash_param_name(module_name)
812
+ handle = fwd_hook_register(submodule, stack_hook, name=name)
813
+ self.handles["stack"].append(handle)
814
+ name = self._is_target_module(module_name, target_names, vpp_stage)
815
+ if not name:
816
+ continue
817
+ if self.xy_distribution or self.print_struct:
743
818
  if not self.backward_only:
744
- handle = submodule.register_forward_hook(fwd_hook_fun_wrapper(fwd_hook_fun, name=name))
819
+ handle = fwd_hook_register(submodule, fwd_hook_fun, name=name)
745
820
  self.handles['xy'].append(handle)
746
821
  if not self.forward_only:
747
822
  handle = submodule.register_backward_hook(bwd_hook_fun)
@@ -760,22 +835,30 @@ class TrainerMon:
760
835
  context = self.grad_context
761
836
 
762
837
  @_no_grad()
763
- def param_hook(grad, context_dict, param, key):
838
+ def param_hook(grad, context_dict, param, name):
839
+ key = name
840
+ if self.monitor_mbs_grad:
841
+ key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
842
+ key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
843
+ self.register_param_call_id("param_hook", key)
764
844
  param.micro_step += 1
765
- self._register_param_call_id("param_hook", key)
845
+
846
+ if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
847
+ context_dict[key] = grad
766
848
  if param.micro_step == self.micro_batch_number:
767
849
  param.micro_step = 0
768
- context_dict[key] = grad
769
850
 
770
- def param_hook_wrapper(param_hook, context_dict, param, key):
851
+ def param_hook_wrapper(param_hook, context_dict, param, name):
771
852
  def wrapper(grad):
772
- return param_hook(grad, context_dict, param, key)
853
+ return param_hook(grad, context_dict, param, name)
854
+
773
855
  return wrapper
774
856
 
857
+ logger.info("hooking weights.")
775
858
  for param, name in self.param2name.items():
776
- key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
777
859
  setattr(param, 'micro_step', 0)
778
- handle = param.register_hook(param_hook_wrapper(param_hook, context_dict=context.acc, param=param, key=key))
860
+ handle = param.register_hook(
861
+ param_hook_wrapper(param_hook, context_dict=context.acc, param=param, name=name))
779
862
  self.handles['wgrads'].append(handle)
780
863
  self.weight_hooked = True
781
864
 
@@ -801,17 +884,7 @@ class TrainerMon:
801
884
  return pattern
802
885
  return ""
803
886
 
804
- def _register_param_call_id(self, hook_name: str, key: str):
805
- """
806
- :param hook_name:
807
- :param key: str, '0:relu_0/output_grad'
808
- :return:
809
- """
810
- logger.debug(f"{hook_name} {key}: {self.call_id}")
811
- self.param_name_call_id[key] = self.call_id
812
- self.call_id += 1
813
-
814
- def _remove_all_hooks(self):
887
+ def _remove_all_hooks(self, optimizer):
815
888
  # 清空hook handle
816
889
  for handle in self.handles['xy']:
817
890
  handle.remove()
@@ -829,9 +902,8 @@ class TrainerMon:
829
902
  self.weight_hooked = False
830
903
 
831
904
  if self.optimizer_hooked:
832
- for handle in self.handles['optimizer']:
833
- handle.remove()
834
- self.handles['optimizer'].clear()
905
+ self.pre_step_hooks.clear()
906
+ self.post_step_hooks.clear()
835
907
  for _, context in self.optimizer_context.items():
836
908
  context.reset()
837
909
  self.optimizer_hooked = False
@@ -870,4 +942,4 @@ class TrainerMon:
870
942
  except Exception as e:
871
943
  logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!")
872
944
  logger.info("Finish monitor")
873
- self._remove_all_hooks()
945
+ self._remove_all_hooks(optimizer)