mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,36 +1,34 @@
1
- import os
2
- import yaml
3
- from msprobe.core.common.file_check import FileOpen
4
- from msprobe.core.common.utils import CompareException
5
-
6
-
7
- class AtenIrMapping():
8
- def __init__(self):
9
- cur_path = os.path.dirname(os.path.realpath(__file__))
10
- yaml_path = os.path.join(cur_path, "mapping.yaml")
11
- with FileOpen(yaml_path, 'r') as f:
12
- self.aten_mapping = yaml.safe_load(f)
13
-
14
- def match(self, op1, op2):
15
- if "Aten" in op1 and "Aten" not in op2:
16
- return self.match_op(op1, op2)
17
- else:
18
- return self.match_op(op2, op1)
19
-
20
- def match_op(self, aten_op, torch_op):
21
- try:
22
- aten_op_raw_name_overload = '_'.join(aten_op.split("_")[1:-3])
23
- aten_op_raw_name = aten_op_raw_name_overload.split('.')[0]
24
- torch_op_raw_name = '_'.join(torch_op.split("_")[1:-3]).lower()
25
- except IndexError as e:
26
- err_msg = f"Dump op name format error: {aten_op}, {torch_op}. Your dump data may be corrupted."
27
- raise CompareException.INVALID_DATA_ERROR(err_msg) from e
28
- matching_op = self.aten_mapping.get(aten_op_raw_name)
29
- if matching_op is None:
30
- return False
31
- if matching_op.lower() == torch_op_raw_name:
32
- return True
33
- return False
34
-
35
-
36
- graph_mapping = AtenIrMapping()
1
+ import os
2
+ from msprobe.core.common.utils import CompareException
3
+ from msprobe.core.common.file_utils import load_yaml
4
+
5
+
6
+ class AtenIrMapping():
7
+ def __init__(self):
8
+ cur_path = os.path.dirname(os.path.realpath(__file__))
9
+ yaml_path = os.path.join(cur_path, "mapping.yaml")
10
+ self.aten_mapping = load_yaml(yaml_path)
11
+
12
+ def match(self, op1, op2):
13
+ if "Aten" in op1 and "Aten" not in op2:
14
+ return self.match_op(op1, op2)
15
+ else:
16
+ return self.match_op(op2, op1)
17
+
18
+ def match_op(self, aten_op, torch_op):
19
+ try:
20
+ aten_op_raw_name_overload = '_'.join(aten_op.split("_")[1:-3])
21
+ aten_op_raw_name = aten_op_raw_name_overload.split('.')[0]
22
+ torch_op_raw_name = '_'.join(torch_op.split("_")[1:-3]).lower()
23
+ except IndexError as e:
24
+ err_msg = f"Dump op name format error: {aten_op}, {torch_op}. Your dump data may be corrupted."
25
+ raise CompareException.INVALID_DATA_ERROR(err_msg) from e
26
+ matching_op = self.aten_mapping.get(aten_op_raw_name)
27
+ if matching_op is None:
28
+ return False
29
+ if matching_op.lower() == torch_op_raw_name:
30
+ return True
31
+ return False
32
+
33
+
34
+ graph_mapping = AtenIrMapping()
@@ -0,0 +1,50 @@
1
+ import os.path
2
+ import torch
3
+ from msprobe.core.common.const import FileCheckConst
4
+ from msprobe.pytorch.common.log import logger
5
+ from msprobe.core.common.exceptions import FileCheckException
6
+ from msprobe.core.compare.acc_compare import Comparator
7
+ from msprobe.core.common.utils import check_configuration_param, task_dumppath_get, check_compare_param, CompareException
8
+ from msprobe.core.common.file_utils import FileChecker, create_directory
9
+ from msprobe.pytorch.common.utils import load_pt
10
+
11
+
12
+ class PTComparator (Comparator):
13
+ def __init__(self):
14
+ self.frame_name = PTComparator.__name__
15
+
16
+ def read_npy_data(self, dir_path, file_name):
17
+ data_path = os.path.join(dir_path, file_name)
18
+ path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
19
+ FileCheckConst.PT_SUFFIX, False)
20
+ data_path = path_checker.common_check()
21
+ try:
22
+ data_value = load_pt(data_path,
23
+ to_cpu=True).detach() # detach because numpy can not process gradient information
24
+ except RuntimeError as e:
25
+ # 这里捕获 load_pt 中抛出的异常
26
+ logger.error(f"Failed to load the .pt file at {data_path}.")
27
+ raise CompareException(CompareException.INVALID_FILE_ERROR) from e
28
+ except AttributeError as e:
29
+ # 这里捕获 detach 方法抛出的异常
30
+ logger.error(f"Failed to detach the loaded tensor.")
31
+ raise CompareException(CompareException.DETACH_ERROR) from e
32
+ if data_value.dtype == torch.bfloat16:
33
+ data_value = data_value.to(torch.float32)
34
+ data_value = data_value.numpy()
35
+ return data_value
36
+
37
+
38
+ def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False):
39
+ try:
40
+ summary_compare, md5_compare = task_dumppath_get(input_param)
41
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
42
+ create_directory(output_path)
43
+ check_compare_param(input_param, output_path, summary_compare, md5_compare)
44
+ except (CompareException, FileCheckException) as error:
45
+ logger.error('Compare failed. Please check the arguments and do it again!')
46
+ raise CompareException(error.code) from error
47
+ pt_comparator = PTComparator()
48
+ pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
49
+ auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
50
+ md5_compare=md5_compare)
@@ -1,86 +1,95 @@
1
- from msprobe.pytorch.common import seed_all
2
- from msprobe.pytorch.common.log import logger
3
- from msprobe.core.common.const import Const
4
-
5
-
6
- class DebuggerConfig:
7
- def __init__(self, common_config, task_config, task, dump_path, level):
8
- self.dump_path = dump_path if dump_path else common_config.dump_path
9
- self.task = task or common_config.task or Const.STATISTICS
10
- self.rank = common_config.rank if common_config.rank else []
11
- self.step = common_config.step if common_config.step else []
12
- self.level = level or common_config.level or "L1"
13
- self.seed = common_config.seed if common_config.seed else 1234
14
- self.is_deterministic = common_config.is_deterministic
15
- self.enable_dataloader = common_config.enable_dataloader
16
- self.scope = task_config.scope if task_config.scope else []
17
- self.list = task_config.list if task_config.list else []
18
- self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
19
- self.backward_input_list = task_config.backward_input if task_config.backward_input else []
20
- self.backward_input = {}
21
- self.acl_config = common_config.acl_config if common_config.acl_config else ""
22
- self.is_forward_acl_dump = True
23
- self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
24
- self.overflow_num = task_config.overflow_num if task_config.overflow_num else 1
25
- self.framework = Const.PT_FRAMEWORK
26
-
27
- if self.task == Const.FREE_BENCHMARK:
28
- self.fuzz_device = task_config.fuzz_device if task_config.fuzz_device else 'npu'
29
- self.handler_type = task_config.handler_type if task_config.handler_type else 'check'
30
- self.pert_mode = task_config.pert_mode if task_config.pert_mode else 'improve_precision'
31
- self.fuzz_level = task_config.fuzz_level if task_config.fuzz_level else 'L1'
32
- self.fuzz_stage = task_config.fuzz_stage if task_config.fuzz_stage else 'forward'
33
- self.preheat_config = {
34
- "if_preheat": task_config.if_preheat if task_config.if_preheat is not None else True,
35
- "preheat_step": task_config.preheat_step if task_config.preheat_step else 15,
36
- "max_sample": task_config.max_sample if task_config.max_sample else 20,
37
- }
38
-
39
- self.check()
40
- if self.step:
41
- self.step.sort()
42
- if self.level == "L2":
43
- if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
44
- raise ValueError("scope must be configured as a list with one api name")
45
- if isinstance(self.scope[0], str) and Const.BACKWARD in self.scope[0] and not self.backward_input_list:
46
- raise ValueError("backward_input must be configured when scope contains 'backward'")
47
- if Const.BACKWARD in self.scope[0]:
48
- self.is_forward_acl_dump = False
49
- for index, scope_spec in enumerate(self.scope):
50
- self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
51
- self.backward_input[self.scope[index]] = self.backward_input_list[index]
52
- seed_all(self.seed, self.is_deterministic)
53
-
54
- def check_kwargs(self):
55
- if self.task and self.task not in Const.TASK_LIST:
56
- raise Exception("task is invalid")
57
- if self.level and self.level not in Const.LEVEL_LIST:
58
- raise Exception("level is invalid")
59
- if not self.dump_path:
60
- raise Exception("Invalid dump path, please check your config")
61
-
62
- def check(self):
63
- self.check_kwargs()
64
- self._check_rank()
65
- self._check_step()
66
- return True
67
-
68
- def check_model(self, model):
69
- if self.level in ["L0", "mix"] and not model:
70
- raise Exception(
71
- f"For level {self.level}, PrecisionDebugger must receive a model argument."
72
- )
73
-
74
- def _check_rank(self):
75
- if self.rank:
76
- for rank_id in self.rank:
77
- if not isinstance(rank_id, int) or rank_id < 0:
78
- raise ValueError(f"rank {self.rank} must be an integer and greater than or equal to 0.")
79
- else:
80
- logger.warning_on_rank_0(f"Rank argument is provided. Only rank {self.rank} data will be dumpped.")
81
-
82
- def _check_step(self):
83
- if self.step:
84
- for s in self.step:
85
- if not isinstance(s, int) or s < 0:
86
- raise ValueError(f"step element {s} must be an integer and greater than or equal to 0.")
1
+ from msprobe.pytorch.common import seed_all
2
+ from msprobe.pytorch.common.log import logger
3
+ from msprobe.core.common.const import Const
4
+
5
+
6
+ class DebuggerConfig:
7
+ def __init__(self, common_config, task_config, task, dump_path, level):
8
+ self.dump_path = dump_path if dump_path else common_config.dump_path
9
+ self.task = task or common_config.task or Const.STATISTICS
10
+ self.rank = common_config.rank if common_config.rank else []
11
+ self.step = common_config.step if common_config.step else []
12
+ self.level = level or common_config.level or "L1"
13
+ self.seed = common_config.seed if common_config.seed else 1234
14
+ self.is_deterministic = common_config.is_deterministic
15
+ self.enable_dataloader = common_config.enable_dataloader
16
+ self.scope = task_config.scope if task_config.scope else []
17
+ self.list = task_config.list if task_config.list else []
18
+ self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
19
+ self.backward_input_list = task_config.backward_input if task_config.backward_input else []
20
+ self.backward_input = {}
21
+ self.acl_config = common_config.acl_config if common_config.acl_config else ""
22
+ self.is_forward_acl_dump = True
23
+ self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
24
+ self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
25
+ self.framework = Const.PT_FRAMEWORK
26
+
27
+ if self.task == Const.FREE_BENCHMARK:
28
+ self.fuzz_device = task_config.fuzz_device if task_config.fuzz_device else 'npu'
29
+ self.handler_type = task_config.handler_type if task_config.handler_type else 'check'
30
+ self.pert_mode = task_config.pert_mode if task_config.pert_mode else 'improve_precision'
31
+ self.fuzz_level = task_config.fuzz_level if task_config.fuzz_level else 'L1'
32
+ self.fuzz_stage = task_config.fuzz_stage if task_config.fuzz_stage else 'forward'
33
+ self.preheat_config = {
34
+ "if_preheat": task_config.if_preheat if task_config.if_preheat is not None else True,
35
+ "preheat_step": task_config.preheat_step if task_config.preheat_step else 15,
36
+ "max_sample": task_config.max_sample if task_config.max_sample else 20,
37
+ }
38
+
39
+ self.online_run_ut = False
40
+ if self.task == Const.TENSOR:
41
+ # dump api tensor and collaborate with online run_ut
42
+ self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
43
+ self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
44
+ self.tls_path = task_config.tls_path if task_config.tls_path else ""
45
+ self.host = task_config.host if task_config.host else ""
46
+ self.port = task_config.port if task_config.port else -1
47
+
48
+ self.check()
49
+ if self.step:
50
+ self.step.sort()
51
+ if self.level == "L2":
52
+ if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
53
+ raise ValueError("scope must be configured as a list with one api name")
54
+ if isinstance(self.scope[0], str) and Const.BACKWARD in self.scope[0] and not self.backward_input_list:
55
+ raise ValueError("backward_input must be configured when scope contains 'backward'")
56
+ if Const.BACKWARD in self.scope[0]:
57
+ self.is_forward_acl_dump = False
58
+ for index, scope_spec in enumerate(self.scope):
59
+ self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
60
+ self.backward_input[self.scope[index]] = self.backward_input_list[index]
61
+ seed_all(self.seed, self.is_deterministic)
62
+
63
+ def check_kwargs(self):
64
+ if self.task and self.task not in Const.TASK_LIST:
65
+ raise Exception("task is invalid")
66
+ if self.level and self.level not in Const.LEVEL_LIST:
67
+ raise Exception("level is invalid")
68
+ if not self.dump_path:
69
+ raise Exception("Invalid dump path, please check your config")
70
+
71
+ def check(self):
72
+ self.check_kwargs()
73
+ self._check_rank()
74
+ self._check_step()
75
+ return True
76
+
77
+ def check_model(self, model):
78
+ if self.level in ["L0", "mix"] and not model:
79
+ raise Exception(
80
+ f"For level {self.level}, PrecisionDebugger must receive a model argument."
81
+ )
82
+
83
+ def _check_rank(self):
84
+ if self.rank:
85
+ for rank_id in self.rank:
86
+ if not isinstance(rank_id, int) or rank_id < 0:
87
+ raise ValueError(f"rank {self.rank} must be an integer and greater than or equal to 0.")
88
+ else:
89
+ logger.warning_on_rank_0(f"Rank argument is provided. Only rank {self.rank} data will be dumpped.")
90
+
91
+ def _check_step(self):
92
+ if self.step:
93
+ for s in self.step:
94
+ if not isinstance(s, int) or s < 0:
95
+ raise ValueError(f"step element {s} must be an integer and greater than or equal to 0.")
@@ -1,95 +1,125 @@
1
- import torch
2
- from torch.utils.data import dataloader
3
- from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
4
- from msprobe.pytorch.service import Service
5
- from msprobe.pytorch.common.log import logger
6
- from msprobe.pytorch.pt_config import parse_json_config
7
- from msprobe.core.common.exceptions import MsaccException
8
-
9
-
10
- class PrecisionDebugger:
11
- _instance = None
12
-
13
- def __new__(cls, *args, **kwargs):
14
- if cls._instance is None:
15
- cls._instance = super(PrecisionDebugger, cls).__new__(cls)
16
- cls._instance.config = None
17
- cls._instance.enable_dataloader = False
18
- return cls._instance
19
-
20
- def __init__(
21
- self,
22
- config_path=None,
23
- task=None,
24
- dump_path=None,
25
- level=None,
26
- model=None,
27
- step=None,
28
- ):
29
- if not hasattr(self, "initialized"):
30
- self.initialized = True
31
- self.model = self.check_model_valid(model)
32
- common_config, task_config = parse_json_config(config_path, task)
33
- if step:
34
- common_config.step = step
35
- self.config = DebuggerConfig(
36
- common_config, task_config, task, dump_path, level
37
- )
38
- self.config.check_model(self.model)
39
- self.service = Service(self.config)
40
- self.enable_dataloader = self.config.enable_dataloader
41
- if self.enable_dataloader:
42
- logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
43
- dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
44
-
45
- @property
46
- def instance(self):
47
- return self._instance
48
-
49
- @staticmethod
50
- def check_model_valid(model):
51
- if not model or isinstance(model, torch.nn.Module):
52
- return model
53
- raise MsaccException(
54
- MsaccException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。"
55
- )
56
-
57
- @classmethod
58
- def start(cls):
59
- instance = cls._instance
60
- if not instance:
61
- raise Exception("No instance of PrecisionDebugger found.")
62
- if instance.enable_dataloader:
63
- logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
64
- else:
65
- instance.service.start(instance.model)
66
-
67
- @classmethod
68
- def stop(cls):
69
- instance = cls._instance
70
- if not instance:
71
- raise Exception("PrecisionDebugger instance is not created.")
72
- if instance.enable_dataloader:
73
- logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
74
- else:
75
- instance.service.stop()
76
-
77
- @classmethod
78
- def step(cls):
79
- if not cls._instance:
80
- raise Exception("PrecisionDebugger instance is not created.")
81
- cls._instance.service.step()
82
-
83
-
84
- def iter_tracer(func):
85
- def func_wrapper(*args, **kwargs):
86
- debugger_instance = PrecisionDebugger.instance
87
- debugger_instance.enable_dataloader = False
88
- if not debugger_instance.service.first_start:
89
- debugger_instance.stop()
90
- debugger_instance.step()
91
- result = func(*args, **kwargs)
92
- debugger_instance.start()
93
- debugger_instance.enable_dataloader = True
94
- return result
95
- return func_wrapper
1
+ import torch
2
+ from torch.utils.data import dataloader
3
+ from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
4
+ from msprobe.pytorch.service import Service
5
+ from msprobe.pytorch.common.log import logger
6
+ from msprobe.pytorch.pt_config import parse_json_config
7
+ from msprobe.core.common.exceptions import MsprobeException
8
+ from msprobe.core.common.const import Const
9
+ from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
10
+
11
+
12
+ class PrecisionDebugger:
13
+ _instance = None
14
+ tasks_not_need_debugger = [Const.GRAD_PROBE]
15
+
16
+ def __new__(cls, *args, **kwargs):
17
+ if cls._instance is None:
18
+ cls._instance = super(PrecisionDebugger, cls).__new__(cls)
19
+ cls._instance.config = None
20
+ cls._instance.enable_dataloader = False
21
+ return cls._instance
22
+
23
+ def __init__(
24
+ self,
25
+ config_path=None,
26
+ task=None,
27
+ dump_path=None,
28
+ level=None,
29
+ model=None,
30
+ step=None,
31
+ ):
32
+ if not hasattr(self, "initialized"):
33
+ self.api_origin = False
34
+ self.initialized = True
35
+ self.model = self.check_model_valid(model)
36
+ common_config, task_config = parse_json_config(config_path, task)
37
+ self.task = common_config.task
38
+ if self.task == Const.GRAD_PROBE:
39
+ self.gm = GradientMonitor(common_config, task_config)
40
+ return
41
+ if step:
42
+ common_config.step = step
43
+ self.config = DebuggerConfig(
44
+ common_config, task_config, task, dump_path, level
45
+ )
46
+ self.config.check_model(self.model)
47
+ self.service = Service(self.config)
48
+ self.enable_dataloader = self.config.enable_dataloader
49
+ if self.enable_dataloader:
50
+ logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
51
+ dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
52
+
53
+ @property
54
+ def instance(self):
55
+ return self._instance
56
+
57
+ @staticmethod
58
+ def check_model_valid(model):
59
+ if not model or isinstance(model, torch.nn.Module):
60
+ return model
61
+ raise MsprobeException(
62
+ MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。"
63
+ )
64
+
65
+ @classmethod
66
+ def start(cls):
67
+ instance = cls._instance
68
+ if instance.task in PrecisionDebugger.tasks_not_need_debugger:
69
+ return
70
+ if not instance:
71
+ raise Exception("No instance of PrecisionDebugger found.")
72
+ if instance.enable_dataloader:
73
+ logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
74
+ else:
75
+ instance.service.start(instance.model, instance.api_origin)
76
+ instance.api_origin = False
77
+
78
+ # 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
79
+ @classmethod
80
+ def forward_backward_dump_end(cls):
81
+ instance = cls._instance
82
+ instance.service.forward_backward_dump_end()
83
+ instance.api_origin = True
84
+
85
+ @classmethod
86
+ def stop(cls):
87
+ instance = cls._instance
88
+ if instance.task in PrecisionDebugger.tasks_not_need_debugger:
89
+ return
90
+ if not instance:
91
+ raise Exception("PrecisionDebugger instance is not created.")
92
+ if instance.enable_dataloader:
93
+ logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
94
+ else:
95
+ instance.service.stop()
96
+
97
+ @classmethod
98
+ def step(cls):
99
+ if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
100
+ return
101
+ if not cls._instance:
102
+ raise Exception("PrecisionDebugger instance is not created.")
103
+ cls._instance.service.step()
104
+
105
+ @classmethod
106
+ def monitor(cls, model):
107
+ if not cls._instance:
108
+ raise Exception("PrecisionDebugger instance is not created.")
109
+ if cls._instance.task != Const.GRAD_PROBE:
110
+ return
111
+ cls._instance.gm.monitor(model)
112
+
113
+
114
+ def iter_tracer(func):
115
+ def func_wrapper(*args, **kwargs):
116
+ debugger_instance = PrecisionDebugger.instance
117
+ debugger_instance.enable_dataloader = False
118
+ if not debugger_instance.service.first_start:
119
+ debugger_instance.stop()
120
+ debugger_instance.step()
121
+ result = func(*args, **kwargs)
122
+ debugger_instance.start()
123
+ debugger_instance.enable_dataloader = True
124
+ return result
125
+ return func_wrapper
@@ -1,8 +1,8 @@
1
- from msprobe.core.common.log import logger
2
- from msprobe.core.common.exceptions import FreeBenchmarkException
3
- from msprobe.core.common.const import Const
4
-
5
- from .main import FreeBenchmarkCheck
6
- from .common.params import UnequalRow
7
-
8
- __all__ = [FreeBenchmarkCheck, UnequalRow]
1
+ from msprobe.pytorch.common.log import logger
2
+ from msprobe.core.common.exceptions import FreeBenchmarkException
3
+ from msprobe.core.common.const import Const
4
+
5
+ from .main import FreeBenchmarkCheck
6
+ from .common.params import UnequalRow
7
+
8
+ __all__ = [FreeBenchmarkCheck, UnequalRow]