mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -1,4 +1,5 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
2
3
  #
3
4
  # Licensed under the Apache License, Version 2.0 (the "License");
4
5
  # you may not use this file except in compliance with the License.
@@ -11,39 +12,42 @@
11
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
13
  # See the License for the specific language governing permissions and
13
14
  # limitations under the License.
14
- # ============================================================================
15
15
 
16
- import os
17
16
  import copy
18
17
  import functools
18
+ import os
19
19
  from collections import defaultdict
20
20
 
21
21
  import mindspore as ms
22
- from mindspore.common.tensor import Tensor
23
- from mindspore import ops
24
22
  from mindspore import nn
23
+ from mindspore.common.api import _no_grad
24
+ from mindspore.ops.primitive import Primitive
25
25
  try:
26
26
  from mindspore.common._pijit_context import PIJitCaptureContext
27
- pijit_label = True
28
27
  except ImportError:
29
28
  pijit_label = False
29
+ else:
30
+ pijit_label = True
30
31
 
31
-
32
+ from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
33
+ from msprobe.core.common.file_utils import create_directory
34
+ from msprobe.core.common.utils import Const, print_tools_ends_info
32
35
  from msprobe.core.data_dump.data_collector import build_data_collector
36
+ from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,
37
+ ModuleBackwardInputs)
33
38
  from msprobe.core.data_dump.scope import BaseScope
34
- from msprobe.mindspore.common.utils import get_rank_if_initialized
35
- from msprobe.core.common.file_utils import create_directory
39
+ from msprobe.mindspore.cell_processor import CellProcessor
36
40
  from msprobe.mindspore.common.log import logger
37
- from msprobe.core.common.utils import Const, print_tools_ends_info
38
- from msprobe.core.common.exceptions import DistributedNotInitializedError
41
+ from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs,
42
+ is_mindtorch, register_backward_hook_functions)
39
43
  from msprobe.mindspore.dump.hook_cell.api_registry import api_register
40
44
  from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
41
- from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
42
- ModuleBackwardInputs, ModuleBackwardOutputs
43
- from msprobe.core.common.exceptions import MsprobeException
44
- from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
45
- from msprobe.mindspore.cell_processor import CellProcessor
46
45
  from msprobe.mindspore.dump.jit_dump import JitDump
46
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
47
+ from msprobe.mindspore.dump.kernel_dump.kernel_config import create_kernel_config_json
48
+
49
+ if is_mindtorch():
50
+ import torch
47
51
 
48
52
 
49
53
  class Service:
@@ -55,75 +59,196 @@ class Service:
55
59
  self.cell_processor = CellProcessor(self.data_collector.scope)
56
60
  self.primitive_hook_service = PrimitiveHookService(self)
57
61
  self.switch = False
62
+ self.inner_switch = False
58
63
  self.primitive_switch = False
59
64
  self.current_iter = 0
60
65
  self.first_start = True
61
66
  self.current_rank = None
62
67
  self.dump_iter_dir = None
63
68
  self.start_call = False
64
- self.check_level_valid()
65
69
  self.should_stop_service = False
70
+ self.params_grad_info = {}
71
+ # 提前注册,确保注册尽可能多的API hook
72
+ self.register_api_hook()
66
73
 
67
74
  @staticmethod
68
- def check_model_valid(model):
69
- if not model or isinstance(model, nn.Cell):
70
- return model
71
- raise MsprobeException(
72
- MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。"
73
- )
75
+ def check_model_valid(models):
76
+ target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell")
77
+ if models is None or isinstance(models, target_module_type[0]):
78
+ return models
79
+ error_model = None
80
+ if isinstance(models, (list, tuple)):
81
+ for model in models:
82
+ if not isinstance(model, target_module_type[0]):
83
+ error_model = model
84
+ break
85
+ else:
86
+ error_model = models
74
87
 
75
- def check_level_valid(self):
76
- if self.config.level == Const.LEVEL_L2:
88
+ if error_model is not None:
89
+ error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] "
90
+ f"type, currently there is a {type(error_model)} type.")
77
91
  raise MsprobeException(
78
- MsprobeException.INVALID_PARAM_ERROR, "L2 level dump function is currently not supported."
79
- )
92
+ MsprobeException.INVALID_PARAM_ERROR, error_info)
93
+ return models
94
+
95
+ @staticmethod
96
+ def prepare_module_input_output(target_type, cell, input_data, output):
97
+ if target_type == BaseScope.Module_Type_Module:
98
+ module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output)
99
+ else:
100
+ module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output)
101
+ return module_input_output
80
102
 
81
103
  def build_hook(self, target_type, name):
82
- def forward_hook(api_or_cell_name, cell, input, output):
83
- if not self.should_excute_hook():
104
+ def pre_hook(api_or_cell_name, cell, input_data):
105
+ if not self.should_execute_hook(target_type, cell, True):
106
+ clean_input_kwargs(cell)
84
107
  return None
85
108
 
86
- if target_type == BaseScope.Module_Type_Module:
87
- api_or_cell_name = cell.mindstudio_reserved_name
88
- module_input_output = ModuleForwardInputsOutputs(args=input, kwargs={}, output=output)
89
- else:
90
- module_input_output = ModuleForwardInputsOutputs(args=input, kwargs=cell.input_kwargs,
91
- output=output)
109
+ with _no_grad():
110
+ self.inner_switch = True
111
+ if target_type == BaseScope.Module_Type_Module:
112
+ api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
113
+ else:
114
+ cell.forward_data_collected = True
115
+ HOOKCell.add_cell_count(name)
116
+ module_input_output = self.prepare_module_input_output(target_type, cell, input_data, None)
117
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
118
+ self.data_collector.forward_input_data_collect(api_or_cell_name, cell, pid, module_input_output)
119
+ self.inner_switch = False
120
+ return input_data
121
+
122
+ def grad_hook(cell, ori_name, param_name):
123
+ def hook_fn(grad):
124
+ if not self.should_execute_hook(target_type, cell, False):
125
+ return None
126
+ self.inner_switch = True
127
+ self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
128
+ self.inner_switch = False
129
+ return None
92
130
 
93
- self.data_collector.update_api_or_module_name(api_or_cell_name)
94
- self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
95
- if self.data_collector.if_return_forward_new_output():
96
- return self.data_collector.get_forward_new_output()
97
- if target_type == BaseScope.Module_Type_API:
98
- del cell.input_kwargs
99
- return output
131
+ return hook_fn
132
+
133
+ def register_param_hook(ori_name, cell, params_dict):
134
+ '''
135
+ 注册参数hook
136
+ '''
137
+ # data_mode为forward时,不注册参数hook
138
+ if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
139
+ for param_name, param in params_dict.items():
140
+ if param.requires_grad:
141
+ param.register_hook(grad_hook(cell, ori_name, param_name))
142
+
143
+ def init_params_grad_info(cell, params_dict):
144
+ '''
145
+ 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
146
+ '''
147
+ if not params_dict:
148
+ return
149
+ if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
150
+ grad_name = cell.params_grad_name if hasattr(cell, 'params_grad_name') else None
151
+ # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
152
+ if not self.params_grad_info.get(grad_name):
153
+ data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
154
+ # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
155
+ if data_info.get(grad_name):
156
+ # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
157
+ self.data_collector.handle_data(grad_name, data_info,
158
+ flush=self.data_collector.data_processor.is_terminated)
159
+ # 记录当前模块的参数梯度信息已占位
160
+ self.params_grad_info[grad_name] = True
161
+
162
+ def forward_hook(api_or_cell_name, cell, input_data, output):
163
+ if not self.should_execute_hook(target_type, cell, True):
164
+ clean_input_kwargs(cell)
165
+ return None
166
+ with _no_grad():
167
+ self.inner_switch = True
168
+ module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output)
169
+ if target_type == BaseScope.Module_Type_Module:
170
+ api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
171
+ params_dict = {key.split(Const.SEP)[-1]: value for key, value in cell.parameters_dict(
172
+ recurse=False).items()}
173
+ setattr(module_input_output, Const.PARAMS, params_dict)
174
+ # 判断是否需要注册参数hook
175
+ if not hasattr(cell, 'params_grad_name') and params_dict:
176
+ ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0]
177
+ grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
178
+ # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
179
+ setattr(cell, 'params_grad_name', grad_name)
180
+ register_param_hook(ori_name, cell, params_dict)
181
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
182
+ self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
183
+ init_params_grad_info(cell, params_dict)
184
+ else:
185
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
186
+ self.data_collector.forward_output_data_collect(api_or_cell_name, cell, pid, module_input_output)
187
+
188
+ if self.data_collector.if_return_forward_new_output():
189
+ forward_new_output = self.data_collector.get_forward_new_output()
190
+ self.inner_switch = False
191
+ return forward_new_output
192
+ clean_input_kwargs(cell)
193
+ self.inner_switch = False
194
+ return output
100
195
 
101
196
  def backward_hook(api_or_cell_name, cell, grad_input, grad_output):
102
- if not self.should_excute_hook():
197
+ if not self.should_execute_hook(target_type, cell, False):
103
198
  return
199
+ self.inner_switch = True
104
200
 
201
+ need_exchange = True
105
202
  if target_type == BaseScope.Module_Type_Module:
106
- api_or_cell_name = cell.mindstudio_reserved_name
203
+ if not hasattr(cell, 'has_pre_hook_called') or not cell.has_pre_hook_called:
204
+ need_exchange = False
205
+ api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
206
+
107
207
  self.data_collector.update_api_or_module_name(api_or_cell_name)
108
208
  if self.data_collector:
109
209
  # 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入
110
- module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
210
+ if need_exchange:
211
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
212
+ else:
213
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
111
214
  self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
215
+ self.inner_switch = False
216
+
217
+ def pre_backward_hook(api_or_cell_name, cell, grad_input):
218
+ if not self.should_execute_hook(target_type, cell, False):
219
+ return
220
+ self.inner_switch = True
221
+ module_input = ModuleBackwardInputs(grad_input=grad_input)
222
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
223
+ self.data_collector.backward_input_data_collect(api_or_cell_name, cell, pid, module_input)
224
+
225
+ self.inner_switch = False
112
226
 
113
227
  pid = os.getpid()
114
- forward_name_template = name + Const.FORWARD
115
- backward_name_template = name + Const.BACKWARD
116
- forward_hook = functools.partial(forward_hook, forward_name_template)
117
- backward_hook = functools.partial(backward_hook, backward_name_template)
228
+ if target_type == BaseScope.Module_Type_Module:
229
+ full_forward_name = name + Const.FORWARD
230
+ full_backward_name = name + Const.BACKWARD
231
+ else:
232
+ full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD
233
+ full_backward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.BACKWARD
234
+ pre_forward_hook = functools.partial(pre_hook, full_forward_name)
235
+ forward_hook = functools.partial(forward_hook, full_forward_name)
236
+ backward_hook = functools.partial(backward_hook, full_backward_name)
237
+ pre_backward_hook = functools.partial(pre_backward_hook, full_backward_name)
238
+
239
+ def wrap_pre_forward_hook(cell, input_data):
240
+ return pre_forward_hook(cell, input_data)
118
241
 
119
- def wrap_forward_hook(cell, input, output):
120
- return forward_hook(cell, input, output)
242
+ def wrap_forward_hook(cell, input_data, output_data):
243
+ return forward_hook(cell, input_data, output_data)
121
244
 
122
245
  def wrap_backward_hook(cell, grad_input, grad_output):
123
246
  return backward_hook(cell, grad_input, grad_output)
124
247
 
125
- return wrap_forward_hook, wrap_backward_hook
248
+ def wrap_pre_backward_hook(cell, grad_input):
249
+ return pre_backward_hook(cell, grad_input)
126
250
 
251
+ return wrap_pre_forward_hook, wrap_forward_hook, wrap_backward_hook, wrap_pre_backward_hook
127
252
 
128
253
  def update_primitive_counters(self, primitive_name):
129
254
  if primitive_name not in self.primitive_counters:
@@ -131,32 +256,20 @@ class Service:
131
256
  else:
132
257
  self.primitive_counters[primitive_name] += 1
133
258
 
134
- def register_primitive_hooks(self):
135
- primitive_set = set()
136
- for _, cell in self.model.cells_and_names():
137
- for pname, primitive in cell._primitives.items():
138
- primitive_set.add((pname, primitive))
139
-
140
- for pname, primitive in primitive_set:
141
- NewPrimitive = type('NewPrimitive', (primitive.__class__,),
142
- {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__, pname)})
143
- primitive.__class__ = NewPrimitive
144
-
145
259
  def step(self):
260
+ if self.config.async_dump:
261
+ self.data_collector.fill_stack_tensor_data()
262
+ self.data_collector.data_processor.dump_async_data()
263
+ self.data_collector.write_json()
146
264
  self.current_iter += 1
147
265
  self.data_collector.update_iter(self.current_iter)
148
- HOOKCell.cell_count = defaultdict(int)
149
- CellProcessor.reset_cell_stats()
150
- self.primitive_hook_service.primitive_counters.clear()
151
- self.data_collector.data_writer.reset_cache()
152
- JitDump.jit_count = defaultdict(int)
266
+ self.reset_status()
153
267
 
154
268
  def start(self, model=None):
155
269
  self.start_call = True
156
270
  if self.should_stop_service:
157
271
  return
158
272
  if self.need_end_service():
159
- api_register.api_set_ori_func()
160
273
  self.should_stop_service = True
161
274
  self.switch = False
162
275
  self.primitive_switch = False
@@ -176,7 +289,8 @@ class Service:
176
289
 
177
290
  if self.config.rank and self.current_rank not in self.config.rank:
178
291
  return
179
- self.register_hook_new()
292
+ self.register_primitive_hook()
293
+ self.register_cell_hook()
180
294
  if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
181
295
  JitDump.set_config(self.config)
182
296
  JitDump.set_data_collector(self.data_collector)
@@ -195,24 +309,6 @@ class Service:
195
309
  logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
196
310
  JitDump.jit_dump_switch = True
197
311
 
198
- def forward_backward_dump_end(self):
199
- if self.should_stop_service:
200
- return
201
- logger.info(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() is set successfully. ")
202
- if not self.start_call:
203
- logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
204
- raise Exception("debugger.start() is not set in the current scope.")
205
- if not self.switch:
206
- logger.error(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() should be called between "
207
- "debugger.start() and debugger.stop() ")
208
- raise Exception("debugger.stop() is already called. ")
209
- if self.config.step and self.current_iter not in self.config.step:
210
- return
211
- if self.config.rank and self.current_rank not in self.config.rank:
212
- return
213
- self.primitive_switch = False
214
- api_register.api_set_ori_func()
215
-
216
312
  def stop(self):
217
313
  if self.should_stop_service:
218
314
  return
@@ -228,6 +324,9 @@ class Service:
228
324
  self.switch = False
229
325
  self.primitive_switch = False
230
326
  self.start_call = False
327
+ if self.config.async_dump:
328
+ self.data_collector.fill_stack_tensor_data()
329
+ self.data_collector.data_processor.dump_async_data()
231
330
  self.data_collector.write_json()
232
331
  JitDump.jit_dump_switch = False
233
332
 
@@ -238,8 +337,16 @@ class Service:
238
337
  return True
239
338
  return False
240
339
 
241
- def should_excute_hook(self):
242
- if not self.switch:
340
+ def should_execute_hook(self, hook_type, cell, is_forward):
341
+ is_cell_hook = hook_type == BaseScope.Module_Type_Module
342
+ if is_cell_hook and not self.switch:
343
+ return False
344
+ elif not is_cell_hook and is_forward and not self.switch:
345
+ return False
346
+ elif not is_cell_hook and not is_forward and not cell.forward_data_collected:
347
+ return False
348
+
349
+ if self.inner_switch:
243
350
  return False
244
351
  if not self.data_collector or self.data_collector.data_processor.is_terminated:
245
352
  return False
@@ -249,6 +356,12 @@ class Service:
249
356
  create_directory(self.config.dump_path)
250
357
  self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
251
358
  cur_rank = self.current_rank if self.current_rank is not None else ''
359
+ if self.config.level == Const.LEVEL_L2:
360
+ create_directory(self.dump_iter_dir)
361
+ kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
362
+ self.config.kernel_config_path = kernel_config_path
363
+ return
364
+
252
365
  dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
253
366
  create_directory(dump_dir)
254
367
  if self.config.task in self.data_collector.tasks_need_tensor_data:
@@ -261,37 +374,96 @@ class Service:
261
374
  stack_file_path = os.path.join(dump_dir, "stack.json")
262
375
  construct_file_path = os.path.join(dump_dir, "construct.json")
263
376
  self.data_collector.update_dump_paths(
264
- dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None)
377
+ dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None
378
+ )
379
+ self.data_collector.initialize_json_file(
380
+ framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
381
+ )
265
382
 
266
383
  def empty(self, *args, **kwargs):
267
384
  pass
268
385
 
269
- def register_hook_new(self):
270
- logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
271
- if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
386
+ def register_api_hook(self):
387
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
388
+ logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.")
272
389
  api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
273
390
  api_register.api_set_hook_func()
274
- if self.model and self.config.task in Const.DUMP_DATA_COLLECTION_LIST:
275
- self.register_primitive_hooks()
276
391
 
392
+ def get_cells_and_names(self):
393
+ cells_and_names_with_index = {}
394
+
395
+ def get_cell_or_module(model):
396
+ return model.named_modules() if is_mindtorch() else model.cells_and_names()
397
+
398
+ if isinstance(self.model, (list, tuple)):
399
+ for index, model in enumerate(self.model):
400
+ cells_and_names_with_index[str(index)] = get_cell_or_module(model)
401
+ else:
402
+ cells_and_names_with_index["-1"] = get_cell_or_module(self.model)
403
+ return cells_and_names_with_index
404
+
405
+ def register_primitive_hook(self):
406
+ if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
407
+ return
408
+ if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST:
409
+ return
410
+
411
+ primitive_set = set()
412
+ cells_and_names_with_index = self.get_cells_and_names()
413
+ for cells_and_names in cells_and_names_with_index.values():
414
+ for _, cell in cells_and_names:
415
+ for attribute, value in vars(cell).items():
416
+ if isinstance(value, Primitive):
417
+ primitive_set.add((attribute, value))
418
+
419
+ for pname, primitive in primitive_set:
420
+ primitive_class_name = primitive.__class__.__name__
421
+ primitive_combined_name = pname + Const.SEP + primitive_class_name
422
+ new_primitive = type('NewPrimitive', (primitive.__class__,),
423
+ {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
424
+ primitive_combined_name)})
425
+ primitive.__class__ = new_primitive
426
+
427
+ def register_cell_hook(self):
277
428
  if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]:
429
+ logger.info(f"The cell {self.config.task} hook function is successfully mounted to the model.")
278
430
  if not self.model:
279
431
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
280
432
  f"The current level is {self.config.level}, the model cannot be None")
281
- for name, cell in self.model.cells_and_names():
282
- if cell == self.model:
283
- continue
284
- prefix = 'Cell' + Const.SEP + name + Const.SEP + \
285
- cell.__class__.__name__ + Const.SEP
286
- forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix)
287
- cell.register_forward_hook(forward_hook)
288
- cell.register_backward_hook(backward_hook)
289
-
290
- cell.register_forward_pre_hook(
291
- self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
292
- cell.register_forward_hook(
293
- self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
294
- cell.register_backward_pre_hook(
295
- self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
296
- cell.register_backward_hook(
297
- self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
433
+ model_type = Const.MODULE if is_mindtorch() else Const.CELL
434
+ cells_and_names_with_index = self.get_cells_and_names()
435
+
436
+ for index, cells_and_names in cells_and_names_with_index.items():
437
+ model = self.model if index == "-1" else self.model[int(index)]
438
+ for name, cell in cells_and_names:
439
+ if cell == model:
440
+ continue
441
+ cell_index = (index + Const.SEP) if index != "-1" else ""
442
+ prefix = (model_type + Const.SEP + cell_index + name +
443
+ Const.SEP + cell.__class__.__name__ + Const.SEP)
444
+ _, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix)
445
+ cell.register_forward_hook(forward_hook)
446
+ cell.register_forward_pre_hook(
447
+ self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
448
+ cell.register_forward_hook(
449
+ self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
450
+
451
+ register_backward_hook_functions["full"](cell, backward_hook)
452
+ register_backward_hook_functions["pre"](
453
+ cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
454
+ register_backward_hook_functions["full"](
455
+ cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
456
+
457
+ def reset_status(self):
458
+ self.primitive_hook_service.primitive_counters.clear()
459
+ self.data_collector.data_writer.reset_cache()
460
+ JitDump.jit_count = defaultdict(int)
461
+ self.params_grad_info.clear()
462
+
463
+ if self.config.level == Const.LEVEL_L2:
464
+ self.data_collector.data_processor.reset_status()
465
+ return
466
+ if self.config.step and self.current_iter not in self.config.step:
467
+ return
468
+ if self.config.rank and self.current_rank not in self.config.rank:
469
+ return
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.core.common.const import Const
2
17
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
3
18
  from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory
msprobe/msprobe.py CHANGED
@@ -16,10 +16,12 @@
16
16
  import argparse
17
17
  import sys
18
18
  import importlib.util
19
- from msprobe.core.compare.utils import _compare_parser
19
+
20
+ from msprobe.core.common.const import Const
20
21
  from msprobe.core.common.log import logger
22
+ from msprobe.core.compare.utils import _compare_parser
21
23
  from msprobe.core.compare.compare_cli import compare_cli
22
- from msprobe.core.common.const import Const
24
+ from msprobe.core.compare.merge_result.merge_result_cli import _merge_result_parser, merge_result_cli
23
25
 
24
26
 
25
27
  def is_module_available(module_name):
@@ -45,10 +47,20 @@ def main():
45
47
  multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut')
46
48
  api_precision_compare_cmd_parser = subparsers.add_parser('api_precision_compare')
47
49
  run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check')
50
+ code_mapping_cmd_parser = subparsers.add_parser('code_mapping')
51
+ graph_service_cmd_parser = subparsers.add_parser('graph')
52
+ op_generate_cmd_parser = subparsers.add_parser('op_generate')
53
+ merge_result_parser = subparsers.add_parser('merge_result')
48
54
  _compare_parser(compare_cmd_parser)
49
- is_torch_available=is_module_available("torch")
50
- is_mindspore_available = is_module_available("mindspore")
51
- if is_torch_available:
55
+ _merge_result_parser(merge_result_parser)
56
+
57
+ is_torch_available = is_module_available("torch")
58
+
59
+ if len(sys.argv) < 4:
60
+ parser.print_help()
61
+ sys.exit(0)
62
+ framework_args = parser.parse_args(sys.argv[1:3])
63
+ if framework_args.framework == Const.PT_FRAMEWORK:
52
64
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
53
65
  from msprobe.pytorch.parse_tool.cli import parse as cli_parse
54
66
  from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut
@@ -56,20 +68,29 @@ def main():
56
68
  _api_precision_compare_command
57
69
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
58
70
  _run_overflow_check_command
71
+ from msprobe.visualization.graph_service import _pt_graph_service_parser, _pt_graph_service_command
72
+ from msprobe.pytorch.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \
73
+ _run_operator_generate_commond
59
74
 
60
75
  _run_ut_parser(run_ut_cmd_parser)
61
76
  _run_ut_parser(multi_run_ut_cmd_parser)
62
77
  multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
63
- help='Number of splits for parallel processing. Range: 1-64')
78
+ help='Number of splits for parallel processing. Range: 1-64')
64
79
  _api_precision_compare_parser(api_precision_compare_cmd_parser)
65
80
  _run_overflow_check_parser(run_overflow_check_cmd_parser)
66
- elif is_mindspore_available:
81
+ _pt_graph_service_parser(graph_service_cmd_parser)
82
+ _op_generator_parser(op_generate_cmd_parser)
83
+ elif framework_args.framework == Const.MS_FRAMEWORK:
67
84
  from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument
85
+ from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command
68
86
  add_api_accuracy_checker_argument(run_ut_cmd_parser)
87
+ from msprobe.mindspore.api_accuracy_checker.cmd_parser import multi_add_api_accuracy_checker_argument
88
+ multi_add_api_accuracy_checker_argument(multi_run_ut_cmd_parser)
89
+ from msprobe.mindspore.code_mapping.cmd_parser import add_ir_parser_arguments
90
+ add_ir_parser_arguments(code_mapping_cmd_parser)
91
+
92
+ _ms_graph_service_parser(graph_service_cmd_parser)
69
93
 
70
- if len(sys.argv) == 1:
71
- parser.print_help()
72
- sys.exit(0)
73
94
  args = parser.parse_args(sys.argv[1:])
74
95
  if sys.argv[2] == Const.PT_FRAMEWORK:
75
96
  if not is_torch_available:
@@ -86,20 +107,37 @@ def main():
86
107
  _api_precision_compare_command(args)
87
108
  elif sys.argv[3] == "run_overflow_check":
88
109
  _run_overflow_check_command(args)
110
+ elif sys.argv[3] == "graph":
111
+ _pt_graph_service_command(args)
112
+ elif sys.argv[3] == 'op_generate':
113
+ _run_operator_generate_commond(args)
89
114
  elif sys.argv[3] == "compare":
90
115
  if args.cell_mapping is not None or args.api_mapping is not None:
91
116
  logger.error("Argument -cm or -am is not supported in PyTorch framework")
92
117
  raise Exception("Argument -cm or -am is not supported in PyTorch framework")
93
118
  compare_cli(args)
119
+ elif sys.argv[3] == "merge_result":
120
+ merge_result_cli(args)
94
121
  else:
95
122
  if not is_module_available(Const.MS_FRAMEWORK):
96
123
  logger.error("MindSpore does not exist, please install MindSpore library")
97
124
  raise Exception("MindSpore does not exist, please install MindSpore library")
98
125
  if sys.argv[3] == "compare":
99
126
  compare_cli(args)
127
+ elif sys.argv[3] == "merge_result":
128
+ merge_result_cli(args)
100
129
  elif sys.argv[3] == "run_ut":
101
130
  from msprobe.mindspore.api_accuracy_checker.main import api_checker_main
102
131
  api_checker_main(args)
132
+ elif sys.argv[3] == "multi_run_ut":
133
+ from msprobe.mindspore.api_accuracy_checker.main import mul_api_checker_main
134
+ mul_api_checker_main(args)
135
+ elif sys.argv[3] == "graph":
136
+ _ms_graph_service_command(args)
137
+ elif sys.argv[3] == "code_mapping":
138
+ from msprobe.mindspore.code_mapping.main import code_mapping_main
139
+ code_mapping_main(args)
140
+
103
141
 
104
142
  if __name__ == "__main__":
105
143
  main()