mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
  3. msprobe/README.md +25 -20
  4. msprobe/core/common/const.py +110 -66
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/utils.py +30 -34
  9. msprobe/core/compare/acc_compare.py +43 -74
  10. msprobe/core/compare/check.py +2 -6
  11. msprobe/core/compare/highlight.py +2 -0
  12. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  13. msprobe/core/compare/merge_result/merge_result.py +8 -2
  14. msprobe/core/compare/multiprocessing_compute.py +19 -12
  15. msprobe/core/compare/npy_compare.py +30 -12
  16. msprobe/core/compare/utils.py +20 -10
  17. msprobe/core/data_dump/api_registry.py +176 -0
  18. msprobe/core/data_dump/data_processor/base.py +2 -2
  19. msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
  20. msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
  21. msprobe/core/data_dump/json_writer.py +38 -35
  22. msprobe/core/grad_probe/constant.py +1 -0
  23. msprobe/core/grad_probe/grad_compare.py +1 -1
  24. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  25. msprobe/docs/01.installation.md +2 -1
  26. msprobe/docs/02.config_introduction.md +17 -15
  27. msprobe/docs/05.data_dump_PyTorch.md +70 -2
  28. msprobe/docs/06.data_dump_MindSpore.md +33 -12
  29. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  30. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  31. msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
  32. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  33. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  34. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  35. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  36. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  37. msprobe/docs/18.online_dispatch.md +1 -1
  38. msprobe/docs/19.monitor.md +124 -62
  39. msprobe/docs/21.visualization_PyTorch.md +32 -13
  40. msprobe/docs/22.visualization_MindSpore.md +32 -13
  41. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  42. msprobe/docs/27.dump_json_instruction.md +278 -8
  43. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  44. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  45. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  46. msprobe/docs/FAQ.md +3 -11
  47. msprobe/docs/img/compare_result.png +0 -0
  48. msprobe/docs/img/merge_result.png +0 -0
  49. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  50. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  51. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  52. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  53. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  54. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  55. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  56. msprobe/mindspore/__init__.py +4 -3
  57. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
  58. msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
  59. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  60. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  61. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  62. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  63. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  64. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  65. msprobe/mindspore/common/const.py +61 -0
  66. msprobe/mindspore/common/utils.py +31 -19
  67. msprobe/mindspore/compare/ms_compare.py +27 -19
  68. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  69. msprobe/mindspore/debugger/debugger_config.py +6 -4
  70. msprobe/mindspore/debugger/precision_debugger.py +22 -10
  71. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  72. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  73. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  74. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  75. msprobe/mindspore/dump/jit_dump.py +14 -9
  76. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  77. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  78. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  79. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  80. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  81. msprobe/mindspore/grad_probe/global_context.py +2 -0
  82. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  83. msprobe/mindspore/grad_probe/hook.py +2 -4
  84. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  85. msprobe/mindspore/monitor/module_hook.py +354 -302
  86. msprobe/mindspore/monitor/utils.py +46 -4
  87. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  88. msprobe/mindspore/service.py +23 -17
  89. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  90. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
  91. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  92. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  93. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  94. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  95. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  96. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  97. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  98. msprobe/pytorch/common/utils.py +29 -7
  99. msprobe/pytorch/debugger/precision_debugger.py +10 -1
  100. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  101. msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
  102. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  103. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  104. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  105. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  106. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  107. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  108. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  109. msprobe/pytorch/function_factory.py +1 -1
  110. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  111. msprobe/pytorch/hook_module/api_register.py +131 -0
  112. msprobe/pytorch/hook_module/hook_module.py +19 -14
  113. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  114. msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
  115. msprobe/pytorch/monitor/csv2tb.py +8 -2
  116. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  117. msprobe/pytorch/monitor/module_hook.py +131 -105
  118. msprobe/pytorch/monitor/module_metric.py +3 -0
  119. msprobe/pytorch/monitor/optimizer_collect.py +55 -4
  120. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  121. msprobe/pytorch/monitor/utils.py +68 -1
  122. msprobe/pytorch/online_dispatch/compare.py +0 -2
  123. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  124. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  125. msprobe/pytorch/online_dispatch/utils.py +3 -0
  126. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  127. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  128. msprobe/pytorch/pt_config.py +11 -7
  129. msprobe/pytorch/service.py +11 -8
  130. msprobe/visualization/builder/graph_builder.py +44 -5
  131. msprobe/visualization/builder/msprobe_adapter.py +0 -1
  132. msprobe/visualization/compare/graph_comparator.py +42 -38
  133. msprobe/visualization/compare/mode_adapter.py +0 -19
  134. msprobe/visualization/graph/base_node.py +8 -1
  135. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  136. msprobe/visualization/graph/graph.py +0 -11
  137. msprobe/visualization/graph/node_op.py +1 -2
  138. msprobe/visualization/graph_service.py +1 -1
  139. msprobe/visualization/utils.py +2 -33
  140. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  141. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  142. msprobe/pytorch/hook_module/api_registry.py +0 -166
  143. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  144. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  145. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  146. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  147. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  148. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  149. msprobe/pytorch/parse.py +0 -19
  150. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  151. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  152. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  153. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -22,13 +22,15 @@ from torch.utils.tensorboard import SummaryWriter
22
22
  from tqdm import tqdm
23
23
 
24
24
  from msprobe.core.common.const import MonitorConst
25
- from msprobe.core.common.file_utils import read_csv, create_directory, remove_path
25
+ from msprobe.core.common.file_utils import read_csv, create_directory, remove_path, recursive_chmod
26
26
  from msprobe.core.common.utils import is_int
27
+ from msprobe.core.common.decorator import recursion_depth_decorator
27
28
  from msprobe.pytorch.common.log import logger
28
29
  from msprobe.pytorch.monitor.utils import get_target_output_dir
29
30
 
30
31
  all_data_type_list = ["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param"]
31
32
  CSV_FILE_SUFFIX = r"_\d+-\d+\.csv"
33
+ MAX_PROCESS_NUM = 128
32
34
 
33
35
 
34
36
  def parse_step_line(line, ops):
@@ -76,6 +78,7 @@ def write_step(output_dirpath, parse_step_result, rank, data_type):
76
78
  writer.add_scalar(tag, value, step)
77
79
 
78
80
 
81
+ @recursion_depth_decorator("update_dict", max_depth=50)
79
82
  def update_dict(dict1, dict2):
80
83
  for key, value in dict2.items():
81
84
  if key in dict1:
@@ -115,11 +118,13 @@ def csv2tb_by_step_work(target_output_dirs, output_dirpath, data_type_list):
115
118
  def check_process_num(process_num):
116
119
  if not is_int(process_num) or process_num <= 0:
117
120
  raise ValueError(f"process_num({process_num}) is not a positive integer")
121
+ if process_num > MAX_PROCESS_NUM:
122
+ raise ValueError(f"The maximum supported process_num is {MAX_PROCESS_NUM}, current value: {process_num}.")
118
123
 
119
124
 
120
125
  def check_data_type_list(data_type_list):
121
126
  if data_type_list is None:
122
- logger.info(f"data_type_list is None, use defualt all_data_type_list: {all_data_type_list}")
127
+ logger.info(f"data_type_list is None, use default all_data_type_list: {all_data_type_list}")
123
128
  return
124
129
  if not isinstance(data_type_list, list):
125
130
  raise ValueError(f"data_type_list({data_type_list}) is not a list")
@@ -161,4 +166,5 @@ def csv2tensorboard_by_step(
161
166
  p.start()
162
167
  for p in processes:
163
168
  p.join()
169
+ recursive_chmod(output_dirpath)
164
170
  logger.info(f"output has been saved to: {output_dirpath}")
@@ -24,6 +24,7 @@ import torch.nn as nn
24
24
  from msprobe.core.common.const import MonitorConst
25
25
  from msprobe.core.common.file_utils import load_yaml
26
26
  from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name
27
+ from msprobe.pytorch.common.log import logger
27
28
 
28
29
  try:
29
30
  import torch_npu
@@ -37,6 +38,7 @@ WrapDistributedOps = load_yaml(OpsPath).get("distributed", [])
37
38
 
38
39
  StackBlackListPath = os.path.join(os.path.dirname(__file__), "stack_blacklist.yaml")
39
40
  StackBlackList = load_yaml(StackBlackListPath).get("stack", [])
41
+ MAX_STRING_LENGTH = 1000
40
42
 
41
43
  distributed_func = {}
42
44
  for f in dir(dist):
@@ -139,6 +141,8 @@ def get_process_group(process_group):
139
141
 
140
142
 
141
143
  def stack_filter(stack):
144
+ if len(stack) > MAX_STRING_LENGTH:
145
+ logger.warning(f'The character string contains more than {MAX_STRING_LENGTH}. re match is skipped.')
142
146
  for pattern in StackBlackList:
143
147
  if re.search(pattern, stack):
144
148
  return False
@@ -188,10 +192,12 @@ def update_data(old, new):
188
192
 
189
193
 
190
194
  def is_target_line(codeline):
191
- stack = get_callstack()
192
- whole_stack = ';'.join(stack)
193
195
  if codeline == []:
194
196
  return True
197
+ stack = get_callstack()
198
+ whole_stack = ';'.join(stack)
199
+ if len(whole_stack) > MAX_STRING_LENGTH:
200
+ logger.warning(f'The character string contains more than {MAX_STRING_LENGTH}. re match is skipped.')
195
201
  for pattern in codeline:
196
202
  if re.search(pattern, whole_stack):
197
203
  return True
@@ -26,8 +26,9 @@ from torch.utils.hooks import BackwardHook
26
26
 
27
27
  from msprobe.core.common.const import MonitorConst, Const
28
28
  from msprobe.core.common.file_utils import load_json, save_json
29
+ from msprobe.core.common.decorator import recursion_depth_decorator
29
30
  from msprobe.pytorch.common.log import logger
30
- from msprobe.pytorch.common.utils import is_recomputation
31
+ from msprobe.pytorch.common.utils import is_recomputation, is_float8_tensor
31
32
  from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter
32
33
  from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
33
34
  CSVWriterWithAD, BaseWriterWithAD, WriterInput
@@ -39,7 +40,7 @@ from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_write
39
40
  from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
40
41
  from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
41
42
  from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \
42
- get_output_base_dir, get_target_output_dir
43
+ get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor
43
44
  from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
44
45
 
45
46
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
@@ -176,7 +177,8 @@ class GradContext:
176
177
  class TrainerMon:
177
178
  tensor_metrics = TensorMetrics()
178
179
 
179
- def __init__(self, config_file_path, process_group=None, params_have_main_grad=True) -> None:
180
+ # 保留原opt_ty参数, 兼容msprobe1.2.2前旧版本
181
+ def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None:
180
182
  # TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
181
183
  self.config_file_path = config_file_path
182
184
  self.process_group = get_process_group(process_group)
@@ -222,6 +224,7 @@ class TrainerMon:
222
224
  self.micro_batch_number = 1
223
225
  self.optimizer_class = None
224
226
  self.optimizer_mon = None
227
+ self.optimizer_trans = None
225
228
 
226
229
  # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
227
230
  self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
@@ -322,8 +325,6 @@ class TrainerMon:
322
325
  self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
323
326
  self.cc_logged_stack = defaultdict(set)
324
327
  self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
325
- self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
326
- api_register.redirect_api()
327
328
 
328
329
  self.common_info()
329
330
 
@@ -336,11 +337,11 @@ class TrainerMon:
336
337
 
337
338
  # 初始化writer, 创建输出目录
338
339
  if self.format not in FORMAT_MAPPING:
339
- logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
340
+ logger.warning(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
340
341
  self.format = MonitorConst.CSV
341
342
 
342
343
  if self.ur_distribution and self.format != 'tensorboard':
343
- logger.error("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution")
344
+ logger.warning("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution")
344
345
  self.ur_distribution = False
345
346
 
346
347
  writer = FORMAT_MAPPING[self.format]
@@ -363,19 +364,6 @@ class TrainerMon:
363
364
  self.rank)
364
365
  self.anomaly_data_writer.init_detected_json()
365
366
 
366
- def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
367
- rank = None
368
- if dist.is_initialized():
369
- rank = dist.get_rank()
370
- if (rank not in rank_list) and len(rank_list) != 0:
371
- return
372
- self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
373
-
374
- def build_tbtag_tensor_map(self, module_name, tag, tensor):
375
- key = get_summary_writer_tag_name(module_name, tag, self.rank)
376
- self._register_param_call_id("_hook_module", key)
377
- return {key: tensor}
378
-
379
367
  def common_info(self):
380
368
  if not self.xy_distribution:
381
369
  logger.info_on_rank_0("> module input/output input_grad/output_grad is not monitored. ")
@@ -392,94 +380,31 @@ class TrainerMon:
392
380
  if not self.cc_distribution.get('enable', False):
393
381
  logger.info_on_rank_0("> cc operator is not monitored.")
394
382
 
395
- def hook_modules(self):
396
- if self.module_rank_list and (self.rank not in self.module_rank_list):
397
- return
398
-
399
- targets = self.config['targets']
400
- module_in_all_stage = [key for key in targets.keys() if MonitorConst.NAME_SEP not in key]
401
- for key in module_in_all_stage:
402
- struct = targets.pop(key)
403
- targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
404
-
405
- hooked_count = 0
406
- for vpp_stage, model_chunk in enumerate(self.model):
407
- vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
408
- targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
409
- 'targets'].keys()
410
- hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
411
-
412
- logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
413
-
414
- def clone_if_tensor(args):
415
- if isinstance(args, tuple):
416
- return tuple([clone_if_tensor(arg) for arg in args])
417
- elif isinstance(args, torch.Tensor):
418
- return args.clone()
419
- else:
420
- return args
421
-
422
- @torch.no_grad
423
- def wrap_hook_setup(setup):
424
- def wrapped_setup(*args, **kwargs):
425
- args = setup(*args, **kwargs)
426
- args = clone_if_tensor(args)
427
- return args
428
-
429
- return wrapped_setup
430
-
431
- BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
432
-
433
- return
434
-
435
- def generate_param_metrics(self, opt_context):
436
- if not self.param_distribution:
437
- return
438
- get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
439
-
440
- def generate_mv_metrics(self, opt_context):
441
- if not self.mv_distribution:
442
- return
443
- opt_context.exp_avg_metric = {}
444
- opt_context.exp_avg_sq_metric = {}
445
- m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
446
- v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
447
- get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
448
- get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
449
-
450
- def generate_wgrad_metrics(self):
451
- if not self.wg_distribution:
452
- return {}, {}
453
-
454
- if self.weight_hooked:
455
- get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
456
-
457
- grad_dict = {}
458
- for param, name in self.param2name.items():
459
- if self.duplicate_param.get(name, False):
460
- continue
461
- grad = param.main_grad if self.params_have_main_grad else param.grad
462
- if grad is None:
463
- logger.warning(f"grad is None: {name}, maybe something wrong happened.")
464
- continue
465
- tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
466
- self._register_param_call_id("hook_optimizer", tag)
467
- grad_dict[tag] = grad
383
+ # 保留原接口, 兼容msprobe1.2.2前旧版本
384
+ def monitor_gnorm_with_ad(self, model, optimizer=None, grad_acc_steps=1, tp_group=None, dp_group=None,
385
+ start_iteration=0):
386
+ if optimizer is None:
387
+ optimizer = getattr(self, "optimizer_trans", None) # 兼容老版本可传None的情况, 从set_wrapped_optimizer获取
388
+ if optimizer is None:
389
+ logger.error("monitor_gnorm_with_ad: please set_wrapped_optimizer before it or input optimizer!=None")
390
+ return
391
+ self.set_monitor(model, optimizer, grad_acc_steps, tp_group, dp_group, start_iteration)
468
392
 
469
- get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
470
- unreduced_grad = self.grad_context.acc_metric if self.weight_hooked else self.grad_context.pre
471
- return self.grad_context.post, unreduced_grad
393
+ # 保留原接口, 兼容msprobe1.2.2前旧版本
394
+ def set_wrapped_optimizer(self, optimizer):
395
+ self.optimizer_trans = optimizer
472
396
 
473
397
  def set_monitor(
474
398
  self,
475
399
  model,
400
+ optimizer,
476
401
  grad_acc_steps=1,
477
- optimizer=None,
478
402
  tp_group=None,
479
403
  dp_group=None,
480
404
  start_iteration=0
481
405
  ):
482
406
  """External interface"""
407
+ grad_acc_steps, start_iteration = validate_set_monitor(grad_acc_steps, start_iteration)
483
408
  global start_step
484
409
  start_step = start_iteration
485
410
  logger.info(f'grad acc steps {grad_acc_steps}')
@@ -502,8 +427,24 @@ class TrainerMon:
502
427
  self.hook_optimizer(optimizer)
503
428
  self._patch_grad_sync()
504
429
  self.hook_modules()
430
+ if self.cc_distribution.get('enable', False):
431
+ self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
432
+ api_register.redirect_api()
505
433
  self.monitoring = True
506
434
 
435
+ def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
436
+ rank = None
437
+ if dist.is_initialized():
438
+ rank = dist.get_rank()
439
+ if (rank not in rank_list) and len(rank_list) != 0:
440
+ return
441
+ self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
442
+
443
+ def build_tbtag_tensor_map(self, module_name, tag, tensor):
444
+ key = get_summary_writer_tag_name(module_name, tag, self.rank)
445
+ self._register_param_call_id("_hook_module", key)
446
+ return {key: tensor}
447
+
507
448
  def generate_param_map(self, tag, param_tensor):
508
449
  metrics = {}
509
450
  for name in self.param2name.values():
@@ -514,6 +455,44 @@ class TrainerMon:
514
455
  metrics[key] = param_tensor[name]
515
456
  return metrics
516
457
 
458
+ def generate_param_metrics(self, opt_context):
459
+ if not self.param_distribution:
460
+ return
461
+ get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
462
+
463
+ def generate_mv_metrics(self, opt_context):
464
+ if not self.mv_distribution:
465
+ return
466
+ opt_context.exp_avg_metric = {}
467
+ opt_context.exp_avg_sq_metric = {}
468
+ m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
469
+ v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
470
+ get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
471
+ get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
472
+
473
+ def generate_wgrad_metrics(self):
474
+ if not self.wg_distribution:
475
+ return {}, {}
476
+
477
+ if self.weight_hooked:
478
+ get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
479
+
480
+ grad_dict = {}
481
+ for param, name in self.param2name.items():
482
+ if self.duplicate_param.get(name, False):
483
+ continue
484
+ grad = param.main_grad if self.params_have_main_grad else param.grad
485
+ if grad is None:
486
+ logger.warning(f"grad is None: {name}, maybe something wrong happened.")
487
+ continue
488
+ tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
489
+ self._register_param_call_id("hook_optimizer", tag)
490
+ grad_dict[tag] = grad
491
+
492
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
493
+ unreduced_grad = self.grad_context.acc_metric if self.weight_hooked else self.grad_context.pre
494
+ return self.grad_context.post, unreduced_grad
495
+
517
496
  def generate_xy_metrics(self):
518
497
  actv = {}
519
498
  for fwd_context in self.module_fwd_hook_context_by_module.values():
@@ -557,9 +536,9 @@ class TrainerMon:
557
536
  def write_mv_tb(self, opt_context):
558
537
  if not self.mv_distribution:
559
538
  return
560
- self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric,
539
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric,
561
540
  opt_context.step, MonitorConst.EXP_AVG)
562
- self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric,
541
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric,
563
542
  opt_context.step, MonitorConst.EXP_AVG_SQ)
564
543
 
565
544
  def write_grad_tb(self, step):
@@ -572,7 +551,7 @@ class TrainerMon:
572
551
  self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
573
552
  self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
574
553
 
575
- def hook_optimizer(self, optimizer=None):
554
+ def hook_optimizer(self, optimizer):
576
555
  # in DDP by default use params_have_main_grad
577
556
  def optimizer_pre_step_hook(optimizer, args, kwargs):
578
557
  context = self.optimizer_context[optimizer]
@@ -638,7 +617,6 @@ class TrainerMon:
638
617
  optimizer_pre_step_hook(optimizer, args, kwargs)
639
618
  out = func(*args, **kwargs)
640
619
  return out
641
-
642
620
  return wrapper
643
621
 
644
622
  if self.optimizer_hooked:
@@ -674,6 +652,7 @@ class TrainerMon:
674
652
  validate_config(config)
675
653
  self.config = config
676
654
  self.set_config()
655
+ self.start_step = context.step # 动态启停时不受原start_step影响,永远从下一步开始
677
656
  logger.warning(f"config is updated at step{context.step - 1}, "
678
657
  f"will start new hook at step{context.step}.")
679
658
  except Exception as e:
@@ -721,6 +700,9 @@ class TrainerMon:
721
700
  if self.anomaly_data_factory:
722
701
  self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
723
702
  self.summary_writer.clear_anomalies()
703
+
704
+ if self.format == MonitorConst.TENSORBOARD:
705
+ chmod_tensorboard_dir(self.tensorboard_dir)
724
706
  self.call_id = 0
725
707
  self.param_name_call_id.clear()
726
708
 
@@ -739,7 +721,47 @@ class TrainerMon:
739
721
 
740
722
  optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
741
723
  self.origin_step_func = optimizer.__class__.step
724
+ return
725
+
726
+ def hook_modules(self):
727
+ if self.module_rank_list and (self.rank not in self.module_rank_list):
728
+ return
729
+
730
+ targets = self.config['targets']
731
+ module_in_all_stage = [key for key in targets.keys() if MonitorConst.NAME_SEP not in key]
732
+ for key in module_in_all_stage:
733
+ struct = targets.pop(key)
734
+ targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
735
+
736
+ hooked_count = 0
737
+ for vpp_stage, model_chunk in enumerate(self.model):
738
+ vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
739
+ targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
740
+ 'targets'].keys()
741
+ hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
742
+
743
+ logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
744
+
745
+ @recursion_depth_decorator('msprobe.pytorch.monitor.clone_if_tensor')
746
+ def clone_if_tensor(args):
747
+ if isinstance(args, tuple):
748
+ return tuple([clone_if_tensor(arg) for arg in args])
749
+ elif isinstance(args, torch.Tensor) and not is_float8_tensor(args):
750
+ return args.clone()
751
+ else:
752
+ return args
742
753
 
754
+ @torch.no_grad
755
+ def wrap_hook_setup(setup):
756
+ def wrapped_setup(*args, **kwargs):
757
+ args = setup(*args, **kwargs)
758
+ args = clone_if_tensor(args)
759
+ return args
760
+
761
+ return wrapped_setup
762
+
763
+ BackwardHook.setup_input_hook = wrap_hook_setup(BackwardHook.setup_input_hook)
764
+ BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
743
765
  return
744
766
 
745
767
  def _remove_all_hooks(self, optimizer):
@@ -783,6 +805,7 @@ class TrainerMon:
783
805
  for handle in self.handles['cc']:
784
806
  handle.remove()
785
807
  self.handles['cc'].clear()
808
+ api_register.restore_api()
786
809
  for _, context in self.cc_context.items():
787
810
  context.reset()
788
811
 
@@ -956,7 +979,7 @@ class TrainerMon:
956
979
  return
957
980
  if not context.verified:
958
981
  context.focused_in_col = validate_config_spec(
959
- context.format_by_arg[MonitorConst.INPUT_GRAD],
982
+ context.format_by_arg[MonitorConst.INPUT_GRAD],
960
983
  input_grad, context.module_name, MonitorConst.INPUT_GRAD)
961
984
  context.focused_out_col = validate_config_spec(
962
985
  context.format_by_arg[MonitorConst.OUTPUT_GRAD],
@@ -1052,7 +1075,7 @@ class TrainerMon:
1052
1075
  self.enable_megatron = True
1053
1076
  logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0")
1054
1077
  except ImportError:
1055
- self.enable_megatron = False
1078
+ self.enable_megatron = False | self.enable_megatron
1056
1079
 
1057
1080
  if not self.enable_megatron:
1058
1081
  self._hook_weights()
@@ -1067,9 +1090,12 @@ class TrainerMon:
1067
1090
  if param.micro_step == self.micro_batch_number:
1068
1091
  param.micro_step = 0
1069
1092
  if self.params_have_main_grad:
1070
- context_dict[key] = param.main_grad.clone()
1093
+ grad = param.main_grad
1071
1094
  else:
1072
- context_dict[key] = param.grad.clone()
1095
+ grad = param.grad
1096
+ if is_float8_tensor(grad):
1097
+ grad = grad.float()
1098
+ context_dict[key] = grad.clone()
1073
1099
 
1074
1100
  logger.info("hooking weights.")
1075
1101
  for param, name in self.param2name.items():
@@ -16,6 +16,7 @@ import re
16
16
 
17
17
  import torch
18
18
 
19
+ from msprobe.pytorch.common.utils import is_float8_tensor
19
20
  from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
20
21
  from msprobe.pytorch.monitor.utils import get_nan_tensor
21
22
 
@@ -166,6 +167,8 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None):
166
167
  # Non-tensor in/output filled with nan.
167
168
  out_dict[tag].update({metric_name: get_nan_tensor() for metric_name in ops})
168
169
  continue
170
+ if is_float8_tensor(tensor):
171
+ tensor = tensor.float()
169
172
  for metric_name in ops:
170
173
  fun_metric = config_metric_registry.get(metric_name)
171
174
  out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
@@ -185,7 +185,7 @@ class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
185
185
  for opt in torch_opt.chained_optimizers:
186
186
  self.map_fp16_tp_fp32_param(opt)
187
187
 
188
- if not isinstance(torch_opt, torch.optim.Optimizer):
188
+ if not isinstance(torch_opt, torch.optim.Optimizer) and not hasattr(torch_opt, 'state'):
189
189
  torch_opt.state = {}
190
190
  for opt in torch_opt.chained_optimizers:
191
191
  torch_opt.state.update(opt.optimizer.state)
@@ -198,7 +198,7 @@ class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
198
198
  for opt in torch_opt.chained_optimizers:
199
199
  self.map_fp16_tp_fp32_param(opt)
200
200
 
201
- if not isinstance(torch_opt, torch.optim.Optimizer):
201
+ if not isinstance(torch_opt, torch.optim.Optimizer) and not hasattr(torch_opt, 'state'):
202
202
  torch_opt.state = {}
203
203
  for opt in torch_opt.chained_optimizers:
204
204
  torch_opt.state.update(opt.optimizer.state)
@@ -206,9 +206,60 @@ class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
206
206
 
207
207
 
208
208
  class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon):
209
- def fetch_mv(self, monitor, torch_opt, params2name):
210
- return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
209
+ def get_group_index(self, torch_opt):
210
+ bit16_groups = torch_opt.bf16_groups
211
+ param2group = defaultdict()
212
+ for group_idx, bit16_group in enumerate(bit16_groups):
213
+ for param in bit16_group:
214
+ param2group[param] = group_idx
215
+ return param2group
216
+
217
+ def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
218
+ param2group = self.get_group_index(torch_opt)
219
+ exp_avg_dict = defaultdict(float)
220
+ exp_avg_sq_dict = defaultdict(float)
221
+ update_dict = defaultdict()
222
+ ratio_dict = defaultdict()
223
+
224
+ param_slice_mappings = torch_opt.state_dict()['param_slice_mappings']
225
+ for param, name in params2name.items():
226
+ group_idx = param2group[param]
227
+ state = torch_opt.optimizer.state[torch_opt.fp32_groups_flat_partition[group_idx]]
228
+ if state.get('exp_avg', None) is None:
229
+ logger.warning(f"optimizer state is None. Something is wrong if this is not the first step")
230
+ break
231
+ param_slice_mapping = param_slice_mappings[group_idx]
232
+ hp_address = param_slice_mapping.get(torch_opt.param_names[param])
233
+ if hp_address is None:
234
+ continue
235
+ start = hp_address.start
236
+ numel = hp_address.numel
211
237
 
238
+ if monitor.mv_distribution:
239
+ exp_avg_dict[name] = state['exp_avg'].narrow(0, start, numel)
240
+ exp_avg_sq_dict[name] = state['exp_avg_sq'].narrow(0, start, numel)
241
+ if monitor.mg_direction:
242
+ exp_avg_dict[name] = state['exp'].narrow(0, start, numel)
243
+ if monitor.ur_distribution:
244
+ if len(torch_opt.param_groups) > 1:
245
+ logger.info(f"the length of torch_opt.param_groups is {len(torch_opt.param_groups)}.")
246
+ if 'step' in state:
247
+ step = state['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
248
+ elif 'step' in torch_opt.param_groups[0]:
249
+ step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed
250
+ else:
251
+ logger.warning(f"step of {name} is None, maybe something wrong happened.")
252
+ continue
253
+ exp_avg = state['exp_avg'].narrow(0, start, numel)
254
+ exp_avg_sq = state['exp_avg_sq'].narrow(0, start, numel)
255
+ exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
256
+ exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
257
+ update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
258
+ ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
259
+ monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
260
+ monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
261
+ return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
262
+
212
263
 
213
264
  class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
214
265
  def get_param_index(self, params2name, name2index, torch_opt):
@@ -92,7 +92,7 @@ def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel):
92
92
  if errors:
93
93
  logger.info(errors)
94
94
  else:
95
- logger.info(f'grad mean is in consist between unreduced grad and reduced grad monitord.')
95
+ logger.info(f'grad mean is in consist between unreduced grad and reduced grad monitored.')
96
96
 
97
97
 
98
98
  def assert_equal(a, b):