mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -20,7 +20,7 @@ from mindspore import Tensor
20
20
  from mindspore._c_expression import PyNativeExecutor_
21
21
  from mindspore.common.api import _MindsporeFunctionExecutor
22
22
 
23
- from msprobe.mindspore.dump.hook_cell.api_registry import api_register
23
+ from msprobe.core.common.log import logger
24
24
  from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
25
25
  from msprobe.core.common.const import Const
26
26
  from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs
@@ -33,6 +33,8 @@ def dump_jit(name, in_feat, out_feat, is_forward):
33
33
  index = ori_args.find("<")
34
34
  if index != 0 and index != -1:
35
35
  result = ori_args[0:index]
36
+ elif name is not None and "<" not in str(name):
37
+ result = str(name)
36
38
  else:
37
39
  result = "JitFunction"
38
40
  if JitDump.need_dump():
@@ -47,7 +49,7 @@ def dump_jit(name, in_feat, out_feat, is_forward):
47
49
  name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \
48
50
  Const.BACKWARD
49
51
  JitDump.data_collector.update_api_or_module_name(name_template)
50
- module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat ,grad_output=out_feat)
52
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat, grad_output=out_feat)
51
53
  JitDump.data_collector.backward_data_collect(name_template, None, pid, module_input_output)
52
54
 
53
55
 
@@ -59,15 +61,25 @@ class JitDump(_MindsporeFunctionExecutor):
59
61
 
60
62
  def __init__(self, *args, **kwargs):
61
63
  super().__init__(*args, **kwargs)
64
+ self.name = None
65
+ if len(args) > 0:
66
+ self.name = args[0].__name__
62
67
  self._executor = PyNativeExecutor_.get_instance()
63
68
 
64
69
  def __call__(self, *args, **kwargs):
65
- api_register.api_set_ori_func()
70
+ if JitDump.jit_dump_switch:
71
+ api_register.api_set_ori_func()
66
72
  out = super().__call__(*args, **kwargs)
67
73
  if JitDump.jit_dump_switch and len(args) > 0:
68
- dump_jit(args[0], args, out, True)
74
+ if self.name and self.name != "construct":
75
+ dump_jit(self.name, args, out, True)
76
+ else:
77
+ dump_jit(args[0], args, out, True)
69
78
  JitDump.jit_enable = True
70
- api_register.api_set_hook_func()
79
+ elif len(args) == 0:
80
+ logger.warning(f"The jit function {self.name} has no input arguments, nothing will be dumped.")
81
+ if JitDump.jit_dump_switch:
82
+ api_register.api_set_hook_func()
71
83
  return out
72
84
 
73
85
  @classmethod
@@ -0,0 +1,33 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from msprobe.core.common.file_utils import save_json
19
+
20
+
21
+ def create_kernel_config_json(dump_path, cur_rank):
22
+ kernel_config_name = "kernel_config.json" if cur_rank == '' else f"kernel_config_{cur_rank}.json"
23
+ kernel_config_path = os.path.join(dump_path, kernel_config_name)
24
+ config_info = {
25
+ "dump": {
26
+ "dump_list": [],
27
+ "dump_path": dump_path,
28
+ "dump_mode": "all",
29
+ "dump_op_switch": "on"
30
+ }
31
+ }
32
+ save_json(kernel_config_path, config_info, indent=4)
33
+ return kernel_config_path
@@ -13,10 +13,9 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import json
17
16
  import os
18
17
 
19
- from msprobe.core.common.file_utils import FileOpen, create_directory
18
+ from msprobe.core.common.file_utils import create_directory, save_json
20
19
  from msprobe.mindspore.common.log import logger
21
20
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
22
21
 
@@ -57,13 +56,19 @@ class KernelGraphDump:
57
56
  self.dump_json["common_dump_settings"]["input_output"] = 2
58
57
 
59
58
  def handle(self):
59
+ try:
60
+ from msprobe.lib import _msprobe_c
61
+ return
62
+ except ImportError:
63
+ # 如果没有_msprobe_ce_c走MindSpore老流程
64
+ logger.info("Module _msprobe_c has not been installed, use interface in mindspore instead.")
65
+
60
66
  if os.getenv("GRAPH_OP_RUN") == "1":
61
67
  raise Exception("Must run in graph mode, not kbk mode")
62
68
  json_path = self.dump_json["common_dump_settings"]["path"]
63
69
  create_directory(json_path)
64
70
  json_path = os.path.join(json_path, "kernel_graph_dump.json")
65
- with FileOpen(json_path, 'w') as f:
66
- json.dump(self.dump_json, f)
71
+ save_json(json_path, self.dump_json, indent=4)
67
72
  logger.info(json_path + " has been created.")
68
73
  os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
69
74
  if self.dump_json["common_dump_settings"]["dump_mode"] == 0:
@@ -13,11 +13,10 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import json
17
16
  import os
18
17
 
19
18
  from msprobe.core.common.const import Const
20
- from msprobe.core.common.file_utils import FileOpen, create_directory
19
+ from msprobe.core.common.file_utils import create_directory, save_json
21
20
  from msprobe.mindspore.common.log import logger
22
21
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
23
22
 
@@ -70,8 +69,7 @@ class KernelKbykDump:
70
69
  json_path = self.dump_json[KernelKbykDump.COMMON_SETTINGS]["path"]
71
70
  create_directory(json_path)
72
71
  json_path = os.path.join(json_path, "kernel_kbyk_dump.json")
73
- with FileOpen(json_path, 'w') as f:
74
- json.dump(self.dump_json, f)
72
+ save_json(json_path, self.dump_json, indent=4)
75
73
  logger.info(json_path + " has been created.")
76
74
 
77
75
  os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
@@ -0,0 +1,140 @@
1
+ /**
2
+ * Copyright 2024 Huawei Technologies Co., Ltd
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include "hook_dynamic_loader.h"
18
+ #include <sys/stat.h>
19
+ #include <cstdlib>
20
+ #include <cstring>
21
+ #include "utils/log_adapter.h"
22
+
23
+ namespace {
24
+
25
+ // Utility function to check if a file path is valid
26
+ bool IsValidPath(const std::string &path) {
27
+ struct stat fileStat;
28
+ if (stat(path.c_str(), &fileStat) != 0) {
29
+ MS_LOG(ERROR) << "File does not exist or cannot be accessed: " << path;
30
+ return false;
31
+ }
32
+
33
+ if (S_ISLNK(fileStat.st_mode)) {
34
+ MS_LOG(ERROR) << "File is a symbolic link, which is not allowed: " << path;
35
+ return false;
36
+ }
37
+
38
+ if (!S_ISREG(fileStat.st_mode)) {
39
+ MS_LOG(ERROR) << "File is not a regular file: " << path;
40
+ return false;
41
+ }
42
+
43
+ if (path.substr(path.find_last_of(".")) != ".so") {
44
+ MS_LOG(ERROR) << "File is not a .so file: " << path;
45
+ return false;
46
+ }
47
+
48
+ return true;
49
+ }
50
+
51
+ } // namespace
52
+
53
+ HookDynamicLoader &HookDynamicLoader::GetInstance() {
54
+ static HookDynamicLoader instance;
55
+ return instance;
56
+ }
57
+
58
+ bool HookDynamicLoader::loadFunction(void *handle, const std::string &functionName) {
59
+ void *func = dlsym(handle, functionName.c_str());
60
+ if (!func) {
61
+ MS_LOG(WARNING) << "Could not load function: " << functionName << ", error: " << dlerror();
62
+ return false;
63
+ }
64
+ funcMap_[functionName] = func;
65
+ return true;
66
+ }
67
+
68
+ bool HookDynamicLoader::validateLibraryPath(const std::string &libPath) {
69
+ char *realPath = realpath(libPath.c_str(), nullptr);
70
+ if (!realPath) {
71
+ MS_LOG(WARNING) << "Failed to resolve realpath for the library: " << libPath;
72
+ return false;
73
+ }
74
+
75
+ bool isValid = IsValidPath(realPath);
76
+ free(realPath); // Free memory allocated by realpath
77
+ return isValid;
78
+ }
79
+
80
+ bool HookDynamicLoader::LoadLibrary() {
81
+ const char *libPath = std::getenv("HOOK_TOOL_PATH");
82
+ if (!libPath) {
83
+ MS_LOG(WARNING) << "HOOK_TOOL_PATH is not set!";
84
+ return false;
85
+ }
86
+
87
+ std::string resolvedLibPath(libPath);
88
+ if (!validateLibraryPath(resolvedLibPath)) {
89
+ MS_LOG(WARNING) << "Library path validation failed.";
90
+ return false;
91
+ }
92
+
93
+ std::lock_guard<std::mutex> lock(mutex_);
94
+ if (handle_) {
95
+ MS_LOG(WARNING) << "Hook library already loaded!";
96
+ return false;
97
+ }
98
+
99
+ handle_ = dlopen(resolvedLibPath.c_str(), RTLD_LAZY | RTLD_LOCAL);
100
+ if (!handle_) {
101
+ MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror();
102
+ return false;
103
+ }
104
+
105
+ for (const auto &functionName : functionList_) {
106
+ if (!loadFunction(handle_, functionName)) {
107
+ MS_LOG(WARNING) << "Failed to load function: " << functionName;
108
+ dlclose(handle_);
109
+ handle_ = nullptr;
110
+ return false;
111
+ }
112
+ }
113
+
114
+ MS_LOG(INFO) << "Hook library loaded successfully.";
115
+ return true;
116
+ }
117
+
118
+ bool HookDynamicLoader::UnloadLibrary() {
119
+ std::lock_guard<std::mutex> lock(mutex_);
120
+ if (!handle_) {
121
+ MS_LOG(WARNING) << "Hook library hasn't been loaded.";
122
+ return false;
123
+ }
124
+
125
+ dlclose(handle_);
126
+ handle_ = nullptr;
127
+ funcMap_.clear();
128
+ MS_LOG(INFO) << "Library unloaded successfully.";
129
+ return true;
130
+ }
131
+
132
+ void *HookDynamicLoader::GetHooker(const std::string &funcName) {
133
+ std::lock_guard<std::mutex> lock(mutex_);
134
+ auto iter = funcMap_.find(funcName);
135
+ if (iter == funcMap_.end()) {
136
+ MS_LOG(WARNING) << "Function not found: " << funcName;
137
+ return nullptr;
138
+ }
139
+ return iter->second;
140
+ }
@@ -0,0 +1,53 @@
1
+ /**
2
+ * Copyright 2024 Huawei Technologies Co., Ltd
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef HOOK_DYNAMIC_LOADER_H
18
+ #define HOOK_DYNAMIC_LOADER_H
19
+
20
+ #include <dlfcn.h>
21
+ #include <string>
22
+ #include <vector>
23
+ #include <map>
24
+ #include <mutex>
25
+
26
+ constexpr auto kHookBegin = "MS_DbgOnStepBegin";
27
+ constexpr auto kHookEnd = "MS_DbgOnStepEnd";
28
+
29
+ class HookDynamicLoader {
30
+ public:
31
+ static HookDynamicLoader &GetInstance();
32
+
33
+ HookDynamicLoader(const HookDynamicLoader &) = delete;
34
+ HookDynamicLoader &operator=(const HookDynamicLoader &) = delete;
35
+
36
+ bool LoadLibrary();
37
+ bool UnloadLibrary();
38
+ void *GetHooker(const std::string &funcName);
39
+
40
+ private:
41
+ // Helper functions
42
+ bool loadFunction(void *handle, const std::string &functionName);
43
+ bool validateLibraryPath(const std::string &libPath);
44
+
45
+ HookDynamicLoader() = default;
46
+
47
+ void *handle_ = nullptr;
48
+ std::vector<std::string> functionList_ = {kHookBegin, kHookEnd};
49
+ std::map<std::string, void *> funcMap_;
50
+ std::mutex mutex_;
51
+ };
52
+
53
+ #endif // HOOK_DYNAMIC_LOADER_H
@@ -1,7 +1,7 @@
1
1
  # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
6
6
  # You may obtain a copy of the License at
7
7
  #
@@ -13,24 +13,31 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import functools
16
17
  import importlib
17
- import inspect
18
18
  import os
19
+ import traceback
19
20
 
20
21
  import mindspore as ms
21
- from mindspore.communication import comm_func
22
-
23
22
  from msprobe.core.common.const import Const
23
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
24
24
  from msprobe.core.common.file_utils import check_path_length, load_yaml
25
25
  from msprobe.mindspore.common.const import Const as MsConst
26
26
  from msprobe.mindspore.common.const import FreeBenchmarkConst
27
27
  from msprobe.mindspore.common.log import logger
28
+ from msprobe.mindspore.common.utils import get_rank_if_initialized
28
29
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
30
+ from msprobe.mindspore.dump.hook_cell.api_registry import api_register
31
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
29
32
  from msprobe.mindspore.free_benchmark.common.config import Config
30
- from msprobe.mindspore.free_benchmark.decorator.decorator_factory import decorate_forward_function
33
+ from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
34
+ from msprobe.mindspore.free_benchmark.common.utils import Tools
35
+ from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
36
+ from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
37
+ from msprobe.mindspore.runtime import Runtime
31
38
 
32
39
 
33
- class ApiPyNativeSelFCheck:
40
+ class ApiPyNativeSelfCheck:
34
41
  def __init__(self, config: DebuggerConfig):
35
42
  Config.is_enable = True
36
43
  Config.handler_type = config.handler_type
@@ -39,29 +46,77 @@ class ApiPyNativeSelFCheck:
39
46
  Config.dump_level = config.dump_level
40
47
  Config.steps = config.step
41
48
  Config.ranks = config.rank
42
- Config.dump_path = os.path.join(config.dump_path, "free_benchmark.csv")
49
+ Config.dump_path = os.path.join(config.dump_path, FreeBenchmarkConst.CHECK_RESULT_FILE)
43
50
  check_path_length(Config.dump_path)
44
51
 
52
+ self.ori_func = {}
53
+
45
54
  self.api_list = config.list
46
55
  all_api = get_supported_ops()
47
56
  if not self.api_list:
48
57
  self.api_list = all_api
49
58
  else:
50
59
  self.api_list = set(self.api_list) & all_api
60
+ self.store_original_func()
51
61
 
52
62
  def handle(self):
63
+ api_register.initialize_hook(self.build_hook)
64
+ api_register.api_set_hook_func()
65
+
66
+ def build_hook(self, api_name):
67
+ def pre_hook(cell, input_data):
68
+ return None
69
+
70
+ def forward_hook(api_name_with_id, cell, input_data, output_data):
71
+ ret = None
72
+
73
+ if not need_wrapper_func():
74
+ del cell.input_kwargs
75
+ return ret
76
+
77
+ api_name_with_id = api_name_with_id[:-1]
78
+ hook_prefix = api_name_with_id[:api_name_with_id.find(Const.SEP) + 1]
79
+ api_name = (MsConst.HOOK_MS_PREFIX_DICT.get(hook_prefix, "") +
80
+ api_name_with_id[api_name_with_id.find(Const.SEP) + 1:api_name_with_id.rfind(Const.SEP)])
81
+ if api_name in self.api_list:
82
+ ret = check_self(api_name_with_id, output_data, self.ori_func.get(api_name),
83
+ *input_data, **cell.input_kwargs)
84
+
85
+ del cell.input_kwargs
86
+ return ret
87
+
88
+ def backward_hook(cell, grad_input, grad_output):
89
+ pass
90
+
91
+ HOOKCell.get_cell_count(api_name)
92
+ api_name_with_id = api_name + str(HOOKCell.get_cell_count(api_name)) + Const.SEP
93
+ forward_hook = functools.partial(forward_hook, api_name_with_id)
94
+ HOOKCell.add_cell_count(api_name)
95
+
96
+ def wrap_forward_hook(cell, input_data, output_data):
97
+ return forward_hook(cell, input_data, output_data)
98
+
99
+ def wrap_backward_hook(cell, grad_input, grad_output):
100
+ return backward_hook(cell, grad_input, grad_output)
101
+
102
+ def pre_backward_hook(cell, grad_input):
103
+ return None
104
+
105
+ return pre_hook, wrap_forward_hook, wrap_backward_hook, pre_backward_hook
106
+
107
+ def store_original_func(self):
53
108
  for api_name in self.api_list:
54
- hijack(api_name)
109
+ self.ori_func[api_name] = get_module(api_name)[1]
55
110
 
56
111
 
57
112
  def get_supported_ops():
58
113
  supported_ops = []
59
114
  cur_path = os.path.dirname(os.path.realpath(__file__))
60
- yaml_path = os.path.join(cur_path, "data", "support_wrap_ops.yaml")
115
+ yaml_path = os.path.join(cur_path, "data", FreeBenchmarkConst.SUPPORTED_CHECK_API_FILE)
61
116
 
62
- yaml_data = load_yaml(yaml_path)
117
+ supported_ops_list = load_yaml(yaml_path)
63
118
  for k, v in FreeBenchmarkConst.API_PREFIX_DICT.items():
64
- ops = yaml_data.get(k)
119
+ ops = supported_ops_list.get(k)
65
120
  if ops:
66
121
  ops = [v + i for i in ops]
67
122
  supported_ops += ops
@@ -72,7 +127,7 @@ def get_supported_ops():
72
127
  _all_functional_ops += ms_ops
73
128
 
74
129
  ms_tensor = dir(ms.Tensor)
75
- ms_tensor = [MsConst.Tensor_PREFIX + i for i in ms_tensor]
130
+ ms_tensor = [MsConst.TENSOR_PREFIX + i for i in ms_tensor]
76
131
  _all_functional_ops += ms_tensor
77
132
 
78
133
  ms_mint = dir(ms.mint)
@@ -83,49 +138,109 @@ def get_supported_ops():
83
138
  ms_mint_nn_func = [MsConst.MINT_NN_FUNC_PREFIX + i for i in ms_mint_nn_func]
84
139
  _all_functional_ops += ms_mint_nn_func
85
140
 
86
- ms_communication = dir(comm_func)
87
- ms_communication = [MsConst.COMM_PREFIX + i for i in ms_communication]
88
- _all_functional_ops += ms_communication
89
-
90
141
  return set(supported_ops) & set(_all_functional_ops)
91
142
 
92
143
 
93
- def get_decorate_func():
94
- return decorate_forward_function
95
-
96
-
97
- def is_func_support_decorate(orig_func):
98
- return not inspect.isclass(orig_func) and callable(orig_func)
99
-
100
-
101
- def get_wrapper_obj(orig_func, api_name):
102
- if is_func_support_decorate(orig_func):
103
- wrapped_obj = get_decorate_func()(orig_func, api_name)
104
- else:
105
- wrapped_obj = orig_func
106
- return wrapped_obj
107
-
108
-
109
144
  def get_module(api_name):
110
145
  func_name_list = api_name.split(Const.SEP)
111
146
  func_name = func_name_list[-1]
112
147
  module_obj = importlib.import_module(func_name_list[0])
113
148
  for i, module_name in enumerate(func_name_list[1:-1]):
114
149
  if not hasattr(module_obj, module_name):
115
- importlib.import_module(f"{Const.SEP.join(func_name_list[:i+2])}")
150
+ importlib.import_module(f"{Const.SEP.join(func_name_list[:i + 2])}")
116
151
  module_obj = getattr(module_obj, module_name)
117
152
  orig_func = getattr(module_obj, func_name)
118
153
 
119
154
  return module_obj, orig_func
120
155
 
121
156
 
122
- def hijack(api_name):
123
- if not api_name.strip():
124
- return
157
+ def check_self(api_name_with_id, output, ori_func, *args, **kwargs):
158
+ ret = None
159
+
160
+ if Config.stage == Const.BACKWARD and not (check_all_tensor(args) and check_all_tensor(output)):
161
+ logger.warning(f"{api_name_with_id} has non-tensor input or output.")
162
+ return ret
163
+
164
+ params = data_pre_deal(api_name_with_id, ori_func, *args, **kwargs)
165
+ if params.index == -1:
166
+ return ret
167
+
168
+ logger.info(f"[{api_name_with_id}] is {Config.handler_type}ing.")
169
+ api_register.api_set_ori_func()
170
+
125
171
  try:
126
- func_name = api_name.split(Const.SEP)[-1]
127
- module_obj, origin_func = get_module(api_name)
128
- wrapped_obj = get_wrapper_obj(origin_func, api_name)
129
- setattr(module_obj, func_name, wrapped_obj)
172
+ perturbation = PerturbationFactory.create(api_name_with_id)
173
+ params.fuzzed_result = perturbation.handle(params)
174
+ if params.fuzzed_result is False:
175
+ api_register.api_set_hook_func()
176
+ return ret
177
+ if Config.stage == Const.BACKWARD:
178
+ params.original_result = Tools.get_grad(params.original_func, *params.args, **params.kwargs)
179
+ else:
180
+ params.original_result = output
181
+ ret = deal_fuzzed_and_original_result(api_name_with_id, params)
130
182
  except Exception as e:
131
- logger.error(f"Failed decorator {api_name}: {e}")
183
+ logger.error(f"[{api_name_with_id}] Error: {str(e)}")
184
+ logger.error(f"[{api_name_with_id}] Error detail: {traceback.format_exc()}")
185
+
186
+ api_register.api_set_hook_func()
187
+ return ret
188
+
189
+
190
+ def check_all_tensor(input_output):
191
+ if isinstance(input_output, ms.Tensor):
192
+ return True
193
+ if isinstance(input_output, (tuple, list)):
194
+ return all([check_all_tensor(v) for v in input_output])
195
+ return False
196
+
197
+
198
+ def get_target_arg_index(args) -> int:
199
+ """
200
+ 类型校验
201
+
202
+ """
203
+ for i, arg in enumerate(args):
204
+ if ms.ops.is_tensor(arg):
205
+ if not ms.ops.is_floating_point(arg):
206
+ continue
207
+ return i
208
+ if isinstance(arg, (list, tuple, dict)):
209
+ return i
210
+ return -1
211
+
212
+
213
+ def data_pre_deal(api_name_with_id, func, *args, **kwargs):
214
+ params = HandlerParams()
215
+ params.args = args
216
+ params.kwargs = kwargs
217
+ params.original_func = func
218
+ index = get_target_arg_index(args)
219
+ if index == -1:
220
+ logger.warning(f"{api_name_with_id} has no supported input type.")
221
+ params.index = index
222
+ return params
223
+
224
+
225
+ def need_wrapper_func():
226
+ if not (Runtime.is_running and Config.is_enable):
227
+ return False
228
+
229
+ if Config.steps and Runtime.step_count not in Config.steps:
230
+ return False
231
+
232
+ if Runtime.rank_id == -1:
233
+ try:
234
+ Runtime.rank_id = get_rank_if_initialized()
235
+ except DistributedNotInitializedError:
236
+ Runtime.rank_id = -1
237
+ if Config.ranks and Runtime.rank_id != -1 and Runtime.rank_id not in Config.ranks:
238
+ return False
239
+
240
+ return True
241
+
242
+
243
+ def deal_fuzzed_and_original_result(api_name_with_id, params: HandlerParams):
244
+ handler = HandlerFactory.create(api_name_with_id)
245
+ result = handler.handle(params)
246
+ return result
@@ -1,7 +1,7 @@
1
1
  # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
6
6
  # You may obtain a copy of the License at
7
7
  #
@@ -27,6 +27,5 @@ class HandlerParams:
27
27
  original_result: Optional[Any] = None
28
28
  fuzzed_result: Optional[Any] = None
29
29
  is_consistent: Optional[bool] = True
30
- save_flag: Optional[bool] = True
31
30
  fuzzed_value: Optional[Any] = None
32
31
  original_func: Optional[Callable] = None
@@ -1,7 +1,7 @@
1
1
  # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
6
6
  # You may obtain a copy of the License at
7
7
  #
@@ -17,7 +17,7 @@ from dataclasses import dataclass
17
17
  from typing import Any, Optional
18
18
 
19
19
  import mindspore as ms
20
- from mindspore import Tensor
20
+ from mindspore import Tensor, ops
21
21
 
22
22
  from msprobe.mindspore.common.const import FreeBenchmarkConst
23
23
  from msprobe.mindspore.free_benchmark.common.config import Config
@@ -43,6 +43,23 @@ class Tools:
43
43
  return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD
44
44
  return FreeBenchmarkConst.ERROR_THRESHOLD.get(dtype, FreeBenchmarkConst.ERROR_THRESHOLD.get(ms.float32))
45
45
 
46
+ @staticmethod
47
+ def get_grad_out(outputs):
48
+ if isinstance(outputs, Tensor):
49
+ return ops.ones_like(outputs)
50
+ if isinstance(outputs, (tuple, list)):
51
+ return type(outputs)([Tools.get_grad_out(v) for v in outputs])
52
+ return outputs
53
+
54
+ @staticmethod
55
+ def get_grad(func, *args, **kwargs):
56
+ def target_func(*inputs):
57
+ return func(*inputs, **kwargs)
58
+
59
+ outputs, vjp_fn = ms.vjp(target_func, *args)
60
+ values = Tools.get_grad_out(outputs)
61
+ return vjp_fn(values)
62
+
46
63
 
47
64
  @dataclass
48
65
  class UnequalRow:
@@ -73,10 +90,8 @@ def make_unequal_row(
73
90
  if isinstance(ratio, float):
74
91
  row.max_rel = ratio - 1
75
92
  original_tensor = params.original_result
76
- fuzzed_tensor = params.fuzzed_result
77
93
  if index is not None:
78
94
  original_tensor = original_tensor[index]
79
- fuzzed_tensor = fuzzed_tensor[index]
80
95
  row.output_index = index
81
96
  if isinstance(original_tensor, Tensor):
82
97
  row.dtype = original_tensor.dtype