mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,125 +1,173 @@
1
- import torch
2
- from torch.utils.data import dataloader
3
- from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
4
- from msprobe.pytorch.service import Service
5
- from msprobe.pytorch.common.log import logger
6
- from msprobe.pytorch.pt_config import parse_json_config
7
- from msprobe.core.common.exceptions import MsprobeException
8
- from msprobe.core.common.const import Const
9
- from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
10
-
11
-
12
- class PrecisionDebugger:
13
- _instance = None
14
- tasks_not_need_debugger = [Const.GRAD_PROBE]
15
-
16
- def __new__(cls, *args, **kwargs):
17
- if cls._instance is None:
18
- cls._instance = super(PrecisionDebugger, cls).__new__(cls)
19
- cls._instance.config = None
20
- cls._instance.enable_dataloader = False
21
- return cls._instance
22
-
23
- def __init__(
24
- self,
25
- config_path=None,
26
- task=None,
27
- dump_path=None,
28
- level=None,
29
- model=None,
30
- step=None,
31
- ):
32
- if not hasattr(self, "initialized"):
33
- self.api_origin = False
34
- self.initialized = True
35
- self.model = self.check_model_valid(model)
36
- common_config, task_config = parse_json_config(config_path, task)
37
- self.task = common_config.task
38
- if self.task == Const.GRAD_PROBE:
39
- self.gm = GradientMonitor(common_config, task_config)
40
- return
41
- if step:
42
- common_config.step = step
43
- self.config = DebuggerConfig(
44
- common_config, task_config, task, dump_path, level
45
- )
46
- self.config.check_model(self.model)
47
- self.service = Service(self.config)
48
- self.enable_dataloader = self.config.enable_dataloader
49
- if self.enable_dataloader:
50
- logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
51
- dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
52
-
53
- @property
54
- def instance(self):
55
- return self._instance
56
-
57
- @staticmethod
58
- def check_model_valid(model):
59
- if not model or isinstance(model, torch.nn.Module):
60
- return model
61
- raise MsprobeException(
62
- MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。"
63
- )
64
-
65
- @classmethod
66
- def start(cls):
67
- instance = cls._instance
68
- if instance.task in PrecisionDebugger.tasks_not_need_debugger:
69
- return
70
- if not instance:
71
- raise Exception("No instance of PrecisionDebugger found.")
72
- if instance.enable_dataloader:
73
- logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
74
- else:
75
- instance.service.start(instance.model, instance.api_origin)
76
- instance.api_origin = False
77
-
78
- # 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
79
- @classmethod
80
- def forward_backward_dump_end(cls):
81
- instance = cls._instance
82
- instance.service.forward_backward_dump_end()
83
- instance.api_origin = True
84
-
85
- @classmethod
86
- def stop(cls):
87
- instance = cls._instance
88
- if instance.task in PrecisionDebugger.tasks_not_need_debugger:
89
- return
90
- if not instance:
91
- raise Exception("PrecisionDebugger instance is not created.")
92
- if instance.enable_dataloader:
93
- logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
94
- else:
95
- instance.service.stop()
96
-
97
- @classmethod
98
- def step(cls):
99
- if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
100
- return
101
- if not cls._instance:
102
- raise Exception("PrecisionDebugger instance is not created.")
103
- cls._instance.service.step()
104
-
105
- @classmethod
106
- def monitor(cls, model):
107
- if not cls._instance:
108
- raise Exception("PrecisionDebugger instance is not created.")
109
- if cls._instance.task != Const.GRAD_PROBE:
110
- return
111
- cls._instance.gm.monitor(model)
112
-
113
-
114
- def iter_tracer(func):
115
- def func_wrapper(*args, **kwargs):
116
- debugger_instance = PrecisionDebugger.instance
117
- debugger_instance.enable_dataloader = False
118
- if not debugger_instance.service.first_start:
119
- debugger_instance.stop()
120
- debugger_instance.step()
121
- result = func(*args, **kwargs)
122
- debugger_instance.start()
123
- debugger_instance.enable_dataloader = True
124
- return result
125
- return func_wrapper
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import namedtuple
17
+
18
+ import torch
19
+ from msprobe.core.common.const import Const, FileCheckConst, MsgConst
20
+ from msprobe.core.common.exceptions import MsprobeException
21
+ from msprobe.core.common.file_utils import FileChecker
22
+ from msprobe.core.common.utils import get_real_step_or_rank
23
+ from msprobe.pytorch.common.log import logger
24
+ from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
25
+ from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
26
+ from msprobe.pytorch.pt_config import parse_json_config
27
+ from msprobe.pytorch.service import Service
28
+ from torch.utils.data import dataloader
29
+
30
+ ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task",
31
+ "dump_path", "level", "model"])
32
+
33
+
34
+ class PrecisionDebugger:
35
+ _instance = None
36
+ tasks_not_need_debugger = [Const.GRAD_PROBE]
37
+
38
+ def __new__(cls, *args, **kwargs):
39
+ if cls._instance is None:
40
+ cls._instance = super(PrecisionDebugger, cls).__new__(cls)
41
+ cls._instance.config = None
42
+ cls._instance.enable_dataloader = False
43
+ return cls._instance
44
+
45
+ def __init__(
46
+ self,
47
+ config_path=None,
48
+ task=None,
49
+ dump_path=None,
50
+ level=None,
51
+ model=None,
52
+ step=None,
53
+ ):
54
+ if not hasattr(self, "initialized"):
55
+ config_params = ConfigParameters(config_path,
56
+ task,
57
+ dump_path,
58
+ level,
59
+ model)
60
+ self.check_input_params(config_params)
61
+
62
+ self.api_origin = False
63
+ self.initialized = True
64
+ self.model = model
65
+ common_config, task_config = parse_json_config(config_path, task)
66
+ self.task = task if task else common_config.task
67
+ if self.task == Const.GRAD_PROBE:
68
+ self.gm = GradientMonitor(common_config, task_config)
69
+ return
70
+ if step:
71
+ common_config.step = get_real_step_or_rank(step, Const.STEP)
72
+ self.config = DebuggerConfig(
73
+ common_config, task_config, task, dump_path, level
74
+ )
75
+ self.service = Service(self.config)
76
+ self.enable_dataloader = self.config.enable_dataloader
77
+ if self.enable_dataloader:
78
+ logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
79
+ dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
80
+
81
+ @property
82
+ def instance(self):
83
+ return self._instance
84
+
85
+ @staticmethod
86
+ def check_input_params(args):
87
+ if args.config_path is not None:
88
+ if not isinstance(args.config_path, str):
89
+ raise MsprobeException(
90
+ MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
91
+ file_checker = FileChecker(
92
+ file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
93
+ file_checker.common_check()
94
+
95
+ if args.task is not None and args.task not in Const.TASK_LIST:
96
+ raise MsprobeException(
97
+ MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
98
+
99
+ if args.dump_path is not None:
100
+ if not isinstance(args.dump_path, str):
101
+ raise MsprobeException(
102
+ MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
103
+
104
+ if args.level is not None and args.level not in Const.LEVEL_LIST:
105
+ raise MsprobeException(
106
+ MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
107
+
108
+ if args.model is not None and not isinstance(args.model, torch.nn.Module):
109
+ raise MsprobeException(
110
+ MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
111
+
112
+ @classmethod
113
+ def start(cls, model=None):
114
+ instance = cls._instance
115
+ if not instance:
116
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
117
+ if instance.task in PrecisionDebugger.tasks_not_need_debugger:
118
+ return
119
+ instance.config.check_model(instance, model)
120
+ if instance.enable_dataloader:
121
+ logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
122
+ else:
123
+ instance.service.start(instance.model, instance.api_origin)
124
+ instance.api_origin = False
125
+
126
+ # 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
127
+ @classmethod
128
+ def forward_backward_dump_end(cls):
129
+ instance = cls._instance
130
+ instance.service.forward_backward_dump_end()
131
+ instance.api_origin = True
132
+
133
+ @classmethod
134
+ def stop(cls):
135
+ instance = cls._instance
136
+ if not instance:
137
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
138
+ if instance.task in PrecisionDebugger.tasks_not_need_debugger:
139
+ return
140
+ if instance.enable_dataloader:
141
+ logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
142
+ else:
143
+ instance.service.stop()
144
+
145
+ @classmethod
146
+ def step(cls):
147
+ if not cls._instance:
148
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
149
+ if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
150
+ return
151
+ cls._instance.service.step()
152
+
153
+ @classmethod
154
+ def monitor(cls, model):
155
+ if not cls._instance:
156
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
157
+ if cls._instance.task != Const.GRAD_PROBE:
158
+ return
159
+ cls._instance.gm.monitor(model)
160
+
161
+
162
+ def iter_tracer(func):
163
+ def func_wrapper(*args, **kwargs):
164
+ debugger_instance = PrecisionDebugger.instance
165
+ debugger_instance.enable_dataloader = False
166
+ if not debugger_instance.service.first_start:
167
+ debugger_instance.stop()
168
+ debugger_instance.step()
169
+ result = func(*args, **kwargs)
170
+ debugger_instance.start()
171
+ debugger_instance.enable_dataloader = True
172
+ return result
173
+ return func_wrapper
@@ -1,8 +1,23 @@
1
- from msprobe.core.common.log import logger
2
- from msprobe.core.common.exceptions import FreeBenchmarkException
3
- from msprobe.core.common.const import Const
4
-
5
- from .main import FreeBenchmarkCheck
6
- from .common.params import UnequalRow
7
-
8
- __all__ = [FreeBenchmarkCheck, UnequalRow]
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __all__ = ["FreeBenchmarkCheck", "UnequalRow"]
17
+
18
+ from msprobe.core.common.const import Const
19
+ from msprobe.core.common.exceptions import FreeBenchmarkException
20
+ from msprobe.pytorch.common.log import logger
21
+
22
+ from .common.params import UnequalRow
23
+ from .main import FreeBenchmarkCheck
@@ -1,70 +1,70 @@
1
- from typing import Dict
2
-
3
- import numpy as np
4
- import torch
5
- from msprobe.pytorch.free_benchmark.common.enums import FuzzThreshold
6
- from msprobe.pytorch.free_benchmark.common.params import BenchmarkThd
7
-
8
-
9
- class CommonField:
10
- DEVICE = "device"
11
- META = "meta"
12
- FUZZ_TENSOR = "fuzz_tensor"
13
- REQUIRES_GRAD = "requires_grad"
14
- HOLD_PLACE = "hold_place"
15
- DISTRIBUTED_OP = "torch.distributed"
16
- GRADSAVER = "grad_saver"
17
-
18
-
19
- class ThresholdConfig:
20
- PERTURBATION_VALUE_DICT: Dict = {
21
- torch.bfloat16: FuzzThreshold.BF16_THD,
22
- torch.float16: FuzzThreshold.F16_THD,
23
- torch.float32: FuzzThreshold.F32_THD,
24
- torch.float64: FuzzThreshold.F64_THD,
25
- }
26
-
27
- ABS_TOL_VALUE_DICT: Dict = {
28
- torch.bfloat16: FuzzThreshold.BF16_THD,
29
- torch.float16: FuzzThreshold.F16_THD,
30
- torch.float32: FuzzThreshold.F32_THD,
31
- torch.float64: FuzzThreshold.F64_THD,
32
- }
33
-
34
- # bit翻转需要匹配到等长或更长的整型
35
- PERTURBATION_BIT_DICT = {
36
- torch.bfloat16: torch.int16,
37
- torch.float16: torch.int16,
38
- torch.float32: torch.int32,
39
- torch.float64: torch.int64,
40
- }
41
-
42
- # 输入噪声下界
43
- NOISE_INPUT_LOWER_BOUND = 1e-8
44
- COMP_CONSISTENT = 1.0
45
- COMP_NAN = np.nan
46
- SYMBOL_FLIPPING = "symbol_flipping"
47
- BACKWARD_OUTPUT_LOWER_BOUND = 1e-3
48
- SMALL_VALUE = 1.0
49
- # 预热初始阈值
50
- PREHEAT_INITIAL_THD = 2.05
51
- API_THD_STEP = 2.0
52
-
53
- DTYPE_PER_THD = {
54
- torch.float16: 1.002,
55
- torch.bfloat16: 1.004,
56
- torch.float32: 1.0002,
57
- }
58
- BENCHMARK_THD_DICT = {
59
- torch.float32: BenchmarkThd(2**-14, 1.0, 2**-14, 1e-4),
60
- torch.float16: BenchmarkThd(2**-11, 1.0, 2**-11, 1e-4),
61
- torch.bfloat16: BenchmarkThd(2**-8, 1.0, 2**-8, 1e-4),
62
- }
63
-
64
- TENSOR_SPLIT_MAX_CHUNK = 128
65
-
66
-
67
- class PreheatConfig:
68
- IF_PREHEAT = "if_preheat"
69
- PREHEAT_STEP = "preheat_step"
70
- MAX_SAMPLE = "max_sample"
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import torch
5
+ from msprobe.pytorch.free_benchmark.common.enums import FuzzThreshold
6
+ from msprobe.pytorch.free_benchmark.common.params import BenchmarkThd
7
+
8
+
9
+ class CommonField:
10
+ DEVICE = "device"
11
+ META = "meta"
12
+ FUZZ_TENSOR = "fuzz_tensor"
13
+ REQUIRES_GRAD = "requires_grad"
14
+ HOLD_PLACE = "hold_place"
15
+ DISTRIBUTED_OP = "torch.distributed"
16
+ GRADSAVER = "grad_saver"
17
+
18
+
19
+ class ThresholdConfig:
20
+ PERTURBATION_VALUE_DICT: Dict = {
21
+ torch.bfloat16: FuzzThreshold.BF16_THD,
22
+ torch.float16: FuzzThreshold.F16_THD,
23
+ torch.float32: FuzzThreshold.F32_THD,
24
+ torch.float64: FuzzThreshold.F64_THD,
25
+ }
26
+
27
+ ABS_TOL_VALUE_DICT: Dict = {
28
+ torch.bfloat16: FuzzThreshold.BF16_THD,
29
+ torch.float16: FuzzThreshold.F16_THD,
30
+ torch.float32: FuzzThreshold.F32_THD,
31
+ torch.float64: FuzzThreshold.F64_THD,
32
+ }
33
+
34
+ # bit翻转需要匹配到等长或更长的整型
35
+ PERTURBATION_BIT_DICT = {
36
+ torch.bfloat16: torch.int16,
37
+ torch.float16: torch.int16,
38
+ torch.float32: torch.int32,
39
+ torch.float64: torch.int64,
40
+ }
41
+
42
+ # 输入噪声下界
43
+ NOISE_INPUT_LOWER_BOUND = 1e-8
44
+ COMP_CONSISTENT = 1.0
45
+ COMP_NAN = np.nan
46
+ SYMBOL_FLIPPING = "symbol_flipping"
47
+ BACKWARD_OUTPUT_LOWER_BOUND = 1e-3
48
+ SMALL_VALUE = 1.0
49
+ # 预热初始阈值
50
+ PREHEAT_INITIAL_THD = 2.05
51
+ API_THD_STEP = 2.0
52
+
53
+ DTYPE_PER_THD = {
54
+ torch.float16: 1.002,
55
+ torch.bfloat16: 1.004,
56
+ torch.float32: 1.0002,
57
+ }
58
+ BENCHMARK_THD_DICT = {
59
+ torch.float32: BenchmarkThd(2**-14, 1.0, 2**-14, 1e-4),
60
+ torch.float16: BenchmarkThd(2**-11, 1.0, 2**-11, 1e-4),
61
+ torch.bfloat16: BenchmarkThd(2**-8, 1.0, 2**-8, 1e-4),
62
+ }
63
+
64
+ TENSOR_SPLIT_MAX_CHUNK = 128
65
+
66
+
67
+ class PreheatConfig:
68
+ IF_PREHEAT = "if_preheat"
69
+ PREHEAT_STEP = "preheat_step"
70
+ MAX_SAMPLE = "max_sample"
@@ -1,72 +1,72 @@
1
- from collections import defaultdict
2
- from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
3
-
4
-
5
- class PreheatCounter:
6
- def __init__(self) -> None:
7
- self.api_called_time: dict = defaultdict(int)
8
- self.api_sample_time: dict = defaultdict(int)
9
- self.one_step_used_api: dict = defaultdict(int)
10
- self.api_thd: dict = defaultdict(dict)
11
- self.preheat_record: dict = defaultdict(dict)
12
- self.dtype_map: dict = {}
13
- self.if_preheat: dict = defaultdict(dict)
14
- self.step = 0
15
-
16
- def clear_step(self):
17
- self.preheat_record.clear()
18
- self.api_called_time.clear()
19
- self.api_sample_time.clear()
20
-
21
- def check_step(self, current_step):
22
- if current_step != self.step:
23
- self.clear_step()
24
- self.step = current_step
25
-
26
- def add_api_called_time(self, api_name: str):
27
- self.api_called_time[api_name] += 1
28
-
29
- def get_api_called_time(self, api_name: str) -> int:
30
- return self.api_called_time[api_name]
31
-
32
- def add_api_sample_time(self, api_name: str):
33
- self.api_sample_time[api_name] += 1
34
-
35
- def get_api_sample_time(self, api_name: str) -> int:
36
- return self.api_sample_time[api_name]
37
-
38
- def add_one_step_used_api(self, api_name: str):
39
- self.one_step_used_api[api_name] += 1
40
-
41
- def get_one_step_used_api(self, api_name: str):
42
- return self.one_step_used_api[api_name]
43
-
44
- def update_preheat_record(self, api_name, dtype, cmp_result):
45
- # 记录预热阶段CPU标杆比对的结果
46
- if str(dtype) not in self.preheat_record[api_name].keys():
47
- self.preheat_record[api_name][str(dtype)] = list()
48
- self.preheat_record[api_name][str(dtype)].append(cmp_result)
49
- self.dtype_map[str(dtype)] = dtype
50
-
51
- def update_api_thd(self, api_name, dtype, threshold, dthreshold):
52
- self.api_thd[api_name][str(dtype)] = (
53
- threshold if threshold > dthreshold else dthreshold
54
- )
55
-
56
- def get_api_thd(self, api_name, dtype):
57
- if not str(dtype) in self.api_thd[api_name]:
58
- self.api_thd[api_name][str(dtype)] = ThresholdConfig.PREHEAT_INITIAL_THD
59
- self.dtype_map[str(dtype)] = dtype
60
- return self.api_thd[api_name][str(dtype)]
61
-
62
- def set_api_preheat(self, api_name, dtype_str, is_preheat=True):
63
- # 标记cpu不一致的dtype 不再进行预热
64
- self.if_preheat[api_name][dtype_str] = is_preheat
65
-
66
- def get_api_preheat(self, api_name, dtype):
67
- # 标记cpu不一致的dtype 不再进行预热
68
- if str(dtype) not in self.if_preheat[api_name]:
69
- return True
70
- return self.if_preheat[api_name][str(dtype)]
71
-
1
+ from collections import defaultdict
2
+ from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
3
+
4
+
5
+ class PreheatCounter:
6
+ def __init__(self) -> None:
7
+ self.api_called_time: dict = defaultdict(int)
8
+ self.api_sample_time: dict = defaultdict(int)
9
+ self.one_step_used_api: dict = defaultdict(int)
10
+ self.api_thd: dict = defaultdict(dict)
11
+ self.preheat_record: dict = defaultdict(dict)
12
+ self.dtype_map: dict = {}
13
+ self.if_preheat: dict = defaultdict(dict)
14
+ self.step = 0
15
+
16
+ def clear_step(self):
17
+ self.preheat_record.clear()
18
+ self.api_called_time.clear()
19
+ self.api_sample_time.clear()
20
+
21
+ def check_step(self, current_step):
22
+ if current_step != self.step:
23
+ self.clear_step()
24
+ self.step = current_step
25
+
26
+ def add_api_called_time(self, api_name: str):
27
+ self.api_called_time[api_name] += 1
28
+
29
+ def get_api_called_time(self, api_name: str) -> int:
30
+ return self.api_called_time[api_name]
31
+
32
+ def add_api_sample_time(self, api_name: str):
33
+ self.api_sample_time[api_name] += 1
34
+
35
+ def get_api_sample_time(self, api_name: str) -> int:
36
+ return self.api_sample_time[api_name]
37
+
38
+ def add_one_step_used_api(self, api_name: str):
39
+ self.one_step_used_api[api_name] += 1
40
+
41
+ def get_one_step_used_api(self, api_name: str):
42
+ return self.one_step_used_api[api_name]
43
+
44
+ def update_preheat_record(self, api_name, dtype, cmp_result):
45
+ # 记录预热阶段CPU标杆比对的结果
46
+ if str(dtype) not in self.preheat_record[api_name].keys():
47
+ self.preheat_record[api_name][str(dtype)] = list()
48
+ self.preheat_record[api_name][str(dtype)].append(cmp_result)
49
+ self.dtype_map[str(dtype)] = dtype
50
+
51
+ def update_api_thd(self, api_name, dtype, threshold, dthreshold):
52
+ self.api_thd[api_name][str(dtype)] = (
53
+ threshold if threshold > dthreshold else dthreshold
54
+ )
55
+
56
+ def get_api_thd(self, api_name, dtype):
57
+ if not str(dtype) in self.api_thd[api_name]:
58
+ self.api_thd[api_name][str(dtype)] = ThresholdConfig.PREHEAT_INITIAL_THD
59
+ self.dtype_map[str(dtype)] = dtype
60
+ return self.api_thd[api_name][str(dtype)]
61
+
62
+ def set_api_preheat(self, api_name, dtype_str, is_preheat=True):
63
+ # 标记cpu不一致的dtype 不再进行预热
64
+ self.if_preheat[api_name][dtype_str] = is_preheat
65
+
66
+ def get_api_preheat(self, api_name, dtype):
67
+ # 标记cpu不一致的dtype 不再进行预热
68
+ if str(dtype) not in self.if_preheat[api_name]:
69
+ return True
70
+ return self.if_preheat[api_name][str(dtype)]
71
+
72
72
  preheat_counter = PreheatCounter()