mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,354 +1,378 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- import os
17
- import copy
18
- from pathlib import Path
19
- import functools
20
- from collections import defaultdict
21
-
22
- import mindspore as ms
23
- from mindspore.common.tensor import Tensor
24
- from mindspore import ops
25
- from mindspore import nn
26
- try:
27
- from mindspore.common._pijit_context import PIJitCaptureContext
28
- pijit_label = True
29
- except ImportError:
30
- pijit_label = False
31
-
32
-
33
- from msprobe.core.data_dump.data_collector import build_data_collector
34
- from msprobe.core.data_dump.scope import BaseScope
35
- from msprobe.mindspore.common.utils import get_rank_if_initialized
36
- from msprobe.core.common.file_check import FileChecker, FileCheckConst, check_path_before_create
37
- from msprobe.mindspore.common.log import logger
38
- from msprobe.core.common.utils import Const
39
- from msprobe.core.common.exceptions import DistributedNotInitializedError
40
- from msprobe.mindspore.dump.hook_cell.api_registry import api_register
41
- from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
42
- ModuleBackwardInputs, ModuleBackwardOutputs
43
- from msprobe.core.common.exceptions import MsprobeException
44
- from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
45
- from msprobe.mindspore.cell_processor import CellProcessor
46
- from msprobe.mindspore.dump.jit_dump import JitDump
47
-
48
-
49
- class Service:
50
- def __init__(self, config):
51
- self.model = None
52
- self.config = copy.deepcopy(config)
53
- self.config.level = self.config.level_ori
54
- self.data_collector = build_data_collector(self.config)
55
- self.cell_processor = CellProcessor(self.data_collector.scope)
56
- self.switch = False
57
- self.current_iter = 0
58
- self.first_start = True
59
- self.current_rank = None
60
- self.primitive_counters = {}
61
- self.dump_iter_dir = None
62
- self.start_call = False
63
- self.check_level_valid()
64
-
65
- @staticmethod
66
- def check_model_valid(model):
67
- if not model or isinstance(model, nn.Cell):
68
- return model
69
- raise MsprobeException(
70
- MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。"
71
- )
72
-
73
- def check_level_valid(self):
74
- if self.config.level == "L2":
75
- raise MsprobeException(
76
- MsprobeException.INVALID_PARAM_ERROR, "L2 level dump function is currently not supported."
77
- )
78
-
79
- def build_hook(self, target_type, name):
80
- def forward_hook(api_or_cell_name, cell, input, output):
81
- if target_type == BaseScope.Module_Type_Module:
82
- api_or_cell_name = cell.mindstudio_reserved_name
83
- self.data_collector.visit_and_clear_overflow_status(api_or_cell_name)
84
- if not self.switch:
85
- return None
86
- if self.data_collector:
87
- if target_type == BaseScope.Module_Type_Module:
88
- module_input_output = ModuleForwardInputsOutputs(args=input, kwargs={}, output=output)
89
- else:
90
- module_input_output = ModuleForwardInputsOutputs(args=input, kwargs=cell.input_kwargs, output=output)
91
- self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
92
- if self.data_collector.if_return_forward_new_output():
93
- return self.data_collector.get_forward_new_output()
94
- if target_type == BaseScope.Module_Type_API:
95
- del cell.input_kwargs
96
- return output
97
-
98
- def backward_hook(api_or_cell_name, cell, grad_input, grad_output):
99
- if target_type == BaseScope.Module_Type_Module:
100
- api_or_cell_name = cell.mindstudio_reserved_name
101
- self.data_collector.visit_and_clear_overflow_status(api_or_cell_name)
102
- if not self.switch:
103
- return
104
- if self.data_collector:
105
- # 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入
106
- module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
107
- self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
108
-
109
- pid = os.getpid()
110
- forward_name_template = name + Const.FORWARD
111
- backward_name_template = name + Const.BACKWARD
112
- forward_hook = functools.partial(forward_hook, forward_name_template)
113
- backward_hook = functools.partial(backward_hook, backward_name_template)
114
-
115
- def wrap_forward_hook(cell, input, output):
116
- return forward_hook(cell, input, output)
117
-
118
- def wrap_backward_hook(cell, grad_input, grad_output):
119
- return backward_hook(cell, grad_input, grad_output)
120
-
121
- return wrap_forward_hook, wrap_backward_hook
122
-
123
- def wrap_primitive(self, origin_func, primitive_name):
124
- service_instance = self
125
-
126
- def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
127
- def backward_hook(grad):
128
- captured_grads.append(grad)
129
- backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
130
- try:
131
- if len(captured_grads) == num_tensors and hook_type == Const.INPUT:
132
- service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
133
- new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
134
- service_instance.data_collector.backward_output_data_collect(
135
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
136
- )
137
- captured_grads.clear()
138
- elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
139
- service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
140
- new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
141
- service_instance.data_collector.backward_input_data_collect(
142
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
143
- )
144
- captured_grads.clear()
145
-
146
- except Exception as exception:
147
- raise Exception(f"This is a primitive op {hook_type}_backward dump error: {exception},"
148
- f" updated_primitive_name: {updated_primitive_name}") from exception
149
-
150
- return backward_hook
151
-
152
- def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name):
153
- hooked_inputs = []
154
- num_tensors = sum(isinstance(arg, Tensor) for arg in args)
155
- input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name,
156
- Const.INPUT)
157
- for _, arg in enumerate(args):
158
- if isinstance(arg, Tensor):
159
- arg_hooked = ops.HookBackward(input_backward_hook)(arg)
160
- hooked_inputs.append(arg_hooked)
161
- else:
162
- hooked_inputs.append(arg)
163
- return hooked_inputs
164
-
165
- def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
166
- if isinstance(out, tuple):
167
- num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out)
168
- else:
169
- num_output_tensors = 1
170
- output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors,
171
- updated_primitive_name, Const.OUTPUT)
172
-
173
- if isinstance(out, Tensor):
174
- return ops.HookBackward(output_backward_hook)(out)
175
- elif isinstance(out, tuple):
176
- hooked_outputs = []
177
- for tensor in out:
178
- if isinstance(tensor, Tensor):
179
- hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
180
- else:
181
- hooked_outputs.append(tensor)
182
- return tuple(hooked_outputs)
183
- return out
184
-
185
- def wrapped_primitive_call(instance_self, *args, **kwargs):
186
- service_instance.update_primitive_counters(primitive_name)
187
- current_count = service_instance.primitive_counters.get(primitive_name, 0)
188
- updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
189
-
190
- if not service_instance.switch:
191
- return origin_func(*args, **kwargs)
192
-
193
- captured_grads_input, captured_grads_output = [], []
194
-
195
- try:
196
- hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
197
- except Exception as exception:
198
- raise Exception("This is a primitive op dump error during input hooking: {},"
199
- " primitive_name: {}".format(exception, primitive_name)) from exception
200
-
201
- try:
202
- out = origin_func(*hooked_inputs, **kwargs)
203
- except Exception as exception:
204
- raise Exception("This is a primitive op dump error during function call: {},"
205
- " primitive_name: {}".format(exception, primitive_name)) from exception
206
-
207
- forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
208
- service_instance.data_collector.visit_and_clear_overflow_status(forward_primitive_name)
209
- if service_instance.data_collector:
210
- module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
211
- try:
212
- service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
213
- os.getpid(), module_input_output)
214
- except Exception as exception:
215
- raise Exception("This is a primitive op dump error during forward data collection: {},"
216
- " primitive_name: {}".format(exception, primitive_name)) from exception
217
-
218
- if service_instance.data_collector.if_return_forward_new_output():
219
- out = service_instance.data_collector.get_forward_new_output()
220
-
221
- try:
222
- out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
223
- except Exception as exception:
224
- raise Exception("This is a primitive op dump error during output hooking: {},"
225
- " primitive_name: {}".format(exception, primitive_name)) from exception
226
-
227
- return out
228
-
229
- return wrapped_primitive_call
230
-
231
- def update_primitive_counters(self, primitive_name):
232
- if primitive_name not in self.primitive_counters:
233
- self.primitive_counters[primitive_name] = 0
234
- else:
235
- self.primitive_counters[primitive_name] += 1
236
-
237
- def register_hooks(self):
238
- primitive_set = set()
239
- for _, cell in self.model.cells_and_names():
240
- for pname, primitive in cell._primitives.items():
241
- primitive_set.add((pname, primitive))
242
-
243
- for pname, primitive in primitive_set:
244
- NewPrimitive = type('NewPrimitive', (primitive.__class__,),
245
- {'__call__': self.wrap_primitive(primitive.__call__, pname)})
246
- primitive.__class__ = NewPrimitive
247
-
248
- def step(self):
249
- self.current_iter += 1
250
- self.data_collector.update_iter(self.current_iter)
251
- HOOKCell.cell_count = defaultdict(int)
252
- CellProcessor.cell_count = {}
253
- self.primitive_counters.clear()
254
-
255
- def start(self, model=None):
256
- self.model = self.check_model_valid(model)
257
- self.start_call = True
258
- logger.info("msprobe: debugger.start() is set successfully")
259
- if self.config.step and self.current_iter > max(self.config.step):
260
- self.stop()
261
- raise Exception("msprobe: exit after iteration {}".format(max(self.config.step)))
262
- if self.config.step and self.current_iter not in self.config.step:
263
- return
264
- if self.first_start:
265
- try:
266
- self.current_rank = get_rank_if_initialized()
267
- except DistributedNotInitializedError:
268
- self.current_rank = None
269
-
270
- if self.config.rank and self.current_rank not in self.config.rank:
271
- return
272
- self.register_hook_new()
273
- self.first_start = False
274
- self.switch = True
275
- logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
276
- self.create_dirs()
277
- logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
278
- if self.config.level == "L1":
279
- JitDump.set_config(self.config)
280
- JitDump.set_data_collector(self.data_collector)
281
- ms.common.api._MindsporeFunctionExecutor = JitDump
282
- ms.common.api._PyNativeExecutor.grad = JitDump.grad
283
- if pijit_label:
284
- PIJitCaptureContext.__enter__ = self.empty
285
- PIJitCaptureContext.__exit__ = self.empty
286
-
287
- def stop(self):
288
- logger.info("msprobe: debugger.stop() is set successfully. "
289
- "Please set debugger.start() to turn on the dump switch again. ")
290
- if not self.start_call:
291
- logger.error("msprobe: debugger.start() is not set in the current scope.")
292
- raise Exception("debugger.start() is not set in the current scope.")
293
- if self.config.step and self.current_iter not in self.config.step:
294
- return
295
- if self.config.rank and self.current_rank not in self.config.rank:
296
- return
297
- self.switch = False
298
- self.start_call = False
299
- self.data_collector.write_json()
300
-
301
- def create_dirs(self):
302
- check_path_before_create(self.config.dump_path)
303
- if not os.path.exists(self.config.dump_path):
304
- Path(self.config.dump_path).mkdir(mode=0o750, exist_ok=True)
305
- file_check = FileChecker(self.config.dump_path, FileCheckConst.DIR)
306
- file_check.common_check()
307
- self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
308
- cur_rank = self.current_rank if self.current_rank is not None else ''
309
- dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
310
- if not os.path.exists(dump_dir):
311
- Path(dump_dir).mkdir(mode=0o750, parents=True, exist_ok=True)
312
- if self.config.task in self.data_collector.tasks_need_tensor_data:
313
- dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
314
- Path(dump_data_dir).mkdir(mode=0o750, exist_ok=True)
315
- else:
316
- dump_data_dir = None
317
-
318
- dump_file_path = os.path.join(dump_dir, "dump.json")
319
- stack_file_path = os.path.join(dump_dir, "stack.json")
320
- construct_file_path = os.path.join(dump_dir, "construct.json")
321
- self.data_collector.update_dump_paths(
322
- dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None)
323
-
324
- def empty(self, *args, **kwargs):
325
- pass
326
-
327
- def register_hook_new(self):
328
- logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
329
- if self.config.level == "L1":
330
- api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
331
- api_register.api_set_hook_func()
332
- if self.model:
333
- self.register_hooks()
334
-
335
- if self.config.level == "L0":
336
- if not self.model:
337
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, "The current level is L0, the model cannot be None")
338
- for name, cell in self.model.cells_and_names():
339
- if cell == self.model:
340
- continue
341
- prefix = 'Cell' + Const.SEP + name + Const.SEP + \
342
- cell.__class__.__name__ + Const.SEP
343
- forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix)
344
- cell.register_forward_hook(forward_hook)
345
- cell.register_backward_hook(backward_hook)
346
-
347
- cell.register_forward_pre_hook(
348
- self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
349
- cell.register_forward_hook(
350
- self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
351
- cell.register_backward_pre_hook(
352
- self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
353
- cell.register_backward_hook(
354
- self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ import os
17
+ import copy
18
+ import functools
19
+ from collections import defaultdict
20
+
21
+ import mindspore as ms
22
+ from mindspore.common.tensor import Tensor
23
+ from mindspore import ops
24
+ from mindspore import nn
25
+ try:
26
+ from mindspore.common._pijit_context import PIJitCaptureContext
27
+ pijit_label = True
28
+ except ImportError:
29
+ pijit_label = False
30
+
31
+
32
+ from msprobe.core.data_dump.data_collector import build_data_collector
33
+ from msprobe.core.data_dump.scope import BaseScope
34
+ from msprobe.mindspore.common.utils import get_rank_if_initialized
35
+ from msprobe.core.common.file_utils import create_directory
36
+ from msprobe.mindspore.common.log import logger
37
+ from msprobe.core.common.utils import Const
38
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
39
+ from msprobe.mindspore.dump.hook_cell.api_registry import api_register
40
+ from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
41
+ ModuleBackwardInputs, ModuleBackwardOutputs
42
+ from msprobe.core.common.exceptions import MsprobeException
43
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
44
+ from msprobe.mindspore.cell_processor import CellProcessor
45
+ from msprobe.mindspore.dump.jit_dump import JitDump
46
+
47
+
48
+ class Service:
49
+ def __init__(self, config):
50
+ self.model = None
51
+ self.config = copy.deepcopy(config)
52
+ self.config.level = self.config.level_ori
53
+ self.data_collector = build_data_collector(self.config)
54
+ self.cell_processor = CellProcessor(self.data_collector.scope)
55
+ self.switch = False
56
+ self.current_iter = 0
57
+ self.first_start = True
58
+ self.current_rank = None
59
+ self.primitive_counters = {}
60
+ self.dump_iter_dir = None
61
+ self.start_call = False
62
+ self.check_level_valid()
63
+ self.should_stop_service = False
64
+
65
+ @staticmethod
66
+ def check_model_valid(model):
67
+ if not model or isinstance(model, nn.Cell):
68
+ return model
69
+ raise MsprobeException(
70
+ MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。"
71
+ )
72
+
73
+ def check_level_valid(self):
74
+ if self.config.level == "L2":
75
+ raise MsprobeException(
76
+ MsprobeException.INVALID_PARAM_ERROR, "L2 level dump function is currently not supported."
77
+ )
78
+
79
+ def build_hook(self, target_type, name):
80
+ def forward_hook(api_or_cell_name, cell, input, output):
81
+ if not self.should_excute_hook():
82
+ return None
83
+
84
+ if target_type == BaseScope.Module_Type_Module:
85
+ api_or_cell_name = cell.mindstudio_reserved_name
86
+ module_input_output = ModuleForwardInputsOutputs(args=input, kwargs={}, output=output)
87
+ else:
88
+ module_input_output = ModuleForwardInputsOutputs(args=input, kwargs=cell.input_kwargs,
89
+ output=output)
90
+
91
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
92
+ self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
93
+ if self.data_collector.if_return_forward_new_output():
94
+ return self.data_collector.get_forward_new_output()
95
+ if target_type == BaseScope.Module_Type_API:
96
+ del cell.input_kwargs
97
+ return output
98
+
99
+ def backward_hook(api_or_cell_name, cell, grad_input, grad_output):
100
+ if not self.should_excute_hook():
101
+ return
102
+
103
+ if target_type == BaseScope.Module_Type_Module:
104
+ api_or_cell_name = cell.mindstudio_reserved_name
105
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
106
+ if self.data_collector:
107
+ # 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入
108
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
109
+ self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
110
+
111
+ pid = os.getpid()
112
+ forward_name_template = name + Const.FORWARD
113
+ backward_name_template = name + Const.BACKWARD
114
+ forward_hook = functools.partial(forward_hook, forward_name_template)
115
+ backward_hook = functools.partial(backward_hook, backward_name_template)
116
+
117
+ def wrap_forward_hook(cell, input, output):
118
+ return forward_hook(cell, input, output)
119
+
120
+ def wrap_backward_hook(cell, grad_input, grad_output):
121
+ return backward_hook(cell, grad_input, grad_output)
122
+
123
+ return wrap_forward_hook, wrap_backward_hook
124
+
125
+ def wrap_primitive(self, origin_func, primitive_name):
126
+ service_instance = self
127
+
128
+ def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
129
+ def backward_hook(grad):
130
+ captured_grads.append(grad)
131
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
132
+ try:
133
+ if len(captured_grads) == num_tensors and hook_type == Const.INPUT:
134
+ service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
135
+ new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
136
+ service_instance.data_collector.backward_output_data_collect(
137
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
138
+ )
139
+ captured_grads.clear()
140
+ elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
141
+ service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
142
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
143
+ service_instance.data_collector.backward_input_data_collect(
144
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
145
+ )
146
+ captured_grads.clear()
147
+
148
+ except Exception as exception:
149
+ raise Exception(f"This is a primitive op {hook_type}_backward dump error: {exception},"
150
+ f" updated_primitive_name: {updated_primitive_name}") from exception
151
+
152
+ return backward_hook
153
+
154
+ def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name):
155
+ hooked_inputs = []
156
+ num_tensors = sum(isinstance(arg, Tensor) for arg in args)
157
+ input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name,
158
+ Const.INPUT)
159
+ for _, arg in enumerate(args):
160
+ if isinstance(arg, Tensor):
161
+ arg_hooked = ops.HookBackward(input_backward_hook)(arg)
162
+ hooked_inputs.append(arg_hooked)
163
+ else:
164
+ hooked_inputs.append(arg)
165
+ return hooked_inputs
166
+
167
+ def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
168
+ if isinstance(out, tuple):
169
+ num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out)
170
+ else:
171
+ num_output_tensors = 1
172
+ output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors,
173
+ updated_primitive_name, Const.OUTPUT)
174
+
175
+ if isinstance(out, Tensor):
176
+ return ops.HookBackward(output_backward_hook)(out)
177
+ elif isinstance(out, tuple):
178
+ hooked_outputs = []
179
+ for tensor in out:
180
+ if isinstance(tensor, Tensor):
181
+ hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
182
+ else:
183
+ hooked_outputs.append(tensor)
184
+ return tuple(hooked_outputs)
185
+ return out
186
+
187
+ def wrapped_primitive_call(instance_self, *args, **kwargs):
188
+ service_instance.update_primitive_counters(primitive_name)
189
+ current_count = service_instance.primitive_counters.get(primitive_name, 0)
190
+ updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
191
+
192
+ if not service_instance.switch:
193
+ return origin_func(*args, **kwargs)
194
+
195
+ captured_grads_input, captured_grads_output = [], []
196
+
197
+ try:
198
+ hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
199
+ except Exception as exception:
200
+ raise Exception("This is a primitive op dump error during input hooking: {},"
201
+ " primitive_name: {}".format(exception, primitive_name)) from exception
202
+
203
+ try:
204
+ out = origin_func(*hooked_inputs, **kwargs)
205
+ except Exception as exception:
206
+ raise Exception("This is a primitive op dump error during function call: {},"
207
+ " primitive_name: {}".format(exception, primitive_name)) from exception
208
+
209
+ forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
210
+ service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
211
+ if service_instance.data_collector:
212
+ module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
213
+ try:
214
+ service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
215
+ os.getpid(), module_input_output)
216
+ except Exception as exception:
217
+ raise Exception("This is a primitive op dump error during forward data collection: {},"
218
+ " primitive_name: {}".format(exception, primitive_name)) from exception
219
+
220
+ if service_instance.data_collector.if_return_forward_new_output():
221
+ out = service_instance.data_collector.get_forward_new_output()
222
+
223
+ try:
224
+ out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
225
+ except Exception as exception:
226
+ raise Exception("This is a primitive op dump error during output hooking: {},"
227
+ " primitive_name: {}".format(exception, primitive_name)) from exception
228
+
229
+ return out
230
+
231
+ return wrapped_primitive_call
232
+
233
+ def update_primitive_counters(self, primitive_name):
234
+ if primitive_name not in self.primitive_counters:
235
+ self.primitive_counters[primitive_name] = 0
236
+ else:
237
+ self.primitive_counters[primitive_name] += 1
238
+
239
+ def register_hooks(self):
240
+ primitive_set = set()
241
+ for _, cell in self.model.cells_and_names():
242
+ for pname, primitive in cell._primitives.items():
243
+ primitive_set.add((pname, primitive))
244
+
245
+ for pname, primitive in primitive_set:
246
+ NewPrimitive = type('NewPrimitive', (primitive.__class__,),
247
+ {'__call__': self.wrap_primitive(primitive.__call__, pname)})
248
+ primitive.__class__ = NewPrimitive
249
+
250
+ def step(self):
251
+ self.current_iter += 1
252
+ self.data_collector.update_iter(self.current_iter)
253
+ HOOKCell.cell_count = defaultdict(int)
254
+ CellProcessor.cell_count = {}
255
+ self.primitive_counters.clear()
256
+
257
+ def start(self, model=None):
258
+ self.start_call = True
259
+ if self.should_stop_service:
260
+ return
261
+ if self.need_end_service():
262
+ api_register.api_set_ori_func()
263
+ self.should_stop_service = True
264
+ self.switch = False
265
+ logger.info("************************************************")
266
+ logger.info(f"* {Const.TOOL_NAME} ends successfully. *")
267
+ logger.info("************************************************")
268
+ return
269
+ if self.config.step and self.current_iter not in self.config.step:
270
+ return
271
+ self.model = self.check_model_valid(model)
272
+
273
+ logger.info(f"{Const.TOOL_NAME}: debugger.start() is set successfully")
274
+
275
+ if self.first_start:
276
+ try:
277
+ self.current_rank = get_rank_if_initialized()
278
+ except DistributedNotInitializedError:
279
+ self.current_rank = None
280
+
281
+ if self.config.rank and self.current_rank not in self.config.rank:
282
+ return
283
+ self.register_hook_new()
284
+ if self.config.level == "L1":
285
+ JitDump.set_config(self.config)
286
+ JitDump.set_data_collector(self.data_collector)
287
+ ms.common.api._MindsporeFunctionExecutor = JitDump
288
+ ms.common.api._PyNativeExecutor.grad = JitDump.grad
289
+ if pijit_label:
290
+ PIJitCaptureContext.__enter__ = self.empty
291
+ PIJitCaptureContext.__exit__ = self.empty
292
+ self.first_start = False
293
+
294
+ self.switch = True
295
+ logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
296
+ self.create_dirs()
297
+ logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
298
+
299
+ def stop(self):
300
+ if self.should_stop_service:
301
+ return
302
+ logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. "
303
+ "Please set debugger.start() to turn on the dump switch again. ")
304
+ if not self.start_call:
305
+ logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
306
+ raise Exception("debugger.start() is not set in the current scope.")
307
+ if self.config.step and self.current_iter not in self.config.step:
308
+ return
309
+ if self.config.rank and self.current_rank not in self.config.rank:
310
+ return
311
+ self.switch = False
312
+ self.start_call = False
313
+ self.data_collector.write_json()
314
+
315
+ def need_end_service(self):
316
+ if self.config.step and self.current_iter > max(self.config.step):
317
+ return True
318
+ if self.data_collector and self.data_collector.data_processor.is_terminated:
319
+ return True
320
+ return False
321
+
322
+ def should_excute_hook(self):
323
+ if not self.switch:
324
+ return False
325
+ if not self.data_collector or self.data_collector.data_processor.is_terminated:
326
+ return False
327
+ return True
328
+
329
+ def create_dirs(self):
330
+ create_directory(self.config.dump_path)
331
+ self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
332
+ cur_rank = self.current_rank if self.current_rank is not None else ''
333
+ dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
334
+ create_directory(dump_dir)
335
+ if self.config.task in self.data_collector.tasks_need_tensor_data:
336
+ dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
337
+ create_directory(dump_data_dir)
338
+ else:
339
+ dump_data_dir = None
340
+
341
+ dump_file_path = os.path.join(dump_dir, "dump.json")
342
+ stack_file_path = os.path.join(dump_dir, "stack.json")
343
+ construct_file_path = os.path.join(dump_dir, "construct.json")
344
+ self.data_collector.update_dump_paths(
345
+ dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None)
346
+
347
+ def empty(self, *args, **kwargs):
348
+ pass
349
+
350
+ def register_hook_new(self):
351
+ logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
352
+ if self.config.level == "L1":
353
+ api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
354
+ api_register.api_set_hook_func()
355
+ if self.model:
356
+ self.register_hooks()
357
+
358
+ if self.config.level == "L0":
359
+ if not self.model:
360
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
361
+ "The current level is L0, the model cannot be None")
362
+ for name, cell in self.model.cells_and_names():
363
+ if cell == self.model:
364
+ continue
365
+ prefix = 'Cell' + Const.SEP + name + Const.SEP + \
366
+ cell.__class__.__name__ + Const.SEP
367
+ forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix)
368
+ cell.register_forward_hook(forward_hook)
369
+ cell.register_backward_hook(backward_hook)
370
+
371
+ cell.register_forward_pre_hook(
372
+ self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
373
+ cell.register_forward_hook(
374
+ self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
375
+ cell.register_backward_pre_hook(
376
+ self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
377
+ cell.register_backward_hook(
378
+ self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))