mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -13,19 +13,24 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import hashlib
16
17
  import zlib
17
18
  from dataclasses import asdict
18
19
  from typing import List
19
20
 
20
21
  import numpy as np
21
22
  import torch
23
+ from torch import distributed as dist
24
+
22
25
  from msprobe.core.common.const import Const
23
26
  from msprobe.core.common.file_utils import path_len_exceeds_limit
24
27
  from msprobe.core.common.log import logger
28
+ from msprobe.core.common.utils import convert_tuple
25
29
  from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
26
30
  ModuleForwardInputsOutputs, TensorStatInfo
27
31
  from msprobe.pytorch.common.utils import save_pt, load_pt
28
32
  from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
33
+ from msprobe.core.common.utils import recursion_depth_decorator
29
34
 
30
35
  is_gpu = False
31
36
  try:
@@ -35,7 +40,13 @@ except ImportError:
35
40
 
36
41
 
37
42
  class PytorchDataProcessor(BaseDataProcessor):
38
- pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor)
43
+ pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor, torch.memory_format, dist.ProcessGroup)
44
+ memory_format = {
45
+ torch.contiguous_format: "contiguous_format",
46
+ torch.channels_last: "channels_last",
47
+ torch.channels_last_3d: "channels_last_3d",
48
+ torch.preserve_format: "preserve_format"
49
+ }
39
50
 
40
51
  def __init__(self, config, data_writer):
41
52
  super().__init__(config, data_writer)
@@ -43,6 +54,7 @@ class PytorchDataProcessor(BaseDataProcessor):
43
54
  "device": self.analyze_device_in_kwargs,
44
55
  "dtype": self.analyze_dtype_in_kwargs
45
56
  }
57
+ self._async_dump_cache = {}
46
58
 
47
59
  @staticmethod
48
60
  def get_md5_for_tensor(x):
@@ -71,53 +83,114 @@ class PytorchDataProcessor(BaseDataProcessor):
71
83
  return {"type": "torch.dtype", "value": str(element)}
72
84
 
73
85
  @staticmethod
74
- def get_stat_info(data):
86
+ def get_stat_info_async(data):
75
87
  tensor_stat = TensorStatInfo()
76
- if data.is_meta:
77
- return tensor_stat
78
- data_clone = data.detach()
79
- if data_clone.numel() == 0:
88
+ if torch.is_complex(data):
89
+ logger.warning("Async dump do not support complex data!")
80
90
  return tensor_stat
81
- elif data_clone.dtype == torch.bool:
82
- tensor_stat.max = True in data_clone
83
- tensor_stat.min = False not in data_clone
84
- elif not data_clone.shape:
85
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item()
86
- elif torch.is_complex(data_clone):
87
- data_np = data_clone.cpu().numpy()
91
+ elif data.dtype == torch.bool:
92
+ tensor_stat.stack_tensor_stat = (["Max", "Min"], torch.stack(
93
+ [torch.any(data), torch.all(data)]))
94
+ elif not data.shape:
95
+ tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([data, data, data, data]))
96
+ else:
97
+ if not data.is_floating_point() or data.dtype == torch.float64:
98
+ data = data.float()
99
+ tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([
100
+ torch.max(data),
101
+ torch.min(data),
102
+ torch.mean(data),
103
+ torch.norm(data)
104
+ ]))
105
+ return tensor_stat
106
+
107
+ @staticmethod
108
+ def get_stat_info_sync(data):
109
+ tensor_stat = TensorStatInfo()
110
+ if torch.is_complex(data):
111
+ data_np = data.cpu().numpy()
88
112
  data_abs = np.abs(data_np)
89
113
  tensor_stat.max = np.max(data_abs).item()
90
114
  tensor_stat.min = np.min(data_abs).item()
91
115
  tensor_stat.mean = np.mean(data_abs).item()
116
+ elif data.dtype == torch.bool:
117
+ tensor_stat.max = torch.any(data).item()
118
+ tensor_stat.min = torch.all(data).item()
119
+ elif not data.shape:
120
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
92
121
  else:
93
- if not data_clone.is_floating_point() or data_clone.dtype == torch.float64:
94
- data_clone = data_clone.float()
95
- tensor_stat.max = torch._C._VariableFunctionsClass.max(data_clone).item()
96
- tensor_stat.min = torch._C._VariableFunctionsClass.min(data_clone).item()
97
- tensor_stat.mean = torch._C._VariableFunctionsClass.mean(data_clone).item()
98
- tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item()
122
+ if not data.is_floating_point() or data.dtype == torch.float64:
123
+ data = data.float()
124
+ tensor_stat.max = torch.max(data).item()
125
+ tensor_stat.min = torch.min(data).item()
126
+ tensor_stat.mean = torch.mean(data).item()
127
+ tensor_stat.norm = torch.norm(data).item()
99
128
  return tensor_stat
100
129
 
130
+ @staticmethod
131
+ def get_stat_info(data, async_dump=False):
132
+ tensor_stat = TensorStatInfo()
133
+ if data.is_meta:
134
+ return tensor_stat
135
+ data_clone = data.detach()
136
+ if data_clone.numel() == 0:
137
+ return tensor_stat
138
+ else:
139
+ if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump:
140
+ return PytorchDataProcessor.get_stat_info_sync(data_clone)
141
+ else:
142
+ return PytorchDataProcessor.get_stat_info_async(data_clone)
143
+
101
144
  @staticmethod
102
145
  def handle_tensor_extremum_nan_inf(tensor, operator):
103
146
  data_clone = tensor.detach()
104
- data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
105
- if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
147
+ data_nan = torch.isnan(data_clone)
148
+ if int(torch.sum(data_nan)) == data_clone.numel():
106
149
  return float('nan')
107
- finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
108
- if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
150
+
151
+ finite_mask = torch.isfinite(data_clone)
152
+ if int(torch.sum(finite_mask)) > 0:
109
153
  finite_values = data_clone[finite_mask]
110
- return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
111
- torch._C._VariableFunctionsClass.min(finite_values).item()
154
+ return torch.max(finite_values).item() if operator == 'max' else \
155
+ torch.min(finite_values).item()
112
156
  else:
113
157
  data_no_nan = data_clone[~data_nan]
114
- return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
115
- torch._C._VariableFunctionsClass.min(data_no_nan).item()
158
+ return torch.max(data_no_nan).item() if operator == 'max' else \
159
+ torch.min(data_no_nan).item()
160
+
161
+ @staticmethod
162
+ def process_group_hash(arg):
163
+ group_ranks = dist.get_process_group_ranks(arg)
164
+ group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest()
165
+ return group_ranks_hash
166
+
167
+ @staticmethod
168
+ def is_distributed_op(module):
169
+ return getattr(module, "op_is_distributed", False)
116
170
 
117
171
  @staticmethod
118
172
  def _analyze_torch_size(arg):
119
173
  return {"type": "torch.Size", "value": list(arg)}
120
174
 
175
+ @staticmethod
176
+ def _analyze_memory_format(arg):
177
+ # 获取内存格式
178
+ format_type = PytorchDataProcessor.memory_format.get(arg)
179
+
180
+ return {"type": "torch.memory_format", "format": format_type}
181
+
182
+ @staticmethod
183
+ def _analyze_process_group(arg):
184
+ group_info = {"type": "torch.ProcessGroup"}
185
+ try:
186
+ group_ranks = dist.get_process_group_ranks(arg)
187
+ group_info.update({"group_ranks": group_ranks})
188
+ group_id = PytorchDataProcessor.process_group_hash(arg)
189
+ group_info.update({"group_id": group_id})
190
+ except Exception as e:
191
+ logger.warning(f"Failed to get process group(id: {group_id}) ranks info with error info: {e}.")
192
+ return group_info
193
+
121
194
  @classmethod
122
195
  def get_special_types(cls):
123
196
  return super().get_special_types() + cls.pytorch_special_type
@@ -127,6 +200,10 @@ class PytorchDataProcessor(BaseDataProcessor):
127
200
  return self.torch_object_key[suffix_stack[-1]](element)
128
201
  if isinstance(element, torch.Size):
129
202
  return self._analyze_torch_size(element)
203
+ if isinstance(element, torch.memory_format):
204
+ return self._analyze_memory_format(element)
205
+ if isinstance(element, dist.ProcessGroup):
206
+ return self._analyze_process_group(element)
130
207
  converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
131
208
  if converted_numpy is not element:
132
209
  return self._analyze_numpy(converted_numpy, numpy_type)
@@ -136,26 +213,35 @@ class PytorchDataProcessor(BaseDataProcessor):
136
213
  return self._analyze_builtin(element)
137
214
  return {}
138
215
 
216
+ def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
217
+ if self.is_distributed_op(module):
218
+ module_input_output.update_output_with_args_and_kwargs()
219
+ return super().analyze_forward_output(name, module, module_input_output)
220
+
139
221
  def _analyze_tensor(self, tensor, suffix):
140
- tensor_stat = self.get_stat_info(tensor)
222
+ tensor_stat = self.get_stat_info(tensor, self.config.async_dump)
141
223
  tensor_json = {}
142
224
  tensor_json.update({'type': 'torch.Tensor'})
143
225
  tensor_json.update({'dtype': str(tensor.dtype)})
144
226
  tensor_json.update({"shape": tensor.shape})
145
- tensor_json.update({"Max": tensor_stat.max})
146
- tensor_json.update({"Min": tensor_stat.min})
147
- tensor_json.update({"Mean": tensor_stat.mean})
148
- tensor_json.update({"Norm": tensor_stat.norm})
149
- tensor_json.update({"requires_grad": tensor.requires_grad})
150
-
151
- if tensor_stat.max is not None:
152
- if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
153
- tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
154
- if tensor_stat.min is not None:
155
- if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
156
- tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
157
-
158
- if self.config.summary_mode == Const.MD5:
227
+ if tensor_stat.stack_tensor_stat is None:
228
+ tensor_json.update({"Max": tensor_stat.max})
229
+ tensor_json.update({"Min": tensor_stat.min})
230
+ tensor_json.update({"Mean": tensor_stat.mean})
231
+ tensor_json.update({"Norm": tensor_stat.norm})
232
+ tensor_json.update({"requires_grad": tensor.requires_grad})
233
+ if tensor_stat.max is not None:
234
+ if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
235
+ tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
236
+ if tensor_stat.min is not None:
237
+ if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
238
+ tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
239
+
240
+ else:
241
+ tensor_json.update({"requires_grad": tensor.requires_grad})
242
+ tensor_json.update({"tensor_stat": tensor_stat.stack_tensor_stat})
243
+
244
+ if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
159
245
  tensor_md5 = self.get_md5_for_tensor(tensor)
160
246
  tensor_json.update({Const.MD5: tensor_md5})
161
247
  return tensor_json
@@ -166,12 +252,20 @@ class StatisticsDataProcessor(PytorchDataProcessor):
166
252
 
167
253
 
168
254
  class TensorDataProcessor(PytorchDataProcessor):
255
+ def dump_async_data(self):
256
+ for file_path, tensor in self._async_dump_cache.items():
257
+ save_pt(tensor.contiguous(), file_path)
258
+ self._async_dump_cache.clear()
259
+
169
260
  def _analyze_tensor(self, tensor, suffix):
170
261
  dump_data_name, file_path = self.get_save_file_path(suffix)
171
- saved_tensor = tensor.clone().contiguous().detach()
172
- save_pt(saved_tensor, file_path)
173
262
  single_arg = super()._analyze_tensor(tensor, suffix)
174
263
  single_arg.update({"data_name": dump_data_name})
264
+ if self.config.async_dump:
265
+ self._async_dump_cache[file_path] = tensor.clone().detach()
266
+ else:
267
+ saved_tensor = tensor.clone().contiguous().detach()
268
+ save_pt(saved_tensor, file_path)
175
269
  return single_arg
176
270
 
177
271
 
@@ -182,7 +276,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
182
276
  super().__init__(config, data_writer)
183
277
  self.has_overflow = False
184
278
  self.support_inf_nan = None
185
- self.cached_inplace_api_info = {}
279
+ self.cached_api_info = {}
186
280
  self.cached_tensors_and_file_paths = {}
187
281
  self.bits_for_overflow = 8
188
282
  self.real_overflow_nums = 0
@@ -196,21 +290,21 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
196
290
  return True
197
291
  return False
198
292
 
199
- def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
293
+ def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
200
294
  self.has_overflow = False
201
295
  self._is_support_inf_nan()
202
- self.cached_inplace_api_info = super().analyze_pre_forward_inplace(name, module_input_output)
296
+ self.cached_api_info = super().analyze_forward_input(name, module, module_input_output)
203
297
  return None
204
298
 
205
- def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
299
+ def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
206
300
  self._is_support_inf_nan()
207
- api_info_struct = super().analyze_forward_inplace(name, module_input_output)
208
- if name in self.cached_inplace_api_info and name in api_info_struct:
209
- self.cached_inplace_api_info[name].update(api_info_struct[name])
301
+ api_info_struct = super().analyze_forward_output(name, module, module_input_output)
302
+ if name in self.cached_api_info and name in api_info_struct:
303
+ self.cached_api_info[name].update(api_info_struct[name])
210
304
  elif name in api_info_struct:
211
- self.cached_inplace_api_info = api_info_struct
305
+ self.cached_api_info = api_info_struct
212
306
  self.handle_overflow()
213
- return self.cached_inplace_api_info if self.has_overflow else None
307
+ return self.cached_api_info if self.has_overflow else None
214
308
 
215
309
  def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
216
310
  self.has_overflow = False
@@ -225,6 +319,13 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
225
319
  api_info_struct = super().analyze_backward(name, module, module_input_output)
226
320
  self.handle_overflow()
227
321
  return api_info_struct if self.has_overflow else None
322
+
323
+ def analyze_params(self, name, param_name, grad):
324
+ self.has_overflow = False
325
+ self._is_support_inf_nan()
326
+ api_info_struct = super().analyze_params(name, param_name, grad)
327
+ self.handle_overflow()
328
+ return api_info_struct if self.has_overflow else None
228
329
 
229
330
  def handle_overflow(self):
230
331
  if not self.support_inf_nan:
@@ -299,10 +400,10 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
299
400
  )
300
401
  return
301
402
 
302
- def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
403
+ def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
303
404
  self.checker.pre_forward(name, module, self, module_input_output.args, module_input_output.kwargs)
304
405
 
305
- def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
406
+ def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
306
407
  new_output, unequal_rows = self.checker.forward(
307
408
  name,
308
409
  module,
@@ -320,64 +421,120 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
320
421
 
321
422
 
322
423
  class KernelDumpDataProcessor(PytorchDataProcessor):
323
- forward_init_status = False
324
- multi_output_apis = ["_sort_", "npu_flash_attention"]
325
-
326
424
  def __init__(self, config, data_writer):
327
425
  super().__init__(config, data_writer)
426
+ self.enable_kernel_dump = True
427
+ self.is_found_output_tensor = False
428
+ self.is_found_grad_input_tensor = False
429
+ self.forward_args = None
430
+ self.forward_kwargs = None
431
+ self.forward_output_tensor = None
432
+ self.grad_input_tensor = None
433
+
434
+ @staticmethod
435
+ def start_kernel_dump(config_path):
436
+ torch_npu.npu.synchronize()
437
+ torch_npu.npu.init_dump()
438
+ torch_npu.npu.set_dump(config_path)
439
+ torch_npu.npu.synchronize()
440
+
441
+ @staticmethod
442
+ def stop_kernel_dump():
443
+ torch_npu.npu.synchronize()
444
+ torch_npu.npu.finalize_dump()
445
+ torch_npu.npu.synchronize()
446
+
447
+ @staticmethod
448
+ def _print_unsupported_log(api_name):
449
+ logger.warning(f"The kernel dump does not support the {api_name} API.")
450
+
451
+ def analyze_forward_input(self, name, module, module_input_output):
452
+ if not self.enable_kernel_dump:
453
+ return
454
+ if is_gpu:
455
+ logger.warning("The current environment is not a complete NPU environment, and kernel dump cannot be used.")
456
+ self.enable_kernel_dump = False
457
+ return
458
+
459
+ if self.config.is_backward_kernel_dump:
460
+ self.forward_args = self.clone_and_detach_tensor(module_input_output.args)
461
+ self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs)
462
+ try:
463
+ output = module.forward(*self.forward_args, **self.forward_kwargs)
464
+ except Exception:
465
+ self._print_unsupported_log(name)
466
+ self.enable_kernel_dump = False
467
+ return
468
+
469
+ self.analyze_element(convert_tuple(output))
470
+ if not self.is_found_output_tensor:
471
+ self._print_unsupported_log(name)
472
+ self.enable_kernel_dump = False
473
+ return
474
+ self.start_kernel_dump(self.config.kernel_config_path)
475
+
476
+ def analyze_forward_output(self, name, module, module_input_output):
477
+ if not self.enable_kernel_dump:
478
+ return
479
+ if self.config.is_backward_kernel_dump:
480
+ return
481
+ self.enable_kernel_dump = False
482
+ self.stop_kernel_dump()
483
+ logger.info(f"The kernel data of {name} is dumped successfully.")
484
+
485
+ def analyze_backward(self, name, module, module_input_output):
486
+ if not self.enable_kernel_dump:
487
+ return
488
+ self.enable_kernel_dump = False
489
+
490
+ self.analyze_element(module_input_output.grad_input)
491
+ if not self.is_found_grad_input_tensor:
492
+ self._print_unsupported_log(name)
493
+ return
494
+ self.start_kernel_dump(self.config.kernel_config_path)
495
+
496
+ try:
497
+ self.forward_output_tensor.backward(self.grad_input_tensor, retain_graph=True)
498
+ except Exception:
499
+ self._print_unsupported_log(name)
500
+ self.stop_kernel_dump()
501
+ return
328
502
 
329
- def analyze_forward(self, name, module, module_input_output):
330
- if self.config.is_forward_acl_dump:
331
- self.forward_acl_dump(name, module, module_input_output)
503
+ self.stop_kernel_dump()
504
+ logger.info(f"The kernel data of {name} is dumped successfully.")
505
+
506
+ @recursion_depth_decorator("KernelDump: KernelDumpDataProcessor.clone_and_detach_tensor")
507
+ def clone_and_detach_tensor(self, input_params):
508
+ if isinstance(input_params, torch.Tensor):
509
+ if input_params.requires_grad:
510
+ return input_params.clone().detach().requires_grad_()
511
+ return input_params.clone()
512
+ elif isinstance(input_params, tuple):
513
+ return tuple(self.clone_and_detach_tensor(x) for x in input_params)
514
+ elif isinstance(input_params, list):
515
+ return list(self.clone_and_detach_tensor(x) for x in input_params)
516
+ elif isinstance(input_params, dict):
517
+ return {k: self.clone_and_detach_tensor(v) for k, v in input_params.items()}
332
518
  else:
333
- self.dump_mode_backward_acl_dump(name, module, module_input_output)
334
-
335
- def forward_acl_dump(self, name, module, module_input_output):
336
- if not KernelDumpDataProcessor.forward_init_status:
337
- KernelDumpDataProcessor.forward_init_status = True
338
- torch_npu.npu.synchronize()
339
- torch_npu.npu.init_dump()
340
- torch_npu.npu.set_dump(self.config.acl_config)
341
- torch_npu.npu.synchronize()
342
- if self.op_need_trigger(name):
343
- module.forward(*module_input_output.args, **module_input_output.kwargs).cpu()
344
- else:
345
- module.forward(*module_input_output.args, **module_input_output.kwargs)
346
- torch_npu.npu.synchronize()
347
- torch_npu.npu.finalize_dump()
348
- torch_npu.npu.synchronize()
349
- KernelDumpDataProcessor.forward_init_status = False
350
- logger.info("Dump %s op file." % name)
351
-
352
- def acl_backward_dump_status(self, output, grad, module_name):
353
- if isinstance(output, torch.Tensor):
354
- output.backward(grad, retain_graph=True)
355
- return True
519
+ return input_params
356
520
 
357
- for api_name in KernelDumpDataProcessor.multi_output_apis:
358
- if api_name in module_name:
359
- output[0].backward(grad, retain_graph=True)
360
- return True
361
- return False
521
+ def analyze_single_element(self, element, suffix_stack):
522
+ if isinstance(element, torch.Tensor):
523
+ if not self.is_found_output_tensor:
524
+ if element.requires_grad:
525
+ self.forward_output_tensor = element
526
+ self.is_found_output_tensor = True
527
+ return {}
528
+ if not self.is_found_grad_input_tensor:
529
+ self.grad_input_tensor = element.clone()
530
+ self.is_found_grad_input_tensor = True
531
+ return {}
362
532
 
363
- def dump_mode_backward_acl_dump(self, name, module, module_input_output):
364
- grad_path = self.config.backward_input.get(name)
365
- if not KernelDumpDataProcessor.forward_init_status:
366
- KernelDumpDataProcessor.forward_init_status = True
367
- output = module.forward(*module_input_output.args, **module_input_output.kwargs)
368
- pt = load_pt(grad_path)
369
- grad = pt.to("npu").requires_grad_()
370
- torch_npu.npu.init_dump()
371
- torch_npu.npu.set_dump(self.config.acl_config)
372
- torch_npu.npu.synchronize()
373
- if not self.acl_backward_dump_status(output, grad, name):
374
- logger.warning("The output of {} is not of tensor type and cannot be automatically derived. "
375
- "you can manually construct a single API backward case for ACL dump.".format(
376
- name))
377
- torch_npu.npu.synchronize()
378
- torch_npu.npu.finalize_dump()
379
- KernelDumpDataProcessor.forward_init_status = False
380
- logger.info("Dump %s op file." % name)
381
-
382
- def op_need_trigger(self, module_name):
383
- return 'Tensor.__getitem__.' in module_name
533
+ def reset_status(self):
534
+ self.enable_kernel_dump = True
535
+ self.is_found_output_tensor = False
536
+ self.is_found_grad_input_tensor = False
537
+ self.forward_args = None
538
+ self.forward_kwargs = None
539
+ self.forward_output_tensor = None
540
+ self.grad_input_tensor = None
@@ -15,10 +15,12 @@
15
15
 
16
16
  import csv
17
17
  import os
18
+ import numpy as np
18
19
 
19
20
  from msprobe.core.common.const import Const, FileCheckConst
20
- from msprobe.core.common.file_utils import change_mode, FileOpen, save_json
21
+ from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json
21
22
  from msprobe.core.common.log import logger
23
+ from msprobe.core.common.exceptions import MsprobeException
22
24
 
23
25
 
24
26
  class DataWriter:
@@ -115,3 +117,29 @@ class DataWriter:
115
117
  self.write_stack_info_json(self.stack_file_path)
116
118
  if self.cache_construct:
117
119
  self.write_construct_info_json(self.construct_file_path)
120
+
121
+ def fill_stack_tensor_data(self):
122
+ self.process_stat_data_recursive(self.cache_data)
123
+
124
+ def process_stat_data_recursive(self, data, depth=0):
125
+ if depth > Const.MAX_DEPTH:
126
+ logger.error(f"The maximum depth of recursive process stat data, {Const.MAX_DEPTH} is reached.")
127
+ raise MsprobeException(MsprobeException.RECURSION_LIMIT_ERROR)
128
+ if isinstance(data, dict):
129
+ if "tensor_stat" in data.keys():
130
+ tensor_stat = data["tensor_stat"]
131
+ if len(tensor_stat) != Const.TENSOR_STAT_LEN or len(tensor_stat[0]) != len(tensor_stat[1]):
132
+ logger.warning("Some bad data in async dump")
133
+ else:
134
+ tensor_stat_index, tensor_stat_data = tensor_stat[0], tensor_stat[1]
135
+ if hasattr(tensor_stat_data, "device") and tensor_stat_data.device != Const.CPU_LOWERCASE:
136
+ tensor_stat_data = tensor_stat_data.cpu()
137
+ for index, stat in zip(tensor_stat_index, tensor_stat_data):
138
+ data.update({index, stat.item()})
139
+ del data["tensor_stat"]
140
+ else:
141
+ for key in data.keys():
142
+ self.process_stat_data_recursive(data[key], depth + 1)
143
+ elif isinstance(data, (list, tuple)):
144
+ for i in data:
145
+ self.process_stat_data_recursive(i, depth + 1)