mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +14 -19
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +155 -6
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +3 -0
  10. msprobe/core/common/utils.py +28 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +380 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/multiprocessing_compute.py +2 -2
  22. msprobe/core/compare/npy_compare.py +109 -147
  23. msprobe/core/compare/utils.py +189 -69
  24. msprobe/core/data_dump/data_collector.py +51 -21
  25. msprobe/core/data_dump/data_processor/base.py +38 -20
  26. msprobe/core/data_dump/data_processor/factory.py +5 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
  29. msprobe/core/data_dump/json_writer.py +29 -1
  30. msprobe/core/data_dump/scope.py +19 -18
  31. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  32. msprobe/core/overflow_check/checker.py +1 -1
  33. msprobe/core/overflow_check/utils.py +1 -1
  34. msprobe/docs/01.installation.md +96 -17
  35. msprobe/docs/02.config_introduction.md +5 -5
  36. msprobe/docs/05.data_dump_PyTorch.md +91 -61
  37. msprobe/docs/06.data_dump_MindSpore.md +57 -19
  38. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  39. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
  40. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  41. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  42. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  43. msprobe/docs/19.monitor.md +120 -27
  44. msprobe/docs/21.visualization_PyTorch.md +115 -35
  45. msprobe/docs/22.visualization_MindSpore.md +138 -41
  46. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  47. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  48. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  49. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  50. msprobe/docs/27.dump_json_instruction.md +521 -0
  51. msprobe/docs/FAQ.md +26 -2
  52. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  53. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  54. msprobe/docs/img/merge_result.png +0 -0
  55. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  56. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  57. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  58. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  59. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  60. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  61. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  63. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  64. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  65. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  66. msprobe/docs/visualization/GPTModel.png +0 -0
  67. msprobe/docs/visualization/ParallelMLP.png +0 -0
  68. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  69. msprobe/docs/visualization/mapping.png +0 -0
  70. msprobe/docs/visualization/mapping1.png +0 -0
  71. msprobe/docs/visualization/module_name.png +0 -0
  72. msprobe/docs/visualization/module_name1.png +0 -0
  73. msprobe/docs/visualization/no_mapping.png +0 -0
  74. msprobe/docs/visualization/no_mapping1.png +0 -0
  75. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  76. msprobe/docs/visualization/top_layer.png +0 -0
  77. msprobe/mindspore/__init__.py +10 -0
  78. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
  79. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  80. msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
  81. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  82. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  83. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  84. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  85. msprobe/mindspore/code_mapping/bind.py +264 -0
  86. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  87. msprobe/mindspore/code_mapping/graph.py +49 -0
  88. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  89. msprobe/mindspore/code_mapping/main.py +24 -0
  90. msprobe/mindspore/code_mapping/processor.py +34 -0
  91. msprobe/mindspore/common/const.py +3 -1
  92. msprobe/mindspore/common/utils.py +50 -5
  93. msprobe/mindspore/compare/distributed_compare.py +0 -2
  94. msprobe/mindspore/compare/ms_compare.py +105 -63
  95. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  96. msprobe/mindspore/debugger/debugger_config.py +3 -0
  97. msprobe/mindspore/debugger/precision_debugger.py +81 -12
  98. msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
  99. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  100. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  101. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  102. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  103. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  104. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  105. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  106. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  107. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  108. msprobe/mindspore/grad_probe/hook.py +13 -4
  109. msprobe/mindspore/mindtorch/__init__.py +18 -0
  110. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  111. msprobe/mindspore/ms_config.py +5 -1
  112. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  113. msprobe/mindspore/service.py +267 -101
  114. msprobe/msprobe.py +24 -3
  115. msprobe/pytorch/__init__.py +7 -6
  116. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  117. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  123. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  124. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
  125. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  126. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  127. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  128. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  129. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  130. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  131. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  132. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  133. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  134. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  135. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  136. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  140. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  141. msprobe/pytorch/common/parse_json.py +2 -1
  142. msprobe/pytorch/common/utils.py +45 -2
  143. msprobe/pytorch/compare/distributed_compare.py +17 -29
  144. msprobe/pytorch/compare/pt_compare.py +40 -20
  145. msprobe/pytorch/debugger/debugger_config.py +27 -12
  146. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  147. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  148. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  149. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
  150. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  151. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  152. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  153. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  154. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  155. msprobe/pytorch/hook_module/__init__.py +1 -1
  156. msprobe/pytorch/hook_module/hook_module.py +14 -11
  157. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  158. msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
  159. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  160. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  161. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  162. msprobe/pytorch/monitor/anomaly_detect.py +107 -22
  163. msprobe/pytorch/monitor/csv2tb.py +166 -0
  164. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  165. msprobe/pytorch/monitor/features.py +3 -3
  166. msprobe/pytorch/monitor/module_hook.py +483 -277
  167. msprobe/pytorch/monitor/module_metric.py +27 -48
  168. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  169. msprobe/pytorch/monitor/optimizer_collect.py +52 -14
  170. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  171. msprobe/pytorch/monitor/utils.py +77 -6
  172. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  173. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  174. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  175. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  176. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  177. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  178. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  179. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  180. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  181. msprobe/pytorch/service.py +176 -106
  182. msprobe/visualization/builder/graph_builder.py +62 -5
  183. msprobe/visualization/builder/msprobe_adapter.py +24 -2
  184. msprobe/visualization/compare/graph_comparator.py +64 -14
  185. msprobe/visualization/compare/mode_adapter.py +1 -15
  186. msprobe/visualization/graph/base_node.py +12 -17
  187. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  188. msprobe/visualization/graph/graph.py +9 -0
  189. msprobe/visualization/graph_service.py +97 -23
  190. msprobe/visualization/utils.py +14 -29
  191. msprobe/pytorch/functional/module_dump.py +0 -84
  192. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  193. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
  194. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
  195. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  196. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  197. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,50 +12,44 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- import time
16
15
  import json
17
16
  import os
18
17
  import uuid
19
18
  from collections import defaultdict
20
- from datetime import datetime, timezone
19
+ from datetime import datetime
21
20
  from functools import partial
22
21
 
23
22
  import pytz
24
23
  import torch
25
24
  import torch.distributed as dist
25
+ from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
26
+ from torch.utils.hooks import BackwardHook
27
+
26
28
  from msprobe.core.common.const import MonitorConst
27
- from msprobe.core.common.file_utils import load_json
28
- from msprobe.core.common.log import logger
29
+ from msprobe.core.common.file_utils import load_json, save_json
30
+ from msprobe.pytorch.common.log import logger
29
31
  from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter
30
32
  from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
31
33
  CSVWriterWithAD, BaseWriterWithAD, WriterInput
32
34
  from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
33
35
  get_process_group
34
36
  from msprobe.pytorch.monitor.features import get_sign_matches
35
- from msprobe.pytorch.monitor.module_metric import get_metrics, write_metrics_base, get_summary_writer_tag_name, \
36
- TensorMetrics, write_metrics_csv, squash_param_name
37
+ from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
38
+ TensorMetrics, squash_param_name
37
39
  from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
38
40
  from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory, OptimizerMon
39
- from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, is_recomputation
41
+ from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, is_recomputation, \
42
+ get_output_base_dir, get_target_output_dir
40
43
  from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
41
- from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
42
- from torch.utils.hooks import BackwardHook
43
-
44
- try:
45
- import torch_npu
46
- except ImportError:
47
- pass
48
44
 
49
45
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
50
46
  if not torch_version_above_or_equal_2:
51
47
  raise ValueError("monitor require torch>=2.0")
52
48
 
53
- output_base_dir = os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
54
-
55
49
  FORMAT_MAPPING = {
56
- MonitorConst.TENSORBOARD: (SummaryWriterWithAD, write_metrics_base),
57
- MonitorConst.CSV: (CSVWriterWithAD, write_metrics_csv),
58
- MonitorConst.API: (BaseWriterWithAD, write_metrics_base)
50
+ MonitorConst.TENSORBOARD: SummaryWriterWithAD,
51
+ MonitorConst.CSV: CSVWriterWithAD,
52
+ MonitorConst.API: BaseWriterWithAD
59
53
  }
60
54
 
61
55
 
@@ -71,7 +65,6 @@ def param_is_data_parallel_duplicate(dp_group):
71
65
 
72
66
  class ModuleHookContext:
73
67
  def __init__(self, module_name) -> None:
74
- self.step = 0
75
68
  self.micro_step = 0
76
69
  self.actv = defaultdict(dict)
77
70
  self.actvgrad = []
@@ -81,26 +74,47 @@ class ModuleHookContext:
81
74
  self.verified = False
82
75
  self.focused_in_col = 0
83
76
  self.focused_out_col = 0
84
- self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found
85
77
 
86
78
  def set_format_by_arg(self, key_name: str, target_config: dict):
79
+ """ 按照监控对象配置format_by_arg
80
+ 1) module_name 在 target 中配置监控对象
81
+ 2) module_name 未在 targets 中配置,且 all_xy 全量监控
82
+ 3) module_name 未在 targets 中配置,且 all_xy 未全量监控
83
+
84
+ :param key_name: str, one of [input, output, input_grad, output_grad]
85
+ :param target_config: target obj in config json.
86
+ :return:
87
+ """
88
+ valid_key = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT, MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
89
+ if key_name not in valid_key:
90
+ raise ValueError(f"key({key_name}) error, valid_key: {valid_key}")
87
91
  cared = target_config.get(self.module_name, self.struct)
88
92
  if key_name in cared:
89
- if isinstance(cared[key_name], dict):
90
- # current cared is self.struct
91
- config = cared[key_name].get('config')
92
- self.format_by_arg[key_name] = config
93
- else:
93
+ target_module_config = cared[key_name]
94
+ if isinstance(target_module_config, dict):
95
+ # current cared is self.struct, monitor all data for module_name
96
+ self.format_by_arg[key_name] = target_module_config.get('config')
97
+ elif isinstance(target_module_config, str):
94
98
  # current cared is target_config[self.module_name]
95
- self.format_by_arg[key_name] = cared[key_name]
96
- elif key_name in ['input', 'input_grad']:
97
- self.ignore_in = True
99
+ self.format_by_arg[key_name] = target_module_config
100
+ else:
101
+ logger.warning_on_rank_0(f"target module config error, result maybe empty."
102
+ f"module_name: {self.module_name}, key_name: {key_name}")
103
+ self.format_by_arg[key_name] = None
104
+ else:
105
+ self.format_by_arg[key_name] = self.struct.get(key_name).get('config')
106
+
107
+ def reset(self):
108
+ self.actv.clear()
109
+ self.actvgrad.clear()
110
+
111
+
112
+ start_step = 0
98
113
 
99
114
 
100
115
  class OptimizerContext:
101
116
  def __init__(self) -> None:
102
- self.step = 0
103
- self.param_effective_rank = defaultdict(float)
117
+ self.step = start_step
104
118
  self.param_mg_direction = defaultdict(float)
105
119
  self.param_adam_update = defaultdict()
106
120
  self.param_adam_ratio = defaultdict()
@@ -112,6 +126,18 @@ class OptimizerContext:
112
126
  self.metric_dict = {}
113
127
  self.param_metric = {}
114
128
 
129
+ def reset(self):
130
+ self.param_mg_direction.clear()
131
+ self.param_adam_update.clear()
132
+ self.param_adam_ratio.clear()
133
+ self.param_weight_grad.clear()
134
+ self.param_exp_avg.clear()
135
+ self.exp_avg_metric.clear()
136
+ self.param_exp_avg_sq.clear()
137
+ self.exp_avg_sq_metric.clear()
138
+ self.metric_dict.clear()
139
+ self.param_metric.clear()
140
+
115
141
 
116
142
  class CommunicationContext:
117
143
  def __init__(self) -> None:
@@ -156,17 +182,131 @@ class TrainerMon:
156
182
  """
157
183
  opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer"
158
184
  """
159
- self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
160
- self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
161
- self.optimizer_context = defaultdict(OptimizerContext)
162
- self.cc_context = defaultdict(CommunicationContext)
163
- self.grad_context = GradContext()
185
+ # TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
186
+ self.config_file_path = config_file_path
164
187
  self.process_group = get_process_group(process_group)
165
188
  self.params_have_main_grad = params_have_main_grad
166
189
  self.opt_ty = opt_ty
190
+ self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty)
191
+ self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
192
+ self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
193
+ self.origin_step_func = None
194
+ self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过switch开关直接打开
167
195
  self.config = load_json(config_file_path)
168
196
  validate_config(self.config)
169
197
 
198
+ self.squash_name = self.config.get('squash_name', True) # 不允许修改防止前后名字对不上
199
+ local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
200
+ cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
201
+ self.unique_id = str(uuid.uuid4())[:8]
202
+ self.output_base_dir = get_output_base_dir()
203
+ time_tags = self.config.get("append_output", [])
204
+ if dist.is_initialized():
205
+ self.rank = dist.get_rank()
206
+ if time_tags:
207
+ output_append_dirs = get_target_output_dir(self.output_base_dir, time_tags[0], time_tags[1])
208
+ if str(self.rank) in output_append_dirs:
209
+ self.tensorboard_dir = output_append_dirs[str(self.rank)]
210
+ logger.info(f"append rank({self.rank}) result to {self.tensorboard_dir}")
211
+ else:
212
+ self.tensorboard_dir = os.path.join(self.output_base_dir,
213
+ f"{cur_time}-rank{self.rank}-{self.unique_id}")
214
+ self.pp_stage = dist.get_group_rank(self.process_group, self.rank)
215
+ self.group_mates = dist.get_process_group_ranks(self.process_group)
216
+ else:
217
+ self.rank = 0
218
+ self.tensorboard_dir = os.path.join(self.output_base_dir, f"{cur_time}-rank{self.rank}-{self.unique_id}")
219
+ self.pp_stage = 0
220
+ self.group_mates = [0]
221
+
222
+ # TYPE2: 只会在monitor_gnorm_with_ad()主调中赋值的变量
223
+ self.model = None
224
+ self.vpp = False
225
+ self.dp_group = None
226
+ self.tp_group = None
227
+ self.enable_megatron = False
228
+ self.micro_batch_number = 1
229
+
230
+ # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
231
+ self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
232
+ self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
233
+ self.optimizer_context = defaultdict(OptimizerContext)
234
+ self.cc_context = defaultdict(CommunicationContext)
235
+ self.grad_context = GradContext()
236
+ self.handles = defaultdict(list)
237
+ self.param2name = defaultdict(str)
238
+ self.name2index = defaultdict()
239
+ self.name2indices = defaultdict()
240
+ self.name2param = {}
241
+ self.duplicate_param = {}
242
+ self.name2tag = {}
243
+ self.param_name_call_id = {}
244
+ self.call_id = 0
245
+ self.module_struct = defaultdict(dict)
246
+ self.grad_accs = []
247
+ self.weight_hooked = False
248
+ self.optimizer_hooked = False
249
+ self.param_registered = False
250
+ self.struct_printed = False
251
+
252
+ # 动静态区分
253
+ self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
254
+ if self.dynamic_enable:
255
+ logger.warning(f"DYNAMIC_MONITOR is set, "
256
+ f"please make sure you have 'switch' and 'collect_times' item in {self.config_file_path}")
257
+ self.monitoring = False
258
+ else:
259
+ self.set_config()
260
+ # 静态且collect_times>0时在第0步self.monitoring就可以True, 动态默认在下一步开启
261
+ if self.collect_times > 0:
262
+ self.monitoring = True
263
+
264
+ def __del__(self):
265
+ if hasattr(self, "summary_writer"):
266
+ self.summary_writer.close()
267
+
268
+ @property
269
+ def ops(self):
270
+ return self._ops
271
+
272
+ @ops.setter
273
+ def ops(self, value):
274
+ self._ops = validate_ops(value)
275
+
276
+ @staticmethod
277
+ def set_wrapped_optimizer(_wrapped_optimizer):
278
+ OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer)
279
+
280
+ @staticmethod
281
+ def has_register_backward_hook(module_name, module):
282
+ if hasattr(module, '_backward_hooks') and \
283
+ len(module._backward_hooks) > 0 and \
284
+ module._is_full_backward_hook is False:
285
+ logger.warning(
286
+ f"The {module_name} has registered deprecated register_backward_hook,"
287
+ f"which may cause abnormal data dump. The backward input/output for this module will be skipped."
288
+ )
289
+ return True
290
+ return False
291
+
292
+ @staticmethod
293
+ def generate_cc_metrics(cc_name, cc_tensor):
294
+ metrics = defaultdict(dict)
295
+ rank = dist.get_rank() if dist.is_initialized() else None
296
+ for op, tag2tensor in cc_tensor.data.items():
297
+ for tag, tensor in tag2tensor.items():
298
+ key = get_summary_writer_tag_name(cc_name, tag, rank)
299
+ metrics[op].update({key: tensor})
300
+ cc_tensor.reset()
301
+ return metrics
302
+
303
+ def set_config(self):
304
+ logger.info(f"current config: {self.config}")
305
+ self.start_step = self.config.get("start_step", 0)
306
+ self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集
307
+ self.step_interval = self.config.get("step_interval", 1)
308
+ self.has_collect_times = 0 # 重设采集计数器
309
+ self.print_struct = self.config.get("print_struct", False)
170
310
  self.module_rank_list = self.config.get("module_ranks", [])
171
311
  self.format = self.config.get('format', 'tensorboard')
172
312
  self.eps = self.config.get('eps', 1e-8)
@@ -182,6 +322,7 @@ class TrainerMon:
182
322
  self.param_distribution = self.config.get("param_distribution", False)
183
323
  self.mg_direction = self.config.get('mg_direction', False)
184
324
  self.cc_distribution = self.config.get("cc_distribution", {})
325
+
185
326
  if not self.cc_distribution.get('enable', False):
186
327
  self.cc_log_only = False
187
328
  else:
@@ -189,49 +330,30 @@ class TrainerMon:
189
330
  self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
190
331
  self.cc_logged_stack = defaultdict(set)
191
332
  self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
192
- api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
333
+ self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
193
334
  api_register.redirect_api()
194
335
 
195
336
  self.common_info()
196
337
 
338
+ # 初始化AnomalyData工厂
197
339
  alert_setting = self.config.get('alert', {"rules": []})
198
340
  self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
199
-
200
- # 设置时区,使用 'UTC' 作为示例
201
- local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
202
-
203
- cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
204
- unique_id = str(uuid.uuid4())[:8]
205
-
206
- if dist.is_initialized():
207
- rank = dist.get_rank()
208
- tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}")
209
- pp_stage = dist.get_group_rank(self.process_group, rank)
210
- group_mates = dist.get_process_group_ranks(self.process_group)
211
- else:
212
- rank = 0
213
- tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}")
214
- pp_stage = 0
215
- group_mates = [0]
216
- self.rank = rank
217
-
218
- # 初始化AnomalyData工厂
219
341
  self.anomaly_data_factory = None
220
342
  if alert_setting.get('dump', False):
221
- self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates)
343
+ self.anomaly_data_factory = AnomalyDataFactory(self.rank, self.pp_stage, self.group_mates)
222
344
 
345
+ # 初始化writer, 创建输出目录
223
346
  if self.format not in FORMAT_MAPPING:
224
347
  raise ValueError(f"Unsupported format: {self.format}")
225
- writer, self.write_metrics = FORMAT_MAPPING[self.format]
348
+ writer = FORMAT_MAPPING[self.format]
226
349
  self.step_count_per_record = self.config.get('step_count_per_record', 1)
227
350
 
228
- if (rank in self.module_rank_list) or len(self.module_rank_list) == 0:
351
+ if (self.rank in self.module_rank_list) or len(self.module_rank_list) == 0:
229
352
  self.summary_writer = writer(
230
353
  WriterInput(
231
- tensorboard_dir,
354
+ self.tensorboard_dir,
232
355
  self.alert_rules,
233
- unique_id,
234
- None,
356
+ self.unique_id,
235
357
  self.anomaly_data_factory,
236
358
  self.ndigits,
237
359
  self.step_count_per_record
@@ -239,83 +361,22 @@ class TrainerMon:
239
361
  )
240
362
  # 初始化anomaly detected文件目录
241
363
  if self.anomaly_data_factory:
242
- self.anomaly_data_writer = AnomalyDataWriter(os.path.join(output_base_dir, "anomaly_detected"), rank)
364
+ self.anomaly_data_writer = AnomalyDataWriter(os.path.join(self.output_base_dir, "anomaly_detected"),
365
+ self.rank)
243
366
  self.anomaly_data_writer.init_detected_json()
244
367
 
245
- # A HeatmapVisualizer instance is associated with an image
246
- self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
247
- self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
248
- self.micro_batch_number = 1
249
-
250
- self.model = None
251
- self.weight_hooked = False
252
- self.optimizer_hooked = False
253
- self.param_registered = False
254
- self.vpp = False
255
- self.dp_group = None
256
- self.tp_group = None
257
- self.enable_megatron = False
258
-
259
- self.param2name = defaultdict(str)
260
- self.name2index = defaultdict()
261
- self.name2indices = defaultdict()
262
- self.name2param = {}
263
- self.param_name_call_id = {}
264
- self.duplicate_param = {}
265
- self.name2tag = {}
266
- self.call_id = 0
267
- self.grad_accs = []
268
- self.handles = defaultdict(list)
269
-
270
- self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty)
271
- self.print_struct = self.config.get("print_struct", False)
272
- self.struct_printed = False
273
- self.module_struct = {}
274
-
275
- def __del__(self):
276
- if hasattr(self, "summary_writer"):
277
- self.summary_writer.close()
278
-
279
- @property
280
- def ops(self):
281
- return self._ops
282
-
283
- @ops.setter
284
- def ops(self, value):
285
- self._ops = validate_ops(value)
286
-
287
- @staticmethod
288
- def set_wrapped_optimizer(_wrapped_optimizer):
289
- OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer)
290
-
291
- @staticmethod
292
- def adhoc_check(target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
368
+ def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
293
369
  rank = None
294
370
  if dist.is_initialized():
295
371
  rank = dist.get_rank()
296
372
  if (rank not in rank_list) and len(rank_list) != 0:
297
373
  return
298
- TrainerMon.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
374
+ self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
299
375
 
300
- @staticmethod
301
- def build_tbtag_tensor_map(module_name, tag, tensor):
302
- metrics = {}
303
- rank = dist.get_rank() if dist.is_initialized() else None
304
- key = get_summary_writer_tag_name(module_name, tag, rank)
305
- if torch.is_tensor(tensor):
306
- metrics[key] = tensor
307
- return metrics
308
-
309
- @staticmethod
310
- def generate_cc_metrics(cc_name, cc_tensor):
311
- metrics = defaultdict(dict)
312
- rank = dist.get_rank() if dist.is_initialized() else None
313
- for op, tag2tensor in cc_tensor.data.items():
314
- for tag, tensor in tag2tensor.items():
315
- key = get_summary_writer_tag_name(cc_name, tag, rank)
316
- metrics[op].update({key: tensor})
317
- cc_tensor.reset()
318
- return metrics
376
+ def build_tbtag_tensor_map(self, module_name, tag, tensor):
377
+ key = get_summary_writer_tag_name(module_name, tag, self.rank)
378
+ self._register_param_call_id("_hook_module", key)
379
+ return {key: tensor}
319
380
 
320
381
  def common_info(self):
321
382
  if not self.xy_distribution:
@@ -338,31 +399,24 @@ class TrainerMon:
338
399
  if self.mv_distribution:
339
400
  raise Exception("mv_distribution cannot be enabled with unknown optimizer.")
340
401
 
341
- def hook_modules(self, model: torch.nn.Module, grad_acc_steps):
402
+ def hook_modules(self):
342
403
  if self.module_rank_list and (self.rank not in self.module_rank_list):
343
404
  return
344
405
 
345
- if not isinstance(model, list):
346
- model = [model]
347
- self.model = model
348
- self._register_param_name(model)
349
-
350
- self.micro_batch_number = grad_acc_steps
351
-
352
406
  targets = self.config['targets']
353
407
  module_in_all_stage = [key for key in targets.keys() if MonitorConst.VPP_SEP not in key]
354
408
  for key in module_in_all_stage:
355
409
  struct = targets.pop(key)
356
- targets.update({f'{vpp_stage}{MonitorConst.VPP_SEP}{key}': struct for vpp_stage in range(len(model))})
410
+ targets.update({f'{vpp_stage}{MonitorConst.VPP_SEP}{key}': struct for vpp_stage in range(len(self.model))})
357
411
 
358
412
  hooked_count = 0
359
- for vpp_stage, model_chunk in enumerate(model):
413
+ for vpp_stage, model_chunk in enumerate(self.model):
360
414
  vpp_stage = f'{vpp_stage}{MonitorConst.VPP_SEP}'
361
415
  targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
362
416
  'targets'].keys()
363
417
  hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
364
418
 
365
- logger.info_on_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.")
419
+ logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
366
420
 
367
421
  def clone_if_tensor(args):
368
422
  if isinstance(args, tuple):
@@ -383,11 +437,11 @@ class TrainerMon:
383
437
 
384
438
  BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
385
439
 
386
- if not self.optimizer_hooked:
387
- self.hook_optimizer()
388
440
  return
389
441
 
390
442
  def generate_param_metrics(self, opt_context):
443
+ if not self.param_distribution:
444
+ return
391
445
  get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
392
446
 
393
447
  def generate_mv_metrics(self, opt_context):
@@ -416,29 +470,50 @@ class TrainerMon:
416
470
  logger.warning(f"grad is None: {name}, maybe something wrong happened.")
417
471
  continue
418
472
  tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
473
+ self._register_param_call_id("hook_optimizer", tag)
419
474
  grad_dict[tag] = grad
420
475
 
421
476
  get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
422
477
  return self.grad_context.post, self.grad_context.pre
423
478
 
424
- def monitor_gnorm_with_ad(self, model, grad_acc_steps=1, optimizer=None, tp_group=None, dp_group=None):
479
+ def monitor_gnorm_with_ad(
480
+ self,
481
+ model,
482
+ grad_acc_steps=1,
483
+ optimizer=None,
484
+ tp_group=None,
485
+ dp_group=None,
486
+ start_iteration=0
487
+ ):
425
488
  """External interface"""
489
+ global start_step
490
+ start_step = start_iteration
426
491
  logger.info(f'grad acc steps {grad_acc_steps}')
427
- self.hook_optimizer(optimizer)
428
492
  self.micro_batch_number = grad_acc_steps
429
-
430
493
  self.dp_group = dp_group
431
494
  self.tp_group = tp_group
495
+ self.hook_step_final(optimizer)
496
+ if not isinstance(model, list):
497
+ model = [model]
498
+ self.model = model
499
+ if len(model) > 1:
500
+ self.vpp = True
501
+ self._smallest_rank_print('vpp enabled')
502
+ if not self.dynamic_enable:
503
+ self.register_hooks(optimizer)
432
504
 
433
- self._register_param_name(model)
505
+ def register_hooks(self, optimizer):
506
+ self._register_param_name()
507
+ self.hook_optimizer(optimizer)
434
508
  self._patch_grad_sync()
435
- self.hook_modules(model, grad_acc_steps)
509
+ self.hook_modules()
510
+ self.monitoring = True
436
511
 
437
512
  def generate_param_map(self, tag, param_tensor):
438
513
  metrics = {}
439
- rank = dist.get_rank() if dist.is_initialized() else None
440
514
  for name in self.param2name.values():
441
- key = get_summary_writer_tag_name(name, tag, rank)
515
+ key = get_summary_writer_tag_name(name, tag, self.rank)
516
+ self._register_param_call_id("optimizer_pre_step_hook", key)
442
517
  if name not in param_tensor or param_tensor[name] is None:
443
518
  continue
444
519
  metrics[key] = param_tensor[name]
@@ -459,12 +534,12 @@ class TrainerMon:
459
534
  for handle in self.handles['xy']:
460
535
  handle.remove()
461
536
  self.handles['xy'].clear()
462
- self.hook_modules(self.model, self.micro_batch_number)
537
+ self.hook_modules()
463
538
  for _, fwd_context in self.module_fwd_hook_context_by_module.items():
464
539
  fwd_context.actv.clear()
465
540
 
466
541
  def write_adhoc_check(self, step):
467
- TrainerMon.tensor_metrics.flush(self.summary_writer)
542
+ self.tensor_metrics.flush(self.summary_writer)
468
543
 
469
544
  def write_xy_tb(self, step):
470
545
  if not self.xy_distribution:
@@ -472,40 +547,53 @@ class TrainerMon:
472
547
  for _, fwd_context in self.module_fwd_hook_context_by_module.items():
473
548
  if len(fwd_context.actv) == 0:
474
549
  continue
475
- self.write_metrics(self.ops, self.summary_writer, fwd_context.actv, step, 'actv')
550
+ self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, 'actv')
476
551
  fwd_context.actv.clear()
477
552
  if self.grad_context.actv:
478
- self.write_metrics(self.ops, self.summary_writer, self.grad_context.actv, step, 'actv_grad')
553
+ self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, 'actv_grad')
479
554
 
480
555
  def write_param_tb(self, opt_context):
481
556
  if not self.param_distribution:
482
557
  return
483
- self.write_metrics(self.ops, self.summary_writer, opt_context.param_metric, opt_context.step, 'param')
558
+ self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, 'param')
484
559
 
485
560
  def write_mv_tb(self, opt_context):
486
561
  if not self.mv_distribution:
487
562
  return
488
- self.write_metrics(self.ops, self.summary_writer, opt_context.exp_avg_metric,
489
- opt_context.step, 'exp_avg')
490
- self.write_metrics(self.ops, self.summary_writer, opt_context.exp_avg_sq_metric,
491
- opt_context.step, 'exp_avg_sq')
563
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, 'exp_avg')
564
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq')
492
565
 
493
566
  def write_grad_tb(self, step):
494
567
  if not self.wg_distribution:
495
568
  return
496
569
 
497
570
  if self.enable_megatron:
498
- self.write_metrics(self.ops, self.summary_writer, self.grad_context.pre, step, 'grad_unreduced')
571
+ self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
499
572
  else:
500
- self.write_metrics(self.ops, self.summary_writer, self.grad_context.acc_metric, step, 'grad_unreduced')
501
- self.write_metrics(self.ops, self.summary_writer, self.grad_context.post, step, 'grad_reduced')
573
+ self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
574
+ self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
502
575
 
503
576
  def hook_optimizer(self, optimizer=None):
504
577
  # in DDP by default use params_have_main_grad
505
578
  def optimizer_pre_step_hook(optimizer, args, kwargs):
506
579
  context = self.optimizer_context[optimizer]
580
+
581
+ if (self.print_struct and not all(value == {} for value in self.module_struct.values())
582
+ and not self.struct_printed):
583
+ self._save_module_struct()
584
+ if not self.cc_log_only:
585
+ raise Exception("exit after first monitor step when print model struct")
586
+ if self.cc_log_only and context.step > 0:
587
+ self._smallest_rank_print("> Used communication ops and corresponding stack")
588
+ self._smallest_rank_print(
589
+ json.dumps({k: [i.split(';') for i in v] for k, v in self.cc_logged_stack.items()}))
590
+ raise Exception("exit after first step when print cc stack")
591
+
592
+ # skip generate metrics
593
+ if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0:
594
+ return
507
595
  if self.opt_ty in MonitorConst.DEEPSPEED_OPT_TY:
508
- if context.step == 0:
596
+ if not self.name2indices:
509
597
  self.name2indices = self.mix_precision_optimizer_mon.get_param_index(self.param2name,
510
598
  self.name2index)
511
599
  mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name,
@@ -518,19 +606,6 @@ class TrainerMon:
518
606
  context.param_adam_update = mv_result.update
519
607
  context.param_adam_ratio = mv_result.ratio
520
608
 
521
- if (self.print_struct and not all(value == {} for value in self.module_struct.values())
522
- and not self.struct_printed):
523
- self._smallest_rank_print("> module struct:")
524
- self._smallest_rank_print(json.dumps(self.module_struct))
525
- self.struct_printed = True
526
- if not self.cc_log_only:
527
- raise Exception("exit after first step when print model struct")
528
- if self.cc_log_only and context.step > 0:
529
- self._smallest_rank_print("> Used communication ops and corresponding stack")
530
- self._smallest_rank_print(
531
- json.dumps({k: [i.split(';') for i in v] for k, v in self.cc_logged_stack.items()}))
532
- raise Exception("exit after first step when print cc stack")
533
-
534
609
  self.generate_wgrad_metrics()
535
610
  self.generate_mv_metrics(context)
536
611
  self.generate_param_metrics(context)
@@ -561,41 +636,10 @@ class TrainerMon:
561
636
  context.metric_dict = metric_dict
562
637
  return
563
638
 
564
- def optimizer_post_step_hook(optimizer, args, kwargs):
565
- context = self.optimizer_context[optimizer]
566
- rank = dist.get_rank() if dist.is_initialized() else None
567
-
568
- if self.anomaly_data_factory:
569
- self.anomaly_data_factory.set_call_id(self.param_name_call_id)
570
- self.write_xy_tb(context.step)
571
- self.write_grad_tb(context.step)
572
- self.write_mv_tb(context)
573
- self.write_param_tb(context)
574
- self.write_adhoc_check(context.step)
575
-
576
- if self.ur_distribution:
577
- for param_name, _ in context.param_adam_update.items():
578
- self.update_heatmap_visualizer[param_name].visualize(
579
- get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, self.summary_writer)
580
- for param_name, _ in context.param_adam_ratio.items():
581
- self.ratio_heatmap_visualizer[param_name].visualize(
582
- get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer)
583
-
584
- if context.metric_dict:
585
- self.write_metrics(self.ops, self.summary_writer, context.metric_dict, context.step, 'other')
586
- context.metric_dict.clear()
587
- context.step += 1
588
- if self.anomaly_data_factory:
589
- self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
590
- self.summary_writer.clear_anomalies()
591
- self.call_id = 0
592
- return
593
-
594
639
  def patch_step(func, optimizer):
595
640
  def wrapper(*args, **kwargs):
596
641
  optimizer_pre_step_hook(optimizer, args, kwargs)
597
642
  out = func(*args, **kwargs)
598
- optimizer_post_step_hook(optimizer, args, kwargs)
599
643
  return out
600
644
 
601
645
  return wrapper
@@ -605,14 +649,171 @@ class TrainerMon:
605
649
 
606
650
  if optimizer:
607
651
  optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
608
-
652
+ self.handles['optimizer'] = []
609
653
  else:
610
654
  if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list):
611
- register_optimizer_step_pre_hook(optimizer_pre_step_hook)
612
- register_optimizer_step_post_hook(optimizer_post_step_hook)
655
+ step_pre_hook = register_optimizer_step_pre_hook(optimizer_pre_step_hook)
656
+ self.handles['optimizer'] = [step_pre_hook]
613
657
  self.optimizer_hooked = True
614
658
  return
615
659
 
660
+ def dynamic_monitor(self, optimizer):
661
+ """
662
+ If dynamic monitor enabled and config.json updated,
663
+ remove hooks and register new hooks according to new configuration.
664
+ """
665
+ context = self.optimizer_context[optimizer]
666
+ if not self.dynamic_enable:
667
+ return
668
+ try:
669
+ # 如果文件时间戳没变, 可以不读取节省时间
670
+ config_timestamp = os.path.getmtime(self.config_file_path)
671
+ if config_timestamp == self.config_timestamp:
672
+ return
673
+ # 更新config文件最新修改时间戳
674
+ self.config_timestamp = config_timestamp
675
+ config = load_json(self.config_file_path)
676
+ except Exception as e:
677
+ logger.error(f"get config.json wrong because {e}, not updated, please check!!!")
678
+ return
679
+
680
+ if config.get("switch", False):
681
+ try:
682
+ validate_config(config)
683
+ self.config = config
684
+ self.set_config()
685
+ logger.warning(f"config is updated at step{context.step - 1}, "
686
+ f"will start new hook at step{context.step}.")
687
+ except Exception as e:
688
+ logger.error(f"set config wrong because {e}, not updated, please check!!!")
689
+ return
690
+
691
+ self._remove_all_hooks(optimizer)
692
+ self.register_hooks(optimizer)
693
+
694
+ def hook_step_final(self, optimizer):
695
+ def step_final_hook(optimizer, args, kwargs):
696
+ context = self.optimizer_context[optimizer]
697
+ rank = dist.get_rank() if dist.is_initialized() else None
698
+ # 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
699
+ if self.monitoring:
700
+ module_rank_valid = not self.module_rank_list or (
701
+ dist.is_initialized() and dist.get_rank() in self.module_rank_list)
702
+ step_condition = (context.step >= self.start_step and (
703
+ context.step - self.start_step) % self.step_interval == 0)
704
+ if module_rank_valid and step_condition:
705
+ self.has_collect_times += 1
706
+
707
+ if self.anomaly_data_factory:
708
+ self.anomaly_data_factory.set_call_id(self.param_name_call_id)
709
+ self.write_xy_tb(context.step)
710
+ self.write_grad_tb(context.step)
711
+ self.write_mv_tb(context)
712
+ self.write_param_tb(context)
713
+ self.write_adhoc_check(context.step)
714
+
715
+ if self.ur_distribution:
716
+ for param_name, _ in context.param_adam_update.items():
717
+ self.update_heatmap_visualizer[param_name].visualize(
718
+ get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step,
719
+ self.summary_writer)
720
+ for param_name, _ in context.param_adam_ratio.items():
721
+ self.ratio_heatmap_visualizer[param_name].visualize(
722
+ get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step,
723
+ self.summary_writer)
724
+
725
+ if context.metric_dict:
726
+ self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
727
+ context.metric_dict.clear()
728
+
729
+ if self.anomaly_data_factory:
730
+ self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
731
+ self.summary_writer.clear_anomalies()
732
+ self.call_id = 0
733
+ self.param_name_call_id.clear()
734
+
735
+ if self.has_collect_times >= self.collect_times:
736
+ self._remove_all_hooks_final(optimizer)
737
+
738
+ context.step += 1
739
+ self.dynamic_monitor(optimizer)
740
+
741
+ def patch_step(func, optimizer):
742
+ def wrapper(*args, **kwargs):
743
+ out = func(*args, **kwargs)
744
+ step_final_hook(optimizer, args, kwargs)
745
+ return out
746
+ return wrapper
747
+
748
+ if optimizer:
749
+ optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
750
+ self.origin_step_func = optimizer.__class__.step
751
+ else:
752
+ register_optimizer_step_post_hook(step_final_hook)
753
+ return
754
+
755
+ def _remove_all_hooks(self, optimizer):
756
+ # 清空hook handle
757
+ for handle in self.handles['xy']:
758
+ handle.remove()
759
+ self.handles['xy'].clear()
760
+ # 清空对应context缓存
761
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
762
+ fwd_context.reset()
763
+ for _, bwd_context in self.module_bwd_hook_context_by_module.items():
764
+ bwd_context.reset()
765
+ self.grad_context.reset() # 权重梯度和激活值梯度都在这
766
+
767
+ for handle in self.handles['wgrads']:
768
+ handle.remove()
769
+ self.handles['wgrads'].clear()
770
+ self.weight_hooked = False
771
+
772
+ if len(self.handles['optimizer']) == 0 and self.optimizer_hooked:
773
+ optimizer.__class__.step = self.origin_step_func
774
+ else:
775
+ for handle in self.handles['optimizer']:
776
+ handle.remove()
777
+ self.handles['optimizer'].clear()
778
+ for _, context in self.optimizer_context.items():
779
+ context.reset()
780
+ self.optimizer_hooked = False
781
+
782
+ for handle in self.handles['cc']:
783
+ handle.remove()
784
+ self.handles['cc'].clear()
785
+ for _, context in self.cc_context.items():
786
+ context.reset()
787
+
788
+ # 清空节点缓存
789
+ self.param2name.clear()
790
+ self.name2index.clear()
791
+ self.name2indices.clear()
792
+ self.name2param.clear()
793
+ self.duplicate_param.clear()
794
+ self.name2tag.clear()
795
+ self.module_struct.clear()
796
+ self.grad_accs.clear()
797
+
798
+ # 关闭采集状态
799
+ self.monitoring = False
800
+
801
+ def _remove_all_hooks_final(self, optimizer):
802
+ if self.dynamic_enable:
803
+ # 结束后自动重置switch为False等待用户手动开启
804
+ try:
805
+ config = load_json(self.config_file_path)
806
+ config['switch'] = False
807
+ save_json(self.config_file_path, config, indent=2)
808
+ config_timestamp = os.path.getmtime(self.config_file_path)
809
+ self.config_timestamp = config_timestamp
810
+ logger.info(
811
+ "Finish monitor, set config'switch=False, will restart by set switch=True and update content")
812
+ except Exception as e:
813
+ logger.warning(f"Finish monitor, set config'switch=False fail because {e}, please check!!!")
814
+ logger.info("Finish monitor")
815
+ self._remove_all_hooks(optimizer)
816
+
616
817
  def _smallest_rank_print(self, msg):
617
818
  if dist.is_initialized():
618
819
  if self.module_rank_list:
@@ -624,9 +825,20 @@ class TrainerMon:
624
825
  else:
625
826
  logger.info(msg)
626
827
 
828
+ def _save_module_struct(self):
829
+ save_module_struct = (not dist.is_initialized()
830
+ or (self.module_rank_list and dist.get_rank() == min(self.module_rank_list))
831
+ or (not self.module_rank_list and dist.get_rank() == 0))
832
+
833
+ if save_module_struct:
834
+ module_struct_file = os.path.realpath(os.path.join(get_output_base_dir(), 'module_struct.json'))
835
+ save_json(module_struct_file, self.module_struct, indent=2)
836
+ logger.info(f"> save module struct to {module_struct_file}")
837
+ self.struct_printed = True
838
+
627
839
  def _is_target_param(self, param_name, param, prefix):
628
- squash_name = prefix + squash_param_name(param_name)
629
840
  name = prefix + param_name
841
+ squash_name = prefix + squash_param_name(param_name, self.squash_name)
630
842
  for target in self.config['targets'].keys():
631
843
  if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
632
844
  setattr(param, "zero_out_wgrad", True)
@@ -635,15 +847,14 @@ class TrainerMon:
635
847
  return False
636
848
 
637
849
  def _register_chunk(self, model_chunk, prefix):
638
- for index, (param_name, param) in enumerate(model_chunk.named_parameters()):
850
+ index = 0
851
+ for (param_name, param) in model_chunk.named_parameters():
639
852
  if not param.requires_grad:
640
853
  continue
641
854
  if self._is_target_param(param_name, param, prefix):
642
- name = prefix + squash_param_name(param_name)
855
+ name = prefix + squash_param_name(param_name, self.squash_name)
643
856
  if name in self.param2name.values():
644
- logger.error(f'same name {name} for different param. Current param is {param_name}. \
645
- May be error of squash_param_name')
646
- raise Exception("param with same name will be overwritten.")
857
+ name = prefix + param_name
647
858
  self.param2name[param] = name
648
859
  self.name2param[name] = param
649
860
  self.name2index[name] = index
@@ -652,34 +863,22 @@ class TrainerMon:
652
863
  self.duplicate_param[name] = True
653
864
  if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
654
865
  self.duplicate_param[name] = True
655
- self.name2tag[name] = {}
656
- self.name2tag[name][MonitorConst.PRE_GRAD] = get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD,
657
- self.rank)
658
- self.name2tag[name][MonitorConst.POST_GRAD] = get_summary_writer_tag_name(name, MonitorConst.POST_GRAD,
659
- self.rank)
660
-
661
- def _register_param_name(self, model):
662
- if self.param_registered:
663
- return
664
-
665
- if not isinstance(model, list):
666
- model = [model]
667
-
668
- if len(model) > 1:
669
- self.vpp = True
670
- self._smallest_rank_print('vpp enabled')
671
-
672
- for vpp_stage, model_chunk in enumerate(model):
866
+ self.name2tag[name] = {
867
+ MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank),
868
+ MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank)
869
+ }
870
+ index += 1
871
+
872
+ def _register_param_name(self):
873
+ for vpp_stage, model_chunk in enumerate(self.model):
673
874
  prefix = f'{vpp_stage}{MonitorConst.VPP_SEP}'
674
875
  self._register_chunk(model_chunk, prefix)
675
876
 
676
- self.param_registered = True
677
-
678
877
  def _is_target_module(self, module_name, targets, vpp_stage):
679
878
  if self.all_xy or self.print_struct:
680
- return vpp_stage + squash_param_name(module_name)
879
+ return vpp_stage + squash_param_name(module_name, self.squash_name)
681
880
  for pattern in [
682
- vpp_stage + squash_param_name(module_name),
881
+ vpp_stage + squash_param_name(module_name, self.squash_name),
683
882
  vpp_stage + module_name,
684
883
  ]:
685
884
  if pattern in targets:
@@ -692,63 +891,59 @@ class TrainerMon:
692
891
  return 0
693
892
 
694
893
  def fwd_hook_fun(module, module_input, module_output, name):
695
- if is_recomputation():
894
+ if not module.training or is_recomputation():
895
+ # 1 only monitor training stage.
896
+ # 2 when open recompute, skip recomputed forward stage.
696
897
  return
697
898
  if module not in self.module_fwd_hook_context_by_module:
698
899
  self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
699
900
  context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
700
901
  if not context.struct:
701
- context.struct = {MonitorConst.ACTV_IN: get_param_struct(module_input),
702
- MonitorConst.ACTV_OUT: get_param_struct(module_output)}
902
+ context.struct = {
903
+ MonitorConst.ACTV_IN: get_param_struct(module_input),
904
+ MonitorConst.ACTV_OUT: get_param_struct(module_output)
905
+ }
703
906
  if self.print_struct:
704
- if context.module_name not in self.module_struct:
705
- self.module_struct[context.module_name] = {}
706
907
  self.module_struct[context.module_name].update(context.struct)
707
908
  return
708
- if not module.training:
709
- return
710
909
  if not context.format_by_arg:
711
910
  context.set_format_by_arg(MonitorConst.ACTV_IN, self.config['targets'])
712
911
  context.set_format_by_arg(MonitorConst.ACTV_OUT, self.config['targets'])
713
912
  if not context.format_by_arg:
714
913
  return
715
914
  if not context.verified:
716
- if not context.ignore_in:
717
- context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
718
- module_input, context.module_name,
719
- MonitorConst.ACTV_IN)
915
+ context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
916
+ module_input, context.module_name,
917
+ MonitorConst.ACTV_IN)
720
918
  context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
721
919
  module_output, context.module_name,
722
920
  MonitorConst.ACTV_OUT)
723
921
  context.verified = True
724
922
  # expect output be tensor type
725
923
  tbtag_tensor_map = {}
726
- if not context.ignore_in:
727
- cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
728
- tbtag_tensor_map.update(
729
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
730
- cared_input))
924
+ cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
925
+ tbtag_tensor_map.update(
926
+ self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
927
+ cared_input))
731
928
  cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
732
929
  tbtag_tensor_map.update(
733
930
  self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT,
734
931
  cared_output))
735
932
 
736
933
  get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
737
-
738
934
  context.micro_step += 1
739
935
  if context.micro_step == self.micro_batch_number:
740
936
  context.micro_step = 0
741
- context.step += 1
742
937
  return
743
938
 
744
939
  def bwd_hook_fun(module, input_grad, output_grad):
745
940
  context: ModuleHookContext = self.module_bwd_hook_context_by_module[module]
746
941
  if not context.struct:
747
- context.struct = {MonitorConst.ACTVGRAD_IN: get_param_struct(input_grad),
748
- MonitorConst.ACTVGRAD_OUT: get_param_struct(output_grad)}
942
+ context.struct = {
943
+ MonitorConst.ACTVGRAD_IN: get_param_struct(input_grad),
944
+ MonitorConst.ACTVGRAD_OUT: get_param_struct(output_grad)
945
+ }
749
946
  if self.print_struct:
750
- if context.module_name not in self.module_struct:
751
- self.module_struct[context.module_name] = {}
752
947
  self.module_struct[context.module_name].update(context.struct)
753
948
  return
754
949
  if not context.format_by_arg:
@@ -757,21 +952,19 @@ class TrainerMon:
757
952
  if not context.format_by_arg:
758
953
  return
759
954
  if not context.verified:
760
- if not context.ignore_in:
761
- context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
762
- input_grad, context.module_name,
763
- MonitorConst.ACTVGRAD_IN)
955
+ context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
956
+ input_grad, context.module_name,
957
+ MonitorConst.ACTVGRAD_IN)
764
958
  context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
765
959
  output_grad, context.module_name,
766
960
  MonitorConst.ACTVGRAD_OUT)
767
961
  context.verified = True
768
962
 
769
963
  tbtag_tensor_map = {}
770
- if not context.ignore_in:
771
- cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
772
- tbtag_tensor_map.update(
773
- self.build_tbtag_tensor_map(
774
- f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad))
964
+ cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
965
+ tbtag_tensor_map.update(
966
+ self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN,
967
+ cared_input_grad))
775
968
  cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
776
969
  tbtag_tensor_map.update(
777
970
  self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT,
@@ -787,7 +980,6 @@ class TrainerMon:
787
980
  context.micro_step += 1
788
981
  if context.micro_step == self.micro_batch_number:
789
982
  context.micro_step = 0
790
- context.step += 1
791
983
  return
792
984
 
793
985
  if self.backward_only and self.forward_only:
@@ -802,7 +994,7 @@ class TrainerMon:
802
994
  if not self.backward_only:
803
995
  handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name))
804
996
  self.handles['xy'].append(handle)
805
- if not self.forward_only:
997
+ if not self.forward_only and not self.has_register_backward_hook(name, submodule):
806
998
  handle = submodule.register_full_backward_hook(bwd_hook_fun)
807
999
  self.handles['xy'].append(handle)
808
1000
  self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
@@ -814,8 +1006,9 @@ class TrainerMon:
814
1006
  def patch_sync(sync_grad_func):
815
1007
  def wrapper(bucket):
816
1008
  grad_dict = {}
1009
+ bucket_params_id_list = [id(params) for params in bucket.params_list]
817
1010
  for param, name in self.param2name.items():
818
- if param not in bucket.params_list:
1011
+ if id(param) not in bucket_params_id_list:
819
1012
  continue
820
1013
  grad = param.main_grad if self.params_have_main_grad else param.grad
821
1014
  if grad is None:
@@ -825,6 +1018,7 @@ class TrainerMon:
825
1018
  if tag is None:
826
1019
  continue
827
1020
  grad_dict[tag] = grad
1021
+ self._register_param_call_id("sync_grad_func", tag)
828
1022
  get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
829
1023
  out = sync_grad_func(bucket)
830
1024
  return out
@@ -837,6 +1031,9 @@ class TrainerMon:
837
1031
  except ImportError:
838
1032
  self.enable_megatron = False
839
1033
 
1034
+ if not self.wg_distribution:
1035
+ return
1036
+
840
1037
  if self.enable_megatron:
841
1038
  Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
842
1039
  else:
@@ -848,8 +1045,7 @@ class TrainerMon:
848
1045
  @torch.no_grad
849
1046
  def param_hook(*args, context_dict, param, key, name):
850
1047
  param.micro_step += 1
851
- self.param_name_call_id[name] = self.call_id
852
- self.call_id += 1
1048
+ self._register_param_call_id("param_hook", key)
853
1049
  if param.micro_step == self.micro_batch_number:
854
1050
  param.micro_step = 0
855
1051
  if self.params_have_main_grad:
@@ -868,3 +1064,13 @@ class TrainerMon:
868
1064
  self.handles['wgrads'].append(handle)
869
1065
 
870
1066
  self.weight_hooked = True
1067
+
1068
+ def _register_param_call_id(self, hook_name: str, key: str):
1069
+ """
1070
+ :param hook_name:
1071
+ :param key: str, '0:relu_0/output_grad'
1072
+ :return:
1073
+ """
1074
+ logger.debug(f"{hook_name} {key}: {self.call_id}")
1075
+ self.param_name_call_id[key] = self.call_id
1076
+ self.call_id += 1