mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -15,6 +15,7 @@
15
15
 
16
16
  import io
17
17
  import os
18
+ import pickle
18
19
  import random
19
20
  import stat
20
21
  from functools import wraps
@@ -24,7 +25,7 @@ import torch
24
25
  import torch.distributed as dist
25
26
  from msprobe.core.common.exceptions import DistributedNotInitializedError
26
27
  from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
27
- check_file_or_directory_path, check_path_before_create)
28
+ check_file_or_directory_path, check_path_before_create, FileOpen)
28
29
  from msprobe.core.common.log import logger
29
30
  from msprobe.core.common.utils import check_seed_all
30
31
  from packaging import version
@@ -75,7 +76,7 @@ def parameter_adapter(func):
75
76
  else:
76
77
  res = [input_tensor[tensor_index] for tensor_index in indices]
77
78
  return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0)
78
- if self.op_name_ == "__eq__" and args[1] is None:
79
+ if self.op_name_ == "__eq__" and len(args) > 1 and args[1] is None:
79
80
  return False
80
81
  return func(self, *args, **kwargs)
81
82
 
@@ -104,8 +105,49 @@ def get_rank_if_initialized():
104
105
  raise DistributedNotInitializedError("torch distributed environment is not initialized")
105
106
 
106
107
 
107
- def seed_all(seed=1234, mode=False):
108
- check_seed_all(seed, mode)
108
+ def remove_dropout():
109
+ if torch.__version__ > "1.8":
110
+ logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.")
111
+ import torch.nn.functional as F
112
+ from torch import _VF
113
+ from torch.overrides import has_torch_function_unary, handle_torch_function
114
+
115
+ def function_dropout(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
116
+ inplace: bool = False) -> torch.Tensor:
117
+ if has_torch_function_unary(input_tensor):
118
+ return handle_torch_function(
119
+ function_dropout, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
120
+ if p < 0.0 or p > 1.0:
121
+ raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
122
+ return _VF.dropout_(input_tensor, 0., training) if inplace else _VF.dropout(input_tensor, 0., training)
123
+
124
+ def function_dropout2d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
125
+ inplace: bool = False) -> torch.Tensor:
126
+ if has_torch_function_unary(input_tensor):
127
+ return handle_torch_function(
128
+ function_dropout2d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
129
+ if p < 0.0 or p > 1.0:
130
+ raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
131
+ return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
132
+ 0., training)
133
+
134
+ def function_dropout3d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
135
+ inplace: bool = False) -> torch.Tensor:
136
+ if has_torch_function_unary(input_tensor):
137
+ return handle_torch_function(
138
+ function_dropout3d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
139
+ if p < 0.0 or p > 1.0:
140
+ raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
141
+ return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
142
+ 0., training)
143
+
144
+ F.dropout = function_dropout
145
+ F.dropout2d = function_dropout2d
146
+ F.dropout3d = function_dropout3d
147
+
148
+
149
+ def seed_all(seed=1234, mode=False, rm_dropout=True):
150
+ check_seed_all(seed, mode, rm_dropout)
109
151
  try:
110
152
  random.seed(seed)
111
153
  os.environ['PYTHONHASHSEED'] = str(seed)
@@ -125,6 +167,8 @@ def seed_all(seed=1234, mode=False):
125
167
  else:
126
168
  torch_npu.npu.manual_seed_all(seed)
127
169
  torch_npu.npu.manual_seed(seed)
170
+ if rm_dropout:
171
+ remove_dropout()
128
172
  except Exception as e:
129
173
  logger.error(f"There is an unexpected error while determinating randomness. {e}")
130
174
 
@@ -269,17 +313,17 @@ def load_pt(pt_path, to_cpu=False):
269
313
  check_file_or_directory_path(pt_path)
270
314
  try:
271
315
  if to_cpu:
272
- pt = torch.load(pt_path, map_location=torch.device("cpu"))
316
+ pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True)
273
317
  else:
274
- pt = torch.load(pt_path)
318
+ pt = torch.load(pt_path, weights_only=True)
275
319
  except Exception as e:
276
320
  raise RuntimeError(f"load pt file {pt_path} failed") from e
277
321
  return pt
278
322
 
279
323
 
280
324
  def save_pt(tensor, filepath):
281
- filepath = os.path.realpath(filepath)
282
325
  check_path_before_create(filepath)
326
+ filepath = os.path.realpath(filepath)
283
327
  try:
284
328
  torch.save(tensor, filepath)
285
329
  except Exception as e:
@@ -290,6 +334,56 @@ def save_pt(tensor, filepath):
290
334
  change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
291
335
 
292
336
 
337
+ class TypeCheckingUnpickler(pickle.Unpickler):
338
+ """
339
+ This class is a subclass of pickle.Unpickler, which is used to unpickle pickled objects.
340
+ It overrides the find_class method to add type checking functionality.
341
+ """
342
+ allowed_types = [
343
+ "str",
344
+ "ApiData",
345
+ "OrderedDict",
346
+ "_rebuild_tensor_v2", # from torch.utils
347
+ "_load_from_bytes" # from torch.storage
348
+ ]
349
+
350
+ def find_class(self, module, name):
351
+ """
352
+ Method to find the class of the object to be unpickled.
353
+ Throws pickle.UnpicklingError If the object type is not in the allowed types list.
354
+ """
355
+ if name in self.allowed_types:
356
+ return super().find_class(module, name)
357
+ raise pickle.UnpicklingError("Unsupported object type: {}.{}".format(module, name))
358
+
359
+
360
+ def save_pkl(tensor, filepath):
361
+ """Save ApiData or str objection by pickle"""
362
+ check_path_before_create(filepath)
363
+ filepath = os.path.realpath(filepath)
364
+ try:
365
+ with FileOpen(filepath, 'wb') as f:
366
+ pickle.dump(tensor, f)
367
+ except Exception as e:
368
+ logger.error("Save pt file failed, please check according possible error causes: "
369
+ "1. out of disk space or disk error, "
370
+ "2. no permission to write files, etc.")
371
+ raise RuntimeError(f"save pt file {filepath} failed") from e
372
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
373
+
374
+
375
+ def load_pkl(pt_path):
376
+ """Load ApiData or str objection by pickle for accuracy_checker_online"""
377
+ check_file_or_directory_path(pt_path)
378
+ pt_path = os.path.realpath(pt_path)
379
+ try:
380
+ with FileOpen(pt_path, 'rb') as f:
381
+ pt = TypeCheckingUnpickler(f).load()
382
+ except Exception as e:
383
+ raise RuntimeError(f"load pt file {pt_path} failed: {e}") from e
384
+ return pt
385
+
386
+
293
387
  def save_api_data(api_data):
294
388
  """Save data to io stream"""
295
389
  try:
@@ -14,53 +14,40 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
- from msprobe.core.common.utils import CompareException, check_compare_param, \
18
- check_configuration_param, task_dumppath_get
19
- from msprobe.core.common.file_utils import create_directory
17
+
20
18
  from msprobe.core.common.exceptions import FileCheckException
19
+ from msprobe.core.common.file_utils import create_directory
20
+ from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
21
+ set_dump_path
22
+ from msprobe.core.compare.acc_compare import ModeConfig
23
+ from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json, set_stack_json_path
21
24
  from msprobe.pytorch.common.log import logger
22
- from msprobe.pytorch.compare.pt_compare import PTComparator
23
- from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
25
+ from msprobe.pytorch.compare.pt_compare import PTComparator, compare
24
26
 
25
27
 
26
28
  def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
27
- if kwargs.get('suffix'):
29
+ if kwargs.get("suffix"):
28
30
  logger.error("Argument 'suffix' is not supported for compare_distributed.")
29
31
  raise CompareException(CompareException.INVALID_PARAM_ERROR)
30
- stack_mode = kwargs.get('stack_mode', False)
31
- auto_analyze = kwargs.get('auto_analyze', True)
32
- fuzzy_match = kwargs.get('fuzzy_match', False)
32
+ is_print_compare_log = kwargs.get("is_print_compare_log", True)
33
33
  # get the ranks and match by order
34
34
  npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
35
35
  bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
36
36
  if len(npu_ranks) != len(bench_ranks):
37
- logger.error('The number of ranks in the two runs are different. '
38
- 'Unable to match the ranks. Please use another folder to compare '
39
- 'or use compare() api and manually match the ranks.')
37
+ logger.error(
38
+ "The number of ranks in the two runs are different. "
39
+ "Unable to match the ranks. "
40
+ "Please use another folder to compare or use compare() api and manually match the ranks.")
40
41
  raise CompareException(CompareException.INVALID_PATH_ERROR)
41
42
  for nr, br in zip(npu_ranks, bench_ranks):
42
43
  npu_data_dir = os.path.join(npu_dump_dir, nr)
43
44
  bench_data_dir = os.path.join(bench_dump_dir, br)
44
45
  npu_path = extract_json(npu_data_dir, stack_json=False)
45
46
  bench_path = extract_json(bench_data_dir, stack_json=False)
46
- stack_path = extract_json(npu_data_dir, stack_json=True)
47
47
 
48
48
  dump_result_param = {
49
- 'npu_json_path': npu_path,
50
- 'bench_json_path': bench_path,
51
- 'stack_json_path': stack_path,
52
- 'is_print_compare_log': True
49
+ "npu_json_path": npu_path,
50
+ "bench_json_path": bench_path,
51
+ "is_print_compare_log": is_print_compare_log
53
52
  }
54
- try:
55
- summary_compare, md5_compare = task_dumppath_get(dump_result_param)
56
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match,
57
- dump_result_param.get('is_print_compare_log', True))
58
- create_directory(output_path)
59
- check_compare_param(dump_result_param, output_path,
60
- summary_compare=summary_compare, md5_compare=md5_compare)
61
- except (CompareException, FileCheckException) as error:
62
- logger.error('Compare failed. Please check the arguments and do it again!')
63
- raise CompareException(error.code) from error
64
- pt_comparator = PTComparator()
65
- pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}',
66
- summary_compare=summary_compare, md5_compare=md5_compare, **kwargs)
53
+ compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
@@ -14,19 +14,29 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os.path
17
+
17
18
  import torch
19
+
18
20
  from msprobe.core.common.const import FileCheckConst
19
- from msprobe.pytorch.common.log import logger
20
21
  from msprobe.core.common.exceptions import FileCheckException
21
- from msprobe.core.compare.acc_compare import Comparator
22
- from msprobe.core.common.utils import check_configuration_param, task_dumppath_get, check_compare_param, \
23
- CompareException
24
22
  from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
23
+ from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
24
+ set_dump_path
25
+ from msprobe.core.compare.acc_compare import Comparator, ModeConfig
26
+ from msprobe.core.compare.utils import set_stack_json_path
27
+ from msprobe.pytorch.common.log import logger
25
28
  from msprobe.pytorch.common.utils import load_pt
26
29
 
27
30
 
28
- class PTComparator (Comparator):
29
- def __init__(self, data_mapping=None):
31
+ class PTComparator(Comparator):
32
+ def __init__(self, mode_config, data_mapping=None):
33
+ super().__init__(mode_config)
34
+
35
+ self.stack_mode = mode_config.stack_mode
36
+ self.auto_analyze = mode_config.auto_analyze
37
+ self.fuzzy_match = mode_config.fuzzy_match
38
+ self.dump_mode = mode_config.dump_mode
39
+
30
40
  self.frame_name = PTComparator.__name__
31
41
  self.data_mapping = data_mapping
32
42
  if isinstance(self.data_mapping, str) or self.data_mapping is None:
@@ -37,21 +47,24 @@ class PTComparator (Comparator):
37
47
  raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
38
48
  f"{type(self.data_mapping)}")
39
49
 
40
- def load_mapping_file(self, mapping_file):
50
+ @staticmethod
51
+ def load_mapping_file(mapping_file):
41
52
  if isinstance(mapping_file, str):
42
53
  mapping_dict = load_yaml(mapping_file)
43
54
  else:
44
55
  mapping_dict = {}
45
56
  return mapping_dict
46
-
57
+
47
58
  def read_npy_data(self, dir_path, file_name):
59
+ if not file_name:
60
+ return None
48
61
  data_path = os.path.join(dir_path, file_name)
49
62
  path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
50
- FileCheckConst.PT_SUFFIX, False)
63
+ FileCheckConst.PT_SUFFIX, False)
51
64
  data_path = path_checker.common_check()
52
65
  try:
53
- data_value = load_pt(data_path,
54
- to_cpu=True).detach() # detach because numpy can not process gradient information
66
+ # detach because numpy can not process gradient information
67
+ data_value = load_pt(data_path, to_cpu=True).detach()
55
68
  except RuntimeError as e:
56
69
  # 这里捕获 load_pt 中抛出的异常
57
70
  logger.error(f"Failed to load the .pt file at {data_path}.")
@@ -63,20 +76,29 @@ class PTComparator (Comparator):
63
76
  if data_value.dtype == torch.bfloat16:
64
77
  data_value = data_value.to(torch.float32)
65
78
  data_value = data_value.numpy()
66
- return data_value
67
-
68
-
69
- def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False, **kwargs):
79
+ return data_value
80
+
81
+
82
+ def compare(input_param, output_path, **kwargs):
70
83
  try:
71
- summary_compare, md5_compare = task_dumppath_get(input_param)
84
+ auto_analyze = kwargs.get('auto_analyze', True)
85
+ fuzzy_match = kwargs.get('fuzzy_match', False)
86
+ data_mapping = kwargs.get('data_mapping', None)
87
+ suffix = kwargs.get('suffix', '')
88
+
89
+ set_dump_path(input_param)
90
+ dump_mode = get_dump_mode(input_param)
91
+ if "stack_json_path" in input_param:
92
+ stack_mode = kwargs.get('stack_mode', False)
93
+ else:
94
+ stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
72
95
  check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
73
96
  create_directory(output_path)
74
- check_compare_param(input_param, output_path, summary_compare, md5_compare)
75
- data_mapping = kwargs.get('data_mapping', None)
97
+ check_compare_param(input_param, output_path, dump_mode, stack_mode)
76
98
  except (CompareException, FileCheckException) as error:
77
99
  logger.error('Compare failed. Please check the arguments and do it again!')
78
100
  raise CompareException(error.code) from error
79
- pt_comparator = PTComparator(data_mapping)
80
- pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
81
- auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
82
- md5_compare=md5_compare)
101
+
102
+ mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode)
103
+ pt_comparator = PTComparator(mode_config, data_mapping)
104
+ pt_comparator.compare_core(input_param, output_path, suffix=suffix)
@@ -31,13 +31,14 @@ class DebuggerConfig:
31
31
  self.scope = task_config.scope if task_config.scope else []
32
32
  self.list = task_config.list if task_config.list else []
33
33
  self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
34
- self.backward_input_list = task_config.backward_input if task_config.backward_input else []
35
- self.backward_input = {}
36
- self.acl_config = common_config.acl_config if common_config.acl_config else ""
37
- self.is_forward_acl_dump = True
38
34
  self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
39
35
  self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
40
36
  self.framework = Const.PT_FRAMEWORK
37
+ self.async_dump = common_config.async_dump if common_config.async_dump else False
38
+
39
+ if self.level == Const.LEVEL_L2:
40
+ self.is_backward_kernel_dump = False
41
+ self._check_and_adjust_config_with_l2()
41
42
 
42
43
  if self.task == Const.FREE_BENCHMARK:
43
44
  self.fuzz_device = task_config.fuzz_device
@@ -59,20 +60,11 @@ class DebuggerConfig:
59
60
  self.tls_path = task_config.tls_path if task_config.tls_path else ""
60
61
  self.host = task_config.host if task_config.host else ""
61
62
  self.port = task_config.port if task_config.port else -1
63
+ self.online_run_ut_recompute = task_config.online_run_ut_recompute \
64
+ if isinstance(task_config.online_run_ut_recompute, bool) else False
62
65
 
63
66
  self.check()
64
67
 
65
- if self.level == "L2":
66
- if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
67
- raise ValueError("scope must be configured as a list with one api name")
68
- if isinstance(self.scope[0], str) and Const.BACKWARD in self.scope[0] and not self.backward_input_list:
69
- raise ValueError("backward_input must be configured when scope contains 'backward'")
70
- if Const.BACKWARD in self.scope[0]:
71
- self.is_forward_acl_dump = False
72
- for index, scope_spec in enumerate(self.scope):
73
- self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
74
- self.backward_input[self.scope[index]] = self.backward_input_list[index]
75
-
76
68
  def check_kwargs(self):
77
69
  if self.task and self.task not in Const.TASK_LIST:
78
70
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
@@ -83,26 +75,53 @@ class DebuggerConfig:
83
75
  if not self.dump_path:
84
76
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
85
77
  f"The dump_path not found.")
78
+ if not isinstance(self.async_dump, bool):
79
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
80
+ f"The parameters async_dump should be bool.")
86
81
 
87
82
  def check(self):
88
83
  self.check_kwargs()
89
84
  return True
90
85
 
91
86
  def check_model(self, instance, start_model):
92
- if self.level not in ["L0", "mix"]:
87
+ if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
93
88
  if instance.model is not None or start_model is not None:
94
- logger.warning_on_rank_0(
89
+ logger.info_on_rank_0(
95
90
  f"The current level is not L0 or mix level, so the model parameters will not be used.")
96
91
  return
97
- if start_model is None:
98
- if instance.model is None:
99
- logger.error_on_rank_0(
100
- f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' argument.")
101
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
102
- return
103
- if isinstance(start_model, torch.nn.Module):
104
- instance.model = start_model
92
+ if start_model is None and instance.model is None:
93
+ logger.error_on_rank_0(
94
+ f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' parameter.")
95
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
96
+
97
+ instance.model = start_model if start_model is not None else instance.model
98
+ if isinstance(instance.model, torch.nn.Module):
99
+ return
100
+
101
+ error_model = None
102
+ if isinstance(instance.model, (list, tuple)):
103
+ for model in instance.model:
104
+ if not isinstance(model, torch.nn.Module):
105
+ error_model = model
106
+ break
105
107
  else:
106
- logger.error_on_rank_0(f"The 'model' parameter of start must be a torch.nn.Module type.")
108
+ error_model = instance.model
109
+
110
+ if error_model is not None:
111
+ error_info = (f"The 'model' parameter must be a torch.nn.Moudle or list[torch.nn.Moudle] "
112
+ f"type, currently there is a {type(error_model)} type.")
107
113
  raise MsprobeException(
108
- MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
114
+ MsprobeException.INVALID_PARAM_ERROR, error_info)
115
+
116
+ def _check_and_adjust_config_with_l2(self):
117
+ if self.scope:
118
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
119
+ f"When level is set to L2, the scope cannot be configured.")
120
+ if not self.list or len(self.list) != 1:
121
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
122
+ f"When level is set to L2, the list must be configured as a list with one api name.")
123
+ api_name = self.list[0]
124
+ if api_name.endswith(Const.BACKWARD):
125
+ self.is_backward_kernel_dump = True
126
+ api_forward_name = api_name[:-len(Const.BACKWARD)] + Const.FORWARD
127
+ self.list.append(api_forward_name)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -22,6 +22,7 @@ from msprobe.core.common.file_utils import FileChecker
22
22
  from msprobe.core.common.utils import get_real_step_or_rank
23
23
  from msprobe.pytorch.common.log import logger
24
24
  from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
25
+ from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper
25
26
  from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
26
27
  from msprobe.pytorch.pt_config import parse_json_config
27
28
  from msprobe.pytorch.service import Service
@@ -49,7 +50,7 @@ class PrecisionDebugger:
49
50
  dump_path=None,
50
51
  level=None,
51
52
  model=None,
52
- step=None,
53
+ step=None
53
54
  ):
54
55
  if not hasattr(self, "initialized"):
55
56
  config_params = ConfigParameters(config_path,
@@ -59,7 +60,6 @@ class PrecisionDebugger:
59
60
  model)
60
61
  self.check_input_params(config_params)
61
62
 
62
- self.api_origin = False
63
63
  self.initialized = True
64
64
  self.model = model
65
65
  common_config, task_config = parse_json_config(config_path, task)
@@ -67,12 +67,13 @@ class PrecisionDebugger:
67
67
  if self.task == Const.GRAD_PROBE:
68
68
  self.gm = GradientMonitor(common_config, task_config)
69
69
  return
70
- if step:
70
+ if step is not None:
71
71
  common_config.step = get_real_step_or_rank(step, Const.STEP)
72
72
  self.config = DebuggerConfig(
73
73
  common_config, task_config, task, dump_path, level
74
74
  )
75
75
  self.service = Service(self.config)
76
+ self.module_dumper = ModuleDumper(self.service)
76
77
  self.enable_dataloader = self.config.enable_dataloader
77
78
  if self.enable_dataloader:
78
79
  logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
@@ -105,9 +106,11 @@ class PrecisionDebugger:
105
106
  raise MsprobeException(
106
107
  MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
107
108
 
108
- if args.model is not None and not isinstance(args.model, torch.nn.Module):
109
- raise MsprobeException(
110
- MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
109
+ if args.model is not None:
110
+ logger.warning_on_rank_0(
111
+ "The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
112
+ "It is recommended to pass the 'model' parameter in the start interface instead."
113
+ )
111
114
 
112
115
  @classmethod
113
116
  def start(cls, model=None):
@@ -120,15 +123,12 @@ class PrecisionDebugger:
120
123
  if instance.enable_dataloader:
121
124
  logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
122
125
  else:
123
- instance.service.start(instance.model, instance.api_origin)
124
- instance.api_origin = False
126
+ instance.service.start(instance.model)
125
127
 
126
- # 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
127
128
  @classmethod
128
129
  def forward_backward_dump_end(cls):
129
130
  instance = cls._instance
130
- instance.service.forward_backward_dump_end()
131
- instance.api_origin = True
131
+ instance.stop()
132
132
 
133
133
  @classmethod
134
134
  def stop(cls):
@@ -159,6 +159,36 @@ class PrecisionDebugger:
159
159
  cls._instance.gm.monitor(model)
160
160
 
161
161
 
162
+ def module_dump(module, dump_name):
163
+ if not isinstance(module, torch.nn.Module):
164
+ raise MsprobeException(
165
+ MsprobeException.INVALID_PARAM_ERROR,
166
+ f"the module argument in module_dump must be a torch.nn.Module subclass"
167
+ )
168
+ if not isinstance(dump_name, str):
169
+ raise MsprobeException(
170
+ MsprobeException.INVALID_PARAM_ERROR,
171
+ f"the dump_name argument in module_dump must be a str type"
172
+ )
173
+ instance = PrecisionDebugger._instance
174
+ if not instance:
175
+ raise MsprobeException(
176
+ MsprobeException.INTERFACE_USAGE_ERROR,
177
+ f"PrecisionDebugger must be instantiated before using module_dump interface"
178
+ )
179
+ instance.module_dumper.start_module_dump(module, dump_name)
180
+
181
+
182
+ def module_dump_end():
183
+ instance = PrecisionDebugger._instance
184
+ if not instance:
185
+ raise MsprobeException(
186
+ MsprobeException.INTERFACE_USAGE_ERROR,
187
+ f"PrecisionDebugger must be instantiated before using module_dump_end interface"
188
+ )
189
+ instance.module_dumper.stop_module_dump()
190
+
191
+
162
192
  def iter_tracer(func):
163
193
  def func_wrapper(*args, **kwargs):
164
194
  debugger_instance = PrecisionDebugger.instance
@@ -0,0 +1,33 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from msprobe.core.common.file_utils import save_json
19
+
20
+
21
+ def create_kernel_config_json(dump_path, cur_rank):
22
+ kernel_config_name = "kernel_config.json" if cur_rank == '' else f"kernel_config_{cur_rank}.json"
23
+ kernel_config_path = os.path.join(dump_path, kernel_config_name)
24
+ config_info = {
25
+ "dump": {
26
+ "dump_list": [],
27
+ "dump_path": dump_path,
28
+ "dump_mode": "all",
29
+ "dump_op_switch": "on"
30
+ }
31
+ }
32
+ save_json(kernel_config_path, config_info, indent=4)
33
+ return kernel_config_path