mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,581 +1,518 @@
1
- import argparse
2
- import os
3
- import csv
4
- import sys
5
- import time
6
- import gc
7
- from collections import namedtuple
8
-
9
- try:
10
- import torch_npu
11
- except ImportError:
12
- is_gpu = True
13
- current_device = "cuda"
14
- else:
15
- is_gpu = False
16
- current_device = "npu"
17
- import torch
18
- from tqdm import tqdm
19
-
20
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api, UtDataInfo, \
21
- get_validated_result_csv_path, get_validated_details_csv_path, exec_api
22
- from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
23
- from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
24
- initialize_save_path, UtDataProcessor
25
- from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
26
- from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
27
- from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
28
- from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
29
- from msprobe.core.common.file_check import FileOpen, FileChecker, \
30
- change_mode, check_path_before_create, create_directory
31
- from msprobe.pytorch.common.log import logger
32
- from msprobe.core.common.utils import get_json_contents
33
- from msprobe.pytorch.pt_config import parse_json_config
34
- from msprobe.core.common.const import Const, FileCheckConst, CompareConst
35
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, ApiData, move2device_exec
36
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
37
-
38
-
39
- current_time = time.strftime("%Y%m%d%H%M%S")
40
- UT_ERROR_DATA_DIR = 'ut_error_data' + current_time
41
- RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
42
- DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
43
- RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
44
- 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
45
- 'black_list', 'error_data_path', 'online_config'])
46
-
47
- OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
48
-
49
- not_backward_list = ['repeat_interleave']
50
- not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
51
- not_raise_dtype_set = {'type_as'}
52
-
53
- RAISE_PRECISION = {
54
- torch.float16: torch.float32,
55
- torch.bfloat16: torch.float32,
56
- torch.float32: torch.float64
57
- }
58
-
59
- tqdm_params = {
60
- 'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1
61
- 'desc': 'Processing', # 进度条前的描述文字
62
- 'leave': True, # 迭代完成后保留进度条的显示
63
- 'ncols': 75, # 进度条的固定宽度
64
- 'mininterval': 0.1, # 更新进度条的最小间隔秒数
65
- 'maxinterval': 1.0, # 更新进度条的最大间隔秒数
66
- 'miniters': 1, # 更新进度条之间的最小迭代次数
67
- 'ascii': None, # 根据环境自动使用ASCII或Unicode字符
68
- 'unit': 'it', # 迭代单位
69
- 'unit_scale': True, # 自动根据单位缩放
70
- 'dynamic_ncols': True, # 动态调整进度条宽度以适应控制台
71
- 'bar_format': '{l_bar}{bar}| {n}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]' # 自定义进度条输出格式
72
- }
73
-
74
-
75
- def deal_detach(arg, to_detach=True):
76
- return arg.detach() if to_detach else arg
77
-
78
-
79
- def raise_bench_data_dtype(api_name, arg, raise_dtype=None):
80
- '''
81
- 将标杆数据的dtype转换为raise_dtype
82
- 输入:
83
- api_name:api名称
84
- arg:标杆输入
85
- raise_dtype:需要转换的dtype
86
- 输出:
87
- arg: 转换dtype的标杆输入
88
- '''
89
- if api_name in hf_32_standard_api and arg.dtype == torch.float32:
90
- return arg
91
- if raise_dtype is None or arg.dtype not in RAISE_PRECISION or raise_dtype == arg.dtype:
92
- return arg
93
- return arg.type(raise_dtype)
94
-
95
-
96
- def generate_device_params(input_args, input_kwargs, need_backward, api_name):
97
- def recursive_arg_to_device(arg_in, to_detach):
98
- if isinstance(arg_in, (list, tuple)):
99
- return type(arg_in)(recursive_arg_to_device(arg, to_detach) for arg in arg_in)
100
- elif isinstance(arg_in, torch.Tensor):
101
- if need_backward and arg_in.requires_grad:
102
- arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_()
103
- temp_arg_in = arg_in * 1
104
- arg_in = temp_arg_in.type_as(arg_in)
105
- arg_in.retain_grad()
106
- return arg_in
107
- else:
108
- return deal_detach(arg_in.clone(), to_detach).to(current_device)
109
- else:
110
- return arg_in
111
-
112
- is_detach = api_name not in not_detach_set
113
- device_args = recursive_arg_to_device(input_args, is_detach)
114
- device_kwargs = \
115
- {key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()}
116
- return device_args, device_kwargs
117
-
118
-
119
- def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
120
- def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None):
121
- if isinstance(arg_in, (list, tuple)):
122
- return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype) for arg in arg_in)
123
- elif isinstance(arg_in, torch.Tensor):
124
- if need_backward and arg_in.requires_grad:
125
- arg_in = deal_detach(raise_bench_data_dtype(
126
- api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
127
- temp_arg_in = arg_in * 1
128
- arg_in = temp_arg_in.type_as(arg_in)
129
- arg_in.retain_grad()
130
- return arg_in
131
- else:
132
- return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach)
133
- else:
134
- return arg_in
135
-
136
- def is_tensor_with_raise_precision(arg_in, check_kwargs=False):
137
- if arg_in.dtype in RAISE_PRECISION:
138
- return True
139
- if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]:
140
- return True
141
- return False
142
-
143
- def recursive_find_dtypes(arg_in, kwargs=None, check_kwargs=False):
144
- if isinstance(arg_in, (list, tuple)):
145
- return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs) for arg in arg_in))
146
- elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
147
- return set([arg_in.dtype])
148
- elif isinstance(arg_in, dict) and check_kwargs:
149
- return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True) for v in arg_in.values()))
150
- return set()
151
-
152
- raise_dtype = None
153
- need_raise_dtypes = recursive_find_dtypes(input_args)
154
- need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
155
- if len(need_raise_dtypes) == 1:
156
- raise_dtype = RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32)
157
- elif len(need_raise_dtypes) >= 2:
158
- raise_dtype = torch.float32
159
-
160
- raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
161
- is_detach = api_name not in not_detach_set
162
- cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
163
- cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()}
164
- return cpu_args, cpu_kwargs
165
-
166
-
167
- def run_ut(config):
168
- logger.info("start UT test")
169
- if config.online_config.is_online:
170
- logger.info(f"UT task result will be saved in {config.result_csv_path}".replace(".csv", "_rank*.csv"))
171
- logger.info(f"UT task details will be saved in {config.details_csv_path}".replace(".csv", "_rank*.csv"))
172
- else:
173
- logger.info(f"UT task result will be saved in {config.result_csv_path}")
174
- logger.info(f"UT task details will be saved in {config.details_csv_path}")
175
-
176
- if config.save_error_data:
177
- logger.info(f"UT task error_datas will be saved in {config.error_data_path}")
178
- compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
179
-
180
- if config.online_config.is_online:
181
- run_api_online(config, compare)
182
- else:
183
- with FileOpen(config.result_csv_path, 'r') as file:
184
- csv_reader = csv.reader(file)
185
- next(csv_reader)
186
- api_name_set = {row[0] for row in csv_reader}
187
- run_api_offline(config, compare, api_name_set)
188
- for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
189
- change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
190
- change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
191
- logger.info(f"UT task result csv is saved in {result_csv_path}")
192
- logger.info(f"UT task details csv is saved in {details_csv_path}")
193
- compare.print_pretest_result()
194
-
195
-
196
- def run_api_offline(config, compare, api_name_set):
197
- for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
198
- if api_full_name in api_name_set:
199
- continue
200
- if is_unsupported_api(api_full_name):
201
- continue
202
- [_, api_name, _] = api_full_name.split(Const.SEP)
203
- try:
204
- if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
205
- continue
206
- data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict)
207
- is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info)
208
- if config.save_error_data:
209
- do_save_error_data(api_full_name, data_info, config.error_data_path, is_fwd_success, is_bwd_success)
210
- except Exception as err:
211
- if "expected scalar type Long" in str(err):
212
- logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
213
- f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
214
- else:
215
- logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
216
- err_column = CompareColumn()
217
- fwd_compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
218
- result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [fwd_compare_alg_results], None, 0)
219
- compare.record_results(result_info)
220
- finally:
221
- if is_gpu:
222
- torch.cuda.empty_cache()
223
- else:
224
- torch.npu.empty_cache()
225
- gc.collect()
226
-
227
-
228
- def run_api_online(config, compare):
229
- attl = init_attl(config.online_config)
230
- dispatcher = ConsumerDispatcher(compare=compare)
231
- dispatcher.start(handle_func=run_torch_api_online, config=config)
232
-
233
- def tcp_communication_flow():
234
- while True:
235
- api_data = attl.recv()
236
- if api_data == 'STOP_':
237
- continue
238
- if api_data == 'KILL_':
239
- time.sleep(1)
240
- logger.info("==========接收到STOP信号==========")
241
- dispatcher.stop()
242
- attl.stop_serve()
243
- time.sleep(1)
244
- break
245
- if not isinstance(api_data, ApiData):
246
- continue
247
- api_full_name = api_data.name
248
- [_, api_name, _] = api_full_name.split(Const.SEP)
249
- if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
250
- continue
251
- dispatcher.update_consume_queue(api_data)
252
-
253
- def shared_storage_communication_flow():
254
- flag_num = -1
255
- while True:
256
- api_data = attl.download()
257
- if api_data == "start":
258
- if flag_num == -1:
259
- flag_num += 1
260
- flag_num += 1
261
- if api_data == "end":
262
- flag_num -= 1
263
- if flag_num == 0:
264
- dispatcher.stop()
265
- break
266
- if not isinstance(api_data, ApiData):
267
- continue
268
- api_full_name = api_data.name
269
- [_, api_name, _] = api_full_name.split(Const.SEP)
270
- if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
271
- continue
272
- dispatcher.update_consume_queue(api_data)
273
-
274
- if config.online_config.nfs_path:
275
- shared_storage_communication_flow()
276
- else:
277
- tcp_communication_flow()
278
-
279
-
280
- def blacklist_and_whitelist_filter(api_name, black_list, white_list):
281
- """
282
- run api(api_name) if api_name not in black_list and in white_list.
283
- If api is both in black_list and black_list, black_list first.
284
- return: False for exec api, True for not exec
285
- """
286
- if black_list and api_name in black_list:
287
- return True
288
- if white_list and api_name not in white_list:
289
- return True
290
- return False
291
-
292
-
293
- def is_unsupported_api(api_name):
294
- split_name = api_name.split(Const.SEP)[0]
295
- flag = split_name in [Const.NPU, Const.DISTRIBUTED]
296
- if flag:
297
- logger.info(f"{split_name} api is not supported for run ut. SKIP.")
298
- return flag
299
-
300
-
301
- def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success, is_bwd_success):
302
- if not is_fwd_success or not is_bwd_success:
303
- processor = UtDataProcessor(error_data_path)
304
- for element in data_info.in_fwd_data_list:
305
- processor.save_tensors_in_element(api_full_name + '.forward.input', element)
306
- processor.save_tensors_in_element(api_full_name + '.forward.output.bench', data_info.bench_output)
307
- processor.save_tensors_in_element(api_full_name + '.forward.output.device', data_info.device_output)
308
- processor.save_tensors_in_element(api_full_name + '.backward.input', data_info.grad_in)
309
- processor.save_tensors_in_element(api_full_name + '.backward.output.bench', data_info.bench_grad)
310
- processor.save_tensors_in_element(api_full_name + '.backward.output.device', data_info.device_grad)
311
-
312
-
313
- def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict):
314
- in_fwd_data_list = []
315
- backward_message = ''
316
- [api_type, api_name, _] = api_full_name.split(Const.SEP)
317
- args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
318
- in_fwd_data_list.append(args)
319
- in_fwd_data_list.append(kwargs)
320
- need_backward = api_full_name in backward_content
321
- if not need_grad:
322
- logger.warning("%s %s" % (api_full_name, Backward_Message.UNSUPPORT_BACKWARD_MESSAGE))
323
- backward_message += Backward_Message.UNSUPPORT_BACKWARD_MESSAGE
324
- if api_name in not_backward_list:
325
- need_grad = False
326
- logger.warning("%s %s" % (api_full_name, Backward_Message.NO_BACKWARD_RESULT_MESSAGE))
327
- backward_message += Backward_Message.NO_BACKWARD_RESULT_MESSAGE
328
- need_backward = need_backward and need_grad
329
- if kwargs.get("device"):
330
- del kwargs["device"]
331
- cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward, api_name)
332
- device_args, device_kwargs = generate_device_params(args, kwargs, need_backward, api_name)
333
- bench_grad_out, device_grad_out = None, None
334
- out = exec_api(api_type, api_name, cpu_args, cpu_kwargs)
335
- device_out = exec_api(api_type, api_name, device_args, device_kwargs)
336
- current_path = os.path.dirname(os.path.realpath(__file__))
337
- ut_setting_path = os.path.join(current_path, "torch_ut_setting.json")
338
- api_setting_dict = get_json_contents(ut_setting_path)
339
- grad_input_index = api_setting_dict.get(api_name)
340
- grad_index = None
341
- grad, bench_grad = None, None
342
- if grad_input_index is not None:
343
- grad_index = grad_input_index.get('grad_index')
344
-
345
- if need_backward:
346
- if need_to_backward(grad_index, out):
347
- backward_args = backward_content[api_full_name].get("input")
348
- grad = gen_args(backward_args, api_name, real_data_path=real_data_path)[0]
349
- bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
350
- bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
351
- device_grad = grad.clone().detach().to(current_device)
352
- device_grad_out = run_backward(device_args, device_grad, grad_index, device_out)
353
- else:
354
- backward_message += Backward_Message.MULTIPLE_BACKWARD_MESSAGE
355
-
356
- return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
357
-
358
-
359
- def run_torch_api_online(api_full_name, api_data, backward_content):
360
- in_fwd_data_list = []
361
- [api_type, api_name, _] = api_full_name.split(Const.SEP)
362
- args, kwargs, out = api_data.args, api_data.kwargs, api_data.result
363
- in_fwd_data_list.append(args)
364
- in_fwd_data_list.append(kwargs)
365
- if kwargs.get("device"):
366
- del kwargs["device"]
367
-
368
- device_out = exec_api(api_type, api_name, args, kwargs)
369
- device_out = move2device_exec(device_out, "cpu")
370
- return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
371
-
372
-
373
- def get_api_info(api_info_dict, api_name, real_data_path):
374
- convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
375
- need_grad = True
376
- if api_info_dict.get("input_kwargs") and "out" in api_info_dict.get("input_kwargs"):
377
- need_grad = False
378
- args, kwargs = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path)
379
- return args, kwargs, need_grad
380
-
381
-
382
- def need_to_backward(grad_index, out):
383
- if grad_index is None and isinstance(out, (list, tuple)):
384
- return False
385
- return True
386
-
387
-
388
- def run_backward(args, grad, grad_index, out):
389
- if grad_index is not None:
390
- out[grad_index].backward(grad)
391
- else:
392
- out.backward(grad)
393
- args_grad = []
394
- for arg in args:
395
- if isinstance(arg, torch.Tensor):
396
- args_grad.append(arg.grad)
397
- grad_out = args_grad
398
-
399
- return grad_out
400
-
401
-
402
- def initialize_save_error_data(error_data_path):
403
- check_path_before_create(error_data_path)
404
- create_directory(error_data_path)
405
- error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
406
- ability=FileCheckConst.WRITE_ABLE)
407
- error_data_path = error_data_path_checker.common_check()
408
- error_data_path =initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
409
- return error_data_path
410
-
411
-
412
- def init_attl(config):
413
- """config: OnlineConfig"""
414
- attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True,
415
- connect_ip=config.host,
416
- connect_port=config.port,
417
- nfs_path=config.nfs_path,
418
- tls_path=config.tls_path))
419
- return attl
420
-
421
-
422
- def _run_ut_parser(parser):
423
- parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
424
- help="<Optional> The api param tool result file: generate from api param tool, "
425
- "a json file.",
426
- required=False)
427
- parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
428
- help="<optional> The ut task result out path.",
429
- required=False)
430
- parser.add_argument('-save_error_data', dest="save_error_data", action="store_true",
431
- help="<optional> Save compare failed api output.", required=False)
432
- parser.add_argument("-j", "--jit_compile", dest="jit_compile", action="store_true",
433
- help="<optional> whether to turn on jit compile", required=False)
434
-
435
- class UniqueDeviceAction(argparse.Action):
436
- def __call__(self, parser, namespace, values, option_string=None):
437
- unique_values = set(values)
438
- if len(values) != len(unique_values):
439
- parser.error("device id must be unique")
440
- for device_id in values:
441
- if not 0 <= device_id:
442
- parser.error("device id must be greater than or equal to 0")
443
- setattr(namespace, self.dest, values)
444
-
445
- parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int,
446
- help="<optional> set device id to run ut, must be unique and in range 0-7",
447
- default=[0], required=False, action=UniqueDeviceAction)
448
- parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str,
449
- help="<optional> The path of accuracy_checking_result_{timestamp}.csv, "
450
- "when run ut is interrupted, enter the file path to continue run ut.",
451
- required=False)
452
- parser.add_argument("-f", "--filter_api", dest="filter_api", action="store_true",
453
- help="<optional> Whether to filter the api in the api_info_file.", required=False)
454
- parser.add_argument("-config", "--config_path", dest="config_path", default="", type=str,
455
- help="<optional> The path of config.json", required=False)
456
-
457
-
458
- def preprocess_forward_content(forward_content):
459
- processed_content = {}
460
- base_keys_variants = {}
461
- arg_cache = {}
462
-
463
- for key, value in forward_content.items():
464
- base_key = key.rsplit(Const.SEP, 1)[0]
465
-
466
- if key not in arg_cache:
467
- filtered_new_args = [
468
- {k: v for k, v in arg.items() if k not in ['Max', 'Min']}
469
- for arg in value['input_args'] if isinstance(arg, dict)
470
- ]
471
- arg_cache[key] = (filtered_new_args, value['input_kwargs'])
472
-
473
- filtered_new_args, new_kwargs = arg_cache[key]
474
-
475
- if base_key not in base_keys_variants:
476
- processed_content[key] = value
477
- base_keys_variants[base_key] = {key}
478
- else:
479
- is_duplicate = False
480
- for variant in base_keys_variants.get(base_key, []):
481
- try:
482
- existing_args, existing_kwargs = arg_cache.get(variant)
483
- except KeyError as e:
484
- logger.error(f"KeyError: {e} when processing {key}")
485
- if existing_args == filtered_new_args and existing_kwargs == new_kwargs:
486
- is_duplicate = True
487
- break
488
-
489
- if not is_duplicate:
490
- processed_content[key] = value
491
- base_keys_variants[base_key].add(key)
492
-
493
- return processed_content
494
-
495
-
496
- def _run_ut(parser=None):
497
- if not parser:
498
- parser = argparse.ArgumentParser()
499
- _run_ut_parser(parser)
500
- args = parser.parse_args(sys.argv[1:])
501
- run_ut_command(args)
502
-
503
-
504
- def run_ut_command(args):
505
- if not is_gpu:
506
- torch.npu.set_compile_mode(jit_compile=args.jit_compile)
507
- used_device = current_device + ":" + str(args.device_id[0])
508
- try:
509
- if is_gpu:
510
- torch.cuda.set_device(used_device)
511
- else:
512
- torch.npu.set_device(used_device)
513
- except Exception as error:
514
- logger.error(f"Set device id failed. device id is: {args.device_id}")
515
- raise NotImplementedError from error
516
-
517
- # 在线预检场景下,不需要外出输出api信息,forward_content, backward_content, real_data_path设置为None
518
- # 离线场景下,forward_content, backward_content, real_data_path从api_info_file中解析
519
- forward_content, backward_content, real_data_path = None, None, None
520
- if args.api_info_file:
521
- api_info_file_checker = FileChecker(file_path = args.api_info_file, path_type = FileCheckConst.FILE,
522
- ability = FileCheckConst.READ_ABLE, file_type = FileCheckConst.JSON_SUFFIX)
523
- checked_api_info = api_info_file_checker.common_check()
524
- forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info)
525
- if args.filter_api:
526
- logger.info("Start filtering the api in the forward_input_file.")
527
- forward_content = preprocess_forward_content(forward_content)
528
- logger.info("Finish filtering the api in the forward_input_file.")
529
-
530
- out_path = os.path.realpath(args.out_path) if args.out_path else "./"
531
- check_path_before_create(out_path)
532
- create_directory(out_path)
533
- out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
534
- out_path = out_path_checker.common_check()
535
- save_error_data = args.save_error_data
536
-
537
- result_csv_path = os.path.join(out_path, RESULT_FILE_NAME)
538
- details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME)
539
- if args.result_csv_path:
540
- result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
541
- details_csv_path = get_validated_details_csv_path(result_csv_path)
542
- white_list = msCheckerConfig.white_list
543
- black_list = msCheckerConfig.black_list
544
- error_data_path = msCheckerConfig.error_data_path
545
- is_online = msCheckerConfig.is_online
546
- nfs_path = msCheckerConfig.nfs_path
547
- host = msCheckerConfig.host
548
- port = msCheckerConfig.port
549
- rank_list = msCheckerConfig.rank_list
550
- tls_path = msCheckerConfig.tls_path
551
- if args.config_path:
552
- config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
553
- FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
554
- checked_config_path = config_path_checker.common_check()
555
- _, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
556
- white_list = task_config.white_list
557
- black_list = task_config.black_list
558
- error_data_path = task_config.error_data_path
559
- is_online = task_config.is_online
560
- nfs_path = task_config.nfs_path
561
- host = task_config.host
562
- port = task_config.port
563
- rank_list = task_config.rank_list
564
- tls_path = task_config.tls_path
565
-
566
- if save_error_data:
567
- if args.result_csv_path:
568
- time_info = result_csv_path.split('.')[0].split('_')[-1]
569
- global UT_ERROR_DATA_DIR
570
- UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
571
- error_data_path = initialize_save_error_data(error_data_path)
572
- online_config = OnlineConfig(is_online, nfs_path, host, port, rank_list, tls_path)
573
- run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data,
574
- args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path,
575
- online_config)
576
- run_ut(run_ut_config)
577
-
578
-
579
- if __name__ == '__main__':
580
- _run_ut()
581
- logger.info("UT task completed.")
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import argparse
19
+ import os
20
+ import csv
21
+ import sys
22
+ import time
23
+ import gc
24
+ from collections import namedtuple
25
+
26
+ try:
27
+ import torch_npu
28
+ except ImportError:
29
+ is_gpu = True
30
+ current_device = "cuda"
31
+ else:
32
+ is_gpu = False
33
+ current_device = "npu"
34
+ import torch
35
+ from tqdm import tqdm
36
+
37
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import BackwardMessage, UtDataInfo, \
38
+ get_validated_result_csv_path, get_validated_details_csv_path, exec_api, record_skip_info
39
+ from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
40
+ from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
41
+ initialize_save_path, UtDataProcessor, extract_basic_api_segments, ApiData
42
+ from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
43
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
44
+ from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
45
+ from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
46
+ from msprobe.core.common.file_utils import FileChecker, change_mode, check_path_before_create, \
47
+ create_directory, get_json_contents, read_csv
48
+ from msprobe.pytorch.common.log import logger
49
+ from msprobe.pytorch.pt_config import parse_json_config
50
+ from msprobe.core.common.const import Const, FileCheckConst, CompareConst
51
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
52
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
53
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params
54
+
55
+
56
+ current_time = time.strftime("%Y%m%d%H%M%S")
57
+ UT_ERROR_DATA_DIR = 'ut_error_data' + current_time
58
+ RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
59
+ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
60
+ RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
61
+ 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
62
+ 'black_list', 'error_data_path', 'online_config'])
63
+
64
+ OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
65
+
66
+ not_backward_list = ['repeat_interleave']
67
+
68
+
69
+ tqdm_params = {
70
+ 'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1
71
+ 'desc': 'Processing', # 进度条前的描述文字
72
+ 'leave': True, # 迭代完成后保留进度条的显示
73
+ 'ncols': 75, # 进度条的固定宽度
74
+ 'mininterval': 0.1, # 更新进度条的最小间隔秒数
75
+ 'maxinterval': 1.0, # 更新进度条的最大间隔秒数
76
+ 'miniters': 1, # 更新进度条之间的最小迭代次数
77
+ 'ascii': None, # 根据环境自动使用ASCII或Unicode字符
78
+ 'unit': 'it', # 迭代单位
79
+ 'unit_scale': True, # 自动根据单位缩放
80
+ 'dynamic_ncols': True, # 动态调整进度条宽度以适应控制台
81
+ 'bar_format': '{l_bar}{bar}| {n}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]' # 自定义进度条输出格式
82
+ }
83
+
84
+
85
+ def run_ut(config):
86
+ logger.info("start UT test")
87
+ if config.online_config.is_online:
88
+ logger.info(f"UT task result will be saved in {config.result_csv_path}".replace(".csv", "_rank*.csv"))
89
+ logger.info(f"UT task details will be saved in {config.details_csv_path}".replace(".csv", "_rank*.csv"))
90
+ else:
91
+ logger.info(f"UT task result will be saved in {config.result_csv_path}")
92
+ logger.info(f"UT task details will be saved in {config.details_csv_path}")
93
+
94
+ if config.save_error_data:
95
+ logger.info(f"UT task error_datas will be saved in {config.error_data_path}")
96
+ compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
97
+
98
+ if config.online_config.is_online:
99
+ run_api_online(config, compare)
100
+ else:
101
+ csv_df = read_csv(config.result_csv_path)
102
+ api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
103
+ run_api_offline(config, compare, api_name_set)
104
+ for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
105
+ change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
106
+ change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
107
+ logger.info(f"UT task result csv is saved in {result_csv_path}")
108
+ logger.info(f"UT task details csv is saved in {details_csv_path}")
109
+ compare.print_pretest_result()
110
+
111
+
112
+ def run_api_offline(config, compare, api_name_set):
113
+ err_column = CompareColumn()
114
+ for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
115
+ if api_full_name in api_name_set:
116
+ continue
117
+ if is_unsupported_api(api_full_name):
118
+ skip_message = f"API {api_full_name} not support for run ut. SKIP."
119
+ compare_alg_results = err_column.to_column_value(CompareConst.SKIP, skip_message)
120
+ record_skip_info(api_full_name, compare, compare_alg_results)
121
+ continue
122
+ _, api_name = extract_basic_api_segments(api_full_name)
123
+ if not api_name:
124
+ err_message = f"API {api_full_name} not support for run ut. SKIP."
125
+ logger.error(err_message)
126
+ compare_alg_results = err_column.to_column_value(CompareConst.SKIP, err_message)
127
+ record_skip_info(api_full_name, compare, compare_alg_results)
128
+ continue
129
+ try:
130
+ if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
131
+ skip_message = f"API {api_name} in black list or not in white list. SKIP."
132
+ logger.info(skip_message)
133
+ compare_alg_results = err_column.to_column_value(CompareConst.SKIP, skip_message)
134
+ record_skip_info(api_full_name, compare, compare_alg_results)
135
+ continue
136
+ data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict)
137
+ is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info)
138
+ if config.save_error_data:
139
+ do_save_error_data(api_full_name, data_info, config.error_data_path, is_fwd_success, is_bwd_success)
140
+ except Exception as err:
141
+ if "expected scalar type Long" in str(err):
142
+ logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
143
+ f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
144
+ else:
145
+ logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
146
+ compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
147
+ record_skip_info(api_full_name, compare, compare_alg_results)
148
+ finally:
149
+ if is_gpu:
150
+ torch.cuda.empty_cache()
151
+ else:
152
+ torch.npu.empty_cache()
153
+ gc.collect()
154
+
155
+
156
+ def run_api_online(config, compare):
157
+ attl = init_attl(config.online_config)
158
+ dispatcher = ConsumerDispatcher(compare=compare)
159
+ dispatcher.start(handle_func=run_torch_api_online, config=config)
160
+
161
+ def tcp_communication_flow():
162
+ while True:
163
+ api_data = attl.recv()
164
+ if api_data == 'STOP_':
165
+ continue
166
+ if api_data == 'KILL_':
167
+ time.sleep(1)
168
+ logger.info("==========接收到STOP信号==========")
169
+ dispatcher.stop()
170
+ attl.stop_serve()
171
+ time.sleep(1)
172
+ break
173
+ if not isinstance(api_data, ApiData):
174
+ continue
175
+ api_full_name = api_data.name
176
+ _, api_name = extract_basic_api_segments(api_full_name)
177
+ if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
178
+ continue
179
+ if api_data.rank in config.online_config.rank_list:
180
+ dispatcher.update_consume_queue(api_data)
181
+
182
+ def shared_storage_communication_flow():
183
+ flag_num = -1
184
+ while True:
185
+ api_data = attl.download()
186
+ if api_data == "start":
187
+ if flag_num == -1:
188
+ flag_num += 1
189
+ flag_num += 1
190
+ if api_data == "end":
191
+ flag_num -= 1
192
+ if flag_num == 0:
193
+ dispatcher.stop()
194
+ break
195
+ if not isinstance(api_data, ApiData):
196
+ continue
197
+ api_full_name = api_data.name
198
+ _, api_name = extract_basic_api_segments(api_full_name)
199
+ if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
200
+ continue
201
+ if api_data.rank in config.online_config.rank_list:
202
+ dispatcher.update_consume_queue(api_data)
203
+
204
+ if config.online_config.nfs_path:
205
+ shared_storage_communication_flow()
206
+ else:
207
+ tcp_communication_flow()
208
+
209
+
210
+ def blacklist_and_whitelist_filter(api_name, black_list, white_list):
211
+ """
212
+ run api(api_name) if api_name not in black_list and in white_list.
213
+ If api is both in black_list and black_list, black_list first.
214
+ return: False for exec api, True for not exec
215
+ """
216
+ if black_list and api_name in black_list:
217
+ return True
218
+ if white_list and api_name not in white_list:
219
+ return True
220
+ return False
221
+
222
+
223
+ def is_unsupported_api(api_name):
224
+ split_name = api_name.split(Const.SEP)[0]
225
+ flag = split_name == Const.DISTRIBUTED
226
+ if flag:
227
+ logger.info(f"{split_name} api is not supported for run ut. SKIP.")
228
+ return flag
229
+
230
+
231
+ def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success, is_bwd_success):
232
+ if not is_fwd_success or not is_bwd_success:
233
+ processor = UtDataProcessor(error_data_path)
234
+ for element in data_info.in_fwd_data_list:
235
+ processor.save_tensors_in_element(api_full_name + '.forward.input', element)
236
+ processor.save_tensors_in_element(api_full_name + '.forward.output.bench', data_info.bench_output)
237
+ processor.save_tensors_in_element(api_full_name + '.forward.output.device', data_info.device_output)
238
+ processor.save_tensors_in_element(api_full_name + '.backward.input', data_info.grad_in)
239
+ processor.save_tensors_in_element(api_full_name + '.backward.output.bench', data_info.bench_grad)
240
+ processor.save_tensors_in_element(api_full_name + '.backward.output.device', data_info.device_grad)
241
+
242
+
243
+ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict):
244
+ in_fwd_data_list = []
245
+ backward_message = ''
246
+ api_type, api_name = extract_basic_api_segments(api_full_name)
247
+ args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
248
+ in_fwd_data_list.append(args)
249
+ in_fwd_data_list.append(kwargs)
250
+ need_backward = api_full_name in backward_content
251
+ if not need_grad:
252
+ logger.warning("%s %s" % (api_full_name, BackwardMessage.UNSUPPORT_BACKWARD_MESSAGE))
253
+ backward_message += BackwardMessage.UNSUPPORT_BACKWARD_MESSAGE
254
+ if api_name in not_backward_list:
255
+ need_grad = False
256
+ logger.warning("%s %s" % (api_full_name, BackwardMessage.NO_BACKWARD_RESULT_MESSAGE))
257
+ backward_message += BackwardMessage.NO_BACKWARD_RESULT_MESSAGE
258
+ need_backward = need_backward and need_grad
259
+ if kwargs.get("device"):
260
+ del kwargs["device"]
261
+ cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward, api_name)
262
+ device_args, device_kwargs = generate_device_params(args, kwargs, need_backward, api_name)
263
+ bench_grad_out, device_grad_out = None, None
264
+ out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs)
265
+ device_out = exec_api(api_type, api_name, current_device, device_args, device_kwargs)
266
+ current_path = os.path.dirname(os.path.realpath(__file__))
267
+ ut_setting_path = os.path.join(current_path, "torch_ut_setting.json")
268
+ api_setting_dict = get_json_contents(ut_setting_path)
269
+ grad_input_index = api_setting_dict.get(api_name)
270
+ grad_index = None
271
+ grad, bench_grad = None, None
272
+ if grad_input_index is not None:
273
+ grad_index = grad_input_index.get('grad_index')
274
+
275
+ if need_backward:
276
+ if need_to_backward(grad_index, out):
277
+ backward_args = backward_content[api_full_name].get("input")
278
+ func_options = {
279
+ 'real_data_path': real_data_path
280
+ }
281
+ grad = gen_args(backward_args, api_name, func_options)[0]
282
+ bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
283
+ bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
284
+ device_grad = grad.clone().detach().to(current_device)
285
+ device_grad_out = run_backward(device_args, device_grad, grad_index, device_out)
286
+ else:
287
+ backward_message += BackwardMessage.MULTIPLE_BACKWARD_MESSAGE
288
+ if api_name == "npu_fusion_attention":
289
+ out = out[0]
290
+ device_out = device_out[0]
291
+
292
+ return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
293
+
294
+
295
+ def run_torch_api_online(api_full_name, api_data, backward_content):
296
+ in_fwd_data_list = []
297
+ api_type, api_name = extract_basic_api_segments(api_full_name)
298
+ args, kwargs, out = api_data.args, api_data.kwargs, api_data.result
299
+ in_fwd_data_list.append(args)
300
+ in_fwd_data_list.append(kwargs)
301
+ if kwargs.get("device"):
302
+ del kwargs["device"]
303
+
304
+ device_out = exec_api(api_type, api_name, Const.CUDA_LOWERCASE, args, kwargs)
305
+ device_out = move2device_exec(device_out, "cpu")
306
+ return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
307
+
308
+
309
+ def get_api_info(api_info_dict, api_name, real_data_path):
310
+ convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
311
+ need_grad = True
312
+ if api_info_dict.get("input_kwargs") and "out" in api_info_dict.get("input_kwargs"):
313
+ need_grad = False
314
+ args, kwargs = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path)
315
+ return args, kwargs, need_grad
316
+
317
+
318
+ def need_to_backward(grad_index, out):
319
+ if grad_index is None and isinstance(out, (list, tuple)):
320
+ return False
321
+ return True
322
+
323
+
324
+ def run_backward(args, grad, grad_index, out):
325
+ if grad_index is not None:
326
+ out[grad_index].backward(grad)
327
+ else:
328
+ out.backward(grad)
329
+ args_grad = []
330
+ for arg in args:
331
+ if isinstance(arg, torch.Tensor):
332
+ args_grad.append(arg.grad)
333
+ grad_out = args_grad
334
+
335
+ return grad_out
336
+
337
+
338
+ def initialize_save_error_data(error_data_path):
339
+ check_path_before_create(error_data_path)
340
+ create_directory(error_data_path)
341
+ error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
342
+ ability=FileCheckConst.WRITE_ABLE)
343
+ error_data_path = error_data_path_checker.common_check()
344
+ error_data_path = initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
345
+ return error_data_path
346
+
347
+
348
+ def init_attl(config):
349
+ """config: OnlineConfig"""
350
+ attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True,
351
+ connect_ip=config.host,
352
+ connect_port=config.port,
353
+ nfs_path=config.nfs_path,
354
+ tls_path=config.tls_path))
355
+ return attl
356
+
357
+
358
+ def _run_ut_parser(parser):
359
+ parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
360
+ help="<Optional> The api param tool result file: generate from api param tool, "
361
+ "a json file.",
362
+ required=False)
363
+ parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
364
+ help="<optional> The ut task result out path.",
365
+ required=False)
366
+ parser.add_argument('-save_error_data', dest="save_error_data", action="store_true",
367
+ help="<optional> Save compare failed api output.", required=False)
368
+ parser.add_argument("-j", "--jit_compile", dest="jit_compile", action="store_true",
369
+ help="<optional> whether to turn on jit compile", required=False)
370
+
371
+ class UniqueDeviceAction(argparse.Action):
372
+ def __call__(self, parser, namespace, values, option_string=None):
373
+ unique_values = set(values)
374
+ if len(values) != len(unique_values):
375
+ parser.error("device id must be unique")
376
+ for device_id in values:
377
+ if not 0 <= device_id:
378
+ parser.error("device id must be greater than or equal to 0")
379
+ setattr(namespace, self.dest, values)
380
+
381
+ parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int,
382
+ help="<optional> set device id to run ut, must be unique and in range 0-7",
383
+ default=[0], required=False, action=UniqueDeviceAction)
384
+ parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str,
385
+ help="<optional> The path of accuracy_checking_result_{timestamp}.csv, "
386
+ "when run ut is interrupted, enter the file path to continue run ut.",
387
+ required=False)
388
+ parser.add_argument("-f", "--filter_api", dest="filter_api", action="store_true",
389
+ help="<optional> Whether to filter the api in the api_info_file.", required=False)
390
+ parser.add_argument("-config", "--config_path", dest="config_path", default="", type=str,
391
+ help="<optional> The path of config.json", required=False)
392
+
393
+
394
+ def preprocess_forward_content(forward_content):
395
+ processed_content = {}
396
+ base_keys_variants = {}
397
+ arg_cache = {}
398
+
399
+ for key, value in forward_content.items():
400
+ base_key = key.rsplit(Const.SEP, 1)[0]
401
+
402
+ if key not in arg_cache:
403
+ filtered_new_args = [
404
+ {k: v for k, v in arg.items() if k not in ['Max', 'Min']}
405
+ for arg in value['input_args']
406
+ if isinstance(arg, dict)
407
+ ]
408
+ arg_cache[key] = (filtered_new_args, value['input_kwargs'])
409
+
410
+ filtered_new_args, new_kwargs = arg_cache[key]
411
+
412
+ if base_key not in base_keys_variants:
413
+ processed_content[key] = value
414
+ base_keys_variants[base_key] = {key}
415
+ else:
416
+ is_duplicate = False
417
+ for variant in base_keys_variants.get(base_key, []):
418
+ try:
419
+ existing_args, existing_kwargs = arg_cache.get(variant)
420
+ except KeyError as e:
421
+ logger.error(f"KeyError: {e} when processing {key}")
422
+ if existing_args == filtered_new_args and existing_kwargs == new_kwargs:
423
+ is_duplicate = True
424
+ break
425
+
426
+ if not is_duplicate:
427
+ processed_content[key] = value
428
+ base_keys_variants[base_key].add(key)
429
+
430
+ return processed_content
431
+
432
+
433
+ def _run_ut(parser=None):
434
+ if not parser:
435
+ parser = argparse.ArgumentParser()
436
+ _run_ut_parser(parser)
437
+ args = parser.parse_args(sys.argv[1:])
438
+ run_ut_command(args)
439
+
440
+
441
+ def run_ut_command(args):
442
+ if not is_gpu:
443
+ torch.npu.set_compile_mode(jit_compile=args.jit_compile)
444
+ used_device = current_device + ":" + str(args.device_id[0])
445
+ try:
446
+ if is_gpu:
447
+ torch.cuda.set_device(used_device)
448
+ else:
449
+ torch.npu.set_device(used_device)
450
+ except Exception as error:
451
+ logger.error(f"Set device id failed. device id is: {args.device_id}")
452
+ raise NotImplementedError from error
453
+
454
+ # 在线预检场景下,不需要外出输出api信息,forward_content, backward_content, real_data_path设置为None
455
+ # 离线场景下,forward_content, backward_content, real_data_path从api_info_file中解析
456
+ forward_content, backward_content, real_data_path = None, None, None
457
+ if args.api_info_file:
458
+ api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
459
+ ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
460
+ checked_api_info = api_info_file_checker.common_check()
461
+ forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info)
462
+ if args.filter_api:
463
+ logger.info("Start filtering the api in the api_info_file.")
464
+ forward_content = preprocess_forward_content(forward_content)
465
+ logger.info("Finish filtering the api in the api_info_file.")
466
+
467
+ out_path = os.path.realpath(args.out_path) if args.out_path else "./"
468
+ check_path_before_create(out_path)
469
+ create_directory(out_path)
470
+ out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
471
+ out_path = out_path_checker.common_check()
472
+ save_error_data = args.save_error_data
473
+
474
+ result_csv_path = os.path.join(out_path, RESULT_FILE_NAME)
475
+ details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME)
476
+ if args.result_csv_path:
477
+ result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
478
+ details_csv_path = get_validated_details_csv_path(result_csv_path)
479
+ white_list = msCheckerConfig.white_list
480
+ black_list = msCheckerConfig.black_list
481
+ error_data_path = msCheckerConfig.error_data_path
482
+ is_online = msCheckerConfig.is_online
483
+ nfs_path = msCheckerConfig.nfs_path
484
+ host = msCheckerConfig.host
485
+ port = msCheckerConfig.port
486
+ rank_list = msCheckerConfig.rank_list
487
+ tls_path = msCheckerConfig.tls_path
488
+ if args.config_path:
489
+ config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
490
+ FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
491
+ checked_config_path = config_path_checker.common_check()
492
+ _, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
493
+ white_list = task_config.white_list
494
+ black_list = task_config.black_list
495
+ error_data_path = task_config.error_data_path
496
+ is_online = task_config.is_online
497
+ nfs_path = task_config.nfs_path
498
+ host = task_config.host
499
+ port = task_config.port
500
+ rank_list = task_config.rank_list
501
+ tls_path = task_config.tls_path
502
+
503
+ if save_error_data:
504
+ if args.result_csv_path:
505
+ time_info = result_csv_path.split('.')[0].split('_')[-1]
506
+ global UT_ERROR_DATA_DIR
507
+ UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
508
+ error_data_path = initialize_save_error_data(error_data_path)
509
+ online_config = OnlineConfig(is_online, nfs_path, host, port, rank_list, tls_path)
510
+ run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data,
511
+ args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path,
512
+ online_config)
513
+ run_ut(run_ut_config)
514
+
515
+
516
+ if __name__ == '__main__':
517
+ _run_ut()
518
+ logger.info("UT task completed.")