mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,24 +15,24 @@
15
15
 
16
16
  import functools
17
17
  import os
18
- from collections import namedtuple
18
+ from collections import namedtuple, defaultdict
19
19
 
20
20
  import torch
21
21
  from msprobe.core.common.const import Const
22
- from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
22
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
23
23
  from msprobe.core.common.file_utils import create_directory
24
- from msprobe.core.common.utils import print_tools_ends_info
24
+ from msprobe.core.common.utils import print_tools_ends_info, DumpPathAggregation
25
25
  from msprobe.core.data_dump.data_collector import build_data_collector
26
26
  from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
27
27
  from msprobe.core.data_dump.scope import BaseScope
28
28
  from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
29
29
  from msprobe.pytorch.common.log import logger
30
- from msprobe.pytorch.common.utils import get_rank_if_initialized
30
+ from msprobe.pytorch.common.utils import get_rank_if_initialized, is_recomputation
31
31
  from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
32
- from msprobe.pytorch.hook_module import remove_dropout
32
+ from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
33
33
  from msprobe.pytorch.hook_module.api_registry import api_register
34
34
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
35
- from msprobe.pytorch.module_processer import ModuleProcesser
35
+ from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
36
36
 
37
37
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
38
38
  if torch_version_above_or_equal_2:
@@ -48,100 +48,206 @@ class Service:
48
48
  self.data_collector = build_data_collector(config)
49
49
  self.module_processor = ModuleProcesser(self.data_collector.scope)
50
50
  self.switch = False
51
+ self.inner_switch = False
51
52
  self.current_iter = 0
52
53
  self.first_start = True
53
54
  self.current_rank = None
54
55
  self.dump_iter_dir = None
55
56
  self.should_stop_service = False
56
57
  self.attl = None
57
-
58
- @staticmethod
59
- def forward_backward_dump_end():
60
- logger.info_on_rank_0("Data needed ends here.")
61
- api_register.api_originality()
62
-
63
- @staticmethod
64
- def is_registered_backward_hook(module):
65
- if hasattr(module, '_backward_hooks') and \
66
- len(module._backward_hooks) > 0 and \
67
- module._is_full_backward_hook is False:
68
- return True
69
- return False
70
-
71
- def check_register_full_backward_hook(self, module):
72
- if self.is_registered_backward_hook(module):
73
- module._backward_hooks.clear()
74
- module._is_full_backward_hook = None
75
- logger.warning("Found deprecated backward hooks. Removing them and switching to full backward hooks.")
58
+ self.params_grad_info = {}
59
+ self.hook_handle_dict = {}
60
+ # 提前注册,确保注册尽可能多的API hook
61
+ self.register_api_hook()
62
+ self.init_for_debug_level()
76
63
 
77
64
  def build_hook(self, module_type, name):
78
65
  def pre_hook(api_or_module_name, module, args, kwargs):
79
- if not self.should_execute_hook():
66
+ if not self.should_execute_hook(module_type, module, True):
80
67
  return args, kwargs
68
+ is_recompute = is_recomputation()
81
69
 
70
+ self.inner_switch = True
82
71
  if module_type == BaseScope.Module_Type_Module:
83
- api_or_module_name = module.mindstudio_reserved_name
72
+ api_or_module_name = module.mindstudio_reserved_name[-1]
73
+ else:
74
+ module.forward_data_collected = True
75
+ HOOKModule.add_module_count(name)
84
76
  self.data_collector.update_api_or_module_name(api_or_module_name)
85
77
 
86
78
  if self.config.online_run_ut:
79
+ self.inner_switch = False
87
80
  return None, None
88
81
  if self.data_collector:
89
82
  module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
90
- self.data_collector.pre_forward_data_collect(api_or_module_name, module, pid, module_input_output)
83
+ self.data_collector.forward_input_data_collect(
84
+ api_or_module_name,
85
+ module,
86
+ pid,
87
+ module_input_output,
88
+ is_recompute
89
+ )
90
+
91
+ self.inner_switch = False
91
92
  return args, kwargs
92
93
 
94
+ def grad_hook(module, ori_name, param_name):
95
+ def hook_fn(grad):
96
+ if not self.should_execute_hook(module_type, module, False):
97
+ return grad
98
+ self.inner_switch = True
99
+ self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
100
+ self.inner_switch = False
101
+ return grad
102
+
103
+ return hook_fn
104
+
105
+ def register_param_hook(ori_name, module, params_dict):
106
+ '''
107
+ 注册参数hook
108
+ '''
109
+ # data_mode为forward时,不注册参数hook
110
+ if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
111
+ for param_name, param in params_dict.items():
112
+ if param.requires_grad:
113
+ name = ori_name + Const.SEP + param_name
114
+ old_handle = self.hook_handle_dict.get(name)
115
+ if old_handle and hasattr(old_handle, "remove"):
116
+ old_handle.remove()
117
+ handle = param.register_hook(grad_hook(module, ori_name, param_name))
118
+ self.hook_handle_dict[name] = handle
119
+
120
+ def init_params_grad_info(module, params_dict):
121
+ '''
122
+ 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
123
+ '''
124
+ if not params_dict:
125
+ return
126
+ if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
127
+ grad_name = module.params_grad_name if hasattr(module, 'params_grad_name') else None
128
+ # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
129
+ if not self.params_grad_info.get(grad_name):
130
+ data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
131
+ # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
132
+ if data_info.get(grad_name):
133
+ # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
134
+ self.data_collector.handle_data(grad_name, data_info,
135
+ flush=self.data_collector.data_processor.is_terminated)
136
+ # 记录当前模块的参数梯度信息已占位
137
+ self.params_grad_info[grad_name] = True
138
+
93
139
  def forward_hook(api_or_module_name, module, args, kwargs, output):
94
- if not self.should_execute_hook():
140
+ if not self.should_execute_hook(module_type, module, True):
95
141
  return None
142
+ is_recompute = is_recomputation()
96
143
 
97
- if module_type == BaseScope.Module_Type_Module:
98
- api_or_module_name = module.mindstudio_reserved_name
99
- self.data_collector.update_api_or_module_name(api_or_module_name)
100
-
144
+ self.inner_switch = True
101
145
  if self.config.online_run_ut:
146
+ self.data_collector.update_api_or_module_name(api_or_module_name)
102
147
  if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
103
148
  return None
104
- api_data = ApiData(name[:-1], args, kwargs, output, self.current_iter, self.current_rank)
149
+ api_data = ApiData(
150
+ api_or_module_name[:-len(Const.FORWARD_NAME_SUFFIX)],
151
+ args,
152
+ kwargs,
153
+ output,
154
+ self.current_iter,
155
+ self.current_rank
156
+ )
105
157
  self.attl_send(api_data)
158
+ self.inner_switch = False
106
159
  return None
107
160
 
108
- if self.data_collector:
109
- module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
110
- self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output)
111
- if self.data_collector.if_return_forward_new_output():
112
- return self.data_collector.get_forward_new_output()
161
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
162
+ if module_type == BaseScope.Module_Type_Module:
163
+ api_or_module_name = module.mindstudio_reserved_name[-1]
164
+ self.data_collector.update_api_or_module_name(api_or_module_name)
165
+ params_dict = {}
166
+ if self.config.task != Const.STRUCTURE:
167
+ params_dict = {
168
+ key.split(Const.SEP)[-1]: value
169
+ for key, value in module.named_parameters(recurse=False)
170
+ }
171
+ setattr(module_input_output, Const.PARAMS, params_dict)
172
+ # 判断是否需要注册参数hook
173
+ if params_dict:
174
+ ori_name = api_or_module_name.rsplit(Const.SEP, 2)[0]
175
+ grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
176
+ # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
177
+ setattr(module, 'params_grad_name', grad_name)
178
+ register_param_hook(ori_name, module, params_dict)
179
+ self.data_collector.forward_data_collect(
180
+ api_or_module_name,
181
+ module,
182
+ pid,
183
+ module_input_output,
184
+ is_recompute
185
+ )
186
+ init_params_grad_info(module, params_dict)
187
+ else:
188
+ self.data_collector.update_api_or_module_name(api_or_module_name)
189
+ self.data_collector.forward_output_data_collect(
190
+ api_or_module_name,
191
+ module,
192
+ pid,
193
+ module_input_output,
194
+ is_recompute
195
+ )
196
+
197
+ if self.data_collector.if_return_forward_new_output():
198
+ forward_new_output = self.data_collector.get_forward_new_output()
199
+ self.inner_switch = False
200
+ return forward_new_output
201
+ self.inner_switch = False
113
202
  return output
114
203
 
115
204
  def forward_hook_torch_version_below_2(api_or_module_name, module, args, output):
116
205
  return forward_hook(api_or_module_name, module, args, {}, output)
117
206
 
118
207
  def backward_hook(api_or_module_name, module, grad_input, grad_output):
119
- if not self.should_execute_hook():
208
+ if not self.should_execute_hook(module_type, module, False):
120
209
  return
210
+ is_recompute = is_recomputation()
121
211
 
212
+ self.inner_switch = True
122
213
  if module_type == BaseScope.Module_Type_Module:
123
- api_or_module_name = module.mindstudio_reserved_name
214
+ api_or_module_name = module.mindstudio_reserved_name[-1]
124
215
  self.data_collector.update_api_or_module_name(api_or_module_name)
125
216
 
126
217
  if self.config.online_run_ut:
218
+ self.inner_switch = False
127
219
  return
128
220
 
129
221
  if self.data_collector:
130
222
  # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序
131
223
  module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
132
- self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output)
224
+ self.data_collector.backward_data_collect(
225
+ api_or_module_name,
226
+ module,
227
+ pid,
228
+ module_input_output,
229
+ is_recompute
230
+ )
231
+ self.inner_switch = False
133
232
 
134
233
  pid = os.getpid()
135
- forward_name_template = name + Const.FORWARD
136
- backward_name_template = name + Const.BACKWARD
137
- pre_forward_hook_fn = functools.partial(pre_hook, forward_name_template)
138
- forward_hook_fn = functools.partial(forward_hook, forward_name_template)
139
- backward_hook_fn = functools.partial(backward_hook, backward_name_template)
140
- forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2,
141
- forward_name_template)
234
+ full_forward_name = None
235
+ full_backward_name = None
236
+ if module_type == BaseScope.Module_Type_API:
237
+ full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD
238
+ full_backward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.BACKWARD
239
+ pre_forward_hook_fn = functools.partial(pre_hook, full_forward_name)
240
+ forward_hook_fn = functools.partial(forward_hook, full_forward_name)
241
+ backward_hook_fn = functools.partial(backward_hook, full_backward_name)
242
+ forward_hook_torch_version_below_2_fn = functools.partial(
243
+ forward_hook_torch_version_below_2,
244
+ full_forward_name
245
+ )
142
246
  return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
143
247
 
144
- def start(self, model, api_origin=False):
248
+ def start(self, model):
249
+ if self.config.level == Const.LEVEL_DEBUG:
250
+ return
145
251
  if self.need_stop_service():
146
252
  return
147
253
 
@@ -155,10 +261,10 @@ class Service:
155
261
 
156
262
  if self.config.rank and self.current_rank not in self.config.rank:
157
263
  return
158
- self.register_hook_new()
264
+ self.register_module_hook()
265
+ if self.config.level == Const.LEVEL_MIX:
266
+ register_optimizer_hook(self.data_collector)
159
267
  self.first_start = False
160
- if api_origin:
161
- api_register.api_modularity()
162
268
  if self.config.online_run_ut and torch_version_above_or_equal_2:
163
269
  run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute)
164
270
  self.switch = True
@@ -168,32 +274,39 @@ class Service:
168
274
  logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
169
275
 
170
276
  def stop(self):
171
- if self.should_stop_service:
277
+ if self.config.level == Const.LEVEL_DEBUG:
172
278
  return
173
- if self.config.level == "L2":
279
+ if self.should_stop_service:
174
280
  return
175
281
  if self.config.step and self.current_iter not in self.config.step:
176
282
  return
177
283
  if self.config.rank and self.current_rank not in self.config.rank:
178
284
  return
179
285
  self.switch = False
286
+ if self.config.level == Const.LEVEL_L2:
287
+ return
180
288
  if self.config.online_run_ut and torch_version_above_or_equal_2:
181
289
  run_ut_dispatch(self.attl, False, self.config.online_run_ut_recompute)
182
290
  return
291
+ if self.config.async_dump:
292
+ self.data_collector.fill_stack_tensor_data()
293
+ if self.config.task == Const.TENSOR:
294
+ self.data_collector.data_processor.dump_async_data()
183
295
  self.data_collector.write_json()
184
296
 
185
297
  def step(self):
298
+ if self.config.level == Const.LEVEL_DEBUG:
299
+ return
186
300
  if self.should_stop_service:
187
301
  return
302
+ if self.config.async_dump:
303
+ self.data_collector.fill_stack_tensor_data()
304
+ if self.config.task == Const.TENSOR:
305
+ self.data_collector.data_processor.dump_async_data()
306
+ self.data_collector.write_json()
188
307
  self.current_iter += 1
189
308
  self.data_collector.update_iter(self.current_iter)
190
-
191
- ModuleProcesser.reset_module_stats()
192
- HOOKModule.reset_module_stats()
193
- self.data_collector.data_writer.reset_cache()
194
-
195
- if self.config.level == Const.LEVEL_L2:
196
- self.data_collector.data_processor.reset_status()
309
+ self.reset_status()
197
310
 
198
311
  def need_stop_service(self):
199
312
  if self.should_stop_service:
@@ -204,8 +317,6 @@ class Service:
204
317
  if self.config.online_run_ut:
205
318
  # send stop signal if online_run_ut
206
319
  self.attl_stop()
207
- if self.config.level in [Const.LEVEL_L1, Const.LEVEL_L2, Const.LEVEL_MIX]:
208
- api_register.api_originality()
209
320
  self.switch = False
210
321
  self.should_stop_service = True
211
322
  print_tools_ends_info()
@@ -214,10 +325,18 @@ class Service:
214
325
  return True
215
326
  return False
216
327
 
217
- def should_execute_hook(self):
218
- if not self.switch:
328
+ def should_execute_hook(self, hook_type, module, is_forward):
329
+ is_module_hook = hook_type == BaseScope.Module_Type_Module
330
+ if is_module_hook and not self.switch:
331
+ return False
332
+ elif not is_module_hook and is_forward and not self.switch:
333
+ return False
334
+ elif not is_module_hook and not is_forward and not module.forward_data_collected:
335
+ return False
336
+
337
+ if self.inner_switch:
219
338
  return False
220
- if self.data_collector and self.data_collector.data_processor.is_terminated:
339
+ if not self.data_collector or self.data_collector.data_processor.is_terminated:
221
340
  return False
222
341
  return True
223
342
 
@@ -239,55 +358,28 @@ class Service:
239
358
  else:
240
359
  dump_data_dir = None
241
360
 
242
- dump_file_path = os.path.join(dump_dir, "dump.json")
243
- stack_file_path = os.path.join(dump_dir, "stack.json")
244
- construct_file_path = os.path.join(dump_dir, "construct.json")
245
- free_benchmark_file_path = os.path.join(self.config.dump_path, "free_benchmark.csv")
246
- self.data_collector.update_dump_paths(
247
- dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path)
248
-
249
- def register_hook_new(self):
250
- logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task))
251
- if self.config.level in ["L0", "mix"]:
252
- if self.model is None:
253
- logger.error_log_with_exp("The model is None.", MsprobeException.INVALID_PARAM_ERROR)
254
- logger.info_on_rank_0("The init dump mode is enabled, and the module dump function will not be available")
255
- for name, module in self.model.named_modules():
256
- if module == self.model:
257
- continue
258
- prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \
259
- module.__class__.__name__ + Const.SEP
260
-
261
- pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.build_hook(
262
- BaseScope.Module_Type_Module, prefix)
263
- if torch_version_above_or_equal_2:
264
- module.register_forward_hook(forward_hook, with_kwargs=True)
265
- else:
266
- self.check_register_full_backward_hook(module)
267
- module.register_full_backward_hook(
268
- self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
269
- module.register_forward_hook(forward_hook_torch_version_below_2)
270
- self.check_register_full_backward_hook(module)
271
- module.register_full_backward_hook(backward_hook)
272
-
273
- module.register_forward_pre_hook(
274
- self.module_processor.node_hook(prefix + Const.FORWARD, Const.START))
275
- module.register_forward_hook(
276
- self.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
277
- if torch_version_above_or_equal_2:
278
- module.register_full_backward_pre_hook(
279
- self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
280
- self.check_register_full_backward_hook(module)
281
- module.register_full_backward_hook(
282
- self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
283
-
284
- if self.config.level in ["mix", "L1", "L2"]:
285
- api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API),
286
- self.config.online_run_ut)
361
+ dump_path_aggregation = DumpPathAggregation()
362
+ dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
363
+ dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
364
+ dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
365
+ dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
366
+ dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv")
367
+ self.data_collector.update_dump_paths(dump_path_aggregation)
368
+ self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
369
+
370
+ def register_api_hook(self):
371
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
372
+ logger.info_on_rank_0(f"The api {self.config.task} hook function is successfully mounted to the model.")
373
+ api_register.initialize_hook(
374
+ functools.partial(self.build_hook, BaseScope.Module_Type_API),
375
+ self.config.online_run_ut
376
+ )
287
377
  api_register.api_modularity()
288
378
 
289
- if Const.STATISTICS == self.config.task or Const.TENSOR == self.config.task:
290
- remove_dropout()
379
+ def register_module_hook(self):
380
+ if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]:
381
+ logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.")
382
+ self.module_processor.register_module_hook(self.model, self.build_hook)
291
383
 
292
384
  def attl_init(self):
293
385
  if self.config.online_run_ut:
@@ -319,3 +411,60 @@ class Service:
319
411
  elif self.attl.socket_manager is not None:
320
412
  logger.info(f"pid: {os.getpid()} finished, start send STOP signal.")
321
413
  self.attl.socket_manager.send_stop_signal()
414
+
415
+ def reset_status(self):
416
+ ModuleProcesser.reset_module_stats()
417
+ HOOKModule.reset_module_stats()
418
+ self.data_collector.reset_status()
419
+ self.params_grad_info.clear()
420
+
421
+ if self.config.level == Const.LEVEL_L2:
422
+ self.data_collector.data_processor.reset_status()
423
+ return
424
+ if self.config.step and self.current_iter not in self.config.step:
425
+ return
426
+ if self.config.rank and self.current_rank not in self.config.rank:
427
+ return
428
+
429
+ def init_for_debug_level(self):
430
+ if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
431
+ return
432
+ try:
433
+ self.current_rank = get_rank_if_initialized()
434
+ except DistributedNotInitializedError:
435
+ self.current_rank = None
436
+
437
+ # dir: dump_path -- rank{} -- debug.json
438
+ self.dump_iter_dir = self.config.dump_path
439
+ cur_rank = self.current_rank if self.current_rank is not None else ''
440
+ dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
441
+ create_directory(dump_dir)
442
+ if self.config.task in self.data_collector.tasks_need_tensor_data:
443
+ dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
444
+ create_directory(dump_data_dir)
445
+ else:
446
+ dump_data_dir = None
447
+
448
+ dump_path_aggregation = DumpPathAggregation()
449
+ dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
450
+ dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
451
+ self.data_collector.update_dump_paths(dump_path_aggregation)
452
+ self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
453
+
454
+ self.debug_variable_counter = defaultdict(int)
455
+
456
+ def save(self, variable, name, save_backward):
457
+ if self.config.level != Const.LEVEL_DEBUG:
458
+ return
459
+ count = self.debug_variable_counter[name]
460
+ self.debug_variable_counter[name] += 1
461
+
462
+ name_with_count = f"{name}.{count}"
463
+ grad_name_with_count = f"{name}_grad.{count}"
464
+
465
+ # forward save
466
+ self.data_collector.debug_data_collect_forward(variable, name_with_count)
467
+
468
+ # backward save
469
+ if save_backward:
470
+ self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)