mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,389 +1,383 @@
1
- import copy
2
- import os
3
- import zlib
4
- from dataclasses import asdict
5
- from typing import List
6
-
7
- import numpy as np
8
- import torch
9
- from msprobe.core.common.file_check import path_len_exceeds_limit, change_mode
10
- from msprobe.core.common.log import logger
11
- from msprobe.core.common.const import Const, OverflowConst, FileCheckConst
12
- from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
13
- ModuleForwardInputsOutputs, TensorStatInfo
14
- from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
15
- from msprobe.pytorch.common.utils import save_pt
16
-
17
-
18
- try:
19
- import torch_npu
20
- is_gpu = False
21
- except ImportError:
22
- is_gpu = True
23
-
24
-
25
- class PytorchDataProcessor(BaseDataProcessor):
26
- pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor)
27
-
28
- def __init__(self, config, data_writer):
29
- super().__init__(config, data_writer)
30
- self.torch_object_key = {
31
- "device": self.analyze_device_in_kwargs,
32
- "dtype": self.analyze_dtype_in_kwargs
33
- }
34
-
35
- @staticmethod
36
- def get_md5_for_tensor(x):
37
- if x.dtype == torch.bfloat16:
38
- x = x.float()
39
- tensor_bytes = x.cpu().detach().numpy().tobytes()
40
- crc32_hash = zlib.crc32(tensor_bytes)
41
- return f"{crc32_hash:08x}"
42
-
43
- @staticmethod
44
- def analyze_device_in_kwargs(element):
45
- single_arg = {}
46
- single_arg.update({'type': "torch.device"})
47
- if not isinstance(element, str):
48
- if hasattr(element, "index"):
49
- device_value = element.type + ":" + str(element.index)
50
- else:
51
- device_value = element.type
52
- single_arg.update({"value": device_value})
53
- else:
54
- single_arg.update({"value": element})
55
- return single_arg
56
-
57
- @staticmethod
58
- def analyze_dtype_in_kwargs(element):
59
- return {"type": "torch.dtype", "value": str(element)}
60
-
61
- @staticmethod
62
- def get_stat_info(data):
63
- tensor_stat = TensorStatInfo()
64
- if data.is_meta:
65
- return tensor_stat
66
- data_clone = data.detach()
67
- if data_clone.numel() == 0:
68
- return tensor_stat
69
- elif data_clone.dtype == torch.bool:
70
- tensor_stat.max = True in data_clone
71
- tensor_stat.min = False not in data_clone
72
- elif not data_clone.shape:
73
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item()
74
- elif torch.is_complex(data_clone):
75
- data_np = data_clone.cpu().numpy()
76
- data_abs = np.abs(data_np)
77
- tensor_stat.max = np.max(data_abs).item()
78
- tensor_stat.min = np.min(data_abs).item()
79
- tensor_stat.mean = np.mean(data_abs).item()
80
- else:
81
- if not data_clone.is_floating_point() or data_clone.dtype == torch.float64:
82
- data_clone = data_clone.float()
83
- tensor_stat.max = torch._C._VariableFunctionsClass.max(data_clone).item()
84
- tensor_stat.min = torch._C._VariableFunctionsClass.min(data_clone).item()
85
- tensor_stat.mean = torch._C._VariableFunctionsClass.mean(data_clone).item()
86
- tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item()
87
- return tensor_stat
88
-
89
- @staticmethod
90
- def handle_tensor_extremum_nan_inf(tensor, operator):
91
- data_clone = tensor.detach()
92
- data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
93
- if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
94
- return float('nan')
95
- finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
96
- if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
97
- finite_values = data_clone[finite_mask]
98
- return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
99
- torch._C._VariableFunctionsClass.min(finite_values).item()
100
- else:
101
- data_no_nan = data_clone[~data_nan]
102
- return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
103
- torch._C._VariableFunctionsClass.min(data_no_nan).item()
104
-
105
- @staticmethod
106
- def _analyze_builtin(arg):
107
- single_arg = {}
108
- if isinstance(arg, slice):
109
- single_arg.update({"type": "slice"})
110
- # slice参数中可能存在tensor类型,json序列化,需要转换为python数值类型
111
- values = [
112
- value if not isinstance(value, torch.Tensor) else value.item()
113
- for value in [arg.start, arg.stop, arg.step]
114
- ]
115
- single_arg.update({"value": values})
116
- else:
117
- single_arg.update({"type": type(arg).__name__})
118
- single_arg.update({"value": arg})
119
- return single_arg
120
-
121
- @staticmethod
122
- def _analyze_torch_size(arg):
123
- return {"type": "torch.Size", "value": list(arg)}
124
-
125
- @classmethod
126
- def get_special_types(cls):
127
- return super().get_special_types() + cls.pytorch_special_type
128
-
129
- def analyze_single_element(self, element, suffix_stack):
130
- if suffix_stack and suffix_stack[-1] in self.torch_object_key:
131
- return self.torch_object_key[suffix_stack[-1]](element)
132
- if isinstance(element, torch.Size):
133
- return self._analyze_torch_size(element)
134
- converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
135
- if converted_numpy is not element:
136
- return self._analyze_numpy(converted_numpy, numpy_type)
137
- if isinstance(element, torch.Tensor):
138
- return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
139
- if isinstance(element, (bool, int, float, str, slice)):
140
- return self._analyze_builtin(element)
141
- return {}
142
-
143
- def _analyze_tensor(self, tensor, suffix):
144
- tensor_stat = self.get_stat_info(tensor)
145
- tensor_json = {}
146
- tensor_json.update({'type': 'torch.Tensor'})
147
- tensor_json.update({'dtype': str(tensor.dtype)})
148
- tensor_json.update({"shape": tensor.shape})
149
- tensor_json.update({"Max": tensor_stat.max})
150
- tensor_json.update({"Min": tensor_stat.min})
151
- tensor_json.update({"Mean": tensor_stat.mean})
152
- tensor_json.update({"Norm": tensor_stat.norm})
153
- tensor_json.update({"requires_grad": tensor.requires_grad})
154
-
155
- if tensor_stat.max is not None:
156
- if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
157
- tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
158
- if tensor_stat.min is not None:
159
- if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
160
- tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
161
-
162
- if self.config.summary_mode == Const.MD5:
163
- tensor_md5 = self.get_md5_for_tensor(tensor)
164
- tensor_json.update({Const.MD5: tensor_md5})
165
- return tensor_json
166
-
167
-
168
- class StatisticsDataProcessor(PytorchDataProcessor):
169
- pass
170
-
171
-
172
- class TensorDataProcessor(PytorchDataProcessor):
173
- def _analyze_tensor(self, tensor, suffix):
174
- dump_data_name, file_path = self.get_save_file_path(suffix)
175
- saved_tensor = tensor.contiguous().detach()
176
- save_pt(saved_tensor, file_path)
177
- single_arg = super()._analyze_tensor(tensor, suffix)
178
- single_arg.update({"data_name": dump_data_name})
179
- return single_arg
180
-
181
-
182
- class OverflowCheckDataProcessor(PytorchDataProcessor):
183
- __slots__ = ["cached_tensors_and_file_paths"]
184
-
185
- def __init__(self, config, data_writer):
186
- super().__init__(config, data_writer)
187
- self.cached_tensors_and_file_paths = {}
188
- self.bits_for_overflow = 8
189
- self.real_overflow_nums = 0
190
- self.overflow_nums = config.overflow_nums
191
- self.forward_inplace_inputs = None
192
-
193
- @property
194
- def is_terminated(self):
195
- if self.overflow_nums == -1:
196
- return False
197
- if self.real_overflow_nums >= self.overflow_nums:
198
- logger.info(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_nums}")
199
- return True
200
- return False
201
-
202
- @staticmethod
203
- def overflow_debug_mode_enable():
204
- overflow_mode = os.getenv(OverflowConst.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE)
205
- return overflow_mode == Const.ENV_ENABLE
206
-
207
- def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
208
- self.forward_inplace_inputs = copy.deepcopy(module_input_output)
209
- return None
210
-
211
- def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
212
- module_input_output.output = module_input_output.concat_args_and_kwargs()
213
- module_input_output.args = self.forward_inplace_inputs.args
214
- module_input_output.kwargs = self.forward_inplace_inputs.kwargs
215
- # release memory used by forward inputs
216
- self.forward_inplace_inputs = None
217
- return self.analyze_forward(name, None, module_input_output)
218
-
219
- def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
220
- self.has_overflow = False
221
- api_info_struct = super().analyze_forward(name, module, module_input_output)
222
- self.maybe_save_overflow_data_and_check_overflow_times()
223
- return api_info_struct if self.has_overflow else None
224
-
225
- def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
226
- self.has_overflow = False
227
- api_info_struct = super().analyze_backward(name, module, module_input_output)
228
- self.maybe_save_overflow_data_and_check_overflow_times()
229
- return api_info_struct if self.has_overflow else None
230
-
231
- def maybe_save_overflow_data_and_check_overflow_times(self):
232
- if self.has_overflow:
233
- for file_path, tensor in self.cached_tensors_and_file_paths.items():
234
- save_pt(tensor, file_path)
235
- self.real_overflow_nums += 1
236
- self.cached_tensors_and_file_paths = {}
237
-
238
- def check_overflow_npu(self):
239
- if self.overflow_debug_mode_enable():
240
- float_status = torch.zeros(self.bits_for_overflow).npu()
241
- result = torch_npu.npu_get_float_status(float_status, OverflowConst.OVERFLOW_DEBUG_MODE)
242
- if result.cpu()[0] != 0:
243
- return True
244
- else:
245
- return False
246
- else:
247
- return torch_npu._C._check_overflow_npu()
248
-
249
- def clear_overflow_npu(self):
250
- if self.overflow_debug_mode_enable():
251
- float_status = torch.zeros(self.bits_for_overflow).npu()
252
- torch_npu.npu_clear_float_status(float_status, OverflowConst.OVERFLOW_DEBUG_MODE)
253
- else:
254
- torch_npu._C._clear_overflow_npu()
255
-
256
- def _analyze_maybe_overflow_tensor(self, tensor_json):
257
- if is_gpu or (hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan()):
258
- if tensor_json['Max'] is None:
259
- return
260
- if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']):
261
- self.has_overflow = True
262
- if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']):
263
- self.has_overflow = True
264
- else:
265
- try:
266
- self.has_overflow = self.check_overflow_npu()
267
- if self.has_overflow:
268
- self.clear_overflow_npu()
269
- except Exception as e:
270
- logger.error(f"Overflow check failed, the current environment may be abnormal.")
271
- raise RuntimeError(f"overflow check failed") from e
272
-
273
- def _analyze_tensor(self, tensor, suffix):
274
- dump_data_name, file_path = self.get_save_file_path(suffix)
275
- if not path_len_exceeds_limit(file_path):
276
- self.cached_tensors_and_file_paths.update({file_path: tensor})
277
- else:
278
- logger.warning(f'The file path {file_path} length exceeds limit.')
279
- single_arg = super()._analyze_tensor(tensor, suffix)
280
- self._analyze_maybe_overflow_tensor(single_arg)
281
- single_arg.update({"data_name": dump_data_name})
282
- return single_arg
283
-
284
-
285
- class FreeBenchmarkDataProcessor(PytorchDataProcessor):
286
-
287
- def __init__(self, config, data_writer):
288
- super().__init__(config, data_writer)
289
- self.checker = FreeBenchmarkCheck(config=config)
290
- self._return_forward_new_output = None
291
- self._forward_new_output = None
292
-
293
- def update_iter(self, current_iter):
294
- super().update_iter(current_iter)
295
- self.checker.update_iter(current_iter)
296
-
297
- def update_unequal_rows(self, unequal_rows: List[UnequalRow]):
298
- if not unequal_rows:
299
- return
300
- for row in unequal_rows:
301
- data_dict = asdict(row)
302
- self.data_writer.write_data_to_csv(
303
- data_dict.values(),
304
- data_dict.keys(),
305
- self.data_writer.free_benchmark_file_path
306
- )
307
- return
308
-
309
- def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
310
- self.checker.pre_forward(name, module, self, module_input_output.args, module_input_output.kwargs)
311
-
312
- def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
313
- new_output, unequal_rows = self.checker.forward(
314
- name,
315
- module,
316
- module_input_output.args,
317
- module_input_output.kwargs,
318
- module_input_output.output,
319
- )
320
- self.update_unequal_rows(unequal_rows)
321
- if self.checker.if_fix():
322
- self._return_forward_new_output = True
323
- self._forward_new_output = new_output
324
-
325
- def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
326
- self.checker.backward(name, module, module_input_output.grad_input)
327
-
328
-
329
- class KernelDumpDataProcessor(PytorchDataProcessor):
330
- forward_init_status = False
331
- multi_output_apis = ["_sort_", "npu_flash_attention"]
332
-
333
- def __init__(self, config, data_writer):
334
- super().__init__(config, data_writer)
335
-
336
- def analyze_forward(self, name, module, module_input_output):
337
- if self.config.is_forward_acl_dump:
338
- self.forward_acl_dump(name, module, module_input_output)
339
- else:
340
- self.dump_mode_backward_acl_dump(name, module, module_input_output)
341
-
342
- def forward_acl_dump(self, name, module, module_input_output):
343
- if not KernelDumpDataProcessor.forward_init_status:
344
- KernelDumpDataProcessor.forward_init_status = True
345
- torch_npu.npu.synchronize()
346
- torch_npu.npu.init_dump()
347
- torch_npu.npu.set_dump(self.config.acl_config)
348
- torch_npu.npu.synchronize()
349
- if self.op_need_trigger(name):
350
- module.forward(*module_input_output.args, **module_input_output.kwargs).cpu()
351
- else:
352
- module.forward(*module_input_output.args, **module_input_output.kwargs)
353
- torch_npu.npu.synchronize()
354
- torch_npu.npu.finalize_dump()
355
- torch_npu.npu.synchronize()
356
- KernelDumpDataProcessor.forward_init_status = False
357
- logger.info("Dump %s op file." % name)
358
-
359
- def acl_backward_dump_status(self, output, grad, module_name):
360
- if isinstance(output, torch.Tensor):
361
- output.backward(grad, retain_graph=True)
362
- return True
363
-
364
- for api_name in KernelDumpDataProcessor.multi_output_apis:
365
- if api_name in module_name:
366
- output[0].backward(grad, retain_graph=True)
367
- return True
368
- return False
369
-
370
- def dump_mode_backward_acl_dump(self, name, module, module_input_output):
371
- grad_path = self.config.backward_input.get(name)
372
- if not KernelDumpDataProcessor.forward_init_status:
373
- KernelDumpDataProcessor.forward_init_status = True
374
- output = module.forward(*module_input_output.args, **module_input_output.kwargs)
375
- grad = torch.load(grad_path).to("npu").requires_grad_()
376
- torch_npu.npu.init_dump()
377
- torch_npu.npu.set_dump(self.config.acl_config)
378
- torch_npu.npu.synchronize()
379
- if not self.acl_backward_dump_status(output, grad, name):
380
- logger.warning("The output of {} is not of tensor type and cannot be automatically derived. "
381
- "you can manually construct a single API backward case for ACL dump.".format(
382
- name))
383
- torch_npu.npu.synchronize()
384
- torch_npu.npu.finalize_dump()
385
- KernelDumpDataProcessor.forward_init_status = False
386
- logger.info("Dump %s op file." % name)
387
-
388
- def op_need_trigger(self, module_name):
389
- return 'Tensor.__getitem__.' in module_name
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import zlib
17
+ from dataclasses import asdict
18
+ from typing import List
19
+
20
+ import numpy as np
21
+ import torch
22
+ from msprobe.core.common.const import Const
23
+ from msprobe.core.common.file_utils import path_len_exceeds_limit
24
+ from msprobe.core.common.log import logger
25
+ from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
26
+ ModuleForwardInputsOutputs, TensorStatInfo
27
+ from msprobe.pytorch.common.utils import save_pt, load_pt
28
+ from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
29
+
30
+ is_gpu = False
31
+ try:
32
+ import torch_npu
33
+ except ImportError:
34
+ is_gpu = True
35
+
36
+
37
+ class PytorchDataProcessor(BaseDataProcessor):
38
+ pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor)
39
+
40
+ def __init__(self, config, data_writer):
41
+ super().__init__(config, data_writer)
42
+ self.torch_object_key = {
43
+ "device": self.analyze_device_in_kwargs,
44
+ "dtype": self.analyze_dtype_in_kwargs
45
+ }
46
+
47
+ @staticmethod
48
+ def get_md5_for_tensor(x):
49
+ if x.dtype == torch.bfloat16:
50
+ x = x.float()
51
+ tensor_bytes = x.cpu().detach().numpy().tobytes()
52
+ crc32_hash = zlib.crc32(tensor_bytes)
53
+ return f"{crc32_hash:08x}"
54
+
55
+ @staticmethod
56
+ def analyze_device_in_kwargs(element):
57
+ single_arg = {}
58
+ single_arg.update({'type': "torch.device"})
59
+ if not isinstance(element, str):
60
+ if hasattr(element, "index"):
61
+ device_value = element.type + ":" + str(element.index)
62
+ else:
63
+ device_value = element.type
64
+ single_arg.update({"value": device_value})
65
+ else:
66
+ single_arg.update({"value": element})
67
+ return single_arg
68
+
69
+ @staticmethod
70
+ def analyze_dtype_in_kwargs(element):
71
+ return {"type": "torch.dtype", "value": str(element)}
72
+
73
+ @staticmethod
74
+ def get_stat_info(data):
75
+ tensor_stat = TensorStatInfo()
76
+ if data.is_meta:
77
+ return tensor_stat
78
+ data_clone = data.detach()
79
+ if data_clone.numel() == 0:
80
+ return tensor_stat
81
+ elif data_clone.dtype == torch.bool:
82
+ tensor_stat.max = True in data_clone
83
+ tensor_stat.min = False not in data_clone
84
+ elif not data_clone.shape:
85
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item()
86
+ elif torch.is_complex(data_clone):
87
+ data_np = data_clone.cpu().numpy()
88
+ data_abs = np.abs(data_np)
89
+ tensor_stat.max = np.max(data_abs).item()
90
+ tensor_stat.min = np.min(data_abs).item()
91
+ tensor_stat.mean = np.mean(data_abs).item()
92
+ else:
93
+ if not data_clone.is_floating_point() or data_clone.dtype == torch.float64:
94
+ data_clone = data_clone.float()
95
+ tensor_stat.max = torch._C._VariableFunctionsClass.max(data_clone).item()
96
+ tensor_stat.min = torch._C._VariableFunctionsClass.min(data_clone).item()
97
+ tensor_stat.mean = torch._C._VariableFunctionsClass.mean(data_clone).item()
98
+ tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item()
99
+ return tensor_stat
100
+
101
+ @staticmethod
102
+ def handle_tensor_extremum_nan_inf(tensor, operator):
103
+ data_clone = tensor.detach()
104
+ data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
105
+ if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
106
+ return float('nan')
107
+ finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
108
+ if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
109
+ finite_values = data_clone[finite_mask]
110
+ return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
111
+ torch._C._VariableFunctionsClass.min(finite_values).item()
112
+ else:
113
+ data_no_nan = data_clone[~data_nan]
114
+ return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
115
+ torch._C._VariableFunctionsClass.min(data_no_nan).item()
116
+
117
+ @staticmethod
118
+ def _analyze_torch_size(arg):
119
+ return {"type": "torch.Size", "value": list(arg)}
120
+
121
+ @classmethod
122
+ def get_special_types(cls):
123
+ return super().get_special_types() + cls.pytorch_special_type
124
+
125
+ def analyze_single_element(self, element, suffix_stack):
126
+ if suffix_stack and suffix_stack[-1] in self.torch_object_key:
127
+ return self.torch_object_key[suffix_stack[-1]](element)
128
+ if isinstance(element, torch.Size):
129
+ return self._analyze_torch_size(element)
130
+ converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
131
+ if converted_numpy is not element:
132
+ return self._analyze_numpy(converted_numpy, numpy_type)
133
+ if isinstance(element, torch.Tensor):
134
+ return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
135
+ if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
136
+ return self._analyze_builtin(element)
137
+ return {}
138
+
139
+ def _analyze_tensor(self, tensor, suffix):
140
+ tensor_stat = self.get_stat_info(tensor)
141
+ tensor_json = {}
142
+ tensor_json.update({'type': 'torch.Tensor'})
143
+ tensor_json.update({'dtype': str(tensor.dtype)})
144
+ tensor_json.update({"shape": tensor.shape})
145
+ tensor_json.update({"Max": tensor_stat.max})
146
+ tensor_json.update({"Min": tensor_stat.min})
147
+ tensor_json.update({"Mean": tensor_stat.mean})
148
+ tensor_json.update({"Norm": tensor_stat.norm})
149
+ tensor_json.update({"requires_grad": tensor.requires_grad})
150
+
151
+ if tensor_stat.max is not None:
152
+ if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
153
+ tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
154
+ if tensor_stat.min is not None:
155
+ if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
156
+ tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
157
+
158
+ if self.config.summary_mode == Const.MD5:
159
+ tensor_md5 = self.get_md5_for_tensor(tensor)
160
+ tensor_json.update({Const.MD5: tensor_md5})
161
+ return tensor_json
162
+
163
+
164
+ class StatisticsDataProcessor(PytorchDataProcessor):
165
+ pass
166
+
167
+
168
+ class TensorDataProcessor(PytorchDataProcessor):
169
+ def _analyze_tensor(self, tensor, suffix):
170
+ dump_data_name, file_path = self.get_save_file_path(suffix)
171
+ saved_tensor = tensor.clone().contiguous().detach()
172
+ save_pt(saved_tensor, file_path)
173
+ single_arg = super()._analyze_tensor(tensor, suffix)
174
+ single_arg.update({"data_name": dump_data_name})
175
+ return single_arg
176
+
177
+
178
+ class OverflowCheckDataProcessor(PytorchDataProcessor):
179
+ __slots__ = ["cached_tensors_and_file_paths"]
180
+
181
+ def __init__(self, config, data_writer):
182
+ super().__init__(config, data_writer)
183
+ self.has_overflow = False
184
+ self.support_inf_nan = None
185
+ self.cached_inplace_api_info = {}
186
+ self.cached_tensors_and_file_paths = {}
187
+ self.bits_for_overflow = 8
188
+ self.real_overflow_nums = 0
189
+ self.overflow_nums = config.overflow_nums
190
+
191
+ @property
192
+ def is_terminated(self):
193
+ if self.overflow_nums == -1:
194
+ return False
195
+ if self.real_overflow_nums >= self.overflow_nums:
196
+ return True
197
+ return False
198
+
199
+ def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
200
+ self.has_overflow = False
201
+ self._is_support_inf_nan()
202
+ self.cached_inplace_api_info = super().analyze_pre_forward_inplace(name, module_input_output)
203
+ return None
204
+
205
+ def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
206
+ self._is_support_inf_nan()
207
+ api_info_struct = super().analyze_forward_inplace(name, module_input_output)
208
+ if name in self.cached_inplace_api_info and name in api_info_struct:
209
+ self.cached_inplace_api_info[name].update(api_info_struct[name])
210
+ elif name in api_info_struct:
211
+ self.cached_inplace_api_info = api_info_struct
212
+ self.handle_overflow()
213
+ return self.cached_inplace_api_info if self.has_overflow else None
214
+
215
+ def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
216
+ self.has_overflow = False
217
+ self._is_support_inf_nan()
218
+ api_info_struct = super().analyze_forward(name, module, module_input_output)
219
+ self.handle_overflow()
220
+ return api_info_struct if self.has_overflow else None
221
+
222
+ def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
223
+ self.has_overflow = False
224
+ self._is_support_inf_nan()
225
+ api_info_struct = super().analyze_backward(name, module, module_input_output)
226
+ self.handle_overflow()
227
+ return api_info_struct if self.has_overflow else None
228
+
229
+ def handle_overflow(self):
230
+ if not self.support_inf_nan:
231
+ self._analyze_maybe_overflow_flag()
232
+ if self.has_overflow:
233
+ for file_path, tensor in self.cached_tensors_and_file_paths.items():
234
+ save_pt(tensor, file_path)
235
+ self.real_overflow_nums += 1
236
+ if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums:
237
+ logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, "
238
+ f"current overflow times: {self.real_overflow_nums}.")
239
+ self.cached_tensors_and_file_paths = {}
240
+
241
+ def _is_support_inf_nan(self):
242
+ if self.support_inf_nan is not None:
243
+ return
244
+ try:
245
+ self.support_inf_nan = is_gpu or torch_npu.npu.utils.is_support_inf_nan()
246
+ except Exception:
247
+ logger.warning(f"Unable to determine if the current device supports inf/nan mode, default not supported.")
248
+ self.support_inf_nan = False
249
+
250
+ def _analyze_maybe_overflow_flag(self):
251
+ try:
252
+ self.has_overflow = torch_npu.npu.utils.get_npu_overflow_flag()
253
+ if self.has_overflow:
254
+ torch_npu.npu.utils.clear_npu_overflow_flag()
255
+ except Exception as e:
256
+ logger.error(f"Overflow check failed, the current environment may be abnormal.")
257
+ raise RuntimeError(f"overflow check failed") from e
258
+
259
+ def _analyze_maybe_overflow_tensor(self, tensor_json):
260
+ if tensor_json['Max'] is None or tensor_json['Min'] is None:
261
+ return
262
+ self.has_overflow = np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']) or \
263
+ np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min'])
264
+
265
+ def _analyze_tensor(self, tensor, suffix):
266
+ dump_data_name, file_path = self.get_save_file_path(suffix)
267
+ if not path_len_exceeds_limit(file_path):
268
+ self.cached_tensors_and_file_paths.update({file_path: tensor})
269
+ else:
270
+ logger.warning(f'The file path {file_path} length exceeds limit.')
271
+ single_arg = super()._analyze_tensor(tensor, suffix)
272
+ single_arg.update({"data_name": dump_data_name})
273
+ if not self.has_overflow and self.support_inf_nan:
274
+ self._analyze_maybe_overflow_tensor(single_arg)
275
+ return single_arg
276
+
277
+
278
+ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
279
+
280
+ def __init__(self, config, data_writer):
281
+ super().__init__(config, data_writer)
282
+ self.checker = FreeBenchmarkCheck(config=config)
283
+ self._return_forward_new_output = None
284
+ self._forward_new_output = None
285
+
286
+ def update_iter(self, current_iter):
287
+ super().update_iter(current_iter)
288
+ self.checker.update_iter(current_iter)
289
+
290
+ def update_unequal_rows(self, unequal_rows: List[UnequalRow]):
291
+ if not unequal_rows:
292
+ return
293
+ for row in unequal_rows:
294
+ data_dict = asdict(row)
295
+ self.data_writer.write_data_to_csv(
296
+ data_dict.values(),
297
+ data_dict.keys(),
298
+ self.data_writer.free_benchmark_file_path
299
+ )
300
+ return
301
+
302
+ def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
303
+ self.checker.pre_forward(name, module, self, module_input_output.args, module_input_output.kwargs)
304
+
305
+ def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
306
+ new_output, unequal_rows = self.checker.forward(
307
+ name,
308
+ module,
309
+ module_input_output.args,
310
+ module_input_output.kwargs,
311
+ module_input_output.output,
312
+ )
313
+ self.update_unequal_rows(unequal_rows)
314
+ if self.checker.if_fix():
315
+ self._return_forward_new_output = True
316
+ self._forward_new_output = new_output
317
+
318
+ def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
319
+ self.checker.backward(name, module, module_input_output.grad_input)
320
+
321
+
322
+ class KernelDumpDataProcessor(PytorchDataProcessor):
323
+ forward_init_status = False
324
+ multi_output_apis = ["_sort_", "npu_flash_attention"]
325
+
326
+ def __init__(self, config, data_writer):
327
+ super().__init__(config, data_writer)
328
+
329
+ def analyze_forward(self, name, module, module_input_output):
330
+ if self.config.is_forward_acl_dump:
331
+ self.forward_acl_dump(name, module, module_input_output)
332
+ else:
333
+ self.dump_mode_backward_acl_dump(name, module, module_input_output)
334
+
335
+ def forward_acl_dump(self, name, module, module_input_output):
336
+ if not KernelDumpDataProcessor.forward_init_status:
337
+ KernelDumpDataProcessor.forward_init_status = True
338
+ torch_npu.npu.synchronize()
339
+ torch_npu.npu.init_dump()
340
+ torch_npu.npu.set_dump(self.config.acl_config)
341
+ torch_npu.npu.synchronize()
342
+ if self.op_need_trigger(name):
343
+ module.forward(*module_input_output.args, **module_input_output.kwargs).cpu()
344
+ else:
345
+ module.forward(*module_input_output.args, **module_input_output.kwargs)
346
+ torch_npu.npu.synchronize()
347
+ torch_npu.npu.finalize_dump()
348
+ torch_npu.npu.synchronize()
349
+ KernelDumpDataProcessor.forward_init_status = False
350
+ logger.info("Dump %s op file." % name)
351
+
352
+ def acl_backward_dump_status(self, output, grad, module_name):
353
+ if isinstance(output, torch.Tensor):
354
+ output.backward(grad, retain_graph=True)
355
+ return True
356
+
357
+ for api_name in KernelDumpDataProcessor.multi_output_apis:
358
+ if api_name in module_name:
359
+ output[0].backward(grad, retain_graph=True)
360
+ return True
361
+ return False
362
+
363
+ def dump_mode_backward_acl_dump(self, name, module, module_input_output):
364
+ grad_path = self.config.backward_input.get(name)
365
+ if not KernelDumpDataProcessor.forward_init_status:
366
+ KernelDumpDataProcessor.forward_init_status = True
367
+ output = module.forward(*module_input_output.args, **module_input_output.kwargs)
368
+ pt = load_pt(grad_path)
369
+ grad = pt.to("npu").requires_grad_()
370
+ torch_npu.npu.init_dump()
371
+ torch_npu.npu.set_dump(self.config.acl_config)
372
+ torch_npu.npu.synchronize()
373
+ if not self.acl_backward_dump_status(output, grad, name):
374
+ logger.warning("The output of {} is not of tensor type and cannot be automatically derived. "
375
+ "you can manually construct a single API backward case for ACL dump.".format(
376
+ name))
377
+ torch_npu.npu.synchronize()
378
+ torch_npu.npu.finalize_dump()
379
+ KernelDumpDataProcessor.forward_init_status = False
380
+ logger.info("Dump %s op file." % name)
381
+
382
+ def op_need_trigger(self, module_name):
383
+ return 'Tensor.__getitem__.' in module_name