mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,300 +1,305 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- import logging
18
- import os
19
- import random
20
- import stat
21
- import csv
22
- import json
23
- import torch
24
- import torch.distributed as dist
25
- import numpy as np
26
- from functools import wraps
27
- from msprobe.core.common.exceptions import DistributedNotInitializedError
28
- from msprobe.core.common.log import logger as common_logger
29
- from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create, CompareException
30
- from msprobe.core.common.file_check import FileCheckConst, change_mode, FileOpen
31
-
32
-
33
- try:
34
- import torch_npu
35
- except ImportError:
36
- is_gpu = True
37
- else:
38
- is_gpu = False
39
-
40
-
41
- torch_without_guard_version = torch.__version__ >= '2.1'
42
-
43
-
44
- if not is_gpu and not torch_without_guard_version:
45
- from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard
46
-
47
- npu_distributed_api = ['isend', 'irecv']
48
-
49
-
50
- def parameter_adapter(func):
51
-
52
- def handle_masked_select(input_tensor, indices):
53
- masked_select_func = getattr(torch._C._VariableFunctionsClass, "masked_select")
54
- if input_tensor.dtype == torch.bfloat16:
55
- # masked_select在NPU上输入数据dtype类型为bfloat16会报错,提示不支持此类型
56
- return masked_select_func(input_tensor.to(torch.float32), indices).to(torch.bfloat16)
57
- else:
58
- return masked_select_func(input_tensor, indices)
59
-
60
- @wraps(func)
61
- def inner(self, *args, **kwargs):
62
- if self.op_name_ == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor):
63
- input_tensor = args[0]
64
- indices = args[1]
65
- if indices.dtype == torch.uint8:
66
- indices = indices.bool()
67
- if indices.dtype == torch.bool:
68
- if indices.shape == input_tensor.shape:
69
- return handle_masked_select(input_tensor, indices)
70
- else:
71
- indices = getattr(torch._C._VariableFunctionsClass, "nonzero")(indices, as_tuple=True)
72
- return getattr(torch._C._TensorBase, "__getitem__")(input_tensor, indices)
73
- elif indices.dtype != torch.bool:
74
- if not indices.shape or len(indices.shape) == 1:
75
- return func(self, input_tensor, indices.tolist())
76
- elif len(indices.shape) == 2:
77
- result = [func(self, input_tensor, index) for index in indices.tolist()]
78
- return getattr(torch._C._VariableFunctionsClass, "stack")(result, 0)
79
- else:
80
- res = [input_tensor[tensor_index] for tensor_index in indices]
81
- return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0)
82
- if self.op_name_ == "__eq__" and args[1] is None:
83
- return False
84
- return func(self, *args, **kwargs)
85
- return inner
86
-
87
-
88
- def torch_device_guard(func):
89
- if is_gpu or torch_without_guard_version:
90
- return func
91
- # Parse args/kwargs matched torch.device objects
92
-
93
- @torch_npu_device_guard
94
- def wrapper(*args, **kwargs):
95
- return func(*args, **kwargs)
96
- return wrapper
97
-
98
-
99
- def get_rank_if_initialized():
100
- """
101
- return rank id if it is initialized or raise Exception: DistributedNotInitializedError
102
- """
103
- if torch.distributed.is_initialized():
104
- return torch.distributed.get_rank()
105
- else:
106
- raise DistributedNotInitializedError("torch distributed environment is not initialized")
107
-
108
-
109
- def seed_all(seed=1234, mode=False):
110
- random.seed(seed)
111
- os.environ['PYTHONHASHSEED'] = str(seed)
112
- np.random.seed(seed)
113
- torch.manual_seed(seed)
114
- torch.use_deterministic_algorithms(mode)
115
- if is_gpu:
116
- torch.cuda.manual_seed_all(seed)
117
- torch.cuda.manual_seed(seed)
118
- torch.backends.cudnn.deterministic = True
119
- torch.backends.cudnn.enable = False
120
- torch.backends.cudnn.benchmark = False
121
- else:
122
- torch_npu.npu.manual_seed_all(seed)
123
- torch_npu.npu.manual_seed(seed)
124
-
125
-
126
- class Const:
127
- """
128
- Class for const
129
- """
130
- SEP = "."
131
- MODEL_TYPE = ['.onnx', '.pb', '.om']
132
- DIM_PATTERN = r"^(-?[0-9]+)(,-?[0-9]+)*"
133
- SEMICOLON = ";"
134
- COLON = ":"
135
- EQUAL = "="
136
- COMMA = ","
137
- DOT = "."
138
- DUMP_RATIO_MAX = 100
139
- SUMMERY_DATA_NUMS = 256
140
- FLOAT_EPSILON = np.finfo(float).eps
141
- SUPPORT_DUMP_MODE = ['api', 'acl']
142
- ON = 'ON'
143
- OFF = 'OFF'
144
- KWARGS = 'kwargs'
145
- INPUT = 'input'
146
- OUTPUT = 'output'
147
- BACKWARD = 'backward'
148
- FORWARD = 'forward'
149
- PRE_FORWARD = "pre_forward"
150
- INPUT_ARGS = 'input_args'
151
- INPUT_KWARGS = 'input_kwargs'
152
- GRAD_INPUT = 'grad_input'
153
- GRAD_OUTPUT = 'grad_output'
154
- START = "start"
155
- STOP = "stop"
156
- MAX = 'Max'
157
- MIN = 'Min'
158
-
159
- # dump mode
160
- ALL = "all"
161
- LIST = "list"
162
- RANGE = "range"
163
- STACK = "stack"
164
- ACL = "acl"
165
- API_LIST = "api_list"
166
- API_STACK = "api_stack"
167
- DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK]
168
- AUTO = "auto"
169
- ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF]
170
- SUMMARY = "summary"
171
- MD5 = "md5"
172
- SUMMARY_MODE = [ALL, SUMMARY, MD5]
173
-
174
- WRITE_FLAGS = os.O_WRONLY | os.O_CREAT
175
- OVERWRITE_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
176
- WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR
177
-
178
- PKL_SUFFIX = ".pkl"
179
- NUMPY_SUFFIX = ".npy"
180
- ONE_GB = 1 * 1024 * 1024 * 1024
181
- TEN_GB = 10 * 1024 * 1024 * 1024
182
- FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
183
- FILE_NAME_LENGTH = 255
184
- DIRECTORY_LENGTH = 4096
185
- DISTRIBUTED_PREFIX_LENGTH = 60
186
- SUMMARY_COLUMN_NUM = 6
187
- STACK_COLUMN_NUM = 2
188
- # env dump path
189
- ASCEND_WORK_PATH = "ASCEND_WORK_PATH"
190
- DUMP_DIR = "dump_data"
191
- DATA = "data"
192
-
193
- ENV_ENABLE = "1"
194
- ENV_DISABLE = "0"
195
-
196
- MAX_SEED_VALUE = 2**32 - 1
197
-
198
- INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
199
- "_reduce_scatter_base", "_all_gather_base", "all_to_all_single"]
200
-
201
- TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"]
202
- LEVEL_LIST = ["L0", "L1", "L2", "mix"]
203
- STATISTICS = "statistics"
204
- TENSOR = "tensor"
205
- OVERFLOW_CHECK = "overflow_check"
206
- FREE_BENCHMARK = "free_benchmark"
207
-
208
- ATTR_NAME_PREFIX = "wrap_"
209
-
210
- FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16]
211
- BOOL_TYPE = [bool, np.uint8]
212
- INT_TYPE = [np.int32, np.int64]
213
- NPU = 'NPU'
214
- DISTRIBUTED = 'Distributed'
215
-
216
- RAISE_PRECISION = {
217
- torch.float16: torch.float32,
218
- torch.bfloat16: torch.float32,
219
- torch.float32: torch.float64
220
- }
221
- CONVERT = {
222
- "int32_to_int64": ["torch.int32", "torch.int64"],
223
- }
224
-
225
- CONVERT_API = {
226
- "int32_to_int64": ["cross_entropy"]
227
- }
228
-
229
-
230
- def get_tensor_rank(in_feat, out_feat):
231
- if dist.is_initialized():
232
- return dist.get_rank()
233
-
234
- def get_tensor_rank_single(x):
235
- if isinstance(x, (list, tuple)):
236
- if len(x) > 0:
237
- return get_tensor_rank_single(x[0])
238
- elif isinstance(x, torch.Tensor):
239
- device = x.device
240
- if device.type != 'cpu':
241
- return device.index
242
- return None
243
-
244
- in_rank = get_tensor_rank_single(in_feat)
245
- out_rank = get_tensor_rank_single(out_feat)
246
- tensor_rank = in_rank if in_rank else out_rank
247
- return tensor_rank
248
-
249
-
250
- def get_rank_id():
251
- if torch.distributed.is_initialized():
252
- return torch.distributed.get_rank()
253
- return 0
254
-
255
-
256
- def print_rank_0(message):
257
- if dist.is_initialized():
258
- if dist.get_rank() == 0:
259
- logger.info(message)
260
- else:
261
- logger.info(message)
262
-
263
-
264
- def load_pt(pt_path, to_cpu=False):
265
- pt_path = os.path.realpath(pt_path)
266
- check_file_or_directory_path(pt_path)
267
- try:
268
- if to_cpu:
269
- pt = torch.load(pt_path, map_location=torch.device("cpu"))
270
- else:
271
- pt = torch.load(pt_path)
272
- except Exception as e:
273
- raise RuntimeError(f"load pt file {pt_path} failed") from e
274
- return pt
275
-
276
-
277
- def save_pt(tensor, filepath):
278
- filepath = os.path.realpath(filepath)
279
- check_path_before_create(filepath)
280
- try:
281
- torch.save(tensor, filepath)
282
- except Exception as e:
283
- common_logger.error("Save pt file failed, please check according possible error causes: "
284
- "1. out of disk space or disk error, "
285
- "2. no permission to write files, etc.")
286
- raise RuntimeError(f"save pt file {filepath} failed") from e
287
- change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
288
-
289
-
290
- def _create_logger(level=logging.INFO):
291
- logger_ = logging.getLogger()
292
- logger_.setLevel(level)
293
- ch = logging.StreamHandler()
294
- ch.setLevel(level)
295
- logger_.addHandler(ch)
296
- return logger_
297
-
298
-
299
- log_level = logging.DEBUG if os.environ.get("API_ACCURACY_CHECK_LOG_LEVEL") == "1" else logging.INFO
300
- logger = _create_logger(log_level)
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ import io
18
+ import os
19
+ import random
20
+ import stat
21
+ import torch
22
+ import torch.distributed as dist
23
+ import numpy as np
24
+ from functools import wraps
25
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
26
+ from msprobe.core.common.log import logger
27
+ from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
28
+ check_file_or_directory_path, check_path_before_create)
29
+
30
+
31
+ try:
32
+ import torch_npu
33
+ except ImportError:
34
+ is_gpu = True
35
+ else:
36
+ is_gpu = False
37
+
38
+
39
+ torch_without_guard_version = torch.__version__ >= '2.1'
40
+
41
+
42
+ if not is_gpu and not torch_without_guard_version:
43
+ from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard
44
+
45
+ npu_distributed_api = ['isend', 'irecv']
46
+
47
+
48
+ def parameter_adapter(func):
49
+
50
+ def handle_masked_select(input_tensor, indices):
51
+ masked_select_func = getattr(torch._C._VariableFunctionsClass, "masked_select")
52
+ if input_tensor.dtype == torch.bfloat16:
53
+ # masked_select在NPU上输入数据dtype类型为bfloat16会报错,提示不支持此类型
54
+ return masked_select_func(input_tensor.to(torch.float32), indices).to(torch.bfloat16)
55
+ else:
56
+ return masked_select_func(input_tensor, indices)
57
+
58
+ @wraps(func)
59
+ def inner(self, *args, **kwargs):
60
+ if self.op_name_ == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor):
61
+ input_tensor = args[0]
62
+ indices = args[1]
63
+ if indices.dtype == torch.uint8:
64
+ indices = indices.bool()
65
+ if indices.dtype == torch.bool:
66
+ if indices.shape == input_tensor.shape:
67
+ return handle_masked_select(input_tensor, indices)
68
+ else:
69
+ indices = getattr(torch._C._VariableFunctionsClass, "nonzero")(indices, as_tuple=True)
70
+ return getattr(torch._C._TensorBase, "__getitem__")(input_tensor, indices)
71
+ elif indices.dtype != torch.bool:
72
+ if not indices.shape or len(indices.shape) == 1:
73
+ return func(self, input_tensor, indices.tolist())
74
+ elif len(indices.shape) == 2:
75
+ result = [func(self, input_tensor, index) for index in indices.tolist()]
76
+ return getattr(torch._C._VariableFunctionsClass, "stack")(result, 0)
77
+ else:
78
+ res = [input_tensor[tensor_index] for tensor_index in indices]
79
+ return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0)
80
+ if self.op_name_ == "__eq__" and args[1] is None:
81
+ return False
82
+ return func(self, *args, **kwargs)
83
+ return inner
84
+
85
+
86
+ def torch_device_guard(func):
87
+ if is_gpu or torch_without_guard_version:
88
+ return func
89
+ # Parse args/kwargs matched torch.device objects
90
+
91
+ @torch_npu_device_guard
92
+ def wrapper(*args, **kwargs):
93
+ return func(*args, **kwargs)
94
+ return wrapper
95
+
96
+
97
+ def get_rank_if_initialized():
98
+ """
99
+ return rank id if it is initialized or raise Exception: DistributedNotInitializedError
100
+ """
101
+ if torch.distributed.is_initialized():
102
+ return torch.distributed.get_rank()
103
+ else:
104
+ raise DistributedNotInitializedError("torch distributed environment is not initialized")
105
+
106
+
107
+ def seed_all(seed=1234, mode=False):
108
+ random.seed(seed)
109
+ os.environ['PYTHONHASHSEED'] = str(seed)
110
+ np.random.seed(seed)
111
+ torch.manual_seed(seed)
112
+ torch.use_deterministic_algorithms(mode)
113
+ if is_gpu:
114
+ torch.cuda.manual_seed_all(seed)
115
+ torch.cuda.manual_seed(seed)
116
+ torch.backends.cudnn.deterministic = True
117
+ torch.backends.cudnn.enable = False
118
+ torch.backends.cudnn.benchmark = False
119
+ else:
120
+ torch_npu.npu.manual_seed_all(seed)
121
+ torch_npu.npu.manual_seed(seed)
122
+
123
+
124
+ class Const:
125
+ """
126
+ Class for const
127
+ """
128
+ SEP = "."
129
+ MODEL_TYPE = ['.onnx', '.pb', '.om']
130
+ DIM_PATTERN = r"^(-?[0-9]+)(,-?[0-9]+)*"
131
+ SEMICOLON = ";"
132
+ COLON = ":"
133
+ EQUAL = "="
134
+ COMMA = ","
135
+ DOT = "."
136
+ DUMP_RATIO_MAX = 100
137
+ SUMMERY_DATA_NUMS = 256
138
+ FLOAT_EPSILON = np.finfo(float).eps
139
+ SUPPORT_DUMP_MODE = ['api', 'acl']
140
+ ON = 'ON'
141
+ OFF = 'OFF'
142
+ KWARGS = 'kwargs'
143
+ INPUT = 'input'
144
+ OUTPUT = 'output'
145
+ BACKWARD = 'backward'
146
+ FORWARD = 'forward'
147
+ PRE_FORWARD = "pre_forward"
148
+ INPUT_ARGS = 'input_args'
149
+ INPUT_KWARGS = 'input_kwargs'
150
+ GRAD_INPUT = 'grad_input'
151
+ GRAD_OUTPUT = 'grad_output'
152
+ START = "start"
153
+ STOP = "stop"
154
+ MAX = 'Max'
155
+ MIN = 'Min'
156
+
157
+ # dump mode
158
+ ALL = "all"
159
+ LIST = "list"
160
+ RANGE = "range"
161
+ STACK = "stack"
162
+ ACL = "acl"
163
+ API_LIST = "api_list"
164
+ API_STACK = "api_stack"
165
+ DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK]
166
+ AUTO = "auto"
167
+ ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF]
168
+ SUMMARY = "summary"
169
+ MD5 = "md5"
170
+ SUMMARY_MODE = [ALL, SUMMARY, MD5]
171
+
172
+ WRITE_FLAGS = os.O_WRONLY | os.O_CREAT
173
+ OVERWRITE_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
174
+ WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR
175
+
176
+ PKL_SUFFIX = ".pkl"
177
+ NUMPY_SUFFIX = ".npy"
178
+ ONE_GB = 1 * 1024 * 1024 * 1024
179
+ TEN_GB = 10 * 1024 * 1024 * 1024
180
+ FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
181
+ FILE_NAME_LENGTH = 255
182
+ DIRECTORY_LENGTH = 4096
183
+ DISTRIBUTED_PREFIX_LENGTH = 60
184
+ SUMMARY_COLUMN_NUM = 6
185
+ STACK_COLUMN_NUM = 2
186
+ # env dump path
187
+ ASCEND_WORK_PATH = "ASCEND_WORK_PATH"
188
+ DUMP_DIR = "dump_data"
189
+ DATA = "data"
190
+
191
+ ENV_ENABLE = "1"
192
+ ENV_DISABLE = "0"
193
+
194
+ MAX_SEED_VALUE = 2**32 - 1
195
+
196
+ INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
197
+ "_reduce_scatter_base", "_all_gather_base", "all_to_all_single"]
198
+
199
+ TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"]
200
+ LEVEL_LIST = ["L0", "L1", "L2", "mix"]
201
+ STATISTICS = "statistics"
202
+ TENSOR = "tensor"
203
+ OVERFLOW_CHECK = "overflow_check"
204
+ FREE_BENCHMARK = "free_benchmark"
205
+
206
+ ATTR_NAME_PREFIX = "wrap_"
207
+
208
+ FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16]
209
+ BOOL_TYPE = [bool, np.uint8]
210
+ INT_TYPE = [np.int32, np.int64]
211
+ NPU = 'NPU'
212
+ DISTRIBUTED = 'Distributed'
213
+
214
+ RAISE_PRECISION = {
215
+ torch.float16: torch.float32,
216
+ torch.bfloat16: torch.float32,
217
+ torch.float32: torch.float64
218
+ }
219
+ CONVERT = {
220
+ "int32_to_int64": ["torch.int32", "torch.int64"],
221
+ }
222
+
223
+ CONVERT_API = {
224
+ "int32_to_int64": ["cross_entropy"]
225
+ }
226
+
227
+
228
+ def get_tensor_rank(in_feat, out_feat):
229
+ if dist.is_initialized():
230
+ return dist.get_rank()
231
+
232
+ def get_tensor_rank_single(x):
233
+ if isinstance(x, (list, tuple)):
234
+ if len(x) > 0:
235
+ return get_tensor_rank_single(x[0])
236
+ elif isinstance(x, torch.Tensor):
237
+ device = x.device
238
+ if device.type != 'cpu':
239
+ return device.index
240
+ return None
241
+
242
+ in_rank = get_tensor_rank_single(in_feat)
243
+ out_rank = get_tensor_rank_single(out_feat)
244
+ tensor_rank = in_rank if in_rank else out_rank
245
+ return tensor_rank
246
+
247
+
248
+ def get_rank_id():
249
+ if torch.distributed.is_initialized():
250
+ return torch.distributed.get_rank()
251
+ return 0
252
+
253
+
254
+ def print_rank_0(message):
255
+ if dist.is_initialized():
256
+ if dist.get_rank() == 0:
257
+ logger.info(message)
258
+ else:
259
+ logger.info(message)
260
+
261
+
262
+ def load_pt(pt_path, to_cpu=False):
263
+ pt_path = os.path.realpath(pt_path)
264
+ check_file_or_directory_path(pt_path)
265
+ try:
266
+ if to_cpu:
267
+ pt = torch.load(pt_path, map_location=torch.device("cpu"))
268
+ else:
269
+ pt = torch.load(pt_path)
270
+ except Exception as e:
271
+ raise RuntimeError(f"load pt file {pt_path} failed") from e
272
+ return pt
273
+
274
+
275
+ def save_pt(tensor, filepath):
276
+ filepath = os.path.realpath(filepath)
277
+ check_path_before_create(filepath)
278
+ try:
279
+ torch.save(tensor, filepath)
280
+ except Exception as e:
281
+ logger.error("Save pt file failed, please check according possible error causes: "
282
+ "1. out of disk space or disk error, "
283
+ "2. no permission to write files, etc.")
284
+ raise RuntimeError(f"save pt file {filepath} failed") from e
285
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
286
+
287
+
288
+ def save_api_data(api_data):
289
+ """Save data to io stream"""
290
+ try:
291
+ io_buff = io.BytesIO()
292
+ torch.save(api_data, io_buff)
293
+ except Exception as e:
294
+ raise RuntimeError(f"save api_data to io_buff failed") from e
295
+ return io_buff
296
+
297
+
298
+ def load_api_data(api_data_bytes):
299
+ """Load data from bytes stream"""
300
+ try:
301
+ buffer = io.BytesIO(api_data_bytes)
302
+ buffer = torch.load(buffer, map_location="cpu")
303
+ except Exception as e:
304
+ raise RuntimeError(f"load api_data from bytes failed") from e
305
+ return buffer