mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,50 +12,45 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- import time
16
15
  import json
17
16
  import os
18
17
  import uuid
19
18
  from collections import defaultdict
20
- from datetime import datetime, timezone
19
+ from datetime import datetime
21
20
  from functools import partial
22
21
 
23
22
  import pytz
24
23
  import torch
25
24
  import torch.distributed as dist
26
- from msprobe.core.common.const import MonitorConst
27
- from msprobe.core.common.file_utils import load_json
28
- from msprobe.core.common.log import logger
25
+ from torch.utils.hooks import BackwardHook
26
+
27
+ from msprobe.core.common.const import MonitorConst, Const
28
+ from msprobe.core.common.file_utils import load_json, save_json
29
+ from msprobe.pytorch.common.log import logger
30
+ from msprobe.pytorch.common.utils import is_recomputation
29
31
  from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter
30
32
  from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
31
33
  CSVWriterWithAD, BaseWriterWithAD, WriterInput
32
34
  from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
33
35
  get_process_group
34
36
  from msprobe.pytorch.monitor.features import get_sign_matches
35
- from msprobe.pytorch.monitor.module_metric import get_metrics, write_metrics_base, get_summary_writer_tag_name, \
36
- TensorMetrics, write_metrics_csv, squash_param_name
37
+ from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
38
+ TensorMetrics, squash_param_name
37
39
  from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
38
- from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory, OptimizerMon
39
- from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, is_recomputation
40
+ from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
41
+ from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \
42
+ get_output_base_dir, get_target_output_dir
40
43
  from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
41
- from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
42
- from torch.utils.hooks import BackwardHook
43
-
44
- try:
45
- import torch_npu
46
- except ImportError:
47
- pass
48
44
 
49
45
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
50
46
  if not torch_version_above_or_equal_2:
51
47
  raise ValueError("monitor require torch>=2.0")
52
48
 
53
- output_base_dir = os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
54
49
 
55
50
  FORMAT_MAPPING = {
56
- MonitorConst.TENSORBOARD: (SummaryWriterWithAD, write_metrics_base),
57
- MonitorConst.CSV: (CSVWriterWithAD, write_metrics_csv),
58
- MonitorConst.API: (BaseWriterWithAD, write_metrics_base)
51
+ MonitorConst.TENSORBOARD: SummaryWriterWithAD,
52
+ MonitorConst.CSV: CSVWriterWithAD,
53
+ MonitorConst.API: BaseWriterWithAD
59
54
  }
60
55
 
61
56
 
@@ -71,7 +66,6 @@ def param_is_data_parallel_duplicate(dp_group):
71
66
 
72
67
  class ModuleHookContext:
73
68
  def __init__(self, module_name) -> None:
74
- self.step = 0
75
69
  self.micro_step = 0
76
70
  self.actv = defaultdict(dict)
77
71
  self.actvgrad = []
@@ -81,26 +75,44 @@ class ModuleHookContext:
81
75
  self.verified = False
82
76
  self.focused_in_col = 0
83
77
  self.focused_out_col = 0
84
- self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found
85
78
 
86
79
  def set_format_by_arg(self, key_name: str, target_config: dict):
80
+ """ 按照监控对象配置format_by_arg
81
+ 1) module_name 在 target 中配置监控对象
82
+ 2) module_name 未在 targets 中配置,且 all_xy 全量监控
83
+ 3) module_name 未在 targets 中配置,且 all_xy 未全量监控
84
+
85
+ :param key_name: str, one of [input, output, input_grad, output_grad]
86
+ :param target_config: target obj in config json.
87
+ :return:
88
+ """
87
89
  cared = target_config.get(self.module_name, self.struct)
88
90
  if key_name in cared:
89
- if isinstance(cared[key_name], dict):
90
- # current cared is self.struct
91
- config = cared[key_name].get('config')
92
- self.format_by_arg[key_name] = config
93
- else:
91
+ target_module_config = cared[key_name]
92
+ if isinstance(target_module_config, dict):
93
+ # current cared is self.struct, monitor all data for module_name
94
+ self.format_by_arg[key_name] = target_module_config.get('config')
95
+ elif isinstance(target_module_config, str):
94
96
  # current cared is target_config[self.module_name]
95
- self.format_by_arg[key_name] = cared[key_name]
96
- elif key_name in ['input', 'input_grad']:
97
- self.ignore_in = True
97
+ self.format_by_arg[key_name] = target_module_config
98
+ else:
99
+ logger.warning_on_rank_0(f"target module config error, result maybe empty."
100
+ f"module_name: {self.module_name}, key_name: {key_name}")
101
+ self.format_by_arg[key_name] = None
102
+ else:
103
+ self.format_by_arg[key_name] = self.struct.get(key_name).get('config')
104
+
105
+ def reset(self):
106
+ self.actv.clear()
107
+ self.actvgrad.clear()
108
+
109
+
110
+ start_step = 0
98
111
 
99
112
 
100
113
  class OptimizerContext:
101
114
  def __init__(self) -> None:
102
- self.step = 0
103
- self.param_effective_rank = defaultdict(float)
115
+ self.step = start_step
104
116
  self.param_mg_direction = defaultdict(float)
105
117
  self.param_adam_update = defaultdict()
106
118
  self.param_adam_ratio = defaultdict()
@@ -112,6 +124,18 @@ class OptimizerContext:
112
124
  self.metric_dict = {}
113
125
  self.param_metric = {}
114
126
 
127
+ def reset(self):
128
+ self.param_mg_direction.clear()
129
+ self.param_adam_update.clear()
130
+ self.param_adam_ratio.clear()
131
+ self.param_weight_grad.clear()
132
+ self.param_exp_avg.clear()
133
+ self.exp_avg_metric.clear()
134
+ self.param_exp_avg_sq.clear()
135
+ self.exp_avg_sq_metric.clear()
136
+ self.metric_dict.clear()
137
+ self.param_metric.clear()
138
+
115
139
 
116
140
  class CommunicationContext:
117
141
  def __init__(self) -> None:
@@ -152,23 +176,131 @@ class GradContext:
152
176
  class TrainerMon:
153
177
  tensor_metrics = TensorMetrics()
154
178
 
155
- def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None:
156
- """
157
- opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer"
158
- """
179
+ def __init__(self, config_file_path, process_group=None, params_have_main_grad=True) -> None:
180
+ # TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
181
+ self.config_file_path = config_file_path
182
+ self.process_group = get_process_group(process_group)
183
+ self.params_have_main_grad = params_have_main_grad
184
+ self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
185
+ self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
186
+ self.origin_step_func = None
187
+ self.origin_start_grad_sync = None
188
+ self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
189
+ self.config = load_json(config_file_path)
190
+ validate_config(self.config)
191
+
192
+ self.squash_name = self.config.get('squash_name', True) # 不允许修改防止前后名字对不上
193
+ local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
194
+ cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
195
+ self.unique_id = str(uuid.uuid4())[:8]
196
+ self.output_base_dir = get_output_base_dir()
197
+ time_tags = self.config.get("append_output", [])
198
+ if dist.is_initialized():
199
+ self.rank = dist.get_rank()
200
+ if time_tags:
201
+ output_append_dirs = get_target_output_dir(self.output_base_dir, time_tags[0], time_tags[1])
202
+ if str(self.rank) in output_append_dirs:
203
+ self.tensorboard_dir = output_append_dirs[str(self.rank)]
204
+ logger.info(f"append rank({self.rank}) result to {self.tensorboard_dir}")
205
+ else:
206
+ self.tensorboard_dir = os.path.join(self.output_base_dir,
207
+ f"{cur_time}-rank{self.rank}-{self.unique_id}")
208
+ self.pp_stage = dist.get_group_rank(self.process_group, self.rank)
209
+ self.group_mates = dist.get_process_group_ranks(self.process_group)
210
+ else:
211
+ self.rank = 0
212
+ self.tensorboard_dir = os.path.join(self.output_base_dir, f"{cur_time}-rank{self.rank}-{self.unique_id}")
213
+ self.pp_stage = 0
214
+ self.group_mates = [0]
215
+
216
+ # TYPE2: 只会在set_monitor()主调中赋值的变量
217
+ self.model = None
218
+ self.vpp = False
219
+ self.dp_group = None
220
+ self.tp_group = None
221
+ self.enable_megatron = False
222
+ self.micro_batch_number = 1
223
+ self.optimizer_class = None
224
+ self.optimizer_mon = None
225
+
226
+ # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
159
227
  self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
160
228
  self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
161
229
  self.optimizer_context = defaultdict(OptimizerContext)
162
230
  self.cc_context = defaultdict(CommunicationContext)
163
231
  self.grad_context = GradContext()
164
- self.process_group = get_process_group(process_group)
165
- self.params_have_main_grad = params_have_main_grad
166
- self.opt_ty = opt_ty
167
- self.config = load_json(config_file_path)
168
- validate_config(self.config)
232
+ self.handles = defaultdict(list)
233
+ self.param2name = defaultdict(str)
234
+ self.name2index = defaultdict()
235
+ self.name2indices = defaultdict()
236
+ self.name2param = {}
237
+ self.duplicate_param = {}
238
+ self.name2tag = {}
239
+ self.param_name_call_id = {}
240
+ self.call_id = 0
241
+ self.module_struct = defaultdict(dict)
242
+ self.grad_accs = []
243
+ self.weight_hooked = False
244
+ self.optimizer_hooked = False
245
+ self.param_registered = False
246
+ self.struct_printed = False
247
+
248
+ # 动静态区分
249
+ self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
250
+ if self.dynamic_enable:
251
+ logger.warning(f"DYNAMIC_MONITOR is set, "
252
+ f"please make sure you have 'dynamic_on' and 'collect_times' in {self.config_file_path}")
253
+ self.monitoring = False
254
+ else:
255
+ self.set_config()
256
+ # 静态且collect_times>0时在第0步self.monitoring就可以True, 动态默认在下一步开启
257
+ if self.collect_times > 0:
258
+ self.monitoring = True
259
+
260
+ def __del__(self):
261
+ if hasattr(self, "summary_writer"):
262
+ self.summary_writer.close()
263
+
264
+ @property
265
+ def ops(self):
266
+ return self._ops
169
267
 
268
+ @ops.setter
269
+ def ops(self, value):
270
+ self._ops = validate_ops(value)
271
+
272
+ @staticmethod
273
+ def has_register_backward_hook(module_name, module):
274
+ if hasattr(module, '_backward_hooks') and \
275
+ len(module._backward_hooks) > 0 and \
276
+ module._is_full_backward_hook is False:
277
+ logger.warning(
278
+ f"The {module_name} has registered deprecated register_backward_hook,"
279
+ f"which may cause abnormal data dump. The backward input/output for this module will be skipped."
280
+ )
281
+ return True
282
+ return False
283
+
284
+ @staticmethod
285
+ def generate_cc_metrics(cc_name, cc_tensor):
286
+ metrics = defaultdict(dict)
287
+ rank = dist.get_rank() if dist.is_initialized() else None
288
+ for op, tag2tensor in cc_tensor.data.items():
289
+ for tag, tensor in tag2tensor.items():
290
+ key = get_summary_writer_tag_name(cc_name, tag, rank)
291
+ metrics[op].update({key: tensor})
292
+ cc_tensor.reset()
293
+ return metrics
294
+
295
+ def set_config(self):
296
+ logger.info(f"current config: {self.config}")
297
+ self.start_step = self.config.get("start_step", 0)
298
+ self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集
299
+ self.step_interval = self.config.get("step_interval", 1)
300
+ self.has_collect_times = 0 # 重设采集计数器
301
+ self.print_struct = self.config.get("print_struct", False)
170
302
  self.module_rank_list = self.config.get("module_ranks", [])
171
- self.format = self.config.get('format', 'tensorboard')
303
+ self.format = self.config.get('format', MonitorConst.CSV)
172
304
  self.eps = self.config.get('eps', 1e-8)
173
305
  self.ops = self.config.get('ops', [])
174
306
  self.ndigits = self.config.get('ndigits', 6)
@@ -182,6 +314,7 @@ class TrainerMon:
182
314
  self.param_distribution = self.config.get("param_distribution", False)
183
315
  self.mg_direction = self.config.get('mg_direction', False)
184
316
  self.cc_distribution = self.config.get("cc_distribution", {})
317
+
185
318
  if not self.cc_distribution.get('enable', False):
186
319
  self.cc_log_only = False
187
320
  else:
@@ -189,49 +322,36 @@ class TrainerMon:
189
322
  self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
190
323
  self.cc_logged_stack = defaultdict(set)
191
324
  self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
192
- api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
325
+ self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
193
326
  api_register.redirect_api()
194
327
 
195
328
  self.common_info()
196
329
 
330
+ # 初始化AnomalyData工厂
197
331
  alert_setting = self.config.get('alert', {"rules": []})
198
332
  self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
199
-
200
- # 设置时区,使用 'UTC' 作为示例
201
- local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
202
-
203
- cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
204
- unique_id = str(uuid.uuid4())[:8]
205
-
206
- if dist.is_initialized():
207
- rank = dist.get_rank()
208
- tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}")
209
- pp_stage = dist.get_group_rank(self.process_group, rank)
210
- group_mates = dist.get_process_group_ranks(self.process_group)
211
- else:
212
- rank = 0
213
- tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}")
214
- pp_stage = 0
215
- group_mates = [0]
216
- self.rank = rank
217
-
218
- # 初始化AnomalyData工厂
219
333
  self.anomaly_data_factory = None
220
334
  if alert_setting.get('dump', False):
221
- self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates)
335
+ self.anomaly_data_factory = AnomalyDataFactory(self.rank, self.pp_stage, self.group_mates)
222
336
 
337
+ # 初始化writer, 创建输出目录
223
338
  if self.format not in FORMAT_MAPPING:
224
- raise ValueError(f"Unsupported format: {self.format}")
225
- writer, self.write_metrics = FORMAT_MAPPING[self.format]
339
+ logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
340
+ self.format = MonitorConst.CSV
341
+
342
+ if self.ur_distribution and self.format != 'tensorboard':
343
+ logger.error("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution")
344
+ self.ur_distribution = False
345
+
346
+ writer = FORMAT_MAPPING[self.format]
226
347
  self.step_count_per_record = self.config.get('step_count_per_record', 1)
227
348
 
228
- if (rank in self.module_rank_list) or len(self.module_rank_list) == 0:
349
+ if (self.rank in self.module_rank_list) or len(self.module_rank_list) == 0:
229
350
  self.summary_writer = writer(
230
351
  WriterInput(
231
- tensorboard_dir,
352
+ self.tensorboard_dir,
232
353
  self.alert_rules,
233
- unique_id,
234
- None,
354
+ self.unique_id,
235
355
  self.anomaly_data_factory,
236
356
  self.ndigits,
237
357
  self.step_count_per_record
@@ -239,83 +359,22 @@ class TrainerMon:
239
359
  )
240
360
  # 初始化anomaly detected文件目录
241
361
  if self.anomaly_data_factory:
242
- self.anomaly_data_writer = AnomalyDataWriter(os.path.join(output_base_dir, "anomaly_detected"), rank)
362
+ self.anomaly_data_writer = AnomalyDataWriter(os.path.join(self.output_base_dir, "anomaly_detected"),
363
+ self.rank)
243
364
  self.anomaly_data_writer.init_detected_json()
244
365
 
245
- # A HeatmapVisualizer instance is associated with an image
246
- self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
247
- self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
248
- self.micro_batch_number = 1
249
-
250
- self.model = None
251
- self.weight_hooked = False
252
- self.optimizer_hooked = False
253
- self.param_registered = False
254
- self.vpp = False
255
- self.dp_group = None
256
- self.tp_group = None
257
- self.enable_megatron = False
258
-
259
- self.param2name = defaultdict(str)
260
- self.name2index = defaultdict()
261
- self.name2indices = defaultdict()
262
- self.name2param = {}
263
- self.param_name_call_id = {}
264
- self.duplicate_param = {}
265
- self.name2tag = {}
266
- self.call_id = 0
267
- self.grad_accs = []
268
- self.handles = defaultdict(list)
269
-
270
- self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty)
271
- self.print_struct = self.config.get("print_struct", False)
272
- self.struct_printed = False
273
- self.module_struct = {}
274
-
275
- def __del__(self):
276
- if hasattr(self, "summary_writer"):
277
- self.summary_writer.close()
278
-
279
- @property
280
- def ops(self):
281
- return self._ops
282
-
283
- @ops.setter
284
- def ops(self, value):
285
- self._ops = validate_ops(value)
286
-
287
- @staticmethod
288
- def set_wrapped_optimizer(_wrapped_optimizer):
289
- OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer)
290
-
291
- @staticmethod
292
- def adhoc_check(target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
366
+ def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
293
367
  rank = None
294
368
  if dist.is_initialized():
295
369
  rank = dist.get_rank()
296
370
  if (rank not in rank_list) and len(rank_list) != 0:
297
371
  return
298
- TrainerMon.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
372
+ self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
299
373
 
300
- @staticmethod
301
- def build_tbtag_tensor_map(module_name, tag, tensor):
302
- metrics = {}
303
- rank = dist.get_rank() if dist.is_initialized() else None
304
- key = get_summary_writer_tag_name(module_name, tag, rank)
305
- if torch.is_tensor(tensor):
306
- metrics[key] = tensor
307
- return metrics
308
-
309
- @staticmethod
310
- def generate_cc_metrics(cc_name, cc_tensor):
311
- metrics = defaultdict(dict)
312
- rank = dist.get_rank() if dist.is_initialized() else None
313
- for op, tag2tensor in cc_tensor.data.items():
314
- for tag, tensor in tag2tensor.items():
315
- key = get_summary_writer_tag_name(cc_name, tag, rank)
316
- metrics[op].update({key: tensor})
317
- cc_tensor.reset()
318
- return metrics
374
+ def build_tbtag_tensor_map(self, module_name, tag, tensor):
375
+ key = get_summary_writer_tag_name(module_name, tag, self.rank)
376
+ self._register_param_call_id("_hook_module", key)
377
+ return {key: tensor}
319
378
 
320
379
  def common_info(self):
321
380
  if not self.xy_distribution:
@@ -332,37 +391,25 @@ class TrainerMon:
332
391
  logger.info_on_rank_0('> grad and momentum direction will not be compared.')
333
392
  if not self.cc_distribution.get('enable', False):
334
393
  logger.info_on_rank_0("> cc operator is not monitored.")
335
- if not self.opt_ty:
336
- if self.ur_distribution:
337
- raise Exception("ur_distribution cannot be enabled with unknown optimizer.")
338
- if self.mv_distribution:
339
- raise Exception("mv_distribution cannot be enabled with unknown optimizer.")
340
394
 
341
- def hook_modules(self, model: torch.nn.Module, grad_acc_steps):
395
+ def hook_modules(self):
342
396
  if self.module_rank_list and (self.rank not in self.module_rank_list):
343
397
  return
344
398
 
345
- if not isinstance(model, list):
346
- model = [model]
347
- self.model = model
348
- self._register_param_name(model)
349
-
350
- self.micro_batch_number = grad_acc_steps
351
-
352
399
  targets = self.config['targets']
353
- module_in_all_stage = [key for key in targets.keys() if MonitorConst.VPP_SEP not in key]
400
+ module_in_all_stage = [key for key in targets.keys() if MonitorConst.NAME_SEP not in key]
354
401
  for key in module_in_all_stage:
355
402
  struct = targets.pop(key)
356
- targets.update({f'{vpp_stage}{MonitorConst.VPP_SEP}{key}': struct for vpp_stage in range(len(model))})
403
+ targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
357
404
 
358
405
  hooked_count = 0
359
- for vpp_stage, model_chunk in enumerate(model):
360
- vpp_stage = f'{vpp_stage}{MonitorConst.VPP_SEP}'
406
+ for vpp_stage, model_chunk in enumerate(self.model):
407
+ vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
361
408
  targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
362
409
  'targets'].keys()
363
410
  hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
364
411
 
365
- logger.info_on_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.")
412
+ logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
366
413
 
367
414
  def clone_if_tensor(args):
368
415
  if isinstance(args, tuple):
@@ -383,11 +430,11 @@ class TrainerMon:
383
430
 
384
431
  BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
385
432
 
386
- if not self.optimizer_hooked:
387
- self.hook_optimizer()
388
433
  return
389
434
 
390
435
  def generate_param_metrics(self, opt_context):
436
+ if not self.param_distribution:
437
+ return
391
438
  get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
392
439
 
393
440
  def generate_mv_metrics(self, opt_context):
@@ -395,8 +442,8 @@ class TrainerMon:
395
442
  return
396
443
  opt_context.exp_avg_metric = {}
397
444
  opt_context.exp_avg_sq_metric = {}
398
- m_tag_tensor_map = self.generate_param_map('exp_avg', opt_context.param_exp_avg)
399
- v_tag_tensor_map = self.generate_param_map('efxp_avg_sq', opt_context.param_exp_avg_sq)
445
+ m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
446
+ v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
400
447
  get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
401
448
  get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
402
449
 
@@ -416,29 +463,52 @@ class TrainerMon:
416
463
  logger.warning(f"grad is None: {name}, maybe something wrong happened.")
417
464
  continue
418
465
  tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
466
+ self._register_param_call_id("hook_optimizer", tag)
419
467
  grad_dict[tag] = grad
420
468
 
421
469
  get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
422
- return self.grad_context.post, self.grad_context.pre
423
-
424
- def monitor_gnorm_with_ad(self, model, grad_acc_steps=1, optimizer=None, tp_group=None, dp_group=None):
470
+ unreduced_grad = self.grad_context.acc_metric if self.weight_hooked else self.grad_context.pre
471
+ return self.grad_context.post, unreduced_grad
472
+
473
+ def set_monitor(
474
+ self,
475
+ model,
476
+ grad_acc_steps=1,
477
+ optimizer=None,
478
+ tp_group=None,
479
+ dp_group=None,
480
+ start_iteration=0
481
+ ):
425
482
  """External interface"""
483
+ global start_step
484
+ start_step = start_iteration
426
485
  logger.info(f'grad acc steps {grad_acc_steps}')
427
- self.hook_optimizer(optimizer)
428
486
  self.micro_batch_number = grad_acc_steps
429
-
430
487
  self.dp_group = dp_group
431
488
  self.tp_group = tp_group
489
+ self.optimizer_mon, self.optimizer_class = OptimizerMonFactory.create_optimizer_mon(optimizer)
490
+ self.hook_step_final(optimizer)
491
+ if not isinstance(model, list):
492
+ model = [model]
493
+ self.model = model
494
+ if len(model) > 1:
495
+ self.vpp = True
496
+ self._smallest_rank_print('vpp enabled')
497
+ if not self.dynamic_enable:
498
+ self.register_hooks(optimizer)
432
499
 
433
- self._register_param_name(model)
500
+ def register_hooks(self, optimizer):
501
+ self._register_param_name()
502
+ self.hook_optimizer(optimizer)
434
503
  self._patch_grad_sync()
435
- self.hook_modules(model, grad_acc_steps)
504
+ self.hook_modules()
505
+ self.monitoring = True
436
506
 
437
507
  def generate_param_map(self, tag, param_tensor):
438
508
  metrics = {}
439
- rank = dist.get_rank() if dist.is_initialized() else None
440
509
  for name in self.param2name.values():
441
- key = get_summary_writer_tag_name(name, tag, rank)
510
+ key = get_summary_writer_tag_name(name, tag, self.rank)
511
+ self._register_param_call_id("optimizer_pre_step_hook", key)
442
512
  if name not in param_tensor or param_tensor[name] is None:
443
513
  continue
444
514
  metrics[key] = param_tensor[name]
@@ -454,17 +524,19 @@ class TrainerMon:
454
524
  return actv, actv_grad
455
525
 
456
526
  def reload_xy(self, xy_distribution=False):
527
+ logger.warning("reload_xy() is deprecated and will be removed in a future version. "
528
+ "Use DYNAMIC_MONITOR instead.")
457
529
  self.xy_distribution = xy_distribution
458
530
 
459
531
  for handle in self.handles['xy']:
460
532
  handle.remove()
461
533
  self.handles['xy'].clear()
462
- self.hook_modules(self.model, self.micro_batch_number)
534
+ self.hook_modules()
463
535
  for _, fwd_context in self.module_fwd_hook_context_by_module.items():
464
536
  fwd_context.actv.clear()
465
537
 
466
538
  def write_adhoc_check(self, step):
467
- TrainerMon.tensor_metrics.flush(self.summary_writer)
539
+ self.tensor_metrics.flush(self.summary_writer)
468
540
 
469
541
  def write_xy_tb(self, step):
470
542
  if not self.xy_distribution:
@@ -472,65 +544,65 @@ class TrainerMon:
472
544
  for _, fwd_context in self.module_fwd_hook_context_by_module.items():
473
545
  if len(fwd_context.actv) == 0:
474
546
  continue
475
- self.write_metrics(self.ops, self.summary_writer, fwd_context.actv, step, 'actv')
547
+ self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, MonitorConst.ACTV)
476
548
  fwd_context.actv.clear()
477
549
  if self.grad_context.actv:
478
- self.write_metrics(self.ops, self.summary_writer, self.grad_context.actv, step, 'actv_grad')
550
+ self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD)
479
551
 
480
552
  def write_param_tb(self, opt_context):
481
553
  if not self.param_distribution:
482
554
  return
483
- self.write_metrics(self.ops, self.summary_writer, opt_context.param_metric, opt_context.step, 'param')
555
+ self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, MonitorConst.PARAM)
484
556
 
485
557
  def write_mv_tb(self, opt_context):
486
558
  if not self.mv_distribution:
487
559
  return
488
- self.write_metrics(self.ops, self.summary_writer, opt_context.exp_avg_metric,
489
- opt_context.step, 'exp_avg')
490
- self.write_metrics(self.ops, self.summary_writer, opt_context.exp_avg_sq_metric,
491
- opt_context.step, 'exp_avg_sq')
560
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric,
561
+ opt_context.step, MonitorConst.EXP_AVG)
562
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric,
563
+ opt_context.step, MonitorConst.EXP_AVG_SQ)
492
564
 
493
565
  def write_grad_tb(self, step):
494
566
  if not self.wg_distribution:
495
567
  return
496
568
 
497
569
  if self.enable_megatron:
498
- self.write_metrics(self.ops, self.summary_writer, self.grad_context.pre, step, 'grad_unreduced')
570
+ self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
499
571
  else:
500
- self.write_metrics(self.ops, self.summary_writer, self.grad_context.acc_metric, step, 'grad_unreduced')
501
- self.write_metrics(self.ops, self.summary_writer, self.grad_context.post, step, 'grad_reduced')
572
+ self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
573
+ self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
502
574
 
503
575
  def hook_optimizer(self, optimizer=None):
504
576
  # in DDP by default use params_have_main_grad
505
577
  def optimizer_pre_step_hook(optimizer, args, kwargs):
506
578
  context = self.optimizer_context[optimizer]
507
- if self.opt_ty in MonitorConst.DEEPSPEED_OPT_TY:
508
- if context.step == 0:
509
- self.name2indices = self.mix_precision_optimizer_mon.get_param_index(self.param2name,
510
- self.name2index)
511
- mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name,
512
- self.name2indices)
513
- self.param2name = mv_result.grad
514
- else:
515
- mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name)
516
- context.param_exp_avg = mv_result.exp_avg
517
- context.param_exp_avg_sq = mv_result.exp_avg_sq
518
- context.param_adam_update = mv_result.update
519
- context.param_adam_ratio = mv_result.ratio
520
579
 
521
580
  if (self.print_struct and not all(value == {} for value in self.module_struct.values())
522
581
  and not self.struct_printed):
523
- self._smallest_rank_print("> module struct:")
524
- self._smallest_rank_print(json.dumps(self.module_struct))
525
- self.struct_printed = True
582
+ self._save_module_struct()
526
583
  if not self.cc_log_only:
527
- raise Exception("exit after first step when print model struct")
584
+ raise Exception("exit after first monitor step when print model struct")
528
585
  if self.cc_log_only and context.step > 0:
529
586
  self._smallest_rank_print("> Used communication ops and corresponding stack")
530
587
  self._smallest_rank_print(
531
588
  json.dumps({k: [i.split(';') for i in v] for k, v in self.cc_logged_stack.items()}))
532
589
  raise Exception("exit after first step when print cc stack")
533
590
 
591
+ # skip generate metrics
592
+ if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0:
593
+ return
594
+ if MonitorConst.DEEPSPEED_ZERO_OPT_FILTER in self.optimizer_class: # use deepspeed with zero1/2/3
595
+ if not self.name2indices:
596
+ self.name2indices = self.optimizer_mon.get_param_index(self.param2name, self.name2index, optimizer)
597
+ mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name, self.name2indices)
598
+ self.param2name = mv_result.grad
599
+ else:
600
+ mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name)
601
+ context.param_exp_avg = mv_result.exp_avg
602
+ context.param_exp_avg_sq = mv_result.exp_avg_sq
603
+ context.param_adam_update = mv_result.update
604
+ context.param_adam_ratio = mv_result.ratio
605
+
534
606
  self.generate_wgrad_metrics()
535
607
  self.generate_mv_metrics(context)
536
608
  self.generate_param_metrics(context)
@@ -561,41 +633,10 @@ class TrainerMon:
561
633
  context.metric_dict = metric_dict
562
634
  return
563
635
 
564
- def optimizer_post_step_hook(optimizer, args, kwargs):
565
- context = self.optimizer_context[optimizer]
566
- rank = dist.get_rank() if dist.is_initialized() else None
567
-
568
- if self.anomaly_data_factory:
569
- self.anomaly_data_factory.set_call_id(self.param_name_call_id)
570
- self.write_xy_tb(context.step)
571
- self.write_grad_tb(context.step)
572
- self.write_mv_tb(context)
573
- self.write_param_tb(context)
574
- self.write_adhoc_check(context.step)
575
-
576
- if self.ur_distribution:
577
- for param_name, _ in context.param_adam_update.items():
578
- self.update_heatmap_visualizer[param_name].visualize(
579
- get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, self.summary_writer)
580
- for param_name, _ in context.param_adam_ratio.items():
581
- self.ratio_heatmap_visualizer[param_name].visualize(
582
- get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer)
583
-
584
- if context.metric_dict:
585
- self.write_metrics(self.ops, self.summary_writer, context.metric_dict, context.step, 'other')
586
- context.metric_dict.clear()
587
- context.step += 1
588
- if self.anomaly_data_factory:
589
- self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
590
- self.summary_writer.clear_anomalies()
591
- self.call_id = 0
592
- return
593
-
594
636
  def patch_step(func, optimizer):
595
637
  def wrapper(*args, **kwargs):
596
638
  optimizer_pre_step_hook(optimizer, args, kwargs)
597
639
  out = func(*args, **kwargs)
598
- optimizer_post_step_hook(optimizer, args, kwargs)
599
640
  return out
600
641
 
601
642
  return wrapper
@@ -603,16 +644,177 @@ class TrainerMon:
603
644
  if self.optimizer_hooked:
604
645
  return
605
646
 
606
- if optimizer:
607
- optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
647
+ optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
608
648
 
609
- else:
610
- if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list):
611
- register_optimizer_step_pre_hook(optimizer_pre_step_hook)
612
- register_optimizer_step_post_hook(optimizer_post_step_hook)
613
649
  self.optimizer_hooked = True
614
650
  return
615
651
 
652
+ def dynamic_monitor(self, optimizer):
653
+ """
654
+ If dynamic monitor enabled and config.json updated,
655
+ remove hooks and register new hooks according to new configuration.
656
+ """
657
+ context = self.optimizer_context[optimizer]
658
+ if not self.dynamic_enable:
659
+ return
660
+ try:
661
+ # 如果文件时间戳没变, 可以不读取节省时间
662
+ config_timestamp = os.path.getmtime(self.config_file_path)
663
+ if config_timestamp == self.config_timestamp:
664
+ return
665
+ # 更新config文件最新修改时间戳
666
+ self.config_timestamp = config_timestamp
667
+ config = load_json(self.config_file_path)
668
+ except Exception as e:
669
+ logger.error(f"get config.json wrong because {e}, not updated, please check!!!")
670
+ return
671
+
672
+ if config.get("dynamic_on", False):
673
+ try:
674
+ validate_config(config)
675
+ self.config = config
676
+ self.set_config()
677
+ logger.warning(f"config is updated at step{context.step - 1}, "
678
+ f"will start new hook at step{context.step}.")
679
+ except Exception as e:
680
+ logger.error(f"set config wrong because {e}, not updated, please check!!!")
681
+ return
682
+
683
+ self._remove_all_hooks(optimizer)
684
+ self.register_hooks(optimizer)
685
+
686
+ def hook_step_final(self, optimizer):
687
+ def step_final_hook(optimizer, args, kwargs):
688
+ context = self.optimizer_context[optimizer]
689
+ rank = dist.get_rank() if dist.is_initialized() else None
690
+ # 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
691
+ if self.monitoring:
692
+ module_rank_valid = not self.module_rank_list or (
693
+ dist.is_initialized() and dist.get_rank() in self.module_rank_list)
694
+ step_condition = (context.step >= self.start_step and (
695
+ context.step - self.start_step) % self.step_interval == 0)
696
+ if module_rank_valid and step_condition:
697
+ self.has_collect_times += 1
698
+
699
+ if self.anomaly_data_factory:
700
+ self.anomaly_data_factory.set_call_id(self.param_name_call_id)
701
+ self.write_xy_tb(context.step)
702
+ self.write_grad_tb(context.step)
703
+ self.write_mv_tb(context)
704
+ self.write_param_tb(context)
705
+ self.write_adhoc_check(context.step)
706
+
707
+ if self.ur_distribution:
708
+ for param_name, _ in context.param_adam_update.items():
709
+ self.update_heatmap_visualizer[param_name].visualize(
710
+ get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step,
711
+ self.summary_writer)
712
+ for param_name, _ in context.param_adam_ratio.items():
713
+ self.ratio_heatmap_visualizer[param_name].visualize(
714
+ get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step,
715
+ self.summary_writer)
716
+
717
+ if context.metric_dict:
718
+ self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
719
+ context.metric_dict.clear()
720
+
721
+ if self.anomaly_data_factory:
722
+ self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
723
+ self.summary_writer.clear_anomalies()
724
+ self.call_id = 0
725
+ self.param_name_call_id.clear()
726
+
727
+ if self.has_collect_times >= self.collect_times:
728
+ self._remove_all_hooks_final(optimizer)
729
+
730
+ context.step += 1
731
+ self.dynamic_monitor(optimizer)
732
+
733
+ def patch_step(func, optimizer):
734
+ def wrapper(*args, **kwargs):
735
+ out = func(*args, **kwargs)
736
+ step_final_hook(optimizer, args, kwargs)
737
+ return out
738
+ return wrapper
739
+
740
+ optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
741
+ self.origin_step_func = optimizer.__class__.step
742
+
743
+ return
744
+
745
+ def _remove_all_hooks(self, optimizer):
746
+ # 清空hook handle
747
+ for handle in self.handles['xy']:
748
+ handle.remove()
749
+ self.handles['xy'].clear()
750
+ # 清空对应context缓存
751
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
752
+ fwd_context.reset()
753
+ for _, bwd_context in self.module_bwd_hook_context_by_module.items():
754
+ bwd_context.reset()
755
+ self.grad_context.reset() # 权重梯度和激活值梯度都在这
756
+
757
+ if self.origin_start_grad_sync: # megatron
758
+ try:
759
+ from megatron.core.distributed.param_and_grad_buffer import Bucket
760
+ Bucket.start_grad_sync = self.origin_start_grad_sync
761
+ logger.info("remove Bucket start_grad_sync")
762
+ except ImportError:
763
+ pass
764
+ try:
765
+ from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
766
+ _ParamAndGradBucketGroup.start_grad_sync = self.origin_start_grad_sync
767
+ logger.info("remove _ParamAndGradBucketGroup start_grad_sync")
768
+ except ImportError:
769
+ pass
770
+ else: # not megatron
771
+ for handle in self.handles['wgrads']:
772
+ handle.remove()
773
+ self.handles['wgrads'].clear()
774
+ self.weight_hooked = False
775
+
776
+ if self.optimizer_hooked:
777
+ optimizer.__class__.step = self.origin_step_func
778
+
779
+ for _, context in self.optimizer_context.items():
780
+ context.reset()
781
+ self.optimizer_hooked = False
782
+
783
+ for handle in self.handles['cc']:
784
+ handle.remove()
785
+ self.handles['cc'].clear()
786
+ for _, context in self.cc_context.items():
787
+ context.reset()
788
+
789
+ # 清空节点缓存
790
+ self.param2name.clear()
791
+ self.name2index.clear()
792
+ self.name2indices.clear()
793
+ self.name2param.clear()
794
+ self.duplicate_param.clear()
795
+ self.name2tag.clear()
796
+ self.module_struct.clear()
797
+ self.grad_accs.clear()
798
+
799
+ # 关闭采集状态
800
+ self.monitoring = False
801
+
802
+ def _remove_all_hooks_final(self, optimizer):
803
+ if self.dynamic_enable:
804
+ # 结束后自动重置dynamic_on为False等待用户手动开启
805
+ try:
806
+ config = load_json(self.config_file_path)
807
+ config['dynamic_on'] = False
808
+ save_json(self.config_file_path, config, indent=2)
809
+ config_timestamp = os.path.getmtime(self.config_file_path)
810
+ self.config_timestamp = config_timestamp
811
+ logger.info(
812
+ "Finish monitor, set config'dynamic_on=False, will restart by set it to True and update config")
813
+ except Exception as e:
814
+ logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!")
815
+ logger.info("Finish monitor")
816
+ self._remove_all_hooks(optimizer)
817
+
616
818
  def _smallest_rank_print(self, msg):
617
819
  if dist.is_initialized():
618
820
  if self.module_rank_list:
@@ -624,9 +826,20 @@ class TrainerMon:
624
826
  else:
625
827
  logger.info(msg)
626
828
 
829
+ def _save_module_struct(self):
830
+ save_module_struct = (not dist.is_initialized()
831
+ or (self.module_rank_list and dist.get_rank() == min(self.module_rank_list))
832
+ or (not self.module_rank_list and dist.get_rank() == 0))
833
+
834
+ if save_module_struct:
835
+ module_struct_file = os.path.realpath(os.path.join(get_output_base_dir(), 'module_struct.json'))
836
+ save_json(module_struct_file, self.module_struct, indent=2)
837
+ logger.info(f"> save module struct to {module_struct_file}")
838
+ self.struct_printed = True
839
+
627
840
  def _is_target_param(self, param_name, param, prefix):
628
- squash_name = prefix + squash_param_name(param_name)
629
841
  name = prefix + param_name
842
+ squash_name = prefix + squash_param_name(param_name, self.squash_name)
630
843
  for target in self.config['targets'].keys():
631
844
  if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
632
845
  setattr(param, "zero_out_wgrad", True)
@@ -635,15 +848,14 @@ class TrainerMon:
635
848
  return False
636
849
 
637
850
  def _register_chunk(self, model_chunk, prefix):
638
- for index, (param_name, param) in enumerate(model_chunk.named_parameters()):
851
+ index = 0
852
+ for (param_name, param) in model_chunk.named_parameters():
639
853
  if not param.requires_grad:
640
854
  continue
641
855
  if self._is_target_param(param_name, param, prefix):
642
- name = prefix + squash_param_name(param_name)
856
+ name = prefix + squash_param_name(param_name, self.squash_name)
643
857
  if name in self.param2name.values():
644
- logger.error(f'same name {name} for different param. Current param is {param_name}. \
645
- May be error of squash_param_name')
646
- raise Exception("param with same name will be overwritten.")
858
+ name = prefix + param_name
647
859
  self.param2name[param] = name
648
860
  self.name2param[name] = param
649
861
  self.name2index[name] = index
@@ -652,34 +864,22 @@ class TrainerMon:
652
864
  self.duplicate_param[name] = True
653
865
  if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
654
866
  self.duplicate_param[name] = True
655
- self.name2tag[name] = {}
656
- self.name2tag[name][MonitorConst.PRE_GRAD] = get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD,
657
- self.rank)
658
- self.name2tag[name][MonitorConst.POST_GRAD] = get_summary_writer_tag_name(name, MonitorConst.POST_GRAD,
659
- self.rank)
660
-
661
- def _register_param_name(self, model):
662
- if self.param_registered:
663
- return
664
-
665
- if not isinstance(model, list):
666
- model = [model]
667
-
668
- if len(model) > 1:
669
- self.vpp = True
670
- self._smallest_rank_print('vpp enabled')
671
-
672
- for vpp_stage, model_chunk in enumerate(model):
673
- prefix = f'{vpp_stage}{MonitorConst.VPP_SEP}'
867
+ self.name2tag[name] = {
868
+ MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank),
869
+ MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank)
870
+ }
871
+ index += 1
872
+
873
+ def _register_param_name(self):
874
+ for vpp_stage, model_chunk in enumerate(self.model):
875
+ prefix = f'{vpp_stage}{MonitorConst.NAME_SEP}'
674
876
  self._register_chunk(model_chunk, prefix)
675
877
 
676
- self.param_registered = True
677
-
678
878
  def _is_target_module(self, module_name, targets, vpp_stage):
679
879
  if self.all_xy or self.print_struct:
680
- return vpp_stage + squash_param_name(module_name)
880
+ return vpp_stage + squash_param_name(module_name, self.squash_name)
681
881
  for pattern in [
682
- vpp_stage + squash_param_name(module_name),
882
+ vpp_stage + squash_param_name(module_name, self.squash_name),
683
883
  vpp_stage + module_name,
684
884
  ]:
685
885
  if pattern in targets:
@@ -692,90 +892,88 @@ class TrainerMon:
692
892
  return 0
693
893
 
694
894
  def fwd_hook_fun(module, module_input, module_output, name):
695
- if is_recomputation():
895
+ if not module.training or is_recomputation():
896
+ # 1 only monitor training stage.
897
+ # 2 when open recompute, skip recomputed forward stage.
696
898
  return
697
899
  if module not in self.module_fwd_hook_context_by_module:
698
900
  self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
699
901
  context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
700
902
  if not context.struct:
701
- context.struct = {MonitorConst.ACTV_IN: get_param_struct(module_input),
702
- MonitorConst.ACTV_OUT: get_param_struct(module_output)}
903
+ context.struct = {
904
+ Const.INPUT: get_param_struct(module_input),
905
+ Const.OUTPUT: get_param_struct(module_output)
906
+ }
703
907
  if self.print_struct:
704
- if context.module_name not in self.module_struct:
705
- self.module_struct[context.module_name] = {}
706
908
  self.module_struct[context.module_name].update(context.struct)
707
909
  return
708
- if not module.training:
709
- return
710
910
  if not context.format_by_arg:
711
- context.set_format_by_arg(MonitorConst.ACTV_IN, self.config['targets'])
712
- context.set_format_by_arg(MonitorConst.ACTV_OUT, self.config['targets'])
911
+ context.set_format_by_arg(Const.INPUT, self.config['targets'])
912
+ context.set_format_by_arg(Const.OUTPUT, self.config['targets'])
713
913
  if not context.format_by_arg:
714
914
  return
715
915
  if not context.verified:
716
- if not context.ignore_in:
717
- context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
718
- module_input, context.module_name,
719
- MonitorConst.ACTV_IN)
720
- context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
916
+ context.focused_in_col = validate_config_spec(context.format_by_arg[Const.INPUT],
917
+ module_input, context.module_name,
918
+ Const.INPUT)
919
+ context.focused_out_col = validate_config_spec(context.format_by_arg[Const.OUTPUT],
721
920
  module_output, context.module_name,
722
- MonitorConst.ACTV_OUT)
921
+ Const.OUTPUT)
723
922
  context.verified = True
724
923
  # expect output be tensor type
725
924
  tbtag_tensor_map = {}
726
- if not context.ignore_in:
727
- cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
728
- tbtag_tensor_map.update(
729
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
730
- cared_input))
925
+ cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
926
+ tbtag_tensor_map.update(
927
+ self.build_tbtag_tensor_map(
928
+ f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
929
+ MonitorConst.ACTV, cared_input))
731
930
  cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
732
931
  tbtag_tensor_map.update(
733
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT,
734
- cared_output))
932
+ self.build_tbtag_tensor_map(
933
+ f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
934
+ MonitorConst.ACTV, cared_output))
735
935
 
736
936
  get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
737
-
738
937
  context.micro_step += 1
739
938
  if context.micro_step == self.micro_batch_number:
740
939
  context.micro_step = 0
741
- context.step += 1
742
940
  return
743
941
 
744
942
  def bwd_hook_fun(module, input_grad, output_grad):
745
943
  context: ModuleHookContext = self.module_bwd_hook_context_by_module[module]
746
944
  if not context.struct:
747
- context.struct = {MonitorConst.ACTVGRAD_IN: get_param_struct(input_grad),
748
- MonitorConst.ACTVGRAD_OUT: get_param_struct(output_grad)}
945
+ context.struct = {
946
+ MonitorConst.INPUT_GRAD: get_param_struct(input_grad),
947
+ MonitorConst.OUTPUT_GRAD: get_param_struct(output_grad)
948
+ }
749
949
  if self.print_struct:
750
- if context.module_name not in self.module_struct:
751
- self.module_struct[context.module_name] = {}
752
950
  self.module_struct[context.module_name].update(context.struct)
753
951
  return
754
952
  if not context.format_by_arg:
755
- context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.config['targets'])
756
- context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.config['targets'])
953
+ context.set_format_by_arg(MonitorConst.INPUT_GRAD, self.config['targets'])
954
+ context.set_format_by_arg(MonitorConst.OUTPUT_GRAD, self.config['targets'])
757
955
  if not context.format_by_arg:
758
956
  return
759
957
  if not context.verified:
760
- if not context.ignore_in:
761
- context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
762
- input_grad, context.module_name,
763
- MonitorConst.ACTVGRAD_IN)
764
- context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
765
- output_grad, context.module_name,
766
- MonitorConst.ACTVGRAD_OUT)
958
+ context.focused_in_col = validate_config_spec(
959
+ context.format_by_arg[MonitorConst.INPUT_GRAD],
960
+ input_grad, context.module_name, MonitorConst.INPUT_GRAD)
961
+ context.focused_out_col = validate_config_spec(
962
+ context.format_by_arg[MonitorConst.OUTPUT_GRAD],
963
+ output_grad, context.module_name, MonitorConst.OUTPUT_GRAD)
767
964
  context.verified = True
768
965
 
769
966
  tbtag_tensor_map = {}
770
- if not context.ignore_in:
771
- cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
772
- tbtag_tensor_map.update(
773
- self.build_tbtag_tensor_map(
774
- f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad))
967
+ cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
968
+ tbtag_tensor_map.update(
969
+ self.build_tbtag_tensor_map(
970
+ f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
971
+ MonitorConst.ACTV, cared_input_grad))
775
972
  cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
776
973
  tbtag_tensor_map.update(
777
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT,
778
- cared_output_grad))
974
+ self.build_tbtag_tensor_map(
975
+ f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
976
+ MonitorConst.ACTV, cared_output_grad))
779
977
 
780
978
  if context.micro_step == 0 and context.actvgrad:
781
979
  logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
@@ -787,7 +985,6 @@ class TrainerMon:
787
985
  context.micro_step += 1
788
986
  if context.micro_step == self.micro_batch_number:
789
987
  context.micro_step = 0
790
- context.step += 1
791
988
  return
792
989
 
793
990
  if self.backward_only and self.forward_only:
@@ -802,7 +999,7 @@ class TrainerMon:
802
999
  if not self.backward_only:
803
1000
  handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name))
804
1001
  self.handles['xy'].append(handle)
805
- if not self.forward_only:
1002
+ if not self.forward_only and not self.has_register_backward_hook(name, submodule):
806
1003
  handle = submodule.register_full_backward_hook(bwd_hook_fun)
807
1004
  self.handles['xy'].append(handle)
808
1005
  self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
@@ -814,8 +1011,12 @@ class TrainerMon:
814
1011
  def patch_sync(sync_grad_func):
815
1012
  def wrapper(bucket):
816
1013
  grad_dict = {}
1014
+ # Megatron between core_r0.6.0 and core_r0.8.0, this bucket is Bucket.
1015
+ # When megatron is core_r0.9.0, this bucket is _ParamAndGradBucketGroup.
1016
+ # In megatron version core_r0.9.0, func start_grad_sync from Bucket moved to _ParamAndGradBucketGroup.
1017
+ bucket_params_id_list = [id(params) for params in bucket.params]
817
1018
  for param, name in self.param2name.items():
818
- if param not in bucket.params_list:
1019
+ if id(param) not in bucket_params_id_list:
819
1020
  continue
820
1021
  grad = param.main_grad if self.params_have_main_grad else param.grad
821
1022
  if grad is None:
@@ -825,21 +1026,35 @@ class TrainerMon:
825
1026
  if tag is None:
826
1027
  continue
827
1028
  grad_dict[tag] = grad
1029
+ self._register_param_call_id("sync_grad_func", tag)
828
1030
  get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
829
1031
  out = sync_grad_func(bucket)
830
1032
  return out
831
1033
 
832
1034
  return wrapper
833
1035
 
1036
+ if not self.wg_distribution:
1037
+ return
1038
+
834
1039
  try:
835
1040
  from megatron.core.distributed.param_and_grad_buffer import Bucket
1041
+ self.origin_start_grad_sync = Bucket.start_grad_sync
1042
+ Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync)
836
1043
  self.enable_megatron = True
1044
+ logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0")
837
1045
  except ImportError:
838
1046
  self.enable_megatron = False
839
1047
 
840
- if self.enable_megatron:
841
- Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
842
- else:
1048
+ try:
1049
+ from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
1050
+ self.origin_start_grad_sync = _ParamAndGradBucketGroup.start_grad_sync
1051
+ _ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync)
1052
+ self.enable_megatron = True
1053
+ logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0")
1054
+ except ImportError:
1055
+ self.enable_megatron = False
1056
+
1057
+ if not self.enable_megatron:
843
1058
  self._hook_weights()
844
1059
 
845
1060
  def _hook_weights(self):
@@ -848,8 +1063,7 @@ class TrainerMon:
848
1063
  @torch.no_grad
849
1064
  def param_hook(*args, context_dict, param, key, name):
850
1065
  param.micro_step += 1
851
- self.param_name_call_id[name] = self.call_id
852
- self.call_id += 1
1066
+ self._register_param_call_id("param_hook", key)
853
1067
  if param.micro_step == self.micro_batch_number:
854
1068
  param.micro_step = 0
855
1069
  if self.params_have_main_grad:
@@ -857,6 +1071,7 @@ class TrainerMon:
857
1071
  else:
858
1072
  context_dict[key] = param.grad.clone()
859
1073
 
1074
+ logger.info("hooking weights.")
860
1075
  for param, name in self.param2name.items():
861
1076
  key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
862
1077
  setattr(param, 'micro_step', 0)
@@ -868,3 +1083,13 @@ class TrainerMon:
868
1083
  self.handles['wgrads'].append(handle)
869
1084
 
870
1085
  self.weight_hooked = True
1086
+
1087
+ def _register_param_call_id(self, hook_name: str, key: str):
1088
+ """
1089
+ :param hook_name:
1090
+ :param key: str, '0:relu_0/output_grad'
1091
+ :return:
1092
+ """
1093
+ logger.debug(f"{hook_name} {key}: {self.call_id}")
1094
+ self.param_name_call_id[key] = self.call_id
1095
+ self.call_id += 1