mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -15,15 +15,17 @@
15
15
 
16
16
  from msprobe.mindspore.common.const import Const
17
17
  from msprobe.core.common.log import logger
18
+ from msprobe.mindspore.common.utils import is_graph_mode_cell_dump_allowed
18
19
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
19
20
  from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
20
21
  from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump
22
+ from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump
21
23
 
22
24
 
23
25
  class DumpToolFactory:
24
26
  tools = {
25
27
  Const.CELL: {
26
- Const.GRAPH_KBYK_MODE: None,
28
+ Const.GRAPH_KBYK_MODE: GraphModeCellDump,
27
29
  Const.GRAPH_GE_MODE: None,
28
30
  Const.PYNATIVE_MODE: None
29
31
  },
@@ -40,9 +42,15 @@ class DumpToolFactory:
40
42
  }
41
43
 
42
44
  @staticmethod
43
- def create(config: DebuggerConfig):
44
- if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST:
45
- raise Exception("data_mode must be one of all, input, output.")
45
+ def create(config: DebuggerConfig, model=None):
46
+ if config.level == Const.CELL:
47
+ if not is_graph_mode_cell_dump_allowed(config):
48
+ raise Exception("Cell dump is not supported in graph mode.")
49
+ if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST:
50
+ raise Exception("data_mode must be one of all, forward, backward.")
51
+ else:
52
+ if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST:
53
+ raise Exception("data_mode must be one of all, input, output.")
46
54
  tool = DumpToolFactory.tools.get(config.level)
47
55
  if not tool:
48
56
  raise Exception("Valid level is needed.")
@@ -51,4 +59,4 @@ class DumpToolFactory:
51
59
  logger.error(f"Data dump is not supported in {config.execution_mode} mode "
52
60
  f"when dump level is {config.level}.")
53
61
  raise ValueError
54
- return tool(config)
62
+ return tool(config, model) if tool == GraphModeCellDump else tool(config)
@@ -0,0 +1,139 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ import mindspore as ms
19
+ from mindspore import hal, ops, Tensor
20
+ from mindspore.ops.primitive import _run_op
21
+
22
+ from msprobe.core.common.const import Const as CoreConst
23
+ from msprobe.core.common.runtime import Runtime
24
+ from msprobe.mindspore.common.const import Const
25
+ from msprobe.mindspore.common.log import logger
26
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
27
+ import msprobe.mindspore.dump.cell_dump_process as cellDumperWithDumpGradient
28
+ import msprobe.mindspore.dump.cell_dump_with_insert_gradient as cellDumperWithInsertGradient
29
+
30
+ tensordump_flag = True
31
+ try:
32
+ from mindspore._c_expression import _tensordump_set_step
33
+ except ImportError:
34
+ tensordump_flag = False
35
+
36
+
37
+ class GraphModeCellDump:
38
+ task = CoreConst.STATISTICS
39
+
40
+ def __init__(self, config: DebuggerConfig, model, strict=True):
41
+ self.net = model
42
+ self.white_list = []
43
+ self.black_list = []
44
+ self.execution_mode = config.execution_mode
45
+ self.dump_path = config.dump_path if config.dump_path else "./"
46
+ self.rank = config.rank
47
+ self.step = config.step
48
+ self.scope = config.scope
49
+ self.list = config.list
50
+ self.data_mode = config.data_mode
51
+ self.file_format = config.file_format
52
+ GraphModeCellDump.task = config.task
53
+ self.summary_mode = config.summary_mode
54
+ self.check_config(strict)
55
+ self.set_step()
56
+
57
+ @staticmethod
58
+ def step():
59
+ # 更新TensorDump Step
60
+ if GraphModeCellDump.task == CoreConst.TENSOR:
61
+ hal.synchronize()
62
+ temp_tensor = ms.Tensor([1], dtype=ms.float32)
63
+ step_flag = "<tensordump-update-step>"
64
+ _run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor))
65
+ ops.tensordump(step_flag, temp_tensor)
66
+
67
+ def check_config(self, strict):
68
+ if not self.net:
69
+ raise Exception("The model is empty and cell dump is not enabled.")
70
+
71
+ if strict:
72
+ if self.rank:
73
+ raise Exception("In graph mode, cell dump does not currently support specifying rank.")
74
+ if self.scope:
75
+ raise Exception("In graph mode, cell dump does not currently support specifying scope.")
76
+ if self.list:
77
+ raise Exception("In graph mode, cell dump does not currently support specifying list.")
78
+ if len(self.data_mode) != 1 or self.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST:
79
+ raise Exception("In graph mode and cell dump, data_mode must be one of all, forword, backword.")
80
+ if self.file_format != []:
81
+ logger.warning("In graph mode, cell dump does not currently support specifying file_format."
82
+ " The file will be stored in npy format.")
83
+ if self.task == CoreConst.STATISTICS and self.summary_mode == CoreConst.MD5:
84
+ raise Exception("The L0 level statistics dump mode does not support "
85
+ "the calculation of md5 values currently In graph mode.")
86
+ else:
87
+ self.rank = []
88
+ self.scope = []
89
+ self.list = []
90
+ self.file_format = []
91
+ if len(self.data_mode) != 1 or self.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST:
92
+ self.data_mode = [CoreConst.ALL]
93
+ if self.task == CoreConst.STATISTICS and self.summary_mode == CoreConst.MD5:
94
+ self.summary_mode = CoreConst.STATISTICS
95
+
96
+ return True
97
+
98
+ def set_step(self):
99
+ if tensordump_flag:
100
+ _tensordump_set_step(self.step)
101
+ else:
102
+ raise Exception(
103
+ "Importing _tensordump_set_step failed, "
104
+ "please use the latest version package of MindSpore."
105
+ )
106
+
107
+ def handle(self):
108
+ os.environ['MS_JIT_MODULES'] = 'msprobe'
109
+
110
+ if Runtime.run_mode == Const.PYNATIVE_GRAPH_MODE:
111
+ dump_path = os.path.join(self.dump_path, Const.GRAPH_MODE)
112
+ else:
113
+ dump_path = self.dump_path
114
+
115
+ cell_dumper = cellDumperWithDumpGradient
116
+
117
+ if self.execution_mode == Const.PYNATIVE_MODE:
118
+ enable_dump_gradient = hasattr(ops, 'DumpGradient')
119
+ if hasattr(ops, 'DumpGradient'):
120
+ try:
121
+ ops.DumpGradient()('grad.npy', Tensor([0], dtype=ms.float32), 'in')
122
+ except Exception:
123
+ enable_dump_gradient = False
124
+ logger.warning('the DumpGradient operator failed to execute.')
125
+ if not enable_dump_gradient:
126
+ cell_dumper = cellDumperWithInsertGradient
127
+
128
+ dump_config = cell_dumper.CellDumpConfig(
129
+ net=self.net,
130
+ dump_path=dump_path,
131
+ data_mode=self.data_mode[0],
132
+ task=self.task,
133
+ summary_mode=self.summary_mode,
134
+ step=self.step
135
+ )
136
+
137
+ cell_dumper.start(
138
+ dump_config
139
+ )
@@ -0,0 +1,123 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from collections import OrderedDict
18
+ import mindspore as ms
19
+
20
+
21
+ def _iterate_items(data):
22
+ if isinstance(data, (dict, OrderedDict)):
23
+ return data.items()
24
+ elif isinstance(data, (list, tuple)):
25
+ return enumerate(data)
26
+ else:
27
+ raise TypeError("Unsupported data type")
28
+
29
+
30
+ class _SaveBase:
31
+ def __init__(self, save_dir):
32
+ super(_SaveBase, self).__init__()
33
+ self.path = save_dir
34
+ self.save_func = _npy_save
35
+
36
+ def get_save_func(self):
37
+ return self.save_func
38
+
39
+
40
+ @ms.jit_class
41
+ class _SaveCell(_SaveBase):
42
+ def __call__(self, name, data):
43
+ return self.get_save_func()(self.path, name, data)
44
+
45
+
46
+ class _SaveGradBase:
47
+ def __init__(self, save_dir, name):
48
+ super(_SaveGradBase, self).__init__()
49
+ self.file = save_dir + name
50
+
51
+
52
+ @ms.jit_class
53
+ class _SaveGradCell(_SaveGradBase):
54
+ def __init__(self, save_dir, name):
55
+ super(_SaveGradCell, self).__init__(save_dir, name)
56
+ self.ms_save_grad = ms.ops.InsertGradientOf(
57
+ _wrapper_save_grad_func(self.file))
58
+
59
+ def __call__(self, x):
60
+ if isinstance(x, ms.Tensor):
61
+ return self.ms_save_grad(x)
62
+ else:
63
+ raise TypeError(f"For 'save_grad', the type of argument 'data' must be mindspore.Tensor or torch.tensor, "
64
+ f"but got {type(x)}")
65
+
66
+
67
+ def _npy_save_ops(file, data):
68
+ if isinstance(data, ms.Tensor):
69
+ if data.dtype == ms.bfloat16:
70
+ data = data.float()
71
+ ms.ops.TensorDump()(file, data)
72
+ else:
73
+ raise TypeError(f"For 'save', the type of argument 'data' must be mindspore.Tensor or torch.tensor, "
74
+ f"but got {type(data)}")
75
+
76
+
77
+ def _wrapper_save_grad_func(file):
78
+ def _save_grad_func(grad):
79
+ data = grad
80
+ if data.dtype == ms.bfloat16:
81
+ data = data.float()
82
+ ms.ops.TensorDump()(file, data)
83
+ return grad
84
+ return _save_grad_func
85
+
86
+
87
+ def _npy_save(save_dir, item_name, data):
88
+ if isinstance(data, (list, tuple, dict, OrderedDict)):
89
+ for key, val in _iterate_items(data):
90
+ _npy_save(save_dir, f"{item_name}.{key}", val)
91
+ else:
92
+ if data is None:
93
+ return
94
+ _npy_save_ops(f"{save_dir}{item_name}", data)
95
+
96
+
97
+ def generate_dump_dir(save_dir, sep=os.sep):
98
+ """
99
+ usage: generate dump directory path str in mindspore graph mode
100
+ """
101
+ full_suffix = '{step}' + sep + '{rank}' + sep
102
+ if save_dir and save_dir[-1] != sep:
103
+ result_dir = save_dir + sep + full_suffix
104
+ else:
105
+ result_dir = save_dir + full_suffix
106
+ return result_dir
107
+
108
+
109
+ def save(save_dir, name, data):
110
+ """
111
+ save tensor.
112
+ """
113
+ dump_dir = generate_dump_dir(save_dir)
114
+ _SaveCell(dump_dir)(name, data)
115
+
116
+
117
+ def save_grad(save_dir, name, data):
118
+ """
119
+ save grad.
120
+ """
121
+ dump_dir = generate_dump_dir(save_dir)
122
+ suffix_name = name + '_grad'
123
+ return _SaveGradCell(dump_dir, suffix_name)(data)
@@ -14,14 +14,17 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
+ import inspect
17
18
 
18
19
  from mindspore import Tensor, ops, mint
20
+ from mindspore.mint import distributed
19
21
  from mindspore.mint.nn import functional
20
22
  from mindspore.communication import comm_func
21
23
 
22
24
  from msprobe.core.common.file_utils import load_yaml
23
25
  from msprobe.core.common.utils import Const
24
26
  from msprobe.core.data_dump.api_registry import ApiRegistry
27
+ from msprobe.mindspore.common.log import logger
25
28
  from msprobe.mindspore.common.const import Const as MsConst
26
29
  from msprobe.mindspore.common.utils import is_mindtorch
27
30
  from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
@@ -41,7 +44,8 @@ if not is_mindtorch():
41
44
  Const.MS_API_TYPE_TENSOR: (Tensor, (Tensor,)),
42
45
  Const.MS_API_TYPE_MINT: (mint, (mint,)),
43
46
  Const.MS_API_TYPE_MINT_FUNC: (functional, (functional,)),
44
- Const.MS_API_TYPE_COM: (comm_func, (comm_func,))
47
+ Const.MS_API_TYPE_COM: (comm_func, (comm_func,)),
48
+ Const.MS_API_TYPE_MINT_DIST: (distributed, (distributed,))
45
49
  }
46
50
  }
47
51
  if stub_tensor_existed:
@@ -50,6 +54,7 @@ if not is_mindtorch():
50
54
  )
51
55
 
52
56
  _supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),)
57
+ _backlist = []
53
58
  else:
54
59
  import torch
55
60
  import torch_npu
@@ -64,13 +69,14 @@ else:
64
69
  }
65
70
  _supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module',
66
71
  MsConst.SUPPORTED_API_LIST_FILE),)
72
+ _backlist = [f'{Const.PT_API_TYPE_TENSOR}.__setitem__']
67
73
 
68
74
  _inner_used_api = {
69
75
  Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: (
70
76
  ops, "norm", "square", "sqrt", "is_complex", "stack", "is_floating_point"
71
77
  ),
72
78
  Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_TENSOR: (
73
- Tensor, "to", "numel"
79
+ Tensor, "to", "numel", 'sum'
74
80
  ),
75
81
  Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_MINT: (
76
82
  mint, "max", "min", "mean", "norm"
@@ -84,6 +90,9 @@ class ApiTemplate(HOOKCell):
84
90
  self.api_func = api_func
85
91
  self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
86
92
  super().__init__(hook_build_func)
93
+ distributed_prefix = Const.DIST_API_TYPE_PREFIX if is_mindtorch() else Const.MINT_DIST_API_TYPE_PREFIX
94
+ if prefix == distributed_prefix:
95
+ self.op_is_distributed = True
87
96
 
88
97
  @staticmethod
89
98
  def async_to_sync(output):
@@ -103,9 +112,22 @@ class ApiTemplate(HOOKCell):
103
112
 
104
113
  output = self.api_func(*args, **kwargs)
105
114
 
106
- if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX):
107
- if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]:
115
+ if self.prefix_api_name.startswith(
116
+ (MsConst.DISTRIBUTED_DATA_PREFIX, Const.MINT_DIST_API_TYPE_PREFIX)
117
+ ):
118
+ try:
119
+ bound = inspect.signature(self.api_func).bind(*args, **kwargs)
120
+ bound.apply_defaults()
121
+ use_async_op_flag = bound.arguments.get("async_op", False)
122
+ except Exception as e:
123
+ use_async_op_flag = False
124
+ logger.warning(f"fail to get dist api's func signature because {e}, no wait")
125
+
126
+ if use_async_op_flag or self.api_name in ["isend", "irecv"]:
108
127
  output = self.async_to_sync(output)
128
+ if self.api_name == "batch_isend_irecv" and isinstance(output, list):
129
+ output = [self.async_to_sync(handle) for handle in output]
130
+
109
131
  return output
110
132
 
111
133
  def forward(self, *args, **kwargs):
@@ -134,9 +156,21 @@ def get_api_register(return_new=False):
134
156
  stub_tensor_set = True
135
157
 
136
158
  if return_new:
137
- return ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
159
+ return ApiRegistry(
160
+ _api_types,
161
+ _inner_used_api,
162
+ _supported_api_list_path,
163
+ ApiTemplate,
164
+ _backlist
165
+ )
138
166
 
139
167
  global api_register
140
168
  if api_register is None:
141
- api_register = ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
169
+ api_register = ApiRegistry(
170
+ _api_types,
171
+ _inner_used_api,
172
+ _supported_api_list_path,
173
+ ApiTemplate,
174
+ _backlist
175
+ )
142
176
  return api_register
@@ -15,11 +15,16 @@
15
15
 
16
16
  from collections import defaultdict
17
17
 
18
+ import mindspore as ms
18
19
  from mindspore import nn
19
20
 
21
+ from msprobe.core.common.runtime import Runtime
20
22
  from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions
21
23
 
22
24
 
25
+ ms_version = ms.__version__
26
+
27
+
23
28
  def add_cell_count(name):
24
29
  HOOKCell.cell_count[name] += 1
25
30
 
@@ -31,25 +36,31 @@ def get_cell_count(name):
31
36
  def __init__(self, hook_build_func) -> None:
32
37
  super(HOOKCell, self).__init__()
33
38
  self.changed_status = False
34
- self.input_kwargs = {}
39
+ self.msprobe_input_kwargs = {}
35
40
  if not HOOKCell.g_stop_hook:
36
41
  HOOKCell.g_stop_hook = True
37
42
  self.changed_status = True
38
43
  self.forward_data_collected = False
39
44
 
45
+ if not Runtime.is_running:
46
+ return
40
47
  prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
41
48
  if callable(hook_build_func):
42
- forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = hook_build_func(prefix)
43
- self.register_forward_pre_hook(forward_pre_hook)
44
- self.register_forward_hook(forward_hook)
45
- register_backward_hook_functions["full"](self, backward_hook)
46
- register_backward_hook_functions["pre"](self, backward_pre_hook)
49
+ hook_set = hook_build_func(prefix)
50
+ if ms_version < "2.6.0" and not is_mindtorch():
51
+ getattr(self, "_forward_pre_hook", {})[id(self)] = hook_set.forward_pre_hook
52
+ getattr(self, "_forward_hook", {})[id(self)] = hook_set.forward_hook
53
+ else:
54
+ self.register_forward_pre_hook(hook_set.forward_pre_hook)
55
+ self.register_forward_hook(hook_set.forward_hook)
56
+ register_backward_hook_functions["full"](self, hook_set.backward_hook)
57
+ register_backward_hook_functions["pre"](self, hook_set.backward_pre_hook)
47
58
 
48
59
 
49
60
  # 重载call,加全局标志。
50
61
  def __call__(self, *args, **kwargs):
51
62
  try:
52
- self.input_kwargs = kwargs
63
+ self.msprobe_input_kwargs = kwargs
53
64
  out = super(HOOKCell, self).__call__(*args, **kwargs)
54
65
  except Exception as e:
55
66
  raise e
@@ -0,0 +1,88 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from mindspore.common.api import _no_grad
17
+ from msprobe.core.common.const import Const
18
+ from msprobe.core.common.utils import replace_last_occurrence
19
+ from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputs
20
+ from msprobe.core.hook_manager import BaseHookManager, HookSet
21
+ from msprobe.mindspore.common.utils import has_kwargs_in_forward_hook
22
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
23
+
24
+
25
+ class MindsproeHookManager(BaseHookManager):
26
+ @property
27
+ def _is_recompute(self):
28
+ return None
29
+
30
+ @staticmethod
31
+ def _no_grad_context():
32
+ return _no_grad()
33
+
34
+ @staticmethod
35
+ def _add_count(name):
36
+ HOOKCell.add_cell_count(name)
37
+
38
+ @staticmethod
39
+ def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs):
40
+ if not has_kwargs_in_forward_hook() or hook_type == Const.API:
41
+ kwargs = module.msprobe_input_kwargs if hasattr(module, 'msprobe_input_kwargs') else {}
42
+ output = kwargs_or_output
43
+ else:
44
+ kwargs = kwargs_or_output
45
+ output = output_or_kwargs
46
+ return kwargs, output
47
+
48
+ def build_hook(self, hook_type, name):
49
+ if hook_type == Const.API:
50
+ full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD
51
+ else:
52
+ full_forward_name = name
53
+ full_backward_name = replace_last_occurrence(full_forward_name, Const.FORWARD, Const.BACKWARD)
54
+ hookset = HookSet(
55
+ forward_hook=self._build_forward_hook(hook_type, full_forward_name),
56
+ forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name),
57
+ backward_hook=self._build_backward_hook(hook_type, full_backward_name),
58
+ backward_pre_hook=self._build_backward_pre_hook(hook_type, full_backward_name)
59
+ )
60
+ return hookset
61
+
62
+ def _need_exchange(self, module):
63
+ if not hasattr(module, 'has_pre_hook_called') or not module.has_pre_hook_called:
64
+ return False
65
+ else:
66
+ return True
67
+
68
+ def _get_params_dict(self, module):
69
+ params_dict = {}
70
+ if self.config.task != Const.STRUCTURE:
71
+ params_dict = {
72
+ key.split(Const.SEP)[-1]: value
73
+ for key, value in module.parameters_dict(recurse=False).items()
74
+ }
75
+ return params_dict
76
+
77
+ def _build_backward_pre_hook(self, hook_type, name):
78
+ def backward_pre_hook(module, grad_input):
79
+ if self.config.level != Const.LEVEL_L2:
80
+ return
81
+ if not self._should_execute_hook(hook_type, module, False):
82
+ return
83
+ BaseHookManager.inner_switch = True
84
+ module_input = ModuleBackwardInputs(grad_input=grad_input)
85
+ self.data_collector.update_api_or_module_name(name)
86
+ self.data_collector.backward_input_data_collect(name, module, self._pid, module_input)
87
+ BaseHookManager.inner_switch = False
88
+ return backward_pre_hook
@@ -21,6 +21,7 @@ from mindspore.common.tensor import Tensor
21
21
  from msprobe.core.common.utils import Const, DumpException
22
22
  from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, ModuleBackwardOutputs,
23
23
  ModuleForwardInputsOutputs)
24
+ from msprobe.core.hook_manager import BaseHookManager
24
25
  from msprobe.mindspore.common.log import logger
25
26
 
26
27
 
@@ -58,7 +59,7 @@ class PrimitiveHookService:
58
59
  def backward_hook(grad):
59
60
  captured_grads.extend(grad)
60
61
  backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}"
61
-
62
+ self.service_instance.inner_switch = True
62
63
  try:
63
64
  if hook_type == Const.INPUT:
64
65
  self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
@@ -77,6 +78,7 @@ class PrimitiveHookService:
77
78
  logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, "
78
79
  f"updated_primitive_name: {updated_primitive_name}")
79
80
  raise DumpException(DumpException.BACKWARD_DATA_COLLECTION_ERROR) from exception
81
+ self.service_instance.inner_switch = False
80
82
 
81
83
  return backward_hook
82
84
 
@@ -137,6 +139,7 @@ class PrimitiveHookService:
137
139
 
138
140
  def pre_forward_hook(primitive_name, primitive_instance, args, kwargs):
139
141
  module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
142
+ self.service_instance.inner_switch = True
140
143
  try:
141
144
  self.service_instance.data_collector.forward_input_data_collect(
142
145
  primitive_name,
@@ -148,9 +151,11 @@ class PrimitiveHookService:
148
151
  logger.error(f"This is a primitive op dump error during forward input data collection: {exception}, "
149
152
  f"primitive_name: {primitive_name}")
150
153
  raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
154
+ self.service_instance.inner_switch = False
151
155
 
152
156
  def post_forward_hook(primitive_name, primitive_instance, args, kwargs, output):
153
157
  module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
158
+ self.service_instance.inner_switch = True
154
159
  try:
155
160
  self.service_instance.data_collector.forward_output_data_collect(
156
161
  primitive_name,
@@ -162,6 +167,7 @@ class PrimitiveHookService:
162
167
  logger.error(f"This is a primitive op dump error during forward output data collection: {exception}, "
163
168
  f"primitive_name: {primitive_name}")
164
169
  raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
170
+ self.service_instance.inner_switch = False
165
171
 
166
172
  def wrapped_primitive_call(instance_self, *args, **kwargs):
167
173
  """
@@ -179,7 +185,7 @@ class PrimitiveHookService:
179
185
  current_count = self.primitive_counters.get(primitive_name, 0)
180
186
  updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}{Const.SEP}{primitive_name}{Const.SEP}{current_count}"
181
187
 
182
- if not self.service_instance.primitive_switch:
188
+ if not self.service_instance.primitive_switch or BaseHookManager.inner_switch:
183
189
  return origin_func(*args, **kwargs)
184
190
 
185
191
  captured_grads_input, captured_grads_output = [], []
@@ -1025,3 +1025,21 @@ communication.comm_func:
1025
1025
  - recv
1026
1026
  - isend
1027
1027
  - irecv
1028
+
1029
+ mint.distributed:
1030
+ - send
1031
+ - recv
1032
+ - broadcast
1033
+ - all_reduce
1034
+ - reduce
1035
+ - all_gather
1036
+ - gather
1037
+ - isend
1038
+ - irecv
1039
+ - scatter
1040
+ - reduce_scatter
1041
+ - all_to_all_single
1042
+ - all_to_all
1043
+ - all_gather_into_tensor
1044
+ - reduce_scatter_tensor
1045
+ - batch_isend_irecv