mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -1,7 +1,7 @@
1
1
  # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
6
6
  # You may obtain a copy of the License at
7
7
  #
@@ -18,9 +18,11 @@ from typing import Any
18
18
  import mindspore as ms
19
19
  from mindspore import Tensor, ops
20
20
 
21
- from msprobe.mindspore.common.const import Const
21
+ from msprobe.core.common.const import Const
22
22
  from msprobe.mindspore.common.log import logger
23
+ from msprobe.mindspore.free_benchmark.common.config import Config
23
24
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
25
+ from msprobe.mindspore.free_benchmark.common.utils import Tools
24
26
  from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
25
27
 
26
28
 
@@ -40,10 +42,15 @@ class ImprovePrecisionPerturbation(BasePerturbation):
40
42
  def handle(self, params: HandlerParams) -> Any:
41
43
  args = self.improve_tensor_precision(params.args)
42
44
  kwargs = self.improve_tensor_precision(params.kwargs)
43
- fuzzed_value = args
44
- if self.api_name in Const.COMMUNICATION_API_LIST:
45
- params.fuzzed_value = fuzzed_value
46
45
  if not self.is_fuzzed:
47
- logger.warning(f"{self.api_name} can not improve precision.")
46
+ logger.warning(f"{self.api_name_with_id} can not improve precision.")
48
47
  return False
48
+
49
+ if Config.stage == Const.BACKWARD:
50
+ fuzzed_result = Tools.get_grad(params.original_func, *args, **kwargs)
51
+ if fuzzed_result is not None:
52
+ return fuzzed_result
53
+ else:
54
+ return False
55
+
49
56
  return params.original_func(*args, **kwargs)
@@ -36,9 +36,9 @@ class PerturbationFactory:
36
36
  }
37
37
 
38
38
  @staticmethod
39
- def create(api_name: str):
39
+ def create(api_name_with_id: str):
40
40
  perturbation = PerturbationFactory.perturbations.get(Config.pert_type)
41
41
  if perturbation:
42
- return perturbation(api_name)
42
+ return perturbation(api_name_with_id)
43
43
  else:
44
44
  raise Exception(f'{Config.pert_type} is a invalid perturbation type')
@@ -15,7 +15,7 @@
15
15
 
16
16
  from msprobe.mindspore.common.const import Const
17
17
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
18
- from msprobe.mindspore.free_benchmark.api_pynative_self_check import ApiPyNativeSelFCheck
18
+ from msprobe.mindspore.free_benchmark.api_pynative_self_check import ApiPyNativeSelfCheck
19
19
 
20
20
 
21
21
  class SelfCheckToolFactory:
@@ -28,7 +28,7 @@ class SelfCheckToolFactory:
28
28
  Const.API: {
29
29
  Const.GRAPH_KBYK_MODE: None,
30
30
  Const.GRAPH_GE_MODE: None,
31
- Const.PYNATIVE_MODE: ApiPyNativeSelFCheck
31
+ Const.PYNATIVE_MODE: ApiPyNativeSelfCheck
32
32
  },
33
33
  Const.KERNEL: {
34
34
  Const.GRAPH_KBYK_MODE: None,
@@ -1,15 +1,30 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  import threading
3
18
  from typing import Dict, Union, Tuple
4
19
 
5
- from msprobe.core.grad_probe.utils import check_str, check_bounds_element
20
+ from msprobe.core.common.utils import is_int
21
+ from msprobe.core.common.file_utils import create_directory, check_path_before_create
6
22
  from msprobe.core.grad_probe.constant import GradConst
23
+ from msprobe.core.grad_probe.utils import check_str, check_bounds_element, check_param_element
7
24
  from msprobe.mindspore.common.log import logger
8
- from msprobe.core.common.file_utils import create_directory, check_path_before_create
9
25
 
10
26
 
11
27
  class GlobalContext:
12
-
13
28
  _instance = None
14
29
  _instance_lock = threading.Lock()
15
30
  _setting = {
@@ -37,10 +52,10 @@ class GlobalContext:
37
52
  else:
38
53
  raise ValueError("Invalid level set in config yaml file, level option: L0, L1, L2")
39
54
 
40
- self._set_input_list(config_dict, GradConst.PARAM_LIST, str)
55
+ self._set_input_list(config_dict, GradConst.PARAM_LIST, (str,), element_check=check_param_element)
41
56
  self._set_input_list(config_dict, GradConst.BOUNDS, (float, int), element_check=check_bounds_element)
42
- self._set_input_list(config_dict, GradConst.STEP, int)
43
- self._set_input_list(config_dict, GradConst.RANK, int)
57
+ self._set_input_list(config_dict, GradConst.STEP, (int,))
58
+ self._set_input_list(config_dict, GradConst.RANK, (int,))
44
59
 
45
60
  output_path = config_dict.get(GradConst.OUTPUT_PATH)
46
61
  check_str(output_path, variable_name="output_path in yaml")
@@ -88,13 +103,18 @@ class GlobalContext:
88
103
  if value and isinstance(value, list):
89
104
  for val in value:
90
105
  if not isinstance(val, dtype):
91
- logger.warning(f"Invalid {name} which must be None or list of {type_str}")
106
+ logger.warning(f"Invalid {name} which must be None or list of {type_str}, use default value.")
107
+ return
108
+ elif isinstance(val, int) and not is_int(val):
109
+ logger.warning(f"Invalid {name} which must be None or list of int, use default value.")
92
110
  return
93
111
  if element_check and not element_check(val):
94
- logger.warning(f"Given {name} violates some rules.")
112
+ logger.warning(f"Given {name} violates some rules, use default value.")
95
113
  return
114
+
96
115
  self._setting[name] = value
97
116
  else:
98
117
  logger.warning(f"{name} is None or not a list with valid items, use default value.")
99
118
 
119
+
100
120
  grad_context = GlobalContext()
@@ -1,23 +1,48 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import multiprocessing
1
17
  import os
2
18
  import time
3
- from typing import List, Tuple
4
- import multiprocessing
19
+ from dataclasses import dataclass
5
20
  from multiprocessing import Process
21
+ from typing import List
6
22
 
7
- import numpy as np
8
23
  import mindspore as ms
9
- from mindspore.communication import get_rank
10
- from mindspore.ops import operations as P
24
+ import numpy as np
11
25
  from mindspore.common.parameter import Parameter
26
+ from mindspore.communication import get_rank
12
27
 
13
- from msprobe.core.grad_probe.utils import ListCache
14
- from msprobe.core.grad_probe.constant import GradConst
15
- from msprobe.mindspore.common.log import logger
16
28
  from msprobe.core.common.file_utils import (create_directory, check_file_or_directory_path,
17
29
  write_csv, remove_path, move_file, load_npy)
30
+ from msprobe.core.grad_probe.constant import GradConst
31
+ from msprobe.core.grad_probe.utils import ListCache
32
+ from msprobe.mindspore.common.log import logger
18
33
  from msprobe.mindspore.grad_probe.global_context import grad_context, GlobalContext
19
34
 
20
35
 
36
+ @dataclass
37
+ class GradDumpConfig:
38
+ dump_dir: str
39
+ g_name: str
40
+ dump_step: Parameter
41
+ grad: ms.Tensor
42
+ level: str
43
+ bounds: List
44
+
45
+
21
46
  def get_rank_id():
22
47
  try:
23
48
  rank_id = get_rank()
@@ -27,35 +52,35 @@ def get_rank_id():
27
52
 
28
53
 
29
54
  @ms.jit
30
- def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor, level: str, bounds: List):
31
- '''
55
+ def grad_dump(config: GradDumpConfig):
56
+ """
32
57
  Dump gradient statistic data.
33
58
  level0: [step, max, min, norm, shape_dim, shape]
34
59
  level1: [step, max, min, norm, shape_dim, shape] + grad_bool_data
35
60
  level2: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + grad_bool_data
36
- '''
37
- dump_path = os.path.join(dump_dir, g_name)
61
+ """
62
+ dump_path = os.path.join(config.dump_dir, config.g_name)
38
63
  dump_dir_path = dump_path + "_dir"
39
64
  save_op = ms.ops.TensorDump()
40
65
 
41
- grad_flat = grad.reshape(-1)
66
+ grad_flat = config.grad.reshape(-1)
42
67
  max_val = grad_flat.max(axis=0).float()
43
68
  min_val = grad_flat.min(axis=0).float()
44
69
  norm_val = grad_flat.norm(ord=2).float()
45
- shape = grad.shape
46
- extrem_list = [dump_step[0].float(), max_val, min_val, norm_val]
70
+ shape = config.grad.shape
71
+ extrem_list = [config.dump_step[0].float(), max_val, min_val, norm_val]
47
72
  extrem_stat = ms.ops.stack(extrem_list)
48
73
  shape_list = [len(shape)] + list(shape)
49
74
  shape_stat = ms.Tensor(shape_list).float()
50
75
  level0_stat = ms.ops.concat((extrem_stat, shape_stat), axis=0)
51
76
  level_stat = level0_stat
52
77
 
53
- if level == GradConst.LEVEL2:
54
- zero_grad = (grad == 0).sum()
55
- dist_dim = ms.Tensor([len(bounds) + 2]).float()
56
- bucket_result = ms.ops.bucketize(grad.float(), bounds)
78
+ if config.level == GradConst.LEVEL2:
79
+ zero_grad = (config.grad == 0).sum()
80
+ dist_dim = ms.Tensor([len(config.bounds) + 2]).float()
81
+ bucket_result = ms.ops.bucketize(config.grad.float(), config.bounds)
57
82
  bucket_result = bucket_result.astype(ms.int8)
58
- dist_stat = [(bucket_result == i).sum() for i in range(len(bounds) + 1)]
83
+ dist_stat = [(bucket_result == i).sum() for i in range(len(config.bounds) + 1)]
59
84
  dist_stat.append(zero_grad)
60
85
  dist_stat.append(ms.Tensor(1, dtype=ms.int64)) # make sure dist_stat is not empty
61
86
  dist_stat = ms.ops.stack(dist_stat, axis=0).float()
@@ -63,8 +88,8 @@ def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor,
63
88
  level_stat = level2_stat
64
89
 
65
90
  save_op(dump_path, level_stat)
66
- if level == GradConst.LEVEL1 or level == GradConst.LEVEL2:
67
- grad_direction = grad > 0
91
+ if config.level == GradConst.LEVEL1 or config.level == GradConst.LEVEL2:
92
+ grad_direction = config.grad > 0
68
93
  save_op(dump_dir_path, grad_direction)
69
94
 
70
95
 
@@ -182,7 +207,7 @@ class CSVGenerator(Process):
182
207
  shape_dim = int(stat_data[GradConst.SHAPE_DIM_IDX])
183
208
  file_name = os.path.basename(file_path)
184
209
  prefix_idx = len(file_name.split("_")[0])
185
- param_name = file_name[(prefix_idx + 1) : -(len(GradConst.NPY_SUFFIX) + 1)]
210
+ param_name = file_name[(prefix_idx + 1): -(len(GradConst.NPY_SUFFIX) + 1)]
186
211
  if not param_name:
187
212
  raise RuntimeError("Invalid gradient statistic file name.")
188
213
  csv_line = [param_name]
@@ -224,8 +249,9 @@ class CSVGenerator(Process):
224
249
  if i == 0:
225
250
  intervals.append(f"(-inf, {self.bounds[i]}]")
226
251
  else:
227
- intervals.append(f"({self.bounds[i-1]}, {self.bounds[i]}]")
252
+ intervals.append(f"({self.bounds[i - 1]}, {self.bounds[i]}]")
228
253
  intervals.extend([f"({self.bounds[-1]}, inf)", "=0"])
229
254
  return intervals
230
255
 
256
+
231
257
  csv_generator = CSVGenerator()
@@ -1,7 +1,22 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.core.grad_probe.constant import GradConst
1
17
  from msprobe.mindspore.grad_probe.global_context import grad_context
2
18
  from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator
3
19
  from msprobe.mindspore.grad_probe.hook import hook_optimizer
4
- from msprobe.core.grad_probe.constant import GradConst
5
20
 
6
21
 
7
22
  class GradientMonitor:
@@ -1,8 +1,23 @@
1
- from abc import ABC, abstractmethod
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  import hashlib
17
+ from abc import ABC, abstractmethod
3
18
 
4
19
  import mindspore
5
- from mindspore import ops, Tensor
20
+ from mindspore import ops
6
21
  from msprobe.core.grad_probe.constant import GradConst
7
22
 
8
23
 
@@ -12,6 +27,7 @@ class CsvInput:
12
27
  self.grad = grad
13
28
  self.bounds = bounds
14
29
 
30
+
15
31
  class GradStatCsv:
16
32
  csv = {}
17
33
 
@@ -52,9 +68,11 @@ class CsvItem(ABC):
52
68
 
53
69
  @register_csv_item(GradConst.MD5)
54
70
  class CsvMd5(CsvItem):
71
+ @staticmethod
55
72
  def generate_csv_header(csv_input):
56
73
  return ["MD5"]
57
74
 
75
+ @staticmethod
58
76
  def generate_csv_content(csv_input):
59
77
  grad = csv_input.grad
60
78
  tensor_bytes = grad.float().numpy().tobytes()
@@ -64,19 +82,21 @@ class CsvMd5(CsvItem):
64
82
 
65
83
  @register_csv_item(GradConst.DISTRIBUTION)
66
84
  class CsvDistribution(CsvItem):
85
+ @staticmethod
67
86
  def generate_csv_header(csv_input):
68
87
  bounds = csv_input.bounds
69
88
  intervals = []
70
89
  if bounds:
71
90
  intervals.append(f"(-inf, {bounds[0]}]")
72
91
  for i in range(1, len(bounds)):
73
- intervals.append(f"({bounds[i-1]}, {bounds[i]}]")
92
+ intervals.append(f"({bounds[i - 1]}, {bounds[i]}]")
74
93
  if intervals:
75
94
  intervals.append(f"({bounds[-1]}, inf)")
76
95
  intervals.append("=0")
77
-
96
+
78
97
  return intervals
79
98
 
99
+ @staticmethod
80
100
  def generate_csv_content(csv_input):
81
101
  grad = csv_input.grad
82
102
  bounds = csv_input.bounds
@@ -94,9 +114,11 @@ class CsvDistribution(CsvItem):
94
114
 
95
115
  @register_csv_item(GradConst.MAX)
96
116
  class CsvMax(CsvItem):
117
+ @staticmethod
97
118
  def generate_csv_header(csv_input):
98
119
  return ["max"]
99
120
 
121
+ @staticmethod
100
122
  def generate_csv_content(csv_input):
101
123
  grad = csv_input.grad
102
124
  return [ops.amax(grad).float().numpy().tolist()]
@@ -104,9 +126,11 @@ class CsvMax(CsvItem):
104
126
 
105
127
  @register_csv_item(GradConst.MIN)
106
128
  class CsvMin(CsvItem):
129
+ @staticmethod
107
130
  def generate_csv_header(csv_input):
108
131
  return ["min"]
109
132
 
133
+ @staticmethod
110
134
  def generate_csv_content(csv_input):
111
135
  grad = csv_input.grad
112
136
  return [ops.amin(grad).float().numpy().tolist()]
@@ -114,9 +138,11 @@ class CsvMin(CsvItem):
114
138
 
115
139
  @register_csv_item(GradConst.NORM)
116
140
  class CsvNorm(CsvItem):
141
+ @staticmethod
117
142
  def generate_csv_header(csv_input):
118
143
  return ["norm"]
119
144
 
145
+ @staticmethod
120
146
  def generate_csv_content(csv_input):
121
147
  grad = csv_input.grad
122
148
  return [ops.norm(grad).float().numpy().tolist()]
@@ -124,9 +150,11 @@ class CsvNorm(CsvItem):
124
150
 
125
151
  @register_csv_item(GradConst.SHAPE)
126
152
  class CsvShape(CsvItem):
153
+ @staticmethod
127
154
  def generate_csv_header(csv_input):
128
155
  return ["shape"]
129
156
 
157
+ @staticmethod
130
158
  def generate_csv_content(csv_input):
131
159
  grad = csv_input.grad
132
- return [list(grad.shape)]
160
+ return [list(grad.shape)]
@@ -1,32 +1,51 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
1
15
 
2
16
  import os
3
17
 
4
18
  import mindspore
5
19
  import mindspore as ms
6
20
  from mindspore.common.api import jit
7
- from mindspore.nn.optim.optimizer import Optimizer
8
- from mindspore.common.parameter import Parameter
9
21
  from mindspore.common.initializer import initializer
10
-
22
+ from mindspore.common.parameter import Parameter
23
+ from mindspore.nn.optim.optimizer import Optimizer
24
+ from msprobe.core.common.file_utils import remove_path, write_csv, create_directory
11
25
  from msprobe.core.grad_probe.constant import GradConst
12
26
  from msprobe.mindspore.common.log import logger
13
-
14
- from msprobe.core.common.file_utils import remove_path, write_csv, create_directory
15
27
  from msprobe.mindspore.grad_probe.global_context import grad_context
16
- from msprobe.mindspore.grad_probe.grad_analyzer import grad_dump, get_rank_id
17
28
  from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator
29
+ from msprobe.mindspore.grad_probe.grad_analyzer import grad_dump, get_rank_id, GradDumpConfig
18
30
  from msprobe.mindspore.grad_probe.grad_stat_csv import GradStatCsv, CsvInput
19
31
  from msprobe.mindspore.grad_probe.utils import save_grad_direction, get_adapted_level
20
32
 
21
- class HookInput:
22
33
 
34
+ class HookInput:
23
35
  '''
24
36
  HookInput is a class wrapping all the variables used for hooking optimizer
25
37
  '''
26
38
 
27
39
  def __init__(self, opt) -> None:
28
40
  self.func = opt.construct
29
- self.g_names = [param.name for param in opt._parameters]
41
+ if hasattr(opt, "_parameters"):
42
+ parameter_list = opt._parameters
43
+ elif hasattr(opt, "parameters"):
44
+ parameter_list = opt.parameters
45
+ else:
46
+ logger.error_log_with_exp("Given optimizer has no attributes: '_parameters' or 'parameters'. \
47
+ Please check the type of the given optimizer.", ValueError)
48
+ self.g_names = [param.name for param in parameter_list]
30
49
  self.param_list = grad_context.get_context(GradConst.PARAM_LIST)
31
50
  self.rank_id = get_rank_id()
32
51
  output_path = grad_context.get_context(GradConst.OUTPUT_PATH)
@@ -40,14 +59,17 @@ class HookInput:
40
59
  self.bounds = grad_context.get_context(GradConst.BOUNDS)
41
60
  self.mode = mindspore.get_context("mode")
42
61
 
62
+
43
63
  def hook_graph_mode_optimizer(opt, hook_input):
44
64
  @jit
45
65
  def new_construct(self, gradients):
46
66
  for index, grad_value in enumerate(gradients):
47
67
  if hook_input.param_list and hook_input.g_names[index] not in hook_input.param_list:
48
68
  continue
49
- grad_dump(hook_input.dump_dir, hook_input.g_names[index], self.dump_step,
50
- grad_value, hook_input.level, hook_input.bounds)
69
+ conf = GradDumpConfig(dump_dir=hook_input.dump_dir, g_name=hook_input.g_names[index],
70
+ dump_step=self.dump_step, grad=grad_value, level=hook_input.level,
71
+ bounds=hook_input.bounds)
72
+ grad_dump(conf)
51
73
  ms.ops.TensorDump()(hook_input.step_finish_flag, self.dump_step)
52
74
  self.assignadd(self.dump_step, self.global_step_increase_tensor)
53
75
  out = hook_input.func(gradients)
@@ -57,11 +79,12 @@ def hook_graph_mode_optimizer(opt, hook_input):
57
79
  opt.construct = new_construct.__get__(opt, type(opt))
58
80
  csv_generator.start()
59
81
 
82
+
60
83
  def hook_pynative_optimizer(opt, hook_input):
61
84
  level_adapted = get_adapted_level(hook_input.level)
62
85
 
63
- def hook_fn(cell, input):
64
- gradients, = input
86
+ def hook_fn(cell, input_data):
87
+ gradients, = input_data
65
88
  cur_step = grad_context.get_context(GradConst.CURRENT_STEP)
66
89
  if grad_context.step_need_dump(cur_step) and grad_context.rank_need_dump(hook_input.rank_id):
67
90
  create_directory(hook_input.save_dir)
@@ -1,12 +1,26 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
 
3
18
  import mindspore
4
- from msprobe.core.grad_probe.constant import level_adp
5
- from msprobe.core.grad_probe.utils import check_param
6
19
  from msprobe.core.common.file_utils import (create_directory,
7
- check_path_before_create,
8
20
  check_file_or_directory_path,
9
21
  save_npy)
22
+ from msprobe.core.grad_probe.constant import level_adp
23
+ from msprobe.core.grad_probe.utils import check_param
10
24
 
11
25
 
12
26
  def save_grad_direction(param_name, grad, save_path):
@@ -15,7 +29,6 @@ def save_grad_direction(param_name, grad, save_path):
15
29
  check_file_or_directory_path(save_path, isdir=True)
16
30
  check_param(param_name)
17
31
  save_filepath = os.path.join(save_path, f"{param_name}.npy")
18
- check_path_before_create(save_filepath)
19
32
 
20
33
  if grad.dtype == mindspore.bfloat16:
21
34
  grad = grad.to(mindspore.float32)
@@ -27,4 +40,4 @@ def save_grad_direction(param_name, grad, save_path):
27
40
 
28
41
  def get_adapted_level(level: str):
29
42
  level_adapted = level_adp.get(level)
30
- return level_adapted
43
+ return level_adapted
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .mindtorch_adaptor import (_call_impl,
17
+ register_full_backward_pre_hook,
18
+ register_full_backward_hook)