mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,236 +1,236 @@
1
- # 进行比对及结果展示
2
- import os
3
- import sys
4
- import csv
5
- import json
6
- from collections import namedtuple
7
- from rich.table import Table
8
- from rich.console import Console
9
- from .single_compare import single_benchmark_compare_wrap
10
- from .utils import DispatchException
11
- from msprobe.core.common.const import CompareConst
12
- from msprobe.core.common.file_check import FileOpen
13
- from msprobe.pytorch.common.log import logger
14
- from msprobe.core.common.utils import CompareException
15
-
16
- ELEMENT_NUM_THRESHOLD = 100
17
- ZERO_NUM_THRESHOLD = 0.1
18
- FLOAT_PRECISION = 14
19
-
20
- ResultInfo = namedtuple('ResultInfo', ['api_name', 'is_fwd_success', 'is_bwd_success',
21
- 'fwd_compare_alg_results', 'bwd_compare_alg_results'])
22
-
23
- def get_file_content_bytes(file):
24
- with FileOpen(file, 'rb') as file_handle:
25
- return file_handle.read()
26
-
27
-
28
- def get_json_contents(file_path):
29
- ops = get_file_content_bytes(file_path)
30
- try:
31
- json_obj = json.loads(ops)
32
- except ValueError as error:
33
- logger.error('Failed to load "%s". %s' % (file_path, str(error)))
34
- raise CompareException(CompareException.INVALID_FILE_ERROR) from error
35
- if not isinstance(json_obj, dict):
36
- logger.error('Json file %s, content is not a dictionary!' % file_path)
37
- raise CompareException(CompareException.INVALID_FILE_ERROR)
38
- return json_obj
39
-
40
-
41
- def write_csv(data, filepath):
42
- with FileOpen(filepath, 'a', encoding='utf-8-sig') as f:
43
- writer = csv.writer(f)
44
- writer.writerows(data)
45
-
46
-
47
- class Saver:
48
- # consts for result csv
49
- COLUMN_API_NAME = "API name"
50
- COLUMN_FORWARD_SUCCESS = "Forward Test Success"
51
- COLUMN_BACKWARD_SUCCESS = "Backward Test Success"
52
- COLUMN_STACK_INFO = "Traceback callstack info"
53
-
54
- def __init__(self, save_path, detail_save_path, stack_info):
55
- self.save_path = save_path
56
- self.detail_save_path = detail_save_path
57
- self.stack_info = stack_info
58
-
59
- self.test_result_cnt = {
60
- "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0,
61
- "total_num": 0, "forward_or_backward_fail_num": 0
62
- }
63
-
64
- def write_csv_title(self):
65
- summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS, "Message"]]
66
- write_csv(summary_test_rows, self.save_path)
67
-
68
- detail_test_rows = [[
69
- "Npu Name", "Bench Dtype", "NPU Dtype", "Shape",
70
- "error_balance", "max_abs_diff", "max_abs_idx",
71
- "max_rel_diff", "max_rel_idx", "eb_thd",
72
- "error_thd", "Status","Message"
73
- ]]
74
- write_csv(detail_test_rows, self.detail_save_path)
75
-
76
- def print_pretest_result(self):
77
- self.get_statistics_from_result_csv()
78
- if self.test_result_cnt.get("total_num") != 0:
79
- passing_rate = str(self.test_result_cnt.get("success_num") /
80
- (self.test_result_cnt.get("total_num") + sys.float_info.epsilon))
81
- else:
82
- passing_rate = "0"
83
-
84
- console = Console()
85
- table_total = Table(
86
- show_header=True, title="Overall Statistics", show_lines=True, width=75
87
- )
88
- table_total.add_column("Result")
89
- table_total.add_column("Statistics")
90
- table_total.add_row("[green]Pass[/green]", str(self.test_result_cnt.get("success_num")))
91
- table_total.add_row("[red]Fail[/red]", str(self.test_result_cnt.get("forward_and_backward_fail_num") +
92
- self.test_result_cnt.get("forward_or_backward_fail_num")))
93
- table_total.add_row("Passing Rate", passing_rate)
94
-
95
- table_detail = Table(
96
- show_header=True, title="Detail Statistics", show_lines=True, width=75
97
- )
98
- table_detail.add_column("Result")
99
- table_detail.add_column("Statistics")
100
- table_detail.add_row("Only Forward Fail", str(self.test_result_cnt.get("forward_fail_num")))
101
- table_detail.add_row("Only Backward Fail", str(self.test_result_cnt.get("backward_fail_num")))
102
- table_detail.add_row(
103
- "Both Forward & Backward Fail", str(self.test_result_cnt.get("forward_and_backward_fail_num")))
104
-
105
- console.print(table_total)
106
- console.print(table_detail)
107
-
108
- def get_statistics_from_result_csv(self):
109
- checklist = [CompareConst.TRUE, CompareConst.FALSE, CompareConst.NA, CompareConst.SKIP]
110
- with FileOpen(self.save_path, 'r') as file:
111
- reader = csv.reader(file)
112
- result_csv_rows = [row for row in reader]
113
- result_csv_name = os.path.basename(self.save_path)
114
- for item in result_csv_rows[1:]:
115
- if not isinstance(item, list) or len(item) < 3:
116
- raise ValueError("The number of columns in %s is incorrect" % result_csv_name)
117
- if not all(item[i] and item[i].upper() in checklist for i in (1, 2)):
118
- raise ValueError(
119
- "The value in the 2nd or 3rd column of %s is wrong, it must be TRUE, FALSE, SKIP or N/A"
120
- % result_csv_name)
121
- column1 = item[1].upper()
122
- column2 = item[2].upper()
123
- if column1 == CompareConst.SKIP:
124
- continue
125
- self.test_result_cnt["total_num"] += 1
126
- if column1 == CompareConst.TRUE and column2 in [CompareConst.TRUE, 'N/A']:
127
- self.test_result_cnt['success_num'] += 1
128
- elif column1 == CompareConst.FALSE and column2 == CompareConst.FALSE:
129
- self.test_result_cnt['forward_and_backward_fail_num'] += 1
130
- elif column1 == CompareConst.FALSE:
131
- self.test_result_cnt['forward_fail_num'] += 1
132
- self.test_result_cnt['forward_or_backward_fail_num'] += 1
133
- else:
134
- self.test_result_cnt['backward_fail_num'] += 1
135
- self.test_result_cnt['forward_or_backward_fail_num'] += 1
136
-
137
- def write_summary_csv(self, test_result):
138
- test_rows = []
139
- if self.stack_info:
140
- test_rows[0].append(self.COLUMN_STACK_INFO)
141
-
142
- name = test_result.api_name
143
- df_row = [test_result.api_name, test_result.is_fwd_success, test_result.is_bwd_success]
144
- if test_result.is_fwd_success == "SKIP" or test_result.is_bwd_success == "SKIP":
145
- df_row.append(test_result.fwd_compare_alg_results)
146
- if self.stack_info:
147
- stack_info = "\n".join(self.stack_info[name])
148
- df_row.append(stack_info)
149
- test_rows.append(df_row)
150
- write_csv(test_rows, self.save_path)
151
-
152
- def write_detail_csv(self, test_result):
153
- def get_rows_from_list(result, name, sub_prefix):
154
- rows = []
155
- if isinstance(result, list):
156
- for i, test_subject in enumerate(result):
157
- subject = sub_prefix + "." + name + ".output." + str(i)
158
- test_subject = ["{:.{}f}".format(item, FLOAT_PRECISION) if isinstance(item, float) else item for
159
- item in test_subject]
160
- rows.append([subject] + list(test_subject))
161
- return rows
162
-
163
- test_rows = []
164
- subject_prefix = test_result.api_name
165
- fwd_result = test_result.fwd_compare_alg_results
166
- bwd_result = test_result.bwd_compare_alg_results
167
-
168
- test_rows.extend(get_rows_from_list(fwd_result, "forward", subject_prefix))
169
- test_rows.extend(get_rows_from_list(bwd_result, "backward", subject_prefix))
170
-
171
- write_csv(test_rows, self.detail_save_path)
172
-
173
- def record_results(self, result_info):
174
- self.write_summary_csv(result_info)
175
- self.write_detail_csv(result_info)
176
-
177
-
178
- class Comparator:
179
-
180
- def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None):
181
- self.save_path = result_csv_path
182
- self.detail_save_path = details_csv_path
183
- if stack_info_json_path:
184
- self.stack_info = get_json_contents(stack_info_json_path)
185
- else:
186
- self.stack_info = None
187
- self.saver = Saver(result_csv_path, details_csv_path, self.stack_info)
188
-
189
- if is_continue_run_ut and not os.path.exists(self.save_path) and not os.path.exists(self.detail_save_path):
190
- self.saver.write_csv_title()
191
-
192
- @staticmethod
193
- def _compare_core_wrapper(bench_out, npu_out):
194
- detailed_result_total = []
195
- test_final_success = True
196
- status, details = single_benchmark_compare_wrap(npu_out, bench_out)
197
- if not isinstance(status, list):
198
- detailed_result_total.append(details)
199
- test_final_success = status
200
- else:
201
- for item, item_status in enumerate(status):
202
- detailed_result_total.append(details.get(item, 'key does not exist'))
203
- if not item_status:
204
- test_final_success = False
205
- return test_final_success, detailed_result_total
206
-
207
- @staticmethod
208
- def _compare_dropout(bench_out, npu_out):
209
- tensor_num = bench_out.numel()
210
- if tensor_num >= ELEMENT_NUM_THRESHOLD:
211
- if abs((bench_out == 0).sum() - (npu_out == 0).cpu().sum()) / tensor_num < ZERO_NUM_THRESHOLD:
212
- return True, 1
213
- else:
214
- return False, 0
215
- else:
216
- return True, 1
217
-
218
- def compare_output(self, api_name, bench_out, npu_out, bench_grad=None, npu_grad=None):
219
- if "dropout" in api_name:
220
- is_fwd_success, fwd_compare_alg_results = self._compare_dropout(bench_out, npu_out)
221
- else:
222
- is_fwd_success, fwd_compare_alg_results = self._compare_core_wrapper(bench_out, npu_out)
223
- if bench_grad and npu_grad:
224
- if "dropout" in api_name:
225
- is_bwd_success, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], npu_grad[0])
226
- else:
227
- is_bwd_success, bwd_compare_alg_results = self._compare_core_wrapper(bench_grad, npu_grad)
228
- else:
229
- is_bwd_success, bwd_compare_alg_results = True, None
230
- if is_bwd_success and bwd_compare_alg_results is None:
231
- self.saver.record_results(ResultInfo(api_name, is_fwd_success, CompareConst.NA, fwd_compare_alg_results,
232
- bwd_compare_alg_results))
233
- else:
234
- self.saver.record_results(ResultInfo(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results,
235
- bwd_compare_alg_results))
236
- return is_fwd_success, is_bwd_success
1
+ # 进行比对及结果展示
2
+ import os
3
+ import sys
4
+ import csv
5
+ import json
6
+ from collections import namedtuple
7
+ from rich.table import Table
8
+ from rich.console import Console
9
+ from msprobe.core.common.const import CompareConst, FileCheckConst
10
+ from msprobe.core.common.file_utils import FileOpen, change_mode
11
+ from msprobe.pytorch.online_dispatch.single_compare import single_benchmark_compare_wrap
12
+ from msprobe.pytorch.common.log import logger
13
+ from msprobe.core.common.utils import CompareException
14
+
15
+ ELEMENT_NUM_THRESHOLD = 100
16
+ ZERO_NUM_THRESHOLD = 0.1
17
+ FLOAT_PRECISION = 14
18
+
19
+ ResultInfo = namedtuple('ResultInfo', ['api_name', 'is_fwd_success', 'is_bwd_success',
20
+ 'fwd_compare_alg_results', 'bwd_compare_alg_results'])
21
+
22
+ def get_file_content_bytes(file):
23
+ with FileOpen(file, 'rb') as file_handle:
24
+ return file_handle.read()
25
+
26
+
27
+ def get_json_contents(file_path):
28
+ ops = get_file_content_bytes(file_path)
29
+ try:
30
+ json_obj = json.loads(ops)
31
+ except ValueError as error:
32
+ logger.error('Failed to load "%s". %s' % (file_path, str(error)))
33
+ raise CompareException(CompareException.INVALID_FILE_ERROR) from error
34
+ if not isinstance(json_obj, dict):
35
+ logger.error('Json file %s, content is not a dictionary!' % file_path)
36
+ raise CompareException(CompareException.INVALID_FILE_ERROR)
37
+ return json_obj
38
+
39
+
40
+ def write_csv(data, filepath):
41
+ with FileOpen(filepath, 'a', encoding='utf-8-sig') as f:
42
+ writer = csv.writer(f)
43
+ writer.writerows(data)
44
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
45
+
46
+
47
+ class Saver:
48
+ # consts for result csv
49
+ COLUMN_API_NAME = "API name"
50
+ COLUMN_FORWARD_SUCCESS = "Forward Test Success"
51
+ COLUMN_BACKWARD_SUCCESS = "Backward Test Success"
52
+ COLUMN_STACK_INFO = "Traceback callstack info"
53
+
54
+ def __init__(self, save_path, detail_save_path, stack_info):
55
+ self.save_path = save_path
56
+ self.detail_save_path = detail_save_path
57
+ self.stack_info = stack_info
58
+
59
+ self.test_result_cnt = {
60
+ "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0,
61
+ "total_num": 0, "forward_or_backward_fail_num": 0
62
+ }
63
+
64
+ def write_csv_title(self):
65
+ summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS, "Message"]]
66
+ write_csv(summary_test_rows, self.save_path)
67
+
68
+ detail_test_rows = [[
69
+ "Npu Name", "Bench Dtype", "NPU Dtype", "Shape",
70
+ "error_balance", "max_abs_diff", "max_abs_idx",
71
+ "max_rel_diff", "max_rel_idx", "eb_thd",
72
+ "error_thd", "Status","Message"
73
+ ]]
74
+ write_csv(detail_test_rows, self.detail_save_path)
75
+
76
+ def print_pretest_result(self):
77
+ self.get_statistics_from_result_csv()
78
+ if self.test_result_cnt.get("total_num") != 0:
79
+ passing_rate = str(self.test_result_cnt.get("success_num") /
80
+ (self.test_result_cnt.get("total_num") + sys.float_info.epsilon))
81
+ else:
82
+ passing_rate = "0"
83
+
84
+ console = Console()
85
+ table_total = Table(
86
+ show_header=True, title="Overall Statistics", show_lines=True, width=75
87
+ )
88
+ table_total.add_column("Result")
89
+ table_total.add_column("Statistics")
90
+ table_total.add_row("[green]Pass[/green]", str(self.test_result_cnt.get("success_num")))
91
+ table_total.add_row("[red]Fail[/red]", str(self.test_result_cnt.get("forward_and_backward_fail_num") +
92
+ self.test_result_cnt.get("forward_or_backward_fail_num")))
93
+ table_total.add_row("Passing Rate", passing_rate)
94
+
95
+ table_detail = Table(
96
+ show_header=True, title="Detail Statistics", show_lines=True, width=75
97
+ )
98
+ table_detail.add_column("Result")
99
+ table_detail.add_column("Statistics")
100
+ table_detail.add_row("Only Forward Fail", str(self.test_result_cnt.get("forward_fail_num")))
101
+ table_detail.add_row("Only Backward Fail", str(self.test_result_cnt.get("backward_fail_num")))
102
+ table_detail.add_row(
103
+ "Both Forward & Backward Fail", str(self.test_result_cnt.get("forward_and_backward_fail_num")))
104
+
105
+ console.print(table_total)
106
+ console.print(table_detail)
107
+
108
+ def get_statistics_from_result_csv(self):
109
+ checklist = [CompareConst.TRUE, CompareConst.FALSE, CompareConst.NA, CompareConst.SKIP]
110
+ with FileOpen(self.save_path, 'r') as file:
111
+ reader = csv.reader(file)
112
+ result_csv_rows = [row for row in reader]
113
+ result_csv_name = os.path.basename(self.save_path)
114
+ for item in result_csv_rows[1:]:
115
+ if not isinstance(item, list) or len(item) < 3:
116
+ raise ValueError("The number of columns in %s is incorrect" % result_csv_name)
117
+ if not all(item[i] and item[i].upper() in checklist for i in (1, 2)):
118
+ raise ValueError(
119
+ "The value in the 2nd or 3rd column of %s is wrong, it must be TRUE, FALSE, SKIP or N/A"
120
+ % result_csv_name)
121
+ column1 = item[1].upper()
122
+ column2 = item[2].upper()
123
+ if column1 == CompareConst.SKIP:
124
+ continue
125
+ self.test_result_cnt["total_num"] += 1
126
+ if column1 == CompareConst.TRUE and column2 in [CompareConst.TRUE, 'N/A']:
127
+ self.test_result_cnt['success_num'] += 1
128
+ elif column1 == CompareConst.FALSE and column2 == CompareConst.FALSE:
129
+ self.test_result_cnt['forward_and_backward_fail_num'] += 1
130
+ elif column1 == CompareConst.FALSE:
131
+ self.test_result_cnt['forward_fail_num'] += 1
132
+ self.test_result_cnt['forward_or_backward_fail_num'] += 1
133
+ else:
134
+ self.test_result_cnt['backward_fail_num'] += 1
135
+ self.test_result_cnt['forward_or_backward_fail_num'] += 1
136
+
137
+ def write_summary_csv(self, test_result):
138
+ test_rows = []
139
+ if self.stack_info:
140
+ test_rows[0].append(self.COLUMN_STACK_INFO)
141
+
142
+ name = test_result.api_name
143
+ df_row = [test_result.api_name, test_result.is_fwd_success, test_result.is_bwd_success]
144
+ if test_result.is_fwd_success == "SKIP" or test_result.is_bwd_success == "SKIP":
145
+ df_row.append(test_result.fwd_compare_alg_results)
146
+ if self.stack_info:
147
+ stack_info = "\n".join(self.stack_info[name])
148
+ df_row.append(stack_info)
149
+ test_rows.append(df_row)
150
+ write_csv(test_rows, self.save_path)
151
+
152
+ def write_detail_csv(self, test_result):
153
+ def get_rows_from_list(result, name, sub_prefix):
154
+ rows = []
155
+ if isinstance(result, list):
156
+ for i, test_subject in enumerate(result):
157
+ subject = sub_prefix + "." + name + ".output." + str(i)
158
+ test_subject = ["{:.{}f}".format(item, FLOAT_PRECISION) if isinstance(item, float) else item for
159
+ item in test_subject]
160
+ rows.append([subject] + list(test_subject))
161
+ return rows
162
+
163
+ test_rows = []
164
+ subject_prefix = test_result.api_name
165
+ fwd_result = test_result.fwd_compare_alg_results
166
+ bwd_result = test_result.bwd_compare_alg_results
167
+
168
+ test_rows.extend(get_rows_from_list(fwd_result, "forward", subject_prefix))
169
+ test_rows.extend(get_rows_from_list(bwd_result, "backward", subject_prefix))
170
+
171
+ write_csv(test_rows, self.detail_save_path)
172
+
173
+ def record_results(self, result_info):
174
+ self.write_summary_csv(result_info)
175
+ self.write_detail_csv(result_info)
176
+
177
+
178
+ class Comparator:
179
+
180
+ def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None):
181
+ self.save_path = result_csv_path
182
+ self.detail_save_path = details_csv_path
183
+ if stack_info_json_path:
184
+ self.stack_info = get_json_contents(stack_info_json_path)
185
+ else:
186
+ self.stack_info = None
187
+ self.saver = Saver(result_csv_path, details_csv_path, self.stack_info)
188
+
189
+ if is_continue_run_ut and not os.path.exists(self.save_path) and not os.path.exists(self.detail_save_path):
190
+ self.saver.write_csv_title()
191
+
192
+ @staticmethod
193
+ def _compare_core_wrapper(bench_out, npu_out):
194
+ detailed_result_total = []
195
+ test_final_success = True
196
+ status, details = single_benchmark_compare_wrap(npu_out, bench_out)
197
+ if not isinstance(status, list):
198
+ detailed_result_total.append(details)
199
+ test_final_success = status
200
+ else:
201
+ for item, item_status in enumerate(status):
202
+ detailed_result_total.append(details.get(item, 'key does not exist'))
203
+ if not item_status:
204
+ test_final_success = False
205
+ return test_final_success, detailed_result_total
206
+
207
+ @staticmethod
208
+ def _compare_dropout(bench_out, npu_out):
209
+ tensor_num = bench_out.numel()
210
+ if tensor_num >= ELEMENT_NUM_THRESHOLD:
211
+ if abs((bench_out == 0).sum() - (npu_out == 0).cpu().sum()) / tensor_num < ZERO_NUM_THRESHOLD:
212
+ return True, 1
213
+ else:
214
+ return False, 0
215
+ else:
216
+ return True, 1
217
+
218
+ def compare_output(self, api_name, bench_out, npu_out, bench_grad=None, npu_grad=None):
219
+ if "dropout" in api_name:
220
+ is_fwd_success, fwd_compare_alg_results = self._compare_dropout(bench_out, npu_out)
221
+ else:
222
+ is_fwd_success, fwd_compare_alg_results = self._compare_core_wrapper(bench_out, npu_out)
223
+ if bench_grad and npu_grad:
224
+ if "dropout" in api_name:
225
+ is_bwd_success, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], npu_grad[0])
226
+ else:
227
+ is_bwd_success, bwd_compare_alg_results = self._compare_core_wrapper(bench_grad, npu_grad)
228
+ else:
229
+ is_bwd_success, bwd_compare_alg_results = True, None
230
+ if is_bwd_success and bwd_compare_alg_results is None:
231
+ self.saver.record_results(ResultInfo(api_name, is_fwd_success, CompareConst.NAN, fwd_compare_alg_results,
232
+ bwd_compare_alg_results))
233
+ else:
234
+ self.saver.record_results(ResultInfo(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results,
235
+ bwd_compare_alg_results))
236
+ return is_fwd_success, is_bwd_success