mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -21,6 +21,7 @@ from mindspore.common.tensor import Tensor
21
21
  from msprobe.core.common.utils import Const, DumpException
22
22
  from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, ModuleBackwardOutputs,
23
23
  ModuleForwardInputsOutputs)
24
+ from msprobe.core.hook_manager import BaseHookManager
24
25
  from msprobe.mindspore.common.log import logger
25
26
 
26
27
 
@@ -58,7 +59,7 @@ class PrimitiveHookService:
58
59
  def backward_hook(grad):
59
60
  captured_grads.extend(grad)
60
61
  backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}"
61
-
62
+ self.service_instance.inner_switch = True
62
63
  try:
63
64
  if hook_type == Const.INPUT:
64
65
  self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
@@ -77,6 +78,7 @@ class PrimitiveHookService:
77
78
  logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, "
78
79
  f"updated_primitive_name: {updated_primitive_name}")
79
80
  raise DumpException(DumpException.BACKWARD_DATA_COLLECTION_ERROR) from exception
81
+ self.service_instance.inner_switch = False
80
82
 
81
83
  return backward_hook
82
84
 
@@ -137,6 +139,7 @@ class PrimitiveHookService:
137
139
 
138
140
  def pre_forward_hook(primitive_name, primitive_instance, args, kwargs):
139
141
  module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
142
+ self.service_instance.inner_switch = True
140
143
  try:
141
144
  self.service_instance.data_collector.forward_input_data_collect(
142
145
  primitive_name,
@@ -148,9 +151,11 @@ class PrimitiveHookService:
148
151
  logger.error(f"This is a primitive op dump error during forward input data collection: {exception}, "
149
152
  f"primitive_name: {primitive_name}")
150
153
  raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
154
+ self.service_instance.inner_switch = False
151
155
 
152
156
  def post_forward_hook(primitive_name, primitive_instance, args, kwargs, output):
153
157
  module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
158
+ self.service_instance.inner_switch = True
154
159
  try:
155
160
  self.service_instance.data_collector.forward_output_data_collect(
156
161
  primitive_name,
@@ -162,6 +167,7 @@ class PrimitiveHookService:
162
167
  logger.error(f"This is a primitive op dump error during forward output data collection: {exception}, "
163
168
  f"primitive_name: {primitive_name}")
164
169
  raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
170
+ self.service_instance.inner_switch = False
165
171
 
166
172
  def wrapped_primitive_call(instance_self, *args, **kwargs):
167
173
  """
@@ -179,7 +185,7 @@ class PrimitiveHookService:
179
185
  current_count = self.primitive_counters.get(primitive_name, 0)
180
186
  updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}{Const.SEP}{primitive_name}{Const.SEP}{current_count}"
181
187
 
182
- if not self.service_instance.primitive_switch:
188
+ if not self.service_instance.primitive_switch or BaseHookManager.inner_switch:
183
189
  return origin_func(*args, **kwargs)
184
190
 
185
191
  captured_grads_input, captured_grads_output = [], []
@@ -564,15 +564,15 @@ tensor:
564
564
  - all
565
565
  - amax
566
566
  - amin
567
+ - angle
567
568
  - any
568
569
  - arccos
569
570
  - arccosh
570
- - argmax
571
- - angle
572
571
  - arcsin
573
572
  - arcsinh
574
573
  - arctan
575
574
  - arctanh
575
+ - argmax
576
576
  - argmin
577
577
  - argsort
578
578
  - asin
@@ -582,19 +582,23 @@ tensor:
582
582
  - atanh
583
583
  - baddbmm
584
584
  - bernoulli
585
+ - bfloat16
585
586
  - bincount
586
587
  - bitwise_and
587
588
  - bitwise_or
588
589
  - bitwise_xor
589
590
  - bmm
590
591
  - bool
592
+ - bool astype
591
593
  - broadcast_to
594
+ - byte
592
595
  - ceil
593
- - cholesky_solve
594
596
  - cholesky
597
+ - cholesky_solve
595
598
  - clamp
596
599
  - clip
597
600
  - conj
601
+ - copy
598
602
  - copysign
599
603
  - cos
600
604
  - cosh
@@ -606,11 +610,13 @@ tensor:
606
610
  - deg2rad
607
611
  - diag
608
612
  - diagflat
613
+ - diagonal
609
614
  - diff
610
615
  - digamma
611
616
  - div
612
617
  - div_
613
618
  - divide
619
+ - double
614
620
  - equal
615
621
  - erf
616
622
  - erfc
@@ -618,13 +624,16 @@ tensor:
618
624
  - exp
619
625
  - expand_as
620
626
  - expm1
627
+ - flatten
621
628
  - flip
622
629
  - fliplr
623
630
  - flipud
631
+ - float
624
632
  - float_power
625
633
  - floor
626
634
  - fmod
627
635
  - frac
636
+ - from_numpy
628
637
  - gather_elements
629
638
  - ge
630
639
  - geqrf
@@ -648,12 +657,12 @@ tensor:
648
657
  - inner
649
658
  - int
650
659
  - inverse
660
+ - is_complex
661
+ - is_signed
651
662
  - isclose
652
663
  - isfinite
653
664
  - isinf
654
665
  - isnan
655
- - is_complex
656
- - is_signed
657
666
  - isneginf
658
667
  - isposinf
659
668
  - isreal
@@ -704,28 +713,27 @@ tensor:
704
713
  - new_ones
705
714
  - new_zeros
706
715
  - nextafter
707
- - norm
708
716
  - nonzero
717
+ - norm
709
718
  - not_equal
710
719
  - ormqr
711
720
  - permute
712
721
  - pow
713
722
  - prod
714
723
  - qr
724
+ - rad2deg
715
725
  - ravel
716
726
  - real
717
727
  - reciprocal
718
728
  - remainder
719
729
  - renorm
720
- - rad2deg
721
- - tile
722
730
  - repeat_interleave
723
731
  - reshape
724
732
  - reshape
725
- - round
733
+ - resize
726
734
  - rot90
735
+ - round
727
736
  - rsqrt
728
- - sum_to_size
729
737
  - scatter
730
738
  - sgn
731
739
  - short
@@ -745,7 +753,8 @@ tensor:
745
753
  - sub
746
754
  - sub_
747
755
  - subtract
748
- - subtract
756
+ - sum
757
+ - sum_to_size
749
758
  - svd
750
759
  - swapaxes
751
760
  - swapdims
@@ -753,13 +762,13 @@ tensor:
753
762
  - take
754
763
  - tan
755
764
  - tanh
756
- - trace
757
- - swapaxes
765
+ - tensor_split
758
766
  - tile
767
+ - to
759
768
  - topk
760
- - tril
761
- - tensor_split
769
+ - trace
762
770
  - transpose
771
+ - tril
763
772
  - true_divide
764
773
  - trunc
765
774
  - unbind
@@ -769,17 +778,6 @@ tensor:
769
778
  - view
770
779
  - where
771
780
  - xlogy
772
- - from_numpy
773
- - std
774
- - take
775
- - var
776
- - all
777
- - any
778
- - copy
779
- - diagonal
780
- - flatten
781
- - resize
782
- - sum
783
781
 
784
782
  mint.ops:
785
783
  - abs
@@ -1027,3 +1025,21 @@ communication.comm_func:
1027
1025
  - recv
1028
1026
  - isend
1029
1027
  - irecv
1028
+
1029
+ mint.distributed:
1030
+ - send
1031
+ - recv
1032
+ - broadcast
1033
+ - all_reduce
1034
+ - reduce
1035
+ - all_gather
1036
+ - gather
1037
+ - isend
1038
+ - irecv
1039
+ - scatter
1040
+ - reduce_scatter
1041
+ - all_to_all_single
1042
+ - all_to_all
1043
+ - all_gather_into_tensor
1044
+ - reduce_scatter_tensor
1045
+ - batch_isend_irecv
@@ -13,9 +13,12 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import os
17
16
  from collections import defaultdict
17
+ import os
18
+ import types
18
19
 
20
+ import mindspore
21
+ from mindspore import nn
19
22
  from mindspore._c_expression import PyNativeExecutor_
20
23
  try:
21
24
  from mindspore.common.api import _MindsporeFunctionExecutor
@@ -24,30 +27,31 @@ except ImportError:
24
27
 
25
28
  from msprobe.core.common.log import logger
26
29
  from msprobe.core.common.const import Const
30
+ from msprobe.core.common.runtime import Runtime
27
31
  from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
28
- from msprobe.mindspore.dump.hook_cell.api_registry import api_register
32
+ from msprobe.mindspore.common.const import Const as MsConst
33
+ from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
34
+
35
+
36
+ _api_register = get_api_register()
29
37
 
30
38
 
31
39
  def dump_jit(name, in_feat, out_feat, is_forward):
32
40
  pid = os.getpid()
33
- ori_args = str(name)
34
- index = ori_args.find("<")
35
- if index != 0 and index != -1:
36
- result = ori_args[0:index]
37
- elif name is not None and "<" not in str(name):
38
- result = str(name)
39
- else:
40
- result = "JitFunction"
41
+ name = name if name else "JitFunction"
41
42
  if JitDump.need_dump():
42
43
  if is_forward:
43
- JitDump.jit_count[result] += 1
44
- name_template = (Const.JIT + Const.SEP + result + Const.SEP +
45
- str(JitDump.jit_count[result]) + Const.SEP + Const.FORWARD)
44
+ if name in JitDump.jit_count:
45
+ JitDump.jit_count[name] += 1
46
+ else:
47
+ JitDump.jit_count[name] = 0
48
+ name_template = (Const.JIT + Const.SEP + name + Const.SEP +
49
+ str(JitDump.jit_count[name]) + Const.SEP + Const.FORWARD)
46
50
  JitDump.data_collector.update_api_or_module_name(name_template)
47
51
  module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat)
48
52
  JitDump.data_collector.forward_data_collect(name_template, None, pid, module_input_output)
49
53
  else:
50
- name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \
54
+ name_template = Const.JIT + Const.SEP + name + Const.SEP + str(JitDump.jit_count[name]) + Const.SEP + \
51
55
  Const.BACKWARD
52
56
  JitDump.data_collector.update_api_or_module_name(name_template)
53
57
  module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat, grad_output=out_feat)
@@ -57,7 +61,7 @@ def dump_jit(name, in_feat, out_feat, is_forward):
57
61
  class JitDump(_MindsporeFunctionExecutor):
58
62
  dump_config = None
59
63
  jit_enable = False
60
- jit_dump_switch = True
64
+ jit_dump_switch = False
61
65
  jit_count = defaultdict(int)
62
66
 
63
67
  def __init__(self, *args, **kwargs):
@@ -68,19 +72,17 @@ class JitDump(_MindsporeFunctionExecutor):
68
72
  self._executor = PyNativeExecutor_.get_instance()
69
73
 
70
74
  def __call__(self, *args, **kwargs):
71
- if JitDump.jit_dump_switch:
72
- api_register.api_set_ori_func()
75
+ _api_register.restore_all_api()
73
76
  out = super().__call__(*args, **kwargs)
74
- if JitDump.jit_dump_switch and len(args) > 0:
75
- if self.name and self.name != "construct":
77
+ if JitDump.jit_dump_switch and len(args) > 0 and self.name:
78
+ if self.name != "construct":
76
79
  dump_jit(self.name, args, out, True)
77
- else:
78
- dump_jit(args[0], args, out, True)
80
+ elif Runtime.run_mode != MsConst.PYNATIVE_GRAPH_MODE and isinstance(args[0], nn.Cell):
81
+ dump_jit(args[0].__class__.__name__, args, out, True)
79
82
  JitDump.jit_enable = True
80
83
  elif len(args) == 0:
81
84
  logger.warning(f"The jit function {self.name} has no input arguments, nothing will be dumped.")
82
- if JitDump.jit_dump_switch:
83
- api_register.api_set_hook_func()
85
+ _api_register.register_all_api()
84
86
  return out
85
87
 
86
88
  @classmethod
@@ -101,9 +103,15 @@ class JitDump(_MindsporeFunctionExecutor):
101
103
 
102
104
  def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
103
105
  if JitDump.jit_dump_switch and JitDump.jit_enable:
104
- api_register.api_set_ori_func()
105
- output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
106
+ _api_register.restore_all_api()
107
+ if mindspore.__version__ >= "2.5":
108
+ output = self._executor.grad(grad, obj, weights, grad_position, False, *args, *(kwargs.values()))
109
+ else:
110
+ output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
106
111
  if JitDump.jit_dump_switch and JitDump.jit_enable:
107
- dump_jit(obj, args, None, False)
108
- api_register.api_set_hook_func()
112
+ if isinstance(obj, types.FunctionType):
113
+ dump_jit(obj.__name__, args, None, False)
114
+ elif Runtime.run_mode != MsConst.PYNATIVE_GRAPH_MODE and isinstance(obj, nn.Cell):
115
+ dump_jit(obj.__class__.__name__, args, None, False)
116
+ _api_register.register_all_api()
109
117
  return output
@@ -39,9 +39,12 @@ class KernelKbykDump:
39
39
  common_set["input_output"] = 0
40
40
  common_set["kernels"] = []
41
41
  common_set["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7]
42
- e2e_set = dict()
43
- e2e_set["enable"] = True
44
- e2e_set["trans_flag"] = True
42
+ e2e_set = {
43
+ "enable": not config.async_dump,
44
+ "trans_flag": True,
45
+ "stat_calc_mode": config.stat_cal_mode,
46
+ "device_stat_precision_mode": config.device_stat_precision_mode,
47
+ }
45
48
 
46
49
  if config.list:
47
50
  common_set["dump_mode"] = 1
@@ -0,0 +1,110 @@
1
+ /**
2
+ * Copyright 2024 Huawei Technologies Co., Ltd
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include "hook_dynamic_loader.h"
18
+ #include <sys/stat.h>
19
+ #include <cstdlib>
20
+ #include <cstring>
21
+ #include <pybind11/embed.h>
22
+ #include "utils/log_adapter.h"
23
+
24
+ namespace py = pybind11;
25
+
26
+ HookDynamicLoader &HookDynamicLoader::GetInstance()
27
+ {
28
+ static HookDynamicLoader instance;
29
+ return instance;
30
+ }
31
+
32
+ bool HookDynamicLoader::LoadFunction(void *handle, const std::string &functionName) {
33
+ void *func = dlsym(handle, functionName.c_str());
34
+ if (!func) {
35
+ MS_LOG(WARNING) << "Could not load function: " << functionName << ", error: " << dlerror();
36
+ return false;
37
+ }
38
+ funcMap_[functionName] = func;
39
+ return true;
40
+ }
41
+
42
+ bool HookDynamicLoader::LoadLibrary()
43
+ {
44
+ std::string msprobePath = "";
45
+ // 获取gil锁
46
+ py::gil_scoped_acquire acquire;
47
+ try {
48
+ py::module msprobeMod = py::module::import("msprobe.lib._msprobe_c");
49
+ if (!py::hasattr(msprobeMod, "__file__")) {
50
+ MS_LOG(WARNING) << "Adump mod not found";
51
+ return false;
52
+ }
53
+ msprobePath = msprobeMod.attr("__file__").cast<std::string>();
54
+ } catch (const std::exception& e) {
55
+ MS_LOG(WARNING) << "Adump mod path unable to get: " << e.what();
56
+ return false;
57
+ }
58
+ std::lock_guard<std::mutex> lock(mutex_);
59
+ if (handle_) {
60
+ MS_LOG(WARNING) << "Hook library already loaded!";
61
+ return false;
62
+ }
63
+ if (msprobePath == "") {
64
+ MS_LOG(WARNING) << "Adump path not loaded";
65
+ return false;
66
+ }
67
+ handle_ = dlopen(msprobePath.c_str(), RTLD_LAZY | RTLD_LOCAL);
68
+ if (!handle_) {
69
+ MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror();
70
+ return false;
71
+ }
72
+
73
+ for (const auto &functionName : functionList_) {
74
+ if (!LoadFunction(handle_, functionName)) {
75
+ MS_LOG(WARNING) << "Failed to load adump function";
76
+ dlclose(handle_);
77
+ handle_ = nullptr;
78
+ return false;
79
+ }
80
+ }
81
+
82
+ MS_LOG(INFO) << "Hook library loaded successfully.";
83
+ return true;
84
+ }
85
+
86
+ bool HookDynamicLoader::UnloadLibrary()
87
+ {
88
+ std::lock_guard<std::mutex> lock(mutex_);
89
+ if (!handle_) {
90
+ MS_LOG(WARNING) << "Hook library hasn't been loaded.";
91
+ return false;
92
+ }
93
+
94
+ dlclose(handle_);
95
+ handle_ = nullptr;
96
+ funcMap_.clear();
97
+ MS_LOG(INFO) << "Library unloaded successfully.";
98
+ return true;
99
+ }
100
+
101
+ void *HookDynamicLoader::GetHooker(const std::string &funcName)
102
+ {
103
+ std::lock_guard<std::mutex> lock(mutex_);
104
+ auto iter = funcMap_.find(funcName);
105
+ if (iter == funcMap_.end()) {
106
+ MS_LOG(WARNING) << "Function not found: " << funcName;
107
+ return nullptr;
108
+ }
109
+ return iter->second;
110
+ }
@@ -27,27 +27,26 @@ constexpr auto kHookBegin = "MS_DbgOnStepBegin";
27
27
  constexpr auto kHookEnd = "MS_DbgOnStepEnd";
28
28
 
29
29
  class HookDynamicLoader {
30
- public:
31
- static HookDynamicLoader &GetInstance();
30
+ public:
31
+ static HookDynamicLoader &GetInstance();
32
32
 
33
- HookDynamicLoader(const HookDynamicLoader &) = delete;
34
- HookDynamicLoader &operator=(const HookDynamicLoader &) = delete;
33
+ HookDynamicLoader(const HookDynamicLoader &) = delete;
34
+ HookDynamicLoader &operator=(const HookDynamicLoader &) = delete;
35
35
 
36
- bool LoadLibrary();
37
- bool UnloadLibrary();
38
- void *GetHooker(const std::string &funcName);
36
+ bool LoadLibrary();
37
+ bool UnloadLibrary();
38
+ void *GetHooker(const std::string &funcName);
39
39
 
40
- private:
41
- // Helper functions
42
- bool loadFunction(void *handle, const std::string &functionName);
43
- bool validateLibraryPath(const std::string &libPath);
40
+ private:
41
+ // Helper functions
42
+ bool LoadFunction(void *handle, const std::string &functionName);
44
43
 
45
- HookDynamicLoader() = default;
44
+ HookDynamicLoader() = default;
46
45
 
47
- void *handle_ = nullptr;
48
- std::vector<std::string> functionList_ = {kHookBegin, kHookEnd};
49
- std::map<std::string, void *> funcMap_;
50
- std::mutex mutex_;
46
+ void *handle_ = nullptr;
47
+ std::vector<std::string> functionList_ = {kHookBegin, kHookEnd};
48
+ std::map<std::string, void *> funcMap_;
49
+ std::mutex mutex_;
51
50
  };
52
51
 
53
52
  #endif // HOOK_DYNAMIC_LOADER_H
@@ -19,22 +19,27 @@ import os
19
19
  import traceback
20
20
 
21
21
  import mindspore as ms
22
+
22
23
  from msprobe.core.common.const import Const
23
24
  from msprobe.core.common.exceptions import DistributedNotInitializedError
24
25
  from msprobe.core.common.file_utils import check_path_length, load_yaml
26
+ from msprobe.core.common.runtime import Runtime
27
+ from msprobe.core.hook_manager import HookSet
25
28
  from msprobe.mindspore.common.const import Const as MsConst
26
29
  from msprobe.mindspore.common.const import FreeBenchmarkConst
27
30
  from msprobe.mindspore.common.log import logger
28
31
  from msprobe.mindspore.common.utils import get_rank_if_initialized
29
32
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
30
- from msprobe.mindspore.dump.hook_cell.api_registry import api_register
33
+ from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
31
34
  from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
32
35
  from msprobe.mindspore.free_benchmark.common.config import Config
33
36
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
34
37
  from msprobe.mindspore.free_benchmark.common.utils import Tools
35
38
  from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
36
39
  from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
37
- from msprobe.mindspore.runtime import Runtime
40
+
41
+
42
+ _api_register = get_api_register()
38
43
 
39
44
 
40
45
  class ApiPyNativeSelfCheck:
@@ -60,8 +65,8 @@ class ApiPyNativeSelfCheck:
60
65
  self.store_original_func()
61
66
 
62
67
  def handle(self):
63
- api_register.initialize_hook(self.build_hook)
64
- api_register.api_set_hook_func()
68
+ _api_register.initialize_hook(self.build_hook)
69
+ _api_register.register_all_api()
65
70
 
66
71
  def build_hook(self, api_name):
67
72
  def pre_hook(cell, input_data):
@@ -71,7 +76,7 @@ class ApiPyNativeSelfCheck:
71
76
  ret = None
72
77
 
73
78
  if not need_wrapper_func():
74
- del cell.input_kwargs
79
+ del cell.msprobe_input_kwargs
75
80
  return ret
76
81
 
77
82
  api_name_with_id = api_name_with_id[:-1]
@@ -80,9 +85,9 @@ class ApiPyNativeSelfCheck:
80
85
  api_name_with_id[api_name_with_id.find(Const.SEP) + 1:api_name_with_id.rfind(Const.SEP)])
81
86
  if api_name in self.api_list:
82
87
  ret = check_self(api_name_with_id, output_data, self.ori_func.get(api_name),
83
- *input_data, **cell.input_kwargs)
88
+ *input_data, **cell.msprobe_input_kwargs)
84
89
 
85
- del cell.input_kwargs
90
+ del cell.msprobe_input_kwargs
86
91
  return ret
87
92
 
88
93
  def backward_hook(cell, grad_input, grad_output):
@@ -101,8 +106,13 @@ class ApiPyNativeSelfCheck:
101
106
 
102
107
  def pre_backward_hook(cell, grad_input):
103
108
  return None
104
-
105
- return pre_hook, wrap_forward_hook, wrap_backward_hook, pre_backward_hook
109
+
110
+ return HookSet(
111
+ forward_hook=wrap_forward_hook,
112
+ forward_pre_hook=pre_hook,
113
+ backward_hook=wrap_backward_hook,
114
+ backward_pre_hook=pre_backward_hook
115
+ )
106
116
 
107
117
  def store_original_func(self):
108
118
  for api_name in self.api_list:
@@ -166,13 +176,13 @@ def check_self(api_name_with_id, output, ori_func, *args, **kwargs):
166
176
  return ret
167
177
 
168
178
  logger.info(f"[{api_name_with_id}] is {Config.handler_type}ing.")
169
- api_register.api_set_ori_func()
179
+ _api_register.restore_all_api()
170
180
 
171
181
  try:
172
182
  perturbation = PerturbationFactory.create(api_name_with_id)
173
183
  params.fuzzed_result = perturbation.handle(params)
174
184
  if params.fuzzed_result is False:
175
- api_register.api_set_hook_func()
185
+ _api_register.register_all_api()
176
186
  return ret
177
187
  if Config.stage == Const.BACKWARD:
178
188
  params.original_result = Tools.get_grad(params.original_func, *params.args, **params.kwargs)
@@ -183,7 +193,7 @@ def check_self(api_name_with_id, output, ori_func, *args, **kwargs):
183
193
  logger.error(f"[{api_name_with_id}] Error: {str(e)}")
184
194
  logger.error(f"[{api_name_with_id}] Error detail: {traceback.format_exc()}")
185
195
 
186
- api_register.api_set_hook_func()
196
+ _api_register.register_all_api()
187
197
  return ret
188
198
 
189
199
 
@@ -19,10 +19,10 @@ from typing import Any, Optional
19
19
  import mindspore as ms
20
20
  from mindspore import Tensor, ops
21
21
 
22
+ from msprobe.core.common.runtime import Runtime
22
23
  from msprobe.mindspore.common.const import FreeBenchmarkConst
23
24
  from msprobe.mindspore.free_benchmark.common.config import Config
24
25
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
25
- from msprobe.mindspore.runtime import Runtime
26
26
 
27
27
 
28
28
  class Tools:
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from msprobe.mindspore.common.const import FreeBenchmarkConst
17
+ from msprobe.mindspore.common.log import logger
17
18
  from msprobe.mindspore.free_benchmark.common.config import Config
18
19
  from msprobe.mindspore.free_benchmark.perturbation.add_noise import AddNoisePerturbation
19
20
  from msprobe.mindspore.free_benchmark.perturbation.bit_noise import BitNoisePerturbation
@@ -41,4 +42,5 @@ class PerturbationFactory:
41
42
  if perturbation:
42
43
  return perturbation(api_name_with_id)
43
44
  else:
44
- raise Exception(f'{Config.pert_type} is a invalid perturbation type')
45
+ logger.error(f'{Config.pert_type} is a invalid perturbation type')
46
+ raise ValueError