mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -0,0 +1,378 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ import os
17
+ import copy
18
+ import functools
19
+ from collections import defaultdict
20
+
21
+ import mindspore as ms
22
+ from mindspore.common.tensor import Tensor
23
+ from mindspore import ops
24
+ from mindspore import nn
25
+ try:
26
+ from mindspore.common._pijit_context import PIJitCaptureContext
27
+ pijit_label = True
28
+ except ImportError:
29
+ pijit_label = False
30
+
31
+
32
+ from msprobe.core.data_dump.data_collector import build_data_collector
33
+ from msprobe.core.data_dump.scope import BaseScope
34
+ from msprobe.mindspore.common.utils import get_rank_if_initialized
35
+ from msprobe.core.common.file_utils import create_directory
36
+ from msprobe.mindspore.common.log import logger
37
+ from msprobe.core.common.utils import Const
38
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
39
+ from msprobe.mindspore.dump.hook_cell.api_registry import api_register
40
+ from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
41
+ ModuleBackwardInputs, ModuleBackwardOutputs
42
+ from msprobe.core.common.exceptions import MsprobeException
43
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
44
+ from msprobe.mindspore.cell_processor import CellProcessor
45
+ from msprobe.mindspore.dump.jit_dump import JitDump
46
+
47
+
48
+ class Service:
49
+ def __init__(self, config):
50
+ self.model = None
51
+ self.config = copy.deepcopy(config)
52
+ self.config.level = self.config.level_ori
53
+ self.data_collector = build_data_collector(self.config)
54
+ self.cell_processor = CellProcessor(self.data_collector.scope)
55
+ self.switch = False
56
+ self.current_iter = 0
57
+ self.first_start = True
58
+ self.current_rank = None
59
+ self.primitive_counters = {}
60
+ self.dump_iter_dir = None
61
+ self.start_call = False
62
+ self.check_level_valid()
63
+ self.should_stop_service = False
64
+
65
+ @staticmethod
66
+ def check_model_valid(model):
67
+ if not model or isinstance(model, nn.Cell):
68
+ return model
69
+ raise MsprobeException(
70
+ MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。"
71
+ )
72
+
73
+ def check_level_valid(self):
74
+ if self.config.level == "L2":
75
+ raise MsprobeException(
76
+ MsprobeException.INVALID_PARAM_ERROR, "L2 level dump function is currently not supported."
77
+ )
78
+
79
+ def build_hook(self, target_type, name):
80
+ def forward_hook(api_or_cell_name, cell, input, output):
81
+ if not self.should_excute_hook():
82
+ return None
83
+
84
+ if target_type == BaseScope.Module_Type_Module:
85
+ api_or_cell_name = cell.mindstudio_reserved_name
86
+ module_input_output = ModuleForwardInputsOutputs(args=input, kwargs={}, output=output)
87
+ else:
88
+ module_input_output = ModuleForwardInputsOutputs(args=input, kwargs=cell.input_kwargs,
89
+ output=output)
90
+
91
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
92
+ self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
93
+ if self.data_collector.if_return_forward_new_output():
94
+ return self.data_collector.get_forward_new_output()
95
+ if target_type == BaseScope.Module_Type_API:
96
+ del cell.input_kwargs
97
+ return output
98
+
99
+ def backward_hook(api_or_cell_name, cell, grad_input, grad_output):
100
+ if not self.should_excute_hook():
101
+ return
102
+
103
+ if target_type == BaseScope.Module_Type_Module:
104
+ api_or_cell_name = cell.mindstudio_reserved_name
105
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
106
+ if self.data_collector:
107
+ # 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入
108
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
109
+ self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
110
+
111
+ pid = os.getpid()
112
+ forward_name_template = name + Const.FORWARD
113
+ backward_name_template = name + Const.BACKWARD
114
+ forward_hook = functools.partial(forward_hook, forward_name_template)
115
+ backward_hook = functools.partial(backward_hook, backward_name_template)
116
+
117
+ def wrap_forward_hook(cell, input, output):
118
+ return forward_hook(cell, input, output)
119
+
120
+ def wrap_backward_hook(cell, grad_input, grad_output):
121
+ return backward_hook(cell, grad_input, grad_output)
122
+
123
+ return wrap_forward_hook, wrap_backward_hook
124
+
125
+ def wrap_primitive(self, origin_func, primitive_name):
126
+ service_instance = self
127
+
128
+ def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
129
+ def backward_hook(grad):
130
+ captured_grads.append(grad)
131
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
132
+ try:
133
+ if len(captured_grads) == num_tensors and hook_type == Const.INPUT:
134
+ service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
135
+ new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
136
+ service_instance.data_collector.backward_output_data_collect(
137
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
138
+ )
139
+ captured_grads.clear()
140
+ elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
141
+ service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
142
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
143
+ service_instance.data_collector.backward_input_data_collect(
144
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
145
+ )
146
+ captured_grads.clear()
147
+
148
+ except Exception as exception:
149
+ raise Exception(f"This is a primitive op {hook_type}_backward dump error: {exception},"
150
+ f" updated_primitive_name: {updated_primitive_name}") from exception
151
+
152
+ return backward_hook
153
+
154
+ def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name):
155
+ hooked_inputs = []
156
+ num_tensors = sum(isinstance(arg, Tensor) for arg in args)
157
+ input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name,
158
+ Const.INPUT)
159
+ for _, arg in enumerate(args):
160
+ if isinstance(arg, Tensor):
161
+ arg_hooked = ops.HookBackward(input_backward_hook)(arg)
162
+ hooked_inputs.append(arg_hooked)
163
+ else:
164
+ hooked_inputs.append(arg)
165
+ return hooked_inputs
166
+
167
+ def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
168
+ if isinstance(out, tuple):
169
+ num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out)
170
+ else:
171
+ num_output_tensors = 1
172
+ output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors,
173
+ updated_primitive_name, Const.OUTPUT)
174
+
175
+ if isinstance(out, Tensor):
176
+ return ops.HookBackward(output_backward_hook)(out)
177
+ elif isinstance(out, tuple):
178
+ hooked_outputs = []
179
+ for tensor in out:
180
+ if isinstance(tensor, Tensor):
181
+ hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
182
+ else:
183
+ hooked_outputs.append(tensor)
184
+ return tuple(hooked_outputs)
185
+ return out
186
+
187
+ def wrapped_primitive_call(instance_self, *args, **kwargs):
188
+ service_instance.update_primitive_counters(primitive_name)
189
+ current_count = service_instance.primitive_counters.get(primitive_name, 0)
190
+ updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
191
+
192
+ if not service_instance.switch:
193
+ return origin_func(*args, **kwargs)
194
+
195
+ captured_grads_input, captured_grads_output = [], []
196
+
197
+ try:
198
+ hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
199
+ except Exception as exception:
200
+ raise Exception("This is a primitive op dump error during input hooking: {},"
201
+ " primitive_name: {}".format(exception, primitive_name)) from exception
202
+
203
+ try:
204
+ out = origin_func(*hooked_inputs, **kwargs)
205
+ except Exception as exception:
206
+ raise Exception("This is a primitive op dump error during function call: {},"
207
+ " primitive_name: {}".format(exception, primitive_name)) from exception
208
+
209
+ forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
210
+ service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
211
+ if service_instance.data_collector:
212
+ module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
213
+ try:
214
+ service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
215
+ os.getpid(), module_input_output)
216
+ except Exception as exception:
217
+ raise Exception("This is a primitive op dump error during forward data collection: {},"
218
+ " primitive_name: {}".format(exception, primitive_name)) from exception
219
+
220
+ if service_instance.data_collector.if_return_forward_new_output():
221
+ out = service_instance.data_collector.get_forward_new_output()
222
+
223
+ try:
224
+ out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
225
+ except Exception as exception:
226
+ raise Exception("This is a primitive op dump error during output hooking: {},"
227
+ " primitive_name: {}".format(exception, primitive_name)) from exception
228
+
229
+ return out
230
+
231
+ return wrapped_primitive_call
232
+
233
+ def update_primitive_counters(self, primitive_name):
234
+ if primitive_name not in self.primitive_counters:
235
+ self.primitive_counters[primitive_name] = 0
236
+ else:
237
+ self.primitive_counters[primitive_name] += 1
238
+
239
+ def register_hooks(self):
240
+ primitive_set = set()
241
+ for _, cell in self.model.cells_and_names():
242
+ for pname, primitive in cell._primitives.items():
243
+ primitive_set.add((pname, primitive))
244
+
245
+ for pname, primitive in primitive_set:
246
+ NewPrimitive = type('NewPrimitive', (primitive.__class__,),
247
+ {'__call__': self.wrap_primitive(primitive.__call__, pname)})
248
+ primitive.__class__ = NewPrimitive
249
+
250
+ def step(self):
251
+ self.current_iter += 1
252
+ self.data_collector.update_iter(self.current_iter)
253
+ HOOKCell.cell_count = defaultdict(int)
254
+ CellProcessor.cell_count = {}
255
+ self.primitive_counters.clear()
256
+
257
+ def start(self, model=None):
258
+ self.start_call = True
259
+ if self.should_stop_service:
260
+ return
261
+ if self.need_end_service():
262
+ api_register.api_set_ori_func()
263
+ self.should_stop_service = True
264
+ self.switch = False
265
+ logger.info("************************************************")
266
+ logger.info(f"* {Const.TOOL_NAME} ends successfully. *")
267
+ logger.info("************************************************")
268
+ return
269
+ if self.config.step and self.current_iter not in self.config.step:
270
+ return
271
+ self.model = self.check_model_valid(model)
272
+
273
+ logger.info(f"{Const.TOOL_NAME}: debugger.start() is set successfully")
274
+
275
+ if self.first_start:
276
+ try:
277
+ self.current_rank = get_rank_if_initialized()
278
+ except DistributedNotInitializedError:
279
+ self.current_rank = None
280
+
281
+ if self.config.rank and self.current_rank not in self.config.rank:
282
+ return
283
+ self.register_hook_new()
284
+ if self.config.level == "L1":
285
+ JitDump.set_config(self.config)
286
+ JitDump.set_data_collector(self.data_collector)
287
+ ms.common.api._MindsporeFunctionExecutor = JitDump
288
+ ms.common.api._PyNativeExecutor.grad = JitDump.grad
289
+ if pijit_label:
290
+ PIJitCaptureContext.__enter__ = self.empty
291
+ PIJitCaptureContext.__exit__ = self.empty
292
+ self.first_start = False
293
+
294
+ self.switch = True
295
+ logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
296
+ self.create_dirs()
297
+ logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
298
+
299
+ def stop(self):
300
+ if self.should_stop_service:
301
+ return
302
+ logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. "
303
+ "Please set debugger.start() to turn on the dump switch again. ")
304
+ if not self.start_call:
305
+ logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
306
+ raise Exception("debugger.start() is not set in the current scope.")
307
+ if self.config.step and self.current_iter not in self.config.step:
308
+ return
309
+ if self.config.rank and self.current_rank not in self.config.rank:
310
+ return
311
+ self.switch = False
312
+ self.start_call = False
313
+ self.data_collector.write_json()
314
+
315
+ def need_end_service(self):
316
+ if self.config.step and self.current_iter > max(self.config.step):
317
+ return True
318
+ if self.data_collector and self.data_collector.data_processor.is_terminated:
319
+ return True
320
+ return False
321
+
322
+ def should_excute_hook(self):
323
+ if not self.switch:
324
+ return False
325
+ if not self.data_collector or self.data_collector.data_processor.is_terminated:
326
+ return False
327
+ return True
328
+
329
+ def create_dirs(self):
330
+ create_directory(self.config.dump_path)
331
+ self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
332
+ cur_rank = self.current_rank if self.current_rank is not None else ''
333
+ dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
334
+ create_directory(dump_dir)
335
+ if self.config.task in self.data_collector.tasks_need_tensor_data:
336
+ dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
337
+ create_directory(dump_data_dir)
338
+ else:
339
+ dump_data_dir = None
340
+
341
+ dump_file_path = os.path.join(dump_dir, "dump.json")
342
+ stack_file_path = os.path.join(dump_dir, "stack.json")
343
+ construct_file_path = os.path.join(dump_dir, "construct.json")
344
+ self.data_collector.update_dump_paths(
345
+ dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None)
346
+
347
+ def empty(self, *args, **kwargs):
348
+ pass
349
+
350
+ def register_hook_new(self):
351
+ logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
352
+ if self.config.level == "L1":
353
+ api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
354
+ api_register.api_set_hook_func()
355
+ if self.model:
356
+ self.register_hooks()
357
+
358
+ if self.config.level == "L0":
359
+ if not self.model:
360
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
361
+ "The current level is L0, the model cannot be None")
362
+ for name, cell in self.model.cells_and_names():
363
+ if cell == self.model:
364
+ continue
365
+ prefix = 'Cell' + Const.SEP + name + Const.SEP + \
366
+ cell.__class__.__name__ + Const.SEP
367
+ forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix)
368
+ cell.register_forward_hook(forward_hook)
369
+ cell.register_backward_hook(backward_hook)
370
+
371
+ cell.register_forward_pre_hook(
372
+ self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
373
+ cell.register_forward_hook(
374
+ self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
375
+ cell.register_backward_pre_hook(
376
+ self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
377
+ cell.register_backward_hook(
378
+ self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
@@ -1,21 +1,24 @@
1
- from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
2
- from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory
3
- from msprobe.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory
4
-
5
-
6
- class TaskHandlerFactory:
7
- tasks = {
8
- "tensor": DumpToolFactory,
9
- "statistics": DumpToolFactory,
10
- "overflow_check": OverflowCheckToolFactory
11
- }
12
-
13
- @staticmethod
14
- def create(config: DebuggerConfig):
15
- task = TaskHandlerFactory.tasks.get(config.task)
16
- if not task:
17
- raise Exception("valid task is needed.")
18
- handler = task.create(config)
19
- if not handler:
20
- raise Exception("Can not find task handler")
21
- return handler
1
+ from msprobe.core.common.const import Const
2
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
3
+ from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory
4
+ from msprobe.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory
5
+ from msprobe.mindspore.free_benchmark.self_check_tool_factory import SelfCheckToolFactory
6
+
7
+
8
+ class TaskHandlerFactory:
9
+ tasks = {
10
+ Const.TENSOR: DumpToolFactory,
11
+ Const.STATISTICS: DumpToolFactory,
12
+ Const.OVERFLOW_CHECK: OverflowCheckToolFactory,
13
+ Const.FREE_BENCHMARK: SelfCheckToolFactory
14
+ }
15
+
16
+ @staticmethod
17
+ def create(config: DebuggerConfig):
18
+ task = TaskHandlerFactory.tasks.get(config.task)
19
+ if not task:
20
+ raise Exception("Valid task is needed.")
21
+ handler = task.create(config)
22
+ if not handler:
23
+ raise Exception("Can not find task handler")
24
+ return handler
msprobe/msprobe.py CHANGED
@@ -1,67 +1,105 @@
1
- # Copyright (c) 2024, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import argparse
17
- import sys
18
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
19
- from msprobe.pytorch.parse_tool.cli import parse as cli_parse
20
- from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut
21
- from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, \
22
- _api_precision_compare_command
23
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
24
- _run_overflow_check_command
25
-
26
-
27
- def main():
28
- parser = argparse.ArgumentParser(
29
- formatter_class=argparse.RawDescriptionHelpFormatter,
30
- description="msprobe(mindstudio probe), [Powered by MindStudio].\n"
31
- "Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n"
32
- f"For any issue, refer README.md first",
33
- )
34
- parser.set_defaults(print_help=parser.print_help)
35
- parser.add_argument('-f', '--framework', required=True, choices=['pytorch'],
36
- help='Deep learning framework.')
37
- subparsers = parser.add_subparsers()
38
- subparsers.add_parser('parse')
39
- run_ut_cmd_parser = subparsers.add_parser('run_ut')
40
- multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut')
41
- api_precision_compare_cmd_parser = subparsers.add_parser('api_precision_compare')
42
- run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check')
43
- _run_ut_parser(run_ut_cmd_parser)
44
- _run_ut_parser(multi_run_ut_cmd_parser)
45
- multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
46
- help='Number of splits for parallel processing. Range: 1-64')
47
- _api_precision_compare_parser(api_precision_compare_cmd_parser)
48
- _run_overflow_check_parser(run_overflow_check_cmd_parser)
49
- if len(sys.argv) == 1:
50
- parser.print_help()
51
- sys.exit(0)
52
- args = parser.parse_args(sys.argv[1:])
53
- if sys.argv[3] == "run_ut":
54
- run_ut_command(args)
55
- elif sys.argv[3] == "parse":
56
- cli_parse()
57
- elif sys.argv[3] == "multi_run_ut":
58
- config = prepare_config(args)
59
- run_parallel_ut(config)
60
- elif sys.argv[3] == "api_precision_compare":
61
- _api_precision_compare_command(args)
62
- elif sys.argv[3] == "run_overflow_check":
63
- _run_overflow_check_command(args)
64
-
65
-
66
- if __name__ == "__main__":
67
- main()
1
+ # Copyright (c) 2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import sys
18
+ import importlib.util
19
+ from msprobe.core.compare.utils import _compare_parser
20
+ from msprobe.core.common.log import logger
21
+ from msprobe.core.compare.compare_cli import compare_cli
22
+ from msprobe.core.common.const import Const
23
+
24
+
25
+ def is_module_available(module_name):
26
+ spec = importlib.util.find_spec(module_name)
27
+ return spec is not None
28
+
29
+
30
+ def main():
31
+ parser = argparse.ArgumentParser(
32
+ formatter_class=argparse.RawDescriptionHelpFormatter,
33
+ description="msprobe(mindstudio probe), [Powered by MindStudio].\n"
34
+ "Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n"
35
+ f"For any issue, refer README.md first",
36
+ )
37
+
38
+ parser.set_defaults(print_help=parser.print_help)
39
+ parser.add_argument('-f', '--framework', required=True, choices=[Const.PT_FRAMEWORK, Const.MS_FRAMEWORK],
40
+ help='Deep learning framework.')
41
+ subparsers = parser.add_subparsers()
42
+ subparsers.add_parser('parse')
43
+ compare_cmd_parser = subparsers.add_parser('compare')
44
+ run_ut_cmd_parser = subparsers.add_parser('run_ut')
45
+ multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut')
46
+ api_precision_compare_cmd_parser = subparsers.add_parser('api_precision_compare')
47
+ run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check')
48
+ _compare_parser(compare_cmd_parser)
49
+ is_torch_available=is_module_available("torch")
50
+ is_mindspore_available = is_module_available("mindspore")
51
+ if is_torch_available:
52
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
53
+ from msprobe.pytorch.parse_tool.cli import parse as cli_parse
54
+ from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut
55
+ from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, \
56
+ _api_precision_compare_command
57
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
58
+ _run_overflow_check_command
59
+
60
+ _run_ut_parser(run_ut_cmd_parser)
61
+ _run_ut_parser(multi_run_ut_cmd_parser)
62
+ multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
63
+ help='Number of splits for parallel processing. Range: 1-64')
64
+ _api_precision_compare_parser(api_precision_compare_cmd_parser)
65
+ _run_overflow_check_parser(run_overflow_check_cmd_parser)
66
+ elif is_mindspore_available:
67
+ from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument
68
+ add_api_accuracy_checker_argument(run_ut_cmd_parser)
69
+
70
+ if len(sys.argv) == 1:
71
+ parser.print_help()
72
+ sys.exit(0)
73
+ args = parser.parse_args(sys.argv[1:])
74
+ if sys.argv[2] == Const.PT_FRAMEWORK:
75
+ if not is_torch_available:
76
+ logger.error("PyTorch does not exist, please install PyTorch library")
77
+ raise Exception("PyTorch does not exist, please install PyTorch library")
78
+ if sys.argv[3] == "run_ut":
79
+ run_ut_command(args)
80
+ elif sys.argv[3] == "parse":
81
+ cli_parse()
82
+ elif sys.argv[3] == "multi_run_ut":
83
+ config = prepare_config(args)
84
+ run_parallel_ut(config)
85
+ elif sys.argv[3] == "api_precision_compare":
86
+ _api_precision_compare_command(args)
87
+ elif sys.argv[3] == "run_overflow_check":
88
+ _run_overflow_check_command(args)
89
+ elif sys.argv[3] == "compare":
90
+ if args.cell_mapping is not None or args.api_mapping is not None:
91
+ logger.error("Argument -cm or -am is not supported in PyTorch framework")
92
+ raise Exception("Argument -cm or -am is not supported in PyTorch framework")
93
+ compare_cli(args)
94
+ else:
95
+ if not is_module_available(Const.MS_FRAMEWORK):
96
+ logger.error("MindSpore does not exist, please install MindSpore library")
97
+ raise Exception("MindSpore does not exist, please install MindSpore library")
98
+ if sys.argv[3] == "compare":
99
+ compare_cli(args)
100
+ elif sys.argv[3] == "run_ut":
101
+ from msprobe.mindspore.api_accuracy_checker.main import api_checker_main
102
+ api_checker_main(args)
103
+
104
+ if __name__ == "__main__":
105
+ main()
@@ -1,4 +1,4 @@
1
- from .debugger.precision_debugger import PrecisionDebugger
2
- from .common.utils import seed_all
3
- from .compare.acc_compare import compare
4
- from .compare.distributed_compare import compare_distributed
1
+ from .debugger.precision_debugger import PrecisionDebugger
2
+ from .common.utils import seed_all
3
+ from .compare.distributed_compare import compare_distributed
4
+ from .compare.pt_compare import compare