mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,15 +14,18 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from msprobe.mindspore.common.const import Const
17
+ from msprobe.core.common.log import logger
18
+ from msprobe.mindspore.common.utils import is_graph_mode_cell_dump_allowed
17
19
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
18
20
  from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
19
21
  from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump
22
+ from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump
20
23
 
21
24
 
22
25
  class DumpToolFactory:
23
26
  tools = {
24
27
  Const.CELL: {
25
- Const.GRAPH_KBYK_MODE: None,
28
+ Const.GRAPH_KBYK_MODE: GraphModeCellDump,
26
29
  Const.GRAPH_GE_MODE: None,
27
30
  Const.PYNATIVE_MODE: None
28
31
  },
@@ -39,14 +42,21 @@ class DumpToolFactory:
39
42
  }
40
43
 
41
44
  @staticmethod
42
- def create(config: DebuggerConfig):
43
- if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST:
44
- raise Exception("data_mode must be one of all, input, output.")
45
+ def create(config: DebuggerConfig, model=None):
46
+ if config.level == Const.CELL:
47
+ if not is_graph_mode_cell_dump_allowed(config):
48
+ raise Exception("Cell dump is not supported in graph mode.")
49
+ if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST:
50
+ raise Exception("data_mode must be one of all, forward, backward.")
51
+ else:
52
+ if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST:
53
+ raise Exception("data_mode must be one of all, input, output.")
45
54
  tool = DumpToolFactory.tools.get(config.level)
46
55
  if not tool:
47
56
  raise Exception("Valid level is needed.")
48
57
  tool = tool.get(config.execution_mode)
49
58
  if not tool:
50
- raise Exception(f"Data dump is not supported in {config.execution_mode} mode "
51
- f"when dump level is {config.level}.")
52
- return tool(config)
59
+ logger.error(f"Data dump is not supported in {config.execution_mode} mode "
60
+ f"when dump level is {config.level}.")
61
+ raise ValueError
62
+ return tool(config, model) if tool == GraphModeCellDump else tool(config)
@@ -0,0 +1,139 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ import mindspore as ms
19
+ from mindspore import hal, ops, Tensor
20
+ from mindspore.ops.primitive import _run_op
21
+
22
+ from msprobe.core.common.const import Const as CoreConst
23
+ from msprobe.core.common.runtime import Runtime
24
+ from msprobe.mindspore.common.const import Const
25
+ from msprobe.mindspore.common.log import logger
26
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
27
+ import msprobe.mindspore.dump.cell_dump_process as cellDumperWithDumpGradient
28
+ import msprobe.mindspore.dump.cell_dump_with_insert_gradient as cellDumperWithInsertGradient
29
+
30
+ tensordump_flag = True
31
+ try:
32
+ from mindspore._c_expression import _tensordump_set_step
33
+ except ImportError:
34
+ tensordump_flag = False
35
+
36
+
37
+ class GraphModeCellDump:
38
+ task = CoreConst.STATISTICS
39
+
40
+ def __init__(self, config: DebuggerConfig, model, strict=True):
41
+ self.net = model
42
+ self.white_list = []
43
+ self.black_list = []
44
+ self.execution_mode = config.execution_mode
45
+ self.dump_path = config.dump_path if config.dump_path else "./"
46
+ self.rank = config.rank
47
+ self.step = config.step
48
+ self.scope = config.scope
49
+ self.list = config.list
50
+ self.data_mode = config.data_mode
51
+ self.file_format = config.file_format
52
+ GraphModeCellDump.task = config.task
53
+ self.summary_mode = config.summary_mode
54
+ self.check_config(strict)
55
+ self.set_step()
56
+
57
+ @staticmethod
58
+ def step():
59
+ # 更新TensorDump Step
60
+ if GraphModeCellDump.task == CoreConst.TENSOR:
61
+ hal.synchronize()
62
+ temp_tensor = ms.Tensor([1], dtype=ms.float32)
63
+ step_flag = "<tensordump-update-step>"
64
+ _run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor))
65
+ ops.tensordump(step_flag, temp_tensor)
66
+
67
+ def check_config(self, strict):
68
+ if not self.net:
69
+ raise Exception("The model is empty and cell dump is not enabled.")
70
+
71
+ if strict:
72
+ if self.rank:
73
+ raise Exception("In graph mode, cell dump does not currently support specifying rank.")
74
+ if self.scope:
75
+ raise Exception("In graph mode, cell dump does not currently support specifying scope.")
76
+ if self.list:
77
+ raise Exception("In graph mode, cell dump does not currently support specifying list.")
78
+ if len(self.data_mode) != 1 or self.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST:
79
+ raise Exception("In graph mode and cell dump, data_mode must be one of all, forword, backword.")
80
+ if self.file_format != []:
81
+ logger.warning("In graph mode, cell dump does not currently support specifying file_format."
82
+ " The file will be stored in npy format.")
83
+ if self.task == CoreConst.STATISTICS and self.summary_mode == CoreConst.MD5:
84
+ raise Exception("The L0 level statistics dump mode does not support "
85
+ "the calculation of md5 values currently In graph mode.")
86
+ else:
87
+ self.rank = []
88
+ self.scope = []
89
+ self.list = []
90
+ self.file_format = []
91
+ if len(self.data_mode) != 1 or self.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST:
92
+ self.data_mode = [CoreConst.ALL]
93
+ if self.task == CoreConst.STATISTICS and self.summary_mode == CoreConst.MD5:
94
+ self.summary_mode = CoreConst.STATISTICS
95
+
96
+ return True
97
+
98
+ def set_step(self):
99
+ if tensordump_flag:
100
+ _tensordump_set_step(self.step)
101
+ else:
102
+ raise Exception(
103
+ "Importing _tensordump_set_step failed, "
104
+ "please use the latest version package of MindSpore."
105
+ )
106
+
107
+ def handle(self):
108
+ os.environ['MS_JIT_MODULES'] = 'msprobe'
109
+
110
+ if Runtime.run_mode == Const.PYNATIVE_GRAPH_MODE:
111
+ dump_path = os.path.join(self.dump_path, Const.GRAPH_MODE)
112
+ else:
113
+ dump_path = self.dump_path
114
+
115
+ cell_dumper = cellDumperWithDumpGradient
116
+
117
+ if self.execution_mode == Const.PYNATIVE_MODE:
118
+ enable_dump_gradient = hasattr(ops, 'DumpGradient')
119
+ if hasattr(ops, 'DumpGradient'):
120
+ try:
121
+ ops.DumpGradient()('grad.npy', Tensor([0], dtype=ms.float32), 'in')
122
+ except Exception:
123
+ enable_dump_gradient = False
124
+ logger.warning('the DumpGradient operator failed to execute.')
125
+ if not enable_dump_gradient:
126
+ cell_dumper = cellDumperWithInsertGradient
127
+
128
+ dump_config = cell_dumper.CellDumpConfig(
129
+ net=self.net,
130
+ dump_path=dump_path,
131
+ data_mode=self.data_mode[0],
132
+ task=self.task,
133
+ summary_mode=self.summary_mode,
134
+ step=self.step
135
+ )
136
+
137
+ cell_dumper.start(
138
+ dump_config
139
+ )
@@ -0,0 +1,123 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from collections import OrderedDict
18
+ import mindspore as ms
19
+
20
+
21
+ def _iterate_items(data):
22
+ if isinstance(data, (dict, OrderedDict)):
23
+ return data.items()
24
+ elif isinstance(data, (list, tuple)):
25
+ return enumerate(data)
26
+ else:
27
+ raise TypeError("Unsupported data type")
28
+
29
+
30
+ class _SaveBase:
31
+ def __init__(self, save_dir):
32
+ super(_SaveBase, self).__init__()
33
+ self.path = save_dir
34
+ self.save_func = _npy_save
35
+
36
+ def get_save_func(self):
37
+ return self.save_func
38
+
39
+
40
+ @ms.jit_class
41
+ class _SaveCell(_SaveBase):
42
+ def __call__(self, name, data):
43
+ return self.get_save_func()(self.path, name, data)
44
+
45
+
46
+ class _SaveGradBase:
47
+ def __init__(self, save_dir, name):
48
+ super(_SaveGradBase, self).__init__()
49
+ self.file = save_dir + name
50
+
51
+
52
+ @ms.jit_class
53
+ class _SaveGradCell(_SaveGradBase):
54
+ def __init__(self, save_dir, name):
55
+ super(_SaveGradCell, self).__init__(save_dir, name)
56
+ self.ms_save_grad = ms.ops.InsertGradientOf(
57
+ _wrapper_save_grad_func(self.file))
58
+
59
+ def __call__(self, x):
60
+ if isinstance(x, ms.Tensor):
61
+ return self.ms_save_grad(x)
62
+ else:
63
+ raise TypeError(f"For 'save_grad', the type of argument 'data' must be mindspore.Tensor or torch.tensor, "
64
+ f"but got {type(x)}")
65
+
66
+
67
+ def _npy_save_ops(file, data):
68
+ if isinstance(data, ms.Tensor):
69
+ if data.dtype == ms.bfloat16:
70
+ data = data.float()
71
+ ms.ops.TensorDump()(file, data)
72
+ else:
73
+ raise TypeError(f"For 'save', the type of argument 'data' must be mindspore.Tensor or torch.tensor, "
74
+ f"but got {type(data)}")
75
+
76
+
77
+ def _wrapper_save_grad_func(file):
78
+ def _save_grad_func(grad):
79
+ data = grad
80
+ if data.dtype == ms.bfloat16:
81
+ data = data.float()
82
+ ms.ops.TensorDump()(file, data)
83
+ return grad
84
+ return _save_grad_func
85
+
86
+
87
+ def _npy_save(save_dir, item_name, data):
88
+ if isinstance(data, (list, tuple, dict, OrderedDict)):
89
+ for key, val in _iterate_items(data):
90
+ _npy_save(save_dir, f"{item_name}.{key}", val)
91
+ else:
92
+ if data is None:
93
+ return
94
+ _npy_save_ops(f"{save_dir}{item_name}", data)
95
+
96
+
97
+ def generate_dump_dir(save_dir, sep=os.sep):
98
+ """
99
+ usage: generate dump directory path str in mindspore graph mode
100
+ """
101
+ full_suffix = '{step}' + sep + '{rank}' + sep
102
+ if save_dir and save_dir[-1] != sep:
103
+ result_dir = save_dir + sep + full_suffix
104
+ else:
105
+ result_dir = save_dir + full_suffix
106
+ return result_dir
107
+
108
+
109
+ def save(save_dir, name, data):
110
+ """
111
+ save tensor.
112
+ """
113
+ dump_dir = generate_dump_dir(save_dir)
114
+ _SaveCell(dump_dir)(name, data)
115
+
116
+
117
+ def save_grad(save_dir, name, data):
118
+ """
119
+ save grad.
120
+ """
121
+ dump_dir = generate_dump_dir(save_dir)
122
+ suffix_name = name + '_grad'
123
+ return _SaveGradCell(dump_dir, suffix_name)(data)
@@ -0,0 +1,176 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import inspect
18
+
19
+ from mindspore import Tensor, ops, mint
20
+ from mindspore.mint import distributed
21
+ from mindspore.mint.nn import functional
22
+ from mindspore.communication import comm_func
23
+
24
+ from msprobe.core.common.file_utils import load_yaml
25
+ from msprobe.core.common.utils import Const
26
+ from msprobe.core.data_dump.api_registry import ApiRegistry
27
+ from msprobe.mindspore.common.log import logger
28
+ from msprobe.mindspore.common.const import Const as MsConst
29
+ from msprobe.mindspore.common.utils import is_mindtorch
30
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
31
+
32
+
33
+ stub_tensor_existed = True
34
+ try:
35
+ from mindspore.common._stub_tensor import StubTensor
36
+ except ImportError:
37
+ stub_tensor_existed = False
38
+
39
+ cur_path = os.path.dirname(os.path.realpath(__file__))
40
+ if not is_mindtorch():
41
+ _api_types = {
42
+ Const.MS_FRAMEWORK: {
43
+ Const.MS_API_TYPE_OPS: (ops, (ops,)),
44
+ Const.MS_API_TYPE_TENSOR: (Tensor, (Tensor,)),
45
+ Const.MS_API_TYPE_MINT: (mint, (mint,)),
46
+ Const.MS_API_TYPE_MINT_FUNC: (functional, (functional,)),
47
+ Const.MS_API_TYPE_COM: (comm_func, (comm_func,)),
48
+ Const.MS_API_TYPE_MINT_DIST: (distributed, (distributed,))
49
+ }
50
+ }
51
+ if stub_tensor_existed:
52
+ _api_types.get(Const.MS_FRAMEWORK).update(
53
+ {Const.MS_API_TYPE_STUB_TENSOR: (StubTensor, (StubTensor,))}
54
+ )
55
+
56
+ _supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),)
57
+ _backlist = []
58
+ else:
59
+ import torch
60
+ import torch_npu
61
+ _api_types = {
62
+ Const.MT_FRAMEWORK: {
63
+ Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
64
+ Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)),
65
+ Const.PT_API_TYPE_TORCH: (torch, (torch,)),
66
+ Const.PT_API_TYPE_NPU: (torch_npu, (torch_npu,)),
67
+ Const.PT_API_TYPE_DIST: (torch.distributed, (torch.distributed, torch.distributed.distributed_c10d))
68
+ }
69
+ }
70
+ _supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module',
71
+ MsConst.SUPPORTED_API_LIST_FILE),)
72
+ _backlist = [f'{Const.PT_API_TYPE_TENSOR}.__setitem__']
73
+
74
+ _inner_used_api = {
75
+ Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: (
76
+ ops, "norm", "square", "sqrt", "is_complex", "stack", "is_floating_point"
77
+ ),
78
+ Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_TENSOR: (
79
+ Tensor, "to", "numel", 'sum'
80
+ ),
81
+ Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_MINT: (
82
+ mint, "max", "min", "mean", "norm"
83
+ )
84
+ }
85
+
86
+
87
+ class ApiTemplate(HOOKCell):
88
+ def __init__(self, api_name, api_func, prefix, hook_build_func):
89
+ self.api_name = api_name
90
+ self.api_func = api_func
91
+ self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
92
+ super().__init__(hook_build_func)
93
+ distributed_prefix = Const.DIST_API_TYPE_PREFIX if is_mindtorch() else Const.MINT_DIST_API_TYPE_PREFIX
94
+ if prefix == distributed_prefix:
95
+ self.op_is_distributed = True
96
+
97
+ @staticmethod
98
+ def async_to_sync(output):
99
+ # Fake handle, used to return after the CommHandle executes the wait method
100
+ fake_handle = type("FakeHandle", (), {"wait": lambda self: None})()
101
+ if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"):
102
+ output[1].wait()
103
+ output = (output[0], fake_handle)
104
+ elif hasattr(output, "wait"):
105
+ output.wait()
106
+ output = fake_handle
107
+ return output
108
+
109
+ def construct(self, *args, **kwargs):
110
+ if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
111
+ return args[0] if args else kwargs.get(Const.INPUT)
112
+
113
+ output = self.api_func(*args, **kwargs)
114
+
115
+ if self.prefix_api_name.startswith(
116
+ (MsConst.DISTRIBUTED_DATA_PREFIX, Const.MINT_DIST_API_TYPE_PREFIX)
117
+ ):
118
+ try:
119
+ bound = inspect.signature(self.api_func).bind(*args, **kwargs)
120
+ bound.apply_defaults()
121
+ use_asyn_op_flag = bound.arguments.get("asyn_op", False)
122
+ except Exception as e:
123
+ use_asyn_op_flag = False
124
+ logger.warning(f"fail to get dist api's func signature because {e}, no wait")
125
+
126
+ if use_asyn_op_flag or self.api_name in ["isend", "irecv"]:
127
+ output = self.async_to_sync(output)
128
+ if self.api_name == "batch_isend_irecv" and isinstance(output, list):
129
+ output = [self.async_to_sync(handle) for handle in output]
130
+
131
+ return output
132
+
133
+ def forward(self, *args, **kwargs):
134
+ if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
135
+ return args[0] if args else kwargs.get(Const.INPUT)
136
+ return self.api_func(*args, **kwargs)
137
+
138
+
139
+ api_register = None
140
+ stub_tensor_set = False
141
+
142
+
143
+ def get_api_register(return_new=False):
144
+ global stub_tensor_set
145
+
146
+ def stub_method(method):
147
+ def wrapped_method(*args, **kwargs):
148
+ return method(*args, **kwargs)
149
+ return wrapped_method
150
+ if not is_mindtorch() and stub_tensor_existed and not stub_tensor_set:
151
+ api_names = load_yaml(_supported_api_list_path[0]).get(Const.MS_API_TYPE_TENSOR, [])
152
+ for attr_name in dir(StubTensor):
153
+ attr = getattr(StubTensor, attr_name)
154
+ if attr_name in api_names and callable(attr):
155
+ setattr(StubTensor, attr_name, stub_method(attr))
156
+ stub_tensor_set = True
157
+
158
+ if return_new:
159
+ return ApiRegistry(
160
+ _api_types,
161
+ _inner_used_api,
162
+ _supported_api_list_path,
163
+ ApiTemplate,
164
+ _backlist
165
+ )
166
+
167
+ global api_register
168
+ if api_register is None:
169
+ api_register = ApiRegistry(
170
+ _api_types,
171
+ _inner_used_api,
172
+ _supported_api_list_path,
173
+ ApiTemplate,
174
+ _backlist
175
+ )
176
+ return api_register
@@ -15,11 +15,16 @@
15
15
 
16
16
  from collections import defaultdict
17
17
 
18
+ import mindspore as ms
18
19
  from mindspore import nn
19
20
 
21
+ from msprobe.core.common.runtime import Runtime
20
22
  from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions
21
23
 
22
24
 
25
+ ms_version = ms.__version__
26
+
27
+
23
28
  def add_cell_count(name):
24
29
  HOOKCell.cell_count[name] += 1
25
30
 
@@ -28,29 +33,34 @@ def get_cell_count(name):
28
33
  return HOOKCell.cell_count[name]
29
34
 
30
35
 
31
- def __init__(self, build_hook) -> None:
36
+ def __init__(self, hook_build_func) -> None:
32
37
  super(HOOKCell, self).__init__()
33
38
  self.changed_status = False
34
- self.input_kwargs = {}
35
- self.prefix = ""
39
+ self.msprobe_input_kwargs = {}
36
40
  if not HOOKCell.g_stop_hook:
37
41
  HOOKCell.g_stop_hook = True
38
42
  self.changed_status = True
39
- if hasattr(self, "prefix_api_name"):
40
- self.prefix = self.prefix_api_name
41
-
42
43
  self.forward_data_collected = False
43
- forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = build_hook(self.prefix)
44
- self.register_forward_pre_hook(forward_pre_hook)
45
- self.register_forward_hook(forward_hook)
46
- register_backward_hook_functions["full"](self, backward_hook)
47
- register_backward_hook_functions["pre"](self, backward_pre_hook)
44
+
45
+ if not Runtime.is_running:
46
+ return
47
+ prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
48
+ if callable(hook_build_func):
49
+ hook_set = hook_build_func(prefix)
50
+ if ms_version < "2.6.0" and not is_mindtorch():
51
+ getattr(self, "_forward_pre_hook", {})[id(self)] = hook_set.forward_pre_hook
52
+ getattr(self, "_forward_hook", {})[id(self)] = hook_set.forward_hook
53
+ else:
54
+ self.register_forward_pre_hook(hook_set.forward_pre_hook)
55
+ self.register_forward_hook(hook_set.forward_hook)
56
+ register_backward_hook_functions["full"](self, hook_set.backward_hook)
57
+ register_backward_hook_functions["pre"](self, hook_set.backward_pre_hook)
48
58
 
49
59
 
50
60
  # 重载call,加全局标志。
51
61
  def __call__(self, *args, **kwargs):
52
62
  try:
53
- self.input_kwargs = kwargs
63
+ self.msprobe_input_kwargs = kwargs
54
64
  out = super(HOOKCell, self).__call__(*args, **kwargs)
55
65
  except Exception as e:
56
66
  raise e
@@ -0,0 +1,88 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from mindspore.common.api import _no_grad
17
+ from msprobe.core.common.const import Const
18
+ from msprobe.core.common.utils import replace_last_occurrence
19
+ from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputs
20
+ from msprobe.core.hook_manager import BaseHookManager, HookSet
21
+ from msprobe.mindspore.common.utils import has_kwargs_in_forward_hook
22
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
23
+
24
+
25
+ class MindsproeHookManager(BaseHookManager):
26
+ @property
27
+ def _is_recompute(self):
28
+ return None
29
+
30
+ @staticmethod
31
+ def _no_grad_context():
32
+ return _no_grad()
33
+
34
+ @staticmethod
35
+ def _add_count(name):
36
+ HOOKCell.add_cell_count(name)
37
+
38
+ @staticmethod
39
+ def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs):
40
+ if not has_kwargs_in_forward_hook() or hook_type == Const.API:
41
+ kwargs = module.msprobe_input_kwargs if hasattr(module, 'msprobe_input_kwargs') else {}
42
+ output = kwargs_or_output
43
+ else:
44
+ kwargs = kwargs_or_output
45
+ output = output_or_kwargs
46
+ return kwargs, output
47
+
48
+ def build_hook(self, hook_type, name):
49
+ if hook_type == Const.API:
50
+ full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD
51
+ else:
52
+ full_forward_name = name
53
+ full_backward_name = replace_last_occurrence(full_forward_name, Const.FORWARD, Const.BACKWARD)
54
+ hookset = HookSet(
55
+ forward_hook=self._build_forward_hook(hook_type, full_forward_name),
56
+ forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name),
57
+ backward_hook=self._build_backward_hook(hook_type, full_backward_name),
58
+ backward_pre_hook=self._build_backward_pre_hook(hook_type, full_backward_name)
59
+ )
60
+ return hookset
61
+
62
+ def _need_exchange(self, module):
63
+ if not hasattr(module, 'has_pre_hook_called') or not module.has_pre_hook_called:
64
+ return False
65
+ else:
66
+ return True
67
+
68
+ def _get_params_dict(self, module):
69
+ params_dict = {}
70
+ if self.config.task != Const.STRUCTURE:
71
+ params_dict = {
72
+ key.split(Const.SEP)[-1]: value
73
+ for key, value in module.parameters_dict(recurse=False).items()
74
+ }
75
+ return params_dict
76
+
77
+ def _build_backward_pre_hook(self, hook_type, name):
78
+ def backward_pre_hook(module, grad_input):
79
+ if self.config.level != Const.LEVEL_L2:
80
+ return
81
+ if not self._should_execute_hook(hook_type, module, False):
82
+ return
83
+ BaseHookManager.inner_switch = True
84
+ module_input = ModuleBackwardInputs(grad_input=grad_input)
85
+ self.data_collector.update_api_or_module_name(name)
86
+ self.data_collector.backward_input_data_collect(name, module, self._pid, module_input)
87
+ BaseHookManager.inner_switch = False
88
+ return backward_pre_hook