mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -13,11 +13,13 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from gzip import FEXTRA
16
17
  import os
17
18
  import re
18
19
  import uuid
19
20
  from collections import defaultdict
20
21
  from datetime import datetime
22
+ from functools import partial
21
23
 
22
24
  import pytz
23
25
  import pandas as pd
@@ -27,16 +29,18 @@ from mindspore import nn, _no_grad
27
29
 
28
30
  from msprobe.core.common.log import logger
29
31
  from msprobe.core.common.const import MonitorConst, Const
30
- from msprobe.core.common.file_utils import load_json, save_json
32
+ from msprobe.core.common.file_utils import load_json, save_json, make_dir
31
33
  from msprobe.core.monitor.utils import validate_config, get_output_base_dir, get_target_output_dir
32
34
  from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter
33
35
  from msprobe.mindspore.common.utils import is_mindtorch
34
- from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank
36
+ from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank, \
37
+ comm_is_initialized
35
38
  from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, step_accumulates_one, is_skip_step, \
36
- get_metrics
39
+ get_metrics, get_entropy_metric, get_sr_metric
37
40
  from msprobe.mindspore.monitor.optimizer_collect import OptimizerMonFactory
38
41
  from msprobe.mindspore.monitor.data_writers import CSVWriterWithAD, BaseWriterWithAD, WriterInput
39
42
  from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate
43
+ from msprobe.mindspore.monitor.features import cal_qkt
40
44
  from msprobe.core.common.file_utils import write_df_to_csv
41
45
  from msprobe.core.common.utils import analyze_api_call_stack
42
46
 
@@ -76,13 +80,24 @@ def param_is_data_parallel_duplicate(dp_group):
76
80
 
77
81
 
78
82
  def squash_param_name(param_name):
79
- for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']:
83
+ for pattern in ['^.*\.(layers?\..*)', '^.*\.(embeddings?\..*)', '^.*\.(final.*)', '^.*\.(output.*)',
84
+ '^.*\.(norm.*)']:
80
85
  match = re.findall(pattern, param_name)
81
86
  if match:
82
87
  return match[0]
83
88
  return param_name
84
89
 
85
90
 
91
+ def is_recording_module(module_name, l2_targets, vpp_stage):
92
+ if len(l2_targets) > 0:
93
+ for pattern in [vpp_stage + squash_param_name(module_name), vpp_stage + module_name]:
94
+ if pattern in l2_targets:
95
+ return pattern
96
+ return ""
97
+ else:
98
+ raise NotImplementedError("If monitering l2_features, the targets should be set specifically.")
99
+
100
+
86
101
  # Used For Module Forward & Backward Collect
87
102
  class ModuleHookContext:
88
103
  def __init__(self, module_name) -> None:
@@ -99,6 +114,19 @@ class ModuleHookContext:
99
114
  self.actvgrad.clear()
100
115
 
101
116
 
117
+ class FeatureHookContext:
118
+ def __init__(self, module_name):
119
+ self.step = 0
120
+ self.micro_step = 0
121
+ self.attention_feature = {}
122
+ self.linear_feature = {}
123
+ self.module_name = module_name
124
+
125
+ def reset(self):
126
+ self.attention_feature.clear()
127
+ self.linear_feature.clear()
128
+
129
+
102
130
  start_step = 0
103
131
 
104
132
 
@@ -211,6 +239,7 @@ class TrainerMon:
211
239
  # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
212
240
  self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
213
241
  self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
242
+ self.feature_hook_context_by_module = defaultdict(FeatureHookContext)
214
243
  self.optimizer_context = defaultdict(OptimizerContext)
215
244
  self.cc_context = defaultdict(CommunicationContext)
216
245
  self.grad_context = GradContext()
@@ -244,6 +273,18 @@ class TrainerMon:
244
273
  if self.collect_times > 0:
245
274
  self.monitoring = True
246
275
 
276
+ @staticmethod
277
+ def get_linear_hook_target(module):
278
+ if isinstance(module, nn.Embedding):
279
+ return ''
280
+ if hasattr(module, "num_embeddings") or hasattr(module, "vocab_start_index"):
281
+ return ''
282
+ for weight_name in ["weight", "wg"]:
283
+ if hasattr(module, weight_name) and isinstance(getattr(module, weight_name), Tensor):
284
+ if getattr(module, weight_name).dim() == 2:
285
+ return weight_name
286
+ return ''
287
+
247
288
  def set_config(self):
248
289
  self.start_step = self.config.get("start_step", 0)
249
290
  self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集
@@ -268,6 +309,9 @@ class TrainerMon:
268
309
  self.cc_distribution = self.config.get("cc_distribution", {}) # communication ops
269
310
  self.stack_info = self.config.get('stack_info', False)
270
311
  self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False)
312
+ self.recording_l2_features = self.config.get('recording_l2_features', False)
313
+ self.sa_order = self.config.get('sa_order', "s,b,h,d")
314
+
271
315
 
272
316
  if not self.cc_distribution.get('enable', False):
273
317
  self.cc_log_only = False
@@ -320,6 +364,8 @@ class TrainerMon:
320
364
  logger.info("> momentum and variance of adam is not monitored. ")
321
365
  if not self.wg_distribution:
322
366
  logger.info("> weight grad of specified module is not monitored. ")
367
+ if not self.recording_l2_features:
368
+ logger.info("> l2 features of specified module is not monitored. ")
323
369
  if not self.mg_direction:
324
370
  logger.info('> grad and momentum direction will not be compared.')
325
371
  if not self.cc_distribution.get('enable', False):
@@ -367,6 +413,7 @@ class TrainerMon:
367
413
  self.write_grad_tb(context.step)
368
414
  self.write_mv_tb(context)
369
415
  self.write_param_tb(context)
416
+ self.write_features_tb(context.step)
370
417
  if self.stack_info:
371
418
  self.write_stack_info()
372
419
  self.stack_info = False
@@ -391,7 +438,6 @@ class TrainerMon:
391
438
  context.step += 1
392
439
  self.dynamic_monitor(optimizer)
393
440
 
394
-
395
441
  def patch_step(func, optimizer):
396
442
  def wrapper(*args, **kwargs):
397
443
  for hook in self.pre_step_hooks:
@@ -472,12 +518,18 @@ class TrainerMon:
472
518
  continue
473
519
  vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
474
520
  targets = [x for x, _ in get_submodules(model_chunk)] if self.print_struct else self.targets.keys()
475
- hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
521
+ l2_target_names = self.config.get('l2_targets', {})
522
+ hooked_count += self._hook_module(targets, l2_target_names, model_chunk, vpp_stage)
476
523
  logger.info(f"> {hooked_count} modules are monitored.")
477
524
 
478
525
  def hook_optimizer(self, optimizer):
479
526
  def optimizer_pre_step_hook(opt, *args, **kwargs):
480
527
  context = self.optimizer_context[opt]
528
+ if (self.print_struct and not all(value == {} for value in self.module_struct.values())
529
+ and not self.struct_printed):
530
+ self._save_module_struct()
531
+ if not self.cc_log_only:
532
+ raise Exception("exit after first monitor step when print model struct")
481
533
  if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
482
534
  self.collect_times):
483
535
  return
@@ -623,6 +675,25 @@ class TrainerMon:
623
675
  use_micro_step=self.monitor_mbs_grad)
624
676
  self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
625
677
 
678
+ def write_metrics_if_not_empty(self, features, metrics, step, hook_name):
679
+ if not features or len(features) == 0:
680
+ return
681
+ use_micro_step = hook_name not in ["linear_hook"]
682
+ self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=use_micro_step)
683
+ features.clear()
684
+
685
+ def write_features_tb(self, step):
686
+ if not self.recording_l2_features:
687
+ return
688
+ for context in self.feature_hook_context_by_module.values():
689
+ num_features = len(context.attention_feature) + len(context.linear_feature)
690
+ if num_features == 0:
691
+ continue
692
+ self.write_metrics_if_not_empty(context.attention_feature, ["entropy", "softmax"], step,
693
+ "attention_hook")
694
+ self.write_metrics_if_not_empty(context.linear_feature, ["sr", "kernel_norm"], step,
695
+ "linear_hook")
696
+
626
697
  def is_target_rank(self):
627
698
  if self.module_rank_list and (self.rank not in self.module_rank_list):
628
699
  return False
@@ -695,7 +766,15 @@ class TrainerMon:
695
766
  }
696
767
  index += 1
697
768
 
698
- def _hook_module(self, target_names, module, vpp_stage=''):
769
+ def _save_module_struct(self):
770
+ output_dir = os.path.join(get_output_base_dir(), 'module_struct', f'rank{self.rank}')
771
+ make_dir(output_dir)
772
+ module_struct_file = os.path.realpath(os.path.join(output_dir, 'module_struct.json'))
773
+ save_json(module_struct_file, self.module_struct, indent=2)
774
+ logger.info(f"> save module struct to {module_struct_file}")
775
+ self.struct_printed = True
776
+
777
+ def _hook_module(self, target_names, l2_target_names, module, vpp_stage=''):
699
778
  if not is_valid_instance(module):
700
779
  # nothing to hook
701
780
  return 0
@@ -785,7 +864,8 @@ class TrainerMon:
785
864
  return
786
865
 
787
866
  def fwd_hook_register(module, fwd_hook_fun, name):
788
- if mindspore.__version__ >= '2.6.0':
867
+ from packaging import version
868
+ if version.parse(mindspore.__version__) >= version.parse('2.6.0'):
789
869
  def wrapper(module, args, kwargs, module_output):
790
870
  return fwd_hook_fun(module, args, kwargs, module_output, name)
791
871
  return module.register_forward_hook(wrapper, with_kwargs=True)
@@ -795,6 +875,61 @@ class TrainerMon:
795
875
  return fwd_hook_fun(module, args, None, module_output, name)
796
876
  return module.register_forward_hook(wrapper)
797
877
 
878
+ def extract_attention_feature_hook(module, args, kwargs, module_output, name):
879
+ module_input = [tensor for tensor in args if isinstance(tensor, Tensor)]
880
+ if kwargs:
881
+ kwargs_tensors = [tensor for tensor in kwargs.values() if isinstance(tensor, Tensor)]
882
+ module_input.extend(kwargs_tensors)
883
+
884
+ if module not in self.feature_hook_context_by_module:
885
+ self.feature_hook_context_by_module[module] = FeatureHookContext(name)
886
+ context: FeatureHookContext = self.feature_hook_context_by_module[module]
887
+
888
+ tbtag_tensor_map = {}
889
+ if len(module_input) < 2:
890
+ logger.warning(
891
+ "Calculate attention feature failed, the length of module_input in attention hook's module should "
892
+ "be greater than or equal to 2.")
893
+
894
+ q_h = module_input[0]
895
+ k_h = module_input[1]
896
+ qkt = cal_qkt(q_h, k_h, order=self.sa_order)
897
+ tbtag_tensor_map.update(
898
+ self.build_tbtag_tensor_map(
899
+ f'{context.module_name}.attention', f'{MonitorConst.NAME_SEP}{context.micro_step}',
900
+ 'qkt', qkt))
901
+ get_entropy_metric(tbtag_tensor_map, context.attention_feature)
902
+
903
+ context.micro_step += 1
904
+ if context.micro_step == self.micro_batch_number:
905
+ context.micro_step = 0
906
+ context.step += 1
907
+ return
908
+
909
+ def extract_linear_sr_hook(module, args, kwargs, module_output, name):
910
+ weight_name = self.get_linear_hook_target(module)
911
+ if weight_name == "":
912
+ return
913
+ if module not in self.feature_hook_context_by_module:
914
+ self.feature_hook_context_by_module[module] = FeatureHookContext(name)
915
+ context: FeatureHookContext = self.feature_hook_context_by_module[module]
916
+
917
+ if context.micro_step == self.micro_batch_number - 1:
918
+ tbtag_tensor_map = {}
919
+ value = module.weight.data
920
+ tbtag_tensor_map.update(
921
+ self.build_tbtag_tensor_map(
922
+ f'{context.module_name}.linear', f'{MonitorConst.NAME_SEP}{context.micro_step}',
923
+ 'sr', value))
924
+
925
+ get_sr_metric(tbtag_tensor_map, context.linear_feature)
926
+
927
+ context.micro_step += 1
928
+ if context.micro_step == self.micro_batch_number:
929
+ context.micro_step = 0
930
+ context.step += 1
931
+ return
932
+
798
933
  def stack_hook(module, args, kwargs, module_output, name):
799
934
  if module not in self.module_fwd_hook_context_by_module:
800
935
  self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
@@ -824,6 +959,24 @@ class TrainerMon:
824
959
  self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
825
960
  logger.info(f"> {name} is monitored successfully")
826
961
  hooked_count += 1
962
+
963
+ if not self.print_struct and self.recording_l2_features:
964
+ for module_name, submodule in get_submodules(module):
965
+ func_map = {
966
+ "attention_hook": extract_attention_feature_hook,
967
+ "linear_hook": extract_linear_sr_hook
968
+ }
969
+ for hook in func_map.keys():
970
+ if hook in l2_target_names:
971
+ temp_names = l2_target_names[hook]
972
+ name = is_recording_module(module_name, temp_names, vpp_stage)
973
+ if name:
974
+ handle = fwd_hook_register(submodule, func_map[hook], name=name)
975
+ print_feature_name = hook.split('_')[0]
976
+ logger.info_on_rank_0(
977
+ f'> {print_feature_name} features of {name} is monitored successfully')
978
+ self.handles["L2_features"].append(handle)
979
+ hooked_count += 1
827
980
  return hooked_count
828
981
 
829
982
  def _patch_grad_sync(self):
@@ -889,11 +1042,16 @@ class TrainerMon:
889
1042
  for handle in self.handles['xy']:
890
1043
  handle.remove()
891
1044
  self.handles['xy'].clear()
1045
+ for handle in self.handles['L2_features']:
1046
+ handle.remove()
1047
+ self.handles['L2_features'].clear()
892
1048
  # 清空对应context缓存
893
- for _, fwd_context in self.module_fwd_hook_context_by_module.items():
1049
+ for fwd_context in self.module_fwd_hook_context_by_module.values():
894
1050
  fwd_context.reset()
895
- for _, bwd_context in self.module_bwd_hook_context_by_module.items():
1051
+ for bwd_context in self.module_bwd_hook_context_by_module.values():
896
1052
  bwd_context.reset()
1053
+ for feature_context in self.feature_hook_context_by_module.values():
1054
+ feature_context.reset()
897
1055
  self.grad_context.reset() # 权重梯度和激活值梯度都在这
898
1056
 
899
1057
  for handle in self.handles['wgrads']:
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
  from mindspore import dtype as mstype, Tensor
16
16
 
17
- from msprobe.mindspore.monitor.features import FUNC_MAP
17
+ from msprobe.mindspore.monitor.features import FUNC_MAP, cal_entropy, cal_stable_rank
18
18
 
19
19
 
20
20
  def get_single_metrics(op_list, tag, tensor, eps=1e-8, output=None):
@@ -75,3 +75,29 @@ def is_skip_step(step, start_step, step_interval, has_collect_times=0, collect_t
75
75
  :return: whether skip or not, bool
76
76
  """
77
77
  return step < start_step or (step - start_step) % step_interval != 0 or has_collect_times >= collect_times
78
+
79
+
80
+ def get_entropy_metric(tag2tensor, out_dict=None):
81
+ if out_dict is None:
82
+ out_dict = {}
83
+ for tag, tensor in tag2tensor.items():
84
+ if tag not in out_dict:
85
+ out_dict[tag] = {}
86
+ entropy, softmax = cal_entropy(tensor)
87
+ out_dict[tag]["entropy"] = entropy
88
+ out_dict[tag]["softmax"] = softmax
89
+ return out_dict
90
+
91
+
92
+ def get_sr_metric(tag2tensor, out_dict=None):
93
+ if out_dict is None:
94
+ out_dict = {}
95
+ for tag, tensor in tag2tensor.items():
96
+ if "sr" not in tag:
97
+ continue
98
+ if tag not in out_dict:
99
+ out_dict[tag] = {}
100
+ sr, eig = cal_stable_rank(tensor)
101
+ out_dict[tag]["sr"] = sr
102
+ out_dict[tag]["eig"] = eig
103
+ return out_dict
@@ -57,11 +57,12 @@ class StatisticsConfig(BaseConfig):
57
57
  raise Exception("Config param [precision] is invalid, expected from [\"high\", \"low\"]")
58
58
 
59
59
  def _check_summary_mode(self):
60
- muti_opt = ["md5", "max", "min", "mean", "l2norm"]
60
+ muti_opt = ["max", "min", "mean", "count", "negative zero count", "positive zero count", "nan count",
61
+ "negative inf count", "positive inf count", "zero count", "l2norm", "hash", "md5"]
61
62
  if isinstance(self.summary_mode, str) and self.summary_mode not in Const.SUMMARY_MODE:
62
- raise Exception("summary_mode is invalid")
63
+ raise Exception("summary_mode is an invalid string")
63
64
  if isinstance(self.summary_mode, list) and not all(opt in muti_opt for opt in self.summary_mode):
64
- raise Exception("summary_mode is invalid")
65
+ raise Exception("summary_mode contains invalid option(s)")
65
66
 
66
67
 
67
68
  class OverflowCheckConfig(BaseConfig):
@@ -79,6 +80,12 @@ class OverflowCheckConfig(BaseConfig):
79
80
  raise Exception("check_mode is invalid")
80
81
 
81
82
 
83
+ class ExceptionDumpConfig(BaseConfig):
84
+ def __init__(self, json_config):
85
+ super().__init__(json_config)
86
+ self.data_mode = ["all"]
87
+
88
+
82
89
  class FreeBenchmarkConfig(BaseConfig):
83
90
  def __init__(self, task_config):
84
91
  super().__init__(task_config)
@@ -128,7 +135,8 @@ TaskDict = {
128
135
  Const.OVERFLOW_CHECK: OverflowCheckConfig,
129
136
  Const.FREE_BENCHMARK: FreeBenchmarkConfig,
130
137
  Const.GRAD_PROBE: GradProbeConfig,
131
- Const.STRUCTURE: StructureConfig
138
+ Const.STRUCTURE: StructureConfig,
139
+ Const.EXCEPTION_DUMP: ExceptionDumpConfig
132
140
  }
133
141
 
134
142
 
@@ -32,7 +32,7 @@ class OverflowCheckToolFactory:
32
32
  Const.PYNATIVE_MODE: None
33
33
  },
34
34
  Const.KERNEL: {
35
- Const.GRAPH_KBYK_MODE: None,
35
+ Const.GRAPH_KBYK_MODE: KernelGraphOverflowCheck,
36
36
  Const.GRAPH_GE_MODE: KernelGraphOverflowCheck,
37
37
  Const.PYNATIVE_MODE: None
38
38
  }
@@ -18,6 +18,7 @@ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
18
18
  from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory
19
19
  from msprobe.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory
20
20
  from msprobe.mindspore.free_benchmark.self_check_tool_factory import SelfCheckToolFactory
21
+ from msprobe.mindspore.exception_dump.exception_dump_tool_factory import ExceptionDumpToolFactory
21
22
 
22
23
 
23
24
  class TaskHandlerFactory:
@@ -25,7 +26,8 @@ class TaskHandlerFactory:
25
26
  Const.TENSOR: DumpToolFactory,
26
27
  Const.STATISTICS: DumpToolFactory,
27
28
  Const.OVERFLOW_CHECK: OverflowCheckToolFactory,
28
- Const.FREE_BENCHMARK: SelfCheckToolFactory
29
+ Const.FREE_BENCHMARK: SelfCheckToolFactory,
30
+ Const.EXCEPTION_DUMP: ExceptionDumpToolFactory
29
31
  }
30
32
 
31
33
  @staticmethod
@@ -16,8 +16,8 @@
16
16
  from dataclasses import dataclass
17
17
  from msprobe.core.common.const import Const
18
18
  from msprobe.core.common.log import logger
19
- from msprobe.core.common.exceptions import MsprobeException
20
19
  from msprobe.nan_analyze.utils import FileCache, RankPath, is_ignore_op, check_item_anomaly, NanAnalyseConst
20
+ from msprobe.core.common.exceptions import MsprobeException
21
21
 
22
22
 
23
23
  @dataclass
@@ -24,8 +24,7 @@ from msprobe.pytorch.pt_config import RunUTConfig
24
24
 
25
25
  RunUtConfig = namedtuple('RunUtConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
26
26
  'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
27
- 'black_list', 'error_data_path', 'online_config'])
28
- OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
27
+ 'black_list', 'error_data_path'])
29
28
 
30
29
 
31
30
  class Config:
@@ -46,13 +45,7 @@ class Config:
46
45
  'white_list': list,
47
46
  'black_list': list,
48
47
  'error_data_path': str,
49
- 'precision': int,
50
- 'is_online': bool,
51
- 'nfs_path': str,
52
- 'host': str,
53
- 'port': int,
54
- 'rank_list': list,
55
- 'tls_path': str
48
+ 'precision': int
56
49
  }
57
50
  if key not in validators:
58
51
  raise ValueError(f"{key} must be one of {validators.keys()}")
@@ -68,10 +61,6 @@ class Config:
68
61
  RunUTConfig.check_filter_list_config(key, value)
69
62
  if key == 'error_data_path':
70
63
  RunUTConfig.check_error_data_path_config(value)
71
- if key == 'nfs_path':
72
- RunUTConfig.check_nfs_path_config(value)
73
- if key == 'tls_path':
74
- RunUTConfig.check_tls_path_config(value)
75
64
  return value
76
65
 
77
66
 
@@ -85,12 +74,6 @@ class CheckerConfig:
85
74
  self.white_list = msCheckerConfig.white_list
86
75
  self.black_list = msCheckerConfig.black_list
87
76
  self.error_data_path = msCheckerConfig.error_data_path
88
- self.is_online = msCheckerConfig.is_online
89
- self.nfs_path = msCheckerConfig.nfs_path
90
- self.host = msCheckerConfig.host
91
- self.port = msCheckerConfig.port
92
- self.rank_list = msCheckerConfig.rank_list
93
- self.tls_path = msCheckerConfig.tls_path
94
77
 
95
78
  if task_config:
96
79
  self.load_config(task_config)
@@ -99,22 +82,7 @@ class CheckerConfig:
99
82
  self.white_list = task_config.white_list
100
83
  self.black_list = task_config.black_list
101
84
  self.error_data_path = task_config.error_data_path
102
- self.is_online = task_config.is_online
103
- self.nfs_path = task_config.nfs_path
104
- self.host = task_config.host
105
- self.port = task_config.port
106
- self.rank_list = task_config.rank_list
107
- self.tls_path = task_config.tls_path
108
85
 
109
- def get_online_config(self):
110
- return OnlineConfig(
111
- is_online=self.is_online,
112
- nfs_path=self.nfs_path,
113
- host=self.host,
114
- port=self.port,
115
- rank_list=self.rank_list,
116
- tls_path=self.tls_path
117
- )
118
86
 
119
87
  def get_run_ut_config(self, **config_params):
120
88
  return RunUtConfig(
@@ -127,6 +95,5 @@ class CheckerConfig:
127
95
  real_data_path=config_params.get('real_data_path'),
128
96
  white_list=self.white_list.copy() if self.white_list else [],
129
97
  black_list=self.black_list.copy() if self.black_list else [],
130
- error_data_path=config_params.get('error_data_path'),
131
- online_config=self.get_online_config()
98
+ error_data_path=config_params.get('error_data_path')
132
99
  )
@@ -117,30 +117,6 @@ def api_precision_compare(config):
117
117
  change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
118
118
 
119
119
 
120
- def online_api_precision_compare(online_config):
121
- rank = online_config.rank
122
- result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace(
123
- "_rank*.csv", f"_rank{rank}.csv")
124
- details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace(
125
- "_rank*.csv", f"_rank{rank}.csv")
126
- detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
127
- result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
128
- if not os.path.exists(result_csv_path):
129
- write_csv(result_csv_title, result_csv_path)
130
- if not os.path.exists(details_csv_path):
131
- write_csv(detail_csv_title, details_csv_path)
132
- config = CompareConfig("", "", result_csv_path, details_csv_path)
133
- try:
134
- npu_data, gpu_data = online_config.npu_data, online_config.gpu_data
135
- check_csv_columns(npu_data.columns, "npu_csv")
136
- check_csv_columns(gpu_data.columns, "gpu_csv")
137
- analyse_csv(npu_data, gpu_data, config)
138
- except Exception as err:
139
- logger.error(f"Online api precision compare Error: {str(err)}")
140
- change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
141
- change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
142
-
143
-
144
120
  def analyse_csv(npu_data, gpu_data, config):
145
121
  forward_status, backward_status = [], []
146
122
  last_api_name, last_api_dtype, last_api_full_name = None, None, None
@@ -66,13 +66,6 @@ class Comparator:
66
66
  self.save_path_list = [result_csv_path]
67
67
  self.detail_save_path_list = [details_csv_path]
68
68
 
69
- if config and config.online_config.is_online:
70
- self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv")
71
- self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv")
72
- self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list]
73
- self.detail_save_path_list = \
74
- [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
75
-
76
69
  self.registry = self._register_compare_func()
77
70
 
78
71
  if not is_continue_run_ut:
@@ -245,9 +238,8 @@ class Comparator:
245
238
  self.write_detail_csv(args)
246
239
 
247
240
 
248
- def compare_output(self, full_api_name, data_info, is_online=False):
241
+ def compare_output(self, full_api_name, data_info):
249
242
  """Get compare result and write to result and detail csv.
250
- is_online: bool, default False. True: called by online api precision compare, only compare without write to csv.
251
243
  """
252
244
  _, api_name = extract_basic_api_segments(full_api_name)
253
245
  if not api_name:
@@ -280,9 +272,7 @@ class Comparator:
280
272
  fwd_compare_alg_results,
281
273
  bwd_compare_alg_results,
282
274
  data_info.rank)
283
- if is_online:
284
- # get run_ut compare detail
285
- return self._get_run_ut_detail(result_info)
275
+
286
276
  self.record_results(result_info)
287
277
  return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
288
278
  or bwd_success_status == CompareConst.SPACE
@@ -2,9 +2,4 @@ white_list: []
2
2
  black_list: []
3
3
  error_data_path: './'
4
4
  precision: 14
5
- is_online: False
6
- nfs_path: ""
7
- host: ""
8
- port: -1
9
- rank_list: [0]
10
- tls_path: "./"
5
+
@@ -84,8 +84,8 @@ def split_json_file(input_file, num_splits, filter_api):
84
84
  for file in split_files:
85
85
  try:
86
86
  remove_path(file)
87
- except FileNotFoundError:
88
- logger.error(f"File not found and could not be deleted: {file}")
87
+ except Exception:
88
+ logger.error(f"File not found or could not be deleted: {file}")
89
89
  msg = 'ERROR: Split json file failed, please check the input file and try again.'
90
90
  raise CompareException(CompareException.PARSE_FILE_ERROR, msg) from e
91
91
  return split_files, total_items