mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,274 +1,272 @@
1
- import os
2
- import time
3
- import json
4
- from pathlib import Path
5
- from multiprocessing import Manager, Pool
6
-
7
- import yaml
8
- import torch
9
-
10
- from torch.utils._python_dispatch import TorchDispatchMode
11
-
12
- try:
13
- import torch_npu
14
- except ImportError:
15
- is_npu = False
16
- else:
17
- is_npu = True
18
-
19
- from .dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \
20
- DispatchRunParam, DisPatchDataInfo
21
- from .utils import get_callstack, data_to_cpu, logger_debug, logger_error, logger_warn, logger_logo, get_sys_info, \
22
- DispatchException
23
- from .compare import Comparator
24
- from msprobe.core.common.file_check import FileOpen
25
- from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create
26
- from msprobe.core.common.const import Const, CompareConst
27
-
28
- current_time = time.strftime("%Y%m%d%H%M%S")
29
- RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
30
- DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
31
-
32
-
33
- class PtdbgDispatch(TorchDispatchMode):
34
- def __init__(self, dump_mode=Const.OFF, api_list=None, debug=False, dump_path=None, tag=None, process_num=0):
35
- super(PtdbgDispatch, self).__init__()
36
- logger_logo()
37
- if not is_npu:
38
- logger_error("Please confirm you run environment installed torch_npu!")
39
- return
40
- if dump_path is None:
41
- logger_error("Please set dump_path when dump_mode is config!")
42
- check_file_or_directory_path(dump_path, True)
43
-
44
- self.device_id = torch_npu._C._npu_getDevice()
45
- self.dump_mode = dump_mode
46
- self.dump_api_list = api_list
47
- self.debug_flag = debug
48
- self.api_index = 0
49
- self.single_api_index_dict = {}
50
- self.device_dump_path_cpu = None
51
- self.device_dump_path_npu = None
52
- self.all_summery = []
53
- self.call_stack_list = []
54
- self.process_num = process_num
55
- self.filter_dump_api()
56
- self.check_param()
57
- dir_name = self.get_dir_name(tag)
58
- self.root_path = os.path.join(os.path.realpath(dump_path), dir_name)
59
- self.root_cpu_path = os.path.join(self.root_path, f'cpu')
60
- self.root_npu_path = os.path.join(self.root_path, f'npu')
61
- check_path_before_create(self.root_cpu_path)
62
- check_path_before_create(self.root_npu_path)
63
- Path(self.root_cpu_path).mkdir(mode=0o750, parents=True, exist_ok=True)
64
- Path(self.root_npu_path).mkdir(mode=0o750, parents=True, exist_ok=True)
65
-
66
- self.result_csv_path = os.path.join(self.root_path, RESULT_FILE_NAME)
67
- self.detail_csv_path = os.path.join(self.root_path, DETAILS_FILE_NAME)
68
- self.comparator = Comparator(self.result_csv_path, self.detail_csv_path, False)
69
-
70
- self.aten_ops_blacklist = []
71
- self.npu_adjust_autogard = []
72
- yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
73
- self.load_yaml_file(yaml_path)
74
-
75
- self.lock = None
76
- if process_num > 0:
77
- self.pool = Pool(process_num)
78
- if debug:
79
- logger_debug(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} '
80
- f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], '
81
- f'process[{process_num}]')
82
-
83
- def __exit__(self, exc_type, exc_val, exc_tb):
84
- super().__exit__(exc_type, exc_val, exc_tb)
85
-
86
- if not is_npu:
87
- return
88
- logger_debug(f'start write compare csv: Rank[{self.device_id}], Pid[{os.getpid()}')
89
-
90
- if self.process_num > 0:
91
- self.pool.close()
92
- self.pool.join()
93
- summery_path = os.path.join(self.root_cpu_path, f'summary.json')
94
- if not os.path.exists(summery_path):
95
- logger_error("Please check train log, An exception may have occurred!")
96
- return
97
- check_file_or_directory_path(summery_path, False)
98
- fp_handle = open(summery_path, "r")
99
- while True:
100
- json_line_data = fp_handle.readline()
101
- if json_line_data == '\n':
102
- continue
103
- if len(json_line_data) == 0:
104
- break
105
- msg = json.loads(json_line_data)
106
- self.all_summery[msg[0]] = msg[1]
107
- fp_handle.close()
108
-
109
- if self.debug_flag:
110
- input_num = 0
111
- output_num = 0
112
- total_num = 0
113
-
114
- for list_data in self.all_summery:
115
- for data in list_data:
116
- logger_debug(f'summery: Device[{self.device_id}], Pid[{os.getpid()}], Data[{data}]')
117
- if "_input" in data[CompareConst.NPU_NAME]:
118
- input_num = input_num + 1
119
- if "_output" in data[CompareConst.NPU_NAME]:
120
- output_num = output_num + 1
121
- total_num = total_num + 1
122
- logger_debug(f'Dispatch exit: Device[{self.device_id}], Pid[{os.getpid()} Input[{input_num}] '
123
- f'Output[{output_num}] Total[{total_num}] API_Total[{self.api_index}]]')
124
-
125
- def __torch_dispatch__(self, func, types, args=(), kwargs=None):
126
- if not is_npu:
127
- logger_error("Please confirm you run environment installed torch_npu!")
128
- return func(*args, **kwargs)
129
-
130
- func_name_split_list = func.__name__.split(".")
131
- aten_api = func_name_split_list[0]
132
- try:
133
- aten_api_overload_name = func_name_split_list[1]
134
- except IndexError:
135
- logger_error(f"Please check the func name {func.__name__}!")
136
- return func(*args, **kwargs)
137
-
138
- self.enable_autogard(aten_api)
139
- if aten_api in self.aten_ops_blacklist:
140
- npu_out = func(*args, **kwargs)
141
- return npu_out
142
-
143
- call_stack = get_callstack()
144
- self.call_stack_list.append(call_stack)
145
- self.api_index += 1
146
- if aten_api not in self.single_api_index_dict:
147
- self.single_api_index_dict[aten_api] = 1
148
- else:
149
- self.single_api_index_dict[aten_api] += 1
150
-
151
- run_param = self.get_run_param(aten_api, func.__name__, aten_api_overload_name)
152
-
153
- if self.debug_flag:
154
- logger_debug(f'Dispatch Info: Rank[{self.device_id}], Pid[{os.getpid()}], Func[{func.__name__}], '
155
- f'Name[{run_param.aten_api}_{run_param.single_api_index}], '
156
- f'Count[{self.api_index}], Sys[{get_sys_info()}]')
157
-
158
- cpu_args = []
159
- cpu_kwargs = []
160
- data_to_cpu(args, 0, cpu_args)
161
- data_to_cpu(kwargs, 0, cpu_kwargs)
162
- cpu_args = cpu_args[0]
163
- cpu_kwargs = cpu_kwargs[0]
164
-
165
- with TimeStatistics("NPU RUN", run_param):
166
- npu_out = func(*args, **kwargs)
167
- npu_out_cpu = []
168
- data_to_cpu(npu_out, 0, npu_out_cpu)
169
- npu_out_cpu = npu_out_cpu[0]
170
-
171
- with TimeStatistics("CPU RUN", run_param):
172
- cpu_out = func(*cpu_args, **cpu_kwargs)
173
-
174
- if isinstance(cpu_out, torch.Tensor) and cpu_out.dtype in [torch.bfloat16, torch.float16, torch.half]:
175
- cpu_out = cpu_out.float()
176
-
177
- if self.process_num == 0:
178
- self.all_summery.append([])
179
- data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summery, func, npu_out_cpu, cpu_out, self.lock)
180
- dispatch_workflow(run_param, data_info)
181
- else:
182
- self.lock.acquire()
183
- self.all_summery.append([])
184
- self.lock.release()
185
- run_param.process_flag = True
186
- if self.check_fun(func, run_param):
187
- data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summery, None, npu_out_cpu, cpu_out,
188
- self.lock)
189
- self.pool.apply_async(func=dispatch_multiprocess, args=(run_param, data_info),
190
- error_callback=error_call)
191
- else:
192
- logger_error("can not get correct function please set process_num=0")
193
- return npu_out
194
-
195
- @staticmethod
196
- def check_fun(func, run_param):
197
- if hasattr(torch.ops.aten, run_param.aten_api):
198
- aten_func = getattr(torch.ops.aten, run_param.aten_api)
199
- if hasattr(aten_func, run_param.aten_api_overload_name):
200
- aten_overload_func = getattr(aten_func, run_param.aten_api_overload_name)
201
- if id(aten_overload_func) == id(func):
202
- run_param.func_namespace = "aten"
203
- return True
204
- return False
205
-
206
- def get_dir_name(self, tag):
207
- # guarantee file uniqueness
208
- time.sleep(1)
209
- time_now = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
210
- if tag is None or not isinstance(tag, str):
211
- logger_warn('There is not tag or the type of tag is not string.')
212
- dir_name = f'msprobe_rank{self.device_id}_{time_now}'
213
- else:
214
- dir_name = f'msprobe_{tag}_rank{self.device_id}_{time_now}'
215
- return dir_name
216
-
217
- def load_yaml_file(self, file_path):
218
- with FileOpen(file_path, 'r') as f:
219
- yaml_file = yaml.safe_load(f)
220
- self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist')
221
- self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard')
222
-
223
- def filter_dump_api(self):
224
- if self.dump_mode != Const.LIST or not self.dump_api_list:
225
- self.dump_api_list = []
226
- return
227
- aten_api_list = dir(torch.ops.aten)
228
- dump_api_list = []
229
- for aten_api in self.dump_api_list:
230
- if aten_api in aten_api_list:
231
- dump_api_list.append(aten_api)
232
- else:
233
- logger_warn(f'{aten_api} is not aten api will not dump, please refer to torch.ops.aten')
234
- self.dump_api_list = dump_api_list
235
-
236
- def get_run_param(self, aten_api, func_name, aten_api_overload_name):
237
- run_param = DispatchRunParam(self.debug_flag, self.device_id, self.root_npu_path, self.root_cpu_path,
238
- self.process_num, self.comparator)
239
- run_param.dump_flag, run_param.auto_dump_flag = self.get_dump_flag(aten_api)
240
- run_param.func_name = func_name
241
- run_param.aten_api = aten_api
242
- run_param.aten_api_overload_name = aten_api_overload_name
243
- run_param.single_api_index = self.single_api_index_dict[aten_api]
244
- run_param.api_index = self.api_index
245
- return run_param
246
-
247
- def get_dump_flag(self, aten_api):
248
- dump_flag = False
249
- auto_dump_flag = False
250
- if self.dump_mode == Const.ALL:
251
- dump_flag = True
252
- if self.dump_mode == Const.LIST and aten_api in self.dump_api_list:
253
- dump_flag = True
254
- if self.dump_mode == Const.AUTO:
255
- auto_dump_flag = True
256
- return dump_flag, auto_dump_flag
257
-
258
- def check_param(self):
259
- if self.dump_mode not in Const.ONLINE_DUMP_MODE:
260
- logger_error('The parameter "dump mode" can only be one of {}.'.format(Const.ONLINE_DUMP_MODE))
261
- raise DispatchException(DispatchException.INVALID_PARAMETER)
262
- if not isinstance(self.dump_api_list, list):
263
- logger_error('The type of parameter "api_list" can only be list.')
264
- raise DispatchException(DispatchException.INVALID_PARAMETER)
265
- if not isinstance(self.debug_flag, bool):
266
- logger_error('The type of parameter "debug" can only be bool.')
267
- raise DispatchException(DispatchException.INVALID_PARAMETER)
268
- if not isinstance(self.process_num, int) or self.process_num < 0:
269
- logger_error('The type of parameter "process_num" can only be int and it should not be less than 0.')
270
- raise DispatchException(DispatchException.INVALID_PARAMETER)
271
-
272
- def enable_autogard(self, aten_api):
273
- if aten_api in self.npu_adjust_autogard:
1
+ import os
2
+ import time
3
+ import json
4
+ from multiprocessing import Pool
5
+
6
+ import torch
7
+
8
+ from torch.utils._python_dispatch import TorchDispatchMode
9
+
10
+ try:
11
+ import torch_npu
12
+ except ImportError:
13
+ is_npu = False
14
+ else:
15
+ is_npu = True
16
+
17
+ from msprobe.core.common.file_utils import check_path_before_create, check_file_or_directory_path, load_yaml
18
+ from msprobe.core.common.const import Const, CompareConst
19
+ from msprobe.pytorch.common.log import logger
20
+ from msprobe.pytorch.online_dispatch.dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \
21
+ DispatchRunParam, DisPatchDataInfo
22
+ from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu, get_sys_info, DispatchException, COMPARE_LOGO
23
+ from msprobe.pytorch.online_dispatch.compare import Comparator
24
+ from msprobe.core.common.file_utils import FileOpen, create_directory
25
+
26
+
27
+ current_time = time.strftime("%Y%m%d%H%M%S")
28
+ RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
29
+ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
30
+
31
+
32
+ class PtdbgDispatch(TorchDispatchMode):
33
+ def __init__(self, dump_mode=Const.OFF, api_list=None, debug=False, dump_path=None, tag=None, process_num=0):
34
+ super(PtdbgDispatch, self).__init__()
35
+ logger.info(COMPARE_LOGO)
36
+ if not is_npu:
37
+ logger.error("Please confirm you run environment installed torch_npu!")
38
+ return
39
+ if dump_path is None:
40
+ logger.error("Please set dump_path when dump_mode is config!")
41
+ check_file_or_directory_path(dump_path, True)
42
+
43
+ self.device_id = torch_npu._C._npu_getDevice()
44
+ self.dump_mode = dump_mode
45
+ self.dump_api_list = api_list
46
+ self.debug_flag = debug
47
+ self.api_index = 0
48
+ self.single_api_index_dict = {}
49
+ self.device_dump_path_cpu = None
50
+ self.device_dump_path_npu = None
51
+ self.all_summary = []
52
+ self.call_stack_list = []
53
+ self.process_num = process_num
54
+ self.filter_dump_api()
55
+ self.check_param()
56
+ dir_name = self.get_dir_name(tag)
57
+ self.root_path = os.path.join(os.path.realpath(dump_path), dir_name)
58
+ self.root_cpu_path = os.path.join(self.root_path, f'cpu')
59
+ self.root_npu_path = os.path.join(self.root_path, f'npu')
60
+ check_path_before_create(self.root_cpu_path)
61
+ check_path_before_create(self.root_npu_path)
62
+ create_directory(self.root_cpu_path)
63
+ create_directory(self.root_npu_path)
64
+
65
+ self.result_csv_path = os.path.join(self.root_path, RESULT_FILE_NAME)
66
+ self.detail_csv_path = os.path.join(self.root_path, DETAILS_FILE_NAME)
67
+ self.comparator = Comparator(self.result_csv_path, self.detail_csv_path, False)
68
+
69
+ self.aten_ops_blacklist = []
70
+ self.npu_adjust_autogard = []
71
+ yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
72
+ self.get_ops(yaml_path)
73
+
74
+ self.lock = None
75
+ if process_num > 0:
76
+ self.pool = Pool(process_num)
77
+ if debug:
78
+ logger.info(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} '
79
+ f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], '
80
+ f'process[{process_num}]')
81
+
82
+ def __exit__(self, exc_type, exc_val, exc_tb):
83
+ super().__exit__(exc_type, exc_val, exc_tb)
84
+
85
+ if not is_npu:
86
+ return
87
+ logger.info(f'start write compare csv: Rank[{self.device_id}], Pid[{os.getpid()}')
88
+
89
+ if self.process_num > 0:
90
+ self.pool.close()
91
+ self.pool.join()
92
+ summary_path = os.path.join(self.root_cpu_path, f'summary.json')
93
+ if not os.path.exists(summary_path):
94
+ logger.error("Please check train log, An exception may have occurred!")
95
+ return
96
+ check_file_or_directory_path(summary_path, False)
97
+ fp_handle = FileOpen(summary_path, "r")
98
+ while True:
99
+ json_line_data = fp_handle.readline()
100
+ if json_line_data == '\n':
101
+ continue
102
+ if len(json_line_data) == 0:
103
+ break
104
+ msg = json.loads(json_line_data)
105
+ self.all_summary[msg[0]] = msg[1]
106
+ fp_handle.close()
107
+
108
+ if self.debug_flag:
109
+ input_num = 0
110
+ output_num = 0
111
+ total_num = 0
112
+
113
+ for list_data in self.all_summary:
114
+ for data in list_data:
115
+ logger.info(f'summary: Device[{self.device_id}], Pid[{os.getpid()}], Data[{data}]')
116
+ if "_input" in data[CompareConst.NPU_NAME]:
117
+ input_num = input_num + 1
118
+ if "_output" in data[CompareConst.NPU_NAME]:
119
+ output_num = output_num + 1
120
+ total_num = total_num + 1
121
+ logger.info(f'Dispatch exit: Device[{self.device_id}], Pid[{os.getpid()} Input[{input_num}] '
122
+ f'Output[{output_num}] Total[{total_num}] API_Total[{self.api_index}]]')
123
+
124
+ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
125
+ if not is_npu:
126
+ logger.error("Please confirm you run environment installed torch_npu!")
127
+ return func(*args, **kwargs)
128
+
129
+ func_name_split_list = func.__name__.split(".")
130
+ aten_api = func_name_split_list[0]
131
+ try:
132
+ aten_api_overload_name = func_name_split_list[1]
133
+ except IndexError:
134
+ logger.error(f"Please check the func name {func.__name__}!")
135
+ return func(*args, **kwargs)
136
+
137
+ self.enable_autogard(aten_api)
138
+ if aten_api in self.aten_ops_blacklist:
139
+ npu_out = func(*args, **kwargs)
140
+ return npu_out
141
+
142
+ call_stack = get_callstack()
143
+ self.call_stack_list.append(call_stack)
144
+ self.api_index += 1
145
+ if aten_api not in self.single_api_index_dict:
146
+ self.single_api_index_dict[aten_api] = 1
147
+ else:
148
+ self.single_api_index_dict[aten_api] += 1
149
+
150
+ run_param = self.get_run_param(aten_api, func.__name__, aten_api_overload_name)
151
+
152
+ if self.debug_flag:
153
+ logger.info(f'Dispatch Info: Rank[{self.device_id}], Pid[{os.getpid()}], Func[{func.__name__}], '
154
+ f'Name[{run_param.aten_api}_{run_param.single_api_index}], '
155
+ f'Count[{self.api_index}], Sys[{get_sys_info()}]')
156
+
157
+ cpu_args = []
158
+ cpu_kwargs = []
159
+ data_to_cpu(args, 0, cpu_args)
160
+ data_to_cpu(kwargs, 0, cpu_kwargs)
161
+ cpu_args = cpu_args[0]
162
+ cpu_kwargs = cpu_kwargs[0]
163
+
164
+ with TimeStatistics("NPU RUN", run_param):
165
+ npu_out = func(*args, **kwargs)
166
+ npu_out_cpu = []
167
+ data_to_cpu(npu_out, 0, npu_out_cpu)
168
+ npu_out_cpu = npu_out_cpu[0]
169
+
170
+ with TimeStatistics("CPU RUN", run_param):
171
+ cpu_out = func(*cpu_args, **cpu_kwargs)
172
+
173
+ if isinstance(cpu_out, torch.Tensor) and cpu_out.dtype in [torch.bfloat16, torch.float16, torch.half]:
174
+ cpu_out = cpu_out.float()
175
+
176
+ if self.process_num == 0:
177
+ self.all_summary.append([])
178
+ data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, func, npu_out_cpu, cpu_out, self.lock)
179
+ dispatch_workflow(run_param, data_info)
180
+ else:
181
+ self.lock.acquire()
182
+ self.all_summary.append([])
183
+ self.lock.release()
184
+ run_param.process_flag = True
185
+ if self.check_fun(func, run_param):
186
+ data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out,
187
+ self.lock)
188
+ self.pool.apply_async(func=dispatch_multiprocess, args=(run_param, data_info),
189
+ error_callback=error_call)
190
+ else:
191
+ logger.error("can not get correct function please set process_num=0")
192
+ return npu_out
193
+
194
+ @staticmethod
195
+ def check_fun(func, run_param):
196
+ if hasattr(torch.ops.aten, run_param.aten_api):
197
+ aten_func = getattr(torch.ops.aten, run_param.aten_api)
198
+ if hasattr(aten_func, run_param.aten_api_overload_name):
199
+ aten_overload_func = getattr(aten_func, run_param.aten_api_overload_name)
200
+ if id(aten_overload_func) == id(func):
201
+ run_param.func_namespace = "aten"
202
+ return True
203
+ return False
204
+
205
+ def get_dir_name(self, tag):
206
+ # guarantee file uniqueness
207
+ time.sleep(1)
208
+ time_now = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
209
+ if tag is None or not isinstance(tag, str):
210
+ logger.warning('There is not tag or the type of tag is not string.')
211
+ dir_name = f'msprobe_rank{self.device_id}_{time_now}'
212
+ else:
213
+ dir_name = f'msprobe_{tag}_rank{self.device_id}_{time_now}'
214
+ return dir_name
215
+
216
+ def get_ops(self, file_path):
217
+ yaml_file = load_yaml(file_path)
218
+ self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist')
219
+ self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard')
220
+
221
+ def filter_dump_api(self):
222
+ if self.dump_mode != Const.LIST or not self.dump_api_list:
223
+ self.dump_api_list = []
224
+ return
225
+ aten_api_list = dir(torch.ops.aten)
226
+ dump_api_list = []
227
+ for aten_api in self.dump_api_list:
228
+ if aten_api in aten_api_list:
229
+ dump_api_list.append(aten_api)
230
+ else:
231
+ logger.warning(f'{aten_api} is not aten api will not dump, please refer to torch.ops.aten')
232
+ self.dump_api_list = dump_api_list
233
+
234
+ def get_run_param(self, aten_api, func_name, aten_api_overload_name):
235
+ run_param = DispatchRunParam(self.debug_flag, self.device_id, self.root_npu_path, self.root_cpu_path,
236
+ self.process_num, self.comparator)
237
+ run_param.dump_flag, run_param.auto_dump_flag = self.get_dump_flag(aten_api)
238
+ run_param.func_name = func_name
239
+ run_param.aten_api = aten_api
240
+ run_param.aten_api_overload_name = aten_api_overload_name
241
+ run_param.single_api_index = self.single_api_index_dict[aten_api]
242
+ run_param.api_index = self.api_index
243
+ return run_param
244
+
245
+ def get_dump_flag(self, aten_api):
246
+ dump_flag = False
247
+ auto_dump_flag = False
248
+ if self.dump_mode == Const.ALL:
249
+ dump_flag = True
250
+ if self.dump_mode == Const.LIST and aten_api in self.dump_api_list:
251
+ dump_flag = True
252
+ if self.dump_mode == Const.AUTO:
253
+ auto_dump_flag = True
254
+ return dump_flag, auto_dump_flag
255
+
256
+ def check_param(self):
257
+ if self.dump_mode not in Const.ONLINE_DUMP_MODE:
258
+ logger.error('The parameter "dump mode" can only be one of {}.'.format(Const.ONLINE_DUMP_MODE))
259
+ raise DispatchException(DispatchException.INVALID_PARAMETER)
260
+ if not isinstance(self.dump_api_list, list):
261
+ logger.error('The type of parameter "api_list" can only be list.')
262
+ raise DispatchException(DispatchException.INVALID_PARAMETER)
263
+ if not isinstance(self.debug_flag, bool):
264
+ logger.error('The type of parameter "debug" can only be bool.')
265
+ raise DispatchException(DispatchException.INVALID_PARAMETER)
266
+ if not isinstance(self.process_num, int) or self.process_num < 0:
267
+ logger.error('The type of parameter "process_num" can only be int and it should not be less than 0.')
268
+ raise DispatchException(DispatchException.INVALID_PARAMETER)
269
+
270
+ def enable_autogard(self, aten_api):
271
+ if aten_api in self.npu_adjust_autogard:
274
272
  torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False)