mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,134 +1,133 @@
1
- import argparse
2
- import os
3
- import sys
4
-
5
- try:
6
- import torch_npu
7
- except ImportError:
8
- is_gpu = True
9
- else:
10
- is_gpu = False
11
- import torch
12
- from tqdm import tqdm
13
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import generate_device_params, get_api_info
14
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api
15
- from msprobe.core.common.utils import get_json_contents
16
- from msprobe.core.common.file_check import check_link
17
- from msprobe.pytorch.common.log import logger
18
- from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
19
- from msprobe.core.common.const import Const
20
-
21
-
22
- def check_tensor_overflow(x):
23
- if isinstance(x, torch.Tensor) and x.numel() != 0 and x.dtype != torch.bool:
24
- if len(x.shape) == 0:
25
- tensor_max = x.cpu().detach().float().numpy().tolist()
26
- tensor_min = tensor_max
27
- else:
28
- tensor_max = torch._C._VariableFunctionsClass.max(x).cpu().detach().float().numpy().tolist()
29
- tensor_min = torch._C._VariableFunctionsClass.min(x).cpu().detach().float().numpy().tolist()
30
- # inf
31
- if tensor_max == float('inf') or tensor_min == float('-inf'):
32
- return True
33
- # nan
34
- elif tensor_max != tensor_max or tensor_min != tensor_min:
35
- return True
36
- else:
37
- return False
38
- elif isinstance(x, bool) or isinstance(x, int) or isinstance(x, float):
39
- if x == float('inf') or x == float('-inf') or x != x:
40
- return True
41
- else:
42
- return False
43
- else:
44
- return False
45
-
46
-
47
- def check_data_overflow(x):
48
- if isinstance(x, (tuple, list)) and x:
49
- for _, item in enumerate(x):
50
- if check_data_overflow(item):
51
- return True
52
- return False
53
- else:
54
- return check_tensor_overflow(x)
55
-
56
-
57
- def run_overflow_check(forward_file):
58
- logger.info("start UT test")
59
- forward_content, _, real_data_path = parse_json_info_forward_backward(forward_file)
60
- for api_full_name, api_info_dict in tqdm(forward_content.items()):
61
- try:
62
- run_torch_api(api_full_name, api_info_dict, real_data_path)
63
- except Exception as err:
64
- _, api_name, _ = api_full_name.split(Const.SEP)
65
- if "not implemented for 'Half'" in str(err):
66
- logger.warning(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API "
67
- f"'fp16_to_fp32' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
68
- elif "expected scalar type Long" in str(err):
69
- logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
70
- f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
71
- else:
72
- logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
73
-
74
-
75
- def run_torch_api(api_full_name, api_info_dict, real_data_path):
76
- torch.npu.clear_npu_overflow_flag()
77
- api_type, api_name, _ = api_full_name.split(Const.SEP)
78
- args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
79
- if not need_grad:
80
- logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
81
- % api_full_name)
82
- npu_args, npu_kwargs = generate_device_params(args, kwargs, False, api_name)
83
- if kwargs.get("device"):
84
- del kwargs["device"]
85
- out = exec_api(api_type, api_name, args, kwargs)
86
- npu_out = exec_api(api_type, api_name, npu_args, npu_kwargs)
87
- if out is None and npu_out is None:
88
- logger.warning("The %s overflow is a normal overflow, out and npu_out is None." % api_full_name)
89
- return
90
-
91
- cpu_overflow = check_data_overflow(out)
92
- npu_overflow = torch_npu.npu.utils.npu_check_overflow(npu_out)
93
- if cpu_overflow == npu_overflow:
94
- logger.warning("The %s overflow is a normal overflow." % api_full_name)
95
- else:
96
- logger.warning("The %s overflow is an abnormal overflow." % api_full_name)
97
- return
98
-
99
-
100
- def _run_overflow_check_parser(parser):
101
- parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="",
102
- help="<Required> The api param tool result file: generate from api param tool, "
103
- "a json file.",
104
- required=True)
105
- parser.add_argument("-j", "--jit_compile", dest="jit_compile", help="<optional> whether to turn on jit compile",
106
- default=False, required=False)
107
- parser.add_argument("-d", "--device", dest="device_id", type=int, help="<optional> set NPU device id to run ut",
108
- default=0, required=False)
109
-
110
-
111
- def _run_overflow_check(parser=None):
112
- if not parser:
113
- parser = argparse.ArgumentParser()
114
- _run_overflow_check_parser(parser)
115
- args = parser.parse_args(sys.argv[1:])
116
- _run_overflow_check_command(args)
117
-
118
-
119
- def _run_overflow_check_command(args):
120
- torch.npu.set_compile_mode(jit_compile=args.jit_compile)
121
- npu_device = "npu:" + str(args.device_id)
122
- check_link(args.api_info_file)
123
- api_info = os.path.realpath(args.api_info_file)
124
- try:
125
- torch.npu.set_device(npu_device)
126
- except Exception as error:
127
- logger.error(f"Set NPU device id failed. device id is: {args.device_id}")
128
- raise NotImplementedError from error
129
- run_overflow_check(api_info)
130
-
131
-
132
- if __name__ == '__main__':
133
- _run_overflow_check()
134
- logger.info("UT task completed.")
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+ try:
6
+ import torch_npu
7
+ except ImportError:
8
+ is_gpu = True
9
+ else:
10
+ is_gpu = False
11
+ import torch
12
+ from tqdm import tqdm
13
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import generate_device_params, get_api_info
14
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api
15
+ from msprobe.core.common.file_utils import check_link
16
+ from msprobe.pytorch.common.log import logger
17
+ from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
18
+ from msprobe.core.common.const import Const
19
+
20
+
21
+ def check_tensor_overflow(x):
22
+ if isinstance(x, torch.Tensor) and x.numel() != 0 and x.dtype != torch.bool:
23
+ if len(x.shape) == 0:
24
+ tensor_max = x.cpu().detach().float().numpy().tolist()
25
+ tensor_min = tensor_max
26
+ else:
27
+ tensor_max = torch._C._VariableFunctionsClass.max(x).cpu().detach().float().numpy().tolist()
28
+ tensor_min = torch._C._VariableFunctionsClass.min(x).cpu().detach().float().numpy().tolist()
29
+ # inf
30
+ if tensor_max == float('inf') or tensor_min == float('-inf'):
31
+ return True
32
+ # nan
33
+ elif tensor_max != tensor_max or tensor_min != tensor_min:
34
+ return True
35
+ else:
36
+ return False
37
+ elif isinstance(x, bool) or isinstance(x, int) or isinstance(x, float):
38
+ if x == float('inf') or x == float('-inf') or x != x:
39
+ return True
40
+ else:
41
+ return False
42
+ else:
43
+ return False
44
+
45
+
46
+ def check_data_overflow(x):
47
+ if isinstance(x, (tuple, list)) and x:
48
+ for _, item in enumerate(x):
49
+ if check_data_overflow(item):
50
+ return True
51
+ return False
52
+ else:
53
+ return check_tensor_overflow(x)
54
+
55
+
56
+ def run_overflow_check(forward_file):
57
+ logger.info("start UT test")
58
+ forward_content, _, real_data_path = parse_json_info_forward_backward(forward_file)
59
+ for api_full_name, api_info_dict in tqdm(forward_content.items()):
60
+ try:
61
+ run_torch_api(api_full_name, api_info_dict, real_data_path)
62
+ except Exception as err:
63
+ _, api_name, _ = api_full_name.split(Const.SEP)
64
+ if "not implemented for 'Half'" in str(err):
65
+ logger.warning(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API "
66
+ f"'fp16_to_fp32' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
67
+ elif "expected scalar type Long" in str(err):
68
+ logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
69
+ f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
70
+ else:
71
+ logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
72
+
73
+
74
+ def run_torch_api(api_full_name, api_info_dict, real_data_path):
75
+ torch.npu.clear_npu_overflow_flag()
76
+ api_type, api_name, _ = api_full_name.split(Const.SEP)
77
+ args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
78
+ if not need_grad:
79
+ logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
80
+ % api_full_name)
81
+ npu_args, npu_kwargs = generate_device_params(args, kwargs, False, api_name)
82
+ if kwargs.get("device"):
83
+ del kwargs["device"]
84
+ out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, args, kwargs)
85
+ npu_out = exec_api(api_type, api_name, Const.NPU_LOWERCASE, npu_args, npu_kwargs)
86
+ if out is None and npu_out is None:
87
+ logger.warning("The %s overflow is a normal overflow, out and npu_out is None." % api_full_name)
88
+ return
89
+
90
+ cpu_overflow = check_data_overflow(out)
91
+ npu_overflow = torch_npu.npu.utils.npu_check_overflow(npu_out)
92
+ if cpu_overflow == npu_overflow:
93
+ logger.warning("The %s overflow is a normal overflow." % api_full_name)
94
+ else:
95
+ logger.warning("The %s overflow is an abnormal overflow." % api_full_name)
96
+ return
97
+
98
+
99
+ def _run_overflow_check_parser(parser):
100
+ parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="",
101
+ help="<Required> The api param tool result file: generate from api param tool, "
102
+ "a json file.",
103
+ required=True)
104
+ parser.add_argument("-j", "--jit_compile", dest="jit_compile", help="<optional> whether to turn on jit compile",
105
+ default=False, required=False)
106
+ parser.add_argument("-d", "--device", dest="device_id", type=int, help="<optional> set NPU device id to run ut",
107
+ default=0, required=False)
108
+
109
+
110
+ def _run_overflow_check(parser=None):
111
+ if not parser:
112
+ parser = argparse.ArgumentParser()
113
+ _run_overflow_check_parser(parser)
114
+ args = parser.parse_args(sys.argv[1:])
115
+ _run_overflow_check_command(args)
116
+
117
+
118
+ def _run_overflow_check_command(args):
119
+ torch.npu.set_compile_mode(jit_compile=args.jit_compile)
120
+ npu_device = "npu:" + str(args.device_id)
121
+ check_link(args.api_info_file)
122
+ api_info = os.path.realpath(args.api_info_file)
123
+ try:
124
+ torch.npu.set_device(npu_device)
125
+ except Exception as error:
126
+ logger.error(f"Set NPU device id failed. device id is: {args.device_id}")
127
+ raise NotImplementedError from error
128
+ run_overflow_check(api_info)
129
+
130
+
131
+ if __name__ == '__main__':
132
+ _run_overflow_check()
133
+ logger.info("UT task completed.")