mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,33 +1,34 @@
1
- import os
2
- from msprobe.core.common.utils import CompareException, load_yaml
3
-
4
-
5
- class AtenIrMapping():
6
- def __init__(self):
7
- cur_path = os.path.dirname(os.path.realpath(__file__))
8
- yaml_path = os.path.join(cur_path, "mapping.yaml")
9
- self.aten_mapping = load_yaml(yaml_path)
10
-
11
- def match(self, op1, op2):
12
- if "Aten" in op1 and "Aten" not in op2:
13
- return self.match_op(op1, op2)
14
- else:
15
- return self.match_op(op2, op1)
16
-
17
- def match_op(self, aten_op, torch_op):
18
- try:
19
- aten_op_raw_name_overload = '_'.join(aten_op.split("_")[1:-3])
20
- aten_op_raw_name = aten_op_raw_name_overload.split('.')[0]
21
- torch_op_raw_name = '_'.join(torch_op.split("_")[1:-3]).lower()
22
- except IndexError as e:
23
- err_msg = f"Dump op name format error: {aten_op}, {torch_op}. Your dump data may be corrupted."
24
- raise CompareException.INVALID_DATA_ERROR(err_msg) from e
25
- matching_op = self.aten_mapping.get(aten_op_raw_name)
26
- if matching_op is None:
27
- return False
28
- if matching_op.lower() == torch_op_raw_name:
29
- return True
30
- return False
31
-
32
-
33
- graph_mapping = AtenIrMapping()
1
+ import os
2
+ from msprobe.core.common.utils import CompareException
3
+ from msprobe.core.common.file_utils import load_yaml
4
+
5
+
6
+ class AtenIrMapping():
7
+ def __init__(self):
8
+ cur_path = os.path.dirname(os.path.realpath(__file__))
9
+ yaml_path = os.path.join(cur_path, "mapping.yaml")
10
+ self.aten_mapping = load_yaml(yaml_path)
11
+
12
+ def match(self, op1, op2):
13
+ if "Aten" in op1 and "Aten" not in op2:
14
+ return self.match_op(op1, op2)
15
+ else:
16
+ return self.match_op(op2, op1)
17
+
18
+ def match_op(self, aten_op, torch_op):
19
+ try:
20
+ aten_op_raw_name_overload = '_'.join(aten_op.split("_")[1:-3])
21
+ aten_op_raw_name = aten_op_raw_name_overload.split('.')[0]
22
+ torch_op_raw_name = '_'.join(torch_op.split("_")[1:-3]).lower()
23
+ except IndexError as e:
24
+ err_msg = f"Dump op name format error: {aten_op}, {torch_op}. Your dump data may be corrupted."
25
+ raise CompareException.INVALID_DATA_ERROR(err_msg) from e
26
+ matching_op = self.aten_mapping.get(aten_op_raw_name)
27
+ if matching_op is None:
28
+ return False
29
+ if matching_op.lower() == torch_op_raw_name:
30
+ return True
31
+ return False
32
+
33
+
34
+ graph_mapping = AtenIrMapping()
@@ -1,40 +1,50 @@
1
- import os.path
2
- import torch
3
- from msprobe.core.common.const import FileCheckConst, Const
4
- from msprobe.core.common.log import logger
5
- from msprobe.core.common.exceptions import FileCheckException
6
- from msprobe.core.compare.acc_compare import Comparator
7
- from msprobe.core.common.utils import create_directory, check_configuration_param, task_dumppath_get, \
8
- check_compare_param, FileChecker
9
- from msprobe.core.common.utils import CompareException
10
-
11
-
12
- class PTComparator (Comparator):
13
- def __init__(self):
14
- self.frame_name = PTComparator.__name__
15
-
16
- def read_npy_data(self, dir_path, file_name):
17
- data_path = os.path.join(dir_path, file_name)
18
- path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
19
- FileCheckConst.PT_SUFFIX, False)
20
- data_path = path_checker.common_check()
21
- data_value = torch.load(data_path, map_location=torch.device('cpu')).detach() # detach for less memory
22
- if data_value.dtype == torch.bfloat16:
23
- data_value = data_value.to(torch.float32)
24
- data_value = data_value.numpy()
25
- return data_value
26
-
27
-
28
- def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False):
29
- try:
30
- summary_compare, md5_compare = task_dumppath_get(input_param)
31
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
32
- create_directory(output_path)
33
- check_compare_param(input_param, output_path, summary_compare, md5_compare)
34
- except (CompareException, FileCheckException) as error:
35
- logger.error('Compare failed. Please check the arguments and do it again!')
36
- raise CompareException(error.code) from error
37
- pt_comparator = PTComparator()
38
- pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
39
- auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
40
- md5_compare=md5_compare)
1
+ import os.path
2
+ import torch
3
+ from msprobe.core.common.const import FileCheckConst
4
+ from msprobe.pytorch.common.log import logger
5
+ from msprobe.core.common.exceptions import FileCheckException
6
+ from msprobe.core.compare.acc_compare import Comparator
7
+ from msprobe.core.common.utils import check_configuration_param, task_dumppath_get, check_compare_param, CompareException
8
+ from msprobe.core.common.file_utils import FileChecker, create_directory
9
+ from msprobe.pytorch.common.utils import load_pt
10
+
11
+
12
+ class PTComparator (Comparator):
13
+ def __init__(self):
14
+ self.frame_name = PTComparator.__name__
15
+
16
+ def read_npy_data(self, dir_path, file_name):
17
+ data_path = os.path.join(dir_path, file_name)
18
+ path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
19
+ FileCheckConst.PT_SUFFIX, False)
20
+ data_path = path_checker.common_check()
21
+ try:
22
+ data_value = load_pt(data_path,
23
+ to_cpu=True).detach() # detach because numpy can not process gradient information
24
+ except RuntimeError as e:
25
+ # 这里捕获 load_pt 中抛出的异常
26
+ logger.error(f"Failed to load the .pt file at {data_path}.")
27
+ raise CompareException(CompareException.INVALID_FILE_ERROR) from e
28
+ except AttributeError as e:
29
+ # 这里捕获 detach 方法抛出的异常
30
+ logger.error(f"Failed to detach the loaded tensor.")
31
+ raise CompareException(CompareException.DETACH_ERROR) from e
32
+ if data_value.dtype == torch.bfloat16:
33
+ data_value = data_value.to(torch.float32)
34
+ data_value = data_value.numpy()
35
+ return data_value
36
+
37
+
38
+ def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False):
39
+ try:
40
+ summary_compare, md5_compare = task_dumppath_get(input_param)
41
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
42
+ create_directory(output_path)
43
+ check_compare_param(input_param, output_path, summary_compare, md5_compare)
44
+ except (CompareException, FileCheckException) as error:
45
+ logger.error('Compare failed. Please check the arguments and do it again!')
46
+ raise CompareException(error.code) from error
47
+ pt_comparator = PTComparator()
48
+ pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
49
+ auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
50
+ md5_compare=md5_compare)
@@ -1,95 +1,95 @@
1
- from msprobe.pytorch.common import seed_all
2
- from msprobe.pytorch.common.log import logger
3
- from msprobe.core.common.const import Const
4
-
5
-
6
- class DebuggerConfig:
7
- def __init__(self, common_config, task_config, task, dump_path, level):
8
- self.dump_path = dump_path if dump_path else common_config.dump_path
9
- self.task = task or common_config.task or Const.STATISTICS
10
- self.rank = common_config.rank if common_config.rank else []
11
- self.step = common_config.step if common_config.step else []
12
- self.level = level or common_config.level or "L1"
13
- self.seed = common_config.seed if common_config.seed else 1234
14
- self.is_deterministic = common_config.is_deterministic
15
- self.enable_dataloader = common_config.enable_dataloader
16
- self.scope = task_config.scope if task_config.scope else []
17
- self.list = task_config.list if task_config.list else []
18
- self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
19
- self.backward_input_list = task_config.backward_input if task_config.backward_input else []
20
- self.backward_input = {}
21
- self.acl_config = common_config.acl_config if common_config.acl_config else ""
22
- self.is_forward_acl_dump = True
23
- self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
24
- self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
25
- self.framework = Const.PT_FRAMEWORK
26
-
27
- if self.task == Const.FREE_BENCHMARK:
28
- self.fuzz_device = task_config.fuzz_device if task_config.fuzz_device else 'npu'
29
- self.handler_type = task_config.handler_type if task_config.handler_type else 'check'
30
- self.pert_mode = task_config.pert_mode if task_config.pert_mode else 'improve_precision'
31
- self.fuzz_level = task_config.fuzz_level if task_config.fuzz_level else 'L1'
32
- self.fuzz_stage = task_config.fuzz_stage if task_config.fuzz_stage else 'forward'
33
- self.preheat_config = {
34
- "if_preheat": task_config.if_preheat if task_config.if_preheat is not None else True,
35
- "preheat_step": task_config.preheat_step if task_config.preheat_step else 15,
36
- "max_sample": task_config.max_sample if task_config.max_sample else 20,
37
- }
38
-
39
- self.online_run_ut = False
40
- if self.task == Const.TENSOR:
41
- # dump api tensor and collaborate with online run_ut
42
- self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
43
- self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
44
- self.tls_path = task_config.tls_path if task_config.tls_path else ""
45
- self.host = task_config.host if task_config.host else ""
46
- self.port = task_config.port if task_config.port else -1
47
-
48
- self.check()
49
- if self.step:
50
- self.step.sort()
51
- if self.level == "L2":
52
- if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
53
- raise ValueError("scope must be configured as a list with one api name")
54
- if isinstance(self.scope[0], str) and Const.BACKWARD in self.scope[0] and not self.backward_input_list:
55
- raise ValueError("backward_input must be configured when scope contains 'backward'")
56
- if Const.BACKWARD in self.scope[0]:
57
- self.is_forward_acl_dump = False
58
- for index, scope_spec in enumerate(self.scope):
59
- self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
60
- self.backward_input[self.scope[index]] = self.backward_input_list[index]
61
- seed_all(self.seed, self.is_deterministic)
62
-
63
- def check_kwargs(self):
64
- if self.task and self.task not in Const.TASK_LIST:
65
- raise Exception("task is invalid")
66
- if self.level and self.level not in Const.LEVEL_LIST:
67
- raise Exception("level is invalid")
68
- if not self.dump_path:
69
- raise Exception("Invalid dump path, please check your config")
70
-
71
- def check(self):
72
- self.check_kwargs()
73
- self._check_rank()
74
- self._check_step()
75
- return True
76
-
77
- def check_model(self, model):
78
- if self.level in ["L0", "mix"] and not model:
79
- raise Exception(
80
- f"For level {self.level}, PrecisionDebugger must receive a model argument."
81
- )
82
-
83
- def _check_rank(self):
84
- if self.rank:
85
- for rank_id in self.rank:
86
- if not isinstance(rank_id, int) or rank_id < 0:
87
- raise ValueError(f"rank {self.rank} must be an integer and greater than or equal to 0.")
88
- else:
89
- logger.warning_on_rank_0(f"Rank argument is provided. Only rank {self.rank} data will be dumpped.")
90
-
91
- def _check_step(self):
92
- if self.step:
93
- for s in self.step:
94
- if not isinstance(s, int) or s < 0:
95
- raise ValueError(f"step element {s} must be an integer and greater than or equal to 0.")
1
+ from msprobe.pytorch.common import seed_all
2
+ from msprobe.pytorch.common.log import logger
3
+ from msprobe.core.common.const import Const
4
+
5
+
6
+ class DebuggerConfig:
7
+ def __init__(self, common_config, task_config, task, dump_path, level):
8
+ self.dump_path = dump_path if dump_path else common_config.dump_path
9
+ self.task = task or common_config.task or Const.STATISTICS
10
+ self.rank = common_config.rank if common_config.rank else []
11
+ self.step = common_config.step if common_config.step else []
12
+ self.level = level or common_config.level or "L1"
13
+ self.seed = common_config.seed if common_config.seed else 1234
14
+ self.is_deterministic = common_config.is_deterministic
15
+ self.enable_dataloader = common_config.enable_dataloader
16
+ self.scope = task_config.scope if task_config.scope else []
17
+ self.list = task_config.list if task_config.list else []
18
+ self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
19
+ self.backward_input_list = task_config.backward_input if task_config.backward_input else []
20
+ self.backward_input = {}
21
+ self.acl_config = common_config.acl_config if common_config.acl_config else ""
22
+ self.is_forward_acl_dump = True
23
+ self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
24
+ self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
25
+ self.framework = Const.PT_FRAMEWORK
26
+
27
+ if self.task == Const.FREE_BENCHMARK:
28
+ self.fuzz_device = task_config.fuzz_device if task_config.fuzz_device else 'npu'
29
+ self.handler_type = task_config.handler_type if task_config.handler_type else 'check'
30
+ self.pert_mode = task_config.pert_mode if task_config.pert_mode else 'improve_precision'
31
+ self.fuzz_level = task_config.fuzz_level if task_config.fuzz_level else 'L1'
32
+ self.fuzz_stage = task_config.fuzz_stage if task_config.fuzz_stage else 'forward'
33
+ self.preheat_config = {
34
+ "if_preheat": task_config.if_preheat if task_config.if_preheat is not None else True,
35
+ "preheat_step": task_config.preheat_step if task_config.preheat_step else 15,
36
+ "max_sample": task_config.max_sample if task_config.max_sample else 20,
37
+ }
38
+
39
+ self.online_run_ut = False
40
+ if self.task == Const.TENSOR:
41
+ # dump api tensor and collaborate with online run_ut
42
+ self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
43
+ self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
44
+ self.tls_path = task_config.tls_path if task_config.tls_path else ""
45
+ self.host = task_config.host if task_config.host else ""
46
+ self.port = task_config.port if task_config.port else -1
47
+
48
+ self.check()
49
+ if self.step:
50
+ self.step.sort()
51
+ if self.level == "L2":
52
+ if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
53
+ raise ValueError("scope must be configured as a list with one api name")
54
+ if isinstance(self.scope[0], str) and Const.BACKWARD in self.scope[0] and not self.backward_input_list:
55
+ raise ValueError("backward_input must be configured when scope contains 'backward'")
56
+ if Const.BACKWARD in self.scope[0]:
57
+ self.is_forward_acl_dump = False
58
+ for index, scope_spec in enumerate(self.scope):
59
+ self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
60
+ self.backward_input[self.scope[index]] = self.backward_input_list[index]
61
+ seed_all(self.seed, self.is_deterministic)
62
+
63
+ def check_kwargs(self):
64
+ if self.task and self.task not in Const.TASK_LIST:
65
+ raise Exception("task is invalid")
66
+ if self.level and self.level not in Const.LEVEL_LIST:
67
+ raise Exception("level is invalid")
68
+ if not self.dump_path:
69
+ raise Exception("Invalid dump path, please check your config")
70
+
71
+ def check(self):
72
+ self.check_kwargs()
73
+ self._check_rank()
74
+ self._check_step()
75
+ return True
76
+
77
+ def check_model(self, model):
78
+ if self.level in ["L0", "mix"] and not model:
79
+ raise Exception(
80
+ f"For level {self.level}, PrecisionDebugger must receive a model argument."
81
+ )
82
+
83
+ def _check_rank(self):
84
+ if self.rank:
85
+ for rank_id in self.rank:
86
+ if not isinstance(rank_id, int) or rank_id < 0:
87
+ raise ValueError(f"rank {self.rank} must be an integer and greater than or equal to 0.")
88
+ else:
89
+ logger.warning_on_rank_0(f"Rank argument is provided. Only rank {self.rank} data will be dumpped.")
90
+
91
+ def _check_step(self):
92
+ if self.step:
93
+ for s in self.step:
94
+ if not isinstance(s, int) or s < 0:
95
+ raise ValueError(f"step element {s} must be an integer and greater than or equal to 0.")
@@ -1,125 +1,125 @@
1
- import torch
2
- from torch.utils.data import dataloader
3
- from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
4
- from msprobe.pytorch.service import Service
5
- from msprobe.pytorch.common.log import logger
6
- from msprobe.pytorch.pt_config import parse_json_config
7
- from msprobe.core.common.exceptions import MsprobeException
8
- from msprobe.core.common.const import Const
9
- from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
10
-
11
-
12
- class PrecisionDebugger:
13
- _instance = None
14
- tasks_not_need_debugger = [Const.GRAD_PROBE]
15
-
16
- def __new__(cls, *args, **kwargs):
17
- if cls._instance is None:
18
- cls._instance = super(PrecisionDebugger, cls).__new__(cls)
19
- cls._instance.config = None
20
- cls._instance.enable_dataloader = False
21
- return cls._instance
22
-
23
- def __init__(
24
- self,
25
- config_path=None,
26
- task=None,
27
- dump_path=None,
28
- level=None,
29
- model=None,
30
- step=None,
31
- ):
32
- if not hasattr(self, "initialized"):
33
- self.api_origin = False
34
- self.initialized = True
35
- self.model = self.check_model_valid(model)
36
- common_config, task_config = parse_json_config(config_path, task)
37
- self.task = common_config.task
38
- if self.task == Const.GRAD_PROBE:
39
- self.gm = GradientMonitor(common_config, task_config)
40
- return
41
- if step:
42
- common_config.step = step
43
- self.config = DebuggerConfig(
44
- common_config, task_config, task, dump_path, level
45
- )
46
- self.config.check_model(self.model)
47
- self.service = Service(self.config)
48
- self.enable_dataloader = self.config.enable_dataloader
49
- if self.enable_dataloader:
50
- logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
51
- dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
52
-
53
- @property
54
- def instance(self):
55
- return self._instance
56
-
57
- @staticmethod
58
- def check_model_valid(model):
59
- if not model or isinstance(model, torch.nn.Module):
60
- return model
61
- raise MsprobeException(
62
- MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。"
63
- )
64
-
65
- @classmethod
66
- def start(cls):
67
- instance = cls._instance
68
- if instance.task in PrecisionDebugger.tasks_not_need_debugger:
69
- return
70
- if not instance:
71
- raise Exception("No instance of PrecisionDebugger found.")
72
- if instance.enable_dataloader:
73
- logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
74
- else:
75
- instance.service.start(instance.model, instance.api_origin)
76
- instance.api_origin = False
77
-
78
- # 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
79
- @classmethod
80
- def forward_backward_dump_end(cls):
81
- instance = cls._instance
82
- instance.service.forward_backward_dump_end()
83
- instance.api_origin = True
84
-
85
- @classmethod
86
- def stop(cls):
87
- instance = cls._instance
88
- if instance.task in PrecisionDebugger.tasks_not_need_debugger:
89
- return
90
- if not instance:
91
- raise Exception("PrecisionDebugger instance is not created.")
92
- if instance.enable_dataloader:
93
- logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
94
- else:
95
- instance.service.stop()
96
-
97
- @classmethod
98
- def step(cls):
99
- if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
100
- return
101
- if not cls._instance:
102
- raise Exception("PrecisionDebugger instance is not created.")
103
- cls._instance.service.step()
104
-
105
- @classmethod
106
- def monitor(cls, model):
107
- if not cls._instance:
108
- raise Exception("PrecisionDebugger instance is not created.")
109
- if cls._instance.task != Const.GRAD_PROBE:
110
- return
111
- cls._instance.gm.monitor(model)
112
-
113
-
114
- def iter_tracer(func):
115
- def func_wrapper(*args, **kwargs):
116
- debugger_instance = PrecisionDebugger.instance
117
- debugger_instance.enable_dataloader = False
118
- if not debugger_instance.service.first_start:
119
- debugger_instance.stop()
120
- debugger_instance.step()
121
- result = func(*args, **kwargs)
122
- debugger_instance.start()
123
- debugger_instance.enable_dataloader = True
124
- return result
125
- return func_wrapper
1
+ import torch
2
+ from torch.utils.data import dataloader
3
+ from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
4
+ from msprobe.pytorch.service import Service
5
+ from msprobe.pytorch.common.log import logger
6
+ from msprobe.pytorch.pt_config import parse_json_config
7
+ from msprobe.core.common.exceptions import MsprobeException
8
+ from msprobe.core.common.const import Const
9
+ from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
10
+
11
+
12
+ class PrecisionDebugger:
13
+ _instance = None
14
+ tasks_not_need_debugger = [Const.GRAD_PROBE]
15
+
16
+ def __new__(cls, *args, **kwargs):
17
+ if cls._instance is None:
18
+ cls._instance = super(PrecisionDebugger, cls).__new__(cls)
19
+ cls._instance.config = None
20
+ cls._instance.enable_dataloader = False
21
+ return cls._instance
22
+
23
+ def __init__(
24
+ self,
25
+ config_path=None,
26
+ task=None,
27
+ dump_path=None,
28
+ level=None,
29
+ model=None,
30
+ step=None,
31
+ ):
32
+ if not hasattr(self, "initialized"):
33
+ self.api_origin = False
34
+ self.initialized = True
35
+ self.model = self.check_model_valid(model)
36
+ common_config, task_config = parse_json_config(config_path, task)
37
+ self.task = common_config.task
38
+ if self.task == Const.GRAD_PROBE:
39
+ self.gm = GradientMonitor(common_config, task_config)
40
+ return
41
+ if step:
42
+ common_config.step = step
43
+ self.config = DebuggerConfig(
44
+ common_config, task_config, task, dump_path, level
45
+ )
46
+ self.config.check_model(self.model)
47
+ self.service = Service(self.config)
48
+ self.enable_dataloader = self.config.enable_dataloader
49
+ if self.enable_dataloader:
50
+ logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
51
+ dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
52
+
53
+ @property
54
+ def instance(self):
55
+ return self._instance
56
+
57
+ @staticmethod
58
+ def check_model_valid(model):
59
+ if not model or isinstance(model, torch.nn.Module):
60
+ return model
61
+ raise MsprobeException(
62
+ MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。"
63
+ )
64
+
65
+ @classmethod
66
+ def start(cls):
67
+ instance = cls._instance
68
+ if instance.task in PrecisionDebugger.tasks_not_need_debugger:
69
+ return
70
+ if not instance:
71
+ raise Exception("No instance of PrecisionDebugger found.")
72
+ if instance.enable_dataloader:
73
+ logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
74
+ else:
75
+ instance.service.start(instance.model, instance.api_origin)
76
+ instance.api_origin = False
77
+
78
+ # 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
79
+ @classmethod
80
+ def forward_backward_dump_end(cls):
81
+ instance = cls._instance
82
+ instance.service.forward_backward_dump_end()
83
+ instance.api_origin = True
84
+
85
+ @classmethod
86
+ def stop(cls):
87
+ instance = cls._instance
88
+ if instance.task in PrecisionDebugger.tasks_not_need_debugger:
89
+ return
90
+ if not instance:
91
+ raise Exception("PrecisionDebugger instance is not created.")
92
+ if instance.enable_dataloader:
93
+ logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
94
+ else:
95
+ instance.service.stop()
96
+
97
+ @classmethod
98
+ def step(cls):
99
+ if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
100
+ return
101
+ if not cls._instance:
102
+ raise Exception("PrecisionDebugger instance is not created.")
103
+ cls._instance.service.step()
104
+
105
+ @classmethod
106
+ def monitor(cls, model):
107
+ if not cls._instance:
108
+ raise Exception("PrecisionDebugger instance is not created.")
109
+ if cls._instance.task != Const.GRAD_PROBE:
110
+ return
111
+ cls._instance.gm.monitor(model)
112
+
113
+
114
+ def iter_tracer(func):
115
+ def func_wrapper(*args, **kwargs):
116
+ debugger_instance = PrecisionDebugger.instance
117
+ debugger_instance.enable_dataloader = False
118
+ if not debugger_instance.service.first_start:
119
+ debugger_instance.stop()
120
+ debugger_instance.step()
121
+ result = func(*args, **kwargs)
122
+ debugger_instance.start()
123
+ debugger_instance.enable_dataloader = True
124
+ return result
125
+ return func_wrapper
@@ -1,8 +1,8 @@
1
- from msprobe.core.common.log import logger
2
- from msprobe.core.common.exceptions import FreeBenchmarkException
3
- from msprobe.core.common.const import Const
4
-
5
- from .main import FreeBenchmarkCheck
6
- from .common.params import UnequalRow
7
-
8
- __all__ = [FreeBenchmarkCheck, UnequalRow]
1
+ from msprobe.pytorch.common.log import logger
2
+ from msprobe.core.common.exceptions import FreeBenchmarkException
3
+ from msprobe.core.common.const import Const
4
+
5
+ from .main import FreeBenchmarkCheck
6
+ from .common.params import UnequalRow
7
+
8
+ __all__ = [FreeBenchmarkCheck, UnequalRow]