mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -23,10 +23,16 @@ from mindspore.mint.nn import functional
23
23
  from msprobe.core.common.const import Const
24
24
  from msprobe.core.common.file_utils import load_yaml
25
25
  from msprobe.mindspore.common.const import Const as MsConst
26
+ from msprobe.mindspore.common.utils import is_mindtorch
26
27
  from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
27
28
 
29
+ if is_mindtorch():
30
+ import torch
31
+ import torch_npu
32
+
28
33
  cur_path = os.path.dirname(os.path.realpath(__file__))
29
34
  yaml_path = os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE)
35
+ torch_yaml_path = os.path.join(cur_path, "../../../pytorch/hook_module", MsConst.SUPPORTED_API_LIST_FILE)
30
36
 
31
37
 
32
38
  class HOOKTensor(object):
@@ -53,6 +59,26 @@ class HOOKDistributedOP(object):
53
59
  pass
54
60
 
55
61
 
62
+ class HOOKTorchOP(object):
63
+ pass
64
+
65
+
66
+ class HOOKTorchTensor(object):
67
+ pass
68
+
69
+
70
+ class HOOKTorchFunctionalOP(object):
71
+ pass
72
+
73
+
74
+ class HOOKTorchDistributedOP(object):
75
+ pass
76
+
77
+
78
+ class HOOKTorchNpuOP(object):
79
+ pass
80
+
81
+
56
82
  class ApiTemplate(HOOKCell):
57
83
  def __init__(self, api_name, api_dict, prefix, hook):
58
84
  self.api_name = api_name
@@ -60,7 +86,30 @@ class ApiTemplate(HOOKCell):
60
86
  self.prefix_api_name = prefix + str(api_name.split(Const.SEP)[-1]) + Const.SEP
61
87
  super().__init__(hook)
62
88
 
89
+ @staticmethod
90
+ def async_to_sync(output):
91
+ # Fake handle, used to return after the CommHandle executes the wait method
92
+ fake_handle = type("FakeHandle", (), {"wait": lambda self: None})()
93
+ if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"):
94
+ output[1].wait()
95
+ output = (output[0], fake_handle)
96
+ elif hasattr(output, "wait"):
97
+ output.wait()
98
+ output = fake_handle
99
+ return output
100
+
63
101
  def construct(self, *args, **kwargs):
102
+ if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
103
+ return args[0] if args else kwargs.get(Const.INPUT)
104
+
105
+ output = self.api_func(*args, **kwargs)
106
+
107
+ if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX):
108
+ if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]:
109
+ output = self.async_to_sync(output)
110
+ return output
111
+
112
+ def forward(self, *args, **kwargs):
64
113
  if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
65
114
  return args[0] if args else kwargs.get(Const.INPUT)
66
115
  return self.api_func(*args, **kwargs)
@@ -77,6 +126,15 @@ class WrapApiName:
77
126
  self.distributed_api_names = distributed_api_names
78
127
 
79
128
 
129
+ class WrapTorchApiName:
130
+ def __init__(self, torch_api_names, tensor_api_names, functional_api_names, distributed_api_names, npu_api_names):
131
+ self.torch_api_names = torch_api_names
132
+ self.tensor_api_names = tensor_api_names
133
+ self.functional_api_names = functional_api_names
134
+ self.distributed_api_names = distributed_api_names
135
+ self.npu_api_names = npu_api_names
136
+
137
+
80
138
  def get_wrap_api_list():
81
139
  api_list = load_yaml(yaml_path)
82
140
  tensor_api = api_list.get(MsConst.SUPPORTED_TENSOR_LIST_KEY)
@@ -93,6 +151,21 @@ def get_wrap_api_list():
93
151
  return wrap_api_name
94
152
 
95
153
 
154
+ def get_wrap_torch_api_list():
155
+ api_list = load_yaml(torch_yaml_path)
156
+ torch_api = api_list.get("torch")
157
+ tensor_api = api_list.get("tensor")
158
+ functional_api = api_list.get("functional")
159
+ distributed_api = api_list.get("distributed")
160
+ npu_api = api_list.get("torch_npu")
161
+ wrap_api_name = WrapTorchApiName(set(torch_api) & set(dir(torch)),
162
+ set(tensor_api) & set(dir(torch.Tensor)),
163
+ set(functional_api) & set(dir(torch.nn.functional)),
164
+ set(distributed_api) & set(dir(torch.distributed)),
165
+ set(npu_api) & set(dir(torch_npu)))
166
+ return wrap_api_name
167
+
168
+
96
169
  def wrap_api_func(api_name, api_dict, prefix, hook):
97
170
  def api_function(*args, **kwargs):
98
171
  return ApiTemplate(api_name, api_dict, prefix, hook)(*args, **kwargs)
@@ -106,6 +179,24 @@ def wrap_api_func_and_bind(api_list, api_dict, prefix, hook, hook_class):
106
179
 
107
180
 
108
181
  def setup_hooks(hook):
182
+ if is_mindtorch():
183
+ torch_wrap_api_name = get_wrap_torch_api_list()
184
+ wrap_api_func_and_bind(torch_wrap_api_name.torch_api_names,
185
+ {f: getattr(torch, f) for f in dir(torch)},
186
+ MsConst.TORCH_DATA_PREFIX, hook, HOOKTorchOP)
187
+ wrap_api_func_and_bind(torch_wrap_api_name.tensor_api_names,
188
+ {f: getattr(torch.Tensor, f) for f in dir(torch.Tensor)},
189
+ MsConst.TENSOR_DATA_PREFIX, hook, HOOKTorchTensor)
190
+ wrap_api_func_and_bind(torch_wrap_api_name.functional_api_names,
191
+ {f: getattr(torch.nn.functional, f) for f in dir(torch.nn.functional)},
192
+ MsConst.OPS_DATA_PREFIX, hook, HOOKTorchFunctionalOP)
193
+ wrap_api_func_and_bind(torch_wrap_api_name.distributed_api_names,
194
+ {f: getattr(torch.distributed, f) for f in dir(torch.distributed)},
195
+ MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKTorchDistributedOP)
196
+ wrap_api_func_and_bind(torch_wrap_api_name.npu_api_names, {f: getattr(torch_npu, f) for f in dir(torch_npu)},
197
+ MsConst.TORCH_NPU_DATA_PREFIX, hook, HOOKTorchNpuOP)
198
+ return
199
+
109
200
  wrap_api_name = get_wrap_api_list()
110
201
  wrap_api_func_and_bind(wrap_api_name.tensor_api_names, {f: getattr(Tensor, f) for f in dir(Tensor)},
111
202
  MsConst.TENSOR_DATA_PREFIX, hook, HOOKTensor)
@@ -16,14 +16,15 @@
16
16
  import os
17
17
  from collections import defaultdict
18
18
 
19
- from mindspore import Tensor
20
19
  from mindspore._c_expression import PyNativeExecutor_
21
- from mindspore.common.api import _MindsporeFunctionExecutor
20
+ try:
21
+ from mindspore.common.api import _MindsporeFunctionExecutor
22
+ except ImportError:
23
+ from mindspore.common.api import _JitExecutor as _MindsporeFunctionExecutor
22
24
 
23
25
  from msprobe.core.common.log import logger
24
- from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
25
26
  from msprobe.core.common.const import Const
26
- from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs
27
+ from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
27
28
  from msprobe.mindspore.dump.hook_cell.api_registry import api_register
28
29
 
29
30
 
@@ -40,8 +41,8 @@ def dump_jit(name, in_feat, out_feat, is_forward):
40
41
  if JitDump.need_dump():
41
42
  if is_forward:
42
43
  JitDump.jit_count[result] += 1
43
- name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \
44
- Const.FORWARD
44
+ name_template = (Const.JIT + Const.SEP + result + Const.SEP +
45
+ str(JitDump.jit_count[result]) + Const.SEP + Const.FORWARD)
45
46
  JitDump.data_collector.update_api_or_module_name(name_template)
46
47
  module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat)
47
48
  JitDump.data_collector.forward_data_collect(name_template, None, pid, module_input_output)
@@ -0,0 +1,33 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from msprobe.core.common.file_utils import save_json
19
+
20
+
21
+ def create_kernel_config_json(dump_path, cur_rank):
22
+ kernel_config_name = "kernel_config.json" if cur_rank == '' else f"kernel_config_{cur_rank}.json"
23
+ kernel_config_path = os.path.join(dump_path, kernel_config_name)
24
+ config_info = {
25
+ "dump": {
26
+ "dump_list": [],
27
+ "dump_path": dump_path,
28
+ "dump_mode": "all",
29
+ "dump_op_switch": "on"
30
+ }
31
+ }
32
+ save_json(kernel_config_path, config_info, indent=4)
33
+ return kernel_config_path
@@ -56,6 +56,13 @@ class KernelGraphDump:
56
56
  self.dump_json["common_dump_settings"]["input_output"] = 2
57
57
 
58
58
  def handle(self):
59
+ try:
60
+ from msprobe.lib import _msprobe_c
61
+ return
62
+ except ImportError:
63
+ # 如果没有_msprobe_ce_c走MindSpore老流程
64
+ logger.info("Module _msprobe_c has not been installed, use interface in mindspore instead.")
65
+
59
66
  if os.getenv("GRAPH_OP_RUN") == "1":
60
67
  raise Exception("Must run in graph mode, not kbk mode")
61
68
  json_path = self.dump_json["common_dump_settings"]["path"]
@@ -19,7 +19,6 @@ import os
19
19
  import traceback
20
20
 
21
21
  import mindspore as ms
22
-
23
22
  from msprobe.core.common.const import Const
24
23
  from msprobe.core.common.exceptions import DistributedNotInitializedError
25
24
  from msprobe.core.common.file_utils import check_path_length, load_yaml
@@ -29,6 +28,7 @@ from msprobe.mindspore.common.log import logger
29
28
  from msprobe.mindspore.common.utils import get_rank_if_initialized
30
29
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
31
30
  from msprobe.mindspore.dump.hook_cell.api_registry import api_register
31
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
32
32
  from msprobe.mindspore.free_benchmark.common.config import Config
33
33
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
34
34
  from msprobe.mindspore.free_benchmark.common.utils import Tools
@@ -63,7 +63,10 @@ class ApiPyNativeSelfCheck:
63
63
  api_register.initialize_hook(self.build_hook)
64
64
  api_register.api_set_hook_func()
65
65
 
66
- def build_hook(self, api_name_with_id):
66
+ def build_hook(self, api_name):
67
+ def pre_hook(cell, input_data):
68
+ return None
69
+
67
70
  def forward_hook(api_name_with_id, cell, input_data, output_data):
68
71
  ret = None
69
72
 
@@ -85,7 +88,10 @@ class ApiPyNativeSelfCheck:
85
88
  def backward_hook(cell, grad_input, grad_output):
86
89
  pass
87
90
 
91
+ HOOKCell.get_cell_count(api_name)
92
+ api_name_with_id = api_name + str(HOOKCell.get_cell_count(api_name)) + Const.SEP
88
93
  forward_hook = functools.partial(forward_hook, api_name_with_id)
94
+ HOOKCell.add_cell_count(api_name)
89
95
 
90
96
  def wrap_forward_hook(cell, input_data, output_data):
91
97
  return forward_hook(cell, input_data, output_data)
@@ -93,7 +99,10 @@ class ApiPyNativeSelfCheck:
93
99
  def wrap_backward_hook(cell, grad_input, grad_output):
94
100
  return backward_hook(cell, grad_input, grad_output)
95
101
 
96
- return wrap_forward_hook, wrap_backward_hook
102
+ def pre_backward_hook(cell, grad_input):
103
+ return None
104
+
105
+ return pre_hook, wrap_forward_hook, wrap_backward_hook, pre_backward_hook
97
106
 
98
107
  def store_original_func(self):
99
108
  for api_name in self.api_list:
@@ -138,7 +147,7 @@ def get_module(api_name):
138
147
  module_obj = importlib.import_module(func_name_list[0])
139
148
  for i, module_name in enumerate(func_name_list[1:-1]):
140
149
  if not hasattr(module_obj, module_name):
141
- importlib.import_module(f"{Const.SEP.join(func_name_list[:i+2])}")
150
+ importlib.import_module(f"{Const.SEP.join(func_name_list[:i + 2])}")
142
151
  module_obj = getattr(module_obj, module_name)
143
152
  orig_func = getattr(module_obj, func_name)
144
153
 
@@ -35,12 +35,12 @@ class BitNoisePerturbation(BasePerturbation):
35
35
  noise_type = list(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.keys())[
36
36
  list(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.values()).index(bit_len_type)]
37
37
  noise = ops.full(inputs.shape, 1, dtype=noise_type)
38
- input_np = inputs.contiguous().asnumpy()
38
+ input_np = inputs.asnumpy()
39
39
  input_np_int = input_np.view(bit_len_type)
40
40
  result = Tensor(input_np_int)
41
41
  result = ops.where(ops.abs(inputs) > sub_normal,
42
42
  ops.bitwise_xor(result, noise), result)
43
- result_np = result.contiguous().asnumpy()
43
+ result_np = result.asnumpy()
44
44
  result_np_float = result_np.view(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.get(inputs.dtype))
45
45
  self.is_fuzzed = True
46
46
  return Tensor(result_np_float)
@@ -16,6 +16,7 @@
16
16
  import multiprocessing
17
17
  import os
18
18
  import time
19
+ from dataclasses import dataclass
19
20
  from multiprocessing import Process
20
21
  from typing import List
21
22
 
@@ -23,6 +24,7 @@ import mindspore as ms
23
24
  import numpy as np
24
25
  from mindspore.common.parameter import Parameter
25
26
  from mindspore.communication import get_rank
27
+
26
28
  from msprobe.core.common.file_utils import (create_directory, check_file_or_directory_path,
27
29
  write_csv, remove_path, move_file, load_npy)
28
30
  from msprobe.core.grad_probe.constant import GradConst
@@ -31,6 +33,16 @@ from msprobe.mindspore.common.log import logger
31
33
  from msprobe.mindspore.grad_probe.global_context import grad_context, GlobalContext
32
34
 
33
35
 
36
+ @dataclass
37
+ class GradDumpConfig:
38
+ dump_dir: str
39
+ g_name: str
40
+ dump_step: Parameter
41
+ grad: ms.Tensor
42
+ level: str
43
+ bounds: List
44
+
45
+
34
46
  def get_rank_id():
35
47
  try:
36
48
  rank_id = get_rank()
@@ -40,35 +52,35 @@ def get_rank_id():
40
52
 
41
53
 
42
54
  @ms.jit
43
- def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor, level: str, bounds: List):
55
+ def grad_dump(config: GradDumpConfig):
44
56
  """
45
57
  Dump gradient statistic data.
46
58
  level0: [step, max, min, norm, shape_dim, shape]
47
59
  level1: [step, max, min, norm, shape_dim, shape] + grad_bool_data
48
60
  level2: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + grad_bool_data
49
61
  """
50
- dump_path = os.path.join(dump_dir, g_name)
62
+ dump_path = os.path.join(config.dump_dir, config.g_name)
51
63
  dump_dir_path = dump_path + "_dir"
52
64
  save_op = ms.ops.TensorDump()
53
65
 
54
- grad_flat = grad.reshape(-1)
66
+ grad_flat = config.grad.reshape(-1)
55
67
  max_val = grad_flat.max(axis=0).float()
56
68
  min_val = grad_flat.min(axis=0).float()
57
69
  norm_val = grad_flat.norm(ord=2).float()
58
- shape = grad.shape
59
- extrem_list = [dump_step[0].float(), max_val, min_val, norm_val]
70
+ shape = config.grad.shape
71
+ extrem_list = [config.dump_step[0].float(), max_val, min_val, norm_val]
60
72
  extrem_stat = ms.ops.stack(extrem_list)
61
73
  shape_list = [len(shape)] + list(shape)
62
74
  shape_stat = ms.Tensor(shape_list).float()
63
75
  level0_stat = ms.ops.concat((extrem_stat, shape_stat), axis=0)
64
76
  level_stat = level0_stat
65
77
 
66
- if level == GradConst.LEVEL2:
67
- zero_grad = (grad == 0).sum()
68
- dist_dim = ms.Tensor([len(bounds) + 2]).float()
69
- bucket_result = ms.ops.bucketize(grad.float(), bounds)
78
+ if config.level == GradConst.LEVEL2:
79
+ zero_grad = (config.grad == 0).sum()
80
+ dist_dim = ms.Tensor([len(config.bounds) + 2]).float()
81
+ bucket_result = ms.ops.bucketize(config.grad.float(), config.bounds)
70
82
  bucket_result = bucket_result.astype(ms.int8)
71
- dist_stat = [(bucket_result == i).sum() for i in range(len(bounds) + 1)]
83
+ dist_stat = [(bucket_result == i).sum() for i in range(len(config.bounds) + 1)]
72
84
  dist_stat.append(zero_grad)
73
85
  dist_stat.append(ms.Tensor(1, dtype=ms.int64)) # make sure dist_stat is not empty
74
86
  dist_stat = ms.ops.stack(dist_stat, axis=0).float()
@@ -76,8 +88,8 @@ def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor,
76
88
  level_stat = level2_stat
77
89
 
78
90
  save_op(dump_path, level_stat)
79
- if level == GradConst.LEVEL1 or level == GradConst.LEVEL2:
80
- grad_direction = grad > 0
91
+ if config.level == GradConst.LEVEL1 or config.level == GradConst.LEVEL2:
92
+ grad_direction = config.grad > 0
81
93
  save_op(dump_dir_path, grad_direction)
82
94
 
83
95
 
@@ -26,7 +26,7 @@ from msprobe.core.grad_probe.constant import GradConst
26
26
  from msprobe.mindspore.common.log import logger
27
27
  from msprobe.mindspore.grad_probe.global_context import grad_context
28
28
  from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator
29
- from msprobe.mindspore.grad_probe.grad_analyzer import grad_dump, get_rank_id
29
+ from msprobe.mindspore.grad_probe.grad_analyzer import grad_dump, get_rank_id, GradDumpConfig
30
30
  from msprobe.mindspore.grad_probe.grad_stat_csv import GradStatCsv, CsvInput
31
31
  from msprobe.mindspore.grad_probe.utils import save_grad_direction, get_adapted_level
32
32
 
@@ -38,7 +38,14 @@ class HookInput:
38
38
 
39
39
  def __init__(self, opt) -> None:
40
40
  self.func = opt.construct
41
- self.g_names = [param.name for param in opt._parameters]
41
+ if hasattr(opt, "_parameters"):
42
+ parameter_list = opt._parameters
43
+ elif hasattr(opt, "parameters"):
44
+ parameter_list = opt.parameters
45
+ else:
46
+ logger.error_log_with_exp("Given optimizer has no attributes: '_parameters' or 'parameters'. \
47
+ Please check the type of the given optimizer.", ValueError)
48
+ self.g_names = [param.name for param in parameter_list]
42
49
  self.param_list = grad_context.get_context(GradConst.PARAM_LIST)
43
50
  self.rank_id = get_rank_id()
44
51
  output_path = grad_context.get_context(GradConst.OUTPUT_PATH)
@@ -59,8 +66,10 @@ def hook_graph_mode_optimizer(opt, hook_input):
59
66
  for index, grad_value in enumerate(gradients):
60
67
  if hook_input.param_list and hook_input.g_names[index] not in hook_input.param_list:
61
68
  continue
62
- grad_dump(hook_input.dump_dir, hook_input.g_names[index], self.dump_step,
63
- grad_value, hook_input.level, hook_input.bounds)
69
+ conf = GradDumpConfig(dump_dir=hook_input.dump_dir, g_name=hook_input.g_names[index],
70
+ dump_step=self.dump_step, grad=grad_value, level=hook_input.level,
71
+ bounds=hook_input.bounds)
72
+ grad_dump(conf)
64
73
  ms.ops.TensorDump()(hook_input.step_finish_flag, self.dump_step)
65
74
  self.assignadd(self.dump_step, self.global_step_increase_tensor)
66
75
  out = hook_input.func(gradients)
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .mindtorch_adaptor import (_call_impl,
17
+ register_full_backward_pre_hook,
18
+ register_full_backward_hook)