mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,6 +20,9 @@ from collections import defaultdict
20
20
 
21
21
  import mindspore as ms
22
22
  from mindspore import nn
23
+ from mindspore.common.api import _no_grad
24
+ from mindspore.ops.primitive import Primitive
25
+
23
26
  try:
24
27
  from mindspore.common._pijit_context import PIJitCaptureContext
25
28
  except ImportError:
@@ -27,19 +30,25 @@ except ImportError:
27
30
  else:
28
31
  pijit_label = True
29
32
 
30
-
31
33
  from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
32
34
  from msprobe.core.common.file_utils import create_directory
33
- from msprobe.core.common.utils import Const, print_tools_ends_info
35
+ from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
34
36
  from msprobe.core.data_dump.data_collector import build_data_collector
35
- from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs
37
+ from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,
38
+ ModuleBackwardInputs)
36
39
  from msprobe.core.data_dump.scope import BaseScope
37
40
  from msprobe.mindspore.cell_processor import CellProcessor
38
41
  from msprobe.mindspore.common.log import logger
39
- from msprobe.mindspore.common.utils import get_rank_if_initialized
42
+ from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs,
43
+ is_mindtorch, register_backward_hook_functions)
40
44
  from msprobe.mindspore.dump.hook_cell.api_registry import api_register
41
45
  from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
42
46
  from msprobe.mindspore.dump.jit_dump import JitDump
47
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
48
+ from msprobe.mindspore.dump.kernel_dump.kernel_config import create_kernel_config_json
49
+
50
+ if is_mindtorch():
51
+ import torch
43
52
 
44
53
 
45
54
  class Service:
@@ -51,54 +60,155 @@ class Service:
51
60
  self.cell_processor = CellProcessor(self.data_collector.scope)
52
61
  self.primitive_hook_service = PrimitiveHookService(self)
53
62
  self.switch = False
63
+ self.inner_switch = False
54
64
  self.primitive_switch = False
55
65
  self.current_iter = 0
56
66
  self.first_start = True
57
67
  self.current_rank = None
58
68
  self.dump_iter_dir = None
59
69
  self.start_call = False
60
- self.check_level_valid()
61
70
  self.should_stop_service = False
71
+ self.params_grad_info = {}
72
+ self.hook_handle_dict = {}
73
+ # 提前注册,确保注册尽可能多的API hook
74
+ self.register_api_hook()
75
+ self.init_for_debug_level()
62
76
 
63
77
  @staticmethod
64
- def check_model_valid(model):
65
- if not model or isinstance(model, nn.Cell):
66
- return model
67
- raise MsprobeException(
68
- MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。"
69
- )
78
+ def check_model_valid(models):
79
+ target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell")
80
+ if models is None or isinstance(models, target_module_type[0]):
81
+ return models
82
+ error_model = None
83
+ if isinstance(models, (list, tuple)):
84
+ for model in models:
85
+ if not isinstance(model, target_module_type[0]):
86
+ error_model = model
87
+ break
88
+ else:
89
+ error_model = models
70
90
 
71
- def check_level_valid(self):
72
- if self.config.level == Const.LEVEL_L2:
91
+ if error_model is not None:
92
+ error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] "
93
+ f"type, currently there is a {type(error_model)} type.")
73
94
  raise MsprobeException(
74
- MsprobeException.INVALID_PARAM_ERROR, "L2 level dump function is currently not supported."
75
- )
95
+ MsprobeException.INVALID_PARAM_ERROR, error_info)
96
+ return models
97
+
98
+ @staticmethod
99
+ def prepare_module_input_output(target_type, cell, input_data, output):
100
+ if target_type == BaseScope.Module_Type_Module:
101
+ module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output)
102
+ else:
103
+ module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output)
104
+ return module_input_output
76
105
 
77
106
  def build_hook(self, target_type, name):
78
- def forward_hook(api_or_cell_name, cell, input_data, output):
79
- if not self.should_excute_hook():
80
- if hasattr(cell, 'input_kwargs'):
81
- del cell.input_kwargs
107
+ def pre_hook(api_or_cell_name, cell, input_data):
108
+ if not self.should_execute_hook(target_type, cell, True):
109
+ clean_input_kwargs(cell)
82
110
  return None
83
111
 
84
- if target_type == BaseScope.Module_Type_Module:
85
- api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
86
- module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output)
87
- else:
88
- module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs,
89
- output=output)
112
+ with _no_grad():
113
+ self.inner_switch = True
114
+ if target_type == BaseScope.Module_Type_Module:
115
+ api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
116
+ else:
117
+ cell.forward_data_collected = True
118
+ HOOKCell.add_cell_count(name)
119
+ module_input_output = self.prepare_module_input_output(target_type, cell, input_data, None)
120
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
121
+ self.data_collector.forward_input_data_collect(api_or_cell_name, cell, pid, module_input_output)
122
+ self.inner_switch = False
123
+ return input_data
124
+
125
+ def grad_hook(cell, ori_name, param_name):
126
+ def hook_fn(grad):
127
+ if not self.should_execute_hook(target_type, cell, False):
128
+ return None
129
+ self.inner_switch = True
130
+ self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
131
+ self.inner_switch = False
132
+ return None
90
133
 
91
- self.data_collector.update_api_or_module_name(api_or_cell_name)
92
- self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
93
- if self.data_collector.if_return_forward_new_output():
94
- return self.data_collector.get_forward_new_output()
95
- if hasattr(cell, 'input_kwargs'):
96
- del cell.input_kwargs
97
- return output
134
+ return hook_fn
135
+
136
+ def register_param_hook(ori_name, cell, params_dict):
137
+ '''
138
+ 注册参数hook
139
+ '''
140
+ # data_mode为forward时,不注册参数hook
141
+ if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
142
+ for param_name, param in params_dict.items():
143
+ if param.requires_grad:
144
+ name = ori_name + Const.SEP + param_name
145
+ old_handle = self.hook_handle_dict.get(name)
146
+ if old_handle and hasattr(old_handle, "remove"):
147
+ old_handle.remove()
148
+ handle = param.register_hook(grad_hook(cell, ori_name, param_name))
149
+ self.hook_handle_dict[name] = handle
150
+
151
+ def init_params_grad_info(cell, params_dict):
152
+ '''
153
+ 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
154
+ '''
155
+ if not params_dict:
156
+ return
157
+ if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
158
+ grad_name = cell.params_grad_name if hasattr(cell, 'params_grad_name') else None
159
+ # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
160
+ if not self.params_grad_info.get(grad_name):
161
+ data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
162
+ # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
163
+ if data_info.get(grad_name):
164
+ # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
165
+ self.data_collector.handle_data(grad_name, data_info,
166
+ flush=self.data_collector.data_processor.is_terminated)
167
+ # 记录当前模块的参数梯度信息已占位
168
+ self.params_grad_info[grad_name] = True
169
+
170
+ def forward_hook(api_or_cell_name, cell, input_data, output):
171
+ if not self.should_execute_hook(target_type, cell, True):
172
+ clean_input_kwargs(cell)
173
+ return None
174
+ with _no_grad():
175
+ self.inner_switch = True
176
+ module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output)
177
+ if target_type == BaseScope.Module_Type_Module:
178
+ api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
179
+ params_dict = {}
180
+ if self.config.task != Const.STRUCTURE:
181
+ params_dict = {
182
+ key.split(Const.SEP)[-1]: value
183
+ for key, value in cell.parameters_dict(recurse=False).items()
184
+ }
185
+ setattr(module_input_output, Const.PARAMS, params_dict)
186
+ # 判断是否需要注册参数hook
187
+ if params_dict:
188
+ ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0]
189
+ grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
190
+ # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
191
+ setattr(cell, 'params_grad_name', grad_name)
192
+ register_param_hook(ori_name, cell, params_dict)
193
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
194
+ self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
195
+ init_params_grad_info(cell, params_dict)
196
+ else:
197
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
198
+ self.data_collector.forward_output_data_collect(api_or_cell_name, cell, pid, module_input_output)
199
+
200
+ if self.data_collector.if_return_forward_new_output():
201
+ forward_new_output = self.data_collector.get_forward_new_output()
202
+ self.inner_switch = False
203
+ return forward_new_output
204
+ clean_input_kwargs(cell)
205
+ self.inner_switch = False
206
+ return output
98
207
 
99
208
  def backward_hook(api_or_cell_name, cell, grad_input, grad_output):
100
- if not self.should_excute_hook():
209
+ if not self.should_execute_hook(target_type, cell, False):
101
210
  return
211
+ self.inner_switch = True
102
212
 
103
213
  need_exchange = True
104
214
  if target_type == BaseScope.Module_Type_Module:
@@ -114,12 +224,32 @@ class Service:
114
224
  else:
115
225
  module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
116
226
  self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
227
+ self.inner_switch = False
228
+
229
+ def pre_backward_hook(api_or_cell_name, cell, grad_input):
230
+ if not self.should_execute_hook(target_type, cell, False):
231
+ return
232
+ self.inner_switch = True
233
+ module_input = ModuleBackwardInputs(grad_input=grad_input)
234
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
235
+ self.data_collector.backward_input_data_collect(api_or_cell_name, cell, pid, module_input)
236
+
237
+ self.inner_switch = False
117
238
 
118
239
  pid = os.getpid()
119
- forward_name_template = name + Const.FORWARD
120
- backward_name_template = name + Const.BACKWARD
121
- forward_hook = functools.partial(forward_hook, forward_name_template)
122
- backward_hook = functools.partial(backward_hook, backward_name_template)
240
+ if target_type == BaseScope.Module_Type_Module:
241
+ full_forward_name = name + Const.FORWARD
242
+ full_backward_name = name + Const.BACKWARD
243
+ else:
244
+ full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD
245
+ full_backward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.BACKWARD
246
+ pre_forward_hook = functools.partial(pre_hook, full_forward_name)
247
+ forward_hook = functools.partial(forward_hook, full_forward_name)
248
+ backward_hook = functools.partial(backward_hook, full_backward_name)
249
+ pre_backward_hook = functools.partial(pre_backward_hook, full_backward_name)
250
+
251
+ def wrap_pre_forward_hook(cell, input_data):
252
+ return pre_forward_hook(cell, input_data)
123
253
 
124
254
  def wrap_forward_hook(cell, input_data, output_data):
125
255
  return forward_hook(cell, input_data, output_data)
@@ -127,7 +257,10 @@ class Service:
127
257
  def wrap_backward_hook(cell, grad_input, grad_output):
128
258
  return backward_hook(cell, grad_input, grad_output)
129
259
 
130
- return wrap_forward_hook, wrap_backward_hook
260
+ def wrap_pre_backward_hook(cell, grad_input):
261
+ return pre_backward_hook(cell, grad_input)
262
+
263
+ return wrap_pre_forward_hook, wrap_forward_hook, wrap_backward_hook, wrap_pre_backward_hook
131
264
 
132
265
  def update_primitive_counters(self, primitive_name):
133
266
  if primitive_name not in self.primitive_counters:
@@ -135,33 +268,25 @@ class Service:
135
268
  else:
136
269
  self.primitive_counters[primitive_name] += 1
137
270
 
138
- def register_primitive_hooks(self):
139
- primitive_set = set()
140
- for _, cell in self.model.cells_and_names():
141
- for pname, primitive in cell._primitives.items():
142
- primitive_set.add((pname, primitive))
143
-
144
- for pname, primitive in primitive_set:
145
- primitive_class_name = primitive.__class__.__name__
146
- primitive_combined_name = pname + Const.SEP + primitive_class_name
147
- new_primitive = type('NewPrimitive', (primitive.__class__,),
148
- {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
149
- primitive_combined_name)})
150
- primitive.__class__ = new_primitive
151
-
152
271
  def step(self):
272
+ if self.config.level == Const.LEVEL_DEBUG:
273
+ return
274
+ if self.config.async_dump:
275
+ self.data_collector.fill_stack_tensor_data()
276
+ if self.config.task == Const.TENSOR:
277
+ self.data_collector.data_processor.dump_async_data()
278
+ self.data_collector.write_json()
153
279
  self.current_iter += 1
154
280
  self.data_collector.update_iter(self.current_iter)
155
- self.primitive_hook_service.primitive_counters.clear()
156
- self.data_collector.data_writer.reset_cache()
157
- JitDump.jit_count = defaultdict(int)
281
+ self.reset_status()
158
282
 
159
283
  def start(self, model=None):
284
+ if self.config.level == Const.LEVEL_DEBUG:
285
+ return
160
286
  self.start_call = True
161
287
  if self.should_stop_service:
162
288
  return
163
289
  if self.need_end_service():
164
- api_register.api_set_ori_func()
165
290
  self.should_stop_service = True
166
291
  self.switch = False
167
292
  self.primitive_switch = False
@@ -181,11 +306,15 @@ class Service:
181
306
 
182
307
  if self.config.rank and self.current_rank not in self.config.rank:
183
308
  return
184
- self.register_hook_new()
309
+ self.register_primitive_hook()
310
+ self.register_cell_hook()
185
311
  if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
186
312
  JitDump.set_config(self.config)
187
313
  JitDump.set_data_collector(self.data_collector)
188
- ms.common.api._MindsporeFunctionExecutor = JitDump
314
+ if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
315
+ ms.common.api._MindsporeFunctionExecutor = JitDump
316
+ else:
317
+ ms.common.api._JitExecutor = JitDump
189
318
  ms.common.api._PyNativeExecutor.grad = JitDump.grad
190
319
  if pijit_label:
191
320
  PIJitCaptureContext.__enter__ = self.empty
@@ -200,26 +329,9 @@ class Service:
200
329
  logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
201
330
  JitDump.jit_dump_switch = True
202
331
 
203
- def forward_backward_dump_end(self):
204
- if self.should_stop_service:
205
- return
206
- logger.info(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() is set successfully. ")
207
- if not self.start_call:
208
- logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
209
- raise Exception("debugger.start() is not set in the current scope.")
210
- if not self.switch:
211
- logger.error(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() should be called between "
212
- "debugger.start() and debugger.stop() ")
213
- raise Exception("debugger.stop() is already called. ")
214
- if self.config.step and self.current_iter not in self.config.step:
215
- return
216
- if self.config.rank and self.current_rank not in self.config.rank:
217
- return
218
- self.primitive_switch = False
219
- api_register.api_set_ori_func()
220
- JitDump.jit_dump_switch = False
221
-
222
332
  def stop(self):
333
+ if self.config.level == Const.LEVEL_DEBUG:
334
+ return
223
335
  if self.should_stop_service:
224
336
  return
225
337
  logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. "
@@ -234,6 +346,10 @@ class Service:
234
346
  self.switch = False
235
347
  self.primitive_switch = False
236
348
  self.start_call = False
349
+ if self.config.async_dump:
350
+ self.data_collector.fill_stack_tensor_data()
351
+ if self.config.task == Const.TENSOR:
352
+ self.data_collector.data_processor.dump_async_data()
237
353
  self.data_collector.write_json()
238
354
  JitDump.jit_dump_switch = False
239
355
 
@@ -244,8 +360,16 @@ class Service:
244
360
  return True
245
361
  return False
246
362
 
247
- def should_excute_hook(self):
248
- if not self.switch:
363
+ def should_execute_hook(self, hook_type, cell, is_forward):
364
+ is_cell_hook = hook_type == BaseScope.Module_Type_Module
365
+ if is_cell_hook and not self.switch:
366
+ return False
367
+ elif not is_cell_hook and is_forward and not self.switch:
368
+ return False
369
+ elif not is_cell_hook and not is_forward and not cell.forward_data_collected:
370
+ return False
371
+
372
+ if self.inner_switch:
249
373
  return False
250
374
  if not self.data_collector or self.data_collector.data_processor.is_terminated:
251
375
  return False
@@ -255,6 +379,12 @@ class Service:
255
379
  create_directory(self.config.dump_path)
256
380
  self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
257
381
  cur_rank = self.current_rank if self.current_rank is not None else ''
382
+ if self.config.level == Const.LEVEL_L2:
383
+ create_directory(self.dump_iter_dir)
384
+ kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
385
+ self.config.kernel_config_path = kernel_config_path
386
+ return
387
+
258
388
  dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
259
389
  create_directory(dump_dir)
260
390
  if self.config.task in self.data_collector.tasks_need_tensor_data:
@@ -263,41 +393,151 @@ class Service:
263
393
  else:
264
394
  dump_data_dir = None
265
395
 
266
- dump_file_path = os.path.join(dump_dir, "dump.json")
267
- stack_file_path = os.path.join(dump_dir, "stack.json")
268
- construct_file_path = os.path.join(dump_dir, "construct.json")
269
- self.data_collector.update_dump_paths(
270
- dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None)
396
+ dump_path_aggregation = DumpPathAggregation()
397
+ dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
398
+ dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
399
+ dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
400
+ dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
401
+ self.data_collector.update_dump_paths(dump_path_aggregation)
402
+
403
+ self.data_collector.initialize_json_file(
404
+ framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
405
+ )
271
406
 
272
407
  def empty(self, *args, **kwargs):
273
408
  pass
274
409
 
275
- def register_hook_new(self):
276
- logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
277
- if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
410
+ def register_api_hook(self):
411
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
412
+ logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.")
278
413
  api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
279
414
  api_register.api_set_hook_func()
280
- if self.model and self.config.task in Const.DUMP_DATA_COLLECTION_LIST:
281
- self.register_primitive_hooks()
282
415
 
416
+ def get_cells_and_names(self):
417
+ cells_and_names_with_index = {}
418
+
419
+ def get_cell_or_module(model):
420
+ return model.named_modules() if is_mindtorch() else model.cells_and_names()
421
+
422
+ if isinstance(self.model, (list, tuple)):
423
+ for index, model in enumerate(self.model):
424
+ cells_and_names_with_index[str(index)] = get_cell_or_module(model)
425
+ else:
426
+ cells_and_names_with_index["-1"] = get_cell_or_module(self.model)
427
+ return cells_and_names_with_index
428
+
429
+ def register_primitive_hook(self):
430
+ if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
431
+ return
432
+ if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST:
433
+ return
434
+
435
+ primitive_set = set()
436
+ cells_and_names_with_index = self.get_cells_and_names()
437
+ for cells_and_names in cells_and_names_with_index.values():
438
+ for _, cell in cells_and_names:
439
+ for attribute, value in vars(cell).items():
440
+ if isinstance(value, Primitive):
441
+ primitive_set.add((attribute, value))
442
+
443
+ for pname, primitive in primitive_set:
444
+ primitive_class_name = primitive.__class__.__name__
445
+ primitive_combined_name = pname + Const.SEP + primitive_class_name
446
+ new_primitive = type('NewPrimitive', (primitive.__class__,),
447
+ {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
448
+ primitive_combined_name)})
449
+ primitive.__class__ = new_primitive
450
+
451
+ def register_cell_hook(self):
283
452
  if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]:
453
+ logger.info(f"The cell {self.config.task} hook function is successfully mounted to the model.")
284
454
  if not self.model:
285
455
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
286
456
  f"The current level is {self.config.level}, the model cannot be None")
287
- for name, cell in self.model.cells_and_names():
288
- if cell == self.model:
289
- continue
290
- prefix = 'Cell' + Const.SEP + name + Const.SEP + \
291
- cell.__class__.__name__ + Const.SEP
292
- forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix)
293
- cell.register_forward_hook(forward_hook)
294
- cell.register_backward_hook(backward_hook)
295
-
296
- cell.register_forward_pre_hook(
297
- self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
298
- cell.register_forward_hook(
299
- self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
300
- cell.register_backward_pre_hook(
301
- self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
302
- cell.register_backward_hook(
303
- self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
457
+ model_type = Const.MODULE if is_mindtorch() else Const.CELL
458
+ cells_and_names_with_index = self.get_cells_and_names()
459
+
460
+ for index, cells_and_names in cells_and_names_with_index.items():
461
+ model = self.model if index == "-1" else self.model[int(index)]
462
+ for name, cell in cells_and_names:
463
+ if cell == model:
464
+ continue
465
+ cell_index = (index + Const.SEP) if index != "-1" else ""
466
+ prefix = (model_type + Const.SEP + cell_index + name +
467
+ Const.SEP + cell.__class__.__name__ + Const.SEP)
468
+ _, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix)
469
+ cell.register_forward_hook(forward_hook)
470
+ cell.register_forward_pre_hook(
471
+ self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
472
+ cell.register_forward_hook(
473
+ self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
474
+
475
+ register_backward_hook_functions["full"](cell, backward_hook)
476
+ register_backward_hook_functions["pre"](
477
+ cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
478
+ register_backward_hook_functions["full"](
479
+ cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
480
+
481
+ def reset_status(self):
482
+ self.primitive_hook_service.primitive_counters.clear()
483
+ self.data_collector.reset_status()
484
+ JitDump.jit_count = defaultdict(int)
485
+ self.params_grad_info.clear()
486
+ if self.config.level == Const.LEVEL_L2:
487
+ self.data_collector.data_processor.reset_status()
488
+ return
489
+ if self.config.step and self.current_iter not in self.config.step:
490
+ return
491
+ if self.config.rank and self.current_rank not in self.config.rank:
492
+ return
493
+
494
+ def init_for_debug_level(self):
495
+ if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
496
+ return
497
+ try:
498
+ self.current_rank = get_rank_if_initialized()
499
+ except DistributedNotInitializedError:
500
+ self.current_rank = None
501
+ # dir: dump_path -- rank{} -- debug.json
502
+ self.dump_iter_dir = self.config.dump_path
503
+ cur_rank = self.current_rank if self.current_rank is not None else ''
504
+ dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
505
+ create_directory(dump_dir)
506
+ if self.config.task in self.data_collector.tasks_need_tensor_data:
507
+ dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
508
+ create_directory(dump_data_dir)
509
+ else:
510
+ dump_data_dir = None
511
+
512
+ dump_path_aggregation = DumpPathAggregation()
513
+ dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
514
+ dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
515
+ self.data_collector.update_dump_paths(dump_path_aggregation)
516
+ self.data_collector.initialize_json_file(
517
+ framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
518
+ )
519
+ self.debug_variable_counter = defaultdict(int)
520
+
521
+ def save(self, variable, name, save_backward):
522
+ '''
523
+ Args:
524
+ variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int]
525
+ name: str
526
+ save_backward: boolean
527
+ Return:
528
+ void
529
+ '''
530
+ if self.config.level != Const.LEVEL_DEBUG:
531
+ return
532
+ count = self.debug_variable_counter[name]
533
+ self.debug_variable_counter[name] += 1
534
+
535
+ name_with_count = f"{name}.{count}"
536
+ grad_name_with_count = f"{name}_grad.{count}"
537
+
538
+ # forward save
539
+ self.data_collector.debug_data_collect_forward(variable, name_with_count)
540
+
541
+ # backward save
542
+ if save_backward:
543
+ self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)