mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -22,26 +22,29 @@ from functools import partial
22
22
  import pytz
23
23
  import torch
24
24
  import torch.distributed as dist
25
+ import pandas as pd
25
26
  from torch.utils.hooks import BackwardHook
26
27
 
27
28
  from msprobe.core.common.const import MonitorConst, Const
28
29
  from msprobe.core.common.file_utils import load_json, save_json
30
+ from msprobe.core.common.decorator import recursion_depth_decorator
31
+ from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter
32
+ from msprobe.core.common.file_utils import write_df_to_csv
33
+ from msprobe.core.common.utils import analyze_api_call_stack
29
34
  from msprobe.pytorch.common.log import logger
30
- from msprobe.pytorch.common.utils import is_recomputation
31
- from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter
32
- from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
33
- CSVWriterWithAD, BaseWriterWithAD, WriterInput
35
+ from msprobe.pytorch.common.utils import is_recomputation, is_float8_tensor
36
+ from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput
34
37
  from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
35
38
  get_process_group
36
39
  from msprobe.pytorch.monitor.features import get_sign_matches
37
40
  from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
38
41
  TensorMetrics, squash_param_name
39
- from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
40
42
  from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
41
43
  from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \
42
- get_output_base_dir, get_target_output_dir
44
+ get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor
43
45
  from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
44
46
 
47
+
45
48
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
46
49
  if not torch_version_above_or_equal_2:
47
50
  raise ValueError("monitor require torch>=2.0")
@@ -71,36 +74,7 @@ class ModuleHookContext:
71
74
  self.actvgrad = []
72
75
  self.module_name = module_name
73
76
  self.struct = {}
74
- self.format_by_arg = {}
75
- self.verified = False
76
- self.focused_in_col = 0
77
- self.focused_out_col = 0
78
-
79
- def set_format_by_arg(self, key_name: str, target_config: dict):
80
- """ 按照监控对象配置format_by_arg
81
- 1) module_name 在 target 中配置监控对象
82
- 2) module_name 未在 targets 中配置,且 all_xy 全量监控
83
- 3) module_name 未在 targets 中配置,且 all_xy 未全量监控
84
-
85
- :param key_name: str, one of [input, output, input_grad, output_grad]
86
- :param target_config: target obj in config json.
87
- :return:
88
- """
89
- cared = target_config.get(self.module_name, self.struct)
90
- if key_name in cared:
91
- target_module_config = cared[key_name]
92
- if isinstance(target_module_config, dict):
93
- # current cared is self.struct, monitor all data for module_name
94
- self.format_by_arg[key_name] = target_module_config.get('config')
95
- elif isinstance(target_module_config, str):
96
- # current cared is target_config[self.module_name]
97
- self.format_by_arg[key_name] = target_module_config
98
- else:
99
- logger.warning_on_rank_0(f"target module config error, result maybe empty."
100
- f"module_name: {self.module_name}, key_name: {key_name}")
101
- self.format_by_arg[key_name] = None
102
- else:
103
- self.format_by_arg[key_name] = self.struct.get(key_name).get('config')
77
+ self.stack = ""
104
78
 
105
79
  def reset(self):
106
80
  self.actv.clear()
@@ -176,15 +150,16 @@ class GradContext:
176
150
  class TrainerMon:
177
151
  tensor_metrics = TensorMetrics()
178
152
 
179
- def __init__(self, config_file_path, process_group=None, params_have_main_grad=True) -> None:
153
+ # 保留原opt_ty参数, 兼容msprobe1.2.2前旧版本
154
+ def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None:
180
155
  # TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
181
156
  self.config_file_path = config_file_path
182
157
  self.process_group = get_process_group(process_group)
183
158
  self.params_have_main_grad = params_have_main_grad
184
159
  self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
185
160
  self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
186
- self.origin_step_func = None
187
161
  self.origin_start_grad_sync = None
162
+ self.fsdp_post_backward_hook = None
188
163
  self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
189
164
  self.config = load_json(config_file_path)
190
165
  validate_config(self.config)
@@ -219,9 +194,10 @@ class TrainerMon:
219
194
  self.dp_group = None
220
195
  self.tp_group = None
221
196
  self.enable_megatron = False
197
+ self.fsdp_wrapped_module = False
222
198
  self.micro_batch_number = 1
223
- self.optimizer_class = None
224
199
  self.optimizer_mon = None
200
+ self.optimizer_trans = None
225
201
 
226
202
  # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
227
203
  self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
@@ -231,7 +207,6 @@ class TrainerMon:
231
207
  self.grad_context = GradContext()
232
208
  self.handles = defaultdict(list)
233
209
  self.param2name = defaultdict(str)
234
- self.name2index = defaultdict()
235
210
  self.name2indices = defaultdict()
236
211
  self.name2param = {}
237
212
  self.duplicate_param = {}
@@ -244,6 +219,8 @@ class TrainerMon:
244
219
  self.optimizer_hooked = False
245
220
  self.param_registered = False
246
221
  self.struct_printed = False
222
+ self.pre_step_hooks = []
223
+ self.post_step_hooks = []
247
224
 
248
225
  # 动静态区分
249
226
  self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
@@ -314,6 +291,8 @@ class TrainerMon:
314
291
  self.param_distribution = self.config.get("param_distribution", False)
315
292
  self.mg_direction = self.config.get('mg_direction', False)
316
293
  self.cc_distribution = self.config.get("cc_distribution", {})
294
+ self.stack_info = self.config.get('stack_info', False)
295
+ self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False)
317
296
 
318
297
  if not self.cc_distribution.get('enable', False):
319
298
  self.cc_log_only = False
@@ -322,8 +301,6 @@ class TrainerMon:
322
301
  self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
323
302
  self.cc_logged_stack = defaultdict(set)
324
303
  self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
325
- self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
326
- api_register.redirect_api()
327
304
 
328
305
  self.common_info()
329
306
 
@@ -336,11 +313,11 @@ class TrainerMon:
336
313
 
337
314
  # 初始化writer, 创建输出目录
338
315
  if self.format not in FORMAT_MAPPING:
339
- logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
316
+ logger.warning(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
340
317
  self.format = MonitorConst.CSV
341
318
 
342
319
  if self.ur_distribution and self.format != 'tensorboard':
343
- logger.error("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution")
320
+ logger.warning("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution")
344
321
  self.ur_distribution = False
345
322
 
346
323
  writer = FORMAT_MAPPING[self.format]
@@ -363,19 +340,6 @@ class TrainerMon:
363
340
  self.rank)
364
341
  self.anomaly_data_writer.init_detected_json()
365
342
 
366
- def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
367
- rank = None
368
- if dist.is_initialized():
369
- rank = dist.get_rank()
370
- if (rank not in rank_list) and len(rank_list) != 0:
371
- return
372
- self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
373
-
374
- def build_tbtag_tensor_map(self, module_name, tag, tensor):
375
- key = get_summary_writer_tag_name(module_name, tag, self.rank)
376
- self._register_param_call_id("_hook_module", key)
377
- return {key: tensor}
378
-
379
343
  def common_info(self):
380
344
  if not self.xy_distribution:
381
345
  logger.info_on_rank_0("> module input/output input_grad/output_grad is not monitored. ")
@@ -392,101 +356,38 @@ class TrainerMon:
392
356
  if not self.cc_distribution.get('enable', False):
393
357
  logger.info_on_rank_0("> cc operator is not monitored.")
394
358
 
395
- def hook_modules(self):
396
- if self.module_rank_list and (self.rank not in self.module_rank_list):
397
- return
398
-
399
- targets = self.config['targets']
400
- module_in_all_stage = [key for key in targets.keys() if MonitorConst.NAME_SEP not in key]
401
- for key in module_in_all_stage:
402
- struct = targets.pop(key)
403
- targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
404
-
405
- hooked_count = 0
406
- for vpp_stage, model_chunk in enumerate(self.model):
407
- vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
408
- targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
409
- 'targets'].keys()
410
- hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
411
-
412
- logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
413
-
414
- def clone_if_tensor(args):
415
- if isinstance(args, tuple):
416
- return tuple([clone_if_tensor(arg) for arg in args])
417
- elif isinstance(args, torch.Tensor):
418
- return args.clone()
419
- else:
420
- return args
421
-
422
- @torch.no_grad
423
- def wrap_hook_setup(setup):
424
- def wrapped_setup(*args, **kwargs):
425
- args = setup(*args, **kwargs)
426
- args = clone_if_tensor(args)
427
- return args
428
-
429
- return wrapped_setup
430
-
431
- BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
432
-
433
- return
434
-
435
- def generate_param_metrics(self, opt_context):
436
- if not self.param_distribution:
437
- return
438
- get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
439
-
440
- def generate_mv_metrics(self, opt_context):
441
- if not self.mv_distribution:
442
- return
443
- opt_context.exp_avg_metric = {}
444
- opt_context.exp_avg_sq_metric = {}
445
- m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
446
- v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
447
- get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
448
- get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
449
-
450
- def generate_wgrad_metrics(self):
451
- if not self.wg_distribution:
452
- return {}, {}
453
-
454
- if self.weight_hooked:
455
- get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
456
-
457
- grad_dict = {}
458
- for param, name in self.param2name.items():
459
- if self.duplicate_param.get(name, False):
460
- continue
461
- grad = param.main_grad if self.params_have_main_grad else param.grad
462
- if grad is None:
463
- logger.warning(f"grad is None: {name}, maybe something wrong happened.")
464
- continue
465
- tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
466
- self._register_param_call_id("hook_optimizer", tag)
467
- grad_dict[tag] = grad
359
+ # 保留原接口, 兼容msprobe1.2.2前旧版本
360
+ def monitor_gnorm_with_ad(self, model, optimizer=None, grad_acc_steps=1, tp_group=None, dp_group=None,
361
+ start_iteration=0):
362
+ if optimizer is None:
363
+ optimizer = getattr(self, "optimizer_trans", None) # 兼容老版本可传None的情况, 从set_wrapped_optimizer获取
364
+ if optimizer is None:
365
+ logger.error("monitor_gnorm_with_ad: please set_wrapped_optimizer before it or input optimizer!=None")
366
+ return
367
+ self.set_monitor(model, optimizer, grad_acc_steps, tp_group, dp_group, start_iteration)
468
368
 
469
- get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
470
- unreduced_grad = self.grad_context.acc_metric if self.weight_hooked else self.grad_context.pre
471
- return self.grad_context.post, unreduced_grad
369
+ # 保留原接口, 兼容msprobe1.2.2前旧版本
370
+ def set_wrapped_optimizer(self, optimizer):
371
+ self.optimizer_trans = optimizer
472
372
 
473
373
  def set_monitor(
474
374
  self,
475
375
  model,
376
+ optimizer,
476
377
  grad_acc_steps=1,
477
- optimizer=None,
478
378
  tp_group=None,
479
379
  dp_group=None,
480
380
  start_iteration=0
481
381
  ):
482
382
  """External interface"""
383
+ grad_acc_steps, start_iteration = validate_set_monitor(grad_acc_steps, start_iteration)
483
384
  global start_step
484
385
  start_step = start_iteration
485
386
  logger.info(f'grad acc steps {grad_acc_steps}')
486
387
  self.micro_batch_number = grad_acc_steps
487
388
  self.dp_group = dp_group
488
389
  self.tp_group = tp_group
489
- self.optimizer_mon, self.optimizer_class = OptimizerMonFactory.create_optimizer_mon(optimizer)
390
+ self.optimizer_mon = OptimizerMonFactory.create_optimizer_mon(optimizer)
490
391
  self.hook_step_final(optimizer)
491
392
  if not isinstance(model, list):
492
393
  model = [model]
@@ -502,18 +403,89 @@ class TrainerMon:
502
403
  self.hook_optimizer(optimizer)
503
404
  self._patch_grad_sync()
504
405
  self.hook_modules()
406
+ if self.cc_distribution.get('enable', False):
407
+ self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
408
+ api_register.redirect_api()
505
409
  self.monitoring = True
506
410
 
411
+ def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
412
+ rank = None
413
+ if dist.is_initialized():
414
+ rank = dist.get_rank()
415
+ if (rank not in rank_list) and len(rank_list) != 0:
416
+ return
417
+ self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
418
+
419
+ def build_tbtag_tensor_map(self, module_name, suffix, tag, tensor):
420
+ """
421
+ :param module_name: str of module name
422
+ :param suffix:
423
+ :param tag:
424
+ :param tensor: torch.tensor or tuple/list of torch.tensor
425
+ :return: tensor_map
426
+ """
427
+ tensor_map = {}
428
+ if isinstance(tensor, torch.Tensor):
429
+ tensor = [tensor]
430
+ if isinstance(tensor, tuple) or isinstance(tensor, list):
431
+ if len(tensor) == 1:
432
+ key = get_summary_writer_tag_name(module_name + suffix, tag, self.rank)
433
+ self.register_param_call_id("_hook_module", key)
434
+ tensor_map[key] = tensor[0]
435
+ else:
436
+ for i, tensor_i in enumerate(tensor):
437
+ key = get_summary_writer_tag_name(module_name + f"_{i}" + suffix, tag, self.rank)
438
+ self.register_param_call_id("_hook_module", key)
439
+ tensor_map[key] = tensor_i
440
+ return tensor_map
441
+
507
442
  def generate_param_map(self, tag, param_tensor):
508
443
  metrics = {}
509
444
  for name in self.param2name.values():
510
445
  key = get_summary_writer_tag_name(name, tag, self.rank)
511
- self._register_param_call_id("optimizer_pre_step_hook", key)
446
+ self.register_param_call_id("optimizer_pre_step_hook", key)
512
447
  if name not in param_tensor or param_tensor[name] is None:
513
448
  continue
514
449
  metrics[key] = param_tensor[name]
515
450
  return metrics
516
451
 
452
+ def generate_param_metrics(self, opt_context, stage=MonitorConst.PRE_PARAM):
453
+ if not self.param_distribution:
454
+ return
455
+ tag2param = {
456
+ self.name2tag.get(name, {}).get(stage): param
457
+ for name, param in self.name2param.items()
458
+ if param.numel() != 0
459
+ }
460
+ get_metrics(self.ops, tag2param, self.eps, opt_context.param_metric)
461
+
462
+ def generate_mv_metrics(self, opt_context):
463
+ if not self.mv_distribution:
464
+ return
465
+ opt_context.exp_avg_metric = {}
466
+ opt_context.exp_avg_sq_metric = {}
467
+ m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
468
+ v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
469
+ get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
470
+ get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
471
+
472
+ def generate_wgrad_metrics(self, post_grad_dict):
473
+ if not self.wg_distribution:
474
+ return {}, {}
475
+
476
+ if self.weight_hooked:
477
+ get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
478
+
479
+ get_metrics(self.ops, post_grad_dict, self.eps, self.grad_context.post)
480
+ reduced_grad = self.grad_context.post
481
+
482
+ if self.weight_hooked:
483
+ unreduced_grad = self.grad_context.acc_metric
484
+ else:
485
+ unreduced_grad = self.grad_context.pre
486
+
487
+ return reduced_grad, unreduced_grad
488
+
517
489
  def generate_xy_metrics(self):
518
490
  actv = {}
519
491
  for fwd_context in self.module_fwd_hook_context_by_module.values():
@@ -538,6 +510,17 @@ class TrainerMon:
538
510
  def write_adhoc_check(self, step):
539
511
  self.tensor_metrics.flush(self.summary_writer)
540
512
 
513
+ def write_stack_info(self):
514
+ stack_data = []
515
+ header = ["module_name", "stack_info"]
516
+ stack_data.append(header)
517
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
518
+ stack_data.append([fwd_context.module_name, fwd_context.stack])
519
+ filepath = os.path.join(self.tensorboard_dir, f'stack_info.csv')
520
+ if not os.path.exists(filepath):
521
+ data_frame = pd.DataFrame(columns=stack_data)
522
+ write_df_to_csv(data_frame, filepath)
523
+
541
524
  def write_xy_tb(self, step):
542
525
  if not self.xy_distribution:
543
526
  return
@@ -552,27 +535,31 @@ class TrainerMon:
552
535
  def write_param_tb(self, opt_context):
553
536
  if not self.param_distribution:
554
537
  return
555
- self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, MonitorConst.PARAM)
538
+ param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.PRE_PARAM in k}
539
+ updated_param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.POST_PARAM in k}
540
+ self.summary_writer.write_metrics(self.ops, param_metrics, opt_context.step, MonitorConst.PRE_PARAM)
541
+ self.summary_writer.write_metrics(self.ops, updated_param_metrics, opt_context.step, MonitorConst.POST_PARAM)
556
542
 
557
543
  def write_mv_tb(self, opt_context):
558
544
  if not self.mv_distribution:
559
545
  return
560
- self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric,
546
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric,
561
547
  opt_context.step, MonitorConst.EXP_AVG)
562
- self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric,
548
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric,
563
549
  opt_context.step, MonitorConst.EXP_AVG_SQ)
564
550
 
565
551
  def write_grad_tb(self, step):
566
552
  if not self.wg_distribution:
567
553
  return
568
554
 
569
- if self.enable_megatron:
570
- self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
555
+ if self.weight_hooked:
556
+ self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced',
557
+ use_micro_step=self.monitor_mbs_grad)
571
558
  else:
572
- self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
559
+ self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
573
560
  self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
574
561
 
575
- def hook_optimizer(self, optimizer=None):
562
+ def hook_optimizer(self, optimizer):
576
563
  # in DDP by default use params_have_main_grad
577
564
  def optimizer_pre_step_hook(optimizer, args, kwargs):
578
565
  context = self.optimizer_context[optimizer]
@@ -591,21 +578,23 @@ class TrainerMon:
591
578
  # skip generate metrics
592
579
  if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0:
593
580
  return
594
- if MonitorConst.DEEPSPEED_ZERO_OPT_FILTER in self.optimizer_class: # use deepspeed with zero1/2/3
595
- if not self.name2indices:
596
- self.name2indices = self.optimizer_mon.get_param_index(self.param2name, self.name2index, optimizer)
597
- mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name, self.name2indices)
598
- self.param2name = mv_result.grad
599
- else:
600
- mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name)
601
- context.param_exp_avg = mv_result.exp_avg
602
- context.param_exp_avg_sq = mv_result.exp_avg_sq
603
- context.param_adam_update = mv_result.update
604
- context.param_adam_ratio = mv_result.ratio
605
581
 
606
- self.generate_wgrad_metrics()
582
+ grad_dict = {}
583
+ if self.wg_distribution:
584
+ grad_dict = self.optimizer_mon.fetch_grad(self, self.param2name)
585
+
586
+ mv_result = None
587
+ if self.mv_distribution or self.ur_distribution or self.mg_direction:
588
+ mv_result = self.optimizer_mon.fetch_mv(self, self.param2name)
589
+ if mv_result:
590
+ context.param_exp_avg = mv_result.exp_avg
591
+ context.param_exp_avg_sq = mv_result.exp_avg_sq
592
+ context.param_adam_update = mv_result.update
593
+ context.param_adam_ratio = mv_result.ratio
594
+
595
+ self.generate_wgrad_metrics(grad_dict)
607
596
  self.generate_mv_metrics(context)
608
- self.generate_param_metrics(context)
597
+ self.generate_param_metrics(context, MonitorConst.PRE_PARAM)
609
598
 
610
599
  tbtag_tensor_map = {}
611
600
  if self.mg_direction:
@@ -633,18 +622,15 @@ class TrainerMon:
633
622
  context.metric_dict = metric_dict
634
623
  return
635
624
 
636
- def patch_step(func, optimizer):
637
- def wrapper(*args, **kwargs):
638
- optimizer_pre_step_hook(optimizer, args, kwargs)
639
- out = func(*args, **kwargs)
640
- return out
641
-
642
- return wrapper
625
+ def optimizer_post_step_hook(optimizer, args, kwargs):
626
+ context = self.optimizer_context[optimizer]
627
+ self.generate_param_metrics(context, MonitorConst.POST_PARAM)
643
628
 
644
629
  if self.optimizer_hooked:
645
630
  return
646
631
 
647
- optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
632
+ self.pre_step_hooks.append(optimizer_pre_step_hook)
633
+ self.post_step_hooks.append(optimizer_post_step_hook)
648
634
 
649
635
  self.optimizer_hooked = True
650
636
  return
@@ -674,6 +660,7 @@ class TrainerMon:
674
660
  validate_config(config)
675
661
  self.config = config
676
662
  self.set_config()
663
+ self.start_step = context.step # 动态启停时不受原start_step影响,永远从下一步开始
677
664
  logger.warning(f"config is updated at step{context.step - 1}, "
678
665
  f"will start new hook at step{context.step}.")
679
666
  except Exception as e:
@@ -703,6 +690,12 @@ class TrainerMon:
703
690
  self.write_mv_tb(context)
704
691
  self.write_param_tb(context)
705
692
  self.write_adhoc_check(context.step)
693
+ if self.stack_info:
694
+ self.write_stack_info()
695
+ self.stack_info = False
696
+ for handle in self.handles["stack"]:
697
+ handle.remove()
698
+ self.handles["stack"].clear()
706
699
 
707
700
  if self.ur_distribution:
708
701
  for param_name, _ in context.param_adam_update.items():
@@ -721,6 +714,9 @@ class TrainerMon:
721
714
  if self.anomaly_data_factory:
722
715
  self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
723
716
  self.summary_writer.clear_anomalies()
717
+
718
+ if self.format == MonitorConst.TENSORBOARD:
719
+ chmod_tensorboard_dir(self.tensorboard_dir)
724
720
  self.call_id = 0
725
721
  self.param_name_call_id.clear()
726
722
 
@@ -732,16 +728,69 @@ class TrainerMon:
732
728
 
733
729
  def patch_step(func, optimizer):
734
730
  def wrapper(*args, **kwargs):
731
+ for hook in self.pre_step_hooks:
732
+ hook(optimizer, args, kwargs)
735
733
  out = func(*args, **kwargs)
734
+ for hook in self.post_step_hooks:
735
+ hook(optimizer, args, kwargs)
736
736
  step_final_hook(optimizer, args, kwargs)
737
737
  return out
738
738
  return wrapper
739
739
 
740
740
  optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
741
- self.origin_step_func = optimizer.__class__.step
741
+ return
742
+
743
+ def hook_modules(self):
744
+ if self.module_rank_list and (self.rank not in self.module_rank_list):
745
+ return
742
746
 
747
+ targets = self.config['targets']
748
+ module_in_all_stage = [key for key in targets.keys() if MonitorConst.NAME_SEP not in key]
749
+ for key in module_in_all_stage:
750
+ struct = targets.pop(key)
751
+ targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
752
+
753
+ hooked_count = 0
754
+ for vpp_stage, model_chunk in enumerate(self.model):
755
+ vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
756
+ targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
757
+ 'targets'].keys()
758
+ hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
759
+
760
+ logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
761
+
762
+ @recursion_depth_decorator('msprobe.pytorch.monitor.clone_if_tensor')
763
+ def clone_if_tensor(args):
764
+ if isinstance(args, tuple):
765
+ return tuple([clone_if_tensor(arg) for arg in args])
766
+ elif isinstance(args, torch.Tensor) and not is_float8_tensor(args):
767
+ return args.clone()
768
+ else:
769
+ return args
770
+
771
+ @torch.no_grad
772
+ def wrap_hook_setup(setup):
773
+ def wrapped_setup(*args, **kwargs):
774
+ args = setup(*args, **kwargs)
775
+ args = clone_if_tensor(args)
776
+ return args
777
+
778
+ return wrapped_setup
779
+
780
+ BackwardHook.setup_input_hook = wrap_hook_setup(BackwardHook.setup_input_hook)
781
+ BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
743
782
  return
744
783
 
784
+ def register_param_call_id(self, hook_name: str, key: str):
785
+ """
786
+ :param hook_name:
787
+ :param key: str, '0:relu_0/output_grad'
788
+ :return:
789
+ """
790
+ logger.debug(f"{hook_name} {key}: {self.call_id}")
791
+ self.param_name_call_id[key] = self.call_id
792
+ self.call_id += 1
793
+
745
794
  def _remove_all_hooks(self, optimizer):
746
795
  # 清空hook handle
747
796
  for handle in self.handles['xy']:
@@ -767,14 +816,18 @@ class TrainerMon:
767
816
  logger.info("remove _ParamAndGradBucketGroup start_grad_sync")
768
817
  except ImportError:
769
818
  pass
770
- else: # not megatron
819
+ elif self.fsdp_post_backward_hook: # fsdp
820
+ torch.distributed.fsdp._runtime_utils._post_backward_hook = self.fsdp_post_backward_hook
821
+ logger.info("remove patch_post_backward_hook in fsdp.")
822
+ else: # not megatron and not fsdp
771
823
  for handle in self.handles['wgrads']:
772
824
  handle.remove()
773
825
  self.handles['wgrads'].clear()
774
826
  self.weight_hooked = False
775
827
 
776
828
  if self.optimizer_hooked:
777
- optimizer.__class__.step = self.origin_step_func
829
+ self.pre_step_hooks.clear()
830
+ self.post_step_hooks.clear()
778
831
 
779
832
  for _, context in self.optimizer_context.items():
780
833
  context.reset()
@@ -783,12 +836,12 @@ class TrainerMon:
783
836
  for handle in self.handles['cc']:
784
837
  handle.remove()
785
838
  self.handles['cc'].clear()
839
+ api_register.restore_api()
786
840
  for _, context in self.cc_context.items():
787
841
  context.reset()
788
842
 
789
843
  # 清空节点缓存
790
844
  self.param2name.clear()
791
- self.name2index.clear()
792
845
  self.name2indices.clear()
793
846
  self.name2param.clear()
794
847
  self.duplicate_param.clear()
@@ -848,27 +901,33 @@ class TrainerMon:
848
901
  return False
849
902
 
850
903
  def _register_chunk(self, model_chunk, prefix):
851
- index = 0
852
904
  for (param_name, param) in model_chunk.named_parameters():
853
905
  if not param.requires_grad:
854
906
  continue
907
+ if not self.fsdp_wrapped_module and param_name.startswith("_fsdp_wrapped_module"):
908
+ self.fsdp_wrapped_module = True
855
909
  if self._is_target_param(param_name, param, prefix):
856
910
  name = prefix + squash_param_name(param_name, self.squash_name)
857
911
  if name in self.param2name.values():
858
912
  name = prefix + param_name
859
913
  self.param2name[param] = name
860
914
  self.name2param[name] = param
861
- self.name2index[name] = index
862
915
 
863
916
  if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group):
864
917
  self.duplicate_param[name] = True
865
918
  if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
866
919
  self.duplicate_param[name] = True
920
+
921
+ keywords = [
922
+ MonitorConst.PRE_GRAD,
923
+ MonitorConst.POST_GRAD,
924
+ MonitorConst.PRE_PARAM,
925
+ MonitorConst.POST_PARAM
926
+ ]
867
927
  self.name2tag[name] = {
868
- MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank),
869
- MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank)
928
+ k: get_summary_writer_tag_name(name, k, self.rank)
929
+ for k in keywords
870
930
  }
871
- index += 1
872
931
 
873
932
  def _register_param_name(self):
874
933
  for vpp_stage, model_chunk in enumerate(self.model):
@@ -891,11 +950,17 @@ class TrainerMon:
891
950
  # nothing to hook
892
951
  return 0
893
952
 
894
- def fwd_hook_fun(module, module_input, module_output, name):
953
+ def fwd_hook_fun(module, args, kwargs, module_output, name):
895
954
  if not module.training or is_recomputation():
896
955
  # 1 only monitor training stage.
897
956
  # 2 when open recompute, skip recomputed forward stage.
898
957
  return
958
+
959
+ module_input = [tensor for tensor in args if torch.is_tensor(tensor)]
960
+ if kwargs:
961
+ kwargs_tensors = [tensor for tensor in kwargs.values() if torch.is_tensor(tensor)]
962
+ module_input.extend(kwargs_tensors)
963
+
899
964
  if module not in self.module_fwd_hook_context_by_module:
900
965
  self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
901
966
  context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
@@ -904,34 +969,20 @@ class TrainerMon:
904
969
  Const.INPUT: get_param_struct(module_input),
905
970
  Const.OUTPUT: get_param_struct(module_output)
906
971
  }
972
+
907
973
  if self.print_struct:
908
974
  self.module_struct[context.module_name].update(context.struct)
909
975
  return
910
- if not context.format_by_arg:
911
- context.set_format_by_arg(Const.INPUT, self.config['targets'])
912
- context.set_format_by_arg(Const.OUTPUT, self.config['targets'])
913
- if not context.format_by_arg:
914
- return
915
- if not context.verified:
916
- context.focused_in_col = validate_config_spec(context.format_by_arg[Const.INPUT],
917
- module_input, context.module_name,
918
- Const.INPUT)
919
- context.focused_out_col = validate_config_spec(context.format_by_arg[Const.OUTPUT],
920
- module_output, context.module_name,
921
- Const.OUTPUT)
922
- context.verified = True
923
- # expect output be tensor type
976
+
924
977
  tbtag_tensor_map = {}
925
- cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
926
978
  tbtag_tensor_map.update(
927
979
  self.build_tbtag_tensor_map(
928
- f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
929
- MonitorConst.ACTV, cared_input))
930
- cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
980
+ f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
981
+ MonitorConst.ACTV, module_input))
931
982
  tbtag_tensor_map.update(
932
983
  self.build_tbtag_tensor_map(
933
- f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
934
- MonitorConst.ACTV, cared_output))
984
+ f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
985
+ MonitorConst.ACTV, module_output))
935
986
 
936
987
  get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
937
988
  context.micro_step += 1
@@ -949,31 +1000,17 @@ class TrainerMon:
949
1000
  if self.print_struct:
950
1001
  self.module_struct[context.module_name].update(context.struct)
951
1002
  return
952
- if not context.format_by_arg:
953
- context.set_format_by_arg(MonitorConst.INPUT_GRAD, self.config['targets'])
954
- context.set_format_by_arg(MonitorConst.OUTPUT_GRAD, self.config['targets'])
955
- if not context.format_by_arg:
956
- return
957
- if not context.verified:
958
- context.focused_in_col = validate_config_spec(
959
- context.format_by_arg[MonitorConst.INPUT_GRAD],
960
- input_grad, context.module_name, MonitorConst.INPUT_GRAD)
961
- context.focused_out_col = validate_config_spec(
962
- context.format_by_arg[MonitorConst.OUTPUT_GRAD],
963
- output_grad, context.module_name, MonitorConst.OUTPUT_GRAD)
964
- context.verified = True
965
1003
 
966
1004
  tbtag_tensor_map = {}
967
- cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
968
1005
  tbtag_tensor_map.update(
969
1006
  self.build_tbtag_tensor_map(
970
- f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
971
- MonitorConst.ACTV, cared_input_grad))
972
- cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
1007
+ f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
1008
+ MonitorConst.ACTVGRAD, input_grad))
1009
+
973
1010
  tbtag_tensor_map.update(
974
1011
  self.build_tbtag_tensor_map(
975
- f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
976
- MonitorConst.ACTV, cared_output_grad))
1012
+ f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
1013
+ MonitorConst.ACTVGRAD, output_grad))
977
1014
 
978
1015
  if context.micro_step == 0 and context.actvgrad:
979
1016
  logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
@@ -987,17 +1024,30 @@ class TrainerMon:
987
1024
  context.micro_step = 0
988
1025
  return
989
1026
 
1027
+ def stack_hook(module, args, kwargs, module_output, name):
1028
+ if module not in self.module_fwd_hook_context_by_module:
1029
+ self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
1030
+ context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
1031
+ context.stack = analyze_api_call_stack(name)
1032
+ return
1033
+
990
1034
  if self.backward_only and self.forward_only:
991
1035
  logger.warning('not enable backward_only and forward_only simultaneously')
992
1036
 
993
1037
  hooked_count = 0
994
- if self.xy_distribution or self.print_struct:
995
- for module_name, submodule in module.named_modules():
996
- name = self._is_target_module(module_name, target_names, vpp_stage)
997
- if not name:
998
- continue
1038
+ for module_name, submodule in module.named_modules():
1039
+ if self.stack_info:
1040
+ name = vpp_stage + squash_param_name(module_name, self.squash_name)
1041
+ handle = submodule.register_forward_hook(partial(stack_hook, name=name), with_kwargs=True)
1042
+ self.handles['stack'].append(handle)
1043
+ name = self._is_target_module(module_name, target_names, vpp_stage)
1044
+ if not name:
1045
+ continue
1046
+ if submodule.__class__.__name__ == "FullyShardedDataParallel":
1047
+ continue
1048
+ if self.xy_distribution or self.print_struct:
999
1049
  if not self.backward_only:
1000
- handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name))
1050
+ handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name), with_kwargs=True)
1001
1051
  self.handles['xy'].append(handle)
1002
1052
  if not self.forward_only and not self.has_register_backward_hook(name, submodule):
1003
1053
  handle = submodule.register_full_backward_hook(bwd_hook_fun)
@@ -1026,7 +1076,7 @@ class TrainerMon:
1026
1076
  if tag is None:
1027
1077
  continue
1028
1078
  grad_dict[tag] = grad
1029
- self._register_param_call_id("sync_grad_func", tag)
1079
+ self.register_param_call_id("sync_grad_func", tag)
1030
1080
  get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
1031
1081
  out = sync_grad_func(bucket)
1032
1082
  return out
@@ -1035,7 +1085,14 @@ class TrainerMon:
1035
1085
 
1036
1086
  if not self.wg_distribution:
1037
1087
  return
1088
+ if self.fsdp_wrapped_module:
1089
+ # patch fsdp _runtime_utils._post_backward_hook
1090
+ self._patch_fsdp_post_backward_hook()
1091
+ return
1038
1092
 
1093
+ if self.monitor_mbs_grad:
1094
+ self._hook_weights()
1095
+ return
1039
1096
  try:
1040
1097
  from megatron.core.distributed.param_and_grad_buffer import Bucket
1041
1098
  self.origin_start_grad_sync = Bucket.start_grad_sync
@@ -1052,44 +1109,82 @@ class TrainerMon:
1052
1109
  self.enable_megatron = True
1053
1110
  logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0")
1054
1111
  except ImportError:
1055
- self.enable_megatron = False
1112
+ self.enable_megatron = False | self.enable_megatron
1113
+ if self.enable_megatron:
1114
+ return
1056
1115
 
1057
- if not self.enable_megatron:
1058
- self._hook_weights()
1116
+ # default hook weights
1117
+ self._hook_weights()
1118
+
1119
+ def _patch_fsdp_post_backward_hook(self):
1120
+ """
1121
+ FSDP runtime 需要处理整个forward和backward计算和通信的流程,通过override nn.Module的forward,定义相应的逻辑。
1122
+ 对AccumulateGrad对象注册hook,可以在backward计算grad后立刻执行,在reduce_scatter操作前采集梯度累计后,通信聚合前的梯度。
1123
+ 每个forward阶段,fsdp对AccumulateGrad重复注册hook方法,monitor工具内注册hook无法生效,
1124
+ 因此对_post_backward_hook进行patch,在backward后,reduce_scatter前采集梯度。
1125
+ """
1126
+ def patch_post_backward_hook(_post_backward_hook):
1127
+ def wrapper(state, handle, *unused):
1128
+ grad_dict = {}
1129
+ offset = 0
1130
+ for param, name in self.param2name.items():
1131
+ limit = param.numel()
1132
+ if not limit:
1133
+ continue
1134
+ grad = handle.flat_param.grad[offset:offset + limit]
1135
+ offset += limit
1136
+ tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
1137
+ if tag is None:
1138
+ continue
1139
+ grad_dict[tag] = grad
1140
+ self.register_param_call_id("_post_backward_hook", tag)
1141
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
1142
+ out = _post_backward_hook(state, handle, *unused)
1143
+ return out
1144
+
1145
+ return wrapper
1146
+
1147
+ logger.info("Patch fsdp _post_backward_hook, collect pre_grad metrics.")
1148
+ self.fsdp_post_backward_hook = torch.distributed.fsdp._runtime_utils._post_backward_hook
1149
+ torch.distributed.fsdp._runtime_utils._post_backward_hook = \
1150
+ patch_post_backward_hook(torch.distributed.fsdp._runtime_utils._post_backward_hook)
1059
1151
 
1060
1152
  def _hook_weights(self):
1153
+ """
1154
+ 遍历参数的梯度生成函数(grad_acc),并挂载hook,以便在该参数所有梯度计算后,采集通信聚合前梯度数据。
1155
+ """
1061
1156
  context = self.grad_context
1062
1157
 
1063
1158
  @torch.no_grad
1064
- def param_hook(*args, context_dict, param, key, name):
1159
+ def param_hook(*args, context_dict, param, name):
1160
+ key = name
1161
+ if self.monitor_mbs_grad:
1162
+ key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
1163
+
1164
+ key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
1165
+ self.register_param_call_id("param_hook", key)
1065
1166
  param.micro_step += 1
1066
- self._register_param_call_id("param_hook", key)
1067
- if param.micro_step == self.micro_batch_number:
1068
- param.micro_step = 0
1167
+
1168
+ if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
1069
1169
  if self.params_have_main_grad:
1070
- context_dict[key] = param.main_grad.clone()
1170
+ grad = param.main_grad
1071
1171
  else:
1072
- context_dict[key] = param.grad.clone()
1172
+ grad = param.grad
1173
+ if is_float8_tensor(grad):
1174
+ grad = grad.float()
1175
+ context_dict[key] = grad.clone()
1176
+
1177
+ if param.micro_step == self.micro_batch_number:
1178
+ param.micro_step = 0
1073
1179
 
1074
1180
  logger.info("hooking weights.")
1075
1181
  for param, name in self.param2name.items():
1076
- key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
1077
1182
  setattr(param, 'micro_step', 0)
1078
1183
  param_tmp = param.expand_as(param)
1079
1184
  grad_acc = param_tmp.grad_fn.next_functions[0][0]
1080
1185
  handle = grad_acc.register_hook(
1081
- partial(param_hook, context_dict=context.acc, param=param, key=key, name=name))
1186
+ partial(param_hook, context_dict=context.acc, param=param, name=name))
1082
1187
  self.grad_accs.append(grad_acc)
1083
1188
  self.handles['wgrads'].append(handle)
1084
1189
 
1085
1190
  self.weight_hooked = True
1086
-
1087
- def _register_param_call_id(self, hook_name: str, key: str):
1088
- """
1089
- :param hook_name:
1090
- :param key: str, '0:relu_0/output_grad'
1091
- :return:
1092
- """
1093
- logger.debug(f"{hook_name} {key}: {self.call_id}")
1094
- self.param_name_call_id[key] = self.call_id
1095
- self.call_id += 1