mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,389 +1,366 @@
1
- import copy
2
- import os
3
- import zlib
4
- from dataclasses import asdict
5
- from typing import List
6
-
7
- import numpy as np
8
- import torch
9
- from msprobe.core.common.file_check import path_len_exceeds_limit, change_mode
10
- from msprobe.core.common.log import logger
11
- from msprobe.core.common.const import Const, OverflowConst, FileCheckConst
12
- from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
13
- ModuleForwardInputsOutputs, TensorStatInfo
14
- from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
15
- from msprobe.pytorch.common.utils import save_pt
16
-
17
-
18
- try:
19
- import torch_npu
20
- is_gpu = False
21
- except ImportError:
22
- is_gpu = True
23
-
24
-
25
- class PytorchDataProcessor(BaseDataProcessor):
26
- pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor)
27
-
28
- def __init__(self, config, data_writer):
29
- super().__init__(config, data_writer)
30
- self.torch_object_key = {
31
- "device": self.analyze_device_in_kwargs,
32
- "dtype": self.analyze_dtype_in_kwargs
33
- }
34
-
35
- @staticmethod
36
- def get_md5_for_tensor(x):
37
- if x.dtype == torch.bfloat16:
38
- x = x.float()
39
- tensor_bytes = x.cpu().detach().numpy().tobytes()
40
- crc32_hash = zlib.crc32(tensor_bytes)
41
- return f"{crc32_hash:08x}"
42
-
43
- @staticmethod
44
- def analyze_device_in_kwargs(element):
45
- single_arg = {}
46
- single_arg.update({'type': "torch.device"})
47
- if not isinstance(element, str):
48
- if hasattr(element, "index"):
49
- device_value = element.type + ":" + str(element.index)
50
- else:
51
- device_value = element.type
52
- single_arg.update({"value": device_value})
53
- else:
54
- single_arg.update({"value": element})
55
- return single_arg
56
-
57
- @staticmethod
58
- def analyze_dtype_in_kwargs(element):
59
- return {"type": "torch.dtype", "value": str(element)}
60
-
61
- @staticmethod
62
- def get_stat_info(data):
63
- tensor_stat = TensorStatInfo()
64
- if data.is_meta:
65
- return tensor_stat
66
- data_clone = data.detach()
67
- if data_clone.numel() == 0:
68
- return tensor_stat
69
- elif data_clone.dtype == torch.bool:
70
- tensor_stat.max = True in data_clone
71
- tensor_stat.min = False not in data_clone
72
- elif not data_clone.shape:
73
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item()
74
- elif torch.is_complex(data_clone):
75
- data_np = data_clone.cpu().numpy()
76
- data_abs = np.abs(data_np)
77
- tensor_stat.max = np.max(data_abs).item()
78
- tensor_stat.min = np.min(data_abs).item()
79
- tensor_stat.mean = np.mean(data_abs).item()
80
- else:
81
- if not data_clone.is_floating_point() or data_clone.dtype == torch.float64:
82
- data_clone = data_clone.float()
83
- tensor_stat.max = torch._C._VariableFunctionsClass.max(data_clone).item()
84
- tensor_stat.min = torch._C._VariableFunctionsClass.min(data_clone).item()
85
- tensor_stat.mean = torch._C._VariableFunctionsClass.mean(data_clone).item()
86
- tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item()
87
- return tensor_stat
88
-
89
- @staticmethod
90
- def handle_tensor_extremum_nan_inf(tensor, operator):
91
- data_clone = tensor.detach()
92
- data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
93
- if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
94
- return float('nan')
95
- finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
96
- if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
97
- finite_values = data_clone[finite_mask]
98
- return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
99
- torch._C._VariableFunctionsClass.min(finite_values).item()
100
- else:
101
- data_no_nan = data_clone[~data_nan]
102
- return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
103
- torch._C._VariableFunctionsClass.min(data_no_nan).item()
104
-
105
- @staticmethod
106
- def _analyze_builtin(arg):
107
- single_arg = {}
108
- if isinstance(arg, slice):
109
- single_arg.update({"type": "slice"})
110
- # slice参数中可能存在tensor类型,json序列化,需要转换为python数值类型
111
- values = [
112
- value if not isinstance(value, torch.Tensor) else value.item()
113
- for value in [arg.start, arg.stop, arg.step]
114
- ]
115
- single_arg.update({"value": values})
116
- else:
117
- single_arg.update({"type": type(arg).__name__})
118
- single_arg.update({"value": arg})
119
- return single_arg
120
-
121
- @staticmethod
122
- def _analyze_torch_size(arg):
123
- return {"type": "torch.Size", "value": list(arg)}
124
-
125
- @classmethod
126
- def get_special_types(cls):
127
- return super().get_special_types() + cls.pytorch_special_type
128
-
129
- def analyze_single_element(self, element, suffix_stack):
130
- if suffix_stack and suffix_stack[-1] in self.torch_object_key:
131
- return self.torch_object_key[suffix_stack[-1]](element)
132
- if isinstance(element, torch.Size):
133
- return self._analyze_torch_size(element)
134
- converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
135
- if converted_numpy is not element:
136
- return self._analyze_numpy(converted_numpy, numpy_type)
137
- if isinstance(element, torch.Tensor):
138
- return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
139
- if isinstance(element, (bool, int, float, str, slice)):
140
- return self._analyze_builtin(element)
141
- return {}
142
-
143
- def _analyze_tensor(self, tensor, suffix):
144
- tensor_stat = self.get_stat_info(tensor)
145
- tensor_json = {}
146
- tensor_json.update({'type': 'torch.Tensor'})
147
- tensor_json.update({'dtype': str(tensor.dtype)})
148
- tensor_json.update({"shape": tensor.shape})
149
- tensor_json.update({"Max": tensor_stat.max})
150
- tensor_json.update({"Min": tensor_stat.min})
151
- tensor_json.update({"Mean": tensor_stat.mean})
152
- tensor_json.update({"Norm": tensor_stat.norm})
153
- tensor_json.update({"requires_grad": tensor.requires_grad})
154
-
155
- if tensor_stat.max is not None:
156
- if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
157
- tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
158
- if tensor_stat.min is not None:
159
- if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
160
- tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
161
-
162
- if self.config.summary_mode == Const.MD5:
163
- tensor_md5 = self.get_md5_for_tensor(tensor)
164
- tensor_json.update({Const.MD5: tensor_md5})
165
- return tensor_json
166
-
167
-
168
- class StatisticsDataProcessor(PytorchDataProcessor):
169
- pass
170
-
171
-
172
- class TensorDataProcessor(PytorchDataProcessor):
173
- def _analyze_tensor(self, tensor, suffix):
174
- dump_data_name, file_path = self.get_save_file_path(suffix)
175
- saved_tensor = tensor.contiguous().detach()
176
- save_pt(saved_tensor, file_path)
177
- single_arg = super()._analyze_tensor(tensor, suffix)
178
- single_arg.update({"data_name": dump_data_name})
179
- return single_arg
180
-
181
-
182
- class OverflowCheckDataProcessor(PytorchDataProcessor):
183
- __slots__ = ["cached_tensors_and_file_paths"]
184
-
185
- def __init__(self, config, data_writer):
186
- super().__init__(config, data_writer)
187
- self.cached_tensors_and_file_paths = {}
188
- self.bits_for_overflow = 8
189
- self.real_overflow_nums = 0
190
- self.overflow_nums = config.overflow_nums
191
- self.forward_inplace_inputs = None
192
-
193
- @property
194
- def is_terminated(self):
195
- if self.overflow_nums == -1:
196
- return False
197
- if self.real_overflow_nums >= self.overflow_nums:
198
- logger.info(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_nums}")
199
- return True
200
- return False
201
-
202
- @staticmethod
203
- def overflow_debug_mode_enable():
204
- overflow_mode = os.getenv(OverflowConst.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE)
205
- return overflow_mode == Const.ENV_ENABLE
206
-
207
- def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
208
- self.forward_inplace_inputs = copy.deepcopy(module_input_output)
209
- return None
210
-
211
- def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
212
- module_input_output.output = module_input_output.concat_args_and_kwargs()
213
- module_input_output.args = self.forward_inplace_inputs.args
214
- module_input_output.kwargs = self.forward_inplace_inputs.kwargs
215
- # release memory used by forward inputs
216
- self.forward_inplace_inputs = None
217
- return self.analyze_forward(name, None, module_input_output)
218
-
219
- def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
220
- self.has_overflow = False
221
- api_info_struct = super().analyze_forward(name, module, module_input_output)
222
- self.maybe_save_overflow_data_and_check_overflow_times()
223
- return api_info_struct if self.has_overflow else None
224
-
225
- def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
226
- self.has_overflow = False
227
- api_info_struct = super().analyze_backward(name, module, module_input_output)
228
- self.maybe_save_overflow_data_and_check_overflow_times()
229
- return api_info_struct if self.has_overflow else None
230
-
231
- def maybe_save_overflow_data_and_check_overflow_times(self):
232
- if self.has_overflow:
233
- for file_path, tensor in self.cached_tensors_and_file_paths.items():
234
- save_pt(tensor, file_path)
235
- self.real_overflow_nums += 1
236
- self.cached_tensors_and_file_paths = {}
237
-
238
- def check_overflow_npu(self):
239
- if self.overflow_debug_mode_enable():
240
- float_status = torch.zeros(self.bits_for_overflow).npu()
241
- result = torch_npu.npu_get_float_status(float_status, OverflowConst.OVERFLOW_DEBUG_MODE)
242
- if result.cpu()[0] != 0:
243
- return True
244
- else:
245
- return False
246
- else:
247
- return torch_npu._C._check_overflow_npu()
248
-
249
- def clear_overflow_npu(self):
250
- if self.overflow_debug_mode_enable():
251
- float_status = torch.zeros(self.bits_for_overflow).npu()
252
- torch_npu.npu_clear_float_status(float_status, OverflowConst.OVERFLOW_DEBUG_MODE)
253
- else:
254
- torch_npu._C._clear_overflow_npu()
255
-
256
- def _analyze_maybe_overflow_tensor(self, tensor_json):
257
- if is_gpu or (hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan()):
258
- if tensor_json['Max'] is None:
259
- return
260
- if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']):
261
- self.has_overflow = True
262
- if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']):
263
- self.has_overflow = True
264
- else:
265
- try:
266
- self.has_overflow = self.check_overflow_npu()
267
- if self.has_overflow:
268
- self.clear_overflow_npu()
269
- except Exception as e:
270
- logger.error(f"Overflow check failed, the current environment may be abnormal.")
271
- raise RuntimeError(f"overflow check failed") from e
272
-
273
- def _analyze_tensor(self, tensor, suffix):
274
- dump_data_name, file_path = self.get_save_file_path(suffix)
275
- if not path_len_exceeds_limit(file_path):
276
- self.cached_tensors_and_file_paths.update({file_path: tensor})
277
- else:
278
- logger.warning(f'The file path {file_path} length exceeds limit.')
279
- single_arg = super()._analyze_tensor(tensor, suffix)
280
- self._analyze_maybe_overflow_tensor(single_arg)
281
- single_arg.update({"data_name": dump_data_name})
282
- return single_arg
283
-
284
-
285
- class FreeBenchmarkDataProcessor(PytorchDataProcessor):
286
-
287
- def __init__(self, config, data_writer):
288
- super().__init__(config, data_writer)
289
- self.checker = FreeBenchmarkCheck(config=config)
290
- self._return_forward_new_output = None
291
- self._forward_new_output = None
292
-
293
- def update_iter(self, current_iter):
294
- super().update_iter(current_iter)
295
- self.checker.update_iter(current_iter)
296
-
297
- def update_unequal_rows(self, unequal_rows: List[UnequalRow]):
298
- if not unequal_rows:
299
- return
300
- for row in unequal_rows:
301
- data_dict = asdict(row)
302
- self.data_writer.write_data_to_csv(
303
- data_dict.values(),
304
- data_dict.keys(),
305
- self.data_writer.free_benchmark_file_path
306
- )
307
- return
308
-
309
- def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
310
- self.checker.pre_forward(name, module, self, module_input_output.args, module_input_output.kwargs)
311
-
312
- def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
313
- new_output, unequal_rows = self.checker.forward(
314
- name,
315
- module,
316
- module_input_output.args,
317
- module_input_output.kwargs,
318
- module_input_output.output,
319
- )
320
- self.update_unequal_rows(unequal_rows)
321
- if self.checker.if_fix():
322
- self._return_forward_new_output = True
323
- self._forward_new_output = new_output
324
-
325
- def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
326
- self.checker.backward(name, module, module_input_output.grad_input)
327
-
328
-
329
- class KernelDumpDataProcessor(PytorchDataProcessor):
330
- forward_init_status = False
331
- multi_output_apis = ["_sort_", "npu_flash_attention"]
332
-
333
- def __init__(self, config, data_writer):
334
- super().__init__(config, data_writer)
335
-
336
- def analyze_forward(self, name, module, module_input_output):
337
- if self.config.is_forward_acl_dump:
338
- self.forward_acl_dump(name, module, module_input_output)
339
- else:
340
- self.dump_mode_backward_acl_dump(name, module, module_input_output)
341
-
342
- def forward_acl_dump(self, name, module, module_input_output):
343
- if not KernelDumpDataProcessor.forward_init_status:
344
- KernelDumpDataProcessor.forward_init_status = True
345
- torch_npu.npu.synchronize()
346
- torch_npu.npu.init_dump()
347
- torch_npu.npu.set_dump(self.config.acl_config)
348
- torch_npu.npu.synchronize()
349
- if self.op_need_trigger(name):
350
- module.forward(*module_input_output.args, **module_input_output.kwargs).cpu()
351
- else:
352
- module.forward(*module_input_output.args, **module_input_output.kwargs)
353
- torch_npu.npu.synchronize()
354
- torch_npu.npu.finalize_dump()
355
- torch_npu.npu.synchronize()
356
- KernelDumpDataProcessor.forward_init_status = False
357
- logger.info("Dump %s op file." % name)
358
-
359
- def acl_backward_dump_status(self, output, grad, module_name):
360
- if isinstance(output, torch.Tensor):
361
- output.backward(grad, retain_graph=True)
362
- return True
363
-
364
- for api_name in KernelDumpDataProcessor.multi_output_apis:
365
- if api_name in module_name:
366
- output[0].backward(grad, retain_graph=True)
367
- return True
368
- return False
369
-
370
- def dump_mode_backward_acl_dump(self, name, module, module_input_output):
371
- grad_path = self.config.backward_input.get(name)
372
- if not KernelDumpDataProcessor.forward_init_status:
373
- KernelDumpDataProcessor.forward_init_status = True
374
- output = module.forward(*module_input_output.args, **module_input_output.kwargs)
375
- grad = torch.load(grad_path).to("npu").requires_grad_()
376
- torch_npu.npu.init_dump()
377
- torch_npu.npu.set_dump(self.config.acl_config)
378
- torch_npu.npu.synchronize()
379
- if not self.acl_backward_dump_status(output, grad, name):
380
- logger.warning("The output of {} is not of tensor type and cannot be automatically derived. "
381
- "you can manually construct a single API backward case for ACL dump.".format(
382
- name))
383
- torch_npu.npu.synchronize()
384
- torch_npu.npu.finalize_dump()
385
- KernelDumpDataProcessor.forward_init_status = False
386
- logger.info("Dump %s op file." % name)
387
-
388
- def op_need_trigger(self, module_name):
389
- return 'Tensor.__getitem__.' in module_name
1
+ import zlib
2
+ from dataclasses import asdict
3
+ from typing import List
4
+
5
+ import numpy as np
6
+ import torch
7
+ from msprobe.core.common.file_utils import path_len_exceeds_limit, change_mode
8
+ from msprobe.core.common.log import logger
9
+ from msprobe.core.common.const import Const, OverflowConst, FileCheckConst
10
+ from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
11
+ ModuleForwardInputsOutputs, TensorStatInfo
12
+ from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
13
+ from msprobe.pytorch.common.utils import save_pt, load_pt
14
+
15
+ try:
16
+ import torch_npu
17
+ is_gpu = False
18
+ except ImportError:
19
+ is_gpu = True
20
+
21
+
22
+ class PytorchDataProcessor(BaseDataProcessor):
23
+ pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor)
24
+
25
+ def __init__(self, config, data_writer):
26
+ super().__init__(config, data_writer)
27
+ self.torch_object_key = {
28
+ "device": self.analyze_device_in_kwargs,
29
+ "dtype": self.analyze_dtype_in_kwargs
30
+ }
31
+
32
+ @staticmethod
33
+ def get_md5_for_tensor(x):
34
+ if x.dtype == torch.bfloat16:
35
+ x = x.float()
36
+ tensor_bytes = x.cpu().detach().numpy().tobytes()
37
+ crc32_hash = zlib.crc32(tensor_bytes)
38
+ return f"{crc32_hash:08x}"
39
+
40
+ @staticmethod
41
+ def analyze_device_in_kwargs(element):
42
+ single_arg = {}
43
+ single_arg.update({'type': "torch.device"})
44
+ if not isinstance(element, str):
45
+ if hasattr(element, "index"):
46
+ device_value = element.type + ":" + str(element.index)
47
+ else:
48
+ device_value = element.type
49
+ single_arg.update({"value": device_value})
50
+ else:
51
+ single_arg.update({"value": element})
52
+ return single_arg
53
+
54
+ @staticmethod
55
+ def analyze_dtype_in_kwargs(element):
56
+ return {"type": "torch.dtype", "value": str(element)}
57
+
58
+ @staticmethod
59
+ def get_stat_info(data):
60
+ tensor_stat = TensorStatInfo()
61
+ if data.is_meta:
62
+ return tensor_stat
63
+ data_clone = data.detach()
64
+ if data_clone.numel() == 0:
65
+ return tensor_stat
66
+ elif data_clone.dtype == torch.bool:
67
+ tensor_stat.max = True in data_clone
68
+ tensor_stat.min = False not in data_clone
69
+ elif not data_clone.shape:
70
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item()
71
+ elif torch.is_complex(data_clone):
72
+ data_np = data_clone.cpu().numpy()
73
+ data_abs = np.abs(data_np)
74
+ tensor_stat.max = np.max(data_abs).item()
75
+ tensor_stat.min = np.min(data_abs).item()
76
+ tensor_stat.mean = np.mean(data_abs).item()
77
+ else:
78
+ if not data_clone.is_floating_point() or data_clone.dtype == torch.float64:
79
+ data_clone = data_clone.float()
80
+ tensor_stat.max = torch._C._VariableFunctionsClass.max(data_clone).item()
81
+ tensor_stat.min = torch._C._VariableFunctionsClass.min(data_clone).item()
82
+ tensor_stat.mean = torch._C._VariableFunctionsClass.mean(data_clone).item()
83
+ tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item()
84
+ return tensor_stat
85
+
86
+ @staticmethod
87
+ def handle_tensor_extremum_nan_inf(tensor, operator):
88
+ data_clone = tensor.detach()
89
+ data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
90
+ if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
91
+ return float('nan')
92
+ finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
93
+ if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
94
+ finite_values = data_clone[finite_mask]
95
+ return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
96
+ torch._C._VariableFunctionsClass.min(finite_values).item()
97
+ else:
98
+ data_no_nan = data_clone[~data_nan]
99
+ return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
100
+ torch._C._VariableFunctionsClass.min(data_no_nan).item()
101
+
102
+ @staticmethod
103
+ def _analyze_torch_size(arg):
104
+ return {"type": "torch.Size", "value": list(arg)}
105
+
106
+ @classmethod
107
+ def get_special_types(cls):
108
+ return super().get_special_types() + cls.pytorch_special_type
109
+
110
+ def analyze_single_element(self, element, suffix_stack):
111
+ if suffix_stack and suffix_stack[-1] in self.torch_object_key:
112
+ return self.torch_object_key[suffix_stack[-1]](element)
113
+ if isinstance(element, torch.Size):
114
+ return self._analyze_torch_size(element)
115
+ converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
116
+ if converted_numpy is not element:
117
+ return self._analyze_numpy(converted_numpy, numpy_type)
118
+ if isinstance(element, torch.Tensor):
119
+ return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
120
+ if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
121
+ return self._analyze_builtin(element)
122
+ return {}
123
+
124
+ def _analyze_tensor(self, tensor, suffix):
125
+ tensor_stat = self.get_stat_info(tensor)
126
+ tensor_json = {}
127
+ tensor_json.update({'type': 'torch.Tensor'})
128
+ tensor_json.update({'dtype': str(tensor.dtype)})
129
+ tensor_json.update({"shape": tensor.shape})
130
+ tensor_json.update({"Max": tensor_stat.max})
131
+ tensor_json.update({"Min": tensor_stat.min})
132
+ tensor_json.update({"Mean": tensor_stat.mean})
133
+ tensor_json.update({"Norm": tensor_stat.norm})
134
+ tensor_json.update({"requires_grad": tensor.requires_grad})
135
+
136
+ if tensor_stat.max is not None:
137
+ if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
138
+ tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
139
+ if tensor_stat.min is not None:
140
+ if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
141
+ tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
142
+
143
+ if self.config.summary_mode == Const.MD5:
144
+ tensor_md5 = self.get_md5_for_tensor(tensor)
145
+ tensor_json.update({Const.MD5: tensor_md5})
146
+ return tensor_json
147
+
148
+
149
+ class StatisticsDataProcessor(PytorchDataProcessor):
150
+ pass
151
+
152
+
153
+ class TensorDataProcessor(PytorchDataProcessor):
154
+ def _analyze_tensor(self, tensor, suffix):
155
+ dump_data_name, file_path = self.get_save_file_path(suffix)
156
+ saved_tensor = tensor.contiguous().detach()
157
+ save_pt(saved_tensor, file_path)
158
+ single_arg = super()._analyze_tensor(tensor, suffix)
159
+ single_arg.update({"data_name": dump_data_name})
160
+ return single_arg
161
+
162
+
163
+ class OverflowCheckDataProcessor(PytorchDataProcessor):
164
+ __slots__ = ["cached_tensors_and_file_paths"]
165
+
166
+ def __init__(self, config, data_writer):
167
+ super().__init__(config, data_writer)
168
+ self.has_overflow = False
169
+ self.support_inf_nan = None
170
+ self.cached_inplace_api_info = {}
171
+ self.cached_tensors_and_file_paths = {}
172
+ self.bits_for_overflow = 8
173
+ self.real_overflow_nums = 0
174
+ self.overflow_nums = config.overflow_nums
175
+
176
+ @property
177
+ def is_terminated(self):
178
+ if self.overflow_nums == -1:
179
+ return False
180
+ if self.real_overflow_nums >= self.overflow_nums:
181
+ logger.info(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_nums}")
182
+ return True
183
+ return False
184
+
185
+ def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
186
+ self.has_overflow = False
187
+ self._is_support_inf_nan()
188
+ self.cached_inplace_api_info = super().analyze_pre_forward_inplace(name, module_input_output)
189
+ return None
190
+
191
+ def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
192
+ self._is_support_inf_nan()
193
+ api_info_struct = super().analyze_forward_inplace(name, module_input_output)
194
+ if name in self.cached_inplace_api_info and name in api_info_struct:
195
+ self.cached_inplace_api_info[name].update(api_info_struct[name])
196
+ elif name in api_info_struct:
197
+ self.cached_inplace_api_info = api_info_struct
198
+ self.handle_overflow()
199
+ return self.cached_inplace_api_info if self.has_overflow else None
200
+
201
+ def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
202
+ self.has_overflow = False
203
+ self._is_support_inf_nan()
204
+ api_info_struct = super().analyze_forward(name, module, module_input_output)
205
+ self.handle_overflow()
206
+ return api_info_struct if self.has_overflow else None
207
+
208
+ def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
209
+ self.has_overflow = False
210
+ self._is_support_inf_nan()
211
+ api_info_struct = super().analyze_backward(name, module, module_input_output)
212
+ self.handle_overflow()
213
+ return api_info_struct if self.has_overflow else None
214
+
215
+ def handle_overflow(self):
216
+ if not self.support_inf_nan:
217
+ self._analyze_maybe_overflow_flag()
218
+ if self.has_overflow:
219
+ for file_path, tensor in self.cached_tensors_and_file_paths.items():
220
+ save_pt(tensor, file_path)
221
+ self.real_overflow_nums += 1
222
+ self.cached_tensors_and_file_paths = {}
223
+
224
+ def _is_support_inf_nan(self):
225
+ if self.support_inf_nan is not None:
226
+ return
227
+ try:
228
+ self.support_inf_nan = is_gpu or torch_npu.npu.utils.is_support_inf_nan()
229
+ except Exception:
230
+ logger.warning(f"Unable to determine if the current device supports inf/nan mode, default not supported.")
231
+ self.support_inf_nan = False
232
+
233
+ def _analyze_maybe_overflow_flag(self):
234
+ try:
235
+ self.has_overflow = torch_npu.npu.utils.get_npu_overflow_flag()
236
+ if self.has_overflow:
237
+ torch_npu.npu.utils.clear_npu_overflow_flag()
238
+ except Exception as e:
239
+ logger.error(f"Overflow check failed, the current environment may be abnormal.")
240
+ raise RuntimeError(f"overflow check failed") from e
241
+
242
+ def _analyze_maybe_overflow_tensor(self, tensor_json):
243
+ if tensor_json['Max'] is None or tensor_json['Min'] is None:
244
+ return
245
+ self.has_overflow = np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']) or \
246
+ np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min'])
247
+
248
+ def _analyze_tensor(self, tensor, suffix):
249
+ dump_data_name, file_path = self.get_save_file_path(suffix)
250
+ if not path_len_exceeds_limit(file_path):
251
+ self.cached_tensors_and_file_paths.update({file_path: tensor})
252
+ else:
253
+ logger.warning(f'The file path {file_path} length exceeds limit.')
254
+ single_arg = super()._analyze_tensor(tensor, suffix)
255
+ single_arg.update({"data_name": dump_data_name})
256
+ if not self.has_overflow and self.support_inf_nan:
257
+ self._analyze_maybe_overflow_tensor(single_arg)
258
+ return single_arg
259
+
260
+
261
+ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
262
+
263
+ def __init__(self, config, data_writer):
264
+ super().__init__(config, data_writer)
265
+ self.checker = FreeBenchmarkCheck(config=config)
266
+ self._return_forward_new_output = None
267
+ self._forward_new_output = None
268
+
269
+ def update_iter(self, current_iter):
270
+ super().update_iter(current_iter)
271
+ self.checker.update_iter(current_iter)
272
+
273
+ def update_unequal_rows(self, unequal_rows: List[UnequalRow]):
274
+ if not unequal_rows:
275
+ return
276
+ for row in unequal_rows:
277
+ data_dict = asdict(row)
278
+ self.data_writer.write_data_to_csv(
279
+ data_dict.values(),
280
+ data_dict.keys(),
281
+ self.data_writer.free_benchmark_file_path
282
+ )
283
+ return
284
+
285
+ def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
286
+ self.checker.pre_forward(name, module, self, module_input_output.args, module_input_output.kwargs)
287
+
288
+ def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
289
+ new_output, unequal_rows = self.checker.forward(
290
+ name,
291
+ module,
292
+ module_input_output.args,
293
+ module_input_output.kwargs,
294
+ module_input_output.output,
295
+ )
296
+ self.update_unequal_rows(unequal_rows)
297
+ if self.checker.if_fix():
298
+ self._return_forward_new_output = True
299
+ self._forward_new_output = new_output
300
+
301
+ def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
302
+ self.checker.backward(name, module, module_input_output.grad_input)
303
+
304
+
305
+ class KernelDumpDataProcessor(PytorchDataProcessor):
306
+ forward_init_status = False
307
+ multi_output_apis = ["_sort_", "npu_flash_attention"]
308
+
309
+ def __init__(self, config, data_writer):
310
+ super().__init__(config, data_writer)
311
+
312
+ def analyze_forward(self, name, module, module_input_output):
313
+ if self.config.is_forward_acl_dump:
314
+ self.forward_acl_dump(name, module, module_input_output)
315
+ else:
316
+ self.dump_mode_backward_acl_dump(name, module, module_input_output)
317
+
318
+ def forward_acl_dump(self, name, module, module_input_output):
319
+ if not KernelDumpDataProcessor.forward_init_status:
320
+ KernelDumpDataProcessor.forward_init_status = True
321
+ torch_npu.npu.synchronize()
322
+ torch_npu.npu.init_dump()
323
+ torch_npu.npu.set_dump(self.config.acl_config)
324
+ torch_npu.npu.synchronize()
325
+ if self.op_need_trigger(name):
326
+ module.forward(*module_input_output.args, **module_input_output.kwargs).cpu()
327
+ else:
328
+ module.forward(*module_input_output.args, **module_input_output.kwargs)
329
+ torch_npu.npu.synchronize()
330
+ torch_npu.npu.finalize_dump()
331
+ torch_npu.npu.synchronize()
332
+ KernelDumpDataProcessor.forward_init_status = False
333
+ logger.info("Dump %s op file." % name)
334
+
335
+ def acl_backward_dump_status(self, output, grad, module_name):
336
+ if isinstance(output, torch.Tensor):
337
+ output.backward(grad, retain_graph=True)
338
+ return True
339
+
340
+ for api_name in KernelDumpDataProcessor.multi_output_apis:
341
+ if api_name in module_name:
342
+ output[0].backward(grad, retain_graph=True)
343
+ return True
344
+ return False
345
+
346
+ def dump_mode_backward_acl_dump(self, name, module, module_input_output):
347
+ grad_path = self.config.backward_input.get(name)
348
+ if not KernelDumpDataProcessor.forward_init_status:
349
+ KernelDumpDataProcessor.forward_init_status = True
350
+ output = module.forward(*module_input_output.args, **module_input_output.kwargs)
351
+ pt = load_pt(grad_path)
352
+ grad = pt.to("npu").requires_grad_()
353
+ torch_npu.npu.init_dump()
354
+ torch_npu.npu.set_dump(self.config.acl_config)
355
+ torch_npu.npu.synchronize()
356
+ if not self.acl_backward_dump_status(output, grad, name):
357
+ logger.warning("The output of {} is not of tensor type and cannot be automatically derived. "
358
+ "you can manually construct a single API backward case for ACL dump.".format(
359
+ name))
360
+ torch_npu.npu.synchronize()
361
+ torch_npu.npu.finalize_dump()
362
+ KernelDumpDataProcessor.forward_init_status = False
363
+ logger.info("Dump %s op file." % name)
364
+
365
+ def op_need_trigger(self, module_name):
366
+ return 'Tensor.__getitem__.' in module_name