mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,23 +15,24 @@
15
15
 
16
16
  import functools
17
17
  import os
18
-
19
18
  from collections import namedtuple
19
+
20
20
  import torch
21
21
  from msprobe.core.common.const import Const
22
- from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
22
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
23
23
  from msprobe.core.common.file_utils import create_directory
24
24
  from msprobe.core.common.utils import print_tools_ends_info
25
25
  from msprobe.core.data_dump.data_collector import build_data_collector
26
26
  from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
27
27
  from msprobe.core.data_dump.scope import BaseScope
28
+ from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
28
29
  from msprobe.pytorch.common.log import logger
29
30
  from msprobe.pytorch.common.utils import get_rank_if_initialized
30
- from msprobe.pytorch.hook_module import remove_dropout
31
+ from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
32
+ from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
31
33
  from msprobe.pytorch.hook_module.api_registry import api_register
32
34
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
33
- from msprobe.pytorch.module_processer import ModuleProcesser
34
- from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
35
+ from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
35
36
 
36
37
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
37
38
  if torch_version_above_or_equal_2:
@@ -47,100 +48,175 @@ class Service:
47
48
  self.data_collector = build_data_collector(config)
48
49
  self.module_processor = ModuleProcesser(self.data_collector.scope)
49
50
  self.switch = False
51
+ self.inner_switch = False
50
52
  self.current_iter = 0
51
53
  self.first_start = True
52
54
  self.current_rank = None
53
55
  self.dump_iter_dir = None
54
56
  self.should_stop_service = False
55
57
  self.attl = None
56
-
57
- @staticmethod
58
- def forward_backward_dump_end():
59
- logger.info_on_rank_0("Data needed ends here.")
60
- api_register.api_originality()
61
-
62
- @staticmethod
63
- def is_registered_backward_hook(module):
64
- if hasattr(module, '_backward_hooks') and \
65
- len(module._backward_hooks) > 0 and \
66
- module._is_full_backward_hook is False:
67
- return True
68
- return False
69
-
70
- def check_register_full_backward_hook(self, module):
71
- if self.is_registered_backward_hook(module):
72
- module._backward_hooks.clear()
73
- module._is_full_backward_hook = None
74
- logger.warning("Found deprecated backward hooks. Removing them and switching to full backward hooks.")
58
+ self.params_grad_info = {}
59
+ # 提前注册,确保注册尽可能多的API hook
60
+ self.register_api_hook()
75
61
 
76
62
  def build_hook(self, module_type, name):
77
63
  def pre_hook(api_or_module_name, module, args, kwargs):
78
- if not self.should_execute_hook():
64
+ if not self.should_execute_hook(module_type, module, True):
79
65
  return args, kwargs
80
66
 
67
+ self.inner_switch = True
81
68
  if module_type == BaseScope.Module_Type_Module:
82
- api_or_module_name = module.mindstudio_reserved_name
69
+ api_or_module_name = module.mindstudio_reserved_name[-1]
70
+ else:
71
+ module.forward_data_collected = True
72
+ HOOKModule.add_module_count(name)
83
73
  self.data_collector.update_api_or_module_name(api_or_module_name)
84
74
 
85
75
  if self.config.online_run_ut:
76
+ self.inner_switch = False
86
77
  return None, None
87
78
  if self.data_collector:
88
79
  module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
89
- self.data_collector.pre_forward_data_collect(api_or_module_name, module, pid, module_input_output)
80
+ self.data_collector.forward_input_data_collect(api_or_module_name, module, pid, module_input_output)
81
+
82
+ self.inner_switch = False
90
83
  return args, kwargs
91
84
 
85
+ def grad_hook(module, ori_name, param_name):
86
+ def hook_fn(grad):
87
+ if not self.should_execute_hook(module_type, module, False):
88
+ return grad
89
+ self.inner_switch = True
90
+ self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
91
+ self.inner_switch = False
92
+ return grad
93
+
94
+ return hook_fn
95
+
96
+ def register_param_hook(ori_name, module, params_dict):
97
+ '''
98
+ 注册参数hook
99
+ '''
100
+ # data_mode为forward时,不注册参数hook
101
+ if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
102
+ for param_name, param in params_dict.items():
103
+ if param.requires_grad:
104
+ param.register_hook(grad_hook(module, ori_name, param_name))
105
+
106
+ def init_params_grad_info(module, params_dict):
107
+ '''
108
+ 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
109
+ '''
110
+ if not params_dict:
111
+ return
112
+ if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
113
+ grad_name = module.params_grad_name if hasattr(module, 'params_grad_name') else None
114
+ # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
115
+ if not self.params_grad_info.get(grad_name):
116
+ data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
117
+ # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
118
+ if data_info.get(grad_name):
119
+ # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
120
+ self.data_collector.handle_data(grad_name, data_info,
121
+ flush=self.data_collector.data_processor.is_terminated)
122
+ # 记录当前模块的参数梯度信息已占位
123
+ self.params_grad_info[grad_name] = True
124
+
92
125
  def forward_hook(api_or_module_name, module, args, kwargs, output):
93
- if not self.should_execute_hook():
126
+ if not self.should_execute_hook(module_type, module, True):
94
127
  return None
95
128
 
96
- if module_type == BaseScope.Module_Type_Module:
97
- api_or_module_name = module.mindstudio_reserved_name
98
- self.data_collector.update_api_or_module_name(api_or_module_name)
99
-
129
+ self.inner_switch = True
100
130
  if self.config.online_run_ut:
131
+ self.data_collector.update_api_or_module_name(api_or_module_name)
101
132
  if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
102
133
  return None
103
- api_data = ApiData(name[:-1], args, kwargs, output, self.current_iter, self.current_rank)
134
+ api_data = ApiData(
135
+ api_or_module_name[:-len(Const.FORWARD_NAME_SUFFIX)],
136
+ args,
137
+ kwargs,
138
+ output,
139
+ self.current_iter,
140
+ self.current_rank
141
+ )
104
142
  self.attl_send(api_data)
143
+ self.inner_switch = False
105
144
  return None
106
145
 
107
- if self.data_collector:
108
- module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
109
- self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output)
110
- if self.data_collector.if_return_forward_new_output():
111
- return self.data_collector.get_forward_new_output()
146
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
147
+ if module_type == BaseScope.Module_Type_Module:
148
+ api_or_module_name = module.mindstudio_reserved_name[-1]
149
+ self.data_collector.update_api_or_module_name(api_or_module_name)
150
+ params_dict = {key.split(Const.SEP)[-1]: value for key, value in module.named_parameters(recurse=False)}
151
+ setattr(module_input_output, Const.PARAMS, params_dict)
152
+ # 判断是否需要注册参数hook
153
+ if not hasattr(module, 'params_grad_name') and params_dict:
154
+ ori_name = api_or_module_name.rsplit(Const.SEP, 2)[0]
155
+ grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
156
+ # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
157
+ setattr(module, 'params_grad_name', grad_name)
158
+ register_param_hook(ori_name, module, params_dict)
159
+ self.data_collector.forward_data_collect(
160
+ api_or_module_name,
161
+ module,
162
+ pid,
163
+ module_input_output
164
+ )
165
+ init_params_grad_info(module, params_dict)
166
+ else:
167
+ self.data_collector.update_api_or_module_name(api_or_module_name)
168
+ self.data_collector.forward_output_data_collect(
169
+ api_or_module_name,
170
+ module,
171
+ pid,
172
+ module_input_output
173
+ )
174
+
175
+ if self.data_collector.if_return_forward_new_output():
176
+ forward_new_output = self.data_collector.get_forward_new_output()
177
+ self.inner_switch = False
178
+ return forward_new_output
179
+ self.inner_switch = False
112
180
  return output
113
181
 
114
182
  def forward_hook_torch_version_below_2(api_or_module_name, module, args, output):
115
183
  return forward_hook(api_or_module_name, module, args, {}, output)
116
184
 
117
185
  def backward_hook(api_or_module_name, module, grad_input, grad_output):
118
- if not self.should_execute_hook():
186
+ if not self.should_execute_hook(module_type, module, False):
119
187
  return
120
188
 
189
+ self.inner_switch = True
121
190
  if module_type == BaseScope.Module_Type_Module:
122
- api_or_module_name = module.mindstudio_reserved_name
191
+ api_or_module_name = module.mindstudio_reserved_name[-1]
123
192
  self.data_collector.update_api_or_module_name(api_or_module_name)
124
193
 
125
194
  if self.config.online_run_ut:
195
+ self.inner_switch = False
126
196
  return
127
197
 
128
198
  if self.data_collector:
129
199
  # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序
130
200
  module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
131
201
  self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output)
202
+ self.inner_switch = False
132
203
 
133
204
  pid = os.getpid()
134
- forward_name_template = name + Const.FORWARD
135
- backward_name_template = name + Const.BACKWARD
136
- pre_forward_hook_fn = functools.partial(pre_hook, forward_name_template)
137
- forward_hook_fn = functools.partial(forward_hook, forward_name_template)
138
- backward_hook_fn = functools.partial(backward_hook, backward_name_template)
139
- forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2,
140
- forward_name_template)
205
+ full_forward_name = None
206
+ full_backward_name = None
207
+ if module_type == BaseScope.Module_Type_API:
208
+ full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD
209
+ full_backward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.BACKWARD
210
+ pre_forward_hook_fn = functools.partial(pre_hook, full_forward_name)
211
+ forward_hook_fn = functools.partial(forward_hook, full_forward_name)
212
+ backward_hook_fn = functools.partial(backward_hook, full_backward_name)
213
+ forward_hook_torch_version_below_2_fn = functools.partial(
214
+ forward_hook_torch_version_below_2,
215
+ full_forward_name
216
+ )
141
217
  return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
142
218
 
143
- def start(self, model, api_origin=False):
219
+ def start(self, model):
144
220
  if self.need_stop_service():
145
221
  return
146
222
 
@@ -154,42 +230,44 @@ class Service:
154
230
 
155
231
  if self.config.rank and self.current_rank not in self.config.rank:
156
232
  return
157
- self.register_hook_new()
233
+ self.register_module_hook()
158
234
  self.first_start = False
159
- if api_origin:
160
- api_register.api_modularity()
161
235
  if self.config.online_run_ut and torch_version_above_or_equal_2:
162
- run_ut_dispatch(self.attl, True)
236
+ run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute)
163
237
  self.switch = True
164
238
  logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ")
165
- if self.config.level != "L2" and not self.config.online_run_ut:
239
+ if not self.config.online_run_ut:
166
240
  self.create_dirs()
167
241
  logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
168
242
 
169
243
  def stop(self):
170
244
  if self.should_stop_service:
171
245
  return
172
- if self.config.level == "L2":
173
- return
174
246
  if self.config.step and self.current_iter not in self.config.step:
175
247
  return
176
248
  if self.config.rank and self.current_rank not in self.config.rank:
177
249
  return
178
250
  self.switch = False
251
+ if self.config.level == Const.LEVEL_L2:
252
+ return
179
253
  if self.config.online_run_ut and torch_version_above_or_equal_2:
180
- run_ut_dispatch(self.attl, False)
254
+ run_ut_dispatch(self.attl, False, self.config.online_run_ut_recompute)
181
255
  return
256
+ if self.config.async_dump:
257
+ self.data_collector.fill_stack_tensor_data()
258
+ self.data_collector.data_processor.dump_async_data()
182
259
  self.data_collector.write_json()
183
260
 
184
261
  def step(self):
185
262
  if self.should_stop_service:
186
263
  return
264
+ if self.config.async_dump:
265
+ self.data_collector.fill_stack_tensor_data()
266
+ self.data_collector.data_processor.dump_async_data()
267
+ self.data_collector.write_json()
187
268
  self.current_iter += 1
188
269
  self.data_collector.update_iter(self.current_iter)
189
-
190
- ModuleProcesser.reset_module_stats()
191
- HOOKModule.reset_module_stats()
192
- self.data_collector.data_writer.reset_cache()
270
+ self.reset_status()
193
271
 
194
272
  def need_stop_service(self):
195
273
  if self.should_stop_service:
@@ -200,8 +278,6 @@ class Service:
200
278
  if self.config.online_run_ut:
201
279
  # send stop signal if online_run_ut
202
280
  self.attl_stop()
203
- if self.config.level in [Const.LEVEL_L1, Const.LEVEL_L2, Const.LEVEL_MIX]:
204
- api_register.api_originality()
205
281
  self.switch = False
206
282
  self.should_stop_service = True
207
283
  print_tools_ends_info()
@@ -210,10 +286,18 @@ class Service:
210
286
  return True
211
287
  return False
212
288
 
213
- def should_execute_hook(self):
214
- if not self.switch:
289
+ def should_execute_hook(self, hook_type, module, is_forward):
290
+ is_module_hook = hook_type == BaseScope.Module_Type_Module
291
+ if is_module_hook and not self.switch:
292
+ return False
293
+ elif not is_module_hook and is_forward and not self.switch:
215
294
  return False
216
- if self.data_collector and self.data_collector.data_processor.is_terminated:
295
+ elif not is_module_hook and not is_forward and not module.forward_data_collected:
296
+ return False
297
+
298
+ if self.inner_switch:
299
+ return False
300
+ if not self.data_collector or self.data_collector.data_processor.is_terminated:
217
301
  return False
218
302
  return True
219
303
 
@@ -221,6 +305,12 @@ class Service:
221
305
  create_directory(self.config.dump_path)
222
306
  self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
223
307
  cur_rank = self.current_rank if self.current_rank is not None else ''
308
+ if self.config.level == Const.LEVEL_L2:
309
+ create_directory(self.dump_iter_dir)
310
+ kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
311
+ self.config.kernel_config_path = kernel_config_path
312
+ return
313
+
224
314
  dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
225
315
  create_directory(dump_dir)
226
316
  if self.config.task in self.data_collector.tasks_need_tensor_data:
@@ -234,50 +324,26 @@ class Service:
234
324
  construct_file_path = os.path.join(dump_dir, "construct.json")
235
325
  free_benchmark_file_path = os.path.join(self.config.dump_path, "free_benchmark.csv")
236
326
  self.data_collector.update_dump_paths(
237
- dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path)
238
-
239
- def register_hook_new(self):
240
- logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task))
241
- if self.config.level in ["L0", "mix"]:
242
- if self.model is None:
243
- logger.error_log_with_exp("The model is None.", MsprobeException.INVALID_PARAM_ERROR)
244
- logger.info_on_rank_0("The init dump mode is enabled, and the module dump function will not be available")
245
- for name, module in self.model.named_modules():
246
- if module == self.model:
247
- continue
248
- prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \
249
- module.__class__.__name__ + Const.SEP
250
-
251
- pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.build_hook(
252
- BaseScope.Module_Type_Module, prefix)
253
- if torch_version_above_or_equal_2:
254
- module.register_forward_hook(forward_hook, with_kwargs=True)
255
- else:
256
- self.check_register_full_backward_hook(module)
257
- module.register_full_backward_hook(
258
- self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
259
- module.register_forward_hook(forward_hook_torch_version_below_2)
260
- self.check_register_full_backward_hook(module)
261
- module.register_full_backward_hook(backward_hook)
262
-
263
- module.register_forward_pre_hook(
264
- self.module_processor.node_hook(prefix + Const.FORWARD, Const.START))
265
- module.register_forward_hook(
266
- self.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
267
- if torch_version_above_or_equal_2:
268
- module.register_full_backward_pre_hook(
269
- self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
270
- self.check_register_full_backward_hook(module)
271
- module.register_full_backward_hook(
272
- self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
273
-
274
- if self.config.level in ["mix", "L1", "L2"]:
275
- api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API),
276
- self.config.online_run_ut)
327
+ dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path
328
+ )
329
+ self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
330
+
331
+ def register_api_hook(self):
332
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
333
+ logger.info_on_rank_0(f"The api {self.config.task} hook function is successfully mounted to the model.")
334
+ api_register.initialize_hook(
335
+ functools.partial(self.build_hook, BaseScope.Module_Type_API),
336
+ self.config.online_run_ut
337
+ )
277
338
  api_register.api_modularity()
278
339
 
279
- if Const.STATISTICS == self.config.task or Const.TENSOR == self.config.task:
280
- remove_dropout()
340
+ if self.config.level == Const.LEVEL_MIX:
341
+ register_optimizer_hook(self.data_collector)
342
+
343
+ def register_module_hook(self):
344
+ if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]:
345
+ logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.")
346
+ self.module_processor.register_module_hook(self.model, self.build_hook)
281
347
 
282
348
  def attl_init(self):
283
349
  if self.config.online_run_ut:
@@ -309,3 +375,17 @@ class Service:
309
375
  elif self.attl.socket_manager is not None:
310
376
  logger.info(f"pid: {os.getpid()} finished, start send STOP signal.")
311
377
  self.attl.socket_manager.send_stop_signal()
378
+
379
+ def reset_status(self):
380
+ ModuleProcesser.reset_module_stats()
381
+ HOOKModule.reset_module_stats()
382
+ self.data_collector.data_writer.reset_cache()
383
+ self.params_grad_info.clear()
384
+
385
+ if self.config.level == Const.LEVEL_L2:
386
+ self.data_collector.data_processor.reset_status()
387
+ return
388
+ if self.config.step and self.current_iter not in self.config.step:
389
+ return
390
+ if self.config.rank and self.current_rank not in self.config.rank:
391
+ return
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.