mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -0,0 +1,197 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ import mindspore
4
+ import torch
5
+ import numpy as np
6
+
7
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
8
+ from msprobe.mindspore.common.log import logger
9
+ from msprobe.core.common.const import CompareConst, MsCompareConst
10
+
11
+ class CompareResult:
12
+ def __init__(self, compare_value, pass_status, err_msg):
13
+ self.compare_value = compare_value
14
+ self.pass_status = pass_status
15
+ self.err_msg = err_msg
16
+
17
+
18
+ class BaseCompareAlgorithm(ABC):
19
+ def __init__(self) -> None:
20
+ super().__init__()
21
+ self.compare_algorithm_name = None
22
+ self.err_msg_mapping = {
23
+ CompareConst.COSINE: {
24
+ CompareConst.PASS: "",
25
+ CompareConst.ERROR: f"cosine similarity is less than threshold: {CompareConst.COS_THRESHOLD} ",
26
+ CompareConst.SKIP: "two inputs are not valid for computing cosine similarity, skip comparing ",
27
+ },
28
+ CompareConst.MAX_ABS_ERR: {
29
+ CompareConst.PASS: "",
30
+ CompareConst.ERROR: "max absolute difference is greater than " \
31
+ f"threshold: {CompareConst.MAX_ABS_ERR_THRESHOLD} ",
32
+ CompareConst.SKIP: "two inputs are not valid for computing max absolute difference, skip comparing ",
33
+ },
34
+ CompareConst.MAX_RELATIVE_ERR: {
35
+ CompareConst.PASS: "",
36
+ CompareConst.ERROR: "",
37
+ CompareConst.SKIP: "",
38
+ },
39
+ }
40
+
41
+ def __call__(self, bench_compute_element, tested_compute_element):
42
+ '''
43
+ Args:
44
+ bench_compute_element: ComputeElement
45
+ tested_compute_element: ComputeElement
46
+
47
+ Return:
48
+ compare_result: CompareResult
49
+ '''
50
+ if self.check_validity(bench_compute_element, tested_compute_element):
51
+ compare_value = self.run_compare(bench_compute_element, tested_compute_element)
52
+ pass_status = self.check_pass(compare_value)
53
+ else:
54
+ logger.warning(f"not suitable for computing {self.compare_algorithm_name}, skip this.")
55
+ compare_value = None
56
+ pass_status = CompareConst.SKIP
57
+
58
+ err_msg = self.err_msg_mapping.get(self.compare_algorithm_name).get(pass_status)
59
+
60
+ compare_result = CompareResult(compare_value, pass_status, err_msg)
61
+ return compare_result
62
+
63
+ @staticmethod
64
+ def convert_to_np_float64_ndarray(tensor):
65
+ if isinstance(tensor, mindspore.Tensor):
66
+ ndarray = tensor.astype(mindspore.float64).numpy()
67
+ elif isinstance(tensor, torch.Tensor):
68
+ ndarray = tensor.to(torch.float64, copy=True).numpy()
69
+ else:
70
+ err_msg = "BaseCompareAlgorithm.convert_to_np_float64_ndarray failed: " \
71
+ "input is not mindspore.Tensor or torch.Tensor"
72
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
73
+ return ndarray
74
+
75
+ @staticmethod
76
+ def check_two_tensor(bench_compute_element, tested_compute_element):
77
+ bench_parameter = bench_compute_element.get_parameter()
78
+ tested_parameter = tested_compute_element.get_parameter()
79
+
80
+ bench_is_tensor = isinstance(bench_parameter, (mindspore.Tensor, torch.Tensor))
81
+ tested_is_tensor = isinstance(tested_parameter, (mindspore.Tensor, torch.Tensor))
82
+ shape_same = bench_compute_element.get_shape() == tested_compute_element.get_shape()
83
+ return bench_is_tensor and tested_is_tensor and shape_same
84
+
85
+ @abstractmethod
86
+ def check_validity(self, bench_compute_element, tested_compute_element):
87
+ '''
88
+ Args:
89
+ bench_compute_element: ComputeElement
90
+ tested_compute_element: ComputeElement
91
+
92
+ Return:
93
+ check_res: boolean
94
+ '''
95
+ raise NotImplementedError
96
+
97
+ @abstractmethod
98
+ def run_compare(self, bench_compute_element, tested_compute_element):
99
+ '''
100
+ Args:
101
+ bench_compute_element: ComputeElement
102
+ tested_compute_element: ComputeElement
103
+
104
+ Return:
105
+ compare_value: float/int
106
+ '''
107
+ raise NotImplementedError
108
+
109
+ @abstractmethod
110
+ def check_pass(self, compare_value):
111
+ '''
112
+ Args:
113
+ compare_value: float/int
114
+
115
+ Return:
116
+ pass_status: str
117
+ '''
118
+ raise NotImplementedError
119
+
120
+
121
+ class CosineSimilarityCompareAlgorithm(BaseCompareAlgorithm):
122
+ def __init__(self) -> None:
123
+ super().__init__()
124
+ self.compare_algorithm_name = CompareConst.COSINE
125
+
126
+ def check_validity(self, bench_compute_element, tested_compute_element):
127
+ return self.check_two_tensor(bench_compute_element, tested_compute_element)
128
+
129
+ def run_compare(self, bench_compute_element, tested_compute_element):
130
+ bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
131
+ tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
132
+
133
+ bench_norm = np.linalg.norm(bench_ndarray)
134
+ tested_norm = np.linalg.norm(tested_ndarray)
135
+ dot_product = np.dot(bench_ndarray.flatten(), tested_ndarray.flatten())
136
+ cosine_similarity = (MsCompareConst.EPSILON + dot_product) / (MsCompareConst.EPSILON + bench_norm * tested_norm)
137
+ return cosine_similarity
138
+
139
+ def check_pass(self, compare_value):
140
+ if compare_value > CompareConst.COS_THRESHOLD:
141
+ return CompareConst.PASS
142
+ else:
143
+ return CompareConst.ERROR
144
+
145
+
146
+ class MaxAbsoluteDiffCompareAlgorithm(BaseCompareAlgorithm):
147
+ def __init__(self) -> None:
148
+ super().__init__()
149
+ self.compare_algorithm_name = CompareConst.MAX_ABS_ERR
150
+
151
+ def check_validity(self, bench_compute_element, tested_compute_element):
152
+ return self.check_two_tensor(bench_compute_element, tested_compute_element)
153
+
154
+ def run_compare(self, bench_compute_element, tested_compute_element):
155
+ bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
156
+ tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
157
+
158
+ max_absolute_diff = np.max(np.abs(bench_ndarray - tested_ndarray))
159
+ return max_absolute_diff
160
+
161
+ def check_pass(self, compare_value):
162
+ if compare_value < CompareConst.MAX_ABS_ERR_THRESHOLD:
163
+ return CompareConst.PASS
164
+ else:
165
+ return CompareConst.ERROR
166
+
167
+
168
+ class MaxRelativeDiffCompareAlgorithm(BaseCompareAlgorithm):
169
+ def __init__(self) -> None:
170
+ super().__init__()
171
+ self.compare_algorithm_name = CompareConst.MAX_RELATIVE_ERR
172
+
173
+ def check_validity(self, bench_compute_element, tested_compute_element):
174
+ return self.check_two_tensor(bench_compute_element, tested_compute_element)
175
+
176
+ def run_compare(self, bench_compute_element, tested_compute_element):
177
+ bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
178
+ tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
179
+
180
+ abs_diff = np.abs(bench_ndarray - tested_ndarray)
181
+ bench_ndarray_nonzero = np.abs(bench_ndarray) + (bench_ndarray == 0) * MsCompareConst.EPSILON
182
+ max_relative_diff = np.max(abs_diff / bench_ndarray_nonzero)
183
+ return max_relative_diff
184
+
185
+ def check_pass(self, compare_value):
186
+ if compare_value < CompareConst.MAX_RELATIVE_ERR_THRESHOLD:
187
+ return CompareConst.PASS
188
+ else:
189
+ return CompareConst.ERROR
190
+
191
+
192
+
193
+ compare_algorithms = {
194
+ CompareConst.COSINE: CosineSimilarityCompareAlgorithm(),
195
+ CompareConst.MAX_ABS_ERR: MaxAbsoluteDiffCompareAlgorithm(),
196
+ CompareConst.MAX_RELATIVE_ERR: MaxRelativeDiffCompareAlgorithm(),
197
+ }
@@ -0,0 +1,6 @@
1
+ def add_api_accuracy_checker_argument(parser):
2
+ parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
3
+ help="<Required> The api param tool result file: generate from api param tool, "
4
+ "a json file.")
5
+ parser.add_argument("-o", "--out_path", dest="out_path", default="./", type=str, required=False,
6
+ help="<optional> The ut task result out path.")
@@ -0,0 +1,239 @@
1
+ import os
2
+
3
+ import mindspore
4
+ import torch
5
+ import numpy as np
6
+
7
+ from msprobe.mindspore.common.log import logger
8
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
9
+ from msprobe.core.common.file_utils import load_npy
10
+ from msprobe.mindspore.api_accuracy_checker.type_mapping import (dtype_str_to_np_dtype, api_info_type_str_to_type,
11
+ ms_dtype_to_dtype_str, torch_dtype_to_dtype_str,
12
+ dtype_str_to_ms_dtype, dtype_str_to_np_dtype,
13
+ dtype_str_to_torch_dtype, type_to_api_info_type_str,
14
+ DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE, TUPLE_TYPE_STR,
15
+ MINDSPORE_TENSOR_TYPE_STR, float_dtype_str_list,
16
+ int_dtype_str_list)
17
+ from msprobe.core.common.const import Const
18
+ from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
19
+
20
+
21
+ class MstensorMetaData:
22
+ def __init__(self, dtype_str, npy_path, maximum, minimum, shape) -> None:
23
+ self.dtype_str = dtype_str
24
+ self.npy_path = npy_path
25
+ self.maximum = maximum
26
+ self.minimum = minimum
27
+ self.shape = shape
28
+
29
+ class ComputeElement:
30
+ def __init__(self, compute_element_info=None, parameter=None):
31
+ self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple])
32
+ if parameter is not None:
33
+ self._init_with_parameter(parameter)
34
+ elif isinstance(compute_element_info, (list, dict)):
35
+ self._init_from_compute_element_info(compute_element_info)
36
+ elif compute_element_info is None:
37
+ self._init_from_null_compute_element_info()
38
+ else:
39
+ logger.error_log_with_exp(
40
+ "ComputeElement.__init__ failed: not init with parameter or compute_element info is not (list, dict)",
41
+ ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
42
+
43
+ @staticmethod
44
+ def transfer_to_torch_tensor(ms_tensor):
45
+ '''
46
+ Args:
47
+ ms_tensor: mindspore.Tensor
48
+ Return:
49
+ torch_tensor: torch.Tensor
50
+ '''
51
+ ms_dtype = ms_tensor.dtype
52
+ dtype_str = ms_dtype_to_dtype_str.get(ms_dtype)
53
+ if dtype_str not in dtype_str_to_torch_dtype:
54
+ err_msg = f"ComputeElement.transfer_to_torch_tensor failed: no matching torch dtype for {dtype_str}"
55
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
56
+ else:
57
+ torch_dtype = dtype_str_to_torch_dtype.get(dtype_str)
58
+
59
+ if dtype_str in float_dtype_str_list:
60
+ middle_dtype = mindspore.float64
61
+ elif dtype_str in int_dtype_str_list:
62
+ middle_dtype = mindspore.int64
63
+ else:
64
+ middle_dtype = mindspore.uint64
65
+ np_ndarray = ms_tensor.astype(middle_dtype).numpy()
66
+ torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
67
+ return torch_tensor
68
+
69
+ @staticmethod
70
+ def transfer_to_mindspore_tensor(torch_tensor):
71
+ '''
72
+ Args:
73
+ torch_tensor: torch.Tensor
74
+
75
+ Return:
76
+ ms_tensor: mindspore.Tensor
77
+ '''
78
+ torch_dtype = torch_tensor.dtype
79
+ dtype_str = torch_dtype_to_dtype_str.get(torch_dtype)
80
+ if dtype_str not in dtype_str_to_ms_dtype:
81
+ err_msg = \
82
+ f"ComputeElement._transfer_to_mindspore_tensor failed: no matching mindspore dtype for {dtype_str}"
83
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
84
+ else:
85
+ ms_dtype = dtype_str_to_ms_dtype.get(dtype_str)
86
+
87
+ if dtype_str in float_dtype_str_list:
88
+ middle_dtype = torch.float64
89
+ elif dtype_str in int_dtype_str_list:
90
+ middle_dtype = torch.int64
91
+ np_ndarray = torch_tensor.to(middle_dtype, copy=True).numpy()
92
+ ms_tensor = mindspore.Tensor.from_numpy(np_ndarray).astype(ms_dtype)
93
+ return ms_tensor
94
+
95
+ @staticmethod
96
+ def convert_inf_to_real_num(value, dtype_str):
97
+ if value == float("inf"):
98
+ np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
99
+ value = np.finfo(np_dtype).max
100
+ elif value == float("-inf"):
101
+ np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
102
+ value = np.finfo(np_dtype).min
103
+ return value
104
+
105
+ def get_parameter(self, get_origin=True, tensor_platform=Const.MS_FRAMEWORK):
106
+ '''
107
+ Args:
108
+ get_origin: boolean
109
+ tensor_platform: str, Union["mindspore", "pytorch"]
110
+
111
+ Return:
112
+ parameter: Union[int, float, str, slice, tuple, torch.Tensor, mindspore.Tensor]
113
+ '''
114
+ if self.parameter is None:
115
+ return self.parameter
116
+ if isinstance(self.parameter, tuple):
117
+ return tuple([compute_element.get_parameter(get_origin=get_origin, tensor_platform=tensor_platform)
118
+ for compute_element in self.parameter])
119
+ elif isinstance(self.parameter, self.supported_parameter_type):
120
+ parameter_tmp = self.parameter
121
+ elif isinstance(self.parameter, MstensorMetaData):
122
+ mstensor_meta_data = self.parameter
123
+ ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
124
+ if global_context.get_is_constructed():
125
+ np_dtype = dtype_str_to_np_dtype.get(mstensor_meta_data.dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
126
+ ndarray = self._construct_ndarray(mstensor_meta_data.shape, mstensor_meta_data.maximum,
127
+ mstensor_meta_data.minimum, np_dtype)
128
+ else:
129
+ ndarray = load_npy(mstensor_meta_data.npy_path)
130
+ parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype)
131
+ else:
132
+ err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \
133
+ "(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)"
134
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
135
+
136
+ # if necessary, do transfer
137
+ if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
138
+ parameter = self.transfer_to_torch_tensor(parameter_tmp)
139
+ elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform ==Const.MS_FRAMEWORK:
140
+ parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
141
+ else:
142
+ parameter = parameter_tmp
143
+
144
+ return parameter
145
+
146
+ def get_shape(self):
147
+ return self.shape
148
+
149
+ def get_dtype(self):
150
+ return self.dtype_str
151
+
152
+ def _construct_ndarray(self, shape, maximum, minimum, np_dtype):
153
+ shape = tuple(shape)
154
+ np.random.seed(42)
155
+ if np_dtype == np.bool_:
156
+ ndarray = np.random.rand(*shape) > 0.5
157
+ else:
158
+ maximum = self.convert_inf_to_real_num(maximum, np_dtype)
159
+ minimum = self.convert_inf_to_real_num(minimum, np_dtype)
160
+ ndarray = np.random.uniform(minimum, maximum, shape).astype(np_dtype)
161
+ return ndarray
162
+
163
+ def _init_from_null_compute_element_info(self):
164
+ self.parameter = None
165
+ self.shape = tuple()
166
+ self.dtype = "None"
167
+
168
+ def _init_from_compute_element_info(self, compute_element_info):
169
+ '''
170
+ Args:
171
+ compute_element_info: Union[list, dict]
172
+
173
+ Return:
174
+ void
175
+
176
+ init member attributes: self.shape, self.dtype_str, self.parameter
177
+ '''
178
+ if isinstance(compute_element_info, list):
179
+ self.shape = tuple()
180
+ self.dtype_str = TUPLE_TYPE_STR
181
+ self.parameter = tuple([ComputeElement(compute_element_info=sub_info)
182
+ for sub_info in compute_element_info])
183
+ else:
184
+ type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json",
185
+ accepted_type=str, accepted_value=api_info_type_str_to_type.keys())
186
+
187
+ if type_str == MINDSPORE_TENSOR_TYPE_STR:
188
+ self._init_from_mstensor_compute_element_info(compute_element_info)
189
+ else: # type_str in ("slice", "int", "float", "bool")
190
+ value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json")
191
+ self.shape = tuple()
192
+ self.dtype_str = type_str
193
+ self.parameter = slice(*tuple(value)) if type_str == "slice" else value
194
+
195
+ def _init_from_mstensor_compute_element_info(self, compute_element_info):
196
+ '''
197
+ do not load real tensor, only record meta data
198
+ '''
199
+ dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json",
200
+ accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys())
201
+ shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json",
202
+ accepted_type=(list,))
203
+ if global_context.get_is_constructed():
204
+ maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json",
205
+ accepted_type=(int, float))
206
+ minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json",
207
+ accepted_type=(int, float))
208
+
209
+ npy_path = None
210
+ else:
211
+ maximum, minimum = None, None
212
+ data_name = check_and_get_from_json_dict(compute_element_info, "data_name",
213
+ "data_name field in api_info.json", accepted_type=(str,))
214
+ npy_path = os.path.join(global_context.get_dump_data_dir(), data_name)
215
+ mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape)
216
+ self.parameter = mstensor_meta_data
217
+ self.dtype_str = dtype_str
218
+ self.shape = tuple(shape)
219
+
220
+ def _init_with_parameter(self, parameter):
221
+ self.parameter = parameter
222
+ if not isinstance(parameter, self.supported_parameter_type):
223
+ err_msg = "ComputeElement._init_with_parameter failed: " \
224
+ "parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)"
225
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
226
+ if isinstance(parameter, mindspore.Tensor):
227
+ self.shape = tuple(parameter.shape)
228
+ self.dtype_str = ms_dtype_to_dtype_str.get(parameter.dtype)
229
+ elif isinstance(parameter, torch.Tensor):
230
+ self.shape = tuple(parameter.shape)
231
+ self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype)
232
+ elif isinstance(parameter, tuple):
233
+ self.shape = tuple()
234
+ self.dtype_str = TUPLE_TYPE_STR
235
+ self.parameter = tuple([ComputeElement(parameter=param) for param in parameter])
236
+ else:
237
+ self.shape = tuple()
238
+ self.dtype_str = \
239
+ TUPLE_TYPE_STR if isinstance(parameter, tuple) else type_to_api_info_type_str.get(type(parameter))
@@ -0,0 +1,9 @@
1
+ from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker
2
+
3
+
4
+ def api_checker_main(args):
5
+ api_accuracy_checker = ApiAccuracyChecker()
6
+ api_accuracy_checker.parse(args.api_info_file)
7
+ api_accuracy_checker.run_and_compare()
8
+ api_accuracy_checker.to_detail_csv(args.out_path)
9
+ api_accuracy_checker.to_result_csv(args.out_path)
@@ -0,0 +1,114 @@
1
+ from mindspore.common import dtype as mstype
2
+ import numpy as np
3
+ import mindspore
4
+ import torch
5
+
6
+ INT8 = "Int8"
7
+ UINT8 = "UInt8"
8
+ INT16 = "Int16"
9
+ UINT16 = "UInt16"
10
+ INT32 = "Int32"
11
+ UINT32 = "UInt32"
12
+ INT64 = "Int64"
13
+ UINT64 = "UInt64"
14
+ FLOAT16 = "Float16"
15
+ FLOAT32 = "Float32"
16
+ FLOAT64 = "Float64"
17
+ BOOL = "Bool"
18
+ BFLOAT16 = "BFloat16"
19
+ INT4 = "Int4"
20
+
21
+
22
+ dtype_str_to_ms_dtype = {
23
+ INT8: mstype.int8,
24
+ UINT8: mstype.uint8,
25
+ INT16: mstype.int16,
26
+ UINT16: mstype.uint16,
27
+ INT32: mstype.int32,
28
+ UINT32: mstype.uint32,
29
+ INT64: mstype.int64,
30
+ UINT64: mstype.uint64,
31
+ FLOAT16: mstype.float16,
32
+ FLOAT32: mstype.float32,
33
+ FLOAT64: mstype.float64,
34
+ BOOL: mstype.bool_,
35
+ BFLOAT16: mstype.bfloat16,
36
+ INT4: mstype.qint4x2
37
+ }
38
+ ms_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_ms_dtype.items()}
39
+
40
+
41
+ dtype_str_to_np_dtype = {
42
+ INT8: np.int8,
43
+ UINT8: np.uint8,
44
+ INT16: np.int16,
45
+ UINT16: np.uint16,
46
+ INT32: np.int32,
47
+ UINT32: np.uint32,
48
+ INT64: np.int64,
49
+ UINT64: np.uint64,
50
+ FLOAT16: np.float16,
51
+ FLOAT32: np.float32,
52
+ FLOAT64: np.float64,
53
+ BOOL: np.bool_
54
+ }
55
+ np_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_np_dtype.items()}
56
+
57
+ dtype_str_to_torch_dtype = {
58
+ INT8: torch.int8,
59
+ UINT8: torch.uint8,
60
+ INT16: torch.int16,
61
+ INT32: torch.int32,
62
+ INT64: torch.int64,
63
+ FLOAT16: torch.float16,
64
+ FLOAT32: torch.float32,
65
+ FLOAT64: torch.float64,
66
+ BOOL: torch.bool,
67
+ BFLOAT16: torch.bfloat16,
68
+ }
69
+ torch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_torch_dtype.items()}
70
+
71
+ MINDSPORE_TENSOR_TYPE_STR = "mindspore.Tensor"
72
+ BOOL_TYPE_STR = "bool"
73
+ INT_TYPE_STR = "int"
74
+ FLOAT_TYPE_STR = "float"
75
+ SLICE_TYPE_STR = "slice"
76
+ TUPLE_TYPE_STR = "tuple"
77
+ STR_TYPE_STR = "str"
78
+
79
+ api_info_type_str_to_type = {
80
+ MINDSPORE_TENSOR_TYPE_STR: mindspore.Tensor,
81
+ BOOL_TYPE_STR: bool,
82
+ INT_TYPE_STR: int,
83
+ FLOAT_TYPE_STR: float,
84
+ SLICE_TYPE_STR: slice,
85
+ STR_TYPE_STR: str,
86
+ }
87
+ type_to_api_info_type_str = {value: key for key, value in api_info_type_str_to_type.items()}
88
+
89
+ DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE = np.float64
90
+ DEFAULT_CONSTRUCT_NP_INT_DTYPE = np.float64
91
+ DEFAULT_CONSTRUCT_NP_UINT_DTYPE = np.float64
92
+
93
+ float_dtype_str_list = [
94
+ FLOAT16,
95
+ FLOAT32,
96
+ FLOAT64,
97
+ BFLOAT16,
98
+ ]
99
+
100
+ int_dtype_str_list = [
101
+ INT8,
102
+ INT16,
103
+ INT32,
104
+ INT64,
105
+ BOOL,
106
+ INT4,
107
+ ]
108
+
109
+ uint_dtype_str_list = [
110
+ UINT8,
111
+ UINT16,
112
+ UINT32,
113
+ UINT64,
114
+ ]
@@ -0,0 +1,80 @@
1
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
2
+ from msprobe.core.common.const import Const
3
+ from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list
4
+ from msprobe.mindspore.common.log import logger
5
+
6
+ def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_type=None, accepted_value=None):
7
+ '''
8
+ Args:
9
+ dict_instance: dict, dict parsed from input json
10
+ key: str
11
+ key_description: str
12
+ accepted_type: tuple
13
+ accepted_value: Union[tuple, list]
14
+
15
+ Return:
16
+ value, the corresponding value of "key" in "dict_instance"
17
+
18
+ Exception:
19
+ raise ApiAccuracyCheckerException.ParseJsonFailed error when
20
+ 1. dict_instance is not a dict
21
+ 2. value is None
22
+ 3. value is not accepted type
23
+ 4. value is not accepted value
24
+ '''
25
+ parse_failed_exception = ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed)
26
+ if not isinstance(dict_instance, dict):
27
+ logger.error_log_with_exp("check_and_get_from_json_dict failed: input is not a dict", parse_failed_exception)
28
+ value = dict_instance.get(key)
29
+ if value is None:
30
+ logger.error_log_with_exp(f"check_and_get_from_json_dict failed: {key_description} is missing",
31
+ parse_failed_exception)
32
+ elif accepted_type is not None and not isinstance(value, accepted_type):
33
+ logger.error_log_with_exp(
34
+ f"check_and_get_from_json_dict failed: {key_description} is not accepted type: {accepted_type}",
35
+ parse_failed_exception)
36
+ elif accepted_value is not None and value not in accepted_value:
37
+ logger.error_log_with_exp(
38
+ f"check_and_get_from_json_dict failed: {key_description} is not accepted value: {accepted_value}",
39
+ parse_failed_exception)
40
+ return value
41
+
42
+ def convert_to_tuple(input):
43
+ if isinstance(input, (tuple, list)):
44
+ return tuple(input)
45
+ else:
46
+ input_list = [input]
47
+ return tuple(input_list)
48
+
49
+ def trim_output_compute_element_list(compute_element_list, forward_or_backward):
50
+ '''
51
+ Args:
52
+ compute_element_list: List[ComputeElement]
53
+ forward_or_backward: str, Union["forward", "backward"]
54
+ '''
55
+ trimmed_list = []
56
+ for compute_element in compute_element_list:
57
+ if compute_element.get_parameter() is None or \
58
+ (forward_or_backward == Const.BACKWARD and compute_element.get_dtype() not in float_dtype_str_list):
59
+ # trim case: 1. parameter is None. 2. backward output has non float parameter
60
+ continue
61
+ trimmed_list.append(compute_element)
62
+ return trimmed_list
63
+
64
+ class GlobalContext:
65
+ def __init__(self):
66
+ self.is_constructed = True
67
+ self.dump_data_dir = ""
68
+
69
+ def init(self, is_constructed, dump_data_dir):
70
+ self.is_constructed = is_constructed
71
+ self.dump_data_dir = dump_data_dir
72
+
73
+ def get_dump_data_dir(self):
74
+ return self.dump_data_dir
75
+
76
+ def get_is_constructed(self):
77
+ return self.is_constructed
78
+
79
+
80
+ global_context = GlobalContext()
@@ -0,0 +1,34 @@
1
+ from msprobe.core.data_dump.scope import ModuleRangeScope
2
+ from msprobe.core.common.const import Const
3
+ from msprobe.mindspore.common.log import logger
4
+
5
+
6
+ class CellProcessor:
7
+ cell_count = {}
8
+
9
+ def __init__(self, scope):
10
+ if isinstance(scope, ModuleRangeScope):
11
+ self.scope = scope
12
+ else:
13
+ self.scope = None
14
+
15
+ @staticmethod
16
+ def set_cell_count(cell_name):
17
+ if cell_name not in CellProcessor.cell_count:
18
+ CellProcessor.cell_count[cell_name] = 0
19
+ else:
20
+ CellProcessor.cell_count[cell_name] += 1
21
+ return CellProcessor.cell_count[cell_name]
22
+
23
+ def node_hook(self, name_prefix, start_or_stop, **kwargs):
24
+ def begin_hook(cell, input):
25
+ index = self.set_cell_count(name_prefix)
26
+ cell.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
27
+ if self.scope:
28
+ self.scope.begin_module(full_name)
29
+
30
+ def end_hook(cell, input, output):
31
+ if self.scope:
32
+ self.scope.end_module(cell.mindstudio_reserved_name)
33
+
34
+ return begin_hook if Const.START == start_or_stop else end_hook