mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -1,549 +0,0 @@
1
- # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import copy
17
- import functools
18
- import os
19
- from collections import defaultdict
20
-
21
- import mindspore as ms
22
- from mindspore import nn
23
- from mindspore.common.api import _no_grad
24
- from mindspore.ops.primitive import Primitive
25
-
26
- try:
27
- from mindspore.common._pijit_context import PIJitCaptureContext
28
- except ImportError:
29
- pijit_label = False
30
- else:
31
- pijit_label = True
32
-
33
- from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
34
- from msprobe.core.common.file_utils import create_directory
35
- from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
36
- from msprobe.core.data_dump.data_collector import build_data_collector
37
- from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,
38
- ModuleBackwardInputs)
39
- from msprobe.core.data_dump.scope import BaseScope
40
- from msprobe.mindspore.cell_processor import CellProcessor
41
- from msprobe.mindspore.common.log import logger
42
- from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs,
43
- is_mindtorch, register_backward_hook_functions)
44
- from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
45
- from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
46
- from msprobe.mindspore.dump.jit_dump import JitDump
47
- from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
48
- from msprobe.mindspore.dump.kernel_dump.kernel_config import create_kernel_config_json
49
-
50
- if is_mindtorch():
51
- import torch
52
-
53
-
54
- class Service:
55
- def __init__(self, config):
56
- self.model = None
57
- self.config = copy.deepcopy(config)
58
- self.config.level = self.config.level_ori
59
- self.data_collector = build_data_collector(self.config)
60
- self.cell_processor = CellProcessor(self.data_collector.scope)
61
- self.primitive_hook_service = PrimitiveHookService(self)
62
- self.switch = False
63
- self.inner_switch = False
64
- self.primitive_switch = False
65
- self.current_iter = 0
66
- self.loop = 0
67
- self.init_step = 0
68
- self.first_start = True
69
- self.current_rank = None
70
- self.dump_iter_dir = None
71
- self.start_call = False
72
- self.should_stop_service = False
73
- self.params_grad_info = {}
74
- self.hook_handle_dict = {}
75
- # 提前注册,确保注册尽可能多的API hook
76
- self.api_register = get_api_register()
77
- self.register_api_hook()
78
- self.init_for_debug_level()
79
-
80
- @staticmethod
81
- def check_model_valid(models):
82
- target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell")
83
- if models is None or isinstance(models, target_module_type[0]):
84
- return models
85
- error_model = None
86
- if isinstance(models, (list, tuple)):
87
- for model in models:
88
- if not isinstance(model, target_module_type[0]):
89
- error_model = model
90
- break
91
- else:
92
- error_model = models
93
-
94
- if error_model is not None:
95
- error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] "
96
- f"type, currently there is a {type(error_model)} type.")
97
- raise MsprobeException(
98
- MsprobeException.INVALID_PARAM_ERROR, error_info)
99
- return models
100
-
101
- @staticmethod
102
- def prepare_module_input_output(target_type, cell, input_data, output):
103
- if target_type == BaseScope.Module_Type_Module:
104
- module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output)
105
- else:
106
- module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output)
107
- return module_input_output
108
-
109
- def build_hook(self, target_type, name):
110
- def pre_hook(api_or_cell_name, cell, input_data):
111
- if not self.should_execute_hook(target_type, cell, True):
112
- clean_input_kwargs(cell)
113
- return None
114
-
115
- with _no_grad():
116
- self.inner_switch = True
117
- if target_type == BaseScope.Module_Type_Module:
118
- api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
119
- else:
120
- cell.forward_data_collected = True
121
- HOOKCell.add_cell_count(name)
122
- module_input_output = self.prepare_module_input_output(target_type, cell, input_data, None)
123
- self.data_collector.update_api_or_module_name(api_or_cell_name)
124
- self.data_collector.forward_input_data_collect(api_or_cell_name, cell, pid, module_input_output)
125
- self.inner_switch = False
126
- return input_data
127
-
128
- def grad_hook(cell, ori_name, param_name):
129
- def hook_fn(grad):
130
- if not self.should_execute_hook(target_type, cell, False):
131
- return None
132
- self.inner_switch = True
133
- self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
134
- self.inner_switch = False
135
- return None
136
-
137
- return hook_fn
138
-
139
- def register_param_hook(ori_name, cell, params_dict):
140
- '''
141
- 注册参数hook
142
- '''
143
- # data_mode为forward时,不注册参数hook
144
- if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
145
- for param_name, param in params_dict.items():
146
- if param.requires_grad:
147
- name = ori_name + Const.SEP + param_name
148
- old_handle = self.hook_handle_dict.get(name)
149
- if old_handle and hasattr(old_handle, "remove"):
150
- old_handle.remove()
151
- handle = param.register_hook(grad_hook(cell, ori_name, param_name))
152
- self.hook_handle_dict[name] = handle
153
-
154
- def init_params_grad_info(cell, params_dict):
155
- '''
156
- 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
157
- '''
158
- if not params_dict:
159
- return
160
- if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
161
- grad_name = cell.params_grad_name if hasattr(cell, 'params_grad_name') else None
162
- # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
163
- if not self.params_grad_info.get(grad_name):
164
- data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
165
- # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
166
- if data_info.get(grad_name):
167
- # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
168
- self.data_collector.handle_data(grad_name, data_info,
169
- flush=self.data_collector.data_processor.is_terminated)
170
- # 记录当前模块的参数梯度信息已占位
171
- self.params_grad_info[grad_name] = True
172
-
173
- def forward_hook(api_or_cell_name, cell, input_data, output):
174
- if not self.should_execute_hook(target_type, cell, True):
175
- clean_input_kwargs(cell)
176
- return None
177
- with _no_grad():
178
- self.inner_switch = True
179
- module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output)
180
- if target_type == BaseScope.Module_Type_Module:
181
- api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
182
- params_dict = {}
183
- if self.config.task != Const.STRUCTURE:
184
- params_dict = {
185
- key.split(Const.SEP)[-1]: value
186
- for key, value in cell.parameters_dict(recurse=False).items()
187
- }
188
- setattr(module_input_output, Const.PARAMS, params_dict)
189
- # 判断是否需要注册参数hook
190
- if params_dict:
191
- ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0]
192
- grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
193
- # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
194
- setattr(cell, 'params_grad_name', grad_name)
195
- register_param_hook(ori_name, cell, params_dict)
196
- self.data_collector.update_api_or_module_name(api_or_cell_name)
197
- self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
198
- init_params_grad_info(cell, params_dict)
199
- else:
200
- self.data_collector.update_api_or_module_name(api_or_cell_name)
201
- self.data_collector.forward_output_data_collect(api_or_cell_name, cell, pid, module_input_output)
202
-
203
- if self.data_collector.if_return_forward_new_output():
204
- forward_new_output = self.data_collector.get_forward_new_output()
205
- self.inner_switch = False
206
- return forward_new_output
207
- clean_input_kwargs(cell)
208
- self.inner_switch = False
209
- return output
210
-
211
- def backward_hook(api_or_cell_name, cell, grad_input, grad_output):
212
- if not self.should_execute_hook(target_type, cell, False):
213
- return
214
- self.inner_switch = True
215
-
216
- need_exchange = True
217
- if target_type == BaseScope.Module_Type_Module:
218
- if not hasattr(cell, 'has_pre_hook_called') or not cell.has_pre_hook_called:
219
- need_exchange = False
220
- api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
221
-
222
- self.data_collector.update_api_or_module_name(api_or_cell_name)
223
- if self.data_collector:
224
- # 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入
225
- if need_exchange:
226
- module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
227
- else:
228
- module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
229
- self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
230
- self.inner_switch = False
231
-
232
- def pre_backward_hook(api_or_cell_name, cell, grad_input):
233
- if not self.should_execute_hook(target_type, cell, False):
234
- return
235
- self.inner_switch = True
236
- module_input = ModuleBackwardInputs(grad_input=grad_input)
237
- self.data_collector.update_api_or_module_name(api_or_cell_name)
238
- self.data_collector.backward_input_data_collect(api_or_cell_name, cell, pid, module_input)
239
-
240
- self.inner_switch = False
241
-
242
- pid = os.getpid()
243
- if target_type == BaseScope.Module_Type_Module:
244
- full_forward_name = name + Const.FORWARD
245
- full_backward_name = name + Const.BACKWARD
246
- else:
247
- full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD
248
- full_backward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.BACKWARD
249
- pre_forward_hook = functools.partial(pre_hook, full_forward_name)
250
- forward_hook = functools.partial(forward_hook, full_forward_name)
251
- backward_hook = functools.partial(backward_hook, full_backward_name)
252
- pre_backward_hook = functools.partial(pre_backward_hook, full_backward_name)
253
-
254
- def wrap_pre_forward_hook(cell, input_data):
255
- return pre_forward_hook(cell, input_data)
256
-
257
- def wrap_forward_hook(cell, input_data, output_data):
258
- return forward_hook(cell, input_data, output_data)
259
-
260
- def wrap_backward_hook(cell, grad_input, grad_output):
261
- return backward_hook(cell, grad_input, grad_output)
262
-
263
- def wrap_pre_backward_hook(cell, grad_input):
264
- return pre_backward_hook(cell, grad_input)
265
-
266
- return wrap_pre_forward_hook, wrap_forward_hook, wrap_backward_hook, wrap_pre_backward_hook
267
-
268
- def update_primitive_counters(self, primitive_name):
269
- if primitive_name not in self.primitive_counters:
270
- self.primitive_counters[primitive_name] = 0
271
- else:
272
- self.primitive_counters[primitive_name] += 1
273
-
274
- def step(self):
275
- if self.config.level == Const.LEVEL_DEBUG:
276
- return
277
- if self.config.async_dump:
278
- self.data_collector.fill_stack_tensor_data()
279
- if self.config.task == Const.TENSOR:
280
- self.data_collector.data_processor.dump_async_data()
281
- self.data_collector.write_json()
282
- self.loop += 1
283
- self.reset_status()
284
-
285
- def start(self, model=None):
286
- if self.current_iter == 0:
287
- if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
288
- JitDump.set_config(self.config)
289
- JitDump.set_data_collector(self.data_collector)
290
- if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
291
- ms.common.api._MindsporeFunctionExecutor = JitDump
292
- else:
293
- ms.common.api._JitExecutor = JitDump
294
- ms.common.api._PyNativeExecutor.grad = JitDump.grad
295
- if pijit_label:
296
- PIJitCaptureContext.__enter__ = self.empty
297
- PIJitCaptureContext.__exit__ = self.empty
298
- self.current_iter = self.loop + self.init_step
299
- self.data_collector.update_iter(self.current_iter)
300
- if self.config.level == Const.LEVEL_DEBUG:
301
- return
302
- self.start_call = True
303
- if self.should_stop_service:
304
- return
305
- if self.need_end_service():
306
- self.should_stop_service = True
307
- self.switch = False
308
- self.primitive_switch = False
309
- print_tools_ends_info()
310
- return
311
- if self.config.step and self.current_iter not in self.config.step:
312
- JitDump.jit_dump_switch = False
313
- return
314
- self.model = self.check_model_valid(model)
315
-
316
- logger.info(f"{Const.TOOL_NAME}: debugger.start() is set successfully")
317
-
318
- if self.first_start:
319
- try:
320
- self.current_rank = get_rank_if_initialized()
321
- except DistributedNotInitializedError:
322
- self.current_rank = None
323
-
324
- if self.config.rank and self.current_rank not in self.config.rank:
325
- return
326
- self.register_primitive_hook()
327
- self.register_cell_hook()
328
- self.first_start = False
329
-
330
- self.api_register.register_all_api()
331
- self.switch = True
332
- self.primitive_switch = True
333
- logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
334
- self.create_dirs()
335
- logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
336
- JitDump.jit_dump_switch = True
337
-
338
- def stop(self):
339
- if self.config.level == Const.LEVEL_DEBUG:
340
- return
341
- if self.should_stop_service:
342
- return
343
- logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. "
344
- "Please set debugger.start() to turn on the dump switch again. ")
345
- if not self.start_call:
346
- logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
347
- raise Exception("debugger.start() is not set in the current scope.")
348
- if self.config.step and self.current_iter not in self.config.step:
349
- return
350
- if self.config.rank and self.current_rank not in self.config.rank:
351
- return
352
- self.switch = False
353
- self.primitive_switch = False
354
- self.start_call = False
355
- if self.config.async_dump:
356
- self.data_collector.fill_stack_tensor_data()
357
- if self.config.task == Const.TENSOR:
358
- self.data_collector.data_processor.dump_async_data()
359
- self.data_collector.write_json()
360
- JitDump.jit_dump_switch = False
361
-
362
- def need_end_service(self):
363
- if self.config.step and self.current_iter > max(self.config.step):
364
- return True
365
- if self.data_collector and self.data_collector.data_processor.is_terminated:
366
- return True
367
- return False
368
-
369
- def should_execute_hook(self, hook_type, cell, is_forward):
370
- is_cell_hook = hook_type == BaseScope.Module_Type_Module
371
- if is_cell_hook and not self.switch:
372
- return False
373
- elif not is_cell_hook and is_forward and not self.switch:
374
- return False
375
- elif not is_cell_hook and not is_forward and not cell.forward_data_collected:
376
- return False
377
-
378
- if self.inner_switch:
379
- return False
380
- if not self.data_collector or self.data_collector.data_processor.is_terminated:
381
- return False
382
- return True
383
-
384
- def create_dirs(self):
385
- create_directory(self.config.dump_path)
386
- self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
387
- cur_rank = self.current_rank if self.current_rank is not None else ''
388
- if self.config.level == Const.LEVEL_L2:
389
- create_directory(self.dump_iter_dir)
390
- kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
391
- self.config.kernel_config_path = kernel_config_path
392
- return
393
-
394
- dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
395
- create_directory(dump_dir)
396
- if self.config.task in self.data_collector.tasks_need_tensor_data:
397
- dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
398
- create_directory(dump_data_dir)
399
- else:
400
- dump_data_dir = None
401
-
402
- dump_path_aggregation = DumpPathAggregation()
403
- dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
404
- dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
405
- dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
406
- dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
407
- self.data_collector.update_dump_paths(dump_path_aggregation)
408
-
409
- self.data_collector.initialize_json_file(
410
- framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
411
- )
412
-
413
- def empty(self, *args, **kwargs):
414
- pass
415
-
416
- def register_api_hook(self):
417
- if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
418
- logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.")
419
- self.api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
420
- self.api_register.register_all_api()
421
-
422
- def get_cells_and_names(self):
423
- cells_and_names_with_index = {}
424
-
425
- def get_cell_or_module(model):
426
- return model.named_modules() if is_mindtorch() else model.cells_and_names()
427
-
428
- if isinstance(self.model, (list, tuple)):
429
- for index, model in enumerate(self.model):
430
- cells_and_names_with_index[str(index)] = get_cell_or_module(model)
431
- else:
432
- cells_and_names_with_index["-1"] = get_cell_or_module(self.model)
433
- return cells_and_names_with_index
434
-
435
- def register_primitive_hook(self):
436
- if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
437
- return
438
- if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST:
439
- return
440
-
441
- primitive_set = set()
442
- cells_and_names_with_index = self.get_cells_and_names()
443
- for cells_and_names in cells_and_names_with_index.values():
444
- for _, cell in cells_and_names:
445
- for attribute, value in vars(cell).items():
446
- if isinstance(value, Primitive):
447
- primitive_set.add((attribute, value))
448
-
449
- for pname, primitive in primitive_set:
450
- primitive_class_name = primitive.__class__.__name__
451
- primitive_combined_name = pname + Const.SEP + primitive_class_name
452
- new_primitive = type('NewPrimitive', (primitive.__class__,),
453
- {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
454
- primitive_combined_name)})
455
- primitive.__class__ = new_primitive
456
-
457
- def register_cell_hook(self):
458
- if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]:
459
- logger.info(f"The cell {self.config.task} hook function is successfully mounted to the model.")
460
- if not self.model:
461
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
462
- f"The current level is {self.config.level}, the model cannot be None")
463
- model_type = Const.MODULE if is_mindtorch() else Const.CELL
464
- cells_and_names_with_index = self.get_cells_and_names()
465
-
466
- for index, cells_and_names in cells_and_names_with_index.items():
467
- model = self.model if index == "-1" else self.model[int(index)]
468
- for name, cell in cells_and_names:
469
- if cell == model:
470
- continue
471
- cell_index = (index + Const.SEP) if index != "-1" else ""
472
- prefix = (model_type + Const.SEP + cell_index + name +
473
- Const.SEP + cell.__class__.__name__ + Const.SEP)
474
- _, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix)
475
- cell.register_forward_hook(forward_hook)
476
- cell.register_forward_pre_hook(
477
- self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
478
- cell.register_forward_hook(
479
- self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
480
-
481
- register_backward_hook_functions["full"](cell, backward_hook)
482
- register_backward_hook_functions["pre"](
483
- cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
484
- register_backward_hook_functions["full"](
485
- cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
486
-
487
- def reset_status(self):
488
- self.primitive_hook_service.primitive_counters.clear()
489
- self.data_collector.reset_status()
490
- JitDump.jit_count = defaultdict(int)
491
- self.params_grad_info.clear()
492
- if self.config.level == Const.LEVEL_L2:
493
- self.data_collector.data_processor.reset_status()
494
- return
495
- if self.config.step and self.current_iter not in self.config.step:
496
- return
497
- if self.config.rank and self.current_rank not in self.config.rank:
498
- return
499
-
500
- def init_for_debug_level(self):
501
- if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
502
- return
503
- try:
504
- self.current_rank = get_rank_if_initialized()
505
- except DistributedNotInitializedError:
506
- self.current_rank = None
507
- # dir: dump_path -- rank{} -- debug.json
508
- self.dump_iter_dir = self.config.dump_path
509
- cur_rank = self.current_rank if self.current_rank is not None else ''
510
- dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
511
- create_directory(dump_dir)
512
- if self.config.task in self.data_collector.tasks_need_tensor_data:
513
- dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
514
- create_directory(dump_data_dir)
515
- else:
516
- dump_data_dir = None
517
-
518
- dump_path_aggregation = DumpPathAggregation()
519
- dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
520
- dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
521
- self.data_collector.update_dump_paths(dump_path_aggregation)
522
- self.data_collector.initialize_json_file(
523
- framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
524
- )
525
- self.debug_variable_counter = defaultdict(int)
526
-
527
- def save(self, variable, name, save_backward):
528
- '''
529
- Args:
530
- variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int]
531
- name: str
532
- save_backward: boolean
533
- Return:
534
- void
535
- '''
536
- if self.config.level != Const.LEVEL_DEBUG:
537
- return
538
- count = self.debug_variable_counter[name]
539
- self.debug_variable_counter[name] += 1
540
-
541
- name_with_count = f"{name}.{count}"
542
- grad_name_with_count = f"{name}_grad.{count}"
543
-
544
- # forward save
545
- self.data_collector.debug_data_collect_forward(variable, name_with_count)
546
-
547
- # backward save
548
- if save_backward:
549
- self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)