mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -15,18 +15,21 @@
15
15
  import json
16
16
  import os
17
17
  import uuid
18
+ import importlib
18
19
  from collections import defaultdict
19
20
  from datetime import datetime
20
21
  from functools import partial
22
+ from itertools import cycle
21
23
 
22
24
  import pytz
23
25
  import torch
24
26
  import torch.distributed as dist
25
27
  import pandas as pd
26
28
  from torch.utils.hooks import BackwardHook
29
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
27
30
 
28
31
  from msprobe.core.common.const import MonitorConst, Const
29
- from msprobe.core.common.file_utils import load_json, save_json
32
+ from msprobe.core.common.file_utils import load_json, save_json, make_dir
30
33
  from msprobe.core.common.decorator import recursion_depth_decorator
31
34
  from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter
32
35
  from msprobe.core.common.file_utils import write_df_to_csv
@@ -39,9 +42,9 @@ from msprobe.pytorch.monitor.utils import get_param_struct
39
42
  from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput
40
43
  from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
41
44
  get_process_group
42
- from msprobe.pytorch.monitor.features import get_sign_matches
45
+ from msprobe.pytorch.monitor.features import get_sign_matches, cal_qkt
43
46
  from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
44
- TensorMetrics, squash_param_name
47
+ TensorMetrics, squash_param_name, get_entropy_metric, get_sr_metric
45
48
  from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
46
49
  from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
47
50
 
@@ -56,6 +59,7 @@ FORMAT_MAPPING = {
56
59
  MonitorConst.CSV: CSVWriterWithAD,
57
60
  MonitorConst.API: BaseWriterWithAD
58
61
  }
62
+ start_step = 0
59
63
 
60
64
 
61
65
  def param_is_not_tensor_parallel_duplicate(param, tp_group):
@@ -82,7 +86,17 @@ class ModuleHookContext:
82
86
  self.actvgrad.clear()
83
87
 
84
88
 
85
- start_step = 0
89
+ class FeatureHookContext:
90
+ def __init__(self, module_name):
91
+ self.step = 0
92
+ self.micro_step = 0
93
+ self.attention_feature = {}
94
+ self.linear_feature = {}
95
+ self.module_name = module_name
96
+
97
+ def reset(self):
98
+ self.attention_feature.clear()
99
+ self.linear_feature.clear()
86
100
 
87
101
 
88
102
  class OptimizerContext:
@@ -159,8 +173,8 @@ class TrainerMon:
159
173
  self.params_have_main_grad = params_have_main_grad
160
174
  self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
161
175
  self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
162
- self.origin_start_grad_sync = None
163
176
  self.fsdp_post_backward_hook = None
177
+ self.fsdp2_foreach_reduce = None
164
178
  self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
165
179
  self.config = load_json(config_file_path)
166
180
  validate_config(self.config)
@@ -195,7 +209,9 @@ class TrainerMon:
195
209
  self.dp_group = None
196
210
  self.tp_group = None
197
211
  self.enable_megatron = False
212
+ self.enable_deepspeed = False
198
213
  self.fsdp_wrapped_module = False
214
+ self.fsdp2_wrapped_module = False
199
215
  self.micro_batch_number = 1
200
216
  self.optimizer_mon = None
201
217
  self.optimizer_trans = None
@@ -203,6 +219,7 @@ class TrainerMon:
203
219
  # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
204
220
  self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
205
221
  self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
222
+ self.feature_hook_context_by_module = defaultdict(FeatureHookContext)
206
223
  self.optimizer_context = defaultdict(OptimizerContext)
207
224
  self.cc_context = defaultdict(CommunicationContext)
208
225
  self.grad_context = GradContext()
@@ -210,9 +227,12 @@ class TrainerMon:
210
227
  self.param2name = defaultdict(str)
211
228
  self.name2indices = defaultdict()
212
229
  self.name2param = {}
230
+ self.origin2squash = {}
213
231
  self.duplicate_param = {}
214
232
  self.name2tag = {}
215
233
  self.param_name_call_id = {}
234
+ self.flat_prefix_names = []
235
+ self.flat_prefix_reverse_iter = None
216
236
  self.call_id = 0
217
237
  self.module_struct = defaultdict(dict)
218
238
  self.grad_accs = []
@@ -270,6 +290,18 @@ class TrainerMon:
270
290
  cc_tensor.reset()
271
291
  return metrics
272
292
 
293
+ @staticmethod
294
+ def get_linear_hook_target(module):
295
+ if isinstance(module, torch.nn.Embedding):
296
+ return ''
297
+ if hasattr(module, "num_embeddings") or hasattr(module, "vocab_start_index"):
298
+ return ''
299
+ for weight_name in ["weight", "wg"]:
300
+ if hasattr(module, weight_name) and isinstance(getattr(module, weight_name), torch.Tensor):
301
+ if getattr(module, weight_name).dim() == 2:
302
+ return weight_name
303
+ return ''
304
+
273
305
  def set_config(self):
274
306
  logger.info(f"current config: {self.config}")
275
307
  self.start_step = self.config.get("start_step", 0)
@@ -294,6 +326,8 @@ class TrainerMon:
294
326
  self.cc_distribution = self.config.get("cc_distribution", {})
295
327
  self.stack_info = self.config.get('stack_info', False)
296
328
  self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False)
329
+ self.recording_l2_features = self.config.get("recording_l2_features", False)
330
+ self.sa_order = self.config.get("sa_order", "s,b,h,d")
297
331
 
298
332
  if not self.cc_distribution.get('enable', False):
299
333
  self.cc_log_only = False
@@ -352,6 +386,8 @@ class TrainerMon:
352
386
  logger.info_on_rank_0("> momentum and variance of adam is not monitored. ")
353
387
  if not self.wg_distribution:
354
388
  logger.info_on_rank_0("> weight grad of specified module is not monitored. ")
389
+ if not self.recording_l2_features:
390
+ logger.info_on_rank_0("> l2 features of specified module is not monitored. ")
355
391
  if not self.mg_direction:
356
392
  logger.info_on_rank_0('> grad and momentum direction will not be compared.')
357
393
  if not self.cc_distribution.get('enable', False):
@@ -533,6 +569,24 @@ class TrainerMon:
533
569
  if self.grad_context.actv:
534
570
  self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD)
535
571
 
572
+ def write_metrics_if_not_empty(self, features, metrics, step, hook_name):
573
+ if not features or len(features) == 0:
574
+ return
575
+ use_micro_step = hook_name not in ["linear_hook"]
576
+ self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=use_micro_step)
577
+ features.clear()
578
+
579
+ def write_features_tb(self, step):
580
+ if not self.recording_l2_features:
581
+ return
582
+ for context in self.feature_hook_context_by_module.values():
583
+ num_features = len(context.attention_feature) + len(context.linear_feature)
584
+ if num_features == 0:
585
+ continue
586
+ self.write_metrics_if_not_empty(context.attention_feature, ["entropy", "softmax_max"],
587
+ step, "attention_hook")
588
+ self.write_metrics_if_not_empty(context.linear_feature, ["sr", "kernel_norm"], step, "linear_hook")
589
+
536
590
  def write_param_tb(self, opt_context):
537
591
  if not self.param_distribution:
538
592
  return
@@ -687,6 +741,7 @@ class TrainerMon:
687
741
  if self.anomaly_data_factory:
688
742
  self.anomaly_data_factory.set_call_id(self.param_name_call_id)
689
743
  self.write_xy_tb(context.step)
744
+ self.write_features_tb(context.step)
690
745
  self.write_grad_tb(context.step)
691
746
  self.write_mv_tb(context)
692
747
  self.write_param_tb(context)
@@ -756,7 +811,8 @@ class TrainerMon:
756
811
  vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
757
812
  targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
758
813
  'targets'].keys()
759
- hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
814
+ l2_target_names = self.config.get('l2_targets', '')
815
+ hooked_count += self._hook_module(targets, l2_target_names, model_chunk, vpp_stage)
760
816
 
761
817
  logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
762
818
 
@@ -797,6 +853,9 @@ class TrainerMon:
797
853
  for handle in self.handles['xy']:
798
854
  handle.remove()
799
855
  self.handles['xy'].clear()
856
+ for handle in self.handles['L2_features']:
857
+ handle.remove()
858
+ self.handles['L2_features'].clear()
800
859
  # 清空对应context缓存
801
860
  for _, fwd_context in self.module_fwd_hook_context_by_module.items():
802
861
  fwd_context.reset()
@@ -804,22 +863,14 @@ class TrainerMon:
804
863
  bwd_context.reset()
805
864
  self.grad_context.reset() # 权重梯度和激活值梯度都在这
806
865
 
807
- if self.origin_start_grad_sync: # megatron
808
- try:
809
- from megatron.core.distributed.param_and_grad_buffer import Bucket
810
- Bucket.start_grad_sync = self.origin_start_grad_sync
811
- logger.info("remove Bucket start_grad_sync")
812
- except ImportError:
813
- pass
814
- try:
815
- from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
816
- _ParamAndGradBucketGroup.start_grad_sync = self.origin_start_grad_sync
817
- logger.info("remove _ParamAndGradBucketGroup start_grad_sync")
818
- except ImportError:
819
- pass
820
- elif self.fsdp_post_backward_hook: # fsdp
866
+ self.optimizer_mon.restore_grad_sync(self)
867
+ if self.fsdp_post_backward_hook: # fsdp
821
868
  torch.distributed.fsdp._runtime_utils._post_backward_hook = self.fsdp_post_backward_hook
822
869
  logger.info("remove patch_post_backward_hook in fsdp.")
870
+ if self.fsdp2_foreach_reduce: # fsdp2
871
+ torch.distributed.fsdp._fully_shard._fsdp_collectives.foreach_reduce = self.fsdp2_foreach_reduce
872
+ importlib.reload(torch.distributed.fsdp._fully_shard._fsdp_param_group)
873
+ logger.info("remove patch_foreach_reduce_hook in fsdp2.")
823
874
  else: # not megatron and not fsdp
824
875
  for handle in self.handles['wgrads']:
825
876
  handle.remove()
@@ -881,14 +932,11 @@ class TrainerMon:
881
932
  logger.info(msg)
882
933
 
883
934
  def _save_module_struct(self):
884
- save_module_struct = (not dist.is_initialized()
885
- or (self.module_rank_list and dist.get_rank() == min(self.module_rank_list))
886
- or (not self.module_rank_list and dist.get_rank() == 0))
887
-
888
- if save_module_struct:
889
- module_struct_file = os.path.realpath(os.path.join(get_output_base_dir(), 'module_struct.json'))
890
- save_json(module_struct_file, self.module_struct, indent=2)
891
- logger.info(f"> save module struct to {module_struct_file}")
935
+ output_dir = os.path.join(get_output_base_dir(), 'module_struct', f'rank{self.rank}')
936
+ make_dir(output_dir)
937
+ module_struct_file = os.path.realpath(os.path.join(output_dir, 'module_struct.json'))
938
+ save_json(module_struct_file, self.module_struct, indent=2)
939
+ logger.info(f"> save module struct to {module_struct_file}")
892
940
  self.struct_printed = True
893
941
 
894
942
  def _is_target_param(self, param_name, param, prefix):
@@ -896,23 +944,32 @@ class TrainerMon:
896
944
  squash_name = prefix + squash_param_name(param_name, self.squash_name)
897
945
  for target in self.config['targets'].keys():
898
946
  if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
899
- setattr(param, "zero_out_wgrad", True)
900
947
  return True
901
948
 
902
949
  return False
903
950
 
904
951
  def _register_chunk(self, model_chunk, prefix):
952
+ if isinstance(model_chunk, FSDP):
953
+ if not model_chunk._use_orig_params:
954
+ raise ValueError("Only Support fsdp1 with use_orig_params=True")
955
+ self.fsdp_wrapped_module = True
905
956
  for (param_name, param) in model_chunk.named_parameters():
906
957
  if not param.requires_grad:
907
958
  continue
908
- if not self.fsdp_wrapped_module and param_name.startswith("_fsdp_wrapped_module"):
909
- self.fsdp_wrapped_module = True
959
+ if not self.fsdp2_wrapped_module and param.__class__.__name__ == "DTensor":
960
+ self.fsdp2_wrapped_module = True
961
+ if self.fsdp_wrapped_module: # FSDP1需要记录完整的不被target限制的flat权重前缀名,以供后续对flat解包
962
+ flat_prefix_name, _ = param_name.rsplit(MonitorConst.FSDP_FLAT_SEP, 1)
963
+ if flat_prefix_name not in self.flat_prefix_names:
964
+ self.flat_prefix_names.append(flat_prefix_name)
965
+
910
966
  if self._is_target_param(param_name, param, prefix):
911
967
  name = prefix + squash_param_name(param_name, self.squash_name)
912
968
  if name in self.param2name.values():
913
969
  name = prefix + param_name
914
970
  self.param2name[param] = name
915
971
  self.name2param[name] = param
972
+ self.origin2squash[param_name] = name
916
973
 
917
974
  if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group):
918
975
  self.duplicate_param[name] = True
@@ -929,6 +986,8 @@ class TrainerMon:
929
986
  k: get_summary_writer_tag_name(name, k, self.rank)
930
987
  for k in keywords
931
988
  }
989
+ if self.fsdp_wrapped_module:
990
+ self.flat_prefix_reverse_iter = cycle(reversed(self.flat_prefix_names)) # post_backward_hook调用顺序是反向的
932
991
 
933
992
  def _register_param_name(self):
934
993
  for vpp_stage, model_chunk in enumerate(self.model):
@@ -946,7 +1005,20 @@ class TrainerMon:
946
1005
  return pattern
947
1006
  return ""
948
1007
 
949
- def _hook_module(self, target_names, module: torch.nn.Module, vpp_stage=''):
1008
+ def _is_recording_module(self, module_name, l2_targets, vpp_stage, hook_name):
1009
+
1010
+ if len(l2_targets) > 0:
1011
+ for pattern in [
1012
+ vpp_stage + squash_param_name(module_name, self.squash_name),
1013
+ vpp_stage + module_name,
1014
+ ]:
1015
+ if pattern in l2_targets:
1016
+ return pattern
1017
+ elif hook_name in ["linear_hook"]:
1018
+ return vpp_stage + squash_param_name(module_name, self.squash_name)
1019
+ return ""
1020
+
1021
+ def _hook_module(self, target_names, l2_target_names, module: torch.nn.Module, vpp_stage=''):
950
1022
  if '_modules' not in module.__dict__:
951
1023
  # nothing to hook
952
1024
  return 0
@@ -1025,6 +1097,61 @@ class TrainerMon:
1025
1097
  context.micro_step = 0
1026
1098
  return
1027
1099
 
1100
+ def extract_attention_feature_hook(module, module_input, module_output, name):
1101
+ if is_recomputation() or not module.training:
1102
+ return
1103
+
1104
+ if module not in self.feature_hook_context_by_module:
1105
+ self.feature_hook_context_by_module[module] = FeatureHookContext(name)
1106
+ context: FeatureHookContext = self.feature_hook_context_by_module[module]
1107
+ tbtag_tensor_map = {}
1108
+ if len(module_input) < 2:
1109
+ logger.warning(
1110
+ f"Length of module_input in attention hook ({name}) is {len(module_input)}, "
1111
+ "expected >= 2. Skipping feature extraction for this module."
1112
+ )
1113
+ return
1114
+ q_h = module_input[0]
1115
+ k_h = module_input[1]
1116
+ qkt = cal_qkt(q_h, k_h, order=self.sa_order)
1117
+ tbtag_tensor_map.update(
1118
+ self.build_tbtag_tensor_map(f'{context.module_name}.attention',
1119
+ f'{MonitorConst.NAME_SEP}{context.micro_step}', 'qkt', qkt)
1120
+ )
1121
+ get_entropy_metric(tbtag_tensor_map, context.attention_feature)
1122
+
1123
+ context.micro_step += 1
1124
+ if context.micro_step == self.micro_batch_number:
1125
+ context.micro_step = 0
1126
+ context.step += 1
1127
+ return
1128
+
1129
+ def extract_linear_sr_hook(module, module_input, module_output, name):
1130
+ if is_recomputation() or not module.training:
1131
+ return
1132
+ weight_name = self.get_linear_hook_target(module)
1133
+ if weight_name == '':
1134
+ return
1135
+
1136
+ if module not in self.feature_hook_context_by_module:
1137
+ self.feature_hook_context_by_module[module] = FeatureHookContext(name)
1138
+ context: FeatureHookContext = self.feature_hook_context_by_module[module]
1139
+
1140
+ if context.micro_step == (self.micro_batch_number - 1):
1141
+ tbtag_tensor_map = {}
1142
+ value = getattr(module, weight_name).data
1143
+ tbtag_tensor_map.update(
1144
+ self.build_tbtag_tensor_map(f'{context.module_name}.linear',
1145
+ '', 'sr', value)
1146
+ )
1147
+ get_sr_metric(tbtag_tensor_map, context.linear_feature)
1148
+
1149
+ context.micro_step += 1
1150
+ if context.micro_step == self.micro_batch_number:
1151
+ context.micro_step = 0
1152
+ context.step += 1
1153
+ return
1154
+
1028
1155
  def stack_hook(module, args, kwargs, module_output, name):
1029
1156
  if module not in self.module_fwd_hook_context_by_module:
1030
1157
  self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
@@ -1056,34 +1183,29 @@ class TrainerMon:
1056
1183
  self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
1057
1184
  logger.info_on_rank_0(f"> {name} is monitored successfully")
1058
1185
  hooked_count += 1
1059
- return hooked_count
1060
-
1061
- def _patch_grad_sync(self):
1062
- def patch_sync(sync_grad_func):
1063
- def wrapper(bucket):
1064
- grad_dict = {}
1065
- # Megatron between core_r0.6.0 and core_r0.8.0, this bucket is Bucket.
1066
- # When megatron is core_r0.9.0, this bucket is _ParamAndGradBucketGroup.
1067
- # In megatron version core_r0.9.0, func start_grad_sync from Bucket moved to _ParamAndGradBucketGroup.
1068
- bucket_params_id_list = [id(params) for params in bucket.params]
1069
- for param, name in self.param2name.items():
1070
- if id(param) not in bucket_params_id_list:
1071
- continue
1072
- grad = param.main_grad if self.params_have_main_grad else param.grad
1073
- if grad is None:
1074
- logger.warning(f"grad is None: {name}, maybe something wrong happened.")
1075
- continue
1076
- tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
1077
- if tag is None:
1186
+ if not self.print_struct and self.recording_l2_features:
1187
+ for module_name, submodule in module.named_modules():
1188
+ func_map = {
1189
+ "attention_hook": extract_attention_feature_hook,
1190
+ "linear_hook": extract_linear_sr_hook,
1191
+ }
1192
+ for hook_name in func_map.keys():
1193
+ if hook_name not in l2_target_names:
1078
1194
  continue
1079
- grad_dict[tag] = grad
1080
- self.register_param_call_id("sync_grad_func", tag)
1081
- get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
1082
- out = sync_grad_func(bucket)
1083
- return out
1195
+ temp_names = l2_target_names[hook_name]
1196
+ name = self._is_recording_module(module_name, temp_names, vpp_stage, hook_name)
1197
+ if name:
1198
+ handle = submodule.register_forward_hook(partial(func_map[hook_name], name=name))
1199
+ print_feature_name = hook_name.split('_')[0]
1200
+ logger.info_on_rank_0(
1201
+ f'> {print_feature_name} features of {name} is monitored successfully')
1202
+ self.handles["L2_features"].append(handle)
1203
+ hooked_count += 1
1204
+ continue
1084
1205
 
1085
- return wrapper
1206
+ return hooked_count
1086
1207
 
1208
+ def _patch_grad_sync(self):
1087
1209
  if not self.wg_distribution:
1088
1210
  return
1089
1211
  if self.fsdp_wrapped_module:
@@ -1091,27 +1213,18 @@ class TrainerMon:
1091
1213
  self._patch_fsdp_post_backward_hook()
1092
1214
  return
1093
1215
 
1216
+ if self.fsdp2_wrapped_module:
1217
+ # patch fsdp2 _fully_shard._fsdp_collectives.foreach_reduce
1218
+ self._patch_fsdp2_foreach_reduce()
1219
+ return
1220
+
1094
1221
  if self.monitor_mbs_grad:
1095
1222
  self._hook_weights()
1096
1223
  return
1097
- try:
1098
- from megatron.core.distributed.param_and_grad_buffer import Bucket
1099
- self.origin_start_grad_sync = Bucket.start_grad_sync
1100
- Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync)
1101
- self.enable_megatron = True
1102
- logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0")
1103
- except ImportError:
1104
- self.enable_megatron = False
1224
+
1225
+ self.optimizer_mon.patch_grad_sync(self)
1105
1226
 
1106
- try:
1107
- from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
1108
- self.origin_start_grad_sync = _ParamAndGradBucketGroup.start_grad_sync
1109
- _ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync)
1110
- self.enable_megatron = True
1111
- logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0")
1112
- except ImportError:
1113
- self.enable_megatron = False | self.enable_megatron
1114
- if self.enable_megatron:
1227
+ if self.enable_megatron or self.enable_deepspeed:
1115
1228
  return
1116
1229
 
1117
1230
  # default hook weights
@@ -1124,17 +1237,22 @@ class TrainerMon:
1124
1237
  每个forward阶段,fsdp对AccumulateGrad重复注册hook方法,monitor工具内注册hook无法生效,
1125
1238
  因此对_post_backward_hook进行patch,在backward后,reduce_scatter前采集梯度。
1126
1239
  """
1240
+
1127
1241
  def patch_post_backward_hook(_post_backward_hook):
1128
1242
  def wrapper(state, handle, *unused):
1129
1243
  grad_dict = {}
1130
- offset = 0
1131
- for param, name in self.param2name.items():
1132
- limit = param.numel()
1133
- if not limit:
1244
+ local_names = handle.flat_param._fqns
1245
+ offsets = handle._get_flat_param_offsets()
1246
+ shapes = handle.flat_param._shapes
1247
+ flat_prefix = next(self.flat_prefix_reverse_iter)
1248
+ for local_name, (start, end), local_shape in zip(local_names, offsets, shapes):
1249
+ grad_clip = handle.flat_param.grad[start:end + 1]
1250
+ grad = grad_clip.reshape(local_shape)
1251
+ total_name = f"{flat_prefix}{MonitorConst.FSDP_FLAT_SEP}{local_name}"
1252
+ if total_name not in self.origin2squash:
1253
+ logger.warning(f"{total_name} not in model.named_parameters(), skip.")
1134
1254
  continue
1135
- grad = handle.flat_param.grad[offset:offset + limit]
1136
- offset += limit
1137
- tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
1255
+ tag = self.name2tag.get(self.origin2squash[total_name], {}).get(MonitorConst.PRE_GRAD)
1138
1256
  if tag is None:
1139
1257
  continue
1140
1258
  grad_dict[tag] = grad
@@ -1150,6 +1268,28 @@ class TrainerMon:
1150
1268
  torch.distributed.fsdp._runtime_utils._post_backward_hook = \
1151
1269
  patch_post_backward_hook(torch.distributed.fsdp._runtime_utils._post_backward_hook)
1152
1270
 
1271
+ def _patch_fsdp2_foreach_reduce(self):
1272
+ def patch_foreach_reduce(foreach_reduce):
1273
+ def wrapper(fsdp_params, unsharded_grads, *unused):
1274
+ grad_dict = {}
1275
+ for param, grad in zip(fsdp_params, unsharded_grads):
1276
+ tag = self.name2tag.get(self.origin2squash[param._param_fqn], {}).get(MonitorConst.PRE_GRAD)
1277
+ if tag is None:
1278
+ continue
1279
+ grad_dict[tag] = grad
1280
+ self.register_param_call_id("foreach_reduce", tag)
1281
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
1282
+ out = foreach_reduce(fsdp_params, unsharded_grads, *unused)
1283
+ return out
1284
+ return wrapper
1285
+
1286
+ logger.info("Patch fsdp2 foreach_reduce, collect pre_grad metrics.")
1287
+ import torch.distributed.fsdp._fully_shard._fsdp_param_group as _fsdp_param_group
1288
+ import torch.distributed.fsdp._fully_shard._fsdp_collectives as _fsdp_collectives
1289
+ self.fsdp2_foreach_reduce = _fsdp_collectives.foreach_reduce
1290
+ _fsdp_collectives.foreach_reduce = patch_foreach_reduce(_fsdp_collectives.foreach_reduce)
1291
+ importlib.reload(_fsdp_param_group) # 关键操作,不然会因为torch一开始就import foreach_reduce导致patch失效
1292
+
1153
1293
  def _hook_weights(self):
1154
1294
  """
1155
1295
  遍历参数的梯度生成函数(grad_acc),并挂载hook,以便在该参数所有梯度计算后,采集通信聚合前梯度数据。
@@ -17,6 +17,7 @@ import re
17
17
  import torch
18
18
 
19
19
  from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
20
+ from msprobe.pytorch.monitor.features import cal_entropy, cal_stable_rank
20
21
  from msprobe.pytorch.monitor.utils import get_nan_tensor
21
22
 
22
23
 
@@ -31,7 +32,8 @@ def squash_param_name(param_name, enable=True):
31
32
  if not enable:
32
33
  return param_name
33
34
  name = ''
34
- for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']:
35
+ for pattern in ['^.*\.(layers?\..*)', '^.*\.(embeddings?\..*)', '^.*\.(final.*)', '^.*\.(output.*)',
36
+ '^.*\.(norm.*)']:
35
37
  match = re.findall(pattern, param_name)
36
38
  if match:
37
39
  name += match[0]
@@ -184,3 +186,27 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None):
184
186
  fun_metric = config_metric_registry.get(metric_name)
185
187
  out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
186
188
  return out_dict
189
+
190
+
191
+ def get_sr_metric(tag2tensor, out_dict=None):
192
+ if out_dict is None:
193
+ out_dict = {}
194
+ for tag, tensor in tag2tensor.items():
195
+ if "sr" not in tag:
196
+ continue
197
+ if tag not in out_dict:
198
+ out_dict[tag] = {}
199
+ sr, eig = cal_stable_rank(tensor)
200
+ out_dict[tag]['sr'] = sr
201
+ out_dict[tag]['kernel_norm'] = eig
202
+
203
+
204
+ def get_entropy_metric(tag2tensor, out_dict=None):
205
+ if out_dict is None:
206
+ out_dict = {}
207
+ for tag, tensor in tag2tensor.items():
208
+ if tag not in out_dict:
209
+ out_dict[tag] = {}
210
+ entropy, softmax_max = cal_entropy(tensor)
211
+ out_dict[tag]['entropy'] = entropy
212
+ out_dict[tag]['softmax_max'] = softmax_max