mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +14 -19
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +155 -6
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +3 -0
  10. msprobe/core/common/utils.py +28 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +380 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/multiprocessing_compute.py +2 -2
  22. msprobe/core/compare/npy_compare.py +109 -147
  23. msprobe/core/compare/utils.py +189 -69
  24. msprobe/core/data_dump/data_collector.py +51 -21
  25. msprobe/core/data_dump/data_processor/base.py +38 -20
  26. msprobe/core/data_dump/data_processor/factory.py +5 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
  29. msprobe/core/data_dump/json_writer.py +29 -1
  30. msprobe/core/data_dump/scope.py +19 -18
  31. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  32. msprobe/core/overflow_check/checker.py +1 -1
  33. msprobe/core/overflow_check/utils.py +1 -1
  34. msprobe/docs/01.installation.md +96 -17
  35. msprobe/docs/02.config_introduction.md +5 -5
  36. msprobe/docs/05.data_dump_PyTorch.md +91 -61
  37. msprobe/docs/06.data_dump_MindSpore.md +57 -19
  38. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  39. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
  40. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  41. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  42. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  43. msprobe/docs/19.monitor.md +120 -27
  44. msprobe/docs/21.visualization_PyTorch.md +115 -35
  45. msprobe/docs/22.visualization_MindSpore.md +138 -41
  46. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  47. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  48. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  49. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  50. msprobe/docs/27.dump_json_instruction.md +521 -0
  51. msprobe/docs/FAQ.md +26 -2
  52. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  53. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  54. msprobe/docs/img/merge_result.png +0 -0
  55. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  56. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  57. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  58. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  59. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  60. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  61. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  63. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  64. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  65. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  66. msprobe/docs/visualization/GPTModel.png +0 -0
  67. msprobe/docs/visualization/ParallelMLP.png +0 -0
  68. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  69. msprobe/docs/visualization/mapping.png +0 -0
  70. msprobe/docs/visualization/mapping1.png +0 -0
  71. msprobe/docs/visualization/module_name.png +0 -0
  72. msprobe/docs/visualization/module_name1.png +0 -0
  73. msprobe/docs/visualization/no_mapping.png +0 -0
  74. msprobe/docs/visualization/no_mapping1.png +0 -0
  75. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  76. msprobe/docs/visualization/top_layer.png +0 -0
  77. msprobe/mindspore/__init__.py +10 -0
  78. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
  79. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  80. msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
  81. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  82. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  83. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  84. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  85. msprobe/mindspore/code_mapping/bind.py +264 -0
  86. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  87. msprobe/mindspore/code_mapping/graph.py +49 -0
  88. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  89. msprobe/mindspore/code_mapping/main.py +24 -0
  90. msprobe/mindspore/code_mapping/processor.py +34 -0
  91. msprobe/mindspore/common/const.py +3 -1
  92. msprobe/mindspore/common/utils.py +50 -5
  93. msprobe/mindspore/compare/distributed_compare.py +0 -2
  94. msprobe/mindspore/compare/ms_compare.py +105 -63
  95. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  96. msprobe/mindspore/debugger/debugger_config.py +3 -0
  97. msprobe/mindspore/debugger/precision_debugger.py +81 -12
  98. msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
  99. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  100. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  101. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  102. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  103. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  104. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  105. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  106. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  107. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  108. msprobe/mindspore/grad_probe/hook.py +13 -4
  109. msprobe/mindspore/mindtorch/__init__.py +18 -0
  110. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  111. msprobe/mindspore/ms_config.py +5 -1
  112. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  113. msprobe/mindspore/service.py +267 -101
  114. msprobe/msprobe.py +24 -3
  115. msprobe/pytorch/__init__.py +7 -6
  116. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  117. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  123. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  124. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
  125. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  126. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  127. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  128. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  129. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  130. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  131. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  132. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  133. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  134. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  135. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  136. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  140. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  141. msprobe/pytorch/common/parse_json.py +2 -1
  142. msprobe/pytorch/common/utils.py +45 -2
  143. msprobe/pytorch/compare/distributed_compare.py +17 -29
  144. msprobe/pytorch/compare/pt_compare.py +40 -20
  145. msprobe/pytorch/debugger/debugger_config.py +27 -12
  146. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  147. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  148. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  149. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
  150. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  151. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  152. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  153. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  154. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  155. msprobe/pytorch/hook_module/__init__.py +1 -1
  156. msprobe/pytorch/hook_module/hook_module.py +14 -11
  157. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  158. msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
  159. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  160. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  161. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  162. msprobe/pytorch/monitor/anomaly_detect.py +107 -22
  163. msprobe/pytorch/monitor/csv2tb.py +166 -0
  164. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  165. msprobe/pytorch/monitor/features.py +3 -3
  166. msprobe/pytorch/monitor/module_hook.py +483 -277
  167. msprobe/pytorch/monitor/module_metric.py +27 -48
  168. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  169. msprobe/pytorch/monitor/optimizer_collect.py +52 -14
  170. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  171. msprobe/pytorch/monitor/utils.py +77 -6
  172. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  173. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  174. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  175. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  176. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  177. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  178. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  179. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  180. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  181. msprobe/pytorch/service.py +176 -106
  182. msprobe/visualization/builder/graph_builder.py +62 -5
  183. msprobe/visualization/builder/msprobe_adapter.py +24 -2
  184. msprobe/visualization/compare/graph_comparator.py +64 -14
  185. msprobe/visualization/compare/mode_adapter.py +1 -15
  186. msprobe/visualization/graph/base_node.py +12 -17
  187. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  188. msprobe/visualization/graph/graph.py +9 -0
  189. msprobe/visualization/graph_service.py +97 -23
  190. msprobe/visualization/utils.py +14 -29
  191. msprobe/pytorch/functional/module_dump.py +0 -84
  192. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  193. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
  194. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
  195. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  196. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  197. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -14,19 +14,29 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os.path
17
+
17
18
  import torch
19
+
18
20
  from msprobe.core.common.const import FileCheckConst
19
- from msprobe.pytorch.common.log import logger
20
21
  from msprobe.core.common.exceptions import FileCheckException
21
- from msprobe.core.compare.acc_compare import Comparator
22
- from msprobe.core.common.utils import check_configuration_param, check_compare_param, \
23
- CompareException, set_dump_path, get_dump_mode
24
22
  from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
23
+ from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
24
+ set_dump_path
25
+ from msprobe.core.compare.acc_compare import Comparator, ModeConfig
26
+ from msprobe.core.compare.utils import set_stack_json_path
27
+ from msprobe.pytorch.common.log import logger
25
28
  from msprobe.pytorch.common.utils import load_pt
26
29
 
27
30
 
28
- class PTComparator (Comparator):
29
- def __init__(self, data_mapping=None):
31
+ class PTComparator(Comparator):
32
+ def __init__(self, mode_config, data_mapping=None):
33
+ super().__init__(mode_config)
34
+
35
+ self.stack_mode = mode_config.stack_mode
36
+ self.auto_analyze = mode_config.auto_analyze
37
+ self.fuzzy_match = mode_config.fuzzy_match
38
+ self.dump_mode = mode_config.dump_mode
39
+
30
40
  self.frame_name = PTComparator.__name__
31
41
  self.data_mapping = data_mapping
32
42
  if isinstance(self.data_mapping, str) or self.data_mapping is None:
@@ -37,23 +47,24 @@ class PTComparator (Comparator):
37
47
  raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
38
48
  f"{type(self.data_mapping)}")
39
49
 
40
- def load_mapping_file(self, mapping_file):
50
+ @staticmethod
51
+ def load_mapping_file(mapping_file):
41
52
  if isinstance(mapping_file, str):
42
53
  mapping_dict = load_yaml(mapping_file)
43
54
  else:
44
55
  mapping_dict = {}
45
56
  return mapping_dict
46
-
57
+
47
58
  def read_npy_data(self, dir_path, file_name):
48
59
  if not file_name:
49
60
  return None
50
61
  data_path = os.path.join(dir_path, file_name)
51
62
  path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
52
- FileCheckConst.PT_SUFFIX, False)
63
+ FileCheckConst.PT_SUFFIX, False)
53
64
  data_path = path_checker.common_check()
54
65
  try:
55
- data_value = load_pt(data_path,
56
- to_cpu=True).detach() # detach because numpy can not process gradient information
66
+ # detach because numpy can not process gradient information
67
+ data_value = load_pt(data_path, to_cpu=True).detach()
57
68
  except RuntimeError as e:
58
69
  # 这里捕获 load_pt 中抛出的异常
59
70
  logger.error(f"Failed to load the .pt file at {data_path}.")
@@ -65,20 +76,29 @@ class PTComparator (Comparator):
65
76
  if data_value.dtype == torch.bfloat16:
66
77
  data_value = data_value.to(torch.float32)
67
78
  data_value = data_value.numpy()
68
- return data_value
69
-
70
-
71
- def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False, **kwargs):
79
+ return data_value
80
+
81
+
82
+ def compare(input_param, output_path, **kwargs):
72
83
  try:
84
+ auto_analyze = kwargs.get('auto_analyze', True)
85
+ fuzzy_match = kwargs.get('fuzzy_match', False)
86
+ data_mapping = kwargs.get('data_mapping', None)
87
+ suffix = kwargs.get('suffix', '')
88
+
73
89
  set_dump_path(input_param)
74
90
  dump_mode = get_dump_mode(input_param)
91
+ if "stack_json_path" in input_param:
92
+ stack_mode = kwargs.get('stack_mode', False)
93
+ else:
94
+ stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
75
95
  check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
76
96
  create_directory(output_path)
77
- check_compare_param(input_param, output_path, dump_mode)
78
- data_mapping = kwargs.get('data_mapping', None)
97
+ check_compare_param(input_param, output_path, dump_mode, stack_mode)
79
98
  except (CompareException, FileCheckException) as error:
80
99
  logger.error('Compare failed. Please check the arguments and do it again!')
81
100
  raise CompareException(error.code) from error
82
- pt_comparator = PTComparator(data_mapping)
83
- pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
84
- auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, dump_mode=dump_mode)
101
+
102
+ mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode)
103
+ pt_comparator = PTComparator(mode_config, data_mapping)
104
+ pt_comparator.compare_core(input_param, output_path, suffix=suffix)
@@ -34,6 +34,7 @@ class DebuggerConfig:
34
34
  self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
35
35
  self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
36
36
  self.framework = Const.PT_FRAMEWORK
37
+ self.async_dump = common_config.async_dump if common_config.async_dump else False
37
38
 
38
39
  if self.level == Const.LEVEL_L2:
39
40
  self.is_backward_kernel_dump = False
@@ -74,29 +75,43 @@ class DebuggerConfig:
74
75
  if not self.dump_path:
75
76
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
76
77
  f"The dump_path not found.")
78
+ if not isinstance(self.async_dump, bool):
79
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
80
+ f"The parameters async_dump should be bool.")
77
81
 
78
82
  def check(self):
79
83
  self.check_kwargs()
80
84
  return True
81
85
 
82
86
  def check_model(self, instance, start_model):
83
- if self.level not in ["L0", "mix"]:
87
+ if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
84
88
  if instance.model is not None or start_model is not None:
85
- logger.warning_on_rank_0(
89
+ logger.info_on_rank_0(
86
90
  f"The current level is not L0 or mix level, so the model parameters will not be used.")
87
91
  return
88
- if start_model is None:
89
- if instance.model is None:
90
- logger.error_on_rank_0(
91
- f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' argument.")
92
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
93
- return
94
- if isinstance(start_model, torch.nn.Module):
95
- instance.model = start_model
92
+ if start_model is None and instance.model is None:
93
+ logger.error_on_rank_0(
94
+ f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' parameter.")
95
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
96
+
97
+ instance.model = start_model if start_model is not None else instance.model
98
+ if isinstance(instance.model, torch.nn.Module):
99
+ return
100
+
101
+ error_model = None
102
+ if isinstance(instance.model, (list, tuple)):
103
+ for model in instance.model:
104
+ if not isinstance(model, torch.nn.Module):
105
+ error_model = model
106
+ break
96
107
  else:
97
- logger.error_on_rank_0(f"The 'model' parameter of start must be a torch.nn.Module type.")
108
+ error_model = instance.model
109
+
110
+ if error_model is not None:
111
+ error_info = (f"The 'model' parameter must be a torch.nn.Moudle or list[torch.nn.Moudle] "
112
+ f"type, currently there is a {type(error_model)} type.")
98
113
  raise MsprobeException(
99
- MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
114
+ MsprobeException.INVALID_PARAM_ERROR, error_info)
100
115
 
101
116
  def _check_and_adjust_config_with_l2(self):
102
117
  if self.scope:
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -22,6 +22,7 @@ from msprobe.core.common.file_utils import FileChecker
22
22
  from msprobe.core.common.utils import get_real_step_or_rank
23
23
  from msprobe.pytorch.common.log import logger
24
24
  from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
25
+ from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper
25
26
  from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
26
27
  from msprobe.pytorch.pt_config import parse_json_config
27
28
  from msprobe.pytorch.service import Service
@@ -49,7 +50,7 @@ class PrecisionDebugger:
49
50
  dump_path=None,
50
51
  level=None,
51
52
  model=None,
52
- step=None,
53
+ step=None
53
54
  ):
54
55
  if not hasattr(self, "initialized"):
55
56
  config_params = ConfigParameters(config_path,
@@ -59,7 +60,6 @@ class PrecisionDebugger:
59
60
  model)
60
61
  self.check_input_params(config_params)
61
62
 
62
- self.api_origin = False
63
63
  self.initialized = True
64
64
  self.model = model
65
65
  common_config, task_config = parse_json_config(config_path, task)
@@ -67,12 +67,13 @@ class PrecisionDebugger:
67
67
  if self.task == Const.GRAD_PROBE:
68
68
  self.gm = GradientMonitor(common_config, task_config)
69
69
  return
70
- if step:
70
+ if step is not None:
71
71
  common_config.step = get_real_step_or_rank(step, Const.STEP)
72
72
  self.config = DebuggerConfig(
73
73
  common_config, task_config, task, dump_path, level
74
74
  )
75
75
  self.service = Service(self.config)
76
+ self.module_dumper = ModuleDumper(self.service)
76
77
  self.enable_dataloader = self.config.enable_dataloader
77
78
  if self.enable_dataloader:
78
79
  logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
@@ -105,9 +106,11 @@ class PrecisionDebugger:
105
106
  raise MsprobeException(
106
107
  MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
107
108
 
108
- if args.model is not None and not isinstance(args.model, torch.nn.Module):
109
- raise MsprobeException(
110
- MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
109
+ if args.model is not None:
110
+ logger.warning_on_rank_0(
111
+ "The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
112
+ "It is recommended to pass the 'model' parameter in the start interface instead."
113
+ )
111
114
 
112
115
  @classmethod
113
116
  def start(cls, model=None):
@@ -120,15 +123,12 @@ class PrecisionDebugger:
120
123
  if instance.enable_dataloader:
121
124
  logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
122
125
  else:
123
- instance.service.start(instance.model, instance.api_origin)
124
- instance.api_origin = False
126
+ instance.service.start(instance.model)
125
127
 
126
- # 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
127
128
  @classmethod
128
129
  def forward_backward_dump_end(cls):
129
130
  instance = cls._instance
130
- instance.service.forward_backward_dump_end()
131
- instance.api_origin = True
131
+ instance.stop()
132
132
 
133
133
  @classmethod
134
134
  def stop(cls):
@@ -159,6 +159,36 @@ class PrecisionDebugger:
159
159
  cls._instance.gm.monitor(model)
160
160
 
161
161
 
162
+ def module_dump(module, dump_name):
163
+ if not isinstance(module, torch.nn.Module):
164
+ raise MsprobeException(
165
+ MsprobeException.INVALID_PARAM_ERROR,
166
+ f"the module argument in module_dump must be a torch.nn.Module subclass"
167
+ )
168
+ if not isinstance(dump_name, str):
169
+ raise MsprobeException(
170
+ MsprobeException.INVALID_PARAM_ERROR,
171
+ f"the dump_name argument in module_dump must be a str type"
172
+ )
173
+ instance = PrecisionDebugger._instance
174
+ if not instance:
175
+ raise MsprobeException(
176
+ MsprobeException.INTERFACE_USAGE_ERROR,
177
+ f"PrecisionDebugger must be instantiated before using module_dump interface"
178
+ )
179
+ instance.module_dumper.start_module_dump(module, dump_name)
180
+
181
+
182
+ def module_dump_end():
183
+ instance = PrecisionDebugger._instance
184
+ if not instance:
185
+ raise MsprobeException(
186
+ MsprobeException.INTERFACE_USAGE_ERROR,
187
+ f"PrecisionDebugger must be instantiated before using module_dump_end interface"
188
+ )
189
+ instance.module_dumper.stop_module_dump()
190
+
191
+
162
192
  def iter_tracer(func):
163
193
  def func_wrapper(*args, **kwargs):
164
194
  debugger_instance = PrecisionDebugger.instance
File without changes
@@ -0,0 +1,86 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from msprobe.core.common.const import Const
18
+ from msprobe.core.data_dump.scope import BaseScope
19
+ from msprobe.pytorch.common.log import logger
20
+ from msprobe.pytorch.hook_module.api_registry import api_register
21
+
22
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
23
+
24
+
25
+ class ModuleDumper:
26
+ def __init__(self, service):
27
+ self.service = service
28
+ self.hook_handle_list = []
29
+
30
+ def start_module_dump(self, module, dump_name):
31
+ api_register.api_originality()
32
+ self.register_hook(module, dump_name)
33
+
34
+ def stop_module_dump(self):
35
+ api_register.api_modularity()
36
+ for hook_handle in self.hook_handle_list:
37
+ if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
38
+ hook_handle.remove()
39
+ self.hook_handle_list.clear()
40
+
41
+ def register_hook(self, module, dump_name):
42
+ prefix_name = (
43
+ BaseScope.Module_Type_Module + Const.SEP +
44
+ dump_name + Const.SEP +
45
+ module.__class__.__name__ + Const.SEP
46
+ )
47
+ module_processor = self.service.module_processor
48
+ _, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.service.build_hook(
49
+ BaseScope.Module_Type_Module,
50
+ prefix_name
51
+ )
52
+
53
+ if module_processor.has_register_backward_hook(module):
54
+ logger.warning(
55
+ f"The {dump_name} module has registered deprecated register_backward_hook,"
56
+ f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
57
+ )
58
+ if torch_version_above_or_equal_2:
59
+ forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True)
60
+ else:
61
+ if not module_processor.has_register_backward_hook(module):
62
+ backward_hook_handle = module.register_full_backward_hook(
63
+ module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
64
+ )
65
+ self.hook_handle_list.append(backward_hook_handle)
66
+ forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2)
67
+ self.hook_handle_list.append(forward_hook_handle)
68
+ if not module_processor.has_register_backward_hook(module):
69
+ backward_hook_handle = module.register_full_backward_hook(backward_hook)
70
+ self.hook_handle_list.append(backward_hook_handle)
71
+
72
+ forward_pre_hook_handle = module.register_forward_pre_hook(
73
+ module_processor.node_hook(prefix_name + Const.FORWARD, Const.START)
74
+ )
75
+ forward_hook_handle = module.register_forward_hook(
76
+ module_processor.node_hook(prefix_name + Const.FORWARD, Const.STOP)
77
+ )
78
+ self.hook_handle_list.extend([forward_pre_hook_handle, forward_hook_handle])
79
+ if torch_version_above_or_equal_2 and not module_processor.has_register_backward_hook(module):
80
+ backward_pre_hook_handle = module.register_full_backward_pre_hook(
81
+ module_processor.node_hook(prefix_name + Const.BACKWARD, Const.START)
82
+ )
83
+ backward_hook_handle = module.register_full_backward_hook(
84
+ module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
85
+ )
86
+ self.hook_handle_list.extend([backward_pre_hook_handle, backward_hook_handle])
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,12 +17,24 @@ from functools import wraps
17
17
 
18
18
  import torch
19
19
  from msprobe.core.common.const import Const
20
- from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope
20
+ from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
21
+ from msprobe.pytorch.common.log import logger
22
+ from torch.utils.checkpoint import checkpoint as origin_checkpoint
23
+ from torch.utils.checkpoint import set_checkpoint_early_stop
21
24
  from torch.utils.hooks import BackwardHook
22
25
 
23
26
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
24
27
 
25
28
 
29
+ def checkpoint_without_early_stop(*args, **kwargs):
30
+ with set_checkpoint_early_stop(False):
31
+ return origin_checkpoint(*args, **kwargs)
32
+
33
+
34
+ def replace_checkpoint():
35
+ torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
36
+
37
+
26
38
  class ModuleProcesser:
27
39
  module_count = {}
28
40
  module_stack = []
@@ -34,6 +46,7 @@ class ModuleProcesser:
34
46
  BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
35
47
  BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
36
48
  BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook)
49
+ replace_checkpoint()
37
50
 
38
51
  @staticmethod
39
52
  def filter_tensor_and_tuple(func):
@@ -63,7 +76,7 @@ class ModuleProcesser:
63
76
  return ModuleProcesser.clone_if_tensor(result)
64
77
 
65
78
  return clone_return_value_func
66
-
79
+
67
80
  @staticmethod
68
81
  def clone_if_tensor(result):
69
82
  if isinstance(result, torch.Tensor):
@@ -85,6 +98,22 @@ class ModuleProcesser:
85
98
  ModuleProcesser.module_count[module_name] += 1
86
99
  return ModuleProcesser.module_count[module_name]
87
100
 
101
+ @staticmethod
102
+ def has_register_backward_hook(module):
103
+ return hasattr(module, '_backward_hooks') and \
104
+ len(module._backward_hooks) > 0 and \
105
+ module._is_full_backward_hook is False
106
+
107
+ @staticmethod
108
+ def get_modules_and_names(models):
109
+ modules_and_names_with_index = {}
110
+ if isinstance(models, (list, tuple)):
111
+ for index, model in enumerate(models):
112
+ modules_and_names_with_index[str(index)] = model.named_modules()
113
+ else:
114
+ modules_and_names_with_index["-1"] = models.named_modules()
115
+ return modules_and_names_with_index
116
+
88
117
  @classmethod
89
118
  def reset_module_stats(cls):
90
119
  cls.module_count = {}
@@ -92,6 +121,42 @@ class ModuleProcesser:
92
121
  cls.api_parent_node = ""
93
122
  cls.module_node = {}
94
123
 
124
+ def register_module_hook(self, models, build_hook):
125
+ logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.")
126
+ modules_and_names_with_index = self.get_modules_and_names(models)
127
+ for index, modules_and_names in modules_and_names_with_index.items():
128
+ model = models if index == "-1" else models[int(index)]
129
+ for name, module in modules_and_names:
130
+ if module == model:
131
+ continue
132
+ module_index = (index + Const.SEP) if index != "-1" else ""
133
+ prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
134
+ name + Const.SEP + module.__class__.__name__ + Const.SEP)
135
+ pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook(
136
+ BaseScope.Module_Type_Module,
137
+ prefix_name
138
+ )
139
+
140
+ if self.has_register_backward_hook(module):
141
+ logger.warning(
142
+ f"The {prefix_name[:-1]} has registered deprecated register_backward_hook,"
143
+ f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
144
+ )
145
+ if torch_version_above_or_equal_2:
146
+ module.register_forward_hook(forward_hook, with_kwargs=True)
147
+ else:
148
+ if not self.has_register_backward_hook(module):
149
+ module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
150
+ module.register_forward_hook(forward_hook_torch_version_below_2)
151
+ if not self.has_register_backward_hook(module):
152
+ module.register_full_backward_hook(backward_hook)
153
+
154
+ module.register_forward_pre_hook(self.node_hook(prefix_name + Const.FORWARD, Const.START))
155
+ module.register_forward_hook(self.node_hook(prefix_name + Const.FORWARD, Const.STOP))
156
+ if torch_version_above_or_equal_2 and not self.has_register_backward_hook(module):
157
+ module.register_full_backward_pre_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.START))
158
+ module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
159
+
95
160
  def node_hook(self, name_prefix, start_or_stop, **kwargs):
96
161
 
97
162
  def pre_hook(module, input, output=None):
@@ -100,7 +165,10 @@ class ModuleProcesser:
100
165
  except IndexError as e:
101
166
  index = None
102
167
  pass
103
- module.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
168
+ full_name = name_prefix + Const.SEP + str(index)
169
+ if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
170
+ module.mindstudio_reserved_name = []
171
+ module.mindstudio_reserved_name.append(full_name)
104
172
  if self.module_stack:
105
173
  ModuleProcesser.module_node[full_name] = self.module_stack[-1]
106
174
  else:
@@ -119,8 +187,11 @@ class ModuleProcesser:
119
187
  ModuleProcesser.api_parent_node = self.module_stack[-1]
120
188
  else:
121
189
  ModuleProcesser.api_parent_node = None
190
+ if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
191
+ raise RuntimeError(f"module reserve name is None when pop")
192
+ current_name = module.mindstudio_reserved_name.pop()
122
193
  if self.scope:
123
- self.scope.end_module(module.mindstudio_reserved_name)
194
+ self.scope.end_module(current_name)
124
195
 
125
196
  def backward_hook(module, input, output=None):
126
197
  try:
@@ -128,7 +199,10 @@ class ModuleProcesser:
128
199
  except IndexError as e:
129
200
  index = None
130
201
  pass
131
- module.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
202
+ full_name = name_prefix + Const.SEP + str(index)
203
+ if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
204
+ module.mindstudio_reserved_name = []
205
+ module.mindstudio_reserved_name.append(full_name)
132
206
  forward_full_name = full_name.replace(Const.BACKWARD, Const.FORWARD)
133
207
  ModuleProcesser.module_node[full_name] = ModuleProcesser.module_node[forward_full_name].replace(
134
208
  Const.FORWARD, Const.BACKWARD) if ModuleProcesser.module_node[forward_full_name] else None
@@ -39,7 +39,6 @@ class DataParams:
39
39
  origin_func: Optional[Callable] = None
40
40
  api_type: Optional[str] = None
41
41
  fuzz_stage: Optional[str] = None
42
- grad_unequal_flag: Optional[bool] = True
43
42
 
44
43
 
45
44
  @dataclass
@@ -127,6 +126,8 @@ def make_unequal_row(
127
126
  )
128
127
  if isinstance(ratio, float):
129
128
  row.max_rel = ratio - 1
129
+ if isinstance(ratio, str):
130
+ row.max_rel = ratio
130
131
  origin_tensor = data_params.original_result
131
132
  perturbed_tensor = data_params.perturbed_result
132
133
  if index is not None:
@@ -124,6 +124,7 @@ class TorchC:
124
124
  abs = torch._C._VariableFunctionsClass.abs
125
125
  where = torch._C._VariableFunctionsClass.where
126
126
  div = torch._C._VariableFunctionsClass.div
127
+ mul = torch._C._VariableFunctionsClass.mul
127
128
  max = torch._C._VariableFunctionsClass.max
128
129
  min = torch._C._VariableFunctionsClass.min
129
130
  gt = torch._C._VariableFunctionsClass.gt
@@ -138,3 +139,5 @@ class TorchC:
138
139
  tensor_split = torch._C._VariableFunctionsClass.tensor_split
139
140
  stack = torch._C._VariableFunctionsClass.stack
140
141
  reshape = torch._C._VariableFunctionsClass.reshape
142
+ nan_to_num = torch._C._VariableFunctionsClass.nan_to_num
143
+ aminmax = torch._C._VariableFunctionsClass.aminmax
@@ -82,13 +82,11 @@ class GradSaver:
82
82
  data_params = DataParams()
83
83
  data_params.original_result = origin_grad
84
84
  data_params.perturbed_result = perturbed_grad
85
- data_params.grad_unequal_flag = False
86
85
  data_params.valid_input_index = index
87
86
  try:
88
87
  handler.handle(data_params)
89
88
  if not data_params.is_consistent:
90
89
  self.is_compare = False
91
- data_params.grad_unequal_flag = True
92
90
  data_params.is_consistent = True
93
91
  data_params.perturbed_result = self.perturbed_grad_input
94
92
  data_params.original_result = self.origin_grad_input