mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
  3. msprobe/README.md +27 -22
  4. msprobe/core/common/const.py +129 -60
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/inplace_ops.yaml +1 -0
  9. msprobe/core/common/utils.py +43 -33
  10. msprobe/core/compare/acc_compare.py +43 -74
  11. msprobe/core/compare/check.py +2 -6
  12. msprobe/core/compare/highlight.py +2 -0
  13. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  14. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  15. msprobe/core/compare/merge_result/merge_result.py +16 -9
  16. msprobe/core/compare/merge_result/utils.py +81 -0
  17. msprobe/core/compare/multiprocessing_compute.py +19 -12
  18. msprobe/core/compare/npy_compare.py +30 -12
  19. msprobe/core/compare/utils.py +30 -10
  20. msprobe/core/data_dump/api_registry.py +176 -0
  21. msprobe/core/data_dump/data_collector.py +58 -13
  22. msprobe/core/data_dump/data_processor/base.py +94 -10
  23. msprobe/core/data_dump/data_processor/factory.py +3 -0
  24. msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
  25. msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
  26. msprobe/core/data_dump/json_writer.py +61 -40
  27. msprobe/core/grad_probe/constant.py +1 -0
  28. msprobe/core/grad_probe/grad_compare.py +1 -1
  29. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  30. msprobe/docs/01.installation.md +27 -1
  31. msprobe/docs/02.config_introduction.md +27 -23
  32. msprobe/docs/03.config_examples.md +24 -0
  33. msprobe/docs/05.data_dump_PyTorch.md +103 -16
  34. msprobe/docs/06.data_dump_MindSpore.md +76 -32
  35. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  36. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  37. msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
  38. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  39. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  40. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  41. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  42. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  43. msprobe/docs/18.online_dispatch.md +1 -1
  44. msprobe/docs/19.monitor.md +332 -273
  45. msprobe/docs/21.visualization_PyTorch.md +42 -13
  46. msprobe/docs/22.visualization_MindSpore.md +43 -13
  47. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  48. msprobe/docs/27.dump_json_instruction.md +301 -27
  49. msprobe/docs/28.debugger_save_instruction.md +94 -0
  50. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  51. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  52. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  53. msprobe/docs/FAQ.md +3 -11
  54. msprobe/docs/img/compare_result.png +0 -0
  55. msprobe/docs/img/merge_result.png +0 -0
  56. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  57. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  58. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  59. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  60. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  61. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  63. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  64. msprobe/mindspore/__init__.py +4 -2
  65. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
  66. msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
  67. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  68. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  69. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  70. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  71. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  72. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  73. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
  74. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  75. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  76. msprobe/mindspore/common/const.py +61 -0
  77. msprobe/mindspore/common/utils.py +48 -18
  78. msprobe/mindspore/compare/ms_compare.py +27 -19
  79. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  80. msprobe/mindspore/debugger/debugger_config.py +31 -6
  81. msprobe/mindspore/debugger/precision_debugger.py +45 -14
  82. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  83. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  86. msprobe/mindspore/dump/jit_dump.py +21 -15
  87. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  88. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  89. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  90. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  91. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  92. msprobe/mindspore/grad_probe/global_context.py +2 -0
  93. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  94. msprobe/mindspore/grad_probe/hook.py +2 -4
  95. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  96. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  97. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  98. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  99. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  100. msprobe/mindspore/monitor/features.py +63 -0
  101. msprobe/mindspore/monitor/module_hook.py +873 -0
  102. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  103. msprobe/mindspore/monitor/utils.py +309 -0
  104. msprobe/mindspore/ms_config.py +8 -2
  105. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  106. msprobe/mindspore/service.py +114 -34
  107. msprobe/pytorch/__init__.py +0 -1
  108. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  109. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
  110. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  111. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  112. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  116. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  117. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  118. msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
  119. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
  120. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  121. msprobe/pytorch/common/utils.py +97 -4
  122. msprobe/pytorch/debugger/debugger_config.py +19 -9
  123. msprobe/pytorch/debugger/precision_debugger.py +24 -1
  124. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  125. msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
  126. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  127. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  132. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  133. msprobe/pytorch/function_factory.py +8 -2
  134. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  135. msprobe/pytorch/hook_module/api_register.py +131 -0
  136. msprobe/pytorch/hook_module/hook_module.py +19 -14
  137. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  138. msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
  139. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  140. msprobe/pytorch/monitor/csv2tb.py +18 -14
  141. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  142. msprobe/pytorch/monitor/module_hook.py +238 -193
  143. msprobe/pytorch/monitor/module_metric.py +9 -6
  144. msprobe/pytorch/monitor/optimizer_collect.py +100 -67
  145. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  146. msprobe/pytorch/monitor/utils.py +76 -44
  147. msprobe/pytorch/online_dispatch/compare.py +0 -2
  148. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  149. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  150. msprobe/pytorch/online_dispatch/utils.py +3 -0
  151. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  152. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  153. msprobe/pytorch/pt_config.py +30 -29
  154. msprobe/pytorch/service.py +114 -32
  155. msprobe/visualization/builder/graph_builder.py +75 -10
  156. msprobe/visualization/builder/msprobe_adapter.py +7 -6
  157. msprobe/visualization/compare/graph_comparator.py +42 -38
  158. msprobe/visualization/compare/mode_adapter.py +0 -19
  159. msprobe/visualization/graph/base_node.py +11 -3
  160. msprobe/visualization/graph/distributed_analyzer.py +71 -3
  161. msprobe/visualization/graph/graph.py +0 -11
  162. msprobe/visualization/graph/node_op.py +4 -3
  163. msprobe/visualization/graph_service.py +4 -5
  164. msprobe/visualization/utils.py +12 -35
  165. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
  166. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  167. msprobe/pytorch/hook_module/api_registry.py +0 -166
  168. msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
  169. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  171. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  172. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  173. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  174. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  175. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  176. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  177. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -22,12 +22,12 @@ from mindspore._c_expression import MSContext
22
22
  from msprobe.core.common.const import Const, FileCheckConst, MsgConst
23
23
  from msprobe.core.common.exceptions import MsprobeException
24
24
  from msprobe.core.common.file_utils import FileChecker
25
- from msprobe.core.common.utils import get_real_step_or_rank
25
+ from msprobe.core.common.utils import get_real_step_or_rank, check_init_step
26
26
  from msprobe.mindspore.cell_processor import CellProcessor
27
27
  from msprobe.mindspore.common.const import Const as MsConst
28
- from msprobe.mindspore.common.utils import set_register_backward_hook_functions
28
+ from msprobe.mindspore.common.utils import set_register_backward_hook_functions, check_save_param
29
29
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
30
- from msprobe.mindspore.dump.hook_cell.api_registry import api_register
30
+ from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
31
31
  from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
32
32
  from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor
33
33
  from msprobe.mindspore.ms_config import parse_json_config
@@ -84,11 +84,12 @@ class PrecisionDebugger:
84
84
  common_config.dump_path = dump_path if dump_path else common_config.dump_path
85
85
  self.config = DebuggerConfig(common_config, task_config)
86
86
 
87
- if _msprobe_c:
87
+ if self._need_msprobe_c() and _msprobe_c:
88
88
  _msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path)
89
89
 
90
90
  self.config.execution_mode = self._get_execution_mode()
91
91
  if self._need_service():
92
+ self.config.check_config_with_l2()
92
93
  self.service = Service(self.config)
93
94
 
94
95
  Runtime.step_count = 0
@@ -139,18 +140,18 @@ class PrecisionDebugger:
139
140
  def _is_graph_dump(config):
140
141
  if config.level != MsConst.KERNEL:
141
142
  return False
142
- if not config.list or len(config.list) > 1:
143
+ if not config.list:
143
144
  return True
144
- if '-' in config.list[0] or '/' in config.list[0]:
145
- return True
146
- return False
145
+ is_graph = any(item.startswith("name-regex") for item in config.list)
146
+ is_graph |= all("." not in item for item in config.list)
147
+ return is_graph
147
148
 
148
149
  @classmethod
149
150
  def start(cls, model=None):
150
151
  instance = cls._instance
151
152
  if not instance:
152
153
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
153
- if _msprobe_c:
154
+ if cls._need_msprobe_c() and _msprobe_c:
154
155
  _msprobe_c._PrecisionDebugger().start()
155
156
  if instance.task in PrecisionDebugger.task_not_need_service:
156
157
  return
@@ -162,7 +163,7 @@ class PrecisionDebugger:
162
163
  instance.service.start(model)
163
164
  else:
164
165
  if not instance.first_start:
165
- api_register.api_set_ori_func()
166
+ get_api_register().restore_all_api()
166
167
  handler = TaskHandlerFactory.create(instance.config)
167
168
  handler.handle()
168
169
 
@@ -179,8 +180,6 @@ class PrecisionDebugger:
179
180
  instance = cls._instance
180
181
  if not instance:
181
182
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
182
- if _msprobe_c:
183
- _msprobe_c._PrecisionDebugger().stop()
184
183
  if instance.task == Const.GRAD_PROBE:
185
184
  instance.gm.stop()
186
185
  if instance.task in PrecisionDebugger.task_not_need_service:
@@ -194,8 +193,6 @@ class PrecisionDebugger:
194
193
  instance = cls._instance
195
194
  if not instance:
196
195
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
197
- if _msprobe_c:
198
- _msprobe_c._PrecisionDebugger().step()
199
196
  if instance.task in PrecisionDebugger.task_not_need_service:
200
197
  return
201
198
  if instance.service:
@@ -214,6 +211,33 @@ class PrecisionDebugger:
214
211
  return
215
212
  instance.gm.monitor(opt)
216
213
 
214
+ @classmethod
215
+ def save(cls, variable, name, save_backward=True):
216
+ instance = cls._instance
217
+ if not instance:
218
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
219
+ if instance.task not in [Const.TENSOR, Const.STATISTICS] or instance.config.level_ori != Const.LEVEL_DEBUG:
220
+ return
221
+ try:
222
+ check_save_param(variable, name, save_backward)
223
+ except ValueError:
224
+ return
225
+
226
+ instance.config.execution_mode = cls._get_execution_mode()
227
+ if cls._need_service():
228
+ if not instance.service:
229
+ instance.service = Service(instance.config)
230
+ instance.service.save(variable, name, save_backward)
231
+
232
+ @classmethod
233
+ def set_init_step(cls, step):
234
+ instance = cls._instance
235
+ if not instance:
236
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
237
+ check_init_step(step)
238
+ instance.service.init_step = step
239
+ instance.service.loop = 0
240
+
217
241
  @classmethod
218
242
  def _need_service(cls):
219
243
  instance = cls._instance
@@ -223,3 +247,10 @@ class PrecisionDebugger:
223
247
  return False
224
248
  else:
225
249
  return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config)
250
+
251
+ @classmethod
252
+ def _need_msprobe_c(cls):
253
+ instance = cls._instance
254
+ if not instance:
255
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
256
+ return instance.config.level_ori == Const.LEVEL_L2
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from msprobe.mindspore.common.const import Const
17
+ from msprobe.core.common.log import logger
17
18
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
18
19
  from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
19
20
  from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump
@@ -47,6 +48,7 @@ class DumpToolFactory:
47
48
  raise Exception("Valid level is needed.")
48
49
  tool = tool.get(config.execution_mode)
49
50
  if not tool:
50
- raise Exception(f"Data dump is not supported in {config.execution_mode} mode "
51
- f"when dump level is {config.level}.")
51
+ logger.error(f"Data dump is not supported in {config.execution_mode} mode "
52
+ f"when dump level is {config.level}.")
53
+ raise ValueError
52
54
  return tool(config)
@@ -0,0 +1,142 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from mindspore import Tensor, ops, mint
19
+ from mindspore.mint.nn import functional
20
+ from mindspore.communication import comm_func
21
+
22
+ from msprobe.core.common.file_utils import load_yaml
23
+ from msprobe.core.common.utils import Const
24
+ from msprobe.core.data_dump.api_registry import ApiRegistry
25
+ from msprobe.mindspore.common.const import Const as MsConst
26
+ from msprobe.mindspore.common.utils import is_mindtorch
27
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
28
+
29
+
30
+ stub_tensor_existed = True
31
+ try:
32
+ from mindspore.common._stub_tensor import StubTensor
33
+ except ImportError:
34
+ stub_tensor_existed = False
35
+
36
+ cur_path = os.path.dirname(os.path.realpath(__file__))
37
+ if not is_mindtorch():
38
+ _api_types = {
39
+ Const.MS_FRAMEWORK: {
40
+ Const.MS_API_TYPE_OPS: (ops, (ops,)),
41
+ Const.MS_API_TYPE_TENSOR: (Tensor, (Tensor,)),
42
+ Const.MS_API_TYPE_MINT: (mint, (mint,)),
43
+ Const.MS_API_TYPE_MINT_FUNC: (functional, (functional,)),
44
+ Const.MS_API_TYPE_COM: (comm_func, (comm_func,))
45
+ }
46
+ }
47
+ if stub_tensor_existed:
48
+ _api_types.get(Const.MS_FRAMEWORK).update(
49
+ {Const.MS_API_TYPE_STUB_TENSOR: (StubTensor, (StubTensor,))}
50
+ )
51
+
52
+ _supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),)
53
+ else:
54
+ import torch
55
+ import torch_npu
56
+ _api_types = {
57
+ Const.MT_FRAMEWORK: {
58
+ Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
59
+ Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)),
60
+ Const.PT_API_TYPE_TORCH: (torch, (torch,)),
61
+ Const.PT_API_TYPE_NPU: (torch_npu, (torch_npu,)),
62
+ Const.PT_API_TYPE_DIST: (torch.distributed, (torch.distributed, torch.distributed.distributed_c10d))
63
+ }
64
+ }
65
+ _supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module',
66
+ MsConst.SUPPORTED_API_LIST_FILE),)
67
+
68
+ _inner_used_api = {
69
+ Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: (
70
+ ops, "norm", "square", "sqrt", "is_complex", "stack", "is_floating_point"
71
+ ),
72
+ Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_TENSOR: (
73
+ Tensor, "to", "numel"
74
+ ),
75
+ Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_MINT: (
76
+ mint, "max", "min", "mean", "norm"
77
+ )
78
+ }
79
+
80
+
81
+ class ApiTemplate(HOOKCell):
82
+ def __init__(self, api_name, api_func, prefix, hook_build_func):
83
+ self.api_name = api_name
84
+ self.api_func = api_func
85
+ self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
86
+ super().__init__(hook_build_func)
87
+
88
+ @staticmethod
89
+ def async_to_sync(output):
90
+ # Fake handle, used to return after the CommHandle executes the wait method
91
+ fake_handle = type("FakeHandle", (), {"wait": lambda self: None})()
92
+ if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"):
93
+ output[1].wait()
94
+ output = (output[0], fake_handle)
95
+ elif hasattr(output, "wait"):
96
+ output.wait()
97
+ output = fake_handle
98
+ return output
99
+
100
+ def construct(self, *args, **kwargs):
101
+ if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
102
+ return args[0] if args else kwargs.get(Const.INPUT)
103
+
104
+ output = self.api_func(*args, **kwargs)
105
+
106
+ if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX):
107
+ if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]:
108
+ output = self.async_to_sync(output)
109
+ return output
110
+
111
+ def forward(self, *args, **kwargs):
112
+ if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
113
+ return args[0] if args else kwargs.get(Const.INPUT)
114
+ return self.api_func(*args, **kwargs)
115
+
116
+
117
+ api_register = None
118
+ stub_tensor_set = False
119
+
120
+
121
+ def get_api_register(return_new=False):
122
+ global stub_tensor_set
123
+
124
+ def stub_method(method):
125
+ def wrapped_method(*args, **kwargs):
126
+ return method(*args, **kwargs)
127
+ return wrapped_method
128
+ if not is_mindtorch() and stub_tensor_existed and not stub_tensor_set:
129
+ api_names = load_yaml(_supported_api_list_path[0]).get(Const.MS_API_TYPE_TENSOR, [])
130
+ for attr_name in dir(StubTensor):
131
+ attr = getattr(StubTensor, attr_name)
132
+ if attr_name in api_names and callable(attr):
133
+ setattr(StubTensor, attr_name, stub_method(attr))
134
+ stub_tensor_set = True
135
+
136
+ if return_new:
137
+ return ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
138
+
139
+ global api_register
140
+ if api_register is None:
141
+ api_register = ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
142
+ return api_register
@@ -28,23 +28,22 @@ def get_cell_count(name):
28
28
  return HOOKCell.cell_count[name]
29
29
 
30
30
 
31
- def __init__(self, build_hook) -> None:
31
+ def __init__(self, hook_build_func) -> None:
32
32
  super(HOOKCell, self).__init__()
33
33
  self.changed_status = False
34
34
  self.input_kwargs = {}
35
- self.prefix = ""
36
35
  if not HOOKCell.g_stop_hook:
37
36
  HOOKCell.g_stop_hook = True
38
37
  self.changed_status = True
39
- if hasattr(self, "prefix_api_name"):
40
- self.prefix = self.prefix_api_name
41
-
42
38
  self.forward_data_collected = False
43
- forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = build_hook(self.prefix)
44
- self.register_forward_pre_hook(forward_pre_hook)
45
- self.register_forward_hook(forward_hook)
46
- register_backward_hook_functions["full"](self, backward_hook)
47
- register_backward_hook_functions["pre"](self, backward_pre_hook)
39
+
40
+ prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
41
+ if callable(hook_build_func):
42
+ forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = hook_build_func(prefix)
43
+ self.register_forward_pre_hook(forward_pre_hook)
44
+ self.register_forward_hook(forward_hook)
45
+ register_backward_hook_functions["full"](self, backward_hook)
46
+ register_backward_hook_functions["pre"](self, backward_pre_hook)
48
47
 
49
48
 
50
49
  # 重载call,加全局标志。
@@ -564,15 +564,15 @@ tensor:
564
564
  - all
565
565
  - amax
566
566
  - amin
567
+ - angle
567
568
  - any
568
569
  - arccos
569
570
  - arccosh
570
- - argmax
571
- - angle
572
571
  - arcsin
573
572
  - arcsinh
574
573
  - arctan
575
574
  - arctanh
575
+ - argmax
576
576
  - argmin
577
577
  - argsort
578
578
  - asin
@@ -582,19 +582,23 @@ tensor:
582
582
  - atanh
583
583
  - baddbmm
584
584
  - bernoulli
585
+ - bfloat16
585
586
  - bincount
586
587
  - bitwise_and
587
588
  - bitwise_or
588
589
  - bitwise_xor
589
590
  - bmm
590
591
  - bool
592
+ - bool astype
591
593
  - broadcast_to
594
+ - byte
592
595
  - ceil
593
- - cholesky_solve
594
596
  - cholesky
597
+ - cholesky_solve
595
598
  - clamp
596
599
  - clip
597
600
  - conj
601
+ - copy
598
602
  - copysign
599
603
  - cos
600
604
  - cosh
@@ -606,11 +610,13 @@ tensor:
606
610
  - deg2rad
607
611
  - diag
608
612
  - diagflat
613
+ - diagonal
609
614
  - diff
610
615
  - digamma
611
616
  - div
612
617
  - div_
613
618
  - divide
619
+ - double
614
620
  - equal
615
621
  - erf
616
622
  - erfc
@@ -618,13 +624,16 @@ tensor:
618
624
  - exp
619
625
  - expand_as
620
626
  - expm1
627
+ - flatten
621
628
  - flip
622
629
  - fliplr
623
630
  - flipud
631
+ - float
624
632
  - float_power
625
633
  - floor
626
634
  - fmod
627
635
  - frac
636
+ - from_numpy
628
637
  - gather_elements
629
638
  - ge
630
639
  - geqrf
@@ -648,12 +657,12 @@ tensor:
648
657
  - inner
649
658
  - int
650
659
  - inverse
660
+ - is_complex
661
+ - is_signed
651
662
  - isclose
652
663
  - isfinite
653
664
  - isinf
654
665
  - isnan
655
- - is_complex
656
- - is_signed
657
666
  - isneginf
658
667
  - isposinf
659
668
  - isreal
@@ -704,28 +713,27 @@ tensor:
704
713
  - new_ones
705
714
  - new_zeros
706
715
  - nextafter
707
- - norm
708
716
  - nonzero
717
+ - norm
709
718
  - not_equal
710
719
  - ormqr
711
720
  - permute
712
721
  - pow
713
722
  - prod
714
723
  - qr
724
+ - rad2deg
715
725
  - ravel
716
726
  - real
717
727
  - reciprocal
718
728
  - remainder
719
729
  - renorm
720
- - rad2deg
721
- - tile
722
730
  - repeat_interleave
723
731
  - reshape
724
732
  - reshape
725
- - round
733
+ - resize
726
734
  - rot90
735
+ - round
727
736
  - rsqrt
728
- - sum_to_size
729
737
  - scatter
730
738
  - sgn
731
739
  - short
@@ -745,7 +753,8 @@ tensor:
745
753
  - sub
746
754
  - sub_
747
755
  - subtract
748
- - subtract
756
+ - sum
757
+ - sum_to_size
749
758
  - svd
750
759
  - swapaxes
751
760
  - swapdims
@@ -753,13 +762,13 @@ tensor:
753
762
  - take
754
763
  - tan
755
764
  - tanh
756
- - trace
757
- - swapaxes
765
+ - tensor_split
758
766
  - tile
767
+ - to
759
768
  - topk
760
- - tril
761
- - tensor_split
769
+ - trace
762
770
  - transpose
771
+ - tril
763
772
  - true_divide
764
773
  - trunc
765
774
  - unbind
@@ -769,17 +778,6 @@ tensor:
769
778
  - view
770
779
  - where
771
780
  - xlogy
772
- - from_numpy
773
- - std
774
- - take
775
- - var
776
- - all
777
- - any
778
- - copy
779
- - diagonal
780
- - flatten
781
- - resize
782
- - sum
783
781
 
784
782
  mint.ops:
785
783
  - abs
@@ -16,15 +16,20 @@
16
16
  import os
17
17
  from collections import defaultdict
18
18
 
19
- from mindspore import Tensor
19
+ import mindspore
20
20
  from mindspore._c_expression import PyNativeExecutor_
21
- from mindspore.common.api import _MindsporeFunctionExecutor
21
+ try:
22
+ from mindspore.common.api import _MindsporeFunctionExecutor
23
+ except ImportError:
24
+ from mindspore.common.api import _JitExecutor as _MindsporeFunctionExecutor
22
25
 
23
26
  from msprobe.core.common.log import logger
24
- from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
25
27
  from msprobe.core.common.const import Const
26
- from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs
27
- from msprobe.mindspore.dump.hook_cell.api_registry import api_register
28
+ from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
29
+ from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
30
+
31
+
32
+ _api_register = get_api_register()
28
33
 
29
34
 
30
35
  def dump_jit(name, in_feat, out_feat, is_forward):
@@ -40,8 +45,8 @@ def dump_jit(name, in_feat, out_feat, is_forward):
40
45
  if JitDump.need_dump():
41
46
  if is_forward:
42
47
  JitDump.jit_count[result] += 1
43
- name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \
44
- Const.FORWARD
48
+ name_template = (Const.JIT + Const.SEP + result + Const.SEP +
49
+ str(JitDump.jit_count[result]) + Const.SEP + Const.FORWARD)
45
50
  JitDump.data_collector.update_api_or_module_name(name_template)
46
51
  module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat)
47
52
  JitDump.data_collector.forward_data_collect(name_template, None, pid, module_input_output)
@@ -56,7 +61,7 @@ def dump_jit(name, in_feat, out_feat, is_forward):
56
61
  class JitDump(_MindsporeFunctionExecutor):
57
62
  dump_config = None
58
63
  jit_enable = False
59
- jit_dump_switch = True
64
+ jit_dump_switch = False
60
65
  jit_count = defaultdict(int)
61
66
 
62
67
  def __init__(self, *args, **kwargs):
@@ -67,8 +72,7 @@ class JitDump(_MindsporeFunctionExecutor):
67
72
  self._executor = PyNativeExecutor_.get_instance()
68
73
 
69
74
  def __call__(self, *args, **kwargs):
70
- if JitDump.jit_dump_switch:
71
- api_register.api_set_ori_func()
75
+ _api_register.restore_all_api()
72
76
  out = super().__call__(*args, **kwargs)
73
77
  if JitDump.jit_dump_switch and len(args) > 0:
74
78
  if self.name and self.name != "construct":
@@ -78,8 +82,7 @@ class JitDump(_MindsporeFunctionExecutor):
78
82
  JitDump.jit_enable = True
79
83
  elif len(args) == 0:
80
84
  logger.warning(f"The jit function {self.name} has no input arguments, nothing will be dumped.")
81
- if JitDump.jit_dump_switch:
82
- api_register.api_set_hook_func()
85
+ _api_register.register_all_api()
83
86
  return out
84
87
 
85
88
  @classmethod
@@ -100,9 +103,12 @@ class JitDump(_MindsporeFunctionExecutor):
100
103
 
101
104
  def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
102
105
  if JitDump.jit_dump_switch and JitDump.jit_enable:
103
- api_register.api_set_ori_func()
104
- output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
106
+ _api_register.restore_all_api()
107
+ if mindspore.__version__ >= "2.5":
108
+ output = self._executor.grad(grad, obj, weights, grad_position, False, *args, *(kwargs.values()))
109
+ else:
110
+ output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
105
111
  if JitDump.jit_dump_switch and JitDump.jit_enable:
106
112
  dump_jit(obj, args, None, False)
107
- api_register.api_set_hook_func()
113
+ _api_register.register_all_api()
108
114
  return output
@@ -18,37 +18,10 @@
18
18
  #include <sys/stat.h>
19
19
  #include <cstdlib>
20
20
  #include <cstring>
21
+ #include <pybind11/embed.h>
21
22
  #include "utils/log_adapter.h"
22
23
 
23
- namespace {
24
-
25
- // Utility function to check if a file path is valid
26
- bool IsValidPath(const std::string &path) {
27
- struct stat fileStat;
28
- if (stat(path.c_str(), &fileStat) != 0) {
29
- MS_LOG(ERROR) << "File does not exist or cannot be accessed: " << path;
30
- return false;
31
- }
32
-
33
- if (S_ISLNK(fileStat.st_mode)) {
34
- MS_LOG(ERROR) << "File is a symbolic link, which is not allowed: " << path;
35
- return false;
36
- }
37
-
38
- if (!S_ISREG(fileStat.st_mode)) {
39
- MS_LOG(ERROR) << "File is not a regular file: " << path;
40
- return false;
41
- }
42
-
43
- if (path.substr(path.find_last_of(".")) != ".so") {
44
- MS_LOG(ERROR) << "File is not a .so file: " << path;
45
- return false;
46
- }
47
-
48
- return true;
49
- }
50
-
51
- } // namespace
24
+ namespace py = pybind11;
52
25
 
53
26
  HookDynamicLoader &HookDynamicLoader::GetInstance() {
54
27
  static HookDynamicLoader instance;
@@ -65,38 +38,31 @@ bool HookDynamicLoader::loadFunction(void *handle, const std::string &functionNa
65
38
  return true;
66
39
  }
67
40
 
68
- bool HookDynamicLoader::validateLibraryPath(const std::string &libPath) {
69
- char *realPath = realpath(libPath.c_str(), nullptr);
70
- if (!realPath) {
71
- MS_LOG(WARNING) << "Failed to resolve realpath for the library: " << libPath;
72
- return false;
73
- }
74
-
75
- bool isValid = IsValidPath(realPath);
76
- free(realPath); // Free memory allocated by realpath
77
- return isValid;
78
- }
79
-
80
41
  bool HookDynamicLoader::LoadLibrary() {
81
- const char *libPath = std::getenv("HOOK_TOOL_PATH");
82
- if (!libPath) {
83
- MS_LOG(WARNING) << "HOOK_TOOL_PATH is not set!";
84
- return false;
85
- }
86
-
87
- std::string resolvedLibPath(libPath);
88
- if (!validateLibraryPath(resolvedLibPath)) {
89
- MS_LOG(WARNING) << "Library path validation failed.";
90
- return false;
91
- }
92
-
42
+ std::string msprobePath = "";
43
+ // 获取gil锁
44
+ py::gil_scoped_acquire acquire;
45
+ try {
46
+ py::module msprobeMod = py::module::import("msprobe.lib._msprobe_c");
47
+ if (!py::hasattr(msprobeMod, "__file__")) {
48
+ MS_LOG(WARNING) << "Adump mod not found";
49
+ return false;
50
+ }
51
+ msprobePath = msprobeMod.attr("__file__").cast<std::string>();
52
+ } catch (const std::exception& e) {
53
+ MS_LOG(WARNING) << "Adump mod path unable to get: " << e.what();
54
+ return false;
55
+ }
93
56
  std::lock_guard<std::mutex> lock(mutex_);
94
57
  if (handle_) {
95
58
  MS_LOG(WARNING) << "Hook library already loaded!";
96
59
  return false;
97
60
  }
98
-
99
- handle_ = dlopen(resolvedLibPath.c_str(), RTLD_LAZY | RTLD_LOCAL);
61
+ if (msprobePath == "") {
62
+ MS_LOG(WARNING) << "Adump path not loaded";
63
+ return false;
64
+ }
65
+ handle_ = dlopen(msprobePath.c_str(), RTLD_LAZY | RTLD_LOCAL);
100
66
  if (!handle_) {
101
67
  MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror();
102
68
  return false;
@@ -104,7 +70,7 @@ bool HookDynamicLoader::LoadLibrary() {
104
70
 
105
71
  for (const auto &functionName : functionList_) {
106
72
  if (!loadFunction(handle_, functionName)) {
107
- MS_LOG(WARNING) << "Failed to load function: " << functionName;
73
+ MS_LOG(WARNING) << "Failed to load adump function";
108
74
  dlclose(handle_);
109
75
  handle_ = nullptr;
110
76
  return false;
@@ -40,7 +40,6 @@ class HookDynamicLoader {
40
40
  private:
41
41
  // Helper functions
42
42
  bool loadFunction(void *handle, const std::string &functionName);
43
- bool validateLibraryPath(const std::string &libPath);
44
43
 
45
44
  HookDynamicLoader() = default;
46
45