mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,7 +21,9 @@ from msprobe.core.common.exceptions import MsprobeException
21
21
  from msprobe.core.common.file_utils import FileChecker
22
22
  from msprobe.core.common.utils import get_real_step_or_rank
23
23
  from msprobe.pytorch.common.log import logger
24
+ from msprobe.pytorch.common.utils import check_save_param
24
25
  from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
26
+ from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper
25
27
  from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
26
28
  from msprobe.pytorch.pt_config import parse_json_config
27
29
  from msprobe.pytorch.service import Service
@@ -49,7 +51,7 @@ class PrecisionDebugger:
49
51
  dump_path=None,
50
52
  level=None,
51
53
  model=None,
52
- step=None,
54
+ step=None
53
55
  ):
54
56
  if not hasattr(self, "initialized"):
55
57
  config_params = ConfigParameters(config_path,
@@ -59,7 +61,6 @@ class PrecisionDebugger:
59
61
  model)
60
62
  self.check_input_params(config_params)
61
63
 
62
- self.api_origin = False
63
64
  self.initialized = True
64
65
  self.model = model
65
66
  common_config, task_config = parse_json_config(config_path, task)
@@ -67,12 +68,13 @@ class PrecisionDebugger:
67
68
  if self.task == Const.GRAD_PROBE:
68
69
  self.gm = GradientMonitor(common_config, task_config)
69
70
  return
70
- if step:
71
+ if step is not None:
71
72
  common_config.step = get_real_step_or_rank(step, Const.STEP)
72
73
  self.config = DebuggerConfig(
73
74
  common_config, task_config, task, dump_path, level
74
75
  )
75
76
  self.service = Service(self.config)
77
+ self.module_dumper = ModuleDumper(self.service)
76
78
  self.enable_dataloader = self.config.enable_dataloader
77
79
  if self.enable_dataloader:
78
80
  logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
@@ -105,9 +107,11 @@ class PrecisionDebugger:
105
107
  raise MsprobeException(
106
108
  MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
107
109
 
108
- if args.model is not None and not isinstance(args.model, torch.nn.Module):
109
- raise MsprobeException(
110
- MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
110
+ if args.model is not None:
111
+ logger.warning_on_rank_0(
112
+ "The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
113
+ "It is recommended to pass the 'model' parameter in the start interface instead."
114
+ )
111
115
 
112
116
  @classmethod
113
117
  def start(cls, model=None):
@@ -120,15 +124,12 @@ class PrecisionDebugger:
120
124
  if instance.enable_dataloader:
121
125
  logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
122
126
  else:
123
- instance.service.start(instance.model, instance.api_origin)
124
- instance.api_origin = False
127
+ instance.service.start(instance.model)
125
128
 
126
- # 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
127
129
  @classmethod
128
130
  def forward_backward_dump_end(cls):
129
131
  instance = cls._instance
130
- instance.service.forward_backward_dump_end()
131
- instance.api_origin = True
132
+ instance.stop()
132
133
 
133
134
  @classmethod
134
135
  def stop(cls):
@@ -158,6 +159,49 @@ class PrecisionDebugger:
158
159
  return
159
160
  cls._instance.gm.monitor(model)
160
161
 
162
+ @classmethod
163
+ def save(cls, variable, name, save_backward=True):
164
+ instance = cls._instance
165
+ if not instance:
166
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
167
+ if instance.task not in [Const.TENSOR, Const.STATISTICS] or instance.config.level != Const.LEVEL_DEBUG:
168
+ return
169
+ try:
170
+ check_save_param(variable, name, save_backward)
171
+ except ValueError:
172
+ return
173
+ instance.service.save(variable, name, save_backward)
174
+
175
+
176
+ def module_dump(module, dump_name):
177
+ if not isinstance(module, torch.nn.Module):
178
+ raise MsprobeException(
179
+ MsprobeException.INVALID_PARAM_ERROR,
180
+ f"the module argument in module_dump must be a torch.nn.Module subclass"
181
+ )
182
+ if not isinstance(dump_name, str):
183
+ raise MsprobeException(
184
+ MsprobeException.INVALID_PARAM_ERROR,
185
+ f"the dump_name argument in module_dump must be a str type"
186
+ )
187
+ instance = PrecisionDebugger._instance
188
+ if not instance:
189
+ raise MsprobeException(
190
+ MsprobeException.INTERFACE_USAGE_ERROR,
191
+ f"PrecisionDebugger must be instantiated before using module_dump interface"
192
+ )
193
+ instance.module_dumper.start_module_dump(module, dump_name)
194
+
195
+
196
+ def module_dump_end():
197
+ instance = PrecisionDebugger._instance
198
+ if not instance:
199
+ raise MsprobeException(
200
+ MsprobeException.INTERFACE_USAGE_ERROR,
201
+ f"PrecisionDebugger must be instantiated before using module_dump_end interface"
202
+ )
203
+ instance.module_dumper.stop_module_dump()
204
+
161
205
 
162
206
  def iter_tracer(func):
163
207
  def func_wrapper(*args, **kwargs):
File without changes
@@ -0,0 +1,86 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from msprobe.core.common.const import Const
18
+ from msprobe.core.data_dump.scope import BaseScope
19
+ from msprobe.pytorch.common.log import logger
20
+ from msprobe.pytorch.hook_module.api_registry import api_register
21
+
22
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
23
+
24
+
25
+ class ModuleDumper:
26
+ def __init__(self, service):
27
+ self.service = service
28
+ self.hook_handle_list = []
29
+
30
+ def start_module_dump(self, module, dump_name):
31
+ api_register.api_originality()
32
+ self.register_hook(module, dump_name)
33
+
34
+ def stop_module_dump(self):
35
+ api_register.api_modularity()
36
+ for hook_handle in self.hook_handle_list:
37
+ if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
38
+ hook_handle.remove()
39
+ self.hook_handle_list.clear()
40
+
41
+ def register_hook(self, module, dump_name):
42
+ prefix_name = (
43
+ BaseScope.Module_Type_Module + Const.SEP +
44
+ dump_name + Const.SEP +
45
+ module.__class__.__name__ + Const.SEP
46
+ )
47
+ module_processor = self.service.module_processor
48
+ _, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.service.build_hook(
49
+ BaseScope.Module_Type_Module,
50
+ prefix_name
51
+ )
52
+
53
+ if module_processor.has_register_backward_hook(module):
54
+ logger.warning(
55
+ f"The {dump_name} module has registered deprecated register_backward_hook,"
56
+ f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
57
+ )
58
+ if torch_version_above_or_equal_2:
59
+ forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True)
60
+ else:
61
+ if not module_processor.has_register_backward_hook(module):
62
+ backward_hook_handle = module.register_full_backward_hook(
63
+ module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
64
+ )
65
+ self.hook_handle_list.append(backward_hook_handle)
66
+ forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2)
67
+ self.hook_handle_list.append(forward_hook_handle)
68
+ if not module_processor.has_register_backward_hook(module):
69
+ backward_hook_handle = module.register_full_backward_hook(backward_hook)
70
+ self.hook_handle_list.append(backward_hook_handle)
71
+
72
+ forward_pre_hook_handle = module.register_forward_pre_hook(
73
+ module_processor.node_hook(prefix_name + Const.FORWARD, Const.START)
74
+ )
75
+ forward_hook_handle = module.register_forward_hook(
76
+ module_processor.node_hook(prefix_name + Const.FORWARD, Const.STOP)
77
+ )
78
+ self.hook_handle_list.extend([forward_pre_hook_handle, forward_hook_handle])
79
+ if torch_version_above_or_equal_2 and not module_processor.has_register_backward_hook(module):
80
+ backward_pre_hook_handle = module.register_full_backward_pre_hook(
81
+ module_processor.node_hook(prefix_name + Const.BACKWARD, Const.START)
82
+ )
83
+ backward_hook_handle = module.register_full_backward_hook(
84
+ module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
85
+ )
86
+ self.hook_handle_list.extend([backward_pre_hook_handle, backward_hook_handle])
@@ -0,0 +1,204 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import wraps
17
+
18
+ import torch
19
+ from msprobe.core.common.const import Const
20
+ from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
21
+ from msprobe.pytorch.common.log import logger
22
+ from msprobe.pytorch.common.utils import replace_last_occurrence
23
+ from torch.utils.checkpoint import checkpoint as origin_checkpoint
24
+ from torch.utils.checkpoint import set_checkpoint_early_stop
25
+ from torch.utils.hooks import BackwardHook
26
+
27
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
28
+
29
+
30
+ def checkpoint_without_early_stop(*args, **kwargs):
31
+ with set_checkpoint_early_stop(False):
32
+ return origin_checkpoint(*args, **kwargs)
33
+
34
+
35
+ def replace_checkpoint():
36
+ torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
37
+
38
+
39
+ class ModuleProcesser:
40
+ module_count = {}
41
+ module_stack = []
42
+ api_parent_node = ""
43
+ module_node = {}
44
+
45
+ def __init__(self, scope):
46
+ self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
47
+ BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
48
+ BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
49
+ replace_checkpoint()
50
+
51
+ @staticmethod
52
+ def clone_return_value(func):
53
+ @wraps(func)
54
+ def clone_return_value_func(*args, **kwargs):
55
+ result = func(*args, **kwargs)
56
+ return ModuleProcesser.clone_if_tensor(result)
57
+
58
+ return clone_return_value_func
59
+
60
+ @staticmethod
61
+ def clone_if_tensor(result):
62
+ if isinstance(result, torch.Tensor):
63
+ return result.clone()
64
+ elif type(result) is tuple:
65
+ return tuple(ModuleProcesser.clone_if_tensor(x) for x in result)
66
+ elif type(result) is list:
67
+ return list(ModuleProcesser.clone_if_tensor(x) for x in result)
68
+ elif type(result) is dict:
69
+ return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()}
70
+ else:
71
+ return result
72
+
73
+ @staticmethod
74
+ def module_count_func(module_name):
75
+ if module_name not in ModuleProcesser.module_count:
76
+ ModuleProcesser.module_count[module_name] = 0
77
+ else:
78
+ ModuleProcesser.module_count[module_name] += 1
79
+ return ModuleProcesser.module_count[module_name]
80
+
81
+ @staticmethod
82
+ def has_register_backward_hook(module):
83
+ return hasattr(module, '_backward_hooks') and \
84
+ len(module._backward_hooks) > 0 and \
85
+ module._is_full_backward_hook is False
86
+
87
+ @staticmethod
88
+ def get_modules_and_names(models):
89
+ modules_and_names_with_index = {}
90
+ if isinstance(models, (list, tuple)):
91
+ for index, model in enumerate(models):
92
+ modules_and_names_with_index[str(index)] = model.named_modules()
93
+ else:
94
+ modules_and_names_with_index["-1"] = models.named_modules()
95
+ return modules_and_names_with_index
96
+
97
+ @classmethod
98
+ def reset_module_stats(cls):
99
+ cls.module_count = {}
100
+ cls.module_stack = []
101
+ cls.api_parent_node = ""
102
+ cls.module_node = {}
103
+
104
+ def register_module_hook(self, models, build_hook):
105
+ logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.")
106
+ modules_and_names_with_index = self.get_modules_and_names(models)
107
+ for index, modules_and_names in modules_and_names_with_index.items():
108
+ model = models if index == "-1" else models[int(index)]
109
+ for name, module in modules_and_names:
110
+ if module == model:
111
+ continue
112
+ module_index = (index + Const.SEP) if index != "-1" else ""
113
+ prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
114
+ name + Const.SEP + module.__class__.__name__ + Const.SEP)
115
+ pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook(
116
+ BaseScope.Module_Type_Module,
117
+ prefix_name
118
+ )
119
+
120
+ if self.has_register_backward_hook(module):
121
+ logger.warning(
122
+ f"The {prefix_name[:-1]} has registered deprecated register_backward_hook,"
123
+ f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
124
+ )
125
+ if torch_version_above_or_equal_2:
126
+ module.register_forward_hook(forward_hook, with_kwargs=True)
127
+ else:
128
+ if not self.has_register_backward_hook(module):
129
+ module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
130
+ module.register_forward_hook(forward_hook_torch_version_below_2)
131
+ if not self.has_register_backward_hook(module):
132
+ module.register_full_backward_hook(backward_hook)
133
+
134
+ module.register_forward_pre_hook(self.node_hook(prefix_name + Const.FORWARD, Const.START))
135
+ module.register_forward_hook(self.node_hook(prefix_name + Const.FORWARD, Const.STOP))
136
+ if torch_version_above_or_equal_2 and not self.has_register_backward_hook(module):
137
+ module.register_full_backward_pre_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.START))
138
+ module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
139
+
140
+ def node_hook(self, name_prefix, start_or_stop, **kwargs):
141
+
142
+ def pre_hook(module, input, output=None):
143
+ try:
144
+ index = ModuleProcesser.module_count_func(name_prefix)
145
+ except IndexError as e:
146
+ index = None
147
+ pass
148
+ full_name = name_prefix + Const.SEP + str(index)
149
+ if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
150
+ module.mindstudio_reserved_name = []
151
+ module.mindstudio_reserved_name.append(full_name)
152
+ if self.module_stack:
153
+ ModuleProcesser.module_node[full_name] = self.module_stack[-1]
154
+ else:
155
+ ModuleProcesser.module_node[full_name] = None
156
+
157
+ ModuleProcesser.module_stack.append(full_name)
158
+ if self.module_stack:
159
+ ModuleProcesser.api_parent_node = self.module_stack[-1]
160
+ if self.scope:
161
+ self.scope.begin_module(full_name)
162
+
163
+ def end_hook(module, input, output=None):
164
+ if self.module_stack:
165
+ ModuleProcesser.module_stack.pop()
166
+ if self.module_stack:
167
+ ModuleProcesser.api_parent_node = self.module_stack[-1]
168
+ else:
169
+ ModuleProcesser.api_parent_node = None
170
+ if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
171
+ raise RuntimeError(f"module reserve name is None when pop")
172
+ current_name = module.mindstudio_reserved_name.pop()
173
+ if self.scope:
174
+ self.scope.end_module(current_name)
175
+
176
+ def backward_hook(module, input, output=None):
177
+ try:
178
+ index = ModuleProcesser.module_count_func(name_prefix)
179
+ except IndexError as e:
180
+ index = None
181
+ pass
182
+ full_name = name_prefix + Const.SEP + str(index)
183
+ if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
184
+ module.mindstudio_reserved_name = []
185
+ module.mindstudio_reserved_name.append(full_name)
186
+ forward_full_name = replace_last_occurrence(full_name, Const.BACKWARD, Const.FORWARD)
187
+ ModuleProcesser.module_node[full_name] = replace_last_occurrence(
188
+ ModuleProcesser.module_node.get(forward_full_name), Const.FORWARD, Const.BACKWARD)
189
+ ModuleProcesser.api_parent_node = None
190
+ if self.scope:
191
+ self.scope.begin_module(full_name)
192
+
193
+ if torch_version_above_or_equal_2:
194
+ if Const.START in start_or_stop:
195
+ return pre_hook
196
+ else:
197
+ return end_hook
198
+ else:
199
+ if Const.FORWARD in name_prefix and Const.START in start_or_stop:
200
+ return pre_hook
201
+ elif Const.BACKWARD in name_prefix:
202
+ return backward_hook
203
+ else:
204
+ return end_hook
@@ -39,7 +39,6 @@ class DataParams:
39
39
  origin_func: Optional[Callable] = None
40
40
  api_type: Optional[str] = None
41
41
  fuzz_stage: Optional[str] = None
42
- grad_unequal_flag: Optional[bool] = True
43
42
 
44
43
 
45
44
  @dataclass
@@ -127,6 +126,8 @@ def make_unequal_row(
127
126
  )
128
127
  if isinstance(ratio, float):
129
128
  row.max_rel = ratio - 1
129
+ if isinstance(ratio, str):
130
+ row.max_rel = ratio
130
131
  origin_tensor = data_params.original_result
131
132
  perturbed_tensor = data_params.perturbed_result
132
133
  if index is not None:
@@ -124,6 +124,7 @@ class TorchC:
124
124
  abs = torch._C._VariableFunctionsClass.abs
125
125
  where = torch._C._VariableFunctionsClass.where
126
126
  div = torch._C._VariableFunctionsClass.div
127
+ mul = torch._C._VariableFunctionsClass.mul
127
128
  max = torch._C._VariableFunctionsClass.max
128
129
  min = torch._C._VariableFunctionsClass.min
129
130
  gt = torch._C._VariableFunctionsClass.gt
@@ -138,3 +139,5 @@ class TorchC:
138
139
  tensor_split = torch._C._VariableFunctionsClass.tensor_split
139
140
  stack = torch._C._VariableFunctionsClass.stack
140
141
  reshape = torch._C._VariableFunctionsClass.reshape
142
+ nan_to_num = torch._C._VariableFunctionsClass.nan_to_num
143
+ aminmax = torch._C._VariableFunctionsClass.aminmax
@@ -82,13 +82,11 @@ class GradSaver:
82
82
  data_params = DataParams()
83
83
  data_params.original_result = origin_grad
84
84
  data_params.perturbed_result = perturbed_grad
85
- data_params.grad_unequal_flag = False
86
85
  data_params.valid_input_index = index
87
86
  try:
88
87
  handler.handle(data_params)
89
88
  if not data_params.is_consistent:
90
89
  self.is_compare = False
91
- data_params.grad_unequal_flag = True
92
90
  data_params.is_consistent = True
93
91
  data_params.perturbed_result = self.perturbed_grad_input
94
92
  data_params.original_result = self.origin_grad_input
@@ -89,12 +89,6 @@ class FuzzHandler(ABC):
89
89
  )
90
90
  return origin_output_chunks, perturbed_output_chunks
91
91
 
92
- @staticmethod
93
- def convert_overflow_ratio_to_consistent(ratio):
94
- if math.isnan(ratio) or math.isinf(ratio):
95
- return ThresholdConfig.COMP_CONSISTENT
96
- return ratio
97
-
98
92
  @abstractmethod
99
93
  def get_threshold(self, dtype):
100
94
  pass
@@ -107,10 +101,10 @@ class FuzzHandler(ABC):
107
101
  self, origin_output, perturbed_output, norm_type, abs_tol
108
102
  ):
109
103
  if norm_type == NormType.ENDLESS_NORM:
110
- return self.calculate_error(origin_output, perturbed_output, abs_tol)
104
+ return self.calculate_max_ratio(origin_output, perturbed_output, abs_tol)
111
105
  return ThresholdConfig.COMP_CONSISTENT
112
106
 
113
- def calculate_error(self, origin_output, perturbed_output, abs_tol):
107
+ def calculate_max_ratio(self, origin_output, perturbed_output, abs_tol):
114
108
  origin_output_chunks, perturbed_output_chunks = (
115
109
  self.tensor_split_for_error_calculate(origin_output, perturbed_output)
116
110
  )
@@ -122,42 +116,30 @@ class FuzzHandler(ABC):
122
116
  raise FreeBenchmarkException(
123
117
  FreeBenchmarkException.OutputIndexError, err_msg
124
118
  )
125
- norm1 = -np.inf
126
- norm2 = -np.inf
127
- norm3 = np.inf
119
+
120
+ max_ratio = ThresholdConfig.COMP_CONSISTENT
128
121
  for i, chunk_origin in enumerate(origin_output_chunks):
129
122
  if chunk_origin.nelement() == 0:
130
123
  break
131
124
  chunk_perturbed = perturbed_output_chunks[i]
132
- ratio_tensor1 = TorchC.where(
133
- TorchC.abs(chunk_perturbed) > abs_tol,
134
- TorchC.div(
135
- TorchC.clamp(chunk_origin, min=abs_tol),
136
- TorchC.clamp(chunk_perturbed, min=abs_tol),
137
- ),
138
- 1,
139
- )
140
- ratio_tensor2 = TorchC.where(
141
- TorchC.abs(chunk_origin) > abs_tol,
142
- TorchC.div(
143
- TorchC.clamp(chunk_perturbed, min=abs_tol),
144
- TorchC.clamp(chunk_origin, min=abs_tol),
145
- ),
146
- 1,
125
+ # 如果乘积最小值 < 极小值乘积的负值,认为存在非极小值符号相反的情况
126
+ if TorchC.lt(
127
+ TorchC.min(TorchC.mul(chunk_origin, chunk_perturbed)), -(abs_tol**2)
128
+ ):
129
+ return ThresholdConfig.SYMBOL_FLIPPING
130
+ # 求A/B B/A的比值前,将值限制在大于极小值范围内
131
+ clamp_origin = TorchC.clamp(TorchC.abs(chunk_origin), min=abs_tol)
132
+ clamp_perturbed = TorchC.clamp(TorchC.abs(chunk_perturbed), min=abs_tol)
133
+ # 对于计算结果为nan的情况,认为两者没有差异
134
+ ratio_tensor = TorchC.nan_to_num(
135
+ TorchC.div(clamp_origin, clamp_perturbed),
136
+ nan=ThresholdConfig.COMP_CONSISTENT,
147
137
  )
148
- norm_values = TorchC.stack(
149
- [TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)]
150
- )
151
- max_ratio1, max_ratio2 = norm_values.tolist()
152
- norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
153
- norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
154
- norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1))
155
-
156
- if norm3 < 0:
157
- ratio = ThresholdConfig.SYMBOL_FLIPPING
158
- else:
159
- ratio = max(norm1, norm2)
160
- return ratio
138
+ # 求A/B 和 B/A比值最大值,其中 B/A的最大值为 A/B的最小值的倒数
139
+ min_ratio, max_ratio = TorchC.stack([*TorchC.aminmax(ratio_tensor)]).tolist()
140
+ min_ratio_reciprocal = np.inf if min_ratio == 0 else 1 / min_ratio
141
+ max_ratio = max(max_ratio, min_ratio_reciprocal)
142
+ return max_ratio
161
143
 
162
144
  def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float:
163
145
  try:
@@ -220,10 +202,12 @@ class FuzzHandler(ABC):
220
202
  )
221
203
  npu_consistent = is_consistent
222
204
  max_fuzz_ratio = (
223
- max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
205
+ max_fuzz_ratio
206
+ if not isinstance(ratio, (int, float))
207
+ else max(max_fuzz_ratio, ratio)
224
208
  )
225
- data_params.is_consistent = is_consistent and data_params.is_consistent
226
- if not is_consistent and data_params.grad_unequal_flag:
209
+ data_params.is_consistent = is_consistent
210
+ if not is_consistent:
227
211
  self.unequal_rows.append(
228
212
  make_unequal_row(data_params, self.params, ratio=ratio)
229
213
  )
@@ -235,12 +219,12 @@ class FuzzHandler(ABC):
235
219
  )
236
220
  npu_consistent = npu_consistent and is_consistent
237
221
  max_fuzz_ratio = (
238
- max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
239
- )
240
- data_params.is_consistent = (
241
- is_consistent and data_params.is_consistent
222
+ max_fuzz_ratio
223
+ if not isinstance(ratio, (int, float))
224
+ else max(max_fuzz_ratio, ratio)
242
225
  )
243
- if not is_consistent and data_params.grad_unequal_flag:
226
+ data_params.is_consistent = is_consistent
227
+ if not is_consistent:
244
228
  self.unequal_rows.append(
245
229
  make_unequal_row(
246
230
  data_params, self.params, ratio=ratio, index=index_
@@ -75,10 +75,6 @@ class PreheatHandler(FuzzHandler):
75
75
  if self.params.preheat_config.get("preheat_step") <= self.params.step:
76
76
  return data_params.original_result
77
77
 
78
- if not data_params.grad_unequal_flag:
79
- data_params.grad_unequal_flag = True
80
- data_params.is_consistent = False
81
- return data_params.original_result
82
78
  preheat_counter.add_api_called_time(self.pure_name)
83
79
 
84
80
  if not self._is_take_a_sample():
@@ -27,6 +27,11 @@ from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotar
27
27
  from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \
28
28
  npu_scaled_masked_softmax_backward
29
29
  from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward
30
+ from msprobe.pytorch.bench_functions.apply_adam import npu_apply_adam
31
+ from msprobe.pytorch.bench_functions.group_norm_silu import npu_group_norm_silu
32
+ from msprobe.pytorch.bench_functions.mish import npu_mish
33
+ from msprobe.pytorch.bench_functions.moe_gating_top_k_softmax import npu_moe_gating_top_k_softmax
34
+ from msprobe.pytorch.bench_functions.sort_v2 import npu_sort_v2
30
35
  from msprobe.pytorch.common.utils import logger
31
36
 
32
37
 
@@ -79,7 +84,8 @@ class Register(dict):
79
84
  npu_custom_functions = Register()
80
85
  npu_custom_functions([
81
86
  npu_apply_adam_w, npu_confusion_transpose, npu_fast_gelu, npu_layer_norm_eval, npu_linear, npu_fusion_attention,
82
- npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention
87
+ npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention, npu_apply_adam,
88
+ npu_group_norm_silu, npu_mish, npu_moe_gating_top_k_softmax, npu_sort_v2
83
89
  ])
84
90
 
85
91
  # register for npu custom backward bench functions
@@ -13,4 +13,4 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from .wrap_functional import remove_dropout
16
+ from msprobe.pytorch.common.utils import remove_dropout