mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -14,21 +14,45 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import threading
17
+ from collections import OrderedDict
18
+
19
+ import mindspore as ms
20
+ from mindspore import Tensor
21
+ from mindspore.common.api import _no_grad, _pynative_executor
22
+ from mindspore.ops.operations import _inner_ops as inner
17
23
 
18
- from mindspore.common.api import _no_grad
19
24
  from msprobe.core.common.const import Const
25
+ from msprobe.core.common.log import logger
20
26
  from msprobe.core.common.utils import replace_last_occurrence, ThreadSafe
21
27
  from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputs
22
28
  from msprobe.core.hook_manager import BaseHookManager, HookSet
23
- from msprobe.mindspore.common.utils import has_kwargs_in_forward_hook
29
+ from msprobe.mindspore.common.const import Const as MsConst
30
+ from msprobe.mindspore.common.utils import (
31
+ has_kwargs_in_forward_hook,
32
+ is_mindtorch,
33
+ is_backward_hook_output_a_view
34
+ )
24
35
  from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
25
36
 
37
+ ms_version = ms.__version__
38
+
39
+
40
+ class MindsporeHookManager(BaseHookManager):
41
+ cell_bw_hook_kernels = {}
42
+ cell_backward_pre_hook = []
43
+ cell_backward_hook = []
26
44
 
27
- class MindsproeHookManager(BaseHookManager):
28
45
  @property
29
46
  def _is_recompute(self):
30
47
  return None
31
48
 
49
+ @staticmethod
50
+ def reset_status():
51
+ BaseHookManager.reset_status()
52
+ MindsporeHookManager.cell_bw_hook_kernels.clear()
53
+ MindsporeHookManager.cell_backward_pre_hook.clear()
54
+ MindsporeHookManager.cell_backward_hook.clear()
55
+
32
56
  @staticmethod
33
57
  def _no_grad_context():
34
58
  return _no_grad()
@@ -38,9 +62,13 @@ class MindsproeHookManager(BaseHookManager):
38
62
  HOOKCell.add_cell_count(name)
39
63
 
40
64
  @staticmethod
41
- def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs):
65
+ def _get_count(name):
66
+ return HOOKCell.get_cell_count(name)
67
+
68
+ @staticmethod
69
+ def _process_kwargs_and_output(module, tid, hook_type, kwargs_or_output, output_or_kwargs):
42
70
  if not has_kwargs_in_forward_hook() or hook_type == Const.API:
43
- kwargs = module.msprobe_input_kwargs if hasattr(module, 'msprobe_input_kwargs') else {}
71
+ kwargs = module.msprobe_input_kwargs.get(tid, {}) if hasattr(module, 'msprobe_input_kwargs') else {}
44
72
  output = kwargs_or_output
45
73
  else:
46
74
  kwargs = kwargs_or_output
@@ -49,17 +77,107 @@ class MindsproeHookManager(BaseHookManager):
49
77
 
50
78
  def build_hook(self, hook_type, name):
51
79
  if hook_type == Const.API:
52
- full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD
80
+ hook_set = HookSet(
81
+ forward_pre_hook=self._build_forward_pre_hook(hook_type, name)
82
+ )
53
83
  else:
54
- full_forward_name = name
55
- full_backward_name = replace_last_occurrence(full_forward_name, Const.FORWARD, Const.BACKWARD)
56
- hookset = HookSet(
57
- forward_hook=self._build_forward_hook(hook_type, full_forward_name),
58
- forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name),
59
- backward_hook=self._build_backward_hook(hook_type, full_backward_name),
60
- backward_pre_hook=self._build_backward_pre_hook(hook_type, full_backward_name)
84
+ full_backward_name = replace_last_occurrence(name, Const.FORWARD, Const.BACKWARD)
85
+ hook_set = HookSet(
86
+ forward_hook=self._build_forward_hook(hook_type, name),
87
+ backward_pre_hook=self._build_backward_pre_hook(hook_type, full_backward_name),
88
+ backward_hook=self._build_backward_hook(hook_type, full_backward_name)
89
+ )
90
+ return hook_set
91
+
92
+ def _register_forward_hook(self, module, api_name):
93
+ if not hasattr(module, 'msprobe_forward_hook'):
94
+ forward_hook = self._build_forward_hook(Const.API, api_name)
95
+ if ms_version < "2.6.0" and not is_mindtorch():
96
+ getattr(module, "_forward_hook", {})[id(module)] = forward_hook
97
+ else:
98
+ module.register_forward_hook(forward_hook)
99
+ setattr(module, 'msprobe_forward_hook', True)
100
+
101
+ def _register_backward_hook(self, module, full_backward_name, args):
102
+ if not _pynative_executor.requires_grad():
103
+ return args
104
+
105
+ enable_hooked = sum(
106
+ [isinstance(ele, Tensor) and ele.dtype not in MsConst.NonDifferentiableType for ele in args]
107
+ )
108
+
109
+ if enable_hooked:
110
+ backward_hook_dict = OrderedDict()
111
+ backward_hook_dict[full_backward_name] = self._build_backward_hook(Const.API, full_backward_name)
112
+ MindsporeHookManager.cell_backward_hook.append(backward_hook_dict)
113
+ bw_hook = inner.CellBackwardHook(full_backward_name, module, MindsporeHookManager.cell_backward_hook[-1])
114
+ bw_hook.register_backward_hook()
115
+ MindsporeHookManager.cell_bw_hook_kernels[full_backward_name] = bw_hook
116
+ args = bw_hook(args) if is_backward_hook_output_a_view() else bw_hook(*args)
117
+ return args
118
+
119
+ def _register_backward_pre_hook(self, module, full_backward_name, output):
120
+ if not _pynative_executor.requires_grad():
121
+ return output
122
+
123
+ bw_hook = MindsporeHookManager.cell_bw_hook_kernels.get(full_backward_name)
124
+ if bw_hook:
125
+ if not isinstance(output, (Tensor, tuple)):
126
+ logger.debug("For backward hooks to be called, "
127
+ "cell output should be a Tensor or a tuple of Tensors "
128
+ f"but received {type(output)}")
129
+ if is_backward_hook_output_a_view():
130
+ new_outputs = bw_hook(output)
131
+ else:
132
+ if isinstance(output, tuple):
133
+ new_outputs = bw_hook(*output)
134
+ else:
135
+ new_outputs = bw_hook(output)
136
+ if isinstance(output, tuple) and len(output) == 1:
137
+ new_outputs = (new_outputs,)
138
+ output = new_outputs
139
+
140
+ def get_backward_pre_hook(backward_pre_hook, backward_post_hook):
141
+ @ThreadSafe.synchronized
142
+ def backward_pre_hook_fn(cell, grad_output):
143
+ backward_pre_hook(cell, grad_output)
144
+ if backward_post_hook:
145
+ backward_post_hook(cell, (), grad_output)
146
+
147
+ return backward_pre_hook_fn
148
+
149
+ backward_pre_hook = self._build_backward_pre_hook(Const.API, full_backward_name)
150
+ backward_post_hook = None if bw_hook else self._build_backward_hook(Const.API, full_backward_name)
151
+
152
+ backward_pre_hook_dict = OrderedDict()
153
+ backward_pre_hook_dict[full_backward_name] = get_backward_pre_hook(
154
+ backward_pre_hook,
155
+ backward_post_hook
156
+ )
157
+ MindsporeHookManager.cell_backward_pre_hook.append(backward_pre_hook_dict)
158
+ bw_pre_hook = inner.CellBackwardHook(
159
+ full_backward_name,
160
+ module,
161
+ MindsporeHookManager.cell_backward_pre_hook[-1]
61
162
  )
62
- return hookset
163
+ bw_pre_hook.register_backward_pre_hook()
164
+
165
+ if is_backward_hook_output_a_view():
166
+ result = bw_pre_hook(output)
167
+ else:
168
+ if isinstance(output, tuple):
169
+ result = bw_pre_hook(*output)
170
+ else:
171
+ result = bw_pre_hook(output)
172
+ if isinstance(output, tuple):
173
+ if len(output) == 1:
174
+ result = (result,)
175
+ if len(result) != len(output):
176
+ raise TypeError(
177
+ f"The backward pre hook return value size is {len(result)} "
178
+ f"not equal to output size {len(output)}"
179
+ )
180
+ return result
63
181
 
64
182
  def _need_exchange(self, module):
65
183
  if not hasattr(module, 'has_pre_hook_called') or not module.has_pre_hook_called:
@@ -71,23 +189,26 @@ class MindsproeHookManager(BaseHookManager):
71
189
  params_dict = {}
72
190
  if self.config.task != Const.STRUCTURE:
73
191
  params_dict = {
74
- key.split(Const.SEP)[-1]: value
75
- for key, value in module.parameters_dict(recurse=False).items()
76
- }
192
+ key.split(Const.SEP)[-1]: value
193
+ for key, value in module.parameters_dict(recurse=False).items()
194
+ }
77
195
  return params_dict
78
196
 
79
- def _build_backward_pre_hook(self, hook_type, name):
197
+ def _build_backward_pre_hook(self, hook_type, full_name):
80
198
  def backward_pre_hook(module, grad_input):
81
199
  if self.config.level != Const.LEVEL_L2:
82
200
  return
83
201
  tid = threading.get_ident()
84
- if not self._should_execute_hook(hook_type, module, False, tid):
202
+ if not self._should_execute_hook(tid):
85
203
  return
86
204
 
87
205
  with ThreadSafe():
206
+ original_state = self.ensure_gc_enabled()
88
207
  BaseHookManager.inner_switch[tid] = True
89
208
  module_input = ModuleBackwardInputs(grad_input=grad_input)
90
- self.data_collector.update_api_or_module_name(name)
91
- self.data_collector.backward_input_data_collect(name, module, self._pid, module_input)
209
+ self.data_collector.update_api_or_module_name(full_name)
210
+ self.data_collector.backward_input_data_collect(full_name, module, self._pid, module_input)
92
211
  BaseHookManager.inner_switch[tid] = False
212
+ self.restore_gc_state(original_state)
213
+
93
214
  return backward_pre_hook
@@ -20,6 +20,9 @@ from msprobe.core.common.file_utils import create_directory, save_json
20
20
  from msprobe.mindspore.common.log import logger
21
21
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
22
22
 
23
+ import mindspore as ms
24
+ ms_version = ms.__version__
25
+
23
26
 
24
27
  class KernelKbykDump:
25
28
  COMMON_SETTINGS = "common_dump_settings"
@@ -39,6 +42,7 @@ class KernelKbykDump:
39
42
  common_set["input_output"] = 0
40
43
  common_set["kernels"] = []
41
44
  common_set["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7]
45
+ common_set["statistic_category"] = []
42
46
 
43
47
  if config.stat_cal_mode and config.device_stat_precision_mode:
44
48
  e2e_set = {
@@ -71,10 +75,30 @@ class KernelKbykDump:
71
75
  common_set["input_output"] = 1
72
76
  if config.data_mode[0] == Const.OUTPUT:
73
77
  common_set["input_output"] = 2
78
+ if config.summary_mode:
79
+ if isinstance(config.summary_mode, str):
80
+ if config.summary_mode == Const.STATISTICS:
81
+ common_set["statistic_category"] = ["max", "min", "avg", "l2norm"]
82
+ else:
83
+ mode = self._process_hash(config.summary_mode)
84
+ common_set["statistic_category"] = [mode]
85
+ elif isinstance(config.summary_mode, list):
86
+ common_set["statistic_category"] = list({
87
+ self._process_hash("avg" if mode == "mean" else mode)
88
+ for mode in config.summary_mode
89
+ })
74
90
 
75
91
  self.dump_json[KernelKbykDump.COMMON_SETTINGS] = common_set
76
92
  self.dump_json[KernelKbykDump.E2E_SETTINGS] = e2e_set
77
93
 
94
+ @staticmethod
95
+ def _process_hash(value):
96
+ if ms_version <= "2.7.0" and (value == Const.HASH or value == Const.MD5):
97
+ value = "md5"
98
+ elif value == Const.MD5:
99
+ value = "hash:md5"
100
+ return value
101
+
78
102
  def handle(self):
79
103
  json_path = self.dump_json[KernelKbykDump.COMMON_SETTINGS]["path"]
80
104
  create_directory(json_path)
File without changes
@@ -0,0 +1,51 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.core.common.log import logger
17
+ from msprobe.mindspore.common.const import Const
18
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
19
+ from msprobe.mindspore.exception_dump.kernel_graph_exception_dump import KernelGraphExceptionDump
20
+
21
+
22
+ class ExceptionDumpToolFactory:
23
+ tools = {
24
+ Const.CELL: {
25
+ Const.GRAPH_KBYK_MODE: None,
26
+ Const.GRAPH_GE_MODE: None,
27
+ Const.PYNATIVE_MODE: None
28
+ },
29
+ Const.API: {
30
+ Const.GRAPH_KBYK_MODE: None,
31
+ Const.GRAPH_GE_MODE: None,
32
+ Const.PYNATIVE_MODE: None
33
+ },
34
+ Const.KERNEL: {
35
+ Const.GRAPH_KBYK_MODE: KernelGraphExceptionDump,
36
+ Const.GRAPH_GE_MODE: None,
37
+ Const.PYNATIVE_MODE: KernelGraphExceptionDump
38
+ }
39
+ }
40
+
41
+ @staticmethod
42
+ def create(config: DebuggerConfig):
43
+ tool = ExceptionDumpToolFactory.tools.get(config.level)
44
+ if not tool:
45
+ raise Exception("Valid level is needed.")
46
+ tool = tool.get(config.execution_mode)
47
+ if not tool:
48
+ logger.error(f"Exception dump is not supported in {config.execution_mode} mode "
49
+ f"when level is {config.level}.")
50
+ raise ValueError
51
+ return (tool(config),)
@@ -0,0 +1,57 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from msprobe.core.common.file_utils import create_directory, save_json
19
+ from msprobe.mindspore.common.log import logger
20
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
21
+
22
+
23
+ class KernelGraphExceptionDump:
24
+
25
+ def __init__(self, config: DebuggerConfig):
26
+ self.dump_json = dict()
27
+ self.dump_json["common_dump_settings"] = dict()
28
+ self.dump_json["common_dump_settings"]["dump_mode"] = 0
29
+ self.dump_json["common_dump_settings"]["path"] = ""
30
+ self.dump_json["common_dump_settings"]["net_name"] = "Net"
31
+ self.dump_json["common_dump_settings"]["iteration"] = "all"
32
+ self.dump_json["common_dump_settings"]["saved_data"] = "tensor"
33
+ self.dump_json["common_dump_settings"]["input_output"] = 0
34
+ self.dump_json["common_dump_settings"]["kernels"] = []
35
+ self.dump_json["common_dump_settings"]["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7]
36
+ self.dump_json["common_dump_settings"]["op_debug_mode"] = 4
37
+ self.dump_json["common_dump_settings"]["file_format"] = "npy"
38
+ self.dump_json["e2e_dump_settings"] = dict()
39
+ self.dump_json["e2e_dump_settings"]["enable"] = not config.async_dump
40
+ self.dump_json["e2e_dump_settings"]["trans_flag"] = True
41
+
42
+ if config.stat_cal_mode and config.device_stat_precision_mode:
43
+ self.dump_json["e2e_dump_settings"]["stat_calc_mode"] = config.stat_cal_mode
44
+ self.dump_json["e2e_dump_settings"]["device_stat_precision_mode"] = config.device_stat_precision_mode
45
+ self.dump_json["common_dump_settings"]["path"] = config.dump_path
46
+ if len(config.step) > 0:
47
+ logger.warning("Step would change to all in this task.")
48
+ if len(config.rank) > 0:
49
+ self.dump_json["common_dump_settings"]["support_device"] = config.rank
50
+
51
+ def handle(self):
52
+ json_path = self.dump_json["common_dump_settings"]["path"]
53
+ create_directory(json_path)
54
+ json_path = os.path.join(json_path, "kernel_graph_exception_check.json")
55
+ save_json(json_path, self.dump_json, indent=4)
56
+ logger.info(json_path + " has been created.")
57
+ os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
@@ -16,6 +16,7 @@
16
16
  import functools
17
17
  import importlib
18
18
  import os
19
+ import threading
19
20
  import traceback
20
21
 
21
22
  import mindspore as ms
@@ -38,7 +39,6 @@ from msprobe.mindspore.free_benchmark.common.utils import Tools
38
39
  from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
39
40
  from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
40
41
 
41
-
42
42
  _api_register = get_api_register()
43
43
 
44
44
 
@@ -74,9 +74,10 @@ class ApiPyNativeSelfCheck:
74
74
 
75
75
  def forward_hook(api_name_with_id, cell, input_data, output_data):
76
76
  ret = None
77
+ tid = threading.get_ident()
77
78
 
78
79
  if not need_wrapper_func():
79
- del cell.msprobe_input_kwargs
80
+ del cell.msprobe_input_kwargs[tid]
80
81
  return ret
81
82
 
82
83
  api_name_with_id = api_name_with_id[:-1]
@@ -85,9 +86,9 @@ class ApiPyNativeSelfCheck:
85
86
  api_name_with_id[api_name_with_id.find(Const.SEP) + 1:api_name_with_id.rfind(Const.SEP)])
86
87
  if api_name in self.api_list:
87
88
  ret = check_self(api_name_with_id, output_data, self.ori_func.get(api_name),
88
- *input_data, **cell.msprobe_input_kwargs)
89
+ *input_data, **cell.msprobe_input_kwargs[tid])
89
90
 
90
- del cell.msprobe_input_kwargs
91
+ del cell.msprobe_input_kwargs[tid]
91
92
  return ret
92
93
 
93
94
  def backward_hook(cell, grad_input, grad_output):
@@ -27,7 +27,7 @@ from msprobe.mindspore.common.utils import (
27
27
  get_cells_and_names_with_index
28
28
  )
29
29
  from msprobe.mindspore.dump.hook_cell.api_register import get_api_register, ApiTemplate
30
- from msprobe.mindspore.dump.hook_cell.ms_hook_manager import MindsproeHookManager
30
+ from msprobe.mindspore.dump.hook_cell.ms_hook_manager import MindsporeHookManager
31
31
  from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
32
32
  from msprobe.mindspore.dump.jit_dump import JitDump
33
33
 
@@ -59,7 +59,7 @@ class MindsporeService(BaseService):
59
59
  self.api_register = get_api_register()
60
60
  self.primitive_hook_service = PrimitiveHookService(self)
61
61
  self.cell_processor = CellProcessor(self.data_collector.scope)
62
- self.hook_manager = MindsproeHookManager(self.data_collector, self.config)
62
+ self.hook_manager = MindsporeHookManager(self.data_collector, self.config)
63
63
  self._setup_jit_context()
64
64
  self.api_template = ApiTemplate
65
65
 
@@ -93,6 +93,8 @@ from torch.nn.modules.module import (_global_backward_pre_hooks, _global_backwar
93
93
  _global_forward_hooks, _global_forward_hooks_always_called)
94
94
  from torch.utils.hooks import RemovableHandle
95
95
 
96
+ from msprobe.mindspore.common.utils import is_backward_hook_output_a_view
97
+
96
98
 
97
99
  def _call_impl(self, *args, **kwargs):
98
100
  forward_call = self.forward
@@ -245,11 +247,14 @@ def _get_backward_hooks(self):
245
247
 
246
248
 
247
249
  def apply_backward_hook_on_tensors(cell_backward_hook, args):
248
- is_tuple = True
249
- if not isinstance(args, tuple):
250
- args = (args,)
251
- is_tuple = False
252
- hooked_args = cell_backward_hook(*args)
253
- if is_tuple and len(args) == 1:
254
- hooked_args = (hooked_args, )
250
+ if is_backward_hook_output_a_view():
251
+ hooked_args = cell_backward_hook(args)
252
+ else:
253
+ is_tuple = True
254
+ if not isinstance(args, tuple):
255
+ args = (args,)
256
+ is_tuple = False
257
+ hooked_args = cell_backward_hook(*args)
258
+ if is_tuple and len(args) == 1:
259
+ hooked_args = (hooked_args, )
255
260
  return hooked_args
@@ -17,6 +17,8 @@ from mindspore import mint, ops, _no_grad
17
17
  from mindspore import Tensor
18
18
  from mindspore import dtype as mstype
19
19
 
20
+ from msprobe.core.common.log import logger
21
+
20
22
 
21
23
  @_no_grad()
22
24
  def square_sum(x: Tensor):
@@ -74,3 +76,83 @@ FUNC_MAP = {
74
76
  "shape": get_shape,
75
77
  "dtype": get_dtype
76
78
  }
79
+
80
+
81
+ def max_eigenvalue(input_tensor: Tensor, num_iterations=3):
82
+ input_tensor = input_tensor.float()
83
+ try:
84
+ check_tensor_dim(input_tensor, 2)
85
+ except (TypeError, ValueError) as e:
86
+ logger.warning(f"calcute max eigenvalue failed, {e}")
87
+ return Tensor(0)
88
+ in_features = input_tensor.shape[1]
89
+ u_tensor = ops.randn(in_features)
90
+ u_norm = u_tensor.norm()
91
+ if u_norm == 0:
92
+ return Tensor(0)
93
+ u_tensor /= u_tensor.norm()
94
+ input_seq = ops.matmul(input_tensor.T, input_tensor)
95
+ for _ in range(num_iterations):
96
+ v_tensor = ops.matmul(input_seq, u_tensor)
97
+ spectral_norm = ops.matmul(v_tensor.T, u_tensor)
98
+ v_norm = v_tensor.norm()
99
+ if v_norm > 0:
100
+ u_tensor = v_tensor / v_norm
101
+ else:
102
+ spectral_norm = Tensor(0)
103
+ break
104
+ return spectral_norm.sqrt()
105
+
106
+
107
+ def check_tensor_dim(tensor, n):
108
+ if not isinstance(tensor, Tensor):
109
+ raise TypeError(
110
+ f"Input must be a mindspore Tensor, but got {type(tensor)} instead."
111
+ )
112
+ if len(tensor.shape) < n:
113
+ raise ValueError(
114
+ f"tensor dim must be at least {n} dimensions."
115
+ f"Got shape: {tuple(tensor.shape)} with {tensor.dim()} dims"
116
+ )
117
+
118
+
119
+ def cal_entropy(qk_tensor: Tensor, mask=None):
120
+ try:
121
+ check_tensor_dim(qk_tensor, 2)
122
+ except (TypeError, ValueError) as e:
123
+ logger.warning(f"calculate entropy failed, {e}")
124
+ return Tensor(0), Tensor(0)
125
+ if mask is None:
126
+ mask = ops.tril(ops.ones((qk_tensor.shape[1], qk_tensor.shape[1])))
127
+ qk_tensor = qk_tensor - ops.amax(qk_tensor, axis=1, keepdims=True)
128
+ qk_tensor = qk_tensor.masked_fill(mask == 0, float('-inf'))
129
+ softmax_qkt = ops.softmax(qk_tensor.float(), axis=1)
130
+ softmax_max = ops.mean(ops.amax(softmax_qkt, axis=1))
131
+ entropy = ops.mean(-ops.nansum(softmax_qkt * ops.log(softmax_qkt), axis=1))
132
+ return entropy, softmax_max
133
+
134
+
135
+ def cal_stable_rank(weight: Tensor):
136
+ eig = max_eigenvalue(weight)
137
+ if eig == Tensor(0):
138
+ return Tensor(0), Tensor(0)
139
+ f_norm = ops.norm(weight, ord='fro')
140
+ return f_norm / eig, eig
141
+
142
+
143
+ def cal_qkt(q_h: Tensor, k_h: Tensor, order="s,b,h,d"):
144
+ # q_h shape is (s, b, h, d)
145
+ try:
146
+ check_tensor_dim(q_h, 4)
147
+ check_tensor_dim(k_h, 4)
148
+ except (TypeError, ValueError) as e:
149
+ logger.warning(f"calculatee qkt failed, {e}")
150
+ return Tensor(0)
151
+ if order == "s,b,h,d":
152
+ qkt = ops.matmul(q_h[:, 0, 0, :], k_h[:, 0, 0, :].t()) / q_h.shape[-1] ** 0.5
153
+ elif order == "b,s,h,d":
154
+ qkt = ops.matmul(q_h[0, :, 0, :], k_h[0, :, 0, :].t()) / q_h.shape[-1] ** 0.5
155
+ else:
156
+ logger.warning(f"Calculate qk tensor failed: Order unsupported.")
157
+ qkt = Tensor(0)
158
+ return qkt