mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,15 +14,20 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
- from collections import defaultdict
17
+ from collections import defaultdict, namedtuple
18
18
 
19
19
  import mindspore as ms
20
20
  from mindspore._c_expression import MSContext
21
21
 
22
- from msprobe.core.common.const import Const, MsgConst
22
+ from msprobe.core.common.const import Const, FileCheckConst, MsgConst
23
+ from msprobe.core.common.exceptions import MsprobeException
24
+ from msprobe.core.common.file_utils import FileChecker
25
+ from msprobe.core.common.utils import get_real_step_or_rank
23
26
  from msprobe.mindspore.cell_processor import CellProcessor
24
27
  from msprobe.mindspore.common.const import Const as MsConst
28
+ from msprobe.mindspore.common.utils import set_register_backward_hook_functions, check_save_param
25
29
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
30
+ from msprobe.mindspore.dump.hook_cell.api_registry import api_register
26
31
  from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
27
32
  from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor
28
33
  from msprobe.mindspore.ms_config import parse_json_config
@@ -30,12 +35,21 @@ from msprobe.mindspore.runtime import Runtime
30
35
  from msprobe.mindspore.service import Service
31
36
  from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
32
37
 
38
+ try:
39
+ from msprobe.lib import _msprobe_c
40
+ except ImportError:
41
+ _msprobe_c = None
42
+
43
+
44
+ ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task", "dump_path", "level"])
45
+
33
46
 
34
47
  class PrecisionDebugger:
35
48
  _instance = None
36
49
  task_not_need_service = [Const.GRAD_PROBE]
37
50
 
38
- def __new__(cls, config_path=None, opt=None):
51
+ def __new__(cls, config_path=None, task=None, dump_path=None,
52
+ level=None, step=None, opt=None):
39
53
  if not cls._instance:
40
54
  cls._instance = super().__new__(cls)
41
55
  cls._instance.initialized = False
@@ -44,22 +58,66 @@ class PrecisionDebugger:
44
58
  cls.first_start = False
45
59
  return cls._instance
46
60
 
47
- def __init__(self, config_path=None):
61
+ def __init__(self, config_path=None, task=None, dump_path=None,
62
+ level=None, step=None):
48
63
  if self.initialized:
49
64
  return
50
65
  self.initialized = True
66
+
67
+ set_register_backward_hook_functions()
68
+
51
69
  if not config_path:
52
70
  config_path = os.path.join(os.path.dirname(__file__), "../../config.json")
71
+
72
+ config_params = ConfigParameters(config_path, task, dump_path, level)
73
+ self.check_input_params(config_params)
74
+
53
75
  common_config, task_config = parse_json_config(config_path)
76
+ common_config.task = task if task else common_config.task
54
77
  self.task = common_config.task
55
78
  if self.task == Const.GRAD_PROBE:
56
79
  self.gm = GradientMonitor(common_config, task_config)
57
80
  return
81
+ common_config.step = get_real_step_or_rank(
82
+ step, Const.STEP) if step is not None else common_config.step
83
+ common_config.level = level if level else common_config.level
84
+ common_config.dump_path = dump_path if dump_path else common_config.dump_path
58
85
  self.config = DebuggerConfig(common_config, task_config)
59
86
 
87
+ if _msprobe_c:
88
+ _msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path)
89
+
90
+ self.config.execution_mode = self._get_execution_mode()
91
+ if self._need_service():
92
+ self.config.check_config_with_l2()
93
+ self.service = Service(self.config)
94
+
60
95
  Runtime.step_count = 0
61
96
  Runtime.is_running = False
62
97
 
98
+ @staticmethod
99
+ def check_input_params(args):
100
+ if args.config_path is not None:
101
+ if not isinstance(args.config_path, str):
102
+ raise MsprobeException(
103
+ MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
104
+ file_checker = FileChecker(
105
+ file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
106
+ file_checker.common_check()
107
+
108
+ if args.task is not None and args.task not in Const.TASK_LIST:
109
+ raise MsprobeException(
110
+ MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
111
+
112
+ if args.dump_path is not None:
113
+ if not isinstance(args.dump_path, str):
114
+ raise MsprobeException(
115
+ MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
116
+
117
+ if args.level is not None and args.level not in Const.LEVEL_LIST:
118
+ raise MsprobeException(
119
+ MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
120
+
63
121
  @staticmethod
64
122
  def _get_execution_mode():
65
123
  jit_level = ms.context.get_jit_config().get(MsConst.JIT_LEVEL)
@@ -78,11 +136,23 @@ class PrecisionDebugger:
78
136
  else:
79
137
  return MsConst.PYNATIVE_MODE
80
138
 
139
+ @staticmethod
140
+ def _is_graph_dump(config):
141
+ if config.level != MsConst.KERNEL:
142
+ return False
143
+ if not config.list:
144
+ return True
145
+ is_graph = any(item.startswith("name-regex") for item in config.list)
146
+ is_graph |= all("." not in item for item in config.list)
147
+ return is_graph
148
+
81
149
  @classmethod
82
150
  def start(cls, model=None):
83
151
  instance = cls._instance
84
152
  if not instance:
85
153
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
154
+ if _msprobe_c:
155
+ _msprobe_c._PrecisionDebugger().start()
86
156
  if instance.task in PrecisionDebugger.task_not_need_service:
87
157
  return
88
158
 
@@ -93,6 +163,7 @@ class PrecisionDebugger:
93
163
  instance.service.start(model)
94
164
  else:
95
165
  if not instance.first_start:
166
+ api_register.api_set_ori_func()
96
167
  handler = TaskHandlerFactory.create(instance.config)
97
168
  handler.handle()
98
169
 
@@ -102,18 +173,15 @@ class PrecisionDebugger:
102
173
  @classmethod
103
174
  def forward_backward_dump_end(cls):
104
175
  instance = cls._instance
105
- if not instance:
106
- raise Exception(MsgConst.NOT_CREATED_INSTANCE)
107
- if instance.task in PrecisionDebugger.task_not_need_service:
108
- return
109
- if instance.service:
110
- instance.service.forward_backward_dump_end()
176
+ instance.stop()
111
177
 
112
178
  @classmethod
113
179
  def stop(cls):
114
180
  instance = cls._instance
115
181
  if not instance:
116
182
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
183
+ if _msprobe_c:
184
+ _msprobe_c._PrecisionDebugger().stop()
117
185
  if instance.task == Const.GRAD_PROBE:
118
186
  instance.gm.stop()
119
187
  if instance.task in PrecisionDebugger.task_not_need_service:
@@ -127,6 +195,8 @@ class PrecisionDebugger:
127
195
  instance = cls._instance
128
196
  if not instance:
129
197
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
198
+ if _msprobe_c:
199
+ _msprobe_c._PrecisionDebugger().step()
130
200
  if instance.task in PrecisionDebugger.task_not_need_service:
131
201
  return
132
202
  if instance.service:
@@ -145,6 +215,24 @@ class PrecisionDebugger:
145
215
  return
146
216
  instance.gm.monitor(opt)
147
217
 
218
+ @classmethod
219
+ def save(cls, variable, name, save_backward=True):
220
+ instance = cls._instance
221
+ if not instance:
222
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
223
+ if instance.task not in [Const.TENSOR, Const.STATISTICS] or instance.config.level_ori != Const.LEVEL_DEBUG:
224
+ return
225
+ try:
226
+ check_save_param(variable, name, save_backward)
227
+ except ValueError:
228
+ return
229
+
230
+ instance.config.execution_mode = cls._get_execution_mode()
231
+ if cls._need_service():
232
+ if not instance.service:
233
+ instance.service = Service(instance.config)
234
+ instance.service.save(variable, name, save_backward)
235
+
148
236
  @classmethod
149
237
  def _need_service(cls):
150
238
  instance = cls._instance
@@ -153,4 +241,4 @@ class PrecisionDebugger:
153
241
  if instance.config.execution_mode != MsConst.PYNATIVE_MODE:
154
242
  return False
155
243
  else:
156
- return instance.config.task != Const.FREE_BENCHMARK and instance.config.level != MsConst.KERNEL
244
+ return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config)
@@ -1,4 +1,5 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
2
3
  #
3
4
  # Licensed under the Apache License, Version 2.0 (the "License");
4
5
  # you may not use this file except in compliance with the License.
@@ -11,7 +12,6 @@
11
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
13
  # See the License for the specific language governing permissions and
13
14
  # limitations under the License.
14
- # ============================================================================
15
15
 
16
16
  from mindspore import Tensor, ops, mint
17
17
  from mindspore.mint.nn import functional
@@ -20,8 +20,15 @@ from mindspore.communication import comm_func
20
20
 
21
21
  from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP,
22
22
  HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP,
23
- get_wrap_api_list, setup_hooks)
23
+ HOOKTorchOP, HOOKTorchTensor, HOOKTorchFunctionalOP,
24
+ HOOKTorchDistributedOP, HOOKTorchNpuOP,
25
+ get_wrap_api_list, get_wrap_torch_api_list, setup_hooks)
24
26
  from msprobe.core.common.utils import Const
27
+ from msprobe.mindspore.common.utils import is_mindtorch
28
+
29
+ if is_mindtorch():
30
+ import torch
31
+ import torch_npu
25
32
 
26
33
 
27
34
  def stub_method(method):
@@ -40,6 +47,12 @@ class ApiRegistry:
40
47
  self.distributed_ori_attr = {}
41
48
  self.norm_inner_ops_ori_attr = {}
42
49
 
50
+ self.torch_ori_attr = {}
51
+ self.torch_tensor_ori_attr = {}
52
+ self.torch_functional_ori_attr = {}
53
+ self.torch_distributed_ori_attr = {}
54
+ self.torch_npu_ori_attr = {}
55
+
43
56
  self.tensor_hook_attr = {}
44
57
  self.stub_tensor_hook_attr = {}
45
58
  self.functional_hook_attr = {}
@@ -48,6 +61,12 @@ class ApiRegistry:
48
61
  self.distibuted_hook_attr = {}
49
62
  self.norm_inner_ops_hook_attr = {}
50
63
 
64
+ self.torch_hook_attr = {}
65
+ self.torch_tensor_hook_attr = {}
66
+ self.torch_functional_hook_attr = {}
67
+ self.torch_distributed_hook_attr = {}
68
+ self.torch_npu_hook_attr = {}
69
+
51
70
  self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
52
71
 
53
72
  @staticmethod
@@ -82,22 +101,73 @@ class ApiRegistry:
82
101
  self.set_api_attr(ops, self.norm_inner_ops_ori_attr)
83
102
 
84
103
  def api_set_hook_func(self):
85
- self.set_api_attr(Tensor, self.tensor_hook_attr)
86
- self.set_api_attr(StubTensor, self.stub_tensor_hook_attr)
87
- self.set_api_attr(ops, self.functional_hook_attr)
88
- self.set_api_attr(mint, self.mint_ops_hook_attr)
89
- self.set_api_attr(functional, self.mint_func_ops_hook_attr)
90
- self.set_api_attr(comm_func, self.distibuted_hook_attr)
104
+ if is_mindtorch():
105
+ self.set_api_attr(torch, self.torch_hook_attr)
106
+ self.set_api_attr(torch.Tensor, self.torch_tensor_hook_attr)
107
+ self.set_api_attr(torch.nn.functional, self.torch_functional_hook_attr)
108
+ self.set_api_attr(torch.distributed, self.torch_distributed_hook_attr)
109
+ self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_hook_attr)
110
+ self.set_api_attr(torch_npu, self.torch_npu_hook_attr)
111
+ else:
112
+ self.set_api_attr(Tensor, self.tensor_hook_attr)
113
+ self.set_api_attr(StubTensor, self.stub_tensor_hook_attr)
114
+ self.set_api_attr(ops, self.functional_hook_attr)
115
+ self.set_api_attr(mint, self.mint_ops_hook_attr)
116
+ self.set_api_attr(functional, self.mint_func_ops_hook_attr)
117
+ self.set_api_attr(comm_func, self.distibuted_hook_attr)
91
118
 
92
119
  def api_set_ori_func(self):
93
- self.set_api_attr(Tensor, self.tensor_ori_attr)
94
- self.set_api_attr(StubTensor, self.stub_tensor_ori_attr)
95
- self.set_api_attr(ops, self.functional_ori_attr)
96
- self.set_api_attr(mint, self.mint_ops_ori_attr)
97
- self.set_api_attr(functional, self.mint_func_ops_ori_attr)
98
- self.set_api_attr(comm_func, self.distributed_ori_attr)
120
+ if is_mindtorch():
121
+ self.set_api_attr(torch, self.torch_ori_attr)
122
+ self.set_api_attr(torch.Tensor, self.torch_tensor_ori_attr)
123
+ self.set_api_attr(torch.nn.functional, self.torch_functional_ori_attr)
124
+ self.set_api_attr(torch.distributed, self.torch_distributed_ori_attr)
125
+ self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_ori_attr)
126
+ self.set_api_attr(torch_npu, self.torch_npu_ori_attr)
127
+ else:
128
+ self.set_api_attr(Tensor, self.tensor_ori_attr)
129
+ self.set_api_attr(StubTensor, self.stub_tensor_ori_attr)
130
+ self.set_api_attr(ops, self.functional_ori_attr)
131
+ self.set_api_attr(mint, self.mint_ops_ori_attr)
132
+ self.set_api_attr(functional, self.mint_func_ops_ori_attr)
133
+ self.set_api_attr(comm_func, self.distributed_ori_attr)
99
134
 
100
135
  def initialize_hook(self, hook):
136
+ setup_hooks(hook)
137
+ if is_mindtorch():
138
+ wrap_torch_api_name = get_wrap_torch_api_list()
139
+ self.store_ori_attr(torch,
140
+ wrap_torch_api_name.torch_api_names, self.torch_ori_attr)
141
+ self.store_ori_attr(torch.Tensor,
142
+ wrap_torch_api_name.tensor_api_names, self.torch_tensor_ori_attr)
143
+ self.store_ori_attr(torch.nn.functional,
144
+ wrap_torch_api_name.functional_api_names, self.torch_functional_ori_attr)
145
+ self.store_ori_attr(torch.distributed,
146
+ wrap_torch_api_name.distributed_api_names, self.torch_distributed_ori_attr)
147
+ self.store_ori_attr(torch_npu,
148
+ wrap_torch_api_name.npu_api_names, self.torch_npu_ori_attr)
149
+ for attr_name in dir(HOOKTorchOP):
150
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
151
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
152
+ self.torch_hook_attr[api_name] = getattr(HOOKTorchOP, attr_name)
153
+ for attr_name in dir(HOOKTorchTensor):
154
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
155
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
156
+ self.torch_tensor_hook_attr[api_name] = getattr(HOOKTorchTensor, attr_name)
157
+ for attr_name in dir(HOOKTorchFunctionalOP):
158
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
159
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
160
+ self.torch_functional_hook_attr[api_name] = getattr(HOOKTorchFunctionalOP, attr_name)
161
+ for attr_name in dir(HOOKTorchDistributedOP):
162
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
163
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
164
+ self.torch_distributed_hook_attr[api_name] = getattr(HOOKTorchDistributedOP, attr_name)
165
+ for attr_name in dir(HOOKTorchNpuOP):
166
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
167
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
168
+ self.torch_npu_hook_attr[api_name] = getattr(HOOKTorchNpuOP, attr_name)
169
+ return
170
+
101
171
  wrap_api_name = get_wrap_api_list()
102
172
  self.store_ori_attr(Tensor, wrap_api_name.tensor_api_names, self.tensor_ori_attr)
103
173
  self.store_ori_attr(StubTensor, wrap_api_name.stub_tensor_api_names, self.stub_tensor_ori_attr)
@@ -106,7 +176,6 @@ class ApiRegistry:
106
176
  self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr)
107
177
  self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr)
108
178
  self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
109
- setup_hooks(hook)
110
179
  for attr_name in dir(HOOKTensor):
111
180
  if attr_name.startswith(Const.ATTR_NAME_PREFIX):
112
181
  api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
@@ -1,4 +1,5 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
2
3
  #
3
4
  # Licensed under the Apache License, Version 2.0 (the "License");
4
5
  # you may not use this file except in compliance with the License.
@@ -11,45 +12,66 @@
11
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
13
  # See the License for the specific language governing permissions and
13
14
  # limitations under the License.
14
- # ============================================================================
15
15
 
16
16
  from collections import defaultdict
17
17
 
18
18
  from mindspore import nn
19
19
 
20
- from msprobe.core.common.const import Const
21
-
22
-
23
- class HOOKCell(nn.Cell):
24
- cell_count = defaultdict(int)
25
- g_stop_hook = False
26
-
27
- def __init__(self, build_hook) -> None:
28
- super(HOOKCell, self).__init__()
29
- self.changed_status = False
30
- self.input_kwargs = {}
31
- self.prefix = ""
32
- if not HOOKCell.g_stop_hook:
33
- HOOKCell.g_stop_hook = True
34
- self.changed_status = True
35
- if hasattr(self, "prefix_api_name"):
36
- self.prefix = self.prefix_api_name
37
-
38
- HOOKCell.cell_count[self.prefix] += 1
39
- self.prefix = self.prefix + str(HOOKCell.cell_count[self.prefix] - 1) + Const.SEP
40
- forward_hook, backward_hook = build_hook(self.prefix)
41
- self.register_forward_hook(forward_hook)
42
- self.register_backward_hook(backward_hook)
43
-
44
- # 重载call,加全局标志。
45
- def __call__(self, *args, **kwargs):
46
- try:
47
- self.input_kwargs = kwargs
48
- out = super(HOOKCell, self).__call__(*args, **kwargs)
49
- except Exception as e:
50
- raise e
51
- finally:
52
- if self.changed_status:
53
- self.changed_status = False
54
- HOOKCell.g_stop_hook = False
55
- return out
20
+ from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions
21
+
22
+
23
+ def add_cell_count(name):
24
+ HOOKCell.cell_count[name] += 1
25
+
26
+
27
+ def get_cell_count(name):
28
+ return HOOKCell.cell_count[name]
29
+
30
+
31
+ def __init__(self, build_hook) -> None:
32
+ super(HOOKCell, self).__init__()
33
+ self.changed_status = False
34
+ self.input_kwargs = {}
35
+ self.prefix = ""
36
+ if not HOOKCell.g_stop_hook:
37
+ HOOKCell.g_stop_hook = True
38
+ self.changed_status = True
39
+ if hasattr(self, "prefix_api_name"):
40
+ self.prefix = self.prefix_api_name
41
+
42
+ self.forward_data_collected = False
43
+ forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = build_hook(self.prefix)
44
+ self.register_forward_pre_hook(forward_pre_hook)
45
+ self.register_forward_hook(forward_hook)
46
+ register_backward_hook_functions["full"](self, backward_hook)
47
+ register_backward_hook_functions["pre"](self, backward_pre_hook)
48
+
49
+
50
+ # 重载call,加全局标志。
51
+ def __call__(self, *args, **kwargs):
52
+ try:
53
+ self.input_kwargs = kwargs
54
+ out = super(HOOKCell, self).__call__(*args, **kwargs)
55
+ except Exception as e:
56
+ raise e
57
+ finally:
58
+ if self.changed_status:
59
+ self.changed_status = False
60
+ HOOKCell.g_stop_hook = False
61
+ return out
62
+
63
+
64
+ hook_cell_dict = {
65
+ "cell_count": defaultdict(int),
66
+ "g_stop_hook": False,
67
+ "add_cell_count": staticmethod(add_cell_count),
68
+ "get_cell_count": staticmethod(get_cell_count),
69
+ "__init__": __init__,
70
+ "__call__": __call__
71
+ }
72
+
73
+ if is_mindtorch():
74
+ import torch
75
+ HOOKCell = type("HOOKCell", (torch.nn.Module,), hook_cell_dict)
76
+ else:
77
+ HOOKCell = type("HOOKCell", (nn.Cell,), hook_cell_dict)
@@ -135,6 +135,34 @@ class PrimitiveHookService:
135
135
  return tuple(hooked_outputs)
136
136
  return out
137
137
 
138
+ def pre_forward_hook(primitive_name, primitive_instance, args, kwargs):
139
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
140
+ try:
141
+ self.service_instance.data_collector.forward_input_data_collect(
142
+ primitive_name,
143
+ primitive_instance,
144
+ os.getpid(),
145
+ module_input_output
146
+ )
147
+ except Exception as exception:
148
+ logger.error(f"This is a primitive op dump error during forward input data collection: {exception}, "
149
+ f"primitive_name: {primitive_name}")
150
+ raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
151
+
152
+ def post_forward_hook(primitive_name, primitive_instance, args, kwargs, output):
153
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
154
+ try:
155
+ self.service_instance.data_collector.forward_output_data_collect(
156
+ primitive_name,
157
+ primitive_instance,
158
+ os.getpid(),
159
+ module_input_output
160
+ )
161
+ except Exception as exception:
162
+ logger.error(f"This is a primitive op dump error during forward output data collection: {exception}, "
163
+ f"primitive_name: {primitive_name}")
164
+ raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
165
+
138
166
  def wrapped_primitive_call(instance_self, *args, **kwargs):
139
167
  """
140
168
  包装后的 primitive 调用函数,添加输入和输出的 hook。
@@ -163,27 +191,17 @@ class PrimitiveHookService:
163
191
  f"primitive_name: {primitive_name}")
164
192
  raise DumpException(DumpException.INPUT_HOOK_ERROR) from exception
165
193
 
194
+ forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}"
195
+ self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
196
+
197
+ pre_forward_hook(forward_primitive_name, instance_self, hooked_inputs, kwargs)
166
198
  try:
167
199
  out = origin_func(*hooked_inputs, **kwargs)
168
200
  except Exception as exception:
169
201
  logger.error(f"This is a primitive op dump error during function call: {exception}, "
170
202
  f"primitive_name: {primitive_name}")
171
203
  raise DumpException(DumpException.FUNCTION_CALL_ERROR) from exception
172
-
173
- forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}"
174
- self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
175
- if self.service_instance.data_collector:
176
- module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
177
- try:
178
- self.service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
179
- os.getpid(), module_input_output)
180
- except Exception as exception:
181
- logger.error(f"This is a primitive op dump error during forward data collection: {exception}, "
182
- f"primitive_name: {primitive_name}")
183
- raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
184
-
185
- if self.service_instance.data_collector.if_return_forward_new_output():
186
- out = self.service_instance.data_collector.get_forward_new_output()
204
+ post_forward_hook(forward_primitive_name, instance_self, hooked_inputs, kwargs, out)
187
205
 
188
206
  try:
189
207
  out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
@@ -15,7 +15,7 @@
15
15
 
16
16
  # List of ops that register hooks
17
17
 
18
-
18
+
19
19
  ops:
20
20
  - adaptive_avg_pool1d
21
21
  - adaptive_avg_pool2d
@@ -85,6 +85,7 @@ ops:
85
85
  - relu6
86
86
  - celu
87
87
  - rrelu
88
+ - rms_norm
88
89
  - selu
89
90
  - sigmoid
90
91
  - silu
@@ -553,6 +554,7 @@ tensor:
553
554
  - acos
554
555
  - acosh
555
556
  - add
557
+ - add_
556
558
  - addbmm
557
559
  - addcdiv
558
560
  - addcmul
@@ -607,6 +609,7 @@ tensor:
607
609
  - diff
608
610
  - digamma
609
611
  - div
612
+ - div_
610
613
  - divide
611
614
  - equal
612
615
  - erf
@@ -739,6 +742,8 @@ tensor:
739
742
  - square
740
743
  - squeeze
741
744
  - std
745
+ - sub
746
+ - sub_
742
747
  - subtract
743
748
  - subtract
744
749
  - svd
@@ -983,6 +988,7 @@ mint.nn.functional:
983
988
  - one_hot_ext
984
989
  - pad
985
990
  - relu
991
+ - relu_
986
992
  - sigmoid
987
993
  - silu
988
994
  - softmax
@@ -1017,3 +1023,7 @@ communication.comm_func:
1017
1023
  - broadcast
1018
1024
  - gather_into_tensor
1019
1025
  - scatter_tensor
1026
+ - send
1027
+ - recv
1028
+ - isend
1029
+ - irecv