mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,252 +1,311 @@
1
- import functools
2
- import os
3
- import time
4
- from pathlib import Path
5
-
6
- from collections import namedtuple
7
- import torch
8
- from msprobe.core.common.const import Const, FileCheckConst
9
- from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
10
- from msprobe.core.common.file_check import FileChecker, check_path_before_create
11
- from msprobe.core.data_dump.data_collector import build_data_collector
12
- from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
13
- from msprobe.core.data_dump.scope import BaseScope
14
- from msprobe.pytorch.common.log import logger
15
- from msprobe.pytorch.common.utils import get_rank_if_initialized
16
- from msprobe.pytorch.hook_module import remove_dropout
17
- from msprobe.pytorch.hook_module.api_registry import api_register
18
- from msprobe.pytorch.hook_module.hook_module import HOOKModule
19
- from msprobe.pytorch.module_processer import ModuleProcesser
20
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL, ApiData
21
- torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
22
-
23
- HookFn = namedtuple('hookFn', ['pre_hook', 'forward_hook', 'backward_hook', 'forward_hook_torch_version_below_2'])
24
-
25
-
26
- class Service:
27
- def __init__(self, config):
28
- self.model = None
29
- self.config = config
30
- self.data_collector = build_data_collector(config)
31
- self.module_processor = ModuleProcesser(self.data_collector.scope)
32
- self.switch = False
33
- self.current_iter = 0
34
- self.first_start = True
35
- self.current_rank = None
36
- self.dump_iter_dir = None
37
- self.attl = None
38
-
39
- @staticmethod
40
- def forward_backward_dump_end():
41
- logger.info_on_rank_0("Data needed ends here.")
42
- api_register.api_originality()
43
-
44
- def build_hook(self, module_type, name):
45
- def pre_hook(api_or_module_name, module, args, kwargs):
46
- if module_type == BaseScope.Module_Type_Module:
47
- api_or_module_name = module.mindstudio_reserved_name
48
- self.data_collector.visit_and_clear_overflow_status(api_or_module_name)
49
-
50
- if not self.switch:
51
- return args, kwargs
52
- if self.config.online_run_ut:
53
- return None, None
54
- if self.data_collector:
55
- module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
56
- self.data_collector.pre_forward_data_collect(api_or_module_name, module, pid, module_input_output)
57
- return args, kwargs
58
-
59
- def forward_hook(api_or_module_name, module, args, kwargs, output):
60
- if module_type == BaseScope.Module_Type_Module:
61
- api_or_module_name = module.mindstudio_reserved_name
62
- self.data_collector.visit_and_clear_overflow_status(api_or_module_name)
63
-
64
- if not self.switch:
65
- return None
66
-
67
- if self.config.online_run_ut:
68
- if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
69
- return None
70
- api_data = ApiData(name[:-1], args, kwargs, output, self.current_iter, self.current_rank)
71
- self.attl_send(api_data)
72
- return None
73
-
74
- if self.data_collector:
75
- module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
76
- self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output)
77
- if self.data_collector.if_return_forward_new_output():
78
- return self.data_collector.get_forward_new_output()
79
- return output
80
-
81
- def forward_hook_torch_version_below_2(api_or_module_name, module, args, output):
82
- return forward_hook(api_or_module_name, module, args, {}, output)
83
-
84
- def backward_hook(api_or_module_name, module, grad_input, grad_output):
85
- if module_type == BaseScope.Module_Type_Module:
86
- api_or_module_name = module.mindstudio_reserved_name
87
- self.data_collector.visit_and_clear_overflow_status(api_or_module_name)
88
-
89
- if not self.switch:
90
- return
91
-
92
- if self.config.online_run_ut:
93
- if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
94
- return
95
- api_data = ApiData(name[:-1], grad_input, {}, grad_output, self.current_iter, self.current_rank)
96
- self.attl_send(api_data)
97
- return
98
-
99
- if self.data_collector:
100
- # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序
101
- module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
102
- self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output)
103
-
104
- pid = os.getpid()
105
- forward_name_template = name + Const.FORWARD
106
- backward_name_template = name + Const.BACKWARD
107
- pre_forward_hook_fn = functools.partial(pre_hook, forward_name_template)
108
- forward_hook_fn = functools.partial(forward_hook, forward_name_template)
109
- backward_hook_fn = functools.partial(backward_hook, backward_name_template)
110
- forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2, forward_name_template)
111
- return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
112
-
113
- def step(self):
114
- self.current_iter += 1
115
- self.data_collector.update_iter(self.current_iter)
116
-
117
- ModuleProcesser.reset_module_stats()
118
- HOOKModule.reset_module_stats()
119
-
120
- def start(self, model, api_origin=False):
121
- self.model = model
122
- if self.config.step and self.current_iter > max(self.config.step):
123
- if self.config.online_run_ut:
124
- # send stop signal if online_run_ut
125
- self.attl_stop()
126
- self.stop()
127
- raise Exception("msprobe: exit after iteration {}".format(max(self.config.step)))
128
- if self.config.step and self.current_iter not in self.config.step:
129
- return
130
- if self.first_start:
131
- try:
132
- self.current_rank = get_rank_if_initialized()
133
- except DistributedNotInitializedError:
134
- self.current_rank = None
135
- self.attl_init()
136
-
137
- if self.config.rank and self.current_rank not in self.config.rank:
138
- return
139
- self.register_hook_new()
140
- self.first_start = False
141
- if api_origin:
142
- api_register.api_modularity()
143
- self.switch = True
144
- logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ")
145
- if self.config.level != "L2" and not self.config.online_run_ut:
146
- self.create_dirs()
147
- logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
148
-
149
- def stop(self):
150
- if self.config.level == "L2":
151
- return
152
- if self.config.step and self.current_iter not in self.config.step:
153
- return
154
- if self.config.rank and self.current_rank not in self.config.rank:
155
- return
156
- self.switch = False
157
- if self.config.online_run_ut:
158
- return
159
- self.data_collector.write_json()
160
-
161
- def create_dirs(self):
162
- check_path_before_create(self.config.dump_path)
163
- if not os.path.exists(self.config.dump_path):
164
- Path(self.config.dump_path).mkdir(mode=0o750, exist_ok=True)
165
- file_check = FileChecker(self.config.dump_path, FileCheckConst.DIR)
166
- file_check.common_check()
167
- self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
168
- cur_rank = self.current_rank if self.current_rank is not None else ''
169
- dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
170
- if not os.path.exists(dump_dir):
171
- Path(dump_dir).mkdir(mode=0o750, parents=True, exist_ok=True)
172
- if self.config.task in self.data_collector.tasks_need_tensor_data:
173
- dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
174
- Path(dump_data_dir).mkdir(mode=0o750, exist_ok=True)
175
- else:
176
- dump_data_dir = None
177
-
178
- dump_file_path = os.path.join(dump_dir, "dump.json")
179
- stack_file_path = os.path.join(dump_dir, "stack.json")
180
- construct_file_path = os.path.join(dump_dir, "construct.json")
181
- free_benchmark_file_path = os.path.join(self.config.dump_path, "free_benchmark.csv")
182
- self.data_collector.update_dump_paths(
183
- dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path)
184
-
185
- def register_hook_new(self):
186
- logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task))
187
- if self.config.level in ["L0", "mix"]:
188
- if self.model is None:
189
- logger.error_log_with_exp("The model is None.", MsprobeException.INVALID_PARAM_ERROR)
190
- logger.info_on_rank_0("The init dump mode is enabled, and the module dump function will not be available")
191
- for name, module in self.model.named_modules():
192
- if module == self.model:
193
- continue
194
- prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \
195
- module.__class__.__name__ + Const.SEP
196
-
197
- pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 \
198
- = self.build_hook(BaseScope.Module_Type_Module, prefix)
199
- if torch_version_above_or_equal_2:
200
- module.register_forward_hook(forward_hook, with_kwargs=True)
201
- else:
202
- module.register_full_backward_hook(
203
- self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
204
- module.register_forward_hook(forward_hook_torch_version_below_2)
205
- module.register_full_backward_hook(backward_hook)
206
-
207
- module.register_forward_pre_hook(
208
- self.module_processor.node_hook(prefix + Const.FORWARD, Const.START))
209
- module.register_forward_hook(
210
- self.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
211
- if torch_version_above_or_equal_2:
212
- module.register_full_backward_pre_hook(
213
- self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
214
- module.register_full_backward_hook(
215
- self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
216
-
217
- if self.config.level in ["mix", "L1", "L2"]:
218
- api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
219
- api_register.api_modularity()
220
-
221
- if Const.STATISTICS == self.config.task or Const.TENSOR == self.config.task:
222
- remove_dropout()
223
-
224
- def attl_init(self):
225
- if self.config.online_run_ut:
226
- attl_config = ATTLConfig(is_benchmark_device=False,
227
- connect_ip=self.config.host,
228
- connect_port=self.config.port,
229
- nfs_path=self.config.nfs_path,
230
- tls_path=self.config.tls_path)
231
- need_dump = len(self.config.rank) == 0 or self.current_rank in self.config.rank
232
- self.attl = ATTL('npu', attl_config, need_dump=need_dump)
233
- if self.config.nfs_path:
234
- self.attl.upload("start")
235
-
236
- def attl_send(self, api_data):
237
- logger.info(f"tools is dumping api: {api_data.name}, rank: {self.current_rank}")
238
- api_type, _, _ = api_data.name.split(Const.SEP)
239
- if api_type in [Const.DISTRIBUTED]:
240
- logger.info(f"api {api_data.name} is not supported, skip")
241
- return
242
- if self.config.nfs_path:
243
- self.attl.upload(api_data)
244
- else:
245
- self.attl.send(api_data)
246
-
247
- def attl_stop(self):
248
- if self.config.nfs_path:
249
- self.attl.upload("end")
250
- elif self.attl.socket_manager is not None:
251
- logger.info(f"pid: {os.getpid()} finished, start send STOP signal.")
252
- self.attl.socket_manager.send_stop_signal()
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ import os
18
+
19
+ from collections import namedtuple
20
+ import torch
21
+ from msprobe.core.common.const import Const
22
+ from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
23
+ from msprobe.core.common.file_utils import create_directory
24
+ from msprobe.core.common.utils import print_tools_ends_info
25
+ from msprobe.core.data_dump.data_collector import build_data_collector
26
+ from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
27
+ from msprobe.core.data_dump.scope import BaseScope
28
+ from msprobe.pytorch.common.log import logger
29
+ from msprobe.pytorch.common.utils import get_rank_if_initialized
30
+ from msprobe.pytorch.hook_module import remove_dropout
31
+ from msprobe.pytorch.hook_module.api_registry import api_register
32
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
33
+ from msprobe.pytorch.module_processer import ModuleProcesser
34
+ from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
35
+
36
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
37
+ if torch_version_above_or_equal_2:
38
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
39
+
40
+ HookFn = namedtuple('hookFn', ['pre_hook', 'forward_hook', 'backward_hook', 'forward_hook_torch_version_below_2'])
41
+
42
+
43
+ class Service:
44
+ def __init__(self, config):
45
+ self.model = None
46
+ self.config = config
47
+ self.data_collector = build_data_collector(config)
48
+ self.module_processor = ModuleProcesser(self.data_collector.scope)
49
+ self.switch = False
50
+ self.current_iter = 0
51
+ self.first_start = True
52
+ self.current_rank = None
53
+ self.dump_iter_dir = None
54
+ self.should_stop_service = False
55
+ self.attl = None
56
+
57
+ @staticmethod
58
+ def forward_backward_dump_end():
59
+ logger.info_on_rank_0("Data needed ends here.")
60
+ api_register.api_originality()
61
+
62
+ @staticmethod
63
+ def is_registered_backward_hook(module):
64
+ if hasattr(module, '_backward_hooks') and \
65
+ len(module._backward_hooks) > 0 and \
66
+ module._is_full_backward_hook is False:
67
+ return True
68
+ return False
69
+
70
+ def check_register_full_backward_hook(self, module):
71
+ if self.is_registered_backward_hook(module):
72
+ module._backward_hooks.clear()
73
+ module._is_full_backward_hook = None
74
+ logger.warning("Found deprecated backward hooks. Removing them and switching to full backward hooks.")
75
+
76
+ def build_hook(self, module_type, name):
77
+ def pre_hook(api_or_module_name, module, args, kwargs):
78
+ if not self.should_execute_hook():
79
+ return args, kwargs
80
+
81
+ if module_type == BaseScope.Module_Type_Module:
82
+ api_or_module_name = module.mindstudio_reserved_name
83
+ self.data_collector.update_api_or_module_name(api_or_module_name)
84
+
85
+ if self.config.online_run_ut:
86
+ return None, None
87
+ if self.data_collector:
88
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
89
+ self.data_collector.pre_forward_data_collect(api_or_module_name, module, pid, module_input_output)
90
+ return args, kwargs
91
+
92
+ def forward_hook(api_or_module_name, module, args, kwargs, output):
93
+ if not self.should_execute_hook():
94
+ return None
95
+
96
+ if module_type == BaseScope.Module_Type_Module:
97
+ api_or_module_name = module.mindstudio_reserved_name
98
+ self.data_collector.update_api_or_module_name(api_or_module_name)
99
+
100
+ if self.config.online_run_ut:
101
+ if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
102
+ return None
103
+ api_data = ApiData(name[:-1], args, kwargs, output, self.current_iter, self.current_rank)
104
+ self.attl_send(api_data)
105
+ return None
106
+
107
+ if self.data_collector:
108
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
109
+ self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output)
110
+ if self.data_collector.if_return_forward_new_output():
111
+ return self.data_collector.get_forward_new_output()
112
+ return output
113
+
114
+ def forward_hook_torch_version_below_2(api_or_module_name, module, args, output):
115
+ return forward_hook(api_or_module_name, module, args, {}, output)
116
+
117
+ def backward_hook(api_or_module_name, module, grad_input, grad_output):
118
+ if not self.should_execute_hook():
119
+ return
120
+
121
+ if module_type == BaseScope.Module_Type_Module:
122
+ api_or_module_name = module.mindstudio_reserved_name
123
+ self.data_collector.update_api_or_module_name(api_or_module_name)
124
+
125
+ if self.config.online_run_ut:
126
+ return
127
+
128
+ if self.data_collector:
129
+ # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序
130
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
131
+ self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output)
132
+
133
+ pid = os.getpid()
134
+ forward_name_template = name + Const.FORWARD
135
+ backward_name_template = name + Const.BACKWARD
136
+ pre_forward_hook_fn = functools.partial(pre_hook, forward_name_template)
137
+ forward_hook_fn = functools.partial(forward_hook, forward_name_template)
138
+ backward_hook_fn = functools.partial(backward_hook, backward_name_template)
139
+ forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2,
140
+ forward_name_template)
141
+ return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
142
+
143
+ def start(self, model, api_origin=False):
144
+ if self.need_stop_service():
145
+ return
146
+
147
+ self.model = model
148
+ if self.first_start:
149
+ try:
150
+ self.current_rank = get_rank_if_initialized()
151
+ except DistributedNotInitializedError:
152
+ self.current_rank = None
153
+ self.attl_init()
154
+
155
+ if self.config.rank and self.current_rank not in self.config.rank:
156
+ return
157
+ self.register_hook_new()
158
+ self.first_start = False
159
+ if api_origin:
160
+ api_register.api_modularity()
161
+ if self.config.online_run_ut and torch_version_above_or_equal_2:
162
+ run_ut_dispatch(self.attl, True)
163
+ self.switch = True
164
+ logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ")
165
+ if self.config.level != "L2" and not self.config.online_run_ut:
166
+ self.create_dirs()
167
+ logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
168
+
169
+ def stop(self):
170
+ if self.should_stop_service:
171
+ return
172
+ if self.config.level == "L2":
173
+ return
174
+ if self.config.step and self.current_iter not in self.config.step:
175
+ return
176
+ if self.config.rank and self.current_rank not in self.config.rank:
177
+ return
178
+ self.switch = False
179
+ if self.config.online_run_ut and torch_version_above_or_equal_2:
180
+ run_ut_dispatch(self.attl, False)
181
+ return
182
+ self.data_collector.write_json()
183
+
184
+ def step(self):
185
+ if self.should_stop_service:
186
+ return
187
+ self.current_iter += 1
188
+ self.data_collector.update_iter(self.current_iter)
189
+
190
+ ModuleProcesser.reset_module_stats()
191
+ HOOKModule.reset_module_stats()
192
+ self.data_collector.data_writer.reset_cache()
193
+
194
+ def need_stop_service(self):
195
+ if self.should_stop_service:
196
+ return True
197
+ end_service = self.config.step and self.current_iter > max(self.config.step) or \
198
+ self.data_collector and self.data_collector.data_processor.is_terminated
199
+ if end_service:
200
+ if self.config.online_run_ut:
201
+ # send stop signal if online_run_ut
202
+ self.attl_stop()
203
+ if self.config.level in [Const.LEVEL_L1, Const.LEVEL_L2, Const.LEVEL_MIX]:
204
+ api_register.api_originality()
205
+ self.switch = False
206
+ self.should_stop_service = True
207
+ print_tools_ends_info()
208
+ return True
209
+ if self.config.step and self.current_iter not in self.config.step:
210
+ return True
211
+ return False
212
+
213
+ def should_execute_hook(self):
214
+ if not self.switch:
215
+ return False
216
+ if self.data_collector and self.data_collector.data_processor.is_terminated:
217
+ return False
218
+ return True
219
+
220
+ def create_dirs(self):
221
+ create_directory(self.config.dump_path)
222
+ self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
223
+ cur_rank = self.current_rank if self.current_rank is not None else ''
224
+ dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
225
+ create_directory(dump_dir)
226
+ if self.config.task in self.data_collector.tasks_need_tensor_data:
227
+ dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
228
+ create_directory(dump_data_dir)
229
+ else:
230
+ dump_data_dir = None
231
+
232
+ dump_file_path = os.path.join(dump_dir, "dump.json")
233
+ stack_file_path = os.path.join(dump_dir, "stack.json")
234
+ construct_file_path = os.path.join(dump_dir, "construct.json")
235
+ free_benchmark_file_path = os.path.join(self.config.dump_path, "free_benchmark.csv")
236
+ self.data_collector.update_dump_paths(
237
+ dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path)
238
+
239
+ def register_hook_new(self):
240
+ logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task))
241
+ if self.config.level in ["L0", "mix"]:
242
+ if self.model is None:
243
+ logger.error_log_with_exp("The model is None.", MsprobeException.INVALID_PARAM_ERROR)
244
+ logger.info_on_rank_0("The init dump mode is enabled, and the module dump function will not be available")
245
+ for name, module in self.model.named_modules():
246
+ if module == self.model:
247
+ continue
248
+ prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \
249
+ module.__class__.__name__ + Const.SEP
250
+
251
+ pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.build_hook(
252
+ BaseScope.Module_Type_Module, prefix)
253
+ if torch_version_above_or_equal_2:
254
+ module.register_forward_hook(forward_hook, with_kwargs=True)
255
+ else:
256
+ self.check_register_full_backward_hook(module)
257
+ module.register_full_backward_hook(
258
+ self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
259
+ module.register_forward_hook(forward_hook_torch_version_below_2)
260
+ self.check_register_full_backward_hook(module)
261
+ module.register_full_backward_hook(backward_hook)
262
+
263
+ module.register_forward_pre_hook(
264
+ self.module_processor.node_hook(prefix + Const.FORWARD, Const.START))
265
+ module.register_forward_hook(
266
+ self.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
267
+ if torch_version_above_or_equal_2:
268
+ module.register_full_backward_pre_hook(
269
+ self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
270
+ self.check_register_full_backward_hook(module)
271
+ module.register_full_backward_hook(
272
+ self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
273
+
274
+ if self.config.level in ["mix", "L1", "L2"]:
275
+ api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API),
276
+ self.config.online_run_ut)
277
+ api_register.api_modularity()
278
+
279
+ if Const.STATISTICS == self.config.task or Const.TENSOR == self.config.task:
280
+ remove_dropout()
281
+
282
+ def attl_init(self):
283
+ if self.config.online_run_ut:
284
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL
285
+ attl_config = ATTLConfig(is_benchmark_device=False,
286
+ connect_ip=self.config.host,
287
+ connect_port=self.config.port,
288
+ nfs_path=self.config.nfs_path,
289
+ tls_path=self.config.tls_path)
290
+ need_dump = len(self.config.rank) == 0 or self.current_rank in self.config.rank
291
+ self.attl = ATTL('npu', attl_config, need_dump=need_dump)
292
+ if self.config.nfs_path:
293
+ self.attl.upload("start")
294
+
295
+ def attl_send(self, api_data):
296
+ logger.info(f"tools is dumping api: {api_data.name}, rank: {self.current_rank}")
297
+ api_type, _, _ = api_data.name.split(Const.SEP)
298
+ if api_type in [Const.DISTRIBUTED]:
299
+ logger.info(f"api {api_data.name} is not supported, skip")
300
+ return
301
+ if self.config.nfs_path:
302
+ self.attl.upload(api_data)
303
+ else:
304
+ self.attl.send(api_data)
305
+
306
+ def attl_stop(self):
307
+ if self.config.nfs_path:
308
+ self.attl.upload("end")
309
+ elif self.attl.socket_manager is not None:
310
+ logger.info(f"pid: {os.getpid()} finished, start send STOP signal.")
311
+ self.attl.socket_manager.send_stop_signal()