mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -1,473 +0,0 @@
1
- # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import functools
17
- import os
18
- from collections import namedtuple, defaultdict
19
-
20
- import torch
21
- from msprobe.core.common.const import Const
22
- from msprobe.core.common.exceptions import DistributedNotInitializedError
23
- from msprobe.core.common.file_utils import create_directory
24
- from msprobe.core.common.utils import print_tools_ends_info, DumpPathAggregation
25
- from msprobe.core.data_dump.data_collector import build_data_collector
26
- from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
27
- from msprobe.core.data_dump.scope import BaseScope
28
- from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
29
- from msprobe.pytorch.common.log import logger
30
- from msprobe.pytorch.common.utils import get_rank_if_initialized, is_recomputation
31
- from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
32
- from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
33
- from msprobe.pytorch.hook_module.api_register import get_api_register
34
- from msprobe.pytorch.hook_module.hook_module import HOOKModule
35
- from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
36
-
37
- torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
38
- if torch_version_above_or_equal_2:
39
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
40
-
41
- HookFn = namedtuple('hookFn', ['pre_hook', 'forward_hook', 'backward_hook', 'forward_hook_torch_version_below_2'])
42
-
43
-
44
- class Service:
45
- def __init__(self, config):
46
- self.model = None
47
- self.config = config
48
- self.data_collector = build_data_collector(config)
49
- self.module_processor = ModuleProcesser(self.data_collector.scope)
50
- self.switch = False
51
- self.inner_switch = False
52
- self.current_iter = 0
53
- self.loop = 0
54
- self.init_step = 0
55
- self.first_start = True
56
- self.current_rank = None
57
- self.dump_iter_dir = None
58
- self.should_stop_service = False
59
- self.attl = None
60
- self.params_grad_info = {}
61
- self.hook_handle_dict = {}
62
- # 提前注册,确保注册尽可能多的API hook
63
- self.api_register = get_api_register()
64
- self.register_api_hook()
65
- self.init_for_debug_level()
66
-
67
- def build_hook(self, module_type, name):
68
- def pre_hook(api_or_module_name, module, args, kwargs):
69
- if not self.should_execute_hook(module_type, module, True):
70
- return args, kwargs
71
- is_recompute = is_recomputation()
72
-
73
- self.inner_switch = True
74
- if module_type == BaseScope.Module_Type_Module:
75
- api_or_module_name = module.mindstudio_reserved_name[-1]
76
- else:
77
- module.forward_data_collected = True
78
- HOOKModule.add_module_count(name)
79
- self.data_collector.update_api_or_module_name(api_or_module_name)
80
-
81
- if self.config.online_run_ut:
82
- self.inner_switch = False
83
- return None, None
84
- if self.data_collector:
85
- module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
86
- self.data_collector.forward_input_data_collect(
87
- api_or_module_name,
88
- module,
89
- pid,
90
- module_input_output,
91
- is_recompute
92
- )
93
-
94
- self.inner_switch = False
95
- return args, kwargs
96
-
97
- def grad_hook(module, ori_name, param_name):
98
- def hook_fn(grad):
99
- if not self.should_execute_hook(module_type, module, False):
100
- return grad
101
- self.inner_switch = True
102
- self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
103
- self.inner_switch = False
104
- return grad
105
-
106
- return hook_fn
107
-
108
- def register_param_hook(ori_name, module, params_dict):
109
- '''
110
- 注册参数hook
111
- '''
112
- # data_mode为forward时,不注册参数hook
113
- if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
114
- for param_name, param in params_dict.items():
115
- if param.requires_grad:
116
- name = ori_name + Const.SEP + param_name
117
- old_handle = self.hook_handle_dict.get(name)
118
- if old_handle and hasattr(old_handle, "remove"):
119
- old_handle.remove()
120
- handle = param.register_hook(grad_hook(module, ori_name, param_name))
121
- self.hook_handle_dict[name] = handle
122
-
123
- def init_params_grad_info(module, params_dict):
124
- '''
125
- 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
126
- '''
127
- if not params_dict:
128
- return
129
- if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
130
- grad_name = module.params_grad_name if hasattr(module, 'params_grad_name') else None
131
- # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
132
- if not self.params_grad_info.get(grad_name):
133
- data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
134
- # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
135
- if data_info.get(grad_name):
136
- # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
137
- self.data_collector.handle_data(grad_name, data_info,
138
- flush=self.data_collector.data_processor.is_terminated)
139
- # 记录当前模块的参数梯度信息已占位
140
- self.params_grad_info[grad_name] = True
141
-
142
- def forward_hook(api_or_module_name, module, args, kwargs, output):
143
- if not self.should_execute_hook(module_type, module, True):
144
- return None
145
- is_recompute = is_recomputation()
146
-
147
- self.inner_switch = True
148
- if self.config.online_run_ut:
149
- self.data_collector.update_api_or_module_name(api_or_module_name)
150
- if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
151
- return None
152
- api_data = ApiData(
153
- api_or_module_name[:-len(Const.FORWARD_NAME_SUFFIX)],
154
- args,
155
- kwargs,
156
- output,
157
- self.current_iter,
158
- self.current_rank
159
- )
160
- self.attl_send(api_data)
161
- self.inner_switch = False
162
- return None
163
-
164
- module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
165
- if module_type == BaseScope.Module_Type_Module:
166
- api_or_module_name = module.mindstudio_reserved_name[-1]
167
- self.data_collector.update_api_or_module_name(api_or_module_name)
168
- params_dict = {}
169
- if self.config.task != Const.STRUCTURE:
170
- params_dict = {
171
- key.split(Const.SEP)[-1]: value
172
- for key, value in module.named_parameters(recurse=False)
173
- }
174
- setattr(module_input_output, Const.PARAMS, params_dict)
175
- # 判断是否需要注册参数hook
176
- if params_dict:
177
- ori_name = api_or_module_name.rsplit(Const.SEP, 2)[0]
178
- grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
179
- # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
180
- setattr(module, 'params_grad_name', grad_name)
181
- register_param_hook(ori_name, module, params_dict)
182
- self.data_collector.forward_data_collect(
183
- api_or_module_name,
184
- module,
185
- pid,
186
- module_input_output,
187
- is_recompute
188
- )
189
- init_params_grad_info(module, params_dict)
190
- else:
191
- self.data_collector.update_api_or_module_name(api_or_module_name)
192
- self.data_collector.forward_output_data_collect(
193
- api_or_module_name,
194
- module,
195
- pid,
196
- module_input_output,
197
- is_recompute
198
- )
199
-
200
- if self.data_collector.if_return_forward_new_output():
201
- forward_new_output = self.data_collector.get_forward_new_output()
202
- self.inner_switch = False
203
- return forward_new_output
204
- self.inner_switch = False
205
- return output
206
-
207
- def forward_hook_torch_version_below_2(api_or_module_name, module, args, output):
208
- return forward_hook(api_or_module_name, module, args, {}, output)
209
-
210
- def backward_hook(api_or_module_name, module, grad_input, grad_output):
211
- if not self.should_execute_hook(module_type, module, False):
212
- return
213
- is_recompute = is_recomputation()
214
-
215
- self.inner_switch = True
216
- if module_type == BaseScope.Module_Type_Module:
217
- api_or_module_name = module.mindstudio_reserved_name[-1]
218
- self.data_collector.update_api_or_module_name(api_or_module_name)
219
-
220
- if self.config.online_run_ut:
221
- self.inner_switch = False
222
- return
223
-
224
- if self.data_collector:
225
- # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序
226
- module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
227
- self.data_collector.backward_data_collect(
228
- api_or_module_name,
229
- module,
230
- pid,
231
- module_input_output,
232
- is_recompute
233
- )
234
- self.inner_switch = False
235
-
236
- pid = os.getpid()
237
- full_forward_name = None
238
- full_backward_name = None
239
- if module_type == BaseScope.Module_Type_API:
240
- full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD
241
- full_backward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.BACKWARD
242
- pre_forward_hook_fn = functools.partial(pre_hook, full_forward_name)
243
- forward_hook_fn = functools.partial(forward_hook, full_forward_name)
244
- backward_hook_fn = functools.partial(backward_hook, full_backward_name)
245
- forward_hook_torch_version_below_2_fn = functools.partial(
246
- forward_hook_torch_version_below_2,
247
- full_forward_name
248
- )
249
- return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
250
-
251
- def start(self, model):
252
- self.current_iter = self.loop + self.init_step
253
- self.data_collector.update_iter(self.current_iter)
254
- if self.config.level == Const.LEVEL_DEBUG:
255
- return
256
- if self.need_stop_service():
257
- return
258
-
259
- self.model = model
260
- if self.first_start:
261
- try:
262
- self.current_rank = get_rank_if_initialized()
263
- except DistributedNotInitializedError:
264
- self.current_rank = None
265
- self.attl_init()
266
-
267
- if self.config.rank and self.current_rank not in self.config.rank:
268
- return
269
- self.register_module_hook()
270
- if self.config.level == Const.LEVEL_MIX:
271
- register_optimizer_hook(self.data_collector)
272
- self.first_start = False
273
- if self.config.online_run_ut and torch_version_above_or_equal_2:
274
- run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute)
275
- self.switch = True
276
- logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ")
277
- if not self.config.online_run_ut:
278
- self.create_dirs()
279
- logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
280
-
281
- def stop(self):
282
- if self.config.level == Const.LEVEL_DEBUG:
283
- return
284
- if self.should_stop_service:
285
- return
286
- if self.config.step and self.current_iter not in self.config.step:
287
- return
288
- if self.config.rank and self.current_rank not in self.config.rank:
289
- return
290
- self.switch = False
291
- if self.config.level == Const.LEVEL_L2:
292
- return
293
- if self.config.online_run_ut and torch_version_above_or_equal_2:
294
- run_ut_dispatch(self.attl, False, self.config.online_run_ut_recompute)
295
- return
296
- if self.config.async_dump:
297
- self.data_collector.fill_stack_tensor_data()
298
- if self.config.task == Const.TENSOR:
299
- self.data_collector.data_processor.dump_async_data()
300
- self.data_collector.write_json()
301
-
302
- def step(self):
303
- if self.config.level == Const.LEVEL_DEBUG:
304
- return
305
- if self.should_stop_service:
306
- return
307
- if self.config.async_dump:
308
- self.data_collector.fill_stack_tensor_data()
309
- if self.config.task == Const.TENSOR:
310
- self.data_collector.data_processor.dump_async_data()
311
- self.data_collector.write_json()
312
- self.loop += 1
313
- self.reset_status()
314
-
315
- def need_stop_service(self):
316
- if self.should_stop_service:
317
- return True
318
- end_service = self.config.step and self.current_iter > max(self.config.step) or \
319
- self.data_collector and self.data_collector.data_processor.is_terminated
320
- if end_service:
321
- if self.config.online_run_ut:
322
- # send stop signal if online_run_ut
323
- self.attl_stop()
324
- self.switch = False
325
- self.should_stop_service = True
326
- print_tools_ends_info()
327
- return True
328
- if self.config.step and self.current_iter not in self.config.step:
329
- return True
330
- return False
331
-
332
- def should_execute_hook(self, hook_type, module, is_forward):
333
- is_module_hook = hook_type == BaseScope.Module_Type_Module
334
- if is_module_hook and not self.switch:
335
- return False
336
- elif not is_module_hook and is_forward and not self.switch:
337
- return False
338
- elif not is_module_hook and not is_forward and not module.forward_data_collected:
339
- return False
340
-
341
- if self.inner_switch:
342
- return False
343
- if not self.data_collector or self.data_collector.data_processor.is_terminated:
344
- return False
345
- return True
346
-
347
- def create_dirs(self):
348
- create_directory(self.config.dump_path)
349
- self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
350
- cur_rank = self.current_rank if self.current_rank is not None else ''
351
- if self.config.level == Const.LEVEL_L2:
352
- create_directory(self.dump_iter_dir)
353
- kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
354
- self.config.kernel_config_path = kernel_config_path
355
- return
356
-
357
- dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
358
- create_directory(dump_dir)
359
- if self.config.task in self.data_collector.tasks_need_tensor_data:
360
- dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
361
- create_directory(dump_data_dir)
362
- else:
363
- dump_data_dir = None
364
-
365
- dump_path_aggregation = DumpPathAggregation()
366
- dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
367
- dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
368
- dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
369
- dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
370
- dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv")
371
- self.data_collector.update_dump_paths(dump_path_aggregation)
372
- self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
373
-
374
- def register_api_hook(self):
375
- if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
376
- logger.info_on_rank_0(f"The api {self.config.task} hook function is successfully mounted to the model.")
377
- self.api_register.initialize_hook(
378
- functools.partial(self.build_hook, BaseScope.Module_Type_API)
379
- )
380
- self.api_register.register_all_api()
381
-
382
- def register_module_hook(self):
383
- if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]:
384
- logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.")
385
- self.module_processor.register_module_hook(self.model, self.build_hook)
386
-
387
- def attl_init(self):
388
- if self.config.online_run_ut:
389
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL
390
- attl_config = ATTLConfig(is_benchmark_device=False,
391
- connect_ip=self.config.host,
392
- connect_port=self.config.port,
393
- nfs_path=self.config.nfs_path,
394
- tls_path=self.config.tls_path)
395
- need_dump = len(self.config.rank) == 0 or self.current_rank in self.config.rank
396
- self.attl = ATTL('npu', attl_config, need_dump=need_dump)
397
- if self.config.nfs_path:
398
- self.attl.upload("start")
399
-
400
- def attl_send(self, api_data):
401
- logger.info(f"tools is dumping api: {api_data.name}, rank: {self.current_rank}")
402
- api_type, _, _ = api_data.name.split(Const.SEP)
403
- if api_type in [Const.DISTRIBUTED]:
404
- logger.info(f"api {api_data.name} is not supported, skip")
405
- return
406
- if self.config.nfs_path:
407
- self.attl.upload(api_data)
408
- else:
409
- self.attl.send(api_data)
410
-
411
- def attl_stop(self):
412
- if self.config.nfs_path:
413
- self.attl.upload("end")
414
- elif self.attl.socket_manager is not None:
415
- logger.info(f"pid: {os.getpid()} finished, start sends STOP signal.")
416
- self.attl.socket_manager.send_stop_signal()
417
-
418
- def reset_status(self):
419
- ModuleProcesser.reset_module_stats()
420
- HOOKModule.reset_module_stats()
421
- self.data_collector.reset_status()
422
- self.params_grad_info.clear()
423
-
424
- if self.config.level == Const.LEVEL_L2:
425
- self.data_collector.data_processor.reset_status()
426
- return
427
- if self.config.step and self.current_iter not in self.config.step:
428
- return
429
- if self.config.rank and self.current_rank not in self.config.rank:
430
- return
431
-
432
- def init_for_debug_level(self):
433
- if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
434
- return
435
- try:
436
- self.current_rank = get_rank_if_initialized()
437
- except DistributedNotInitializedError:
438
- self.current_rank = None
439
-
440
- # dir: dump_path -- rank{} -- debug.json
441
- self.dump_iter_dir = self.config.dump_path
442
- cur_rank = self.current_rank if self.current_rank is not None else ''
443
- dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
444
- create_directory(dump_dir)
445
- if self.config.task in self.data_collector.tasks_need_tensor_data:
446
- dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
447
- create_directory(dump_data_dir)
448
- else:
449
- dump_data_dir = None
450
-
451
- dump_path_aggregation = DumpPathAggregation()
452
- dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
453
- dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
454
- self.data_collector.update_dump_paths(dump_path_aggregation)
455
- self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
456
-
457
- self.debug_variable_counter = defaultdict(int)
458
-
459
- def save(self, variable, name, save_backward):
460
- if self.config.level != Const.LEVEL_DEBUG:
461
- return
462
- count = self.debug_variable_counter[name]
463
- self.debug_variable_counter[name] += 1
464
-
465
- name_with_count = f"{name}.{count}"
466
- grad_name_with_count = f"{name}_grad.{count}"
467
-
468
- # forward save
469
- self.data_collector.debug_data_collect_forward(variable, name_with_count)
470
-
471
- # backward save
472
- if save_backward:
473
- self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)
File without changes