mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -17,6 +17,9 @@ import inspect
17
17
  import os
18
18
  from dataclasses import dataclass, is_dataclass
19
19
  from typing import Tuple, Dict, Optional, Any
20
+ from functools import partial
21
+ import copy
22
+ from typing import Union
20
23
 
21
24
  import numpy as np
22
25
 
@@ -39,9 +42,8 @@ class ModuleForwardInputsOutputs:
39
42
  def output_tuple(self):
40
43
  return convert_tuple(self.output)
41
44
 
42
- def concat_args_and_kwargs(self):
43
- args = self.args + tuple(self.kwargs.values())
44
- return args
45
+ def update_output_with_args_and_kwargs(self):
46
+ self.output = self.args + tuple(self.kwargs.values())
45
47
 
46
48
 
47
49
  @dataclass
@@ -77,17 +79,18 @@ class ModuleBackwardOutputs:
77
79
 
78
80
 
79
81
  class TensorStatInfo:
80
- def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
82
+ def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None, stack_tensor_stat=None):
81
83
  self.max = max_val
82
84
  self.min = min_val
83
85
  self.mean = mean_val
84
86
  self.norm = norm_val
87
+ self.stack_tensor_stat = stack_tensor_stat
85
88
 
86
89
 
87
90
  class BaseDataProcessor:
88
91
  _recursive_key_stack = []
89
92
  special_type = (
90
- np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
93
+ np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray,
91
94
  bool, int, float, str, slice,
92
95
  type(Ellipsis)
93
96
  )
@@ -102,6 +105,7 @@ class BaseDataProcessor:
102
105
  self.current_iter = 0
103
106
  self._return_forward_new_output = False
104
107
  self._forward_new_output = None
108
+ self.save_name = None
105
109
  if hasattr(config, "data_mode"):
106
110
  self.allowed_data_mode = self._get_allowed_data_mode(config.data_mode)
107
111
 
@@ -142,6 +146,37 @@ class BaseDataProcessor:
142
146
  else:
143
147
  return data
144
148
 
149
+ @staticmethod
150
+ def set_value_into_nested_structure(data_structure, indexes, value):
151
+ '''
152
+ Args:
153
+ data_structure: nested data structure
154
+ indexes: List
155
+ value: value to be set
156
+ '''
157
+ if not indexes:
158
+ raise ValueError("set_value_into_nested_structure failed: "
159
+ "indexes need to be non empty when set value to nested data structure")
160
+ current_level = data_structure
161
+ for i, index in enumerate(indexes):
162
+ valid_for_list = isinstance(current_level, list) and isinstance(index, int) and len(current_level) > index
163
+ valid_for_dict = isinstance(current_level, dict) and index in current_level
164
+ is_last = i == len(indexes) - 1
165
+ if valid_for_dict or valid_for_list:
166
+ if is_last:
167
+ try:
168
+ current_level[index] = value
169
+ except Exception as e:
170
+ raise IndexError("set_value_into_nested_structure failed: passed indexes wrong") from e
171
+ else:
172
+ try:
173
+ current_level = current_level[index]
174
+ except Exception as e:
175
+ raise IndexError("set_value_into_nested_structure failed: passed indexes wrong") from e
176
+ else:
177
+ raise ValueError("set_value_into_nested_structure failed: "
178
+ "invalid data_structure type or invalid index")
179
+
145
180
  @staticmethod
146
181
  def _convert_numpy_to_builtin(arg):
147
182
  type_mapping = {
@@ -182,8 +217,22 @@ class BaseDataProcessor:
182
217
  return single_arg
183
218
 
184
219
  @staticmethod
185
- def _analyze_numpy(value, numpy_type):
186
- return {"type": numpy_type, "value": value}
220
+ def _analyze_numpy(ndarray, numpy_type):
221
+ ndarray_json = {}
222
+ ndarray_json.update({'type': 'numpy.ndarray'})
223
+ ndarray_json.update({'dtype': str(ndarray.dtype)})
224
+ ndarray_json.update({'shape': ndarray.shape})
225
+ if ndarray.size > 0:
226
+ ndarray_json.update({"Max": np.max(ndarray).item()})
227
+ ndarray_json.update({"Min": np.min(ndarray).item()})
228
+ ndarray_json.update({"Mean": np.mean(ndarray).item()})
229
+ ndarray_json.update({"Norm": np.linalg.norm(ndarray).item()})
230
+ else:
231
+ ndarray_json.update({"Max": None})
232
+ ndarray_json.update({"Min": None})
233
+ ndarray_json.update({"Mean": None})
234
+ ndarray_json.update({"Norm": None})
235
+ return ndarray_json
187
236
 
188
237
  @staticmethod
189
238
  def _get_allowed_data_mode(data_mode):
@@ -202,7 +251,7 @@ class BaseDataProcessor:
202
251
  return cls.special_type
203
252
 
204
253
  @classmethod
205
- def recursive_apply_transform(cls, args, transform, depth=0):
254
+ def recursive_apply_transform(cls, args, transform, depth=0) -> Union[dict, list, None]:
206
255
  if depth > Const.MAX_DEPTH:
207
256
  logger.error(f"The maximum depth of recursive transform, {Const.MAX_DEPTH} is reached.")
208
257
  raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
@@ -219,20 +268,20 @@ class BaseDataProcessor:
219
268
  return cls.apply_transform_dict(args_dict, transform, depth)
220
269
  elif isinstance(args, (list, tuple)):
221
270
  result_list = cls.apply_transform_list(args, transform, depth)
222
- return type(args)(result_list)
271
+ return result_list
223
272
  elif isinstance(args, dict):
224
273
  return cls.apply_transform_dict(args, transform, depth)
225
274
  elif args is not None:
226
- logger.warning(f"Data type {type(args)} is not supported.")
275
+ logger.debug(f"Data type {type(args)} is not supported.")
227
276
  return None
228
277
  else:
229
278
  return None
230
-
279
+
231
280
  @classmethod
232
281
  def apply_transform_dict(cls, args, transform, depth):
233
282
  result_dict = {}
234
283
  for k, arg in args.items():
235
- cls._recursive_key_stack.append(str(k))
284
+ cls._recursive_key_stack.append(k)
236
285
  result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
237
286
  cls._recursive_key_stack.pop()
238
287
  return result_dict
@@ -241,11 +290,21 @@ class BaseDataProcessor:
241
290
  def apply_transform_list(cls, args, transform, depth):
242
291
  result_list = []
243
292
  for i, arg in enumerate(args):
244
- cls._recursive_key_stack.append(str(i))
293
+ cls._recursive_key_stack.append(i)
245
294
  result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
246
295
  cls._recursive_key_stack.pop()
247
296
  return result_list
248
297
 
298
+ @classmethod
299
+ def register_hook_single_element(cls, element, suffix_stack, hook_fn):
300
+ if cls.is_hookable_element(element):
301
+ indexes = copy.deepcopy(suffix_stack)
302
+ wrap_hook_fn = partial(hook_fn, indexes=indexes)
303
+
304
+ def real_hook_fn(grad):
305
+ return wrap_hook_fn(grad)
306
+ element.register_hook(real_hook_fn)
307
+
249
308
  def if_return_forward_new_output(self):
250
309
  return self._return_forward_new_output
251
310
 
@@ -273,13 +332,10 @@ class BaseDataProcessor:
273
332
  """
274
333
  return forward_backward in self.allowed_data_mode and input_output in self.allowed_data_mode
275
334
 
276
- def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
277
- pass
278
-
279
335
  def analyze_element(self, element):
280
336
  return self.recursive_apply_transform(element, self.analyze_single_element)
281
337
 
282
- def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
338
+ def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
283
339
  api_info_struct = {}
284
340
  # check whether data_mode contains forward or input
285
341
  if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
@@ -291,16 +347,22 @@ class BaseDataProcessor:
291
347
  kwargs_info_list = self.analyze_element(module_input_output.kwargs)
292
348
  api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
293
349
 
294
- # check whether data_mode contains forward or output
350
+ return api_info_struct
351
+
352
+ def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
353
+ api_info_struct = {}
354
+ # check whether data_mode contains forward or input
295
355
  if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
296
- api_info_struct[name] = api_info_struct.get(name, {})
356
+ api_info_struct[name] = {}
297
357
  self.api_data_category = Const.OUTPUT
298
358
  output_info_list = self.analyze_element(module_input_output.output_tuple)
299
359
  api_info_struct[name][Const.OUTPUT] = output_info_list
360
+
300
361
  return api_info_struct
301
362
 
302
- def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
363
+ def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
303
364
  api_info_struct = {}
365
+ # check whether data_mode contains forward or input
304
366
  if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
305
367
  api_info_struct[name] = {}
306
368
  self.api_data_category = Const.INPUT
@@ -309,16 +371,18 @@ class BaseDataProcessor:
309
371
  self.api_data_category = Const.KWARGS
310
372
  kwargs_info_list = self.analyze_element(module_input_output.kwargs)
311
373
  api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
312
- return api_info_struct
313
374
 
314
- def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
315
- concat_args = module_input_output.concat_args_and_kwargs()
316
- api_info_struct = {}
375
+ # check whether data_mode contains forward or output
317
376
  if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
318
- api_info_struct[name] = {}
377
+ api_info_struct[name] = api_info_struct.get(name, {})
319
378
  self.api_data_category = Const.OUTPUT
320
- output_info_list = self.analyze_element(concat_args)
379
+ output_info_list = self.analyze_element(module_input_output.output_tuple)
321
380
  api_info_struct[name][Const.OUTPUT] = output_info_list
381
+
382
+ if name in api_info_struct and hasattr(module_input_output, Const.PARAMS):
383
+ self.api_data_category = Const.PARAMS
384
+ api_info_struct[name][Const.PARAMS] = self.analyze_element(getattr(module_input_output, Const.PARAMS))
385
+
322
386
  return api_info_struct
323
387
 
324
388
  def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
@@ -359,9 +423,47 @@ class BaseDataProcessor:
359
423
  api_info_struct[name][Const.OUTPUT] = output_info_list
360
424
  return api_info_struct
361
425
 
426
+ def analyze_params(self, name, param_name, grad):
427
+ api_info_struct = {}
428
+ self.save_name = name + Const.SEP + param_name
429
+ data_info = self.analyze_element(grad)
430
+ grad_info_dict = {param_name: [data_info]}
431
+ api_info_struct[name] = grad_info_dict
432
+ return api_info_struct
433
+
362
434
  def get_save_file_path(self, suffix):
363
435
  file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
364
- dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
365
- suffix + file_format)
436
+ if self.save_name is not None:
437
+ dump_data_name = (self.save_name + file_format)
438
+ self.save_name = None
439
+ else:
440
+ dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
441
+ suffix + file_format)
366
442
  file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
367
443
  return dump_data_name, file_path
444
+
445
+ def analyze_element_to_all_none(self, element):
446
+ return self.recursive_apply_transform(element, lambda element, stack: None)
447
+
448
+ def analyze_debug_forward(self, variable, name_with_count):
449
+ self.current_api_or_module_name = name_with_count
450
+ self.api_data_category = Const.TENSOR
451
+ # these two attributes are used to construct tensor file name {name_with_count}.tensor.{indexes}.npy/pt
452
+ data_info = self.analyze_element(variable)
453
+ return data_info
454
+
455
+ def analyze_debug_backward(self, variable, grad_name_with_count, nested_data_structure):
456
+ def hook_fn(grad, indexes):
457
+ suffix = Const.SEP.join([str(index) for index in indexes])
458
+ self.save_name = grad_name_with_count + Const.SEP + Const.TENSOR + Const.SEP + suffix
459
+ grad_data_info = self.analyze_element(grad)
460
+ self.save_name = None
461
+ full_index = [grad_name_with_count] + indexes
462
+ try:
463
+ self.set_value_into_nested_structure(nested_data_structure, full_index, grad_data_info)
464
+ except (ValueError, IndexError) as e:
465
+ logger.warning(f"error occured while recording statistics of {grad_name_with_count} variable, "
466
+ f"skip current recording, detailed infomation: {e}")
467
+ return grad
468
+ wrap_register_hook_single_element = partial(self.register_hook_single_element, hook_fn=hook_fn)
469
+ self.recursive_apply_transform(variable, wrap_register_hook_single_element)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from msprobe.core.common.const import Const
17
+ from msprobe.core.data_dump.data_processor.base import BaseDataProcessor
17
18
 
18
19
 
19
20
  class DataProcessorFactory:
@@ -56,21 +57,25 @@ class DataProcessorFactory:
56
57
  FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor,
57
58
  KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
58
59
  )
59
- from msprobe.pytorch.module_processer import ModuleProcesser
60
+ from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
60
61
  cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor)
61
62
  cls.register_processor(Const.PT_FRAMEWORK, Const.TENSOR, PytorchTensorDataProcessor)
62
63
  cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
63
64
  cls.register_processor(Const.PT_FRAMEWORK, Const.FREE_BENCHMARK, PytorchFreeBenchmarkDataProcessor)
64
65
  cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor)
66
+ cls.register_processor(Const.PT_FRAMEWORK, Const.STRUCTURE, BaseDataProcessor)
65
67
  cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
66
68
  elif framework == Const.MS_FRAMEWORK:
67
69
  from msprobe.core.data_dump.data_processor.mindspore_processor import (
68
70
  StatisticsDataProcessor as MindsporeStatisticsDataProcessor,
69
71
  TensorDataProcessor as MindsporeTensorDataProcessor,
70
- OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor
72
+ OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor,
73
+ KernelDumpDataProcessor as MindsporeKernelDumpDataProcessor
71
74
  )
72
75
  from msprobe.mindspore.cell_processor import CellProcessor
73
76
  cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor)
74
77
  cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
75
78
  cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
79
+ cls.register_processor(Const.MS_FRAMEWORK, Const.KERNEL_DUMP, MindsporeKernelDumpDataProcessor)
80
+ cls.register_processor(Const.MS_FRAMEWORK, Const.STRUCTURE, BaseDataProcessor)
76
81
  cls.register_module_processor(Const.MS_FRAMEWORK, CellProcessor)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright 2024-2025 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,18 +16,24 @@
16
16
  import zlib
17
17
 
18
18
  import mindspore as ms
19
- from mindspore import mint, ops
19
+ from mindspore import mint, ops, hal
20
20
  from mindspore._c_expression.typing import Number
21
21
  import numpy as np
22
22
 
23
23
  from msprobe.core.common.const import Const
24
24
  from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo,
25
25
  ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs)
26
- from msprobe.core.common.file_utils import path_len_exceeds_limit
26
+ from msprobe.core.common.file_utils import path_len_exceeds_limit, save_npy
27
27
  from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy
28
28
  from msprobe.mindspore.common.log import logger
29
29
  from msprobe.mindspore.dump.hook_cell.api_registry import api_register
30
30
 
31
+ has_adump = True
32
+ try:
33
+ from msprobe.lib import _msprobe_c
34
+ except ImportError:
35
+ has_adump = False
36
+
31
37
 
32
38
  class MindsporeDataProcessor(BaseDataProcessor):
33
39
  mindspore_special_type = tuple([ms.Tensor, Number])
@@ -37,11 +43,12 @@ class MindsporeDataProcessor(BaseDataProcessor):
37
43
  self.mindspore_object_key = {
38
44
  "dtype": self.analyze_dtype_in_kwargs
39
45
  }
46
+ self._async_dump_cache = {}
40
47
 
41
48
  @staticmethod
42
49
  def get_md5_for_tensor(x):
43
50
  x = convert_bf16_to_fp32(x)
44
- tensor_bytes = x.contiguous().asnumpy().tobytes()
51
+ tensor_bytes = x.asnumpy().tobytes()
45
52
  crc32_hash = zlib.crc32(tensor_bytes)
46
53
  return f"{crc32_hash:08x}"
47
54
 
@@ -49,22 +56,17 @@ class MindsporeDataProcessor(BaseDataProcessor):
49
56
  def analyze_dtype_in_kwargs(element):
50
57
  return {"type": "mindspore.dtype", "value": str(element)}
51
58
 
52
- @classmethod
53
- def get_special_types(cls):
54
- return super().get_special_types() + cls.mindspore_special_type
55
-
56
- def get_stat_info(self, data):
59
+ @staticmethod
60
+ def get_stat_info_sync(data):
57
61
  tensor_stat = TensorStatInfo()
58
- if data.numel() == 0:
59
- return tensor_stat
60
- elif data.dtype == ms.bool_:
61
- data_np = data.contiguous().asnumpy()
62
+ if data.dtype == ms.bool_:
63
+ data_np = data.asnumpy()
62
64
  tensor_stat.max = np.max(data_np).item()
63
65
  tensor_stat.min = np.min(data_np).item()
64
66
  elif not data.shape:
65
67
  tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
66
68
  elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
67
- data_abs = np.abs(data.contiguous().asnumpy())
69
+ data_abs = np.abs(data.asnumpy())
68
70
  tensor_stat.max = np.max(data_abs).item()
69
71
  tensor_stat.min = np.min(data_abs).item()
70
72
  tensor_stat.mean = np.mean(data_abs).item()
@@ -87,17 +89,64 @@ class MindsporeDataProcessor(BaseDataProcessor):
87
89
  api_register.norm_inner_op_set_hook_func()
88
90
  return tensor_stat
89
91
 
92
+ @staticmethod
93
+ def get_stat_info_async(data):
94
+ tensor_stat = TensorStatInfo()
95
+ stack_method = api_register.functional_ori_attr.get("stack", ms.ops.stack)
96
+ if data.dtype == ms.complex64 or data.dtype == ms.complex128:
97
+ logger.warning("Async dump do not support complex data!")
98
+ return tensor_stat
99
+ elif data.dtype == ms.bool_:
100
+ tensor_stat.stack_tensor_stat = (["Max", "Min"], stack_method([data.any(), data.all()]))
101
+ elif not data.shape:
102
+ tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method([data, data, data, data]))
103
+ else:
104
+ if not ops.is_floating_point(data) or data.dtype == ms.float64:
105
+ data = data.to(ms.float32)
106
+ api_register.norm_inner_op_set_ori_func()
107
+ get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
108
+ get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min)
109
+ get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean)
110
+ if hasattr(mint, "norm"):
111
+ get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
112
+ else:
113
+ get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
114
+ tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method(
115
+ [get_max_value(data), get_min_value(data), get_mean_value(data), get_norm_value(data)]))
116
+ api_register.norm_inner_op_set_hook_func()
117
+ return tensor_stat
118
+
119
+ @staticmethod
120
+ def is_hookable_element(element):
121
+ return hasattr(element, "register_hook") and callable(element.register_hook)
122
+
123
+ @classmethod
124
+ def get_special_types(cls):
125
+ return super().get_special_types() + cls.mindspore_special_type
126
+
127
+ def get_stat_info(self, data):
128
+ tensor_stat = TensorStatInfo()
129
+ if data.numel() == 0:
130
+ return tensor_stat
131
+ else:
132
+ if self.config.async_dump:
133
+ return MindsporeDataProcessor.get_stat_info_async(data)
134
+ else:
135
+ return MindsporeDataProcessor.get_stat_info_sync(data)
136
+
90
137
  def analyze_single_element(self, element, suffix_stack):
91
138
  if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
92
139
  return self.mindspore_object_key[suffix_stack[-1]](element)
93
140
 
94
141
  converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
95
142
  if converted_numpy is not element:
96
- return self._analyze_numpy(converted_numpy, numpy_type)
143
+ return {"type": numpy_type, "value": converted_numpy}
97
144
  if isinstance(element, Number):
98
145
  return self.analyze_dtype_in_kwargs(element)
99
146
  if isinstance(element, ms.Tensor):
100
- return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
147
+ return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
148
+ if isinstance(element, np.ndarray):
149
+ return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
101
150
  if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
102
151
  return self._analyze_builtin(element)
103
152
  return {}
@@ -107,13 +156,17 @@ class MindsporeDataProcessor(BaseDataProcessor):
107
156
  tensor_json = {
108
157
  'type': 'mindspore.Tensor',
109
158
  'dtype': str(tensor.dtype),
110
- 'shape': tensor.shape,
111
- 'Max': self.transfer_type(tensor_stat.max),
112
- 'Min': self.transfer_type(tensor_stat.min),
113
- 'Mean': self.transfer_type(tensor_stat.mean),
114
- 'Norm': self.transfer_type(tensor_stat.norm),
159
+ 'shape': tensor.shape
115
160
  }
116
- if self.config.summary_mode == Const.MD5:
161
+
162
+ if tensor_stat.stack_tensor_stat is None:
163
+ tensor_json.update({'Max': self.transfer_type(tensor_stat.max)})
164
+ tensor_json.update({'Min': self.transfer_type(tensor_stat.min)})
165
+ tensor_json.update({'Mean': self.transfer_type(tensor_stat.mean)})
166
+ tensor_json.update({'Norm': self.transfer_type(tensor_stat.norm)})
167
+ else:
168
+ tensor_json.update({'tensor_stat': tensor_stat.stack_tensor_stat})
169
+ if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
117
170
  tensor_md5 = self.get_md5_for_tensor(tensor)
118
171
  tensor_json.update({Const.MD5: tensor_md5})
119
172
  return tensor_json
@@ -124,12 +177,27 @@ class StatisticsDataProcessor(MindsporeDataProcessor):
124
177
 
125
178
 
126
179
  class TensorDataProcessor(MindsporeDataProcessor):
180
+ def dump_async_data(self):
181
+ for file_path, tensor in self._async_dump_cache.items():
182
+ save_tensor_as_npy(tensor, file_path)
183
+ self._async_dump_cache.clear()
184
+
127
185
  def _analyze_tensor(self, tensor, suffix):
128
186
  dump_data_name, file_path = self.get_save_file_path(suffix)
129
187
  single_arg = super()._analyze_tensor(tensor, suffix)
130
188
  single_arg.update({"data_name": dump_data_name})
131
- save_tensor_as_npy(tensor, file_path)
189
+ if self.config.async_dump:
190
+ self._async_dump_cache[file_path] = tensor.copy()
191
+ else:
192
+ save_tensor_as_npy(tensor, file_path)
132
193
  return single_arg
194
+
195
+ def _analyze_numpy(self, ndarray, suffix):
196
+ dump_data_name, file_path = self.get_save_file_path(suffix)
197
+ save_npy(ndarray, file_path)
198
+ ndarray_json = super()._analyze_numpy(ndarray, suffix)
199
+ ndarray_json.update({"data_name": dump_data_name})
200
+ return ndarray_json
133
201
 
134
202
 
135
203
  class OverflowCheckDataProcessor(MindsporeDataProcessor):
@@ -138,6 +206,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
138
206
  def __init__(self, config, data_writer):
139
207
  super().__init__(config, data_writer)
140
208
  self.has_overflow = False
209
+ self.cached_api_info = {}
141
210
  self.cached_tensors_and_file_paths = {}
142
211
  self.real_overflow_nums = 0
143
212
  self.overflow_nums = config.overflow_nums
@@ -150,6 +219,20 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
150
219
  return True
151
220
  return False
152
221
 
222
+ def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
223
+ self.has_overflow = False
224
+ self.cached_api_info = super().analyze_forward_input(name, module, module_input_output)
225
+ return None
226
+
227
+ def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
228
+ api_info_struct = super().analyze_forward_output(name, module, module_input_output)
229
+ if name in self.cached_api_info and name in api_info_struct:
230
+ self.cached_api_info[name].update(api_info_struct[name])
231
+ elif name in api_info_struct:
232
+ self.cached_api_info = api_info_struct
233
+ self.maybe_save_overflow_data()
234
+ return self.cached_api_info if self.has_overflow else None
235
+
153
236
  def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
154
237
  self.has_overflow = False
155
238
  api_info_struct = super().analyze_forward(name, module, module_input_output)
@@ -162,6 +245,12 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
162
245
  self.maybe_save_overflow_data()
163
246
  return api_info_struct if self.has_overflow else None
164
247
 
248
+ def analyze_params(self, name, param_name, grad):
249
+ self.has_overflow = False
250
+ api_info_struct = super().analyze_params(name, param_name, grad)
251
+ self.maybe_save_overflow_data()
252
+ return api_info_struct if self.has_overflow else None
253
+
165
254
  def maybe_save_overflow_data(self):
166
255
  if self.has_overflow:
167
256
  for file_path, tensor in self.cached_tensors_and_file_paths.items():
@@ -190,3 +279,61 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
190
279
  self._analyze_maybe_overflow_tensor(single_arg)
191
280
  single_arg.update({"data_name": dump_data_name})
192
281
  return single_arg
282
+
283
+
284
+ class KernelDumpDataProcessor(MindsporeDataProcessor):
285
+ def __init__(self, config, data_writer):
286
+ super().__init__(config, data_writer)
287
+ self.enable_kernel_dump = True
288
+
289
+ @staticmethod
290
+ def start_kernel_dump(config_path):
291
+ hal.synchronize()
292
+ _msprobe_c.init_dump()
293
+ _msprobe_c.set_dump(config_path)
294
+ hal.synchronize()
295
+
296
+ @staticmethod
297
+ def stop_kernel_dump():
298
+ hal.synchronize()
299
+ _msprobe_c.finalize_dump()
300
+ hal.synchronize()
301
+
302
+ @staticmethod
303
+ def _print_unsupported_log(api_name):
304
+ logger.warning(f"The kernel dump does not support the {api_name} API.")
305
+
306
+ def analyze_forward_input(self, name, module, module_input_output):
307
+ if not self.enable_kernel_dump:
308
+ return
309
+ if not has_adump:
310
+ logger.warning("The current msprobe package does not compile adump, and kernel dump cannot be used.")
311
+ self.enable_kernel_dump = False
312
+ return
313
+ self.start_kernel_dump(self.config.kernel_config_path)
314
+
315
+ def analyze_forward_output(self, name, module, module_input_output):
316
+ if not self.enable_kernel_dump:
317
+ return
318
+ self.enable_kernel_dump = False
319
+ self.stop_kernel_dump()
320
+ logger.info(f"The kernel data of {name} is dumped successfully.")
321
+
322
+ def analyze_backward_input(self, name, module, module_input_output):
323
+ if not self.enable_kernel_dump:
324
+ return
325
+ if not has_adump:
326
+ logger.warning("The current msprobe package does not compile adump, and kernel dump cannot be used.")
327
+ self.enable_kernel_dump = False
328
+ return
329
+ self.start_kernel_dump(self.config.kernel_config_path)
330
+
331
+ def analyze_backward(self, name, module, module_input_output):
332
+ if not self.enable_kernel_dump:
333
+ return
334
+ self.enable_kernel_dump = False
335
+ self.stop_kernel_dump()
336
+ logger.info(f"The kernel data of {name} is dumped successfully.")
337
+
338
+ def reset_status(self):
339
+ self.enable_kernel_dump = True