mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,493 +1,592 @@
1
- import argparse
2
- import os
3
- import csv
4
- import re
5
- import sys
6
- import time
7
- import gc
8
- from collections import namedtuple
9
-
10
- try:
11
- import torch_npu
12
- except ImportError:
13
- is_gpu = True
14
- current_device = "cuda"
15
- else:
16
- is_gpu = False
17
- current_device = "npu"
18
- import torch
19
- from tqdm import tqdm
20
-
21
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api
22
- from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
23
- from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents, api_info_preprocess, \
24
- initialize_save_path, UtDataProcessor
25
- from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
26
- from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
27
- from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
28
- from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
29
- from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
30
- from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
31
- from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
32
- from msprobe.core.common.file_check import FileOpen, FileChecker, \
33
- change_mode, check_file_suffix, check_link, check_path_before_create, create_directory
34
- from msprobe.pytorch.common.log import logger
35
- from msprobe.core.common.const import Const, FileCheckConst, CompareConst
36
-
37
- current_time = time.strftime("%Y%m%d%H%M%S")
38
- UT_ERROR_DATA_DIR = 'ut_error_data' + current_time
39
- RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
40
- DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
41
- RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
42
- 'save_error_data', 'is_continue_run_ut', 'real_data_path'])
43
- not_backward_list = ['repeat_interleave']
44
- not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
45
- not_raise_dtype_set = {'type_as'}
46
-
47
- RAISE_PRECISION = {
48
- torch.float16: torch.float32,
49
- torch.bfloat16: torch.float32,
50
- torch.float32: torch.float64
51
- }
52
-
53
- tqdm_params = {
54
- 'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1
55
- 'desc': 'Processing', # 进度条前的描述文字
56
- 'leave': True, # 迭代完成后保留进度条的显示
57
- 'ncols': 75, # 进度条的固定宽度
58
- 'mininterval': 0.1, # 更新进度条的最小间隔秒数
59
- 'maxinterval': 1.0, # 更新进度条的最大间隔秒数
60
- 'miniters': 1, # 更新进度条之间的最小迭代次数
61
- 'ascii': None, # 根据环境自动使用ASCII或Unicode字符
62
- 'unit': 'it', # 迭代单位
63
- 'unit_scale': True, # 自动根据单位缩放
64
- 'dynamic_ncols': True, # 动态调整进度条宽度以适应控制台
65
- 'bar_format': '{l_bar}{bar}| {n}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]' # 自定义进度条输出格式
66
- }
67
-
68
-
69
- def exec_api(api_type, api_name, args, kwargs):
70
- if api_type == "Functional":
71
- functional_api = FunctionalOPTemplate(api_name, str, False)
72
- out = functional_api.forward(*args, **kwargs)
73
- if api_type == "Tensor":
74
- tensor_api = TensorOPTemplate(api_name, str, False)
75
- out = tensor_api.forward(*args, **kwargs)
76
- if api_type == "Torch":
77
- torch_api = TorchOPTemplate(api_name, str, False)
78
- out = torch_api.forward(*args, **kwargs)
79
- return out
80
-
81
-
82
- def deal_detach(arg, to_detach=True):
83
- return arg.detach() if to_detach else arg
84
-
85
-
86
- def raise_bench_data_dtype(api_name, arg, raise_dtype=None):
87
- '''
88
- 将标杆数据的dtype转换为raise_dtype
89
- 输入:
90
- api_name:api名称
91
- arg:标杆输入
92
- raise_dtype:需要转换的dtype
93
- 输出:
94
- arg: 转换dtype的标杆输入
95
- '''
96
- if api_name in hf_32_standard_api and arg.dtype == torch.float32:
97
- return arg
98
- if raise_dtype is None or arg.dtype not in RAISE_PRECISION or raise_dtype == arg.dtype:
99
- return arg
100
- return arg.type(raise_dtype)
101
-
102
-
103
- def generate_device_params(input_args, input_kwargs, need_backward, api_name):
104
- def recursive_arg_to_device(arg_in, to_detach):
105
- if isinstance(arg_in, (list, tuple)):
106
- return type(arg_in)(recursive_arg_to_device(arg, to_detach) for arg in arg_in)
107
- elif isinstance(arg_in, torch.Tensor):
108
- if need_backward and arg_in.requires_grad:
109
- arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_()
110
- temp_arg_in = arg_in * 1
111
- arg_in = temp_arg_in.type_as(arg_in)
112
- arg_in.retain_grad()
113
- return arg_in
114
- else:
115
- return deal_detach(arg_in.clone(), to_detach).to(current_device)
116
- else:
117
- return arg_in
118
-
119
- is_detach = api_name not in not_detach_set
120
- device_args = recursive_arg_to_device(input_args, is_detach)
121
- device_kwargs = \
122
- {key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()}
123
- return device_args, device_kwargs
124
-
125
-
126
- def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
127
- def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None):
128
- if isinstance(arg_in, (list, tuple)):
129
- return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype) for arg in arg_in)
130
- elif isinstance(arg_in, torch.Tensor):
131
- if need_backward and arg_in.requires_grad:
132
- arg_in = deal_detach(raise_bench_data_dtype(
133
- api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
134
- temp_arg_in = arg_in * 1
135
- arg_in = temp_arg_in.type_as(arg_in)
136
- arg_in.retain_grad()
137
- return arg_in
138
- else:
139
- return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach)
140
- else:
141
- return arg_in
142
-
143
- def is_tensor_with_raise_precision(arg_in, check_kwargs=False):
144
- if arg_in.dtype in RAISE_PRECISION:
145
- return True
146
- if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]:
147
- return True
148
- return False
149
-
150
- def recursive_find_dtypes(arg_in, kwargs=None, check_kwargs=False):
151
- if isinstance(arg_in, (list, tuple)):
152
- return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs) for arg in arg_in))
153
- elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
154
- return set([arg_in.dtype])
155
- elif isinstance(arg_in, dict) and check_kwargs:
156
- return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True) for v in arg_in.values()))
157
- return set()
158
-
159
- raise_dtype = None
160
- need_raise_dtypes = recursive_find_dtypes(input_args)
161
- need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
162
- if len(need_raise_dtypes) == 1:
163
- raise_dtype = RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32)
164
- elif len(need_raise_dtypes) >= 2:
165
- raise_dtype = torch.float32
166
-
167
- raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
168
- is_detach = api_name not in not_detach_set
169
- cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
170
- cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()}
171
- return cpu_args, cpu_kwargs
172
-
173
-
174
- def run_ut(config):
175
- logger.info("start UT test")
176
- logger.info(f"UT task result will be saved in {config.result_csv_path}")
177
- logger.info(f"UT task details will be saved in {config.details_csv_path}")
178
- if config.save_error_data:
179
- error_data_path = os.path.abspath(os.path.join(msCheckerConfig.error_data_path, UT_ERROR_DATA_DIR))
180
- logger.info(f"UT task error_datas will be saved in {error_data_path}")
181
- compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut)
182
- with FileOpen(config.result_csv_path, 'r') as file:
183
- csv_reader = csv.reader(file)
184
- next(csv_reader)
185
- api_name_set = {row[0] for row in csv_reader}
186
- for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
187
- if api_full_name in api_name_set:
188
- continue
189
- if is_unsupported_api(api_full_name): # TODO run_ut does not support to the npu fusion api and distributed api
190
- continue
191
- try:
192
- if msCheckerConfig.white_list:
193
- [_, api_name, _] = api_full_name.split(Const.SEP)
194
- if api_name not in set(msCheckerConfig.white_list):
195
- continue
196
- data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict)
197
- is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info)
198
- if config.save_error_data:
199
- do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success)
200
- except Exception as err:
201
- [_, api_name, _] = api_full_name.split(Const.SEP)
202
- if "expected scalar type Long" in str(err):
203
- logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
204
- f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
205
- else:
206
- logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
207
- err_column = CompareColumn()
208
- fwd_compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
209
- result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [fwd_compare_alg_results], None, 0)
210
- compare.record_results(result_info)
211
- finally:
212
- if is_gpu:
213
- torch.cuda.empty_cache()
214
- else:
215
- torch.npu.empty_cache()
216
- gc.collect()
217
- change_mode(compare.save_path, FileCheckConst.DATA_FILE_AUTHORITY)
218
- change_mode(compare.detail_save_path, FileCheckConst.DATA_FILE_AUTHORITY)
219
- compare.print_pretest_result()
220
-
221
-
222
- def is_unsupported_api(api_name):
223
- split_name = api_name.split(Const.SEP)[0]
224
- flag = split_name in [Const.NPU, Const.DISTRIBUTED]
225
- if flag:
226
- logger.info(f"{split_name} api is not supported for run ut. SKIP.")
227
- return flag
228
-
229
-
230
- def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success):
231
- if not is_fwd_success or not is_bwd_success:
232
- processor = UtDataProcessor(os.path.join(msCheckerConfig.error_data_path, UT_ERROR_DATA_DIR))
233
- for element in data_info.in_fwd_data_list:
234
- processor.save_tensors_in_element(api_full_name + '.forward.input', element)
235
- processor.save_tensors_in_element(api_full_name + '.forward.output.bench', data_info.bench_out)
236
- processor.save_tensors_in_element(api_full_name + '.forward.output.device', data_info.device_out)
237
- processor.save_tensors_in_element(api_full_name + '.backward.input', data_info.grad_in)
238
- processor.save_tensors_in_element(api_full_name + '.backward.output.bench', data_info.bench_grad_out)
239
- processor.save_tensors_in_element(api_full_name + '.backward.output.device', data_info.device_grad_out)
240
-
241
-
242
- def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict):
243
- in_fwd_data_list = []
244
- backward_message = ''
245
- [api_type, api_name, _] = api_full_name.split(Const.SEP)
246
- args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
247
- in_fwd_data_list.append(args)
248
- in_fwd_data_list.append(kwargs)
249
- need_backward = api_full_name in backward_content
250
- if not need_grad:
251
- logger.warning("%s %s" % (api_full_name, Backward_Message.UNSUPPORT_BACKWARD_MESSAGE))
252
- backward_message += Backward_Message.UNSUPPORT_BACKWARD_MESSAGE
253
- if api_name in not_backward_list:
254
- need_grad = False
255
- logger.warning("%s %s" % (api_full_name, Backward_Message.NO_BACKWARD_RESULT_MESSAGE))
256
- backward_message += Backward_Message.NO_BACKWARD_RESULT_MESSAGE
257
- need_backward = need_backward and need_grad
258
- if kwargs.get("device"):
259
- del kwargs["device"]
260
- cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward, api_name)
261
- device_args, device_kwargs = generate_device_params(args, kwargs, need_backward, api_name)
262
- bench_grad_out, device_grad_out = None, None
263
- out = exec_api(api_type, api_name, cpu_args, cpu_kwargs)
264
- device_out = exec_api(api_type, api_name, device_args, device_kwargs)
265
- current_path = os.path.dirname(os.path.realpath(__file__))
266
- ut_setting_path = os.path.join(current_path, "torch_ut_setting.json")
267
- api_setting_dict = get_json_contents(ut_setting_path)
268
- grad_input_index = api_setting_dict.get(api_name)
269
- grad_index = None
270
- grad, bench_grad = None, None
271
- if grad_input_index is not None:
272
- grad_index = grad_input_index.get('grad_index')
273
-
274
- if need_backward:
275
- if need_to_backward(grad_index, out):
276
- backward_args = backward_content[api_full_name].get("grad_output")
277
- grad = gen_args(backward_args, api_name, real_data_path=real_data_path)[0]
278
- bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
279
- bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
280
- device_grad = grad.clone().detach().to(current_device)
281
- device_grad_out = run_backward(device_args, device_grad, grad_index, device_out)
282
- else:
283
- backward_message += Backward_Message.MULTIPLE_BACKWARD_MESSAGE
284
-
285
- return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
286
-
287
-
288
- def get_api_info(api_info_dict, api_name, real_data_path):
289
- convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
290
- need_grad = True
291
- if api_info_dict.get("input_kwargs") and "out" in api_info_dict.get("input_kwargs"):
292
- need_grad = False
293
- args, kwargs = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path)
294
- return args, kwargs, need_grad
295
-
296
-
297
- def need_to_backward(grad_index, out):
298
- if grad_index is None and isinstance(out, (list, tuple)):
299
- return False
300
- return True
301
-
302
-
303
- def run_backward(args, grad, grad_index, out):
304
- if grad_index is not None:
305
- out[grad_index].backward(grad)
306
- else:
307
- out.backward(grad)
308
- args_grad = []
309
- for arg in args:
310
- if isinstance(arg, torch.Tensor):
311
- args_grad.append(arg.grad)
312
- grad_out = args_grad
313
-
314
- return grad_out
315
-
316
-
317
- def initialize_save_error_data():
318
- error_data_path = msCheckerConfig.error_data_path
319
- check_path_before_create(error_data_path)
320
- create_directory(error_data_path)
321
- error_data_path_checker = FileChecker(msCheckerConfig.error_data_path, FileCheckConst.DIR,
322
- ability=FileCheckConst.WRITE_ABLE)
323
- error_data_path = error_data_path_checker.common_check()
324
- initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
325
-
326
-
327
- def get_validated_result_csv_path(result_csv_path, mode):
328
- if mode not in ['result', 'detail']:
329
- raise ValueError("The csv mode must be result or detail")
330
- result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE,
331
- file_type=FileCheckConst.CSV_SUFFIX)
332
- validated_result_csv_path = result_csv_path_checker.common_check()
333
- if mode == 'result':
334
- result_csv_name = os.path.basename(validated_result_csv_path)
335
- pattern = r"^accuracy_checking_result_\d{14}\.csv$"
336
- if not re.match(pattern, result_csv_name):
337
- raise ValueError("When continue run ut, please do not modify the result csv name.")
338
- return validated_result_csv_path
339
-
340
-
341
- def get_validated_details_csv_path(validated_result_csv_path):
342
- result_csv_name = os.path.basename(validated_result_csv_path)
343
- details_csv_name = result_csv_name.replace('result', 'details')
344
- details_csv_path = os.path.join(os.path.dirname(validated_result_csv_path), details_csv_name)
345
- details_csv_path_checker = FileChecker(details_csv_path, FileCheckConst.FILE,
346
- ability=FileCheckConst.READ_WRITE_ABLE, file_type=FileCheckConst.CSV_SUFFIX)
347
- validated_details_csv_path = details_csv_path_checker.common_check()
348
- return validated_details_csv_path
349
-
350
-
351
- def _run_ut_parser(parser):
352
- parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
353
- help="<Required> The api param tool result file: generate from api param tool, "
354
- "a json file.",
355
- required=True)
356
- parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
357
- help="<optional> The ut task result out path.",
358
- required=False)
359
- parser.add_argument('-save_error_data', dest="save_error_data", action="store_true",
360
- help="<optional> Save compare failed api output.", required=False)
361
- parser.add_argument("-j", "--jit_compile", dest="jit_compile", action="store_true",
362
- help="<optional> whether to turn on jit compile", required=False)
363
-
364
- class UniqueDeviceAction(argparse.Action):
365
- def __call__(self, parser, namespace, values, option_string=None):
366
- unique_values = set(values)
367
- if len(values) != len(unique_values):
368
- parser.error("device id must be unique")
369
- for device_id in values:
370
- if not 0 <= device_id:
371
- parser.error("device id must be greater than or equal to 0")
372
- setattr(namespace, self.dest, values)
373
-
374
- parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int,
375
- help="<optional> set device id to run ut, must be unique and in range 0-7",
376
- default=[0], required=False, action=UniqueDeviceAction)
377
- parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str,
378
- help="<optional> The path of accuracy_checking_result_{timestamp}.csv, "
379
- "when run ut is interrupted, enter the file path to continue run ut.",
380
- required=False)
381
- parser.add_argument("-real_data_path", dest="real_data_path", nargs="?", const="", default="", type=str,
382
- help="<optional> In real data mode, the root directory for storing real data "
383
- "must be configured.",
384
- required=False)
385
- parser.add_argument("-f", "--filter_api", dest="filter_api", action="store_true",
386
- help="<optional> Whether to filter the api in the api_info_file.", required=False)
387
-
388
-
389
- def preprocess_forward_content(forward_content):
390
- processed_content = {}
391
- base_keys_variants = {}
392
- arg_cache = {}
393
-
394
- for key, value in forward_content.items():
395
- base_key = key.rsplit(Const.SEP, 1)[0]
396
-
397
- if key not in arg_cache:
398
- filtered_new_args = [
399
- {k: v for k, v in arg.items() if k not in ['Max', 'Min']}
400
- for arg in value['args'] if isinstance(arg, dict)
401
- ]
402
- arg_cache[key] = (filtered_new_args, value['kwargs'])
403
-
404
- filtered_new_args, new_kwargs = arg_cache[key]
405
-
406
- if base_key not in base_keys_variants:
407
- processed_content[key] = value
408
- base_keys_variants[base_key] = {key}
409
- else:
410
- is_duplicate = False
411
- for variant in base_keys_variants.get(base_key, []):
412
- try:
413
- existing_args, existing_kwargs = arg_cache.get(variant)
414
- except KeyError as e:
415
- logger.error(f"KeyError: {e} when processing {key}")
416
- if existing_args == filtered_new_args and existing_kwargs == new_kwargs:
417
- is_duplicate = True
418
- break
419
-
420
- if not is_duplicate:
421
- processed_content[key] = value
422
- base_keys_variants[base_key].add(key)
423
-
424
- return processed_content
425
-
426
-
427
- def _run_ut(parser=None):
428
- if not parser:
429
- parser = argparse.ArgumentParser()
430
- _run_ut_parser(parser)
431
- args = parser.parse_args(sys.argv[1:])
432
- run_ut_command(args)
433
-
434
-
435
- def run_ut_command(args):
436
- if not is_gpu:
437
- torch.npu.set_compile_mode(jit_compile=args.jit_compile)
438
- used_device = current_device + ":" + str(args.device_id[0])
439
- try:
440
- if is_gpu:
441
- torch.cuda.set_device(used_device)
442
- else:
443
- torch.npu.set_device(used_device)
444
- except Exception as error:
445
- logger.error(f"Set device id failed. device id is: {args.device_id}")
446
- raise NotImplementedError from error
447
- check_link(args.api_info_file)
448
- api_info = os.path.realpath(args.api_info_file)
449
- check_file_suffix(api_info, FileCheckConst.JSON_SUFFIX)
450
- out_path = os.path.realpath(args.out_path) if args.out_path else "./"
451
- check_path_before_create(out_path)
452
- create_directory(out_path)
453
- out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
454
- out_path = out_path_checker.common_check()
455
- save_error_data = args.save_error_data
456
- forward_content, backward_content, real_data_path = parse_json_info_forward_backward(api_info)
457
- if args.filter_api:
458
- logger.info("Start filtering the api in the forward_input_file.")
459
- forward_content = preprocess_forward_content(forward_content)
460
- logger.info("Finish filtering the api in the forward_input_file.")
461
-
462
- result_csv_path = os.path.join(out_path, RESULT_FILE_NAME)
463
- details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME)
464
- if args.result_csv_path:
465
- result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
466
- details_csv_path = get_validated_details_csv_path(result_csv_path)
467
- if save_error_data:
468
- if args.result_csv_path:
469
- time_info = result_csv_path.split('.')[0].split('_')[-1]
470
- global UT_ERROR_DATA_DIR
471
- UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
472
- initialize_save_error_data()
473
- run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data,
474
- args.result_csv_path, real_data_path)
475
- run_ut(run_ut_config)
476
-
477
-
478
- class UtDataInfo:
479
- def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list,
480
- backward_message, rank=0):
481
- self.bench_grad = bench_grad
482
- self.device_grad = device_grad
483
- self.device_output = device_output
484
- self.bench_output = bench_output
485
- self.grad_in = grad_in
486
- self.in_fwd_data_list = in_fwd_data_list
487
- self.backward_message = backward_message
488
- self.rank = rank
489
-
490
-
491
- if __name__ == '__main__':
492
- _run_ut()
493
- logger.info("UT task completed.")
1
+ import argparse
2
+ import os
3
+ import csv
4
+ import sys
5
+ import time
6
+ import gc
7
+ from collections import namedtuple
8
+
9
+ try:
10
+ import torch_npu
11
+ except ImportError:
12
+ is_gpu = True
13
+ current_device = "cuda"
14
+ else:
15
+ is_gpu = False
16
+ current_device = "npu"
17
+ import torch
18
+ from tqdm import tqdm
19
+
20
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api, UtDataInfo, \
21
+ get_validated_result_csv_path, get_validated_details_csv_path, exec_api
22
+ from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
23
+ from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
24
+ initialize_save_path, UtDataProcessor, extract_basic_api_segments, ApiData
25
+ from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
26
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
27
+ from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
28
+ from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
29
+ from msprobe.core.common.file_utils import FileOpen, FileChecker, \
30
+ change_mode, check_path_before_create, create_directory, get_json_contents
31
+ from msprobe.pytorch.common.log import logger
32
+ from msprobe.pytorch.pt_config import parse_json_config
33
+ from msprobe.core.common.const import Const, FileCheckConst, CompareConst
34
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
35
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
36
+
37
+
38
+ current_time = time.strftime("%Y%m%d%H%M%S")
39
+ UT_ERROR_DATA_DIR = 'ut_error_data' + current_time
40
+ RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
41
+ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
42
+ RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
43
+ 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
44
+ 'black_list', 'error_data_path', 'online_config'])
45
+
46
+ OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
47
+
48
+ not_backward_list = ['repeat_interleave']
49
+ not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
50
+ not_raise_dtype_set = {'type_as'}
51
+
52
+ RAISE_PRECISION = {
53
+ torch.float16: torch.float32,
54
+ torch.bfloat16: torch.float32,
55
+ torch.float32: torch.float64
56
+ }
57
+
58
+ tqdm_params = {
59
+ 'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1
60
+ 'desc': 'Processing', # 进度条前的描述文字
61
+ 'leave': True, # 迭代完成后保留进度条的显示
62
+ 'ncols': 75, # 进度条的固定宽度
63
+ 'mininterval': 0.1, # 更新进度条的最小间隔秒数
64
+ 'maxinterval': 1.0, # 更新进度条的最大间隔秒数
65
+ 'miniters': 1, # 更新进度条之间的最小迭代次数
66
+ 'ascii': None, # 根据环境自动使用ASCII或Unicode字符
67
+ 'unit': 'it', # 迭代单位
68
+ 'unit_scale': True, # 自动根据单位缩放
69
+ 'dynamic_ncols': True, # 动态调整进度条宽度以适应控制台
70
+ 'bar_format': '{l_bar}{bar}| {n}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]' # 自定义进度条输出格式
71
+ }
72
+
73
+
74
+ def deal_detach(arg, to_detach=True):
75
+ return arg.detach() if to_detach else arg
76
+
77
+
78
+ def raise_bench_data_dtype(api_name, arg, raise_dtype=None):
79
+ '''
80
+ 将标杆数据的dtype转换为raise_dtype
81
+ 输入:
82
+ api_name:api名称
83
+ arg:标杆输入
84
+ raise_dtype:需要转换的dtype
85
+ 输出:
86
+ arg: 转换dtype的标杆输入
87
+ '''
88
+ if api_name in hf_32_standard_api and arg.dtype == torch.float32:
89
+ return arg
90
+ if raise_dtype is None or arg.dtype not in RAISE_PRECISION or raise_dtype == arg.dtype:
91
+ return arg
92
+ return arg.type(raise_dtype)
93
+
94
+
95
+ def generate_device_params(input_args, input_kwargs, need_backward, api_name):
96
+ def recursive_arg_to_device(arg_in, to_detach):
97
+ if isinstance(arg_in, (list, tuple)):
98
+ return type(arg_in)(recursive_arg_to_device(arg, to_detach) for arg in arg_in)
99
+ elif isinstance(arg_in, torch.Tensor):
100
+ if need_backward and arg_in.requires_grad:
101
+ arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_()
102
+ temp_arg_in = arg_in * 1
103
+ arg_in = temp_arg_in.type_as(arg_in)
104
+ arg_in.retain_grad()
105
+ return arg_in
106
+ else:
107
+ return deal_detach(arg_in.clone(), to_detach).to(current_device)
108
+ else:
109
+ return arg_in
110
+
111
+ is_detach = api_name not in not_detach_set
112
+ device_args = recursive_arg_to_device(input_args, is_detach)
113
+ device_kwargs = \
114
+ {key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()}
115
+ return device_args, device_kwargs
116
+
117
+
118
+ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
119
+ def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None):
120
+ if isinstance(arg_in, (list, tuple)):
121
+ return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype) for arg in arg_in)
122
+ elif isinstance(arg_in, torch.Tensor):
123
+ if need_backward and arg_in.requires_grad:
124
+ arg_in = deal_detach(raise_bench_data_dtype(
125
+ api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
126
+ temp_arg_in = arg_in * 1
127
+ arg_in = temp_arg_in.type_as(arg_in)
128
+ arg_in.retain_grad()
129
+ return arg_in
130
+ else:
131
+ return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach)
132
+ else:
133
+ return arg_in
134
+
135
+ def is_tensor_with_raise_precision(arg_in, check_kwargs=False):
136
+ if arg_in.dtype in RAISE_PRECISION:
137
+ return True
138
+ if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]:
139
+ return True
140
+ return False
141
+
142
+ def recursive_find_dtypes(arg_in, kwargs=None, check_kwargs=False):
143
+ if isinstance(arg_in, (list, tuple)):
144
+ return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs) for arg in arg_in))
145
+ elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
146
+ return set([arg_in.dtype])
147
+ elif isinstance(arg_in, dict) and check_kwargs:
148
+ return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True) for v in arg_in.values()))
149
+ return set()
150
+
151
+ raise_dtype = None
152
+ need_raise_dtypes = recursive_find_dtypes(input_args)
153
+ need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
154
+ if len(need_raise_dtypes) == 1:
155
+ raise_dtype = RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32)
156
+ elif len(need_raise_dtypes) >= 2:
157
+ raise_dtype = torch.float32
158
+
159
+ raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
160
+ is_detach = api_name not in not_detach_set
161
+ cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
162
+ cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()}
163
+ return cpu_args, cpu_kwargs
164
+
165
+
166
+ def run_ut(config):
167
+ logger.info("start UT test")
168
+ if config.online_config.is_online:
169
+ logger.info(f"UT task result will be saved in {config.result_csv_path}".replace(".csv", "_rank*.csv"))
170
+ logger.info(f"UT task details will be saved in {config.details_csv_path}".replace(".csv", "_rank*.csv"))
171
+ else:
172
+ logger.info(f"UT task result will be saved in {config.result_csv_path}")
173
+ logger.info(f"UT task details will be saved in {config.details_csv_path}")
174
+
175
+ if config.save_error_data:
176
+ logger.info(f"UT task error_datas will be saved in {config.error_data_path}")
177
+ compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
178
+
179
+ if config.online_config.is_online:
180
+ run_api_online(config, compare)
181
+ else:
182
+ with FileOpen(config.result_csv_path, 'r') as file:
183
+ csv_reader = csv.reader(file)
184
+ next(csv_reader)
185
+ api_name_set = {row[0] for row in csv_reader}
186
+ run_api_offline(config, compare, api_name_set)
187
+ for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
188
+ change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
189
+ change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
190
+ logger.info(f"UT task result csv is saved in {result_csv_path}")
191
+ logger.info(f"UT task details csv is saved in {details_csv_path}")
192
+ compare.print_pretest_result()
193
+
194
+
195
+ def run_api_offline(config, compare, api_name_set):
196
+ err_column = CompareColumn()
197
+ for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
198
+ if api_full_name in api_name_set:
199
+ continue
200
+ if is_unsupported_api(api_full_name):
201
+ continue
202
+ _, api_name = extract_basic_api_segments(api_full_name)
203
+ if not api_name:
204
+ err_message = f"API {api_full_name} not support for run ut. SKIP."
205
+ logger.error(err_message)
206
+ fwd_compare_alg_results = err_column.to_column_value(CompareConst.SKIP, err_message)
207
+ result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [fwd_compare_alg_results], None, 0)
208
+ compare.record_results(result_info)
209
+ continue
210
+ try:
211
+ if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
212
+ continue
213
+ data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict)
214
+ is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info)
215
+ if config.save_error_data:
216
+ do_save_error_data(api_full_name, data_info, config.error_data_path, is_fwd_success, is_bwd_success)
217
+ except Exception as err:
218
+ if "expected scalar type Long" in str(err):
219
+ logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
220
+ f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
221
+ else:
222
+ logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
223
+ fwd_compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
224
+ result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [fwd_compare_alg_results], None, 0)
225
+ compare.record_results(result_info)
226
+ finally:
227
+ if is_gpu:
228
+ torch.cuda.empty_cache()
229
+ else:
230
+ torch.npu.empty_cache()
231
+ gc.collect()
232
+
233
+
234
+ def run_api_online(config, compare):
235
+ attl = init_attl(config.online_config)
236
+ dispatcher = ConsumerDispatcher(compare=compare)
237
+ dispatcher.start(handle_func=run_torch_api_online, config=config)
238
+
239
+ def tcp_communication_flow():
240
+ while True:
241
+ api_data = attl.recv()
242
+ if api_data == 'STOP_':
243
+ continue
244
+ if api_data == 'KILL_':
245
+ time.sleep(1)
246
+ logger.info("==========接收到STOP信号==========")
247
+ dispatcher.stop()
248
+ attl.stop_serve()
249
+ time.sleep(1)
250
+ break
251
+ if not isinstance(api_data, ApiData):
252
+ continue
253
+ api_full_name = api_data.name
254
+ _, api_name = extract_basic_api_segments(api_full_name)
255
+ if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
256
+ continue
257
+ if api_data.rank in config.online_config.rank_list:
258
+ dispatcher.update_consume_queue(api_data)
259
+
260
+ def shared_storage_communication_flow():
261
+ flag_num = -1
262
+ while True:
263
+ api_data = attl.download()
264
+ if api_data == "start":
265
+ if flag_num == -1:
266
+ flag_num += 1
267
+ flag_num += 1
268
+ if api_data == "end":
269
+ flag_num -= 1
270
+ if flag_num == 0:
271
+ dispatcher.stop()
272
+ break
273
+ if not isinstance(api_data, ApiData):
274
+ continue
275
+ api_full_name = api_data.name
276
+ _, api_name = extract_basic_api_segments(api_full_name)
277
+ if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
278
+ continue
279
+ if api_data.rank in config.online_config.rank_list:
280
+ dispatcher.update_consume_queue(api_data)
281
+
282
+ if config.online_config.nfs_path:
283
+ shared_storage_communication_flow()
284
+ else:
285
+ tcp_communication_flow()
286
+
287
+
288
+ def blacklist_and_whitelist_filter(api_name, black_list, white_list):
289
+ """
290
+ run api(api_name) if api_name not in black_list and in white_list.
291
+ If api is both in black_list and black_list, black_list first.
292
+ return: False for exec api, True for not exec
293
+ """
294
+ if black_list and api_name in black_list:
295
+ return True
296
+ if white_list and api_name not in white_list:
297
+ return True
298
+ return False
299
+
300
+
301
+ def is_unsupported_api(api_name):
302
+ split_name = api_name.split(Const.SEP)[0]
303
+ flag = split_name == Const.DISTRIBUTED
304
+ if flag:
305
+ logger.info(f"{split_name} api is not supported for run ut. SKIP.")
306
+ return flag
307
+
308
+
309
+ def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success, is_bwd_success):
310
+ if not is_fwd_success or not is_bwd_success:
311
+ processor = UtDataProcessor(error_data_path)
312
+ for element in data_info.in_fwd_data_list:
313
+ processor.save_tensors_in_element(api_full_name + '.forward.input', element)
314
+ processor.save_tensors_in_element(api_full_name + '.forward.output.bench', data_info.bench_output)
315
+ processor.save_tensors_in_element(api_full_name + '.forward.output.device', data_info.device_output)
316
+ processor.save_tensors_in_element(api_full_name + '.backward.input', data_info.grad_in)
317
+ processor.save_tensors_in_element(api_full_name + '.backward.output.bench', data_info.bench_grad)
318
+ processor.save_tensors_in_element(api_full_name + '.backward.output.device', data_info.device_grad)
319
+
320
+
321
+ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict):
322
+ in_fwd_data_list = []
323
+ backward_message = ''
324
+ api_type, api_name = extract_basic_api_segments(api_full_name)
325
+ args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
326
+ in_fwd_data_list.append(args)
327
+ in_fwd_data_list.append(kwargs)
328
+ need_backward = api_full_name in backward_content
329
+ if not need_grad:
330
+ logger.warning("%s %s" % (api_full_name, Backward_Message.UNSUPPORT_BACKWARD_MESSAGE))
331
+ backward_message += Backward_Message.UNSUPPORT_BACKWARD_MESSAGE
332
+ if api_name in not_backward_list:
333
+ need_grad = False
334
+ logger.warning("%s %s" % (api_full_name, Backward_Message.NO_BACKWARD_RESULT_MESSAGE))
335
+ backward_message += Backward_Message.NO_BACKWARD_RESULT_MESSAGE
336
+ need_backward = need_backward and need_grad
337
+ if kwargs.get("device"):
338
+ del kwargs["device"]
339
+ cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward, api_name)
340
+ device_args, device_kwargs = generate_device_params(args, kwargs, need_backward, api_name)
341
+ bench_grad_out, device_grad_out = None, None
342
+ out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs)
343
+ device_out = exec_api(api_type, api_name, current_device, device_args, device_kwargs)
344
+ current_path = os.path.dirname(os.path.realpath(__file__))
345
+ ut_setting_path = os.path.join(current_path, "torch_ut_setting.json")
346
+ api_setting_dict = get_json_contents(ut_setting_path)
347
+ grad_input_index = api_setting_dict.get(api_name)
348
+ grad_index = None
349
+ grad, bench_grad = None, None
350
+ if grad_input_index is not None:
351
+ grad_index = grad_input_index.get('grad_index')
352
+
353
+ if need_backward:
354
+ if need_to_backward(grad_index, out):
355
+ backward_args = backward_content[api_full_name].get("input")
356
+ grad = gen_args(backward_args, api_name, real_data_path=real_data_path)[0]
357
+ bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
358
+ bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
359
+ device_grad = grad.clone().detach().to(current_device)
360
+ device_grad_out = run_backward(device_args, device_grad, grad_index, device_out)
361
+ else:
362
+ backward_message += Backward_Message.MULTIPLE_BACKWARD_MESSAGE
363
+ if api_name == "npu_fusion_attention":
364
+ out = out[0]
365
+ device_out = device_out[0]
366
+
367
+ return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
368
+
369
+
370
+ def run_torch_api_online(api_full_name, api_data, backward_content):
371
+ in_fwd_data_list = []
372
+ api_type, api_name = extract_basic_api_segments(api_full_name)
373
+ args, kwargs, out = api_data.args, api_data.kwargs, api_data.result
374
+ in_fwd_data_list.append(args)
375
+ in_fwd_data_list.append(kwargs)
376
+ if kwargs.get("device"):
377
+ del kwargs["device"]
378
+
379
+ device_out = exec_api(api_type, api_name, Const.CUDA_LOWERCASE, args, kwargs)
380
+ device_out = move2device_exec(device_out, "cpu")
381
+ return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
382
+
383
+
384
+ def get_api_info(api_info_dict, api_name, real_data_path):
385
+ convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
386
+ need_grad = True
387
+ if api_info_dict.get("input_kwargs") and "out" in api_info_dict.get("input_kwargs"):
388
+ need_grad = False
389
+ args, kwargs = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path)
390
+ return args, kwargs, need_grad
391
+
392
+
393
+ def need_to_backward(grad_index, out):
394
+ if grad_index is None and isinstance(out, (list, tuple)):
395
+ return False
396
+ return True
397
+
398
+
399
+ def run_backward(args, grad, grad_index, out):
400
+ if grad_index is not None:
401
+ out[grad_index].backward(grad)
402
+ else:
403
+ out.backward(grad)
404
+ args_grad = []
405
+ for arg in args:
406
+ if isinstance(arg, torch.Tensor):
407
+ args_grad.append(arg.grad)
408
+ grad_out = args_grad
409
+
410
+ return grad_out
411
+
412
+
413
+ def initialize_save_error_data(error_data_path):
414
+ check_path_before_create(error_data_path)
415
+ create_directory(error_data_path)
416
+ error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
417
+ ability=FileCheckConst.WRITE_ABLE)
418
+ error_data_path = error_data_path_checker.common_check()
419
+ error_data_path =initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
420
+ return error_data_path
421
+
422
+
423
+ def init_attl(config):
424
+ """config: OnlineConfig"""
425
+ attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True,
426
+ connect_ip=config.host,
427
+ connect_port=config.port,
428
+ nfs_path=config.nfs_path,
429
+ tls_path=config.tls_path))
430
+ return attl
431
+
432
+
433
+ def _run_ut_parser(parser):
434
+ parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
435
+ help="<Optional> The api param tool result file: generate from api param tool, "
436
+ "a json file.",
437
+ required=False)
438
+ parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
439
+ help="<optional> The ut task result out path.",
440
+ required=False)
441
+ parser.add_argument('-save_error_data', dest="save_error_data", action="store_true",
442
+ help="<optional> Save compare failed api output.", required=False)
443
+ parser.add_argument("-j", "--jit_compile", dest="jit_compile", action="store_true",
444
+ help="<optional> whether to turn on jit compile", required=False)
445
+
446
+ class UniqueDeviceAction(argparse.Action):
447
+ def __call__(self, parser, namespace, values, option_string=None):
448
+ unique_values = set(values)
449
+ if len(values) != len(unique_values):
450
+ parser.error("device id must be unique")
451
+ for device_id in values:
452
+ if not 0 <= device_id:
453
+ parser.error("device id must be greater than or equal to 0")
454
+ setattr(namespace, self.dest, values)
455
+
456
+ parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int,
457
+ help="<optional> set device id to run ut, must be unique and in range 0-7",
458
+ default=[0], required=False, action=UniqueDeviceAction)
459
+ parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str,
460
+ help="<optional> The path of accuracy_checking_result_{timestamp}.csv, "
461
+ "when run ut is interrupted, enter the file path to continue run ut.",
462
+ required=False)
463
+ parser.add_argument("-f", "--filter_api", dest="filter_api", action="store_true",
464
+ help="<optional> Whether to filter the api in the api_info_file.", required=False)
465
+ parser.add_argument("-config", "--config_path", dest="config_path", default="", type=str,
466
+ help="<optional> The path of config.json", required=False)
467
+
468
+
469
+ def preprocess_forward_content(forward_content):
470
+ processed_content = {}
471
+ base_keys_variants = {}
472
+ arg_cache = {}
473
+
474
+ for key, value in forward_content.items():
475
+ base_key = key.rsplit(Const.SEP, 1)[0]
476
+
477
+ if key not in arg_cache:
478
+ filtered_new_args = [
479
+ {k: v for k, v in arg.items() if k not in ['Max', 'Min']}
480
+ for arg in value['input_args'] if isinstance(arg, dict)
481
+ ]
482
+ arg_cache[key] = (filtered_new_args, value['input_kwargs'])
483
+
484
+ filtered_new_args, new_kwargs = arg_cache[key]
485
+
486
+ if base_key not in base_keys_variants:
487
+ processed_content[key] = value
488
+ base_keys_variants[base_key] = {key}
489
+ else:
490
+ is_duplicate = False
491
+ for variant in base_keys_variants.get(base_key, []):
492
+ try:
493
+ existing_args, existing_kwargs = arg_cache.get(variant)
494
+ except KeyError as e:
495
+ logger.error(f"KeyError: {e} when processing {key}")
496
+ if existing_args == filtered_new_args and existing_kwargs == new_kwargs:
497
+ is_duplicate = True
498
+ break
499
+
500
+ if not is_duplicate:
501
+ processed_content[key] = value
502
+ base_keys_variants[base_key].add(key)
503
+
504
+ return processed_content
505
+
506
+
507
+ def _run_ut(parser=None):
508
+ if not parser:
509
+ parser = argparse.ArgumentParser()
510
+ _run_ut_parser(parser)
511
+ args = parser.parse_args(sys.argv[1:])
512
+ run_ut_command(args)
513
+
514
+
515
+ def run_ut_command(args):
516
+ if not is_gpu:
517
+ torch.npu.set_compile_mode(jit_compile=args.jit_compile)
518
+ used_device = current_device + ":" + str(args.device_id[0])
519
+ try:
520
+ if is_gpu:
521
+ torch.cuda.set_device(used_device)
522
+ else:
523
+ torch.npu.set_device(used_device)
524
+ except Exception as error:
525
+ logger.error(f"Set device id failed. device id is: {args.device_id}")
526
+ raise NotImplementedError from error
527
+
528
+ # 在线预检场景下,不需要外出输出api信息,forward_content, backward_content, real_data_path设置为None
529
+ # 离线场景下,forward_content, backward_content, real_data_path从api_info_file中解析
530
+ forward_content, backward_content, real_data_path = None, None, None
531
+ if args.api_info_file:
532
+ api_info_file_checker = FileChecker(file_path = args.api_info_file, path_type = FileCheckConst.FILE,
533
+ ability = FileCheckConst.READ_ABLE, file_type = FileCheckConst.JSON_SUFFIX)
534
+ checked_api_info = api_info_file_checker.common_check()
535
+ forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info)
536
+ if args.filter_api:
537
+ logger.info("Start filtering the api in the forward_input_file.")
538
+ forward_content = preprocess_forward_content(forward_content)
539
+ logger.info("Finish filtering the api in the forward_input_file.")
540
+
541
+ out_path = os.path.realpath(args.out_path) if args.out_path else "./"
542
+ check_path_before_create(out_path)
543
+ create_directory(out_path)
544
+ out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
545
+ out_path = out_path_checker.common_check()
546
+ save_error_data = args.save_error_data
547
+
548
+ result_csv_path = os.path.join(out_path, RESULT_FILE_NAME)
549
+ details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME)
550
+ if args.result_csv_path:
551
+ result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
552
+ details_csv_path = get_validated_details_csv_path(result_csv_path)
553
+ white_list = msCheckerConfig.white_list
554
+ black_list = msCheckerConfig.black_list
555
+ error_data_path = msCheckerConfig.error_data_path
556
+ is_online = msCheckerConfig.is_online
557
+ nfs_path = msCheckerConfig.nfs_path
558
+ host = msCheckerConfig.host
559
+ port = msCheckerConfig.port
560
+ rank_list = msCheckerConfig.rank_list
561
+ tls_path = msCheckerConfig.tls_path
562
+ if args.config_path:
563
+ config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
564
+ FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
565
+ checked_config_path = config_path_checker.common_check()
566
+ _, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
567
+ white_list = task_config.white_list
568
+ black_list = task_config.black_list
569
+ error_data_path = task_config.error_data_path
570
+ is_online = task_config.is_online
571
+ nfs_path = task_config.nfs_path
572
+ host = task_config.host
573
+ port = task_config.port
574
+ rank_list = task_config.rank_list
575
+ tls_path = task_config.tls_path
576
+
577
+ if save_error_data:
578
+ if args.result_csv_path:
579
+ time_info = result_csv_path.split('.')[0].split('_')[-1]
580
+ global UT_ERROR_DATA_DIR
581
+ UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
582
+ error_data_path = initialize_save_error_data(error_data_path)
583
+ online_config = OnlineConfig(is_online, nfs_path, host, port, rank_list, tls_path)
584
+ run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data,
585
+ args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path,
586
+ online_config)
587
+ run_ut(run_ut_config)
588
+
589
+
590
+ if __name__ == '__main__':
591
+ _run_ut()
592
+ logger.info("UT task completed.")