mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -1,543 +0,0 @@
1
- # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import copy
17
- import functools
18
- import os
19
- from collections import defaultdict
20
-
21
- import mindspore as ms
22
- from mindspore import nn
23
- from mindspore.common.api import _no_grad
24
- from mindspore.ops.primitive import Primitive
25
-
26
- try:
27
- from mindspore.common._pijit_context import PIJitCaptureContext
28
- except ImportError:
29
- pijit_label = False
30
- else:
31
- pijit_label = True
32
-
33
- from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
34
- from msprobe.core.common.file_utils import create_directory
35
- from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
36
- from msprobe.core.data_dump.data_collector import build_data_collector
37
- from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,
38
- ModuleBackwardInputs)
39
- from msprobe.core.data_dump.scope import BaseScope
40
- from msprobe.mindspore.cell_processor import CellProcessor
41
- from msprobe.mindspore.common.log import logger
42
- from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs,
43
- is_mindtorch, register_backward_hook_functions)
44
- from msprobe.mindspore.dump.hook_cell.api_registry import api_register
45
- from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
46
- from msprobe.mindspore.dump.jit_dump import JitDump
47
- from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
48
- from msprobe.mindspore.dump.kernel_dump.kernel_config import create_kernel_config_json
49
-
50
- if is_mindtorch():
51
- import torch
52
-
53
-
54
- class Service:
55
- def __init__(self, config):
56
- self.model = None
57
- self.config = copy.deepcopy(config)
58
- self.config.level = self.config.level_ori
59
- self.data_collector = build_data_collector(self.config)
60
- self.cell_processor = CellProcessor(self.data_collector.scope)
61
- self.primitive_hook_service = PrimitiveHookService(self)
62
- self.switch = False
63
- self.inner_switch = False
64
- self.primitive_switch = False
65
- self.current_iter = 0
66
- self.first_start = True
67
- self.current_rank = None
68
- self.dump_iter_dir = None
69
- self.start_call = False
70
- self.should_stop_service = False
71
- self.params_grad_info = {}
72
- self.hook_handle_dict = {}
73
- # 提前注册,确保注册尽可能多的API hook
74
- self.register_api_hook()
75
- self.init_for_debug_level()
76
-
77
- @staticmethod
78
- def check_model_valid(models):
79
- target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell")
80
- if models is None or isinstance(models, target_module_type[0]):
81
- return models
82
- error_model = None
83
- if isinstance(models, (list, tuple)):
84
- for model in models:
85
- if not isinstance(model, target_module_type[0]):
86
- error_model = model
87
- break
88
- else:
89
- error_model = models
90
-
91
- if error_model is not None:
92
- error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] "
93
- f"type, currently there is a {type(error_model)} type.")
94
- raise MsprobeException(
95
- MsprobeException.INVALID_PARAM_ERROR, error_info)
96
- return models
97
-
98
- @staticmethod
99
- def prepare_module_input_output(target_type, cell, input_data, output):
100
- if target_type == BaseScope.Module_Type_Module:
101
- module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output)
102
- else:
103
- module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output)
104
- return module_input_output
105
-
106
- def build_hook(self, target_type, name):
107
- def pre_hook(api_or_cell_name, cell, input_data):
108
- if not self.should_execute_hook(target_type, cell, True):
109
- clean_input_kwargs(cell)
110
- return None
111
-
112
- with _no_grad():
113
- self.inner_switch = True
114
- if target_type == BaseScope.Module_Type_Module:
115
- api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
116
- else:
117
- cell.forward_data_collected = True
118
- HOOKCell.add_cell_count(name)
119
- module_input_output = self.prepare_module_input_output(target_type, cell, input_data, None)
120
- self.data_collector.update_api_or_module_name(api_or_cell_name)
121
- self.data_collector.forward_input_data_collect(api_or_cell_name, cell, pid, module_input_output)
122
- self.inner_switch = False
123
- return input_data
124
-
125
- def grad_hook(cell, ori_name, param_name):
126
- def hook_fn(grad):
127
- if not self.should_execute_hook(target_type, cell, False):
128
- return None
129
- self.inner_switch = True
130
- self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
131
- self.inner_switch = False
132
- return None
133
-
134
- return hook_fn
135
-
136
- def register_param_hook(ori_name, cell, params_dict):
137
- '''
138
- 注册参数hook
139
- '''
140
- # data_mode为forward时,不注册参数hook
141
- if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
142
- for param_name, param in params_dict.items():
143
- if param.requires_grad:
144
- name = ori_name + Const.SEP + param_name
145
- old_handle = self.hook_handle_dict.get(name)
146
- if old_handle and hasattr(old_handle, "remove"):
147
- old_handle.remove()
148
- handle = param.register_hook(grad_hook(cell, ori_name, param_name))
149
- self.hook_handle_dict[name] = handle
150
-
151
- def init_params_grad_info(cell, params_dict):
152
- '''
153
- 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
154
- '''
155
- if not params_dict:
156
- return
157
- if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
158
- grad_name = cell.params_grad_name if hasattr(cell, 'params_grad_name') else None
159
- # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
160
- if not self.params_grad_info.get(grad_name):
161
- data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
162
- # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
163
- if data_info.get(grad_name):
164
- # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
165
- self.data_collector.handle_data(grad_name, data_info,
166
- flush=self.data_collector.data_processor.is_terminated)
167
- # 记录当前模块的参数梯度信息已占位
168
- self.params_grad_info[grad_name] = True
169
-
170
- def forward_hook(api_or_cell_name, cell, input_data, output):
171
- if not self.should_execute_hook(target_type, cell, True):
172
- clean_input_kwargs(cell)
173
- return None
174
- with _no_grad():
175
- self.inner_switch = True
176
- module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output)
177
- if target_type == BaseScope.Module_Type_Module:
178
- api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
179
- params_dict = {}
180
- if self.config.task != Const.STRUCTURE:
181
- params_dict = {
182
- key.split(Const.SEP)[-1]: value
183
- for key, value in cell.parameters_dict(recurse=False).items()
184
- }
185
- setattr(module_input_output, Const.PARAMS, params_dict)
186
- # 判断是否需要注册参数hook
187
- if params_dict:
188
- ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0]
189
- grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
190
- # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
191
- setattr(cell, 'params_grad_name', grad_name)
192
- register_param_hook(ori_name, cell, params_dict)
193
- self.data_collector.update_api_or_module_name(api_or_cell_name)
194
- self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
195
- init_params_grad_info(cell, params_dict)
196
- else:
197
- self.data_collector.update_api_or_module_name(api_or_cell_name)
198
- self.data_collector.forward_output_data_collect(api_or_cell_name, cell, pid, module_input_output)
199
-
200
- if self.data_collector.if_return_forward_new_output():
201
- forward_new_output = self.data_collector.get_forward_new_output()
202
- self.inner_switch = False
203
- return forward_new_output
204
- clean_input_kwargs(cell)
205
- self.inner_switch = False
206
- return output
207
-
208
- def backward_hook(api_or_cell_name, cell, grad_input, grad_output):
209
- if not self.should_execute_hook(target_type, cell, False):
210
- return
211
- self.inner_switch = True
212
-
213
- need_exchange = True
214
- if target_type == BaseScope.Module_Type_Module:
215
- if not hasattr(cell, 'has_pre_hook_called') or not cell.has_pre_hook_called:
216
- need_exchange = False
217
- api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
218
-
219
- self.data_collector.update_api_or_module_name(api_or_cell_name)
220
- if self.data_collector:
221
- # 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入
222
- if need_exchange:
223
- module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
224
- else:
225
- module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
226
- self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
227
- self.inner_switch = False
228
-
229
- def pre_backward_hook(api_or_cell_name, cell, grad_input):
230
- if not self.should_execute_hook(target_type, cell, False):
231
- return
232
- self.inner_switch = True
233
- module_input = ModuleBackwardInputs(grad_input=grad_input)
234
- self.data_collector.update_api_or_module_name(api_or_cell_name)
235
- self.data_collector.backward_input_data_collect(api_or_cell_name, cell, pid, module_input)
236
-
237
- self.inner_switch = False
238
-
239
- pid = os.getpid()
240
- if target_type == BaseScope.Module_Type_Module:
241
- full_forward_name = name + Const.FORWARD
242
- full_backward_name = name + Const.BACKWARD
243
- else:
244
- full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD
245
- full_backward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.BACKWARD
246
- pre_forward_hook = functools.partial(pre_hook, full_forward_name)
247
- forward_hook = functools.partial(forward_hook, full_forward_name)
248
- backward_hook = functools.partial(backward_hook, full_backward_name)
249
- pre_backward_hook = functools.partial(pre_backward_hook, full_backward_name)
250
-
251
- def wrap_pre_forward_hook(cell, input_data):
252
- return pre_forward_hook(cell, input_data)
253
-
254
- def wrap_forward_hook(cell, input_data, output_data):
255
- return forward_hook(cell, input_data, output_data)
256
-
257
- def wrap_backward_hook(cell, grad_input, grad_output):
258
- return backward_hook(cell, grad_input, grad_output)
259
-
260
- def wrap_pre_backward_hook(cell, grad_input):
261
- return pre_backward_hook(cell, grad_input)
262
-
263
- return wrap_pre_forward_hook, wrap_forward_hook, wrap_backward_hook, wrap_pre_backward_hook
264
-
265
- def update_primitive_counters(self, primitive_name):
266
- if primitive_name not in self.primitive_counters:
267
- self.primitive_counters[primitive_name] = 0
268
- else:
269
- self.primitive_counters[primitive_name] += 1
270
-
271
- def step(self):
272
- if self.config.level == Const.LEVEL_DEBUG:
273
- return
274
- if self.config.async_dump:
275
- self.data_collector.fill_stack_tensor_data()
276
- if self.config.task == Const.TENSOR:
277
- self.data_collector.data_processor.dump_async_data()
278
- self.data_collector.write_json()
279
- self.current_iter += 1
280
- self.data_collector.update_iter(self.current_iter)
281
- self.reset_status()
282
-
283
- def start(self, model=None):
284
- if self.config.level == Const.LEVEL_DEBUG:
285
- return
286
- self.start_call = True
287
- if self.should_stop_service:
288
- return
289
- if self.need_end_service():
290
- self.should_stop_service = True
291
- self.switch = False
292
- self.primitive_switch = False
293
- print_tools_ends_info()
294
- return
295
- if self.config.step and self.current_iter not in self.config.step:
296
- return
297
- self.model = self.check_model_valid(model)
298
-
299
- logger.info(f"{Const.TOOL_NAME}: debugger.start() is set successfully")
300
-
301
- if self.first_start:
302
- try:
303
- self.current_rank = get_rank_if_initialized()
304
- except DistributedNotInitializedError:
305
- self.current_rank = None
306
-
307
- if self.config.rank and self.current_rank not in self.config.rank:
308
- return
309
- self.register_primitive_hook()
310
- self.register_cell_hook()
311
- if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
312
- JitDump.set_config(self.config)
313
- JitDump.set_data_collector(self.data_collector)
314
- if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
315
- ms.common.api._MindsporeFunctionExecutor = JitDump
316
- else:
317
- ms.common.api._JitExecutor = JitDump
318
- ms.common.api._PyNativeExecutor.grad = JitDump.grad
319
- if pijit_label:
320
- PIJitCaptureContext.__enter__ = self.empty
321
- PIJitCaptureContext.__exit__ = self.empty
322
- self.first_start = False
323
-
324
- api_register.api_set_hook_func()
325
- self.switch = True
326
- self.primitive_switch = True
327
- logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
328
- self.create_dirs()
329
- logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
330
- JitDump.jit_dump_switch = True
331
-
332
- def stop(self):
333
- if self.config.level == Const.LEVEL_DEBUG:
334
- return
335
- if self.should_stop_service:
336
- return
337
- logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. "
338
- "Please set debugger.start() to turn on the dump switch again. ")
339
- if not self.start_call:
340
- logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
341
- raise Exception("debugger.start() is not set in the current scope.")
342
- if self.config.step and self.current_iter not in self.config.step:
343
- return
344
- if self.config.rank and self.current_rank not in self.config.rank:
345
- return
346
- self.switch = False
347
- self.primitive_switch = False
348
- self.start_call = False
349
- if self.config.async_dump:
350
- self.data_collector.fill_stack_tensor_data()
351
- if self.config.task == Const.TENSOR:
352
- self.data_collector.data_processor.dump_async_data()
353
- self.data_collector.write_json()
354
- JitDump.jit_dump_switch = False
355
-
356
- def need_end_service(self):
357
- if self.config.step and self.current_iter > max(self.config.step):
358
- return True
359
- if self.data_collector and self.data_collector.data_processor.is_terminated:
360
- return True
361
- return False
362
-
363
- def should_execute_hook(self, hook_type, cell, is_forward):
364
- is_cell_hook = hook_type == BaseScope.Module_Type_Module
365
- if is_cell_hook and not self.switch:
366
- return False
367
- elif not is_cell_hook and is_forward and not self.switch:
368
- return False
369
- elif not is_cell_hook and not is_forward and not cell.forward_data_collected:
370
- return False
371
-
372
- if self.inner_switch:
373
- return False
374
- if not self.data_collector or self.data_collector.data_processor.is_terminated:
375
- return False
376
- return True
377
-
378
- def create_dirs(self):
379
- create_directory(self.config.dump_path)
380
- self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
381
- cur_rank = self.current_rank if self.current_rank is not None else ''
382
- if self.config.level == Const.LEVEL_L2:
383
- create_directory(self.dump_iter_dir)
384
- kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
385
- self.config.kernel_config_path = kernel_config_path
386
- return
387
-
388
- dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
389
- create_directory(dump_dir)
390
- if self.config.task in self.data_collector.tasks_need_tensor_data:
391
- dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
392
- create_directory(dump_data_dir)
393
- else:
394
- dump_data_dir = None
395
-
396
- dump_path_aggregation = DumpPathAggregation()
397
- dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
398
- dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
399
- dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
400
- dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
401
- self.data_collector.update_dump_paths(dump_path_aggregation)
402
-
403
- self.data_collector.initialize_json_file(
404
- framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
405
- )
406
-
407
- def empty(self, *args, **kwargs):
408
- pass
409
-
410
- def register_api_hook(self):
411
- if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
412
- logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.")
413
- api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
414
- api_register.api_set_hook_func()
415
-
416
- def get_cells_and_names(self):
417
- cells_and_names_with_index = {}
418
-
419
- def get_cell_or_module(model):
420
- return model.named_modules() if is_mindtorch() else model.cells_and_names()
421
-
422
- if isinstance(self.model, (list, tuple)):
423
- for index, model in enumerate(self.model):
424
- cells_and_names_with_index[str(index)] = get_cell_or_module(model)
425
- else:
426
- cells_and_names_with_index["-1"] = get_cell_or_module(self.model)
427
- return cells_and_names_with_index
428
-
429
- def register_primitive_hook(self):
430
- if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
431
- return
432
- if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST:
433
- return
434
-
435
- primitive_set = set()
436
- cells_and_names_with_index = self.get_cells_and_names()
437
- for cells_and_names in cells_and_names_with_index.values():
438
- for _, cell in cells_and_names:
439
- for attribute, value in vars(cell).items():
440
- if isinstance(value, Primitive):
441
- primitive_set.add((attribute, value))
442
-
443
- for pname, primitive in primitive_set:
444
- primitive_class_name = primitive.__class__.__name__
445
- primitive_combined_name = pname + Const.SEP + primitive_class_name
446
- new_primitive = type('NewPrimitive', (primitive.__class__,),
447
- {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
448
- primitive_combined_name)})
449
- primitive.__class__ = new_primitive
450
-
451
- def register_cell_hook(self):
452
- if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]:
453
- logger.info(f"The cell {self.config.task} hook function is successfully mounted to the model.")
454
- if not self.model:
455
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
456
- f"The current level is {self.config.level}, the model cannot be None")
457
- model_type = Const.MODULE if is_mindtorch() else Const.CELL
458
- cells_and_names_with_index = self.get_cells_and_names()
459
-
460
- for index, cells_and_names in cells_and_names_with_index.items():
461
- model = self.model if index == "-1" else self.model[int(index)]
462
- for name, cell in cells_and_names:
463
- if cell == model:
464
- continue
465
- cell_index = (index + Const.SEP) if index != "-1" else ""
466
- prefix = (model_type + Const.SEP + cell_index + name +
467
- Const.SEP + cell.__class__.__name__ + Const.SEP)
468
- _, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix)
469
- cell.register_forward_hook(forward_hook)
470
- cell.register_forward_pre_hook(
471
- self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
472
- cell.register_forward_hook(
473
- self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
474
-
475
- register_backward_hook_functions["full"](cell, backward_hook)
476
- register_backward_hook_functions["pre"](
477
- cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
478
- register_backward_hook_functions["full"](
479
- cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
480
-
481
- def reset_status(self):
482
- self.primitive_hook_service.primitive_counters.clear()
483
- self.data_collector.reset_status()
484
- JitDump.jit_count = defaultdict(int)
485
- self.params_grad_info.clear()
486
- if self.config.level == Const.LEVEL_L2:
487
- self.data_collector.data_processor.reset_status()
488
- return
489
- if self.config.step and self.current_iter not in self.config.step:
490
- return
491
- if self.config.rank and self.current_rank not in self.config.rank:
492
- return
493
-
494
- def init_for_debug_level(self):
495
- if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
496
- return
497
- try:
498
- self.current_rank = get_rank_if_initialized()
499
- except DistributedNotInitializedError:
500
- self.current_rank = None
501
- # dir: dump_path -- rank{} -- debug.json
502
- self.dump_iter_dir = self.config.dump_path
503
- cur_rank = self.current_rank if self.current_rank is not None else ''
504
- dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
505
- create_directory(dump_dir)
506
- if self.config.task in self.data_collector.tasks_need_tensor_data:
507
- dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
508
- create_directory(dump_data_dir)
509
- else:
510
- dump_data_dir = None
511
-
512
- dump_path_aggregation = DumpPathAggregation()
513
- dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
514
- dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
515
- self.data_collector.update_dump_paths(dump_path_aggregation)
516
- self.data_collector.initialize_json_file(
517
- framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
518
- )
519
- self.debug_variable_counter = defaultdict(int)
520
-
521
- def save(self, variable, name, save_backward):
522
- '''
523
- Args:
524
- variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int]
525
- name: str
526
- save_backward: boolean
527
- Return:
528
- void
529
- '''
530
- if self.config.level != Const.LEVEL_DEBUG:
531
- return
532
- count = self.debug_variable_counter[name]
533
- self.debug_variable_counter[name] += 1
534
-
535
- name_with_count = f"{name}.{count}"
536
- grad_name_with_count = f"{name}_grad.{count}"
537
-
538
- # forward save
539
- self.data_collector.debug_data_collect_forward(variable, name_with_count)
540
-
541
- # backward save
542
- if save_backward:
543
- self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)