mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -13,9 +13,10 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import atexit
16
17
  import os
17
18
 
18
- from msprobe.core.data_dump.scope import build_scope, ListScope
19
+ from msprobe.core.data_dump.scope import ScopeFactory
19
20
  from msprobe.core.data_dump.json_writer import DataWriter
20
21
  from msprobe.core.common.log import logger
21
22
  from msprobe.core.common.const import Const
@@ -27,7 +28,6 @@ def build_data_collector(config):
27
28
 
28
29
 
29
30
  class DataCollector:
30
- multi_output_apis = ["_sort_", "npu_flash_attention"]
31
31
  tasks_need_tensor_data = [Const.OVERFLOW_CHECK, Const.TENSOR, Const.FREE_BENCHMARK]
32
32
  level_without_construct = [Const.LEVEL_L1, Const.LEVEL_L2]
33
33
 
@@ -37,13 +37,10 @@ class DataCollector:
37
37
  self.data_processor = DataProcessorFactory.create_processor(self.config, self.data_writer)
38
38
  self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework)
39
39
  self.module_count = {}
40
- if self.config.task == Const.FREE_BENCHMARK:
41
- self.scope = build_scope(ListScope, self.config.scope, self.config.list)
42
- else:
43
- self.scope = build_scope(None, self.config.scope, self.config.list)
44
-
45
- def __del__(self):
46
- self.write_json()
40
+ self.scope = ScopeFactory(self.config).build_scope()
41
+ self.backward_module_names = {}
42
+ self.optimizer_status = ""
43
+ atexit.register(self.write_json)
47
44
 
48
45
  @property
49
46
  def dump_data_dir(self):
@@ -57,10 +54,6 @@ class DataCollector:
57
54
  def check_scope_and_pid(scope, name, pid):
58
55
  return (not scope or scope.check(name)) and pid == os.getpid()
59
56
 
60
- @staticmethod
61
- def is_inplace(module):
62
- return getattr(module, "op_is_inplace", False)
63
-
64
57
  def if_return_forward_new_output(self):
65
58
  return self.data_processor.if_return_forward_new_output()
66
59
 
@@ -84,36 +77,54 @@ class DataCollector:
84
77
  logger.debug(msg)
85
78
  self.data_writer.update_data(data_info)
86
79
 
87
- def pre_forward_data_collect(self, name, module, pid, module_input_output):
88
- backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
89
- if self.check_scope_and_pid(self.scope, backward_name, pid):
90
- self.data_processor.analyze_pre_forward(backward_name, module, module_input_output)
91
- if not self.is_inplace(module) or not self.check_scope_and_pid(self.scope, name, pid):
80
+ def forward_input_data_collect(self, name, module, pid, module_input_output):
81
+ if self.config.task == Const.FREE_BENCHMARK:
82
+ backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
83
+ if self.check_scope_and_pid(self.scope, backward_name, pid):
84
+ self.data_processor.analyze_forward_input(backward_name, module, module_input_output)
85
+ return
86
+
87
+ if not self.check_scope_and_pid(self.scope, name, pid):
88
+ return
89
+
90
+ data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
91
+ if self.config.level == Const.LEVEL_L2:
92
92
  return
93
- logger.info(f"API {name} is inplace.")
94
- data_info = self.data_processor.analyze_pre_forward_inplace(name, module_input_output)
95
93
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
96
94
 
97
- def forward_data_collect(self, name, module, pid, module_input_output):
95
+ def forward_output_data_collect(self, name, module, pid, module_input_output):
98
96
  self.update_construct(name)
99
97
  if not self.check_scope_and_pid(self.scope, name, pid):
100
98
  return
101
99
 
102
- if not self.is_inplace(module):
103
- data_info = self.data_processor.analyze_forward(name, module, module_input_output)
104
- else:
105
- data_info = self.data_processor.analyze_forward_inplace(name, module_input_output)
106
- if self.config.level == "L2":
100
+ data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
101
+ if self.config.level == Const.LEVEL_L2:
107
102
  return
108
103
  self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
109
104
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
110
105
 
106
+ def forward_data_collect(self, name, module, pid, module_input_output):
107
+ self.update_construct(name)
108
+ if not self.check_scope_and_pid(self.scope, name, pid):
109
+ return
110
+
111
+ data_info = self.data_processor.analyze_forward(name, module, module_input_output)
112
+ self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
113
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
114
+
111
115
  def backward_data_collect(self, name, module, pid, module_input_output):
112
116
  self.update_construct(name)
113
117
  if not self.check_scope_and_pid(self.scope, name, pid):
114
118
  return
115
119
 
116
120
  data_info = self.data_processor.analyze_backward(name, module, module_input_output)
121
+ if self.config.level == Const.LEVEL_L2:
122
+ return
123
+ # 获取执行反向的模块名称
124
+ if data_info and name.split(Const.SEP)[0] in Const.MODULE_PREFIX:
125
+ module_name = name.rsplit(Const.SEP, 2)[0]
126
+ # 将模块名称加入到反向模块名称集合中,用于梯度收集时判断是否需要收集梯度
127
+ self.backward_module_names[module_name] = True
117
128
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
118
129
 
119
130
  def backward_input_data_collect(self, name, module, pid, module_input_output):
@@ -134,12 +145,17 @@ class DataCollector:
134
145
 
135
146
  def update_construct(self, name):
136
147
  if self.config.level not in DataCollector.level_without_construct:
137
- self.data_writer.update_construct({name: self.module_processor.api_parent_node})
148
+ if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]:
149
+ self.data_writer.update_construct({name: self.optimizer_status})
150
+ else:
151
+ self.data_writer.update_construct({name: self.module_processor.api_parent_node})
138
152
  self.data_writer.update_construct(self.module_processor.module_node)
139
153
 
140
154
  def handle_data(self, name, data_info, flush=False):
141
155
  if data_info:
142
156
  self.update_data(name, data_info)
157
+ if self.config.async_dump:
158
+ return
143
159
  if not flush:
144
160
  self.data_writer.flush_data_periodically()
145
161
  else:
@@ -147,7 +163,23 @@ class DataCollector:
147
163
 
148
164
  def update_dump_paths(self, *args):
149
165
  self.data_writer.update_dump_paths(*args)
150
- self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level)
166
+
167
+ def initialize_json_file(self, framework=Const.UNKNOWN_FRAMEWORK):
168
+ self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level, framework=framework)
151
169
 
152
170
  def update_iter(self, current_iter):
153
171
  self.data_processor.update_iter(current_iter)
172
+
173
+ def params_data_collect(self, name, param_name, pid, data):
174
+ grad_name = name + Const.SEP + Const.PARAMS_GRAD
175
+ # 校验scope和pid,以及当前name是否有过反向计算
176
+ if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
177
+ # 如果没有反向计算,则需要清除之前占位写入的grad数据
178
+ if self.data_writer.cache_data.get("data"):
179
+ self.data_writer.cache_data.get("data").pop(grad_name, None)
180
+ return
181
+ data_info = self.data_processor.analyze_params(grad_name, param_name, data)
182
+ self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
183
+
184
+ def fill_stack_tensor_data(self):
185
+ self.data_writer.fill_stack_tensor_data()
@@ -1,7 +1,7 @@
1
1
  # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
6
6
  # You may obtain a copy of the License at
7
7
  #
@@ -15,10 +15,11 @@
15
15
 
16
16
  import inspect
17
17
  import os
18
- from dataclasses import dataclass
18
+ from dataclasses import dataclass, is_dataclass
19
19
  from typing import Tuple, Dict, Optional, Any
20
20
 
21
21
  import numpy as np
22
+
22
23
  from msprobe.core.common.const import Const
23
24
  from msprobe.core.common.log import logger
24
25
  from msprobe.core.common.utils import convert_tuple, CompareException
@@ -38,9 +39,8 @@ class ModuleForwardInputsOutputs:
38
39
  def output_tuple(self):
39
40
  return convert_tuple(self.output)
40
41
 
41
- def concat_args_and_kwargs(self):
42
- args = self.args + tuple(self.kwargs.values())
43
- return args
42
+ def update_output_with_args_and_kwargs(self):
43
+ self.output = self.args + tuple(self.kwargs.values())
44
44
 
45
45
 
46
46
  @dataclass
@@ -76,11 +76,12 @@ class ModuleBackwardOutputs:
76
76
 
77
77
 
78
78
  class TensorStatInfo:
79
- def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
79
+ def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None, stack_tensor_stat=None):
80
80
  self.max = max_val
81
81
  self.min = min_val
82
82
  self.mean = mean_val
83
83
  self.norm = norm_val
84
+ self.stack_tensor_stat = stack_tensor_stat
84
85
 
85
86
 
86
87
  class BaseDataProcessor:
@@ -101,6 +102,9 @@ class BaseDataProcessor:
101
102
  self.current_iter = 0
102
103
  self._return_forward_new_output = False
103
104
  self._forward_new_output = None
105
+ self.save_name = None
106
+ if hasattr(config, "data_mode"):
107
+ self.allowed_data_mode = self._get_allowed_data_mode(config.data_mode)
104
108
 
105
109
  @property
106
110
  def data_path(self):
@@ -182,6 +186,18 @@ class BaseDataProcessor:
182
186
  def _analyze_numpy(value, numpy_type):
183
187
  return {"type": numpy_type, "value": value}
184
188
 
189
+ @staticmethod
190
+ def _get_allowed_data_mode(data_mode):
191
+ if Const.ALL in data_mode:
192
+ allowed_data_mode = [Const.FORWARD, Const.BACKWARD, Const.INPUT, Const.OUTPUT]
193
+ else:
194
+ allowed_data_mode = list(set(data_mode))
195
+ if Const.FORWARD not in allowed_data_mode and Const.BACKWARD not in allowed_data_mode:
196
+ allowed_data_mode += [Const.FORWARD, Const.BACKWARD]
197
+ if Const.INPUT not in allowed_data_mode and Const.OUTPUT not in allowed_data_mode:
198
+ allowed_data_mode += [Const.INPUT, Const.OUTPUT]
199
+ return allowed_data_mode
200
+
185
201
  @classmethod
186
202
  def get_special_types(cls):
187
203
  return cls.special_type
@@ -194,25 +210,42 @@ class BaseDataProcessor:
194
210
  if isinstance(args, cls.get_special_types()):
195
211
  arg_transform = transform(args, cls._recursive_key_stack)
196
212
  return arg_transform
213
+ elif isinstance(args, tuple) and hasattr(args, '_fields'):
214
+ # namedtuple to dict
215
+ args_dict = {field: getattr(args, field) for field in args._fields}
216
+ return cls.apply_transform_dict(args_dict, transform, depth)
217
+ elif is_dataclass(args):
218
+ # dataclass to dict
219
+ args_dict = {field: getattr(args, field) for field in args.__dataclass_fields__}
220
+ return cls.apply_transform_dict(args_dict, transform, depth)
197
221
  elif isinstance(args, (list, tuple)):
198
- result_list = []
199
- for i, arg in enumerate(args):
200
- cls._recursive_key_stack.append(str(i))
201
- result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
202
- cls._recursive_key_stack.pop()
222
+ result_list = cls.apply_transform_list(args, transform, depth)
203
223
  return type(args)(result_list)
204
224
  elif isinstance(args, dict):
205
- result_dict = {}
206
- for k, arg in args.items():
207
- cls._recursive_key_stack.append(str(k))
208
- result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
209
- cls._recursive_key_stack.pop()
210
- return result_dict
225
+ return cls.apply_transform_dict(args, transform, depth)
211
226
  elif args is not None:
212
- logger.warning(f"Data type {type(args)} is not supported.")
227
+ logger.debug(f"Data type {type(args)} is not supported.")
213
228
  return None
214
229
  else:
215
230
  return None
231
+
232
+ @classmethod
233
+ def apply_transform_dict(cls, args, transform, depth):
234
+ result_dict = {}
235
+ for k, arg in args.items():
236
+ cls._recursive_key_stack.append(str(k))
237
+ result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
238
+ cls._recursive_key_stack.pop()
239
+ return result_dict
240
+
241
+ @classmethod
242
+ def apply_transform_list(cls, args, transform, depth):
243
+ result_list = []
244
+ for i, arg in enumerate(args):
245
+ cls._recursive_key_stack.append(str(i))
246
+ result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
247
+ cls._recursive_key_stack.pop()
248
+ return result_list
216
249
 
217
250
  def if_return_forward_new_output(self):
218
251
  return self._return_forward_new_output
@@ -239,17 +272,12 @@ class BaseDataProcessor:
239
272
  Return:
240
273
  bool: True if the parameters are in data_mode or data_mode is all, False otherwise.
241
274
  """
242
- return (Const.ALL in self.config.data_mode or
243
- forward_backward in self.config.data_mode or
244
- input_output in self.config.data_mode)
245
-
246
- def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
247
- pass
275
+ return forward_backward in self.allowed_data_mode and input_output in self.allowed_data_mode
248
276
 
249
277
  def analyze_element(self, element):
250
278
  return self.recursive_apply_transform(element, self.analyze_single_element)
251
279
 
252
- def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
280
+ def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
253
281
  api_info_struct = {}
254
282
  # check whether data_mode contains forward or input
255
283
  if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
@@ -261,16 +289,22 @@ class BaseDataProcessor:
261
289
  kwargs_info_list = self.analyze_element(module_input_output.kwargs)
262
290
  api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
263
291
 
264
- # check whether data_mode contains forward or output
292
+ return api_info_struct
293
+
294
+ def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
295
+ api_info_struct = {}
296
+ # check whether data_mode contains forward or input
265
297
  if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
266
- api_info_struct[name] = api_info_struct.get(name, {})
298
+ api_info_struct[name] = {}
267
299
  self.api_data_category = Const.OUTPUT
268
300
  output_info_list = self.analyze_element(module_input_output.output_tuple)
269
301
  api_info_struct[name][Const.OUTPUT] = output_info_list
302
+
270
303
  return api_info_struct
271
304
 
272
- def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
305
+ def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
273
306
  api_info_struct = {}
307
+ # check whether data_mode contains forward or input
274
308
  if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
275
309
  api_info_struct[name] = {}
276
310
  self.api_data_category = Const.INPUT
@@ -279,16 +313,18 @@ class BaseDataProcessor:
279
313
  self.api_data_category = Const.KWARGS
280
314
  kwargs_info_list = self.analyze_element(module_input_output.kwargs)
281
315
  api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
282
- return api_info_struct
283
316
 
284
- def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
285
- concat_args = module_input_output.concat_args_and_kwargs()
286
- api_info_struct = {}
317
+ # check whether data_mode contains forward or output
287
318
  if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
288
- api_info_struct[name] = {}
319
+ api_info_struct[name] = api_info_struct.get(name, {})
289
320
  self.api_data_category = Const.OUTPUT
290
- output_info_list = self.analyze_element(concat_args)
321
+ output_info_list = self.analyze_element(module_input_output.output_tuple)
291
322
  api_info_struct[name][Const.OUTPUT] = output_info_list
323
+
324
+ if name in api_info_struct and hasattr(module_input_output, Const.PARAMS):
325
+ self.api_data_category = Const.PARAMS
326
+ api_info_struct[name][Const.PARAMS] = self.analyze_element(getattr(module_input_output, Const.PARAMS))
327
+
292
328
  return api_info_struct
293
329
 
294
330
  def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
@@ -329,9 +365,21 @@ class BaseDataProcessor:
329
365
  api_info_struct[name][Const.OUTPUT] = output_info_list
330
366
  return api_info_struct
331
367
 
368
+ def analyze_params(self, name, param_name, grad):
369
+ api_info_struct = {}
370
+ self.save_name = name + Const.SEP + param_name
371
+ data_info = self.analyze_element(grad)
372
+ grad_info_dict = {param_name: [data_info]}
373
+ api_info_struct[name] = grad_info_dict
374
+ return api_info_struct
375
+
332
376
  def get_save_file_path(self, suffix):
333
377
  file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
334
- dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
335
- suffix + file_format)
378
+ if self.save_name is not None:
379
+ dump_data_name = (self.save_name + file_format)
380
+ self.save_name = None
381
+ else:
382
+ dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
383
+ suffix + file_format)
336
384
  file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
337
385
  return dump_data_name, file_path
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -56,7 +56,7 @@ class DataProcessorFactory:
56
56
  FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor,
57
57
  KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
58
58
  )
59
- from msprobe.pytorch.module_processer import ModuleProcesser
59
+ from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
60
60
  cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor)
61
61
  cls.register_processor(Const.PT_FRAMEWORK, Const.TENSOR, PytorchTensorDataProcessor)
62
62
  cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
@@ -67,10 +67,12 @@ class DataProcessorFactory:
67
67
  from msprobe.core.data_dump.data_processor.mindspore_processor import (
68
68
  StatisticsDataProcessor as MindsporeStatisticsDataProcessor,
69
69
  TensorDataProcessor as MindsporeTensorDataProcessor,
70
- OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor
70
+ OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor,
71
+ KernelDumpDataProcessor as MindsporeKernelDumpDataProcessor
71
72
  )
72
73
  from msprobe.mindspore.cell_processor import CellProcessor
73
74
  cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor)
74
75
  cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
75
76
  cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
77
+ cls.register_processor(Const.MS_FRAMEWORK, Const.KERNEL_DUMP, MindsporeKernelDumpDataProcessor)
76
78
  cls.register_module_processor(Const.MS_FRAMEWORK, CellProcessor)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright 2024-2025 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@
16
16
  import zlib
17
17
 
18
18
  import mindspore as ms
19
- from mindspore import mint, ops
19
+ from mindspore import mint, ops, hal
20
20
  from mindspore._c_expression.typing import Number
21
21
  import numpy as np
22
22
 
@@ -28,6 +28,12 @@ from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_
28
28
  from msprobe.mindspore.common.log import logger
29
29
  from msprobe.mindspore.dump.hook_cell.api_registry import api_register
30
30
 
31
+ has_adump = True
32
+ try:
33
+ from msprobe.lib import _msprobe_c
34
+ except ImportError:
35
+ has_adump = False
36
+
31
37
 
32
38
  class MindsporeDataProcessor(BaseDataProcessor):
33
39
  mindspore_special_type = tuple([ms.Tensor, Number])
@@ -37,6 +43,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
37
43
  self.mindspore_object_key = {
38
44
  "dtype": self.analyze_dtype_in_kwargs
39
45
  }
46
+ self._async_dump_cache = {}
40
47
 
41
48
  @staticmethod
42
49
  def get_md5_for_tensor(x):
@@ -49,15 +56,10 @@ class MindsporeDataProcessor(BaseDataProcessor):
49
56
  def analyze_dtype_in_kwargs(element):
50
57
  return {"type": "mindspore.dtype", "value": str(element)}
51
58
 
52
- @classmethod
53
- def get_special_types(cls):
54
- return super().get_special_types() + cls.mindspore_special_type
55
-
56
- def get_stat_info(self, data):
59
+ @staticmethod
60
+ def get_stat_info_sync(data):
57
61
  tensor_stat = TensorStatInfo()
58
- if data.numel() == 0:
59
- return tensor_stat
60
- elif data.dtype == ms.bool_:
62
+ if data.dtype == ms.bool_:
61
63
  data_np = data.asnumpy()
62
64
  tensor_stat.max = np.max(data_np).item()
63
65
  tensor_stat.min = np.min(data_np).item()
@@ -70,7 +72,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
70
72
  tensor_stat.mean = np.mean(data_abs).item()
71
73
  tensor_stat.norm = np.linalg.norm(data_abs).item()
72
74
  else:
73
- if not ops.is_floating_point(data):
75
+ if not ops.is_floating_point(data) or data.dtype == ms.float64:
74
76
  data = data.to(ms.float32)
75
77
  api_register.norm_inner_op_set_ori_func()
76
78
  get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
@@ -87,6 +89,47 @@ class MindsporeDataProcessor(BaseDataProcessor):
87
89
  api_register.norm_inner_op_set_hook_func()
88
90
  return tensor_stat
89
91
 
92
+ @staticmethod
93
+ def get_stat_info_async(data):
94
+ tensor_stat = TensorStatInfo()
95
+ stack_method = api_register.functional_ori_attr.get("stack", ms.ops.stack)
96
+ if data.dtype == ms.complex64 or data.dtype == ms.complex128:
97
+ logger.warning("Async dump do not support complex data!")
98
+ return tensor_stat
99
+ elif data.dtype == ms.bool_:
100
+ tensor_stat.stack_tensor_stat = (["Max", "Min"], stack_method([data.any(), data.all()]))
101
+ elif not data.shape:
102
+ tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method([data, data, data, data]))
103
+ else:
104
+ if not ops.is_floating_point(data) or data.dtype == ms.float64:
105
+ data = data.to(ms.float32)
106
+ api_register.norm_inner_op_set_ori_func()
107
+ get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
108
+ get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min)
109
+ get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean)
110
+ if hasattr(mint, "norm"):
111
+ get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
112
+ else:
113
+ get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
114
+ tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method(
115
+ [get_max_value(data), get_min_value(data), get_mean_value(data), get_norm_value(data)]))
116
+ api_register.norm_inner_op_set_hook_func()
117
+ return tensor_stat
118
+
119
+ @classmethod
120
+ def get_special_types(cls):
121
+ return super().get_special_types() + cls.mindspore_special_type
122
+
123
+ def get_stat_info(self, data):
124
+ tensor_stat = TensorStatInfo()
125
+ if data.numel() == 0:
126
+ return tensor_stat
127
+ else:
128
+ if self.config.async_dump:
129
+ return MindsporeDataProcessor.get_stat_info_async(data)
130
+ else:
131
+ return MindsporeDataProcessor.get_stat_info_sync(data)
132
+
90
133
  def analyze_single_element(self, element, suffix_stack):
91
134
  if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
92
135
  return self.mindspore_object_key[suffix_stack[-1]](element)
@@ -107,13 +150,17 @@ class MindsporeDataProcessor(BaseDataProcessor):
107
150
  tensor_json = {
108
151
  'type': 'mindspore.Tensor',
109
152
  'dtype': str(tensor.dtype),
110
- 'shape': tensor.shape,
111
- 'Max': self.transfer_type(tensor_stat.max),
112
- 'Min': self.transfer_type(tensor_stat.min),
113
- 'Mean': self.transfer_type(tensor_stat.mean),
114
- 'Norm': self.transfer_type(tensor_stat.norm),
153
+ 'shape': tensor.shape
115
154
  }
116
- if self.config.summary_mode == Const.MD5:
155
+
156
+ if tensor_stat.stack_tensor_stat is None:
157
+ tensor_json.update({'Max': self.transfer_type(tensor_stat.max)})
158
+ tensor_json.update({'Min': self.transfer_type(tensor_stat.min)})
159
+ tensor_json.update({'Mean': self.transfer_type(tensor_stat.mean)})
160
+ tensor_json.update({'Norm': self.transfer_type(tensor_stat.norm)})
161
+ else:
162
+ tensor_json.update({'tensor_stat': tensor_stat.stack_tensor_stat})
163
+ if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
117
164
  tensor_md5 = self.get_md5_for_tensor(tensor)
118
165
  tensor_json.update({Const.MD5: tensor_md5})
119
166
  return tensor_json
@@ -124,11 +171,19 @@ class StatisticsDataProcessor(MindsporeDataProcessor):
124
171
 
125
172
 
126
173
  class TensorDataProcessor(MindsporeDataProcessor):
174
+ def dump_async_data(self):
175
+ for file_path, tensor in self._async_dump_cache.items():
176
+ save_tensor_as_npy(tensor, file_path)
177
+ self._async_dump_cache.clear()
178
+
127
179
  def _analyze_tensor(self, tensor, suffix):
128
180
  dump_data_name, file_path = self.get_save_file_path(suffix)
129
181
  single_arg = super()._analyze_tensor(tensor, suffix)
130
182
  single_arg.update({"data_name": dump_data_name})
131
- save_tensor_as_npy(tensor, file_path)
183
+ if self.config.async_dump:
184
+ self._async_dump_cache[file_path] = tensor.copy()
185
+ else:
186
+ save_tensor_as_npy(tensor, file_path)
132
187
  return single_arg
133
188
 
134
189
 
@@ -138,6 +193,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
138
193
  def __init__(self, config, data_writer):
139
194
  super().__init__(config, data_writer)
140
195
  self.has_overflow = False
196
+ self.cached_api_info = {}
141
197
  self.cached_tensors_and_file_paths = {}
142
198
  self.real_overflow_nums = 0
143
199
  self.overflow_nums = config.overflow_nums
@@ -150,6 +206,20 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
150
206
  return True
151
207
  return False
152
208
 
209
+ def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
210
+ self.has_overflow = False
211
+ self.cached_api_info = super().analyze_forward_input(name, module, module_input_output)
212
+ return None
213
+
214
+ def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
215
+ api_info_struct = super().analyze_forward_output(name, module, module_input_output)
216
+ if name in self.cached_api_info and name in api_info_struct:
217
+ self.cached_api_info[name].update(api_info_struct[name])
218
+ elif name in api_info_struct:
219
+ self.cached_api_info = api_info_struct
220
+ self.maybe_save_overflow_data()
221
+ return self.cached_api_info if self.has_overflow else None
222
+
153
223
  def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
154
224
  self.has_overflow = False
155
225
  api_info_struct = super().analyze_forward(name, module, module_input_output)
@@ -161,6 +231,12 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
161
231
  api_info_struct = super().analyze_backward(name, module, module_input_output)
162
232
  self.maybe_save_overflow_data()
163
233
  return api_info_struct if self.has_overflow else None
234
+
235
+ def analyze_params(self, name, param_name, grad):
236
+ self.has_overflow = False
237
+ api_info_struct = super().analyze_params(name, param_name, grad)
238
+ self.maybe_save_overflow_data()
239
+ return api_info_struct if self.has_overflow else None
164
240
 
165
241
  def maybe_save_overflow_data(self):
166
242
  if self.has_overflow:
@@ -190,3 +266,61 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
190
266
  self._analyze_maybe_overflow_tensor(single_arg)
191
267
  single_arg.update({"data_name": dump_data_name})
192
268
  return single_arg
269
+
270
+
271
+ class KernelDumpDataProcessor(MindsporeDataProcessor):
272
+ def __init__(self, config, data_writer):
273
+ super().__init__(config, data_writer)
274
+ self.enable_kernel_dump = True
275
+
276
+ @staticmethod
277
+ def start_kernel_dump(config_path):
278
+ hal.synchronize()
279
+ _msprobe_c.init_dump()
280
+ _msprobe_c.set_dump(config_path)
281
+ hal.synchronize()
282
+
283
+ @staticmethod
284
+ def stop_kernel_dump():
285
+ hal.synchronize()
286
+ _msprobe_c.finalize_dump()
287
+ hal.synchronize()
288
+
289
+ @staticmethod
290
+ def _print_unsupported_log(api_name):
291
+ logger.warning(f"The kernel dump does not support the {api_name} API.")
292
+
293
+ def analyze_forward_input(self, name, module, module_input_output):
294
+ if not self.enable_kernel_dump:
295
+ return
296
+ if not has_adump:
297
+ logger.warning("The current msprobe package does not compile adump, and kernel dump cannot be used.")
298
+ self.enable_kernel_dump = False
299
+ return
300
+ self.start_kernel_dump(self.config.kernel_config_path)
301
+
302
+ def analyze_forward_output(self, name, module, module_input_output):
303
+ if not self.enable_kernel_dump:
304
+ return
305
+ self.enable_kernel_dump = False
306
+ self.stop_kernel_dump()
307
+ logger.info(f"The kernel data of {name} is dumped successfully.")
308
+
309
+ def analyze_backward_input(self, name, module, module_input_output):
310
+ if not self.enable_kernel_dump:
311
+ return
312
+ if not has_adump:
313
+ logger.warning("The current msprobe package does not compile adump, and kernel dump cannot be used.")
314
+ self.enable_kernel_dump = False
315
+ return
316
+ self.start_kernel_dump(self.config.kernel_config_path)
317
+
318
+ def analyze_backward(self, name, module, module_input_output):
319
+ if not self.enable_kernel_dump:
320
+ return
321
+ self.enable_kernel_dump = False
322
+ self.stop_kernel_dump()
323
+ logger.info(f"The kernel data of {name} is dumped successfully.")
324
+
325
+ def reset_status(self):
326
+ self.enable_kernel_dump = True