mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -20,22 +20,24 @@ from collections import defaultdict
20
20
  from datetime import datetime
21
21
 
22
22
  import pytz
23
- import mindspore as ms
24
- import mindspore.common.dtype as mstype
25
- from mindspore import Tensor, ops, mint
23
+ import pandas as pd
24
+ import mindspore
25
+ from mindspore import Tensor, mint
26
26
  from mindspore import nn, _no_grad
27
- from mindspore.communication import get_rank
28
27
 
29
28
  from msprobe.core.common.log import logger
30
- from msprobe.core.common.const import MonitorConst
31
- from msprobe.core.common.file_utils import load_json
29
+ from msprobe.core.common.const import MonitorConst, Const
30
+ from msprobe.core.common.file_utils import load_json, save_json
31
+ from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter
32
+ from msprobe.mindspore.common.utils import is_mindtorch
33
+ from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank
32
34
  from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \
33
- is_skip_step, get_metrics, get_single_metrics
34
- from msprobe.mindspore.monitor.module_spec_verifier import validate_config_spec
35
- from msprobe.mindspore.monitor.anomaly_detect import AnomalyScanner, AnomalyDataFactory, \
36
- CSVWriterWithAD, BaseWriterWithAD, WriterInput
37
- from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
38
- get_process_group
35
+ is_skip_step, get_metrics, get_target_output_dir
36
+ from msprobe.mindspore.monitor.optimizer_collect import OptimizerMonFactory
37
+ from msprobe.mindspore.monitor.data_writers import CSVWriterWithAD, BaseWriterWithAD, WriterInput
38
+ from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate
39
+ from msprobe.core.common.file_utils import write_df_to_csv
40
+ from msprobe.core.common.utils import analyze_api_call_stack
39
41
 
40
42
  FORMAT_MAPPING = {
41
43
  MonitorConst.CSV: CSVWriterWithAD,
@@ -89,24 +91,11 @@ class ModuleHookContext:
89
91
  self.actvgrad = []
90
92
  self.module_name = module_name
91
93
  self.struct = {}
92
- self.format_by_arg = {}
93
- self.verified = False
94
- self.focused_in_col = 0
95
- self.focused_out_col = 0
96
- self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found
97
-
98
- def set_format_by_arg(self, key_name: str, target_config: dict):
99
- cared = target_config.get(self.module_name, self.struct)
100
- if key_name in cared:
101
- if isinstance(cared[key_name], dict):
102
- # current cared is self.struct
103
- config = cared[key_name].get('config')
104
- self.format_by_arg[key_name] = config
105
- else:
106
- # current cared is target_config[self.module_name]
107
- self.format_by_arg[key_name] = cared[key_name]
108
- elif key_name in ['input', 'input_grad']:
109
- self.ignore_in = True
94
+ self.stack = ""
95
+
96
+ def reset(self):
97
+ self.actv.clear()
98
+ self.actvgrad.clear()
110
99
 
111
100
 
112
101
  start_step = 0
@@ -116,7 +105,6 @@ start_step = 0
116
105
  class OptimizerContext:
117
106
  def __init__(self) -> None:
118
107
  self.step = start_step
119
- self.param_effective_rank = defaultdict(float)
120
108
  self.param_mg_direction = defaultdict(float)
121
109
  self.param_adam_update = defaultdict()
122
110
  self.param_adam_ratio = defaultdict()
@@ -131,6 +119,7 @@ class OptimizerContext:
131
119
  def reset(self) -> None:
132
120
  self.param_mg_direction.clear()
133
121
  self.param_adam_update.clear()
122
+ self.param_adam_ratio.clear()
134
123
  self.param_weight_grad.clear()
135
124
  self.param_exp_avg.clear()
136
125
  self.exp_avg_metric.clear()
@@ -179,50 +168,107 @@ class CommunicationContext:
179
168
 
180
169
  class TrainerMon:
181
170
  def __init__(self, config_file_path, process_group=None, params_have_main_grad=True) -> None:
171
+ # TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
172
+ self.config_file_path = config_file_path
173
+ self.process_group = process_group
174
+ self.params_have_main_grad = params_have_main_grad
175
+ self.is_mindtorch = is_mindtorch()
176
+ self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
177
+ self.config = load_json(config_file_path)
178
+ validate_config(self.config)
179
+
180
+ local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
181
+ cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
182
+ self.unique_id = str(uuid.uuid4())[:8]
183
+ self.output_base_dir = get_output_base_dir()
184
+ time_tags = self.config.get("append_output", [])
185
+ try:
186
+ self.rank = get_rank()
187
+ if time_tags:
188
+ output_append_dirs = get_target_output_dir(self.output_base_dir, time_tags[0], time_tags[1])
189
+ if str(self.rank) in output_append_dirs:
190
+ self.tensorboard_dir = output_append_dirs[str(self.rank)]
191
+ logger.info(f"Append rank({self.rank}) result to {self.tensorboard_dir}")
192
+ else:
193
+ self.tensorboard_dir = os.path.join(self.output_base_dir,
194
+ f"{cur_time}-rank{self.rank}-{self.unique_id}")
195
+ except Exception as e:
196
+ self.rank = 0
197
+ self.tensorboard_dir = os.path.join(self.output_base_dir, f"{cur_time}-rank{self.rank}-{self.unique_id}")
198
+
199
+ self.pp_stage = 0
200
+ self.group_mates = [0]
201
+
202
+ # TYPE2: 只会在set_monitor()主调中赋值的变量
203
+ self.model = None
204
+ self.vpp = False
205
+ self.dp_group = None
206
+ self.tp_group = None
207
+ self.micro_batch_number = 1
208
+ self.optimizer_mon = None
209
+
210
+ # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
182
211
  self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
183
212
  self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
184
213
  self.optimizer_context = defaultdict(OptimizerContext)
185
214
  self.cc_context = defaultdict(CommunicationContext)
186
215
  self.grad_context = GradContext()
187
- self.params_have_main_grad = params_have_main_grad
188
216
  self.handles = defaultdict(list)
189
- self.config = load_json(config_file_path)
190
- validate_config(self.config)
217
+ self.param2name = defaultdict(str)
218
+ self.name2index = defaultdict()
219
+ self.name2indices = defaultdict()
220
+ self.name2param = {}
221
+ self.duplicate_param = {}
222
+ self.name2tag = {}
223
+ self.param_name_call_id = {}
224
+ self.call_id = 0
225
+ self.module_struct = defaultdict(dict)
226
+ self.grad_accs = []
227
+ self.weight_hooked = False
228
+ self.optimizer_hooked = False
229
+ self.param_registered = False
230
+ self.struct_printed = False
231
+ self.pre_step_hooks = []
232
+ self.post_step_hooks = []
233
+
234
+ # 动静态区分
235
+ self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
236
+ if self.dynamic_enable:
237
+ logger.warning(f"DYNAMIC_MONITOR is set, "
238
+ f"please make sure you have 'dynamic_on' and 'collect_times' in {self.config_file_path}")
239
+ self.monitoring = False
240
+ else:
241
+ self.set_config()
242
+ # 静态且collect_times>0时在第0步self.monitoring就可以True, 动态默认在下一步开启
243
+ if self.collect_times > 0:
244
+ self.monitoring = True
191
245
 
246
+ def set_config(self):
192
247
  self.start_step = self.config.get("start_step", 0)
193
248
  self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集
194
249
  self.step_interval = self.config.get("step_interval", 1)
195
- self.has_collect_times = 0
196
-
197
- # monitor target in module, such as layer, weight, grad
250
+ self.has_collect_times = 0 # 重设采集计数器
251
+ self.print_struct = self.config.get("print_struct", False)
198
252
  self.targets = self.config.get("targets", None)
199
253
  self.is_select = self.config.get("is_select", False)
200
254
  self.module_rank_list = self.config.get("module_ranks", [])
201
- # only csv supported in mindspore
202
- self.format = self.config.get('format', MonitorConst.CSV)
255
+ self.format = self.config.get('format', MonitorConst.CSV) # only csv supported in mindspore
203
256
  self.eps = self.config.get('eps', 1e-8)
204
- # monitor mean/max/norm/min/nan...
205
- self.ops = self.config.get('ops', [])
257
+ self.ops = self.config.get('ops', []) # monitor mean/max/norm/min/nan...
206
258
  self.ndigits = self.config.get('ndigits', 6)
207
259
  self.all_xy = self.config.get('all_xy', False)
208
- # module input/output input_grad/output_grad
209
260
  self.xy_distribution = self.config.get('xy_distribution', False)
210
- # activation forward
211
261
  self.forward_only = self.config.get('forward_only', False)
212
- # activation backward
213
262
  self.backward_only = self.config.get('backward_only', False)
214
- # update vector and ratio vector of adam
215
- self.ur_distribution = self.config.get('ur_distribution', False)
216
- # m/v of adam
217
- self.mv_distribution = self.config.get("mv_distribution", False)
218
- # weight grad
263
+ self.ur_distribution = self.config.get('ur_distribution', False) # vector and ratio vector of adam
264
+ self.mv_distribution = self.config.get("mv_distribution", False) # m/v of adam
219
265
  self.wg_distribution = self.config.get("wg_distribution", False)
220
- # optimizer param
221
266
  self.param_distribution = self.config.get("param_distribution", False)
222
- # main grad direction
223
- self.mg_direction = self.config.get('mg_direction', False)
224
- # communication ops
225
- self.cc_distribution = self.config.get("cc_distribution", {})
267
+ self.mg_direction = self.config.get('mg_direction', False) # main grad direction
268
+ self.cc_distribution = self.config.get("cc_distribution", {}) # communication ops
269
+ self.stack_info = self.config.get('stack_info', False)
270
+ self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False)
271
+
226
272
  if not self.cc_distribution.get('enable', False):
227
273
  self.cc_log_only = False
228
274
  else:
@@ -230,167 +276,227 @@ class TrainerMon:
230
276
  self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
231
277
  self.cc_logged_stack = defaultdict(set)
232
278
  self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
233
- self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
234
- api_register.redirect_api()
235
279
  self.common_info()
236
280
 
281
+ # 初始化AnomalyData工厂
237
282
  alert_setting = self.config.get('alert', {"rules": []})
238
283
  self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
239
-
240
- local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
241
-
242
- cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
243
- unique_id = str(uuid.uuid4())[:8]
244
- output_base_dir = get_output_base_dir()
245
-
246
- time_tags = self.config.get("append_output", [])
247
- if time_tags:
248
- output_append_dirs = get_target_output_dir(output_base_dir, time_tags[0], time_tags[1])
249
- try:
250
- rank = get_rank()
251
- except Exception as e:
252
- rank = 0
253
- tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}")
254
- logger.error(f"Failed to get rank, setting tensorboard_dir to {tensorboard_dir}")
255
- pp_stage = 0
256
- group_mates = [0]
257
- else:
258
- if time_tags and str(rank) in output_append_dirs:
259
- tensorboard_dir = outputappenddirs[str(rank)]
260
- logger.info(f"Append rank({rank}) result to {tensorboard_dir}")
261
- else:
262
- tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}")
263
- pp_stage = 0
264
- group_mates = [0]
265
-
266
- self.rank = rank
267
-
268
- # 初始化AnomalyData工厂
269
284
  self.anomaly_data_factory = None
270
285
  if alert_setting.get('dump', False):
271
- self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates)
286
+ self.anomaly_data_factory = AnomalyDataFactory(self.rank, self.pp_stage, self.group_mates)
272
287
 
288
+ # 初始化writer, 创建输出目录
273
289
  if self.format not in FORMAT_MAPPING:
274
290
  logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
275
291
  self.format = MonitorConst.CSV
276
- writer = FORMAT_MAPPING[self.format]
277
292
  self.step_count_per_record = self.config.get('step_count_per_record', 1)
278
-
279
- self.summary_writer = writer(
280
- WriterInput(
281
- tensorboard_dir,
282
- self.alert_rules,
283
- unique_id,
284
- self.anomaly_data_factory,
285
- self.ndigits,
286
- self.step_count_per_record
293
+ if not self.module_rank_list or (self.rank in self.module_rank_list):
294
+ writer = FORMAT_MAPPING[self.format]
295
+ self.summary_writer = writer(
296
+ WriterInput(
297
+ self.tensorboard_dir,
298
+ self.alert_rules,
299
+ self.unique_id,
300
+ self.anomaly_data_factory,
301
+ self.ndigits,
302
+ self.step_count_per_record
303
+ )
287
304
  )
288
- )
289
-
290
- self.micro_batch_number = 1
291
-
292
- self.model = None
293
- self.weight_hooked = False
294
- self.optimizer_hooked = False
295
- self.param_registered = False
296
- self.vpp = False
297
- self.dp_group = None
298
- self.tp_group = None
299
- self.enable_megatron = False
300
305
 
301
- self.param2name = defaultdict(str)
302
- self.name2index = defaultdict()
303
- self.name2indices = defaultdict()
304
- self.name2param = {}
305
- self.param_name_call_id = {}
306
- self.duplicate_param = {}
307
- self.name2tag = {}
308
- self.call_id = 0
309
- self.grad_accs = []
310
- self.handles = defaultdict(list)
306
+ # 初始化anomaly detected文件目录
307
+ if self.anomaly_data_factory:
308
+ self.anomaly_data_writer = AnomalyDataWriter(os.path.join(self.output_base_dir, "anomaly_detected"),
309
+ self.rank)
310
+ self.anomaly_data_writer.init_detected_json()
311
311
 
312
- self.print_struct = self.config.get("print_struct", False)
313
- self.struct_printed = False
314
- self.module_struct = defaultdict(dict)
312
+ def common_info(self):
313
+ if not self.xy_distribution:
314
+ logger.info("> module input/output input_grad/output_grad is not monitored. ")
315
+ if self.forward_only:
316
+ logger.info("> only module forward is monitored. ")
317
+ if not self.ur_distribution:
318
+ logger.info("> update vector and ratio vector of adam is not monitored. ")
319
+ if not self.mv_distribution:
320
+ logger.info("> momentum and variance of adam is not monitored. ")
321
+ if not self.wg_distribution:
322
+ logger.info("> weight grad of specified module is not monitored. ")
323
+ if not self.mg_direction:
324
+ logger.info('> grad and momentum direction will not be compared.')
325
+ if not self.cc_distribution.get('enable', False):
326
+ logger.info("> cc operator is not monitored.")
315
327
 
316
- # Start
317
328
  def set_monitor(
318
329
  self,
319
330
  model,
331
+ optimizer,
320
332
  grad_acc_steps=1,
321
- optimizer=None,
322
333
  tp_group=None,
323
334
  dp_group=None,
324
- start_iteration=0):
335
+ start_iteration=0
336
+ ):
325
337
  global start_step
326
338
  start_step = start_iteration
327
- logger.info(f'grad acc steps {grad_acc_steps}')
328
- self.hook_optimizer(optimizer)
329
339
  self.micro_batch_number = grad_acc_steps
330
340
  self.dp_group = dp_group
331
341
  self.tp_group = tp_group
342
+ self.optimizer_mon = OptimizerMonFactory.create_optimizer_mon(optimizer)
343
+ self.hook_step_final(optimizer)
344
+ if not isinstance(model, list):
345
+ model = [model]
346
+ self.model = model
347
+ if len(model) > 1:
348
+ self.vpp = True
349
+ logger.info('vpp enabled')
350
+ if not self.dynamic_enable:
351
+ self.register_hooks(optimizer)
352
+
353
+ def hook_step_final(self, optimizer):
354
+ def step_final_hook(optimizer, *args, **kwargs):
355
+ context = self.optimizer_context[optimizer]
356
+ # 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
357
+ if self.monitoring:
358
+ module_rank_valid = self.is_target_rank()
359
+ step_condition = (context.step >= self.start_step and (
360
+ context.step - self.start_step) % self.step_interval == 0)
361
+ if module_rank_valid and step_condition:
362
+ self.has_collect_times += 1
363
+
364
+ if self.anomaly_data_factory:
365
+ self.anomaly_data_factory.set_call_id(self.param_name_call_id)
366
+ self.write_xy_tb(context.step)
367
+ self.write_grad_tb(context.step)
368
+ self.write_mv_tb(context)
369
+ self.write_param_tb(context)
370
+ if self.stack_info:
371
+ self.write_stack_info()
372
+ self.stack_info = False
373
+ for handle in self.handles["stack"]:
374
+ handle.remove()
375
+ self.handles["stack"].clear()
376
+
377
+ if context.metric_dict:
378
+ self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
379
+ context.metric_dict.clear()
380
+
381
+ if self.anomaly_data_factory:
382
+ self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
383
+ self.summary_writer.clear_anomalies()
384
+
385
+ self.call_id = 0
386
+ self.param_name_call_id.clear()
387
+
388
+ if self.has_collect_times >= self.collect_times:
389
+ self._remove_all_hooks_final(optimizer)
332
390
 
333
- self.hook_modules(model, grad_acc_steps)
334
- self._patch_grad_sync()
391
+ context.step += 1
392
+ self.dynamic_monitor(optimizer)
335
393
 
336
- """
337
- Start
338
- """
339
- def hook_optimizer(self, optimizer):
340
- rank_id = str(get_rank())
341
- if self.optimizer_hooked:
394
+
395
+ def patch_step(func, optimizer):
396
+ def wrapper(*args, **kwargs):
397
+ for hook in self.pre_step_hooks:
398
+ hook(optimizer, args, kwargs)
399
+ out = func(*args, **kwargs)
400
+ for hook in self.post_step_hooks:
401
+ hook(optimizer, args, kwargs)
402
+ step_final_hook(optimizer, args, kwargs)
403
+ return out
404
+ return wrapper
405
+
406
+ if self.is_mindtorch:
407
+ optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
408
+ else:
409
+ optimizer.__class__.construct = patch_step(optimizer.__class__.construct, optimizer)
410
+
411
+ return
412
+
413
+ def dynamic_monitor(self, optimizer):
414
+ """
415
+ If dynamic monitor enabled and config.json updated,
416
+ remove hooks and register new hooks according to new configuration.
417
+ """
418
+ context = self.optimizer_context[optimizer]
419
+ if not self.dynamic_enable:
420
+ return
421
+ try:
422
+ # 如果文件时间戳没变, 可以不读取节省时间
423
+ config_timestamp = os.path.getmtime(self.config_file_path)
424
+ if config_timestamp == self.config_timestamp:
425
+ return
426
+ # 更新config文件最新修改时间戳
427
+ self.config_timestamp = config_timestamp
428
+ config = load_json(self.config_file_path)
429
+ except Exception as e:
430
+ logger.error(f"get config.json wrong because {e}, not updated, please check!!!")
342
431
  return
343
432
 
433
+ if config.get("dynamic_on", False):
434
+ try:
435
+ validate_config(config)
436
+ self.config = config
437
+ self.set_config()
438
+ self.start_step = context.step # 动态启停时不受原start_step影响,永远从下一步开始
439
+ logger.warning(f"config is updated at step{context.step - 1}, "
440
+ f"will start new hook at step{context.step}.")
441
+ except Exception as e:
442
+ logger.error(f"set config wrong because {e}, not updated, please check!!!")
443
+ return
444
+
445
+ self._remove_all_hooks(optimizer)
446
+ self.register_hooks(optimizer)
447
+
448
+ def register_hooks(self, optimizer):
449
+ self._register_param_name()
450
+ self.hook_modules()
451
+ self.hook_optimizer(optimizer)
452
+ self._patch_grad_sync()
453
+ if self.cc_distribution.get('enable', False):
454
+ self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
455
+ api_register.redirect_api()
456
+ self.monitoring = True
457
+
458
+ def hook_modules(self):
344
459
  if not self.is_target_rank():
345
460
  return
461
+ module_in_all_stage = [key for key in self.targets.keys() if MonitorConst.NAME_SEP not in key]
346
462
 
347
- m_list = []
348
- v_list = []
349
- param_list = []
350
- grad_names = []
351
- for param in optimizer.get_parameters():
352
- if MonitorConst.EXP_AVG_SQ in param.name:
353
- v_list.append(param)
354
- elif MonitorConst.EXP_AVG in param.name:
355
- m_list.append(param)
356
- else:
357
- param_list.append(param)
358
- grad_names.append(param.name)
463
+ for key in module_in_all_stage:
464
+ struct = self.targets.pop(key)
465
+ self.targets.update(
466
+ {f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
359
467
 
360
- """
361
- grad reduced
362
- m/v
363
- """
364
- def optimizer_pre_hook_function(opt, grad_names, gradients):
468
+ hooked_count = 0
469
+ for vpp_stage, model_chunk in enumerate(self.model):
470
+ if not is_valid_instance(model_chunk):
471
+ logger.info("Target Model is not Cell")
472
+ continue
473
+ vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
474
+ targets = [x for x, _ in get_submodules(model_chunk)] if self.print_struct else self.targets.keys()
475
+ hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
476
+ logger.info(f"> {hooked_count} modules are monitored.")
477
+
478
+ def hook_optimizer(self, optimizer):
479
+ def optimizer_pre_step_hook(opt, *args, **kwargs):
365
480
  context = self.optimizer_context[opt]
366
- if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, \
481
+ if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
367
482
  self.collect_times):
368
483
  return
369
- gradient_list = gradients[0] if isinstance(gradients, tuple) else gradients
370
- is_select = self.is_select
371
- for idx, grad in enumerate(gradient_list):
372
- grad_name = grad_names[idx]
373
- if is_select and grad_name not in self.targets:
374
- continue
375
- get_single_metrics(self.ops, grad_name, grad, context.param_weight_grad)
376
-
377
- if self.mv_distribution:
378
- # fetch mean
379
- for param in m_list:
380
- name = param.name
381
- if is_select and name not in self.targets:
382
- continue
383
- get_single_metrics(self.ops, name, param, context.exp_avg_metric)
384
- # fetch variance
385
- for param in v_list:
386
- name = param.name
387
- if is_select and name not in self.targets:
388
- continue
389
- get_single_metrics(self.ops, name, param, context.exp_avg_sq_metric)
390
- if self.param_distribution:
391
- for param in param_list:
392
- get_single_metrics(self.ops, param.name, param, context.param_metric)
393
- self.generate_wgrad_metrics()
484
+
485
+ grad_dict = {}
486
+ if self.wg_distribution:
487
+ grad_dict = self.optimizer_mon.fetch_grad(self, self.param2name)
488
+
489
+ if self.mv_distribution or self.ur_distribution or self.mg_direction:
490
+ if self.is_mindtorch:
491
+ context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, \
492
+ context.param_adam_ratio = self.optimizer_mon.fetch_mv(self, self.param2name)
493
+ else:
494
+ context.param_exp_avg, context.param_exp_avg_sq = self.get_mv_for_ms(optimizer)
495
+
496
+ self.generate_wgrad_metrics(grad_dict)
497
+ self.generate_mv_metrics(context)
498
+ self.generate_param_metrics(context, MonitorConst.PRE_PARAM)
499
+
394
500
  metric_dict = {}
395
501
  for cc in self.cc_context.values():
396
502
  cc.aggregate()
@@ -402,191 +508,167 @@ class TrainerMon:
402
508
  context.metric_dict = metric_dict
403
509
  return
404
510
 
405
- def optimizer_post_hook_function(opt, args, gradients, outputs):
406
- context = self.optimizer_context[opt]
407
- step_skip = is_skip_step(context.step, self.start_step, self.step_interval, \
408
- self.has_collect_times, self.collect_times)
409
- if step_skip:
410
- context.step += 1
411
- return
412
- self.write_xy_tb(context.step)
413
- self.write_grad_tb(context.step)
414
- self.write_mv_tb(context)
415
- self.write_param_tb(context)
416
-
417
- if context.metric_dict:
418
- self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
419
- context.metric_dict.clear()
420
- self.has_collect_times += 1
421
- context.step += 1
422
- if self.anomaly_data_factory:
423
- self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
424
- self.summary_writer.clear_anomalies()
425
- self.call_id = 0
426
- self.param_name_call_id.clear()
427
- return
511
+ def optimizer_post_step_hook(optimizer, args, kwargs):
512
+ context = self.optimizer_context[optimizer]
513
+ self.generate_param_metrics(context, MonitorConst.POST_PARAM)
428
514
 
429
- def optimizer_pre_hook_wrapper(func, grad_names):
430
- def wrapper(opt, gradients):
431
- return func(opt, grad_names, gradients)
432
- return wrapper
433
515
 
434
- def optimizer_post_hook_wrapper(func, args=None):
435
- def wrapper(opt, gradients, outputs):
436
- return func(opt, args, gradients, outputs)
437
- return wrapper
438
-
439
- optimizer.register_forward_pre_hook(optimizer_pre_hook_wrapper(optimizer_pre_hook_function, grad_names))
440
- optimizer.register_forward_hook(optimizer_post_hook_wrapper(optimizer_post_hook_function))
516
+ if self.optimizer_hooked or not self.is_target_rank():
517
+ return
441
518
 
519
+ self.pre_step_hooks.append(optimizer_pre_step_hook)
520
+ self.post_step_hooks.append(optimizer_post_step_hook)
442
521
  self.optimizer_hooked = True
443
522
  return
444
523
 
524
+ def generate_wgrad_metrics(self, grad_dict):
525
+ if not self.wg_distribution:
526
+ return
527
+
528
+ get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
529
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
530
+
531
+ def generate_param_map(self, tag, param_tensor):
532
+ metrics = {}
533
+ if not self.is_mindtorch:
534
+ return param_tensor
535
+ for name in self.param2name.values():
536
+ key = get_summary_writer_tag_name(name, tag, self.rank)
537
+ self.register_param_call_id("optimizer_pre_step_hook", key)
538
+ if name not in param_tensor or param_tensor[name] is None:
539
+ continue
540
+ metrics[key] = param_tensor[name]
541
+ return metrics
542
+
543
+ def generate_param_metrics(self, opt_context, stage=MonitorConst.PRE_PARAM):
544
+ if not self.param_distribution:
545
+ return
546
+ tag2param = {
547
+ self.name2tag.get(name, {}).get(stage): param
548
+ for name, param in self.name2param.items()
549
+ if param.numel() != 0
550
+ }
551
+ get_metrics(self.ops, tag2param, self.eps, opt_context.param_metric)
552
+
553
+ def get_mv_for_ms(self, opt):
554
+ if not self.mv_distribution:
555
+ return {}, {}
556
+ common_opt = opt
557
+ if not is_valid_instance(opt):
558
+ common_opt = getattr(opt, 'optimizer')
559
+ if not is_valid_instance(common_opt):
560
+ logger.warning("Optimizer is not valid, please check usage")
561
+ return {}, {}
562
+ m_dict = {}
563
+ v_dict = {}
564
+ for name, param in get_parameters(common_opt):
565
+ if MonitorConst.EXP_AVG_SQ in name:
566
+ v_dict[name] = param
567
+ elif MonitorConst.EXP_AVG in name:
568
+ m_dict[name] = param
569
+ return m_dict, v_dict
570
+
571
+ def generate_mv_metrics(self, opt_context):
572
+ if not self.mv_distribution:
573
+ return
574
+ opt_context.exp_avg_metric = {}
575
+ opt_context.exp_avg_sq_metric = {}
576
+ m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
577
+ v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
578
+ get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
579
+ get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
580
+
581
+ def write_stack_info(self):
582
+ stack_data = []
583
+ header = ["module_name", "stack_info"]
584
+ stack_data.append(header)
585
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
586
+ stack_data.append([fwd_context.module_name, fwd_context.stack])
587
+ filepath = os.path.join(self.tensorboard_dir, f'stack_info.csv')
588
+ if not os.path.exists(filepath):
589
+ data_frame = pd.DataFrame(columns=stack_data)
590
+ write_df_to_csv(data_frame, filepath)
591
+
445
592
  def write_xy_tb(self, step):
446
593
  if not self.xy_distribution:
447
594
  return
448
595
  for _, fwd_context in self.module_fwd_hook_context_by_module.items():
449
596
  if len(fwd_context.actv) == 0:
450
597
  continue
451
- self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, 'actv')
598
+ self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, MonitorConst.ACTV)
452
599
  fwd_context.actv.clear()
453
600
  if self.grad_context.actv:
454
- self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, 'actv_grad')
601
+ self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD)
455
602
 
456
603
  def write_param_tb(self, opt_context):
457
604
  if not self.param_distribution:
458
605
  return
459
- self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, 'param')
606
+ param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.PRE_PARAM in k}
607
+ updated_param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.POST_PARAM in k}
608
+ self.summary_writer.write_metrics(self.ops, param_metrics, opt_context.step, MonitorConst.PRE_PARAM)
609
+ self.summary_writer.write_metrics(self.ops, updated_param_metrics, opt_context.step, MonitorConst.POST_PARAM)
460
610
 
461
611
  def write_mv_tb(self, opt_context):
462
612
  if not self.mv_distribution:
463
613
  return
464
- self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, 'exp_avg')
465
- self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq')
614
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, MonitorConst.EXP_AVG)
615
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step,
616
+ MonitorConst.EXP_AVG_SQ)
466
617
 
467
618
  def write_grad_tb(self, step):
468
619
  if not self.wg_distribution:
469
620
  return
470
621
 
471
- if self.enable_megatron:
472
- self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
473
- else:
474
- self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
622
+ self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced',
623
+ use_micro_step=self.monitor_mbs_grad)
475
624
  self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
476
625
 
477
- def common_info(self):
478
- if not self.xy_distribution:
479
- logger.info("> module input/output input_grad/output_grad is not monitored. ")
480
- if self.forward_only:
481
- logger.info("> only module forward is monitored. ")
482
- if not self.ur_distribution:
483
- logger.info("> update vector and ratio vector of adam is not monitored. ")
484
- if not self.mv_distribution:
485
- logger.info("> momentum and variance of adam is not monitored. ")
486
- if not self.wg_distribution:
487
- logger.info("> weight grad of specified module is not monitored. ")
488
- if not self.mg_direction:
489
- logger.info('> grad and momentum direction will not be compared.')
490
- if not self.cc_distribution.get('enable', False):
491
- logger.info("> cc operator is not monitored.")
492
-
493
626
  def is_target_rank(self):
494
- rank_id = str(get_rank())
495
- if self.module_rank_list and (rank_id not in self.module_rank_list):
627
+ if self.module_rank_list and (self.rank not in self.module_rank_list):
496
628
  return False
497
629
  return True
498
630
 
499
- def hook_modules(self, model, grad_acc_steps):
500
- if not self.is_target_rank():
501
- return
502
- if not isinstance(model, list):
503
- model = [model]
504
- self.model = model # list
505
- self._register_param_name(model)
506
- self.micro_batch_number = grad_acc_steps
507
- module_in_all_stage = [key for key in self.targets.keys() if MonitorConst.NAME_SEP not in key]
508
-
509
- for key in module_in_all_stage:
510
- struct = self.targets.pop(key)
511
- self.targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(model))})
512
-
513
- hooked_count = 0
514
- for vpp_stage, model_chunk in enumerate(model):
515
- if not isinstance(model_chunk, nn.Cell):
516
- logger.info("Target Model is not Cell")
517
- continue
518
- vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
519
- targets = [x for x, _ in model_chunk.cells_and_names()] if self.print_struct else self.targets.keys()
520
- hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
521
- logger.info(f"> {hooked_count} modules are monitored.")
522
-
523
- def build_tbtag_tensor_map(self, module_name, tag, tensor):
524
- rank_id = str(get_rank())
525
- metrics = {}
526
- key = get_summary_writer_tag_name(module_name, tag, rank_id)
631
+ def build_tbtag_tensor_map(self, module_name, suffix, tag, tensor):
632
+ """
633
+ :param module_name: str of module name
634
+ :param suffix:
635
+ :param tag:
636
+ :param tensor: torch.tensor or tuple/list of torch.tensor
637
+ :return: tensor_map
638
+ """
639
+ tensor_map = {}
527
640
  if isinstance(tensor, Tensor):
528
- self._register_param_call_id("_hook_module", key)
529
- metrics[key] = tensor
530
- return metrics
531
-
532
- def generate_wgrad_metrics(self):
533
- if not self.wg_distribution:
534
- return {}, {}
535
-
536
- if self.weight_hooked:
537
- try:
538
- get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
539
- except Exception as e:
540
- logger.warning(f"An error occurred while generating wgrad pre metrics")
541
- return {}, {}
542
-
543
- grad_dict = {}
544
- for param, name in self.param2name.items():
545
- if self.duplicate_param.get(name, False):
546
- continue
547
- grad = param.main_grad if self.params_have_main_grad else param.grad
548
- if grad is None:
549
- logger.warning(f"grad is None: {name}, maybe something wrong happened.")
550
- continue
551
- tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
552
- self._register_param_call_id("hook_optimizer", tag)
553
- grad_dict[tag] = grad
554
- try:
555
- get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
556
- except Exception as e:
557
- logger.warning(f"An error occurred while generating wgrad post metrics")
558
- return {}, {}
559
- return self.grad_context.post, self.grad_context.pre
560
-
561
- def _register_param_name(self, model):
562
- if self.param_registered:
563
- return
641
+ tensor = [tensor]
642
+ if isinstance(tensor, tuple) or isinstance(tensor, list):
643
+ if len(tensor) == 1:
644
+ key = get_summary_writer_tag_name(module_name + suffix, tag, self.rank)
645
+ self.register_param_call_id("_hook_module", key)
646
+ tensor_map[key] = tensor[0]
647
+ else:
648
+ for i, tensor_i in enumerate(tensor):
649
+ key = get_summary_writer_tag_name(module_name + f"_{i}" + suffix, tag, self.rank)
650
+ self.register_param_call_id("_hook_module", key)
651
+ tensor_map[key] = tensor_i
652
+ return tensor_map
564
653
 
565
- if len(model) > 1:
566
- self.vpp = True
567
- logger.info('vpp enabled')
654
+ def register_param_call_id(self, hook_name: str, key: str):
655
+ """
656
+ :param hook_name:
657
+ :param key: str, '0:relu_0/output_grad'
658
+ :return:
659
+ """
660
+ logger.debug(f"{hook_name} {key}: {self.call_id}")
661
+ self.param_name_call_id[key] = self.call_id
662
+ self.call_id += 1
568
663
 
569
- for vpp_stage, model_chunk in enumerate(model):
664
+ def _register_param_name(self):
665
+ for vpp_stage, model_chunk in enumerate(self.model):
570
666
  prefix = f'{vpp_stage}{MonitorConst.NAME_SEP}'
571
667
  self._register_chunk(model_chunk, prefix)
572
668
 
573
- self.param_registered = True
574
-
575
- def _is_target_param(self, param_name, param, prefix):
576
- if not self.targets:
577
- return True
578
- squash_name = prefix + squash_param_name(param_name)
579
- name = prefix + param_name
580
- for target in self.targets.keys():
581
- if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
582
- setattr(param, "zero_out_wgrad", True)
583
- return True
584
- return False
585
-
586
669
  def _register_chunk(self, model_chunk, prefix):
587
670
  index = 0
588
- for param in model_chunk.get_parameters():
589
- param_name = param.name
671
+ for param_name, param in get_parameters(model_chunk):
590
672
  if not param.requires_grad:
591
673
  continue
592
674
  if self._is_target_param(param_name, param, prefix):
@@ -601,71 +683,59 @@ class TrainerMon:
601
683
  self.duplicate_param[name] = True
602
684
  if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
603
685
  self.duplicate_param[name] = True
686
+ keywords = [
687
+ MonitorConst.PRE_GRAD,
688
+ MonitorConst.POST_GRAD,
689
+ MonitorConst.PRE_PARAM,
690
+ MonitorConst.POST_PARAM
691
+ ]
604
692
  self.name2tag[name] = {
605
- MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank),
606
- MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank)
693
+ k: get_summary_writer_tag_name(name, k, self.rank)
694
+ for k in keywords
607
695
  }
608
696
  index += 1
609
697
 
610
- def _is_target_module(self, module_name, targets, vpp_stage):
611
- if self.all_xy or self.print_struct:
612
- return vpp_stage + squash_param_name(module_name)
613
- for pattern in [
614
- vpp_stage + squash_param_name(module_name),
615
- vpp_stage + module_name,
616
- ]:
617
- if pattern in targets:
618
- return pattern
619
- return ""
620
-
621
698
  def _hook_module(self, target_names, module, vpp_stage=''):
622
- if not isinstance(module, nn.Cell):
699
+ if not is_valid_instance(module):
623
700
  # nothing to hook
624
701
  return 0
625
702
 
626
- def fwd_hook_fun(module, module_input, module_output, name):
703
+ def fwd_hook_fun(module, args, kwargs, module_output, name):
704
+
705
+ module_input = [tensor for tensor in args if isinstance(tensor, Tensor)]
706
+ if kwargs:
707
+ kwargs_tensors = [tensor for tensor in kwargs.values() if isinstance(tensor, Tensor)]
708
+ module_input.extend(kwargs_tensors)
709
+
627
710
  if module not in self.module_fwd_hook_context_by_module:
628
711
  self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
629
712
  context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
630
713
  if not context.struct:
631
714
  context.struct = {
632
- MonitorConst.ACTV_IN: get_param_struct(module_input),
633
- MonitorConst.ACTV_OUT: get_param_struct(module_output)
715
+ Const.INPUT: get_param_struct(module_input),
716
+ Const.OUTPUT: get_param_struct(module_output)
634
717
  }
635
718
  if self.print_struct:
636
719
  self.module_struct[context.module_name].update(context.struct)
637
720
  return
638
721
  if not module.training:
639
722
  return
640
- if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, \
723
+ if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
641
724
  self.collect_times):
642
725
  step_accumulates_one(context, self.micro_batch_number)
643
726
  return
644
- if not context.format_by_arg:
645
- context.set_format_by_arg(MonitorConst.ACTV_IN, self.targets)
646
- context.set_format_by_arg(MonitorConst.ACTV_OUT, self.targets)
647
- if not context.format_by_arg:
648
- return
649
- if not context.verified:
650
- if not context.ignore_in:
651
- context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
652
- module_input, context.module_name,
653
- MonitorConst.ACTV_IN)
654
- context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
655
- module_output, context.module_name,
656
- MonitorConst.ACTV_OUT)
657
- context.verified = True
658
727
 
659
728
  tbtag_tensor_map = {}
660
- if not context.ignore_in:
661
- cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
662
- tbtag_tensor_map.update(
663
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
664
- cared_input))
665
- cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
666
729
  tbtag_tensor_map.update(
667
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT,
668
- cared_output))
730
+ self.build_tbtag_tensor_map(
731
+ f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
732
+ MonitorConst.ACTV, module_input))
733
+ module_output = [tensor for tensor in module_output if isinstance(tensor, Tensor)] \
734
+ if isinstance(module_output, tuple) else module_output
735
+ tbtag_tensor_map.update(
736
+ self.build_tbtag_tensor_map(
737
+ f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
738
+ MonitorConst.ACTV, module_output))
669
739
  try:
670
740
  get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
671
741
  except Exception as e:
@@ -685,36 +755,22 @@ class TrainerMon:
685
755
  self.module_struct[context.module_name].update(context.struct)
686
756
  return
687
757
 
688
- if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, \
758
+ if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
689
759
  self.collect_times):
690
760
  step_accumulates_one(context, self.micro_batch_number)
691
761
  return
692
762
 
693
- if not context.format_by_arg:
694
- context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.targets)
695
- context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.targets)
696
- if not context.format_by_arg:
697
- return
698
- if not context.verified:
699
- if not context.ignore_in:
700
- context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
701
- input_grad, context.module_name,
702
- MonitorConst.ACTVGRAD_IN)
703
- context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
704
- output_grad, context.module_name,
705
- MonitorConst.ACTVGRAD_OUT)
706
- context.verified = True
707
-
763
+ valid_input_grad = [tensor for tensor in input_grad if isinstance(tensor, Tensor)]
708
764
  tbtag_tensor_map = {}
709
- if not context.ignore_in:
710
- cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
711
- tbtag_tensor_map.update(
712
- self.build_tbtag_tensor_map(
713
- f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad))
714
- cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
715
765
  tbtag_tensor_map.update(
716
- self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT,
717
- cared_output_grad))
766
+ self.build_tbtag_tensor_map(
767
+ f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
768
+ MonitorConst.ACTVGRAD, valid_input_grad))
769
+
770
+ tbtag_tensor_map.update(
771
+ self.build_tbtag_tensor_map(
772
+ f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
773
+ MonitorConst.ACTVGRAD, output_grad))
718
774
 
719
775
  if context.micro_step == 0 and context.actvgrad:
720
776
  logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
@@ -728,21 +784,39 @@ class TrainerMon:
728
784
  step_accumulates_one(context, self.micro_batch_number)
729
785
  return
730
786
 
731
- def fwd_hook_fun_wrapper(fwd_hook_fun, name):
732
- def wrapper(module, module_input, module_output):
733
- return fwd_hook_fun(module, module_input, module_output, name)
734
- return wrapper
787
+ def fwd_hook_register(module, fwd_hook_fun, name):
788
+ if mindspore.__version__ >= '2.6.0':
789
+ def wrapper(module, args, kwargs, module_output):
790
+ return fwd_hook_fun(module, args, kwargs, module_output, name)
791
+ return module.register_forward_hook(wrapper, with_kwargs=True)
792
+
793
+ else:
794
+ def wrapper(module, args, module_output):
795
+ return fwd_hook_fun(module, args, None, module_output, name)
796
+ return module.register_forward_hook(wrapper)
797
+
798
+ def stack_hook(module, args, kwargs, module_output, name):
799
+ if module not in self.module_fwd_hook_context_by_module:
800
+ self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
801
+ context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
802
+ context.stack = analyze_api_call_stack(name)
803
+ return
735
804
 
736
805
  if self.backward_only and self.forward_only:
737
806
  logger.warning('not enable backward_only and forward_only simultaneously')
738
807
  hooked_count = 0
739
- if self.xy_distribution or self.print_struct:
740
- for module_name, submodule in module.cells_and_names():
741
- name = self._is_target_module(module_name, target_names, vpp_stage)
742
- if not name:
743
- continue
808
+
809
+ for module_name, submodule in get_submodules(module):
810
+ if self.stack_info:
811
+ name = vpp_stage + squash_param_name(module_name)
812
+ handle = fwd_hook_register(submodule, stack_hook, name=name)
813
+ self.handles["stack"].append(handle)
814
+ name = self._is_target_module(module_name, target_names, vpp_stage)
815
+ if not name:
816
+ continue
817
+ if self.xy_distribution or self.print_struct:
744
818
  if not self.backward_only:
745
- handle = submodule.register_forward_hook(fwd_hook_fun_wrapper(fwd_hook_fun, name=name))
819
+ handle = fwd_hook_register(submodule, fwd_hook_fun, name=name)
746
820
  self.handles['xy'].append(handle)
747
821
  if not self.forward_only:
748
822
  handle = submodule.register_backward_hook(bwd_hook_fun)
@@ -752,70 +826,120 @@ class TrainerMon:
752
826
  hooked_count += 1
753
827
  return hooked_count
754
828
 
755
- def _register_param_call_id(self, hook_name: str, key: str):
756
- """
757
- :param hook_name:
758
- :param key: str, '0:relu_0/output_grad'
759
- :return:
760
- """
761
- logger.debug(f"{hook_name} {key}: {self.call_id}")
762
- self.param_name_call_id[key] = self.call_id
763
- self.call_id += 1
764
-
765
829
  def _patch_grad_sync(self):
766
- # mindspore 暂不使用megatron
767
- def patch_sync(sync_grad_func):
768
- def wrapper(bucket):
769
- grad_dict = {}
770
- for param, name in self.param2name.items():
771
- if param not in bucket.params_list:
772
- continue
773
- grad = param.main_grad if self.params_have_main_grad else param.grad
774
- if grad is None:
775
- logger.warning(f"grad is None: {name}, maybe something wrong happened.")
776
- continue
777
- tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
778
- if tag is None:
779
- continue
780
- grad_dict[tag] = grad
781
- try:
782
- get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
783
- except Exception as e:
784
- logger.warning(f"An error occurred while generating weight grad metrics")
785
- out = sync_grad_func(bucket)
786
- return out
787
-
788
- return wrapper
789
-
790
- self.enable_megatron = False
791
-
792
830
  if not self.wg_distribution:
793
831
  return
794
-
795
- if self.enable_megatron:
796
- Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
797
- else:
798
- self._hook_weights()
832
+ self._hook_weights()
799
833
 
800
834
  def _hook_weights(self):
801
835
  context = self.grad_context
802
836
 
803
837
  @_no_grad()
804
- def param_hook(grad, context_dict, param, key):
838
+ def param_hook(grad, context_dict, param, name):
839
+ key = name
840
+ if self.monitor_mbs_grad:
841
+ key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
842
+ key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
843
+ self.register_param_call_id("param_hook", key)
805
844
  param.micro_step += 1
806
- self._register_param_call_id("param_hook", key)
845
+
846
+ if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
847
+ context_dict[key] = grad
807
848
  if param.micro_step == self.micro_batch_number:
808
849
  param.micro_step = 0
809
- context_dict[key] = grad
810
850
 
811
- def param_hook_wrapper(param_hook, context_dict, param, key):
851
+ def param_hook_wrapper(param_hook, context_dict, param, name):
812
852
  def wrapper(grad):
813
- return param_hook(grad, context_dict, param, key)
853
+ return param_hook(grad, context_dict, param, name)
854
+
814
855
  return wrapper
815
856
 
857
+ logger.info("hooking weights.")
816
858
  for param, name in self.param2name.items():
817
- key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
818
859
  setattr(param, 'micro_step', 0)
819
- handle = param.register_hook(param_hook_wrapper(param_hook, context_dict=context.acc, param=param, key=key))
860
+ handle = param.register_hook(
861
+ param_hook_wrapper(param_hook, context_dict=context.acc, param=param, name=name))
820
862
  self.handles['wgrads'].append(handle)
821
863
  self.weight_hooked = True
864
+
865
+ def _is_target_param(self, param_name, param, prefix):
866
+ if not self.targets:
867
+ return True
868
+ squash_name = prefix + squash_param_name(param_name)
869
+ name = prefix + param_name
870
+ for target in self.targets.keys():
871
+ if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
872
+ setattr(param, "zero_out_wgrad", True)
873
+ return True
874
+ return False
875
+
876
+ def _is_target_module(self, module_name, targets, vpp_stage):
877
+ if self.all_xy or self.print_struct:
878
+ return vpp_stage + squash_param_name(module_name)
879
+ for pattern in [
880
+ vpp_stage + squash_param_name(module_name),
881
+ vpp_stage + module_name,
882
+ ]:
883
+ if pattern in targets:
884
+ return pattern
885
+ return ""
886
+
887
+ def _remove_all_hooks(self, optimizer):
888
+ # 清空hook handle
889
+ for handle in self.handles['xy']:
890
+ handle.remove()
891
+ self.handles['xy'].clear()
892
+ # 清空对应context缓存
893
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
894
+ fwd_context.reset()
895
+ for _, bwd_context in self.module_bwd_hook_context_by_module.items():
896
+ bwd_context.reset()
897
+ self.grad_context.reset() # 权重梯度和激活值梯度都在这
898
+
899
+ for handle in self.handles['wgrads']:
900
+ handle.remove()
901
+ self.handles['wgrads'].clear()
902
+ self.weight_hooked = False
903
+
904
+ if self.optimizer_hooked:
905
+ self.pre_step_hooks.clear()
906
+ self.post_step_hooks.clear()
907
+ for _, context in self.optimizer_context.items():
908
+ context.reset()
909
+ self.optimizer_hooked = False
910
+
911
+ for handle in self.handles['cc']:
912
+ handle.remove()
913
+ self.handles['cc'].clear()
914
+ api_register.restore_api()
915
+ for _, context in self.cc_context.items():
916
+ context.reset()
917
+
918
+ # 清空节点缓存
919
+ self.param2name.clear()
920
+ self.name2index.clear()
921
+ self.name2indices.clear()
922
+ self.name2param.clear()
923
+ self.duplicate_param.clear()
924
+ self.name2tag.clear()
925
+ self.module_struct.clear()
926
+ self.grad_accs.clear()
927
+
928
+ # 关闭采集状态
929
+ self.monitoring = False
930
+
931
+ def _remove_all_hooks_final(self, optimizer):
932
+ if self.dynamic_enable:
933
+ # 结束后自动重置dynamic_on为False等待用户手动开启
934
+ try:
935
+ config = load_json(self.config_file_path)
936
+ config['dynamic_on'] = False
937
+ save_json(self.config_file_path, config, indent=2)
938
+ config_timestamp = os.path.getmtime(self.config_file_path)
939
+ self.config_timestamp = config_timestamp
940
+ logger.info(
941
+ "Finish monitor, set config'dynamic_on=False, will restart by set it to True and update config")
942
+ except Exception as e:
943
+ logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!")
944
+ logger.info("Finish monitor")
945
+ self._remove_all_hooks(optimizer)