mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,345 +1,386 @@
1
- # 进行比对及结果展示
2
- import os
3
- from collections import namedtuple
4
- import torch
5
- import numpy as np
6
- from msprobe.pytorch.common.log import logger
7
- from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents, write_csv
8
- from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
9
- DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, \
10
- ULPStandardApi, ThousandthStandardApi, apis_threshold
11
- from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
12
- from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \
13
- get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \
14
- get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \
15
- check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err
16
- from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
17
- from msprobe.core.common.const import Const, CompareConst
18
-
19
-
20
- ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status',
21
- 'fwd_compare_alg_results', 'bwd_compare_alg_results', 'rank'])
22
-
23
-
24
- INDEX_TEST_RESULT__GROUP = 3
25
- INDEX_FIRST_GROUP = 0
26
- INDEX_MESSAGE = -1
27
-
28
-
29
- class Comparator:
30
- # consts for result csv
31
- COLUMN_API_NAME = "API name"
32
- COLUMN_FORWARD_SUCCESS = "Forward Test Success"
33
- COLUMN_BACKWARD_SUCCESS = "Backward Test Success"
34
- COLUMN_STACK_INFO = "Traceback callstack info"
35
-
36
- def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None):
37
- self.save_path = result_csv_path
38
- self.detail_save_path = details_csv_path
39
- if not is_continue_run_ut and not os.path.exists(self.save_path) and not os.path.exists(self.detail_save_path):
40
- self.write_csv_title()
41
- if stack_info_json_path:
42
- self.stack_info = get_json_contents(stack_info_json_path)
43
- else:
44
- self.stack_info = None
45
-
46
- @staticmethod
47
- def print_pretest_result():
48
- logger.info("Successfully completed run_ut/multi_run_ut.")
49
-
50
- @staticmethod
51
- def _compare_dropout(bench_output, device_output):
52
- tensor_num = bench_output.numel()
53
- if tensor_num >= 100:
54
- if abs((bench_output == 0).sum() - (device_output == 0).cpu().sum()) / tensor_num < 0.1:
55
- return CompareConst.PASS, 1
56
- else:
57
- return CompareConst.ERROR, 0
58
- else:
59
- return CompareConst.PASS, 1
60
-
61
- @staticmethod
62
- def _compare_builtin_type(bench_output, device_output, compare_column):
63
- if not isinstance(bench_output, (bool, int, float, str)):
64
- return CompareConst.PASS, compare_column, ""
65
- if bench_output != device_output:
66
- return CompareConst.ERROR, compare_column, ""
67
- compare_column.error_rate = 0
68
- return CompareConst.PASS, compare_column, ""
69
-
70
- @staticmethod
71
- def _compare_bool_tensor(bench_output, device_output):
72
- error_nums = (bench_output != device_output).sum()
73
- if bench_output.size == 0:
74
- return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result."
75
- error_rate = float(error_nums / bench_output.size)
76
- result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
77
- return error_rate, result, ""
78
-
79
- @staticmethod
80
- def _get_absolute_threshold_attribute(api_name, dtype):
81
- small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value')
82
- small_value_atol = apis_threshold.get(api_name).get(dtype).get('small_value_atol')
83
- rtol = apis_threshold.get(api_name).get(dtype).get('rtol')
84
- return small_value_threshold, small_value_atol, rtol
85
-
86
- def write_csv_title(self):
87
- summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS,
88
- self.COLUMN_BACKWARD_SUCCESS, "Message"]]
89
- if not os.path.exists(self.save_path):
90
- write_csv(summary_test_rows, self.save_path)
91
- if not os.path.exists(self.detail_save_path):
92
- write_csv(DETAIL_TEST_ROWS, self.detail_save_path)
93
-
94
- def write_summary_csv(self, test_result):
95
- test_rows = []
96
- if self.stack_info:
97
- test_rows[0].append(self.COLUMN_STACK_INFO)
98
-
99
- name = test_result[0]
100
- df_row = list(test_result[:INDEX_TEST_RESULT__GROUP])
101
- if test_result[1] == "SKIP":
102
- df_row.append(test_result[INDEX_TEST_RESULT__GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
103
- if self.stack_info:
104
- stack_info = "\n".join(self.stack_info[name])
105
- df_row.append(stack_info)
106
- test_rows.append(df_row)
107
- write_csv(test_rows, self.save_path)
108
-
109
- def write_detail_csv(self, test_result):
110
- test_rows = []
111
-
112
- subject_prefix = test_result[0]
113
- fwd_result = test_result[3]
114
- bwd_result = test_result[4]
115
- if isinstance(fwd_result, list):
116
- for i, test_subject in enumerate(fwd_result):
117
- subject = subject_prefix + ".forward.output." + str(i)
118
- test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision)
119
- if isinstance(item, float) else item for item in test_subject]
120
- test_rows.append([subject] + list(test_subject))
121
- if isinstance(bwd_result, list):
122
- for i, test_subject in enumerate(bwd_result):
123
- subject = subject_prefix + ".backward.output." + str(i)
124
- test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision)
125
- if isinstance(item, float) else item for item in test_subject]
126
- test_rows.append([subject] + list(test_subject))
127
-
128
- write_csv(test_rows, self.detail_save_path)
129
-
130
- def record_results(self, args):
131
- self.write_summary_csv(args)
132
- self.write_detail_csv(args)
133
-
134
- def compare_output(self, full_api_name, data_info):
135
- _, api_name, _ = full_api_name.split(Const.SEP)
136
- bench_output, device_output = data_info.bench_output, data_info.device_output
137
- bench_grad, device_grad = data_info.bench_grad, data_info.device_grad
138
- backward_message = data_info.backward_message
139
- if "dropout" in full_api_name:
140
- fwd_success_status, fwd_compare_alg_results = self._compare_dropout(bench_output, device_output)
141
- else:
142
- fwd_success_status, fwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_output,
143
- device_output)
144
- if not (bench_grad and device_grad):
145
- bwd_success_status, bwd_compare_alg_results = (CompareConst.SPACE, [])
146
- else:
147
- if "dropout" in full_api_name:
148
- bwd_success_status, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], device_grad[0])
149
- else:
150
- bwd_success_status, bwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_grad,
151
- device_grad)
152
- if backward_message:
153
- backward_column = CompareColumn()
154
- bwd_compare_alg_results = [backward_column.to_column_value(CompareConst.SKIP, backward_message)]
155
- else:
156
- bwd_success_status = bwd_success_status if bwd_compare_alg_results is not None else CompareConst.SPACE
157
- result_info = ResultInfo(full_api_name,
158
- fwd_success_status,
159
- bwd_success_status,
160
- fwd_compare_alg_results,
161
- bwd_compare_alg_results,
162
- data_info.rank)
163
- self.record_results(result_info)
164
- return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
165
- or bwd_success_status == CompareConst.SPACE
166
-
167
- def _compare_core_wrapper(self, api_name, bench_output, device_output):
168
- detailed_result_total = []
169
- test_final_success = CompareConst.PASS
170
- if isinstance(bench_output, (list, tuple)):
171
- status, compare_result, message = [], [], []
172
- if len(bench_output) > len(device_output):
173
- status = [CompareConst.ERROR]
174
- message = ["bench and npu output structure is different."]
175
- else:
176
- device_output = device_output[:len(bench_output)]
177
- for b_out_i, n_out_i in zip(bench_output, device_output):
178
- status_i, compare_result_i, message_i = self._compare_core(api_name, b_out_i, n_out_i)
179
- status.append(status_i)
180
- compare_result.append(compare_result_i)
181
- message.append(message_i)
182
- else:
183
- status, compare_result, message = self._compare_core(api_name, bench_output, device_output)
184
- if not isinstance(status, list):
185
- detailed_result_total.append(compare_result.to_column_value(status, message))
186
- if status == CompareConst.ERROR:
187
- test_final_success = CompareConst.ERROR
188
- elif status == CompareConst.WARNING:
189
- test_final_success = CompareConst.WARNING
190
- else:
191
- for item, item_status in enumerate(status):
192
- detailed_result_total.append(compare_result[item].to_column_value(item_status, message[item]))
193
- if item_status == CompareConst.ERROR:
194
- test_final_success = CompareConst.ERROR
195
- elif item_status == CompareConst.WARNING:
196
- test_final_success = CompareConst.WARNING
197
- return test_final_success, detailed_result_total
198
-
199
- def _compare_core(self, api_name, bench_output, device_output):
200
- compare_column = CompareColumn()
201
- if not isinstance(bench_output, type(device_output)):
202
- return CompareConst.ERROR, compare_column, "bench and npu output type is different."
203
- elif isinstance(bench_output, dict):
204
- b_keys, n_keys = set(bench_output.keys()), set(device_output.keys())
205
- if b_keys != n_keys:
206
- return CompareConst.ERROR, compare_column, "bench and npu output dict keys are different."
207
- else:
208
- status, compare_result, message = self._compare_core(api_name, list(bench_output.values()),
209
- list(device_output.values()))
210
- elif isinstance(bench_output, torch.Tensor):
211
- copy_bench_out = bench_output.detach().clone()
212
- copy_device_output = device_output.detach().clone()
213
- compare_column.bench_type = str(copy_bench_out.dtype)
214
- compare_column.npu_type = str(copy_device_output.dtype)
215
- compare_column.shape = tuple(device_output.shape)
216
- status, compare_result, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output,
217
- compare_column)
218
- elif isinstance(bench_output, (bool, int, float, str)):
219
- compare_column.bench_type = str(type(bench_output))
220
- compare_column.npu_type = str(type(device_output))
221
- status, compare_result, message = self._compare_builtin_type(bench_output, device_output, compare_column)
222
- elif bench_output is None:
223
- return CompareConst.SKIP, compare_column, "Bench output is None, skip this test."
224
- else:
225
- return CompareConst.PASS, compare_column,
226
- "Unexpected output type in compare_core: {}".format(type(bench_output))
227
-
228
- return status, compare_result, message
229
-
230
- def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column):
231
- cpu_shape = bench_output.shape
232
- npu_shape = device_output.shape
233
- npu_dtype = device_output.dtype
234
- if npu_dtype == torch.bfloat16:
235
- bench_output = bench_output.to(torch.float32)
236
- device_output = device_output.to(torch.float32)
237
- bench_output = bench_output.numpy()
238
- device_output = device_output.cpu().numpy()
239
- if cpu_shape != npu_shape:
240
- return CompareConst.ERROR, compare_column, f"The shape of bench{str(cpu_shape)} " \
241
- f"and npu{str(npu_shape)} not equal."
242
- if not check_dtype_comparable(bench_output, device_output):
243
- return CompareConst.ERROR, compare_column, f"Bench out dtype is {bench_output.dtype} but " \
244
- f"npu output dtype is {device_output.dtype}, cannot compare."
245
- message = ""
246
- if bench_output.dtype in [bool, np.uint8, np.int8, np.int16, np.uint16, np.uint32, np.int32,
247
- np.int64, np.uint64]:
248
- message += f"Compare algorithm is not supported for {bench_output.dtype} data. " \
249
- f"Only judged by Error Rate."
250
- err_rate, status, msg = self._compare_bool_tensor(bench_output, device_output)
251
- message += msg + "\n"
252
- compare_column.error_rate = err_rate
253
- return status, compare_column, message
254
- else:
255
- status, compare_column, message = self._compare_float_tensor(api_name, bench_output, device_output,
256
- compare_column, npu_dtype)
257
- return status, compare_column, message
258
-
259
- def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype):
260
- message = ""
261
- abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype)
262
- abs_err = get_abs_err(bench_output, device_output)
263
- rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps)
264
- if api_name in ThousandthStandardApi:
265
- thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
266
- compare_column.rel_err_thousandth = thousand_res
267
- if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST:
268
- both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output)
269
- if api_name in BinaryStandardApi:
270
- err_rate, _, _ = self._compare_bool_tensor(bench_output, device_output)
271
- compare_column.error_rate = err_rate
272
- elif api_name in AbsoluteStandardApi:
273
- small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute(
274
- api_name, str(dtype))
275
- rel_err = abs_err / abs_bench_with_eps
276
- small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold)
277
- normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask))
278
- compare_column.inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, bench_output, device_output,
279
- dtype, rtol)
280
- compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol)
281
- compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol)
282
- elif api_name in ULPStandardApi:
283
- if bench_output.size == 0:
284
- compare_column.max_ulp_error = 0
285
- compare_column.mean_ulp_error = 0
286
- compare_column.ulp_error_proportion = 0
287
- else:
288
- ulp_err = get_ulp_err(bench_output, device_output, dtype)
289
- compare_column.max_ulp_error = np.max(ulp_err)
290
- compare_column.mean_ulp_error = np.mean(ulp_err)
291
- if dtype == torch.float32:
292
- compare_column.ulp_error_proportion = np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size
293
- else:
294
- compare_column.ulp_error_proportion = np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size
295
- else:
296
- dtype_config = precision_configs.get(dtype)
297
- small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0])
298
- abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0])
299
- compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask)
300
- rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask)
301
- compare_column.RMSE = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask))
302
- compare_column.EB = get_error_balance(bench_output, device_output)
303
- if rel_err.size == 0:
304
- return CompareConst.ERROR, compare_column, "Relative error result list is empty."
305
- compare_column.Max_rel_error = get_max_rel_err(rel_err)
306
- compare_column.Mean_rel_error = get_mean_rel_err(rel_err)
307
-
308
- cos_res, cos_status, msg = cosine_sim(bench_output, device_output)
309
- compare_column.cosine_sim = cos_res
310
- message += msg + "\n"
311
- if not cos_status:
312
- message += "Cosine similarity is less than 0.99, consider as error, skip other check and set to SPACE.\n"
313
- return CompareConst.ERROR, compare_column, message
314
-
315
- max_abs_res, max_abs_status = get_max_abs_err(abs_err)
316
- compare_column.max_abs_err = max_abs_res
317
- if max_abs_status:
318
- message += "Max abs error is less than 0.001, consider as pass, skip other check and set to SPACE.\n"
319
- return CompareConst.PASS, compare_column, message
320
-
321
- if dtype in [torch.float16, torch.bfloat16]:
322
- hundred_res, hundred_status = get_rel_err_ratio(rel_err_orign, CompareConst.HUNDRED_RATIO_THRESHOLD)
323
- compare_column.rel_err_hundredth = hundred_res
324
- if not hundred_status:
325
- message += "Relative error is greater than 0.01, consider as error, skip other check and set to SPACE.\n"
326
- return CompareConst.ERROR, compare_column, message
327
- thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
328
- compare_column.rel_err_thousandth = thousand_res
329
- if dtype in [torch.float16, torch.bfloat16]:
330
- if thousand_status:
331
- message += "Relative error is less than 0.001, consider as pass, skip other check and set to SPACE.\n"
332
- return CompareConst.PASS, compare_column, message
333
- message += "Relative error is greater than 0.001, consider as warning, skip other check and set to SPACE.\n"
334
- return CompareConst.WARNING, compare_column, message
335
- ten_thousand_res, ten_thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.TEN_THOUSAND_RATIO_THRESHOLD)
336
- compare_column.rel_err_ten_thousandth = ten_thousand_res
337
- if dtype in [torch.float32, torch.float64]:
338
- if not thousand_status:
339
- message += "Relative error is greater than 0.001, consider as error, skip other check and set to SPACE.\n"
340
- return CompareConst.ERROR, compare_column, message
341
- if not ten_thousand_status:
342
- message += "Relative error is greater than 0.0001, consider as warning, skip other check and set to SPACE.\n"
343
- return CompareConst.WARNING, compare_column, message
344
- message += "Relative error is less than 0.0001, consider as pass.\n"
345
- return CompareConst.PASS, compare_column, message
1
+ # 进行比对及结果展示
2
+ import os
3
+ from collections import namedtuple
4
+
5
+ import numpy as np
6
+ from msprobe.core.common.utils import CompareException
7
+ from msprobe.core.common.file_utils import get_json_contents, write_csv
8
+ import torch
9
+ from msprobe.core.common.const import CompareConst
10
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \
11
+ get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \
12
+ get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \
13
+ check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err
14
+ from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
15
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
16
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
17
+ DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, absolute_standard_api, binary_standard_api, \
18
+ ulp_standard_api, thousandth_standard_api, apis_threshold
19
+ from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
20
+ from msprobe.pytorch.common.log import logger
21
+
22
+
23
+ ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status',
24
+ 'fwd_compare_alg_results', 'bwd_compare_alg_results', 'rank'])
25
+
26
+
27
+ INDEX_TEST_RESULT_GROUP = 3
28
+ INDEX_FIRST_GROUP = 0
29
+ INDEX_MESSAGE = -1
30
+
31
+
32
+ class Comparator:
33
+ # consts for result csv
34
+ COLUMN_API_NAME = "API name"
35
+ COLUMN_FORWARD_SUCCESS = "Forward Test Success"
36
+ COLUMN_BACKWARD_SUCCESS = "Backward Test Success"
37
+ COLUMN_STACK_INFO = "Traceback callstack info"
38
+
39
+ def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None, config=None):
40
+ self.save_path_str = result_csv_path
41
+ self.detail_save_path_str = details_csv_path
42
+ self.save_path_list = [result_csv_path]
43
+ self.detail_save_path_list = [details_csv_path]
44
+
45
+ if config and config.online_config.is_online:
46
+ self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv")
47
+ self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv")
48
+ self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list]
49
+ self.detail_save_path_list = \
50
+ [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
51
+
52
+ if not is_continue_run_ut:
53
+ self.write_csv_title()
54
+ if stack_info_json_path:
55
+ self.stack_info = get_json_contents(stack_info_json_path)
56
+ else:
57
+ self.stack_info = None
58
+
59
+ @staticmethod
60
+ def get_path_from_rank(rank, path_list, path_pattern):
61
+ return path_list[-1] if len(path_list) == 1 else path_pattern.format(rank)
62
+
63
+ @staticmethod
64
+ def print_pretest_result():
65
+ logger.info("Successfully completed run_ut/multi_run_ut.")
66
+
67
+ @staticmethod
68
+ def _compare_dropout(bench_output, device_output):
69
+ tensor_num = bench_output.numel()
70
+ if tensor_num >= 100:
71
+ if abs((bench_output == 0).sum() - (device_output == 0).cpu().sum()) / tensor_num < 0.1:
72
+ return CompareConst.PASS, 1
73
+ else:
74
+ return CompareConst.ERROR, 0
75
+ else:
76
+ return CompareConst.PASS, 1
77
+
78
+ @staticmethod
79
+ def _compare_builtin_type(bench_output, device_output, compare_column):
80
+ if not isinstance(bench_output, (bool, int, float, str)):
81
+ return CompareConst.PASS, compare_column, ""
82
+ if bench_output != device_output:
83
+ return CompareConst.ERROR, compare_column, ""
84
+ compare_column.error_rate = 0
85
+ return CompareConst.PASS, compare_column, ""
86
+
87
+ @staticmethod
88
+ def _compare_bool_tensor(bench_output, device_output):
89
+ error_nums = (bench_output != device_output).sum()
90
+ if bench_output.size == 0:
91
+ return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result."
92
+ error_rate = float(error_nums / bench_output.size)
93
+ result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
94
+ return error_rate, result, ""
95
+
96
+ @staticmethod
97
+ def _get_absolute_threshold_attribute(api_name, dtype):
98
+ small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value')
99
+ small_value_atol = apis_threshold.get(api_name).get(dtype).get('small_value_atol')
100
+ rtol = apis_threshold.get(api_name).get(dtype).get('rtol')
101
+ return small_value_threshold, small_value_atol, rtol
102
+
103
+ @staticmethod
104
+ def _get_run_ut_detail(test_result):
105
+ """get run_ut detail before write to csv, called by online run_ut"""
106
+ test_rows = []
107
+ try:
108
+ subject_prefix = test_result[0]
109
+ fwd_result = test_result[3]
110
+ bwd_result = test_result[4]
111
+ except IndexError as e:
112
+ logger.error("List index out of bounds when writing detail CSV.")
113
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e
114
+
115
+ if isinstance(fwd_result, list):
116
+ for i, test_subject in enumerate(fwd_result):
117
+ subject = subject_prefix + ".forward.output." + str(i)
118
+ test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision)
119
+ if isinstance(item, float) else item for item in test_subject]
120
+ test_rows.append([subject] + list(test_subject))
121
+ if isinstance(bwd_result, list):
122
+ for i, test_subject in enumerate(bwd_result):
123
+ subject = subject_prefix + ".backward.output." + str(i)
124
+ test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision)
125
+ if isinstance(item, float) else item for item in test_subject]
126
+ test_rows.append([subject] + list(test_subject))
127
+ return test_rows
128
+
129
+ def write_csv_title(self):
130
+ summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS,
131
+ self.COLUMN_BACKWARD_SUCCESS, "Message"]]
132
+ for save_path, detail_save_path in zip(self.save_path_list, self.detail_save_path_list):
133
+ if not os.path.exists(save_path):
134
+ write_csv(summary_test_rows, save_path)
135
+ if not os.path.exists(detail_save_path):
136
+ write_csv(DETAIL_TEST_ROWS, detail_save_path)
137
+
138
+ def write_summary_csv(self, test_result):
139
+ test_rows = []
140
+ try:
141
+ name = test_result[0]
142
+ df_row = list(test_result[:INDEX_TEST_RESULT_GROUP])
143
+ if test_result[1] == CompareConst.SKIP:
144
+ df_row.append(test_result[INDEX_TEST_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
145
+ if self.stack_info:
146
+ stack_info = "\n".join(self.stack_info[name])
147
+ df_row.append(stack_info)
148
+ test_rows.append(df_row)
149
+ save_path = self.get_path_from_rank(test_result[-1], self.save_path_list, self.save_path_str)
150
+ except IndexError as e:
151
+ logger.error("List index out of bounds when writing summary CSV.")
152
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e
153
+ write_csv(test_rows, save_path)
154
+
155
+ def write_detail_csv(self, test_result):
156
+ test_rows = self._get_run_ut_detail(test_result)
157
+ detail_save_path = self.get_path_from_rank(test_result[-1],
158
+ self.detail_save_path_list,
159
+ self.detail_save_path_str)
160
+ write_csv(test_rows, detail_save_path)
161
+
162
+ def record_results(self, args):
163
+ self.write_summary_csv(args)
164
+ self.write_detail_csv(args)
165
+
166
+
167
+ def compare_output(self, full_api_name, data_info, is_online=False):
168
+ """Get compare result and write to result and detail csv.
169
+ is_online: bool, default False. True: called by online api precision compare, only compare without write to csv.
170
+ """
171
+ _, api_name = extract_basic_api_segments(full_api_name)
172
+ if not api_name:
173
+ raise ValueError(f"API name {full_api_name} has not been adapted.")
174
+ bench_output, device_output = data_info.bench_output, data_info.device_output
175
+ bench_grad, device_grad = data_info.bench_grad, data_info.device_grad
176
+ backward_message = data_info.backward_message
177
+ if "dropout" in full_api_name:
178
+ fwd_success_status, fwd_compare_alg_results = self._compare_dropout(bench_output, device_output)
179
+ else:
180
+ fwd_success_status, fwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_output,
181
+ device_output)
182
+ if not (bench_grad and device_grad):
183
+ bwd_success_status, bwd_compare_alg_results = (CompareConst.SPACE, [])
184
+ else:
185
+ if "dropout" in full_api_name:
186
+ bwd_success_status, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], device_grad[0])
187
+ else:
188
+ bwd_success_status, bwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_grad,
189
+ device_grad)
190
+ if backward_message:
191
+ backward_column = CompareColumn()
192
+ bwd_compare_alg_results = [backward_column.to_column_value(CompareConst.SKIP, backward_message)]
193
+ else:
194
+ bwd_success_status = bwd_success_status if bwd_compare_alg_results is not None else CompareConst.SPACE
195
+ result_info = ResultInfo(full_api_name,
196
+ fwd_success_status,
197
+ bwd_success_status,
198
+ fwd_compare_alg_results,
199
+ bwd_compare_alg_results,
200
+ data_info.rank)
201
+ if is_online:
202
+ # get run_ut compare detail
203
+ return self._get_run_ut_detail(result_info)
204
+ self.record_results(result_info)
205
+ return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
206
+ or bwd_success_status == CompareConst.SPACE
207
+
208
+ def _compare_core_wrapper(self, api_name, bench_output, device_output):
209
+ detailed_result_total = []
210
+ test_final_success = CompareConst.PASS
211
+ if isinstance(bench_output, (list, tuple)):
212
+ status, compare_result, message = [], [], []
213
+ if len(bench_output) > len(device_output):
214
+ status = [CompareConst.ERROR]
215
+ message = ["bench and npu output structure is different."]
216
+ else:
217
+ device_output = device_output[:len(bench_output)]
218
+ for b_out_i, n_out_i in zip(bench_output, device_output):
219
+ status_i, compare_result_i, message_i = self._compare_core(api_name, b_out_i, n_out_i)
220
+ status.append(status_i)
221
+ compare_result.append(compare_result_i)
222
+ message.append(message_i)
223
+ else:
224
+ status, compare_result, message = self._compare_core(api_name, bench_output, device_output)
225
+ if not isinstance(status, list):
226
+ detailed_result_total.append(compare_result.to_column_value(status, message))
227
+ if status == CompareConst.ERROR:
228
+ test_final_success = CompareConst.ERROR
229
+ elif status == CompareConst.WARNING:
230
+ test_final_success = CompareConst.WARNING
231
+ else:
232
+ for item, item_status in enumerate(status):
233
+ detailed_result_total.append(compare_result[item].to_column_value(item_status, message[item]))
234
+ if item_status == CompareConst.ERROR:
235
+ test_final_success = CompareConst.ERROR
236
+ elif item_status == CompareConst.WARNING:
237
+ test_final_success = CompareConst.WARNING
238
+ return test_final_success, detailed_result_total
239
+
240
+ def _compare_core(self, api_name, bench_output, device_output):
241
+ compare_column = CompareColumn()
242
+ if not isinstance(bench_output, type(device_output)):
243
+ return CompareConst.ERROR, compare_column, "bench and npu output type is different."
244
+ elif isinstance(bench_output, dict):
245
+ b_keys, n_keys = set(bench_output.keys()), set(device_output.keys())
246
+ if b_keys != n_keys:
247
+ return CompareConst.ERROR, compare_column, "bench and npu output dict keys are different."
248
+ else:
249
+ status, compare_result, message = self._compare_core(api_name, list(bench_output.values()),
250
+ list(device_output.values()))
251
+ elif isinstance(bench_output, torch.Tensor):
252
+ copy_bench_out = bench_output.detach().clone()
253
+ copy_device_output = device_output.detach().clone()
254
+ compare_column.bench_type = str(copy_bench_out.dtype)
255
+ compare_column.npu_type = str(copy_device_output.dtype)
256
+ compare_column.shape = tuple(device_output.shape)
257
+ status, compare_result, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output,
258
+ compare_column)
259
+ elif isinstance(bench_output, (bool, int, float, str)):
260
+ compare_column.bench_type = str(type(bench_output))
261
+ compare_column.npu_type = str(type(device_output))
262
+ status, compare_result, message = self._compare_builtin_type(bench_output, device_output, compare_column)
263
+ elif bench_output is None:
264
+ return CompareConst.SKIP, compare_column, "Bench output is None, skip this test."
265
+ else:
266
+ return CompareConst.PASS, compare_column,
267
+ "Unexpected output type in compare_core: {}".format(type(bench_output))
268
+
269
+ return status, compare_result, message
270
+
271
+ def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column):
272
+ cpu_shape = bench_output.shape
273
+ npu_shape = device_output.shape
274
+ npu_dtype = device_output.dtype
275
+ if npu_dtype == torch.bfloat16:
276
+ bench_output = bench_output.to(torch.float32)
277
+ device_output = device_output.to(torch.float32)
278
+ bench_output = bench_output.cpu().numpy()
279
+ device_output = device_output.cpu().numpy()
280
+ if cpu_shape != npu_shape:
281
+ return CompareConst.ERROR, compare_column, f"The shape of bench{str(cpu_shape)} " \
282
+ f"and npu{str(npu_shape)} not equal."
283
+ if not check_dtype_comparable(bench_output, device_output):
284
+ return CompareConst.ERROR, compare_column, f"Bench out dtype is {bench_output.dtype} but " \
285
+ f"npu output dtype is {device_output.dtype}, cannot compare."
286
+ message = ""
287
+ if bench_output.dtype in [bool, np.uint8, np.int8, np.int16, np.uint16, np.uint32, np.int32,
288
+ np.int64, np.uint64]:
289
+ message += f"Compare algorithm is not supported for {bench_output.dtype} data. " \
290
+ f"Only judged by Error Rate."
291
+ err_rate, status, msg = self._compare_bool_tensor(bench_output, device_output)
292
+ message += msg + "\n"
293
+ compare_column.error_rate = err_rate
294
+ return status, compare_column, message
295
+ else:
296
+ status, compare_column, message = self._compare_float_tensor(api_name, bench_output, device_output,
297
+ compare_column, npu_dtype)
298
+ return status, compare_column, message
299
+
300
+ def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype):
301
+ message = ""
302
+ abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype)
303
+ abs_err = get_abs_err(bench_output, device_output)
304
+ rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps)
305
+ if api_name in thousandth_standard_api:
306
+ thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
307
+ compare_column.rel_err_thousandth = thousand_res
308
+ if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST:
309
+ both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output)
310
+ if api_name in binary_standard_api:
311
+ err_rate, _, _ = self._compare_bool_tensor(bench_output, device_output)
312
+ compare_column.error_rate = err_rate
313
+ elif api_name in absolute_standard_api:
314
+ small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute(
315
+ api_name, str(dtype))
316
+ rel_err = abs_err / abs_bench_with_eps
317
+ small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold)
318
+ normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask))
319
+ compare_column.inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, bench_output, device_output,
320
+ dtype, rtol)
321
+ compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol)
322
+ compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol)
323
+ elif api_name in ulp_standard_api:
324
+ if bench_output.size == 0:
325
+ compare_column.max_ulp_error = 0
326
+ compare_column.mean_ulp_error = 0
327
+ compare_column.ulp_error_proportion = 0
328
+ else:
329
+ ulp_err = get_ulp_err(bench_output, device_output, dtype)
330
+ compare_column.max_ulp_error = np.max(ulp_err)
331
+ compare_column.mean_ulp_error = np.mean(ulp_err)
332
+ if dtype == torch.float32:
333
+ compare_column.ulp_error_proportion = np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size
334
+ else:
335
+ compare_column.ulp_error_proportion = np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size
336
+ else:
337
+ dtype_config = precision_configs.get(dtype)
338
+ small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0])
339
+ abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0])
340
+ compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask)
341
+ rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask)
342
+ compare_column.RMSE = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask))
343
+ compare_column.EB = get_error_balance(bench_output, device_output)
344
+ if rel_err.size == 0:
345
+ return CompareConst.ERROR, compare_column, "Relative error result list is empty."
346
+ compare_column.Max_rel_error = get_max_rel_err(rel_err)
347
+ compare_column.Mean_rel_error = get_mean_rel_err(rel_err)
348
+
349
+ cos_res, cos_status, msg = cosine_sim(bench_output, device_output)
350
+ compare_column.cosine_sim = cos_res
351
+ message += msg + "\n"
352
+ if not cos_status:
353
+ message += "Cosine similarity is less than 0.99, consider as error, skip other check and set to SPACE.\n"
354
+ return CompareConst.ERROR, compare_column, message
355
+
356
+ max_abs_res, max_abs_status = get_max_abs_err(abs_err)
357
+ compare_column.max_abs_err = max_abs_res
358
+ if max_abs_status:
359
+ message += "Max abs error is less than 0.001, consider as pass, skip other check and set to SPACE.\n"
360
+ return CompareConst.PASS, compare_column, message
361
+
362
+ if dtype in [torch.float16, torch.bfloat16]:
363
+ hundred_res, hundred_status = get_rel_err_ratio(rel_err_orign, CompareConst.HUNDRED_RATIO_THRESHOLD)
364
+ compare_column.rel_err_hundredth = hundred_res
365
+ if not hundred_status:
366
+ message += "Relative error is greater than 0.01, consider as error, skip other check and set to SPACE.\n"
367
+ return CompareConst.ERROR, compare_column, message
368
+ thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
369
+ compare_column.rel_err_thousandth = thousand_res
370
+ if dtype in [torch.float16, torch.bfloat16]:
371
+ if thousand_status:
372
+ message += "Relative error is less than 0.001, consider as pass, skip other check and set to SPACE.\n"
373
+ return CompareConst.PASS, compare_column, message
374
+ message += "Relative error is greater than 0.001, consider as warning, skip other check and set to SPACE.\n"
375
+ return CompareConst.WARNING, compare_column, message
376
+ ten_thousand_res, ten_thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.TEN_THOUSAND_RATIO_THRESHOLD)
377
+ compare_column.rel_err_ten_thousandth = ten_thousand_res
378
+ if dtype in [torch.float32, torch.float64]:
379
+ if not thousand_status:
380
+ message += "Relative error is greater than 0.001, consider as error, skip other check and set to SPACE.\n"
381
+ return CompareConst.ERROR, compare_column, message
382
+ if not ten_thousand_status:
383
+ message += "Relative error is greater than 0.0001, consider as warning, skip other check and set to SPACE.\n"
384
+ return CompareConst.WARNING, compare_column, message
385
+ message += "Relative error is less than 0.0001, consider as pass.\n"
386
+ return CompareConst.PASS, compare_column, message