mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -0,0 +1,146 @@
1
+ # 无标杆工具场景验证和性能基线报告
2
+
3
+ ## 1 环境信息
4
+
5
+ NPU:Atlas A2 训练系列产品
6
+
7
+ CPU:
8
+
9
+ ![输入图片说明](img/cpu_info.png)
10
+
11
+ Torch:2.1.0
12
+
13
+ CANN:8.0.T5
14
+
15
+ 除上述环境信息影响性能外,API 的数量、种类以及 Shape 都会对性能产生影响,因此本次选取不同场景网络和不同算子进行测试。
16
+
17
+ ## 2 模型信息和性能基线
18
+
19
+ 大模型在使用 msprobe 采集数据时,建议先简化模型层数,减少采集数据量。
20
+
21
+ 以下场景的性能基线测试数据均为多次测试后取平均值,实际运行时性能数据可能会根据环境状态稍有浮动。
22
+
23
+ ### [2.1 ModelLink 模型](https://gitee.com/ascend/ModelLink)
24
+
25
+ NUM_LAYER:1
26
+
27
+ NPU 卡数:1
28
+
29
+ 主要数据类型:FLOAT16
30
+
31
+ #### 2.1.1 LLaMA2-7B
32
+
33
+ softmax 算子为 FLOAT32,输入输出大小均为 2G,为模型最大显存开销的 API。在该模型下、对无标杆工具处理模式、插装范围、扰动方式组合下的性能和显存基线进行验证。
34
+
35
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
36
+ |-------|--------|------|-----|-------|---------|--------|-------|--------|
37
+ | / | / | / | / | 0.24 | 13.69 | 1 | 1 | 混合精度模式基线 |
38
+ | check | 前 | ["softmax"] | improve_precision | 0.26 | 13.69 | 1.08 | 1 | softmax 本身为高精度,跳过 |
39
+ | check | 前 | ["softmax"] | add_noise | 0.54 | 19.17 | 2.25 | 1.40 | |
40
+ | check | 前 | ["softmax"] | bit_noise | 0.56 | 19.17 | 2.33 | 1.40 | |
41
+ | check | 前 | ["softmax"] | change_value | 0.48 | 14.9 | 2 | 1.09 | |
42
+ | check | 前 | ["softmax"] | no_change | 0.47 | 14.9 | 1.96 | 1.09 | |
43
+ | check | 前 | ["softmax"] | to_cpu | 26.45 | 22.67 | 110.21 | 1.66 | 不建议整网 |
44
+ | check | 前 | ["matmul"] | improve_precision | 0.57 | 13.69 | 2.38 | 1 | |
45
+ | check | 前 | ["matmul"] | change_value | 0.48 | 13.69 | 2 | 1 | |
46
+ | check | 前 | ["matmul"] | to_cpu | 78.43 | 19.20 | 326.79 | 1.40 | 不建议整网 |
47
+ | check | 前 | [] | improve_precision | 3.45 | 18.79 | 14.37 | 1.37 | |
48
+ | check | 前 | [] | add_noise | 4.67 | 19.17 | 19.46 | 1.40 | |
49
+ | check | 前 | [] | bit_noise | 16.99 | 19.17 | 70.79 | 1.40 | |
50
+ | check | 前 | [] | no_change | 3.22 | 14.90 | 13.42 | 1.09 | |
51
+ | check | 反 | ["softmax"] | improve_precision | 6.23 | 25.69 | 25.96 | 1.88 | 不建议整网 |
52
+ | check | 反 | ["softmax"] | change_value | 22.76 | 25.69 | 94.83 | 1.88 | 不建议整网 |
53
+ | check | 反 | ["softmax"] | to_cpu | 141.71 | 26.19 | 590.46 | 1.91 | 不建议整网 |
54
+ | fix | 前 | ["softmax"] | to_cpu | 9.70 | 16.67 | 40.42 | 1.22 | 不支持整网、不支持反向 |
55
+ | fix | 前 | ["softmax"] | improve_precision | 0.26 | 14.67 | 1.08 | 1.07 | 不支持整网、不支持反向 |
56
+ | 预热 | 前 | [] | improve_precision | 155.07 | 24.79 | 646.13 | 1.81 | 低精度模型基线、只测预热的迭代 |
57
+ | 预热 | 反 | [] | improve_precision | 72.29 | 22.01 | 301.21 | 1.61 | 低精度模型基线、只测预热的迭代,grad_output 为高精度的算子跳过 |
58
+
59
+ #### 2.1.2 Aquila2-7B
60
+
61
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
62
+ |----------|------|-----|---|----|-----|-------|------|-------------|
63
+ | / | / | / | / | 0.17 | 13.66 | 1 | 1 | 混合精度模式基线 |
64
+ | check | 前 | [] | improve_precision | 1.57 | 14.24 | 9.24 | 1.04 | |
65
+ | check | 反 | [] | add_noise | 21.05 | 14.19 | 123.82 | 1.04 | |
66
+ | fix | 前 | [] | improve_precision | 0.95 | 15.55 | 5.59 | 1.14 | |
67
+
68
+ #### 2.1.3 Baichuan2-7B
69
+
70
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s)| 显存峰值(GB)| 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
71
+ |----|-----|---|--|----|----|------|-------|---------|
72
+ | / | / | / | / | 0.26 | 12.12 | 1 | 1 | 混合精度模式基线 |
73
+ | check | 前 | [] | improve_precision | 1.02 | 12.27 | 3.92 | 1.01 | |
74
+ | check | 反 | [] | add_noise | 11.15 | 12.67 | 42.88 | 1.05 | |
75
+ | fix | 前 | [] | improve_precision | 0.95 | 12.82 | 3.65 | 1.06 | |
76
+
77
+ #### 2.1.4 Bloom-7B
78
+
79
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s)| 显存峰值(GB)| 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
80
+ |-----|------|------|------|----|-----|-----|-------|----|
81
+ | / | / | / | / | 0.14 | 9.51 | 1 | 1 | 混合精度模式基线 |
82
+ | check | 前 | [] | improve_precision | 1.64 | 11.58 | 11.71 | 1.22 | |
83
+ | check | 反 | [] | add_noise | 17.15 | 9.51 | 122.5 | 1 | |
84
+ | fix | 前 | [] | improve_precision | 0.87 | 10.62 | 6.21 | 1.12 | |
85
+
86
+ #### 2.1.5 Interlm-7B
87
+
88
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
89
+ |-------------|--------|-------|----|------|-----|------|-------|----|
90
+ | / | / | / | / | 0.13 | 10.76 | 1 | 1 | 混合精度模式基线 |
91
+ | check | 前 | [] | improve_precision | 1.19 | 11.68 | 9.15 | 1.09 | |
92
+ | check | 反 | [] | add_noise | 11.69 | 10.89 | 89.92 | 1.01 | |
93
+ | fix | 前 | [] | improve_precision | 0.75 | 11.68 | 5.77 | 1.09 | |
94
+
95
+ #### 2.1.6 Qwen-7B
96
+
97
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
98
+ |--------|-------|-----|-----|----|------|-----|------|------|
99
+ | / | / | / | / | 0.28 | 18.41 | 1 | 1 | 混合精度模式基线 |
100
+ | check | 前 | [] | improve_precision | 2.34 | 23.18 | 8.36 | 1.26 | |
101
+ | check | 反 | [] | add_noise | 22.07 | 19.47 | 78.82 | 1.06 | |
102
+ | fix | 前 | [] | improve_precision | 1.31 | 21.11 | 4.68 | 1.15 | |
103
+
104
+ #### 2.1.7 Gemma-7B
105
+
106
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
107
+ |--------|-------|------|---|----|-----|-----|-----|---------|
108
+ | / | / | / | / | 0.15 | 11.06 | 1 | 1 | 混合精度模式基线 |
109
+ | check | 前 | [] | improve_precision | 1.49 | 13.17 | 9.93 | 1.19 | |
110
+ | check | 反 | [] | add_noise | 16.69 | 11.06 | 111.27 | 1 | |
111
+ | fix | 前 | [] | improve_precision | 0.87 | 12.25 | 5.8 | 1.11 | |
112
+
113
+ ### [2.2 ModelZoo-PyTorch 模型](https://gitee.com/ascend/ModelZoo-PyTorch)
114
+
115
+ #### 2.2.1 ResNet50-Cifar
116
+
117
+ NPU 卡数:1
118
+
119
+ 主要数据类型:FLOAT16
120
+
121
+ 主要算子为 conv2d,每个 step 有 51 个, 因此对 conv2d 进行检测。CV 模型、依赖 mmcv 实现(如果不修改 mmcv 代码、工具无法获取 step 信息和反向信息)。
122
+
123
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
124
+ |------------|---------|--------|-----|------|---|--------|-------|----|
125
+ | / | / | / | / | 0.09 | 7.63 | 1 | 1 | 基线 |
126
+ | check | 前 | ["conv2d"] | improve_precision | 0.889 | 7.94 | 9.81 | 1.04 | |
127
+ | fix | 前 | ["conv2d"] | improve_precision | 0.328 | 7.47 | 3.64 | 0.91 | |
128
+ | fix | 前 | ["conv2d"] | to_cpu | 12.23 | 7.47 | 135.88 | 0.91 | |
129
+
130
+ #### 2.2.2 OpenSora1.0
131
+
132
+ NPU 卡数:4
133
+
134
+ 主要数据类型:FLOAT16
135
+
136
+ 每张卡每个 step 中 linear 算子个数为 257 个,FA 算子个数为 83(FA 算子反向无效)。
137
+
138
+ | 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
139
+ |------------|------|-------|----|----|-----|-----|------|-----|
140
+ | / | / | / | / | 0.99 | 17.61 | 1 | 1 | 混合精度模式基线 |
141
+ | check | 前 | ["linear","npu_fusion_attention"] | improve_precision | 3.88 | 17.61 | 3.92 | 1 | |
142
+ | check | 前 | ["linear","npu_fusion_attention"] | add_noise | 3.46 | 17.61 | 3.49 | 1 | |
143
+ | check | 反 | ["linear"] | improve_precision | 12.61 | 17.61 | 12.74 | 1 | |
144
+ | check | 反 | ["linear"] | add_noise | 9.8 | 17.61 | 9.90 | 1 | |
145
+ | fix | 前 | ["linear"] | to_cpu | 18.83 | 17.61 | 19.02 | 1 | |
146
+ | fix | 前 | ["linear"] | improve_precision | 2.83 | 17.61 | 2.86 | 1 | |
Binary file
Binary file
Binary file
Binary file
Binary file
@@ -1 +1 @@
1
- from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
1
+ from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
File without changes
@@ -0,0 +1,255 @@
1
+ import json
2
+ import os
3
+
4
+ from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv
5
+ from msprobe.core.common.utils import add_time_as_suffix
6
+ from msprobe.core.common.const import Const, CompareConst, MsCompareConst
7
+ from msprobe.mindspore.common.log import logger
8
+ from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
9
+ from msprobe.mindspore.api_accuracy_checker.api_runner import api_runner, ApiInputAggregation
10
+ from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
11
+ from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context,
12
+ trim_output_compute_element_list)
13
+
14
+
15
+ class BasicInfoAndStatus:
16
+ def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
17
+ self.api_name = api_name
18
+ self.bench_dtype = bench_dtype
19
+ self.tested_dtype = tested_dtype
20
+ self.shape = shape
21
+ self.status = status
22
+ self.err_msg = err_msg
23
+
24
+ class ResultCsvEntry:
25
+ def __init__(self) -> None:
26
+ self.forward_pass_status = None
27
+ self.backward_pass_status = None
28
+ self.forward_err_msg = ""
29
+ self.backward_err_msg = ""
30
+ self.overall_err_msg = None
31
+
32
+
33
+ class ApiAccuracyChecker:
34
+ def __init__(self):
35
+ self.api_infos = dict()
36
+ self.results = dict()
37
+
38
+ @staticmethod
39
+ def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
40
+ '''
41
+ Args:
42
+ api_info: ApiInfo
43
+ api_name_str: str
44
+ api_input_aggregation: ApiInputAggregation
45
+ forward_or_backward: str: Union["forward", "backward"]
46
+
47
+ Return:
48
+ output_list: List[tuple(str, str, BasicInfoAndStatus, dict{str: CompareResult})]
49
+
50
+ Description:
51
+ get mindspore api output, run torch api and get output.
52
+ compare output.
53
+ record compare result.
54
+ '''
55
+ # get output
56
+ if global_context.get_is_constructed():
57
+ # constructed situation, need use constructed input to run mindspore api getting tested_output
58
+ tested_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.MS_FRAMEWORK)
59
+ else:
60
+ tested_outputs = api_info.get_compute_element_list(forward_or_backward, Const.OUTPUT)
61
+ bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
62
+ tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward)
63
+ bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward)
64
+ if len(tested_outputs) != len(bench_outputs):
65
+ logger.warning(f"ApiAccuracyChecker.run_and_compare_helper: api: {api_name_str}.{forward_or_backward}, "
66
+ "number of bench outputs and tested outputs is different, comparing result can be wrong. "
67
+ f"tested outputs: {len(tested_outputs)}, bench outputs: {len(bench_outputs)}")
68
+
69
+ # compare output
70
+ output_list = []
71
+ for i, (bench_out, tested_out) in enumerate(zip(bench_outputs, tested_outputs)):
72
+ api_name_with_slot = Const.SEP.join([api_name_str, forward_or_backward, Const.OUTPUT, str(i)])
73
+ bench_dtype = bench_out.get_dtype()
74
+ tested_dtype = tested_out.get_dtype()
75
+ shape = bench_out.get_shape()
76
+
77
+ compare_result_dict = dict()
78
+ for compare_algorithm_name, compare_algorithm in compare_algorithms.items():
79
+ compare_result = compare_algorithm(bench_out, tested_out)
80
+ compare_result_dict[compare_algorithm_name] = compare_result
81
+
82
+ if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \
83
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
84
+ status = CompareConst.PASS
85
+ err_msg = ""
86
+ else:
87
+ status = CompareConst.ERROR
88
+ err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg + \
89
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
90
+ basic_info_status = \
91
+ BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
92
+ output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
93
+ return output_list
94
+
95
+ def parse(self, api_info_path):
96
+ with FileOpen(api_info_path, "r") as f:
97
+ api_info_dict = json.load(f)
98
+
99
+ # init global context
100
+ task = check_and_get_from_json_dict(api_info_dict, MsCompareConst.TASK_FIELD,
101
+ "task field in api_info.json",accepted_type=str,
102
+ accepted_value=(MsCompareConst.STATISTICS_TASK,
103
+ MsCompareConst.TENSOR_TASK))
104
+ is_constructed = task == MsCompareConst.STATISTICS_TASK
105
+ if not is_constructed:
106
+ dump_data_dir = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DUMP_DATA_DIR_FIELD,
107
+ "dump_data_dir field in api_info.json", accepted_type=str)
108
+ else:
109
+ dump_data_dir = ""
110
+ global_context.init(is_constructed, dump_data_dir)
111
+
112
+ api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
113
+ "data field in api_info.json", accepted_type=dict)
114
+ for api_name, api_info in api_info_data.items():
115
+ is_mint = api_name.split(Const.SEP)[0] in \
116
+ (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL)
117
+ if not is_mint:
118
+ continue
119
+ forbackward_str = api_name.split(Const.SEP)[-1]
120
+ if forbackward_str not in (Const.FORWARD, Const.BACKWARD):
121
+ logger.warning(f"api: {api_name} is not recognized as forward api or backward api, skip this.")
122
+ api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
123
+ if api_name not in self.api_infos:
124
+ self.api_infos[api_name] = ApiInfo(api_name)
125
+
126
+ if forbackward_str == Const.FORWARD:
127
+ self.api_infos[api_name].load_forward_info(api_info)
128
+ else:
129
+ self.api_infos[api_name].load_backward_info(api_info)
130
+
131
+ def run_and_compare(self):
132
+ for api_name_str, api_info in self.api_infos.items():
133
+ if not api_info.check_forward_info():
134
+ logger.warning(f"api: {api_name_str} is lack of forward infomation, skip forward and backward check")
135
+ continue
136
+ forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
137
+ kwargs = api_info.get_kwargs()
138
+ forward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, None)
139
+ forward_output_list = None
140
+ try:
141
+ forward_output_list = \
142
+ self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
143
+ except Exception as e:
144
+ logger.warning(f"exception occurs when running and comparing {api_name_str} forward api"
145
+ f"detailed exception information: {e}")
146
+ self.record(forward_output_list)
147
+
148
+ if not api_info.check_backward_info():
149
+ logger.warning(f"api: {api_name_str} is lack of backward infomation, skip backward check")
150
+ continue
151
+ gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
152
+ backward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
153
+ backward_output_list = None
154
+ try:
155
+ backward_output_list = \
156
+ self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
157
+ except Exception as e:
158
+ logger.warning(f"exception occurs when running and comparing {api_name_str} backward api"
159
+ f"detailed exception information: {e}")
160
+ self.record(backward_output_list)
161
+
162
+ def record(self, output_list):
163
+ if output_list is None:
164
+ return
165
+ for output in output_list:
166
+ api_real_name, forward_or_backward, basic_info, compare_result_dict = output
167
+ key = tuple([api_real_name, forward_or_backward])
168
+ if key not in self.results:
169
+ self.results[key] = []
170
+ self.results[key].append(tuple([basic_info, compare_result_dict]))
171
+
172
+
173
+ def to_detail_csv(self, csv_dir):
174
+ # detail_csv
175
+ detail_csv = []
176
+ detail_csv_header_basic_info = [
177
+ MsCompareConst.DETAIL_CSV_API_NAME,
178
+ MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
179
+ MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
180
+ MsCompareConst.DETAIL_CSV_SHAPE,
181
+ ]
182
+ detail_csv_header_compare_result = list(compare_algorithms.keys())
183
+ detail_csv_header_status = [
184
+ MsCompareConst.DETAIL_CSV_PASS_STATUS,
185
+ MsCompareConst.DETAIL_CSV_MESSAGE,
186
+ ]
187
+
188
+ detail_csv_header = detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
189
+ detail_csv.append(detail_csv_header)
190
+
191
+ for _, results in self.results.items():
192
+ # detail csv
193
+ for res in results:
194
+ basic_info, compare_result_dict = res
195
+ csv_row_basic_info = \
196
+ [basic_info.api_name, basic_info.bench_dtype, basic_info.tested_dtype, basic_info.shape]
197
+ csv_row_compare_result = list(compare_result_dict.get(algorithm_name).compare_value \
198
+ for algorithm_name in detail_csv_header_compare_result)
199
+ csv_row_status = [basic_info.status, basic_info.err_msg]
200
+ csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
201
+ detail_csv.append(csv_row)
202
+
203
+ file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.DETAIL_CSV_FILE_NAME))
204
+ create_directory(csv_dir)
205
+ write_csv(detail_csv, file_name, mode="w")
206
+
207
+
208
+ def to_result_csv(self, csv_dir):
209
+ result_csv_dict = dict()
210
+ for key, results in self.results.items():
211
+ api_real_name, forward_or_backward = key
212
+ forward_or_backward_pass_status = CompareConst.PASS
213
+ forward_or_backward_overall_err_msg = ""
214
+ # detail csv
215
+ for res in results:
216
+ basic_info, _ = res
217
+ if basic_info.status != CompareConst.PASS:
218
+ forward_or_backward_pass_status = CompareConst.ERROR
219
+ forward_or_backward_overall_err_msg += basic_info.err_msg
220
+ forward_or_backward_overall_err_msg = \
221
+ "" if forward_or_backward_pass_status == CompareConst.PASS else forward_or_backward_overall_err_msg
222
+
223
+ #result_csv_dict
224
+ if api_real_name not in result_csv_dict:
225
+ result_csv_dict[api_real_name] = ResultCsvEntry()
226
+ if forward_or_backward == Const.FORWARD:
227
+ result_csv_dict[api_real_name].forward_pass_status = forward_or_backward_pass_status
228
+ result_csv_dict[api_real_name].forward_err_msg = forward_or_backward_overall_err_msg
229
+ else:
230
+ result_csv_dict[api_real_name].backward_pass_status = forward_or_backward_pass_status
231
+ result_csv_dict[api_real_name].backward_err_msg = forward_or_backward_overall_err_msg
232
+
233
+ #result_csv
234
+ result_csv = []
235
+ result_csv_header = [
236
+ MsCompareConst.DETAIL_CSV_API_NAME,
237
+ MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
238
+ MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
239
+ MsCompareConst.DETAIL_CSV_MESSAGE,
240
+ ]
241
+ result_csv.append(result_csv_header)
242
+
243
+ for api_name, result_csv_entry in result_csv_dict.items():
244
+ if result_csv_entry.forward_pass_status == CompareConst.PASS and \
245
+ result_csv_entry.backward_pass_status == CompareConst.PASS:
246
+ overall_err_msg = ""
247
+ else:
248
+ overall_err_msg = result_csv_entry.forward_err_msg + result_csv_entry.backward_err_msg
249
+ row = [api_name, result_csv_entry.forward_pass_status,
250
+ result_csv_entry.backward_pass_status, overall_err_msg]
251
+ result_csv.append(row)
252
+
253
+ file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
254
+ create_directory(csv_dir)
255
+ write_csv(result_csv, file_name, mode="w")
@@ -0,0 +1,69 @@
1
+ from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
2
+ from msprobe.core.common.const import Const
3
+ from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
4
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
5
+ from msprobe.mindspore.common.log import logger
6
+
7
+ class ApiInfo:
8
+ def __init__(self, api_name):
9
+ self.api_name = api_name
10
+ self.forward_info = None
11
+ self.backward_info = None
12
+
13
+ def load_forward_info(self, forward_info_dict):
14
+ self.forward_info = forward_info_dict
15
+
16
+ def load_backward_info(self, backward_info_dict):
17
+ self.backward_info = backward_info_dict
18
+
19
+ def check_forward_info(self):
20
+ return self.forward_info is not None
21
+
22
+ def check_backward_info(self):
23
+ return self.backward_info is not None
24
+
25
+ def get_compute_element_list(self, forward_or_backward, input_or_output):
26
+ '''
27
+ Args:
28
+ forward_or_backward: str, Union["forward", "backward"]
29
+ input_or_output: str, Union["input", "output"]
30
+
31
+ Return:
32
+ compute_element_list: List[ComputeElement]
33
+ '''
34
+ mapping = {
35
+ (Const.FORWARD, Const.INPUT): [self.forward_info, Const.INPUT_ARGS,
36
+ f"input_args field of {self.api_name} forward api in api_info.json"],
37
+ (Const.FORWARD, Const.OUTPUT): [self.forward_info, Const.OUTPUT,
38
+ f"output field of {self.api_name} forward api in api_info.json"],
39
+ (Const.BACKWARD, Const.INPUT): [self.backward_info, Const.INPUT,
40
+ f"input field of {self.api_name} backward api in api_info.json"],
41
+ (Const.BACKWARD, Const.OUTPUT): [self.backward_info, Const.OUTPUT,
42
+ f"output field of {self.api_name} backward api in api_info.json"]
43
+ }
44
+ dict_instance, key, key_desc = mapping.get((forward_or_backward, input_or_output))
45
+ compute_element_info_list = check_and_get_from_json_dict(dict_instance, key, key_desc, accepted_type=list)
46
+ compute_element_list = [ComputeElement(compute_element_info=compute_element_info)
47
+ for compute_element_info in compute_element_info_list]
48
+ return compute_element_list
49
+
50
+ def get_kwargs(self):
51
+ '''
52
+ Return:
53
+ kwargs_compute_element_dict: dict{str: ComputeElement}
54
+ '''
55
+ kwargs_dict = check_and_get_from_json_dict(self.forward_info, Const.INPUT_KWARGS,
56
+ "input_kwargs in api_info.json", accepted_type=dict)
57
+ for key_str, compute_element_info in kwargs_dict.items():
58
+ if not isinstance(key_str, str):
59
+ err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string"
60
+ logger.error_log_with_exp(err_msg,
61
+ ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
62
+ if not isinstance(compute_element_info, (list, dict)):
63
+ err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or dict"
64
+ logger.error_log_with_exp(err_msg,
65
+ ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
66
+ kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info)
67
+ for key_str, compute_element_info in kwargs_dict.items()}
68
+ return kwargs_compute_element_dict
69
+
@@ -0,0 +1,156 @@
1
+
2
+
3
+ import mindspore
4
+ import torch
5
+ from mindspore import ops
6
+
7
+ from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
8
+ from msprobe.core.common.const import Const, MsCompareConst
9
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
10
+ from msprobe.mindspore.common.log import logger
11
+ from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
12
+ from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str
13
+
14
+
15
+ class ApiInputAggregation:
16
+ def __init__(self, inputs, kwargs, gradient_inputs) -> None:
17
+ '''
18
+ Args:
19
+ inputs: List[ComputeElement]
20
+ kwargs: dict{str: ComputeElement}
21
+ gradient_inputs: Union[List[ComputeElement], None]
22
+ '''
23
+ self.inputs = inputs
24
+ self.kwargs = kwargs
25
+ self.gradient_inputs = gradient_inputs
26
+
27
+ api_parent_module_mapping = {
28
+ (MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint,
29
+ (MsCompareConst.MINT, Const.PT_FRAMEWORK): torch,
30
+ (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
31
+ (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional
32
+ }
33
+
34
+
35
+ class ApiRunner:
36
+ def __call__(self, api_input_aggregation, api_name_str, forward_or_backward=Const.FORWARD,
37
+ api_platform=Const.MS_FRAMEWORK):
38
+ '''
39
+ Args:
40
+ api_input_aggregation: ApiInputAggregation
41
+ api_name_str: str, e.g. "MintFunctional.relu.0"
42
+ forward_or_backward: str, Union["forward", "backward"]
43
+ api_platform: str, Union["mindspore", "torch"]
44
+
45
+ Return:
46
+ outputs: list[ComputeElement]
47
+
48
+ Description:
49
+ run mindspore.mint/torch api
50
+ '''
51
+ api_type_str, api_sub_name = self.get_info_from_name(api_name_str)
52
+ api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform)
53
+
54
+ return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform)
55
+
56
+ @staticmethod
57
+ def get_info_from_name(api_name_str):
58
+ '''
59
+ Args:
60
+ api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
61
+
62
+ Return:
63
+ api_type_str: str, Union["MintFunctional", "Mint"]
64
+ api_sub_name: str, e.g. "relu"
65
+ '''
66
+ api_name_list = api_name_str.split(Const.SEP)
67
+ if len(api_name_list) != 3:
68
+ err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
69
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
70
+ api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
71
+ if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL]:
72
+ err_msg = f"ApiRunner.get_info_from_name failed: not mint or mint.nn.functional api"
73
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
74
+
75
+ return api_type_str, api_sub_name
76
+
77
+ @staticmethod
78
+ def get_api_instance(api_type_str, api_sub_name, api_platform):
79
+ '''
80
+ Args:
81
+ api_type_str: str, Union["MintFunctional", "Mint"]
82
+ api_sub_name: str, e.g. "relu"
83
+ api_platform: str: Union["mindpore", "torch"]
84
+
85
+ Return:
86
+ api_instance: function object
87
+
88
+ Description:
89
+ get mindspore.mint/torch api fucntion
90
+ mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
91
+ mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
92
+ '''
93
+
94
+ api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
95
+ module_str = "mindspore.mint." if api_platform == Const.MS_FRAMEWORK else "torch."
96
+ submodule_str = "nn.functional." if api_type_str == MsCompareConst.MINT_FUNCTIONAL else ""
97
+ full_api_name = module_str + submodule_str + api_sub_name
98
+ if not hasattr(api_parent_module, api_sub_name):
99
+ err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
100
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
101
+
102
+ api_instance = getattr(api_parent_module, api_sub_name)
103
+ if not callable(api_instance):
104
+ err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not callable"
105
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
106
+
107
+ return api_instance
108
+
109
+ @staticmethod
110
+ def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform):
111
+ inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
112
+ for compute_element in api_input_aggregation.inputs)
113
+ kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform)
114
+ for key, value in api_input_aggregation.kwargs.items()}
115
+ gradient_inputs = api_input_aggregation.gradient_inputs
116
+
117
+ if forward_or_backward == Const.FORWARD:
118
+ forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
119
+ forward_result_tuple = convert_to_tuple(forward_result)
120
+ res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
121
+ else:
122
+ if gradient_inputs is None:
123
+ err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing"
124
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
125
+ gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
126
+ for compute_element in gradient_inputs)
127
+ if api_platform == Const.MS_FRAMEWORK:
128
+ if len(gradient_inputs) == 1:
129
+ gradient_inputs = gradient_inputs[0]
130
+ def api_with_kwargs(*forward_inputs):
131
+ return api_instance(*forward_inputs, **kwargs)
132
+ grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs)
133
+ backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
134
+ backward_result_tuple = convert_to_tuple(backward_result)
135
+ res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
136
+ else:
137
+ #set requires_grad
138
+ requires_grad_index = []
139
+ for index, tensor in enumerate(inputs):
140
+ if isinstance(tensor, torch.Tensor) and \
141
+ torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list:
142
+ setattr(tensor, "requires_grad", True)
143
+ requires_grad_index.append(index)
144
+ forward_results = api_instance(*inputs, **kwargs)
145
+ forward_results = convert_to_tuple(forward_results)
146
+ for forward_res, gradient_in in zip(forward_results, gradient_inputs):
147
+ forward_res.backward(gradient_in)
148
+ backward_result_list = []
149
+ for index in requires_grad_index:
150
+ backward_result_list.append(getattr(inputs[index], "grad"))
151
+ res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_list]
152
+
153
+ return res_compute_element_list
154
+
155
+
156
+ api_runner = ApiRunner()