mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -1,4 +1,5 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
2
3
  #
3
4
  # Licensed under the Apache License, Version 2.0 (the "License");
4
5
  # you may not use this file except in compliance with the License.
@@ -11,7 +12,6 @@
11
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
13
  # See the License for the specific language governing permissions and
13
14
  # limitations under the License.
14
- # ============================================================================
15
15
 
16
16
  from mindspore import Tensor, ops, mint
17
17
  from mindspore.mint.nn import functional
@@ -20,8 +20,21 @@ from mindspore.communication import comm_func
20
20
 
21
21
  from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP,
22
22
  HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP,
23
- get_wrap_api_list, setup_hooks)
23
+ HOOKTorchOP, HOOKTorchTensor, HOOKTorchFunctionalOP,
24
+ HOOKTorchDistributedOP, HOOKTorchNpuOP,
25
+ get_wrap_api_list, get_wrap_torch_api_list, setup_hooks)
24
26
  from msprobe.core.common.utils import Const
27
+ from msprobe.mindspore.common.utils import is_mindtorch
28
+
29
+ if is_mindtorch():
30
+ import torch
31
+ import torch_npu
32
+
33
+
34
+ def stub_method(method):
35
+ def wrapped_method(*args, **kwargs):
36
+ return method(*args, **kwargs)
37
+ return wrapped_method
25
38
 
26
39
 
27
40
  class ApiRegistry:
@@ -34,6 +47,12 @@ class ApiRegistry:
34
47
  self.distributed_ori_attr = {}
35
48
  self.norm_inner_ops_ori_attr = {}
36
49
 
50
+ self.torch_ori_attr = {}
51
+ self.torch_tensor_ori_attr = {}
52
+ self.torch_functional_ori_attr = {}
53
+ self.torch_distributed_ori_attr = {}
54
+ self.torch_npu_ori_attr = {}
55
+
37
56
  self.tensor_hook_attr = {}
38
57
  self.stub_tensor_hook_attr = {}
39
58
  self.functional_hook_attr = {}
@@ -42,6 +61,12 @@ class ApiRegistry:
42
61
  self.distibuted_hook_attr = {}
43
62
  self.norm_inner_ops_hook_attr = {}
44
63
 
64
+ self.torch_hook_attr = {}
65
+ self.torch_tensor_hook_attr = {}
66
+ self.torch_functional_hook_attr = {}
67
+ self.torch_distributed_hook_attr = {}
68
+ self.torch_npu_hook_attr = {}
69
+
45
70
  self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
46
71
 
47
72
  @staticmethod
@@ -50,9 +75,13 @@ class ApiRegistry:
50
75
  if Const.SEP in api:
51
76
  sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
52
77
  sub_module = getattr(ori_api_group, sub_module_name)
53
- api_ori_attr[api] = getattr(sub_module, sub_op)
78
+ ori_api_func = getattr(sub_module, sub_op)
54
79
  else:
55
- api_ori_attr[api] = getattr(ori_api_group, api)
80
+ ori_api_func = getattr(ori_api_group, api)
81
+ if ori_api_group == StubTensor:
82
+ api_ori_attr[api] = stub_method(ori_api_func)
83
+ continue
84
+ api_ori_attr[api] = ori_api_func
56
85
 
57
86
  @staticmethod
58
87
  def set_api_attr(api_group, attr_dict):
@@ -72,22 +101,71 @@ class ApiRegistry:
72
101
  self.set_api_attr(ops, self.norm_inner_ops_ori_attr)
73
102
 
74
103
  def api_set_hook_func(self):
75
- self.set_api_attr(Tensor, self.tensor_hook_attr)
76
- self.set_api_attr(StubTensor, self.stub_tensor_hook_attr)
77
- self.set_api_attr(ops, self.functional_hook_attr)
78
- self.set_api_attr(mint, self.mint_ops_hook_attr)
79
- self.set_api_attr(functional, self.mint_func_ops_hook_attr)
80
- self.set_api_attr(comm_func, self.distibuted_hook_attr)
104
+ if is_mindtorch():
105
+ self.set_api_attr(torch, self.torch_hook_attr)
106
+ self.set_api_attr(torch.Tensor, self.torch_tensor_hook_attr)
107
+ self.set_api_attr(torch.nn.functional, self.torch_functional_hook_attr)
108
+ self.set_api_attr(torch.distributed, self.torch_distributed_hook_attr)
109
+ self.set_api_attr(torch_npu, self.torch_npu_hook_attr)
110
+ else:
111
+ self.set_api_attr(Tensor, self.tensor_hook_attr)
112
+ self.set_api_attr(StubTensor, self.stub_tensor_hook_attr)
113
+ self.set_api_attr(ops, self.functional_hook_attr)
114
+ self.set_api_attr(mint, self.mint_ops_hook_attr)
115
+ self.set_api_attr(functional, self.mint_func_ops_hook_attr)
116
+ self.set_api_attr(comm_func, self.distibuted_hook_attr)
81
117
 
82
118
  def api_set_ori_func(self):
83
- self.set_api_attr(Tensor, self.tensor_ori_attr)
84
- self.set_api_attr(StubTensor, self.stub_tensor_ori_attr)
85
- self.set_api_attr(ops, self.functional_ori_attr)
86
- self.set_api_attr(mint, self.mint_ops_ori_attr)
87
- self.set_api_attr(functional, self.mint_func_ops_ori_attr)
88
- self.set_api_attr(comm_func, self.distributed_ori_attr)
119
+ if is_mindtorch():
120
+ self.set_api_attr(torch, self.torch_ori_attr)
121
+ self.set_api_attr(torch.Tensor, self.torch_tensor_ori_attr)
122
+ self.set_api_attr(torch.nn.functional, self.torch_functional_ori_attr)
123
+ self.set_api_attr(torch.distributed, self.torch_distributed_ori_attr)
124
+ self.set_api_attr(torch_npu, self.torch_npu_ori_attr)
125
+ else:
126
+ self.set_api_attr(Tensor, self.tensor_ori_attr)
127
+ self.set_api_attr(StubTensor, self.stub_tensor_ori_attr)
128
+ self.set_api_attr(ops, self.functional_ori_attr)
129
+ self.set_api_attr(mint, self.mint_ops_ori_attr)
130
+ self.set_api_attr(functional, self.mint_func_ops_ori_attr)
131
+ self.set_api_attr(comm_func, self.distributed_ori_attr)
89
132
 
90
133
  def initialize_hook(self, hook):
134
+ setup_hooks(hook)
135
+ if is_mindtorch():
136
+ wrap_torch_api_name = get_wrap_torch_api_list()
137
+ self.store_ori_attr(torch,
138
+ wrap_torch_api_name.torch_api_names, self.torch_ori_attr)
139
+ self.store_ori_attr(torch.Tensor,
140
+ wrap_torch_api_name.tensor_api_names, self.torch_tensor_ori_attr)
141
+ self.store_ori_attr(torch.nn.functional,
142
+ wrap_torch_api_name.functional_api_names, self.torch_functional_ori_attr)
143
+ self.store_ori_attr(torch.distributed,
144
+ wrap_torch_api_name.distributed_api_names, self.torch_distributed_ori_attr)
145
+ self.store_ori_attr(torch_npu,
146
+ wrap_torch_api_name.npu_api_names, self.torch_npu_ori_attr)
147
+ for attr_name in dir(HOOKTorchOP):
148
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
149
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
150
+ self.torch_hook_attr[api_name] = getattr(HOOKTorchOP, attr_name)
151
+ for attr_name in dir(HOOKTorchTensor):
152
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
153
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
154
+ self.torch_tensor_hook_attr[api_name] = getattr(HOOKTorchTensor, attr_name)
155
+ for attr_name in dir(HOOKTorchFunctionalOP):
156
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
157
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
158
+ self.torch_functional_hook_attr[api_name] = getattr(HOOKTorchFunctionalOP, attr_name)
159
+ for attr_name in dir(HOOKTorchDistributedOP):
160
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
161
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
162
+ self.torch_distributed_hook_attr[api_name] = getattr(HOOKTorchDistributedOP, attr_name)
163
+ for attr_name in dir(HOOKTorchNpuOP):
164
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
165
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
166
+ self.torch_npu_hook_attr[api_name] = getattr(HOOKTorchNpuOP, attr_name)
167
+ return
168
+
91
169
  wrap_api_name = get_wrap_api_list()
92
170
  self.store_ori_attr(Tensor, wrap_api_name.tensor_api_names, self.tensor_ori_attr)
93
171
  self.store_ori_attr(StubTensor, wrap_api_name.stub_tensor_api_names, self.stub_tensor_ori_attr)
@@ -96,7 +174,6 @@ class ApiRegistry:
96
174
  self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr)
97
175
  self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr)
98
176
  self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
99
- setup_hooks(hook)
100
177
  for attr_name in dir(HOOKTensor):
101
178
  if attr_name.startswith(Const.ATTR_NAME_PREFIX):
102
179
  api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
@@ -1,4 +1,5 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
2
3
  #
3
4
  # Licensed under the Apache License, Version 2.0 (the "License");
4
5
  # you may not use this file except in compliance with the License.
@@ -11,45 +12,66 @@
11
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
13
  # See the License for the specific language governing permissions and
13
14
  # limitations under the License.
14
- # ============================================================================
15
15
 
16
16
  from collections import defaultdict
17
17
 
18
18
  from mindspore import nn
19
19
 
20
- from msprobe.core.common.const import Const
21
-
22
-
23
- class HOOKCell(nn.Cell):
24
- cell_count = defaultdict(int)
25
- g_stop_hook = False
26
-
27
- def __init__(self, build_hook) -> None:
28
- super(HOOKCell, self).__init__()
29
- self.changed_status = False
30
- self.input_kwargs = {}
31
- self.prefix = ""
32
- if not HOOKCell.g_stop_hook:
33
- HOOKCell.g_stop_hook = True
34
- self.changed_status = True
35
- if hasattr(self, "prefix_api_name"):
36
- self.prefix = self.prefix_api_name
37
-
38
- HOOKCell.cell_count[self.prefix] += 1
39
- self.prefix = self.prefix + str(HOOKCell.cell_count[self.prefix] - 1) + Const.SEP
40
- forward_hook, backward_hook = build_hook(self.prefix)
41
- self.register_forward_hook(forward_hook)
42
- self.register_backward_hook(backward_hook)
43
-
44
- # 重载call,加全局标志。
45
- def __call__(self, *args, **kwargs):
46
- try:
47
- self.input_kwargs = kwargs
48
- out = super(HOOKCell, self).__call__(*args, **kwargs)
49
- except Exception as e:
50
- raise e
51
- finally:
52
- if self.changed_status:
53
- self.changed_status = False
54
- HOOKCell.g_stop_hook = False
55
- return out
20
+ from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions
21
+
22
+
23
+ def add_cell_count(name):
24
+ HOOKCell.cell_count[name] += 1
25
+
26
+
27
+ def get_cell_count(name):
28
+ return HOOKCell.cell_count[name]
29
+
30
+
31
+ def __init__(self, build_hook) -> None:
32
+ super(HOOKCell, self).__init__()
33
+ self.changed_status = False
34
+ self.input_kwargs = {}
35
+ self.prefix = ""
36
+ if not HOOKCell.g_stop_hook:
37
+ HOOKCell.g_stop_hook = True
38
+ self.changed_status = True
39
+ if hasattr(self, "prefix_api_name"):
40
+ self.prefix = self.prefix_api_name
41
+
42
+ self.forward_data_collected = False
43
+ forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = build_hook(self.prefix)
44
+ self.register_forward_pre_hook(forward_pre_hook)
45
+ self.register_forward_hook(forward_hook)
46
+ register_backward_hook_functions["full"](self, backward_hook)
47
+ register_backward_hook_functions["pre"](self, backward_pre_hook)
48
+
49
+
50
+ # 重载call,加全局标志。
51
+ def __call__(self, *args, **kwargs):
52
+ try:
53
+ self.input_kwargs = kwargs
54
+ out = super(HOOKCell, self).__call__(*args, **kwargs)
55
+ except Exception as e:
56
+ raise e
57
+ finally:
58
+ if self.changed_status:
59
+ self.changed_status = False
60
+ HOOKCell.g_stop_hook = False
61
+ return out
62
+
63
+
64
+ hook_cell_dict = {
65
+ "cell_count": defaultdict(int),
66
+ "g_stop_hook": False,
67
+ "add_cell_count": staticmethod(add_cell_count),
68
+ "get_cell_count": staticmethod(get_cell_count),
69
+ "__init__": __init__,
70
+ "__call__": __call__
71
+ }
72
+
73
+ if is_mindtorch():
74
+ import torch
75
+ HOOKCell = type("HOOKCell", (torch.nn.Module,), hook_cell_dict)
76
+ else:
77
+ HOOKCell = type("HOOKCell", (nn.Cell,), hook_cell_dict)
@@ -1,4 +1,5 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
2
3
  #
3
4
  # Licensed under the Apache License, Version 2.0 (the "License");
4
5
  # you may not use this file except in compliance with the License.
@@ -11,18 +12,16 @@
11
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
13
  # See the License for the specific language governing permissions and
13
14
  # limitations under the License.
14
- # ============================================================================
15
15
 
16
16
  import os
17
17
 
18
- import mindspore as ms
19
- from mindspore.common.tensor import Tensor
20
18
  from mindspore import ops
19
+ from mindspore.common.tensor import Tensor
21
20
 
22
- from msprobe.mindspore.common.log import logger
23
21
  from msprobe.core.common.utils import Const, DumpException
24
- from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
25
- ModuleBackwardInputs, ModuleBackwardOutputs
22
+ from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, ModuleBackwardOutputs,
23
+ ModuleForwardInputsOutputs)
24
+ from msprobe.mindspore.common.log import logger
26
25
 
27
26
 
28
27
  class PrimitiveHookService:
@@ -41,6 +40,7 @@ class PrimitiveHookService:
41
40
  Returns:
42
41
  callable: 包装后的 primitive 函数。
43
42
  """
43
+
44
44
  def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
45
45
  """
46
46
  创建反向 hook 函数,用于捕获梯度。
@@ -54,26 +54,24 @@ class PrimitiveHookService:
54
54
  Returns:
55
55
  callable: 反向 hook 函数。
56
56
  """
57
- def backward_hook(grad):
58
57
 
59
- captured_grads.append(grad)
58
+ def backward_hook(grad):
59
+ captured_grads.extend(grad)
60
60
  backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}"
61
61
 
62
62
  try:
63
- if len(captured_grads) == num_tensors and hook_type == Const.INPUT:
63
+ if hook_type == Const.INPUT:
64
64
  self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
65
65
  new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
66
66
  self.service_instance.data_collector.backward_output_data_collect(
67
67
  backward_primitive_name, self, os.getpid(), new_module_input_output
68
68
  )
69
- captured_grads.clear()
70
- elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
69
+ elif hook_type == Const.OUTPUT:
71
70
  self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
72
71
  new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
73
72
  self.service_instance.data_collector.backward_input_data_collect(
74
73
  backward_primitive_name, self, os.getpid(), new_module_input_output
75
74
  )
76
- captured_grads.clear()
77
75
 
78
76
  except Exception as exception:
79
77
  logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, "
@@ -104,7 +102,7 @@ class PrimitiveHookService:
104
102
  hooked_inputs.append(arg_hooked)
105
103
  else:
106
104
  hooked_inputs.append(arg)
107
- return hooked_inputs
105
+ return tuple(hooked_inputs)
108
106
 
109
107
  def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
110
108
  """
@@ -137,6 +135,34 @@ class PrimitiveHookService:
137
135
  return tuple(hooked_outputs)
138
136
  return out
139
137
 
138
+ def pre_forward_hook(primitive_name, primitive_instance, args, kwargs):
139
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
140
+ try:
141
+ self.service_instance.data_collector.forward_input_data_collect(
142
+ primitive_name,
143
+ primitive_instance,
144
+ os.getpid(),
145
+ module_input_output
146
+ )
147
+ except Exception as exception:
148
+ logger.error(f"This is a primitive op dump error during forward input data collection: {exception}, "
149
+ f"primitive_name: {primitive_name}")
150
+ raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
151
+
152
+ def post_forward_hook(primitive_name, primitive_instance, args, kwargs, output):
153
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
154
+ try:
155
+ self.service_instance.data_collector.forward_output_data_collect(
156
+ primitive_name,
157
+ primitive_instance,
158
+ os.getpid(),
159
+ module_input_output
160
+ )
161
+ except Exception as exception:
162
+ logger.error(f"This is a primitive op dump error during forward output data collection: {exception}, "
163
+ f"primitive_name: {primitive_name}")
164
+ raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
165
+
140
166
  def wrapped_primitive_call(instance_self, *args, **kwargs):
141
167
  """
142
168
  包装后的 primitive 调用函数,添加输入和输出的 hook。
@@ -165,27 +191,17 @@ class PrimitiveHookService:
165
191
  f"primitive_name: {primitive_name}")
166
192
  raise DumpException(DumpException.INPUT_HOOK_ERROR) from exception
167
193
 
194
+ forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}"
195
+ self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
196
+
197
+ pre_forward_hook(forward_primitive_name, instance_self, hooked_inputs, kwargs)
168
198
  try:
169
199
  out = origin_func(*hooked_inputs, **kwargs)
170
200
  except Exception as exception:
171
201
  logger.error(f"This is a primitive op dump error during function call: {exception}, "
172
202
  f"primitive_name: {primitive_name}")
173
203
  raise DumpException(DumpException.FUNCTION_CALL_ERROR) from exception
174
-
175
- forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}"
176
- self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
177
- if self.service_instance.data_collector:
178
- module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
179
- try:
180
- self.service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
181
- os.getpid(), module_input_output)
182
- except Exception as exception:
183
- logger.error(f"This is a primitive op dump error during forward data collection: {exception}, "
184
- f"primitive_name: {primitive_name}")
185
- raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
186
-
187
- if self.service_instance.data_collector.if_return_forward_new_output():
188
- out = self.service_instance.data_collector.get_forward_new_output()
204
+ post_forward_hook(forward_primitive_name, instance_self, hooked_inputs, kwargs, out)
189
205
 
190
206
  try:
191
207
  out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
@@ -203,4 +219,3 @@ class PrimitiveHookService:
203
219
  self.primitive_counters[primitive_name] = 0
204
220
  else:
205
221
  self.primitive_counters[primitive_name] += 1
206
-
@@ -15,7 +15,7 @@
15
15
 
16
16
  # List of ops that register hooks
17
17
 
18
-
18
+
19
19
  ops:
20
20
  - adaptive_avg_pool1d
21
21
  - adaptive_avg_pool2d
@@ -85,6 +85,7 @@ ops:
85
85
  - relu6
86
86
  - celu
87
87
  - rrelu
88
+ - rms_norm
88
89
  - selu
89
90
  - sigmoid
90
91
  - silu
@@ -490,6 +491,31 @@ ops:
490
491
  - scatter_update
491
492
  - derivative
492
493
  - jet
494
+ - row_stack
495
+ - gather
496
+ - arange
497
+ - cond
498
+ - slice_scatter
499
+ - clip_by_norm
500
+ - eps
501
+ - layer_norm
502
+ - cast
503
+ - numel
504
+ - permute
505
+ - select_scatter
506
+ - group_norm
507
+ - eq
508
+ - embedding
509
+ - ones_like
510
+ - zeros
511
+ - nanmean
512
+ - shape
513
+ - zeros_like
514
+ - ones
515
+ - diagonal_scatter
516
+ - vander
517
+ - is_nonzero
518
+ - rotary_position_embedding
493
519
 
494
520
  tensor:
495
521
  - __abs__
@@ -528,6 +554,7 @@ tensor:
528
554
  - acos
529
555
  - acosh
530
556
  - add
557
+ - add_
531
558
  - addbmm
532
559
  - addcdiv
533
560
  - addcmul
@@ -582,6 +609,7 @@ tensor:
582
609
  - diff
583
610
  - digamma
584
611
  - div
612
+ - div_
585
613
  - divide
586
614
  - equal
587
615
  - erf
@@ -714,6 +742,8 @@ tensor:
714
742
  - square
715
743
  - squeeze
716
744
  - std
745
+ - sub
746
+ - sub_
717
747
  - subtract
718
748
  - subtract
719
749
  - svd
@@ -958,6 +988,7 @@ mint.nn.functional:
958
988
  - one_hot_ext
959
989
  - pad
960
990
  - relu
991
+ - relu_
961
992
  - sigmoid
962
993
  - silu
963
994
  - softmax
@@ -992,3 +1023,7 @@ communication.comm_func:
992
1023
  - broadcast
993
1024
  - gather_into_tensor
994
1025
  - scatter_tensor
1026
+ - send
1027
+ - recv
1028
+ - isend
1029
+ - irecv
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -23,10 +23,16 @@ from mindspore.mint.nn import functional
23
23
  from msprobe.core.common.const import Const
24
24
  from msprobe.core.common.file_utils import load_yaml
25
25
  from msprobe.mindspore.common.const import Const as MsConst
26
+ from msprobe.mindspore.common.utils import is_mindtorch
26
27
  from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
27
28
 
29
+ if is_mindtorch():
30
+ import torch
31
+ import torch_npu
32
+
28
33
  cur_path = os.path.dirname(os.path.realpath(__file__))
29
34
  yaml_path = os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE)
35
+ torch_yaml_path = os.path.join(cur_path, "../../../pytorch/hook_module", MsConst.SUPPORTED_API_LIST_FILE)
30
36
 
31
37
 
32
38
  class HOOKTensor(object):
@@ -53,6 +59,26 @@ class HOOKDistributedOP(object):
53
59
  pass
54
60
 
55
61
 
62
+ class HOOKTorchOP(object):
63
+ pass
64
+
65
+
66
+ class HOOKTorchTensor(object):
67
+ pass
68
+
69
+
70
+ class HOOKTorchFunctionalOP(object):
71
+ pass
72
+
73
+
74
+ class HOOKTorchDistributedOP(object):
75
+ pass
76
+
77
+
78
+ class HOOKTorchNpuOP(object):
79
+ pass
80
+
81
+
56
82
  class ApiTemplate(HOOKCell):
57
83
  def __init__(self, api_name, api_dict, prefix, hook):
58
84
  self.api_name = api_name
@@ -60,7 +86,30 @@ class ApiTemplate(HOOKCell):
60
86
  self.prefix_api_name = prefix + str(api_name.split(Const.SEP)[-1]) + Const.SEP
61
87
  super().__init__(hook)
62
88
 
89
+ @staticmethod
90
+ def async_to_sync(output):
91
+ # Fake handle, used to return after the CommHandle executes the wait method
92
+ fake_handle = type("FakeHandle", (), {"wait": lambda self: None})()
93
+ if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"):
94
+ output[1].wait()
95
+ output = (output[0], fake_handle)
96
+ elif hasattr(output, "wait"):
97
+ output.wait()
98
+ output = fake_handle
99
+ return output
100
+
63
101
  def construct(self, *args, **kwargs):
102
+ if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
103
+ return args[0] if args else kwargs.get(Const.INPUT)
104
+
105
+ output = self.api_func(*args, **kwargs)
106
+
107
+ if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX):
108
+ if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]:
109
+ output = self.async_to_sync(output)
110
+ return output
111
+
112
+ def forward(self, *args, **kwargs):
64
113
  if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
65
114
  return args[0] if args else kwargs.get(Const.INPUT)
66
115
  return self.api_func(*args, **kwargs)
@@ -77,6 +126,15 @@ class WrapApiName:
77
126
  self.distributed_api_names = distributed_api_names
78
127
 
79
128
 
129
+ class WrapTorchApiName:
130
+ def __init__(self, torch_api_names, tensor_api_names, functional_api_names, distributed_api_names, npu_api_names):
131
+ self.torch_api_names = torch_api_names
132
+ self.tensor_api_names = tensor_api_names
133
+ self.functional_api_names = functional_api_names
134
+ self.distributed_api_names = distributed_api_names
135
+ self.npu_api_names = npu_api_names
136
+
137
+
80
138
  def get_wrap_api_list():
81
139
  api_list = load_yaml(yaml_path)
82
140
  tensor_api = api_list.get(MsConst.SUPPORTED_TENSOR_LIST_KEY)
@@ -93,6 +151,21 @@ def get_wrap_api_list():
93
151
  return wrap_api_name
94
152
 
95
153
 
154
+ def get_wrap_torch_api_list():
155
+ api_list = load_yaml(torch_yaml_path)
156
+ torch_api = api_list.get("torch")
157
+ tensor_api = api_list.get("tensor")
158
+ functional_api = api_list.get("functional")
159
+ distributed_api = api_list.get("distributed")
160
+ npu_api = api_list.get("torch_npu")
161
+ wrap_api_name = WrapTorchApiName(set(torch_api) & set(dir(torch)),
162
+ set(tensor_api) & set(dir(torch.Tensor)),
163
+ set(functional_api) & set(dir(torch.nn.functional)),
164
+ set(distributed_api) & set(dir(torch.distributed)),
165
+ set(npu_api) & set(dir(torch_npu)))
166
+ return wrap_api_name
167
+
168
+
96
169
  def wrap_api_func(api_name, api_dict, prefix, hook):
97
170
  def api_function(*args, **kwargs):
98
171
  return ApiTemplate(api_name, api_dict, prefix, hook)(*args, **kwargs)
@@ -106,6 +179,24 @@ def wrap_api_func_and_bind(api_list, api_dict, prefix, hook, hook_class):
106
179
 
107
180
 
108
181
  def setup_hooks(hook):
182
+ if is_mindtorch():
183
+ torch_wrap_api_name = get_wrap_torch_api_list()
184
+ wrap_api_func_and_bind(torch_wrap_api_name.torch_api_names,
185
+ {f: getattr(torch, f) for f in dir(torch)},
186
+ MsConst.TORCH_DATA_PREFIX, hook, HOOKTorchOP)
187
+ wrap_api_func_and_bind(torch_wrap_api_name.tensor_api_names,
188
+ {f: getattr(torch.Tensor, f) for f in dir(torch.Tensor)},
189
+ MsConst.TENSOR_DATA_PREFIX, hook, HOOKTorchTensor)
190
+ wrap_api_func_and_bind(torch_wrap_api_name.functional_api_names,
191
+ {f: getattr(torch.nn.functional, f) for f in dir(torch.nn.functional)},
192
+ MsConst.OPS_DATA_PREFIX, hook, HOOKTorchFunctionalOP)
193
+ wrap_api_func_and_bind(torch_wrap_api_name.distributed_api_names,
194
+ {f: getattr(torch.distributed, f) for f in dir(torch.distributed)},
195
+ MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKTorchDistributedOP)
196
+ wrap_api_func_and_bind(torch_wrap_api_name.npu_api_names, {f: getattr(torch_npu, f) for f in dir(torch_npu)},
197
+ MsConst.TORCH_NPU_DATA_PREFIX, hook, HOOKTorchNpuOP)
198
+ return
199
+
109
200
  wrap_api_name = get_wrap_api_list()
110
201
  wrap_api_func_and_bind(wrap_api_name.tensor_api_names, {f: getattr(Tensor, f) for f in dir(Tensor)},
111
202
  MsConst.TENSOR_DATA_PREFIX, hook, HOOKTensor)