mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,165 +1,214 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- import os
18
- import re
19
-
20
- import torch
21
-
22
- try:
23
- import torch_npu
24
- except ImportError:
25
- IS_GPU = True
26
- else:
27
- IS_GPU = False
28
-
29
- from msprobe.pytorch.common.log import logger
30
- from msprobe.core.common.file_check import FileChecker, FileOpen, change_mode, create_directory
31
- from msprobe.core.common.const import Const, FileCheckConst
32
- from msprobe.core.common.utils import CompareException
33
-
34
-
35
- class DumpException(CompareException):
36
- pass
37
-
38
-
39
- def check_object_type(check_object, allow_type):
40
- """
41
- Function Description:
42
- Check if the object belongs to a certain data type
43
- Parameter:
44
- check_object: the object to be checked
45
- allow_type: legal data type
46
- Exception Description:
47
- when invalid data throw exception
48
- """
49
- if not isinstance(check_object, allow_type):
50
- logger.error(f"{check_object} not of {allow_type} type")
51
- raise CompareException(CompareException.INVALID_DATA_ERROR)
52
-
53
-
54
- class SoftlinkCheckException(Exception):
55
- pass
56
-
57
-
58
- def check_need_convert(api_name):
59
- convert_type = None
60
- for key, value in Const.CONVERT_API.items():
61
- if api_name not in value:
62
- continue
63
- else:
64
- convert_type = key
65
- return convert_type
66
-
67
-
68
- def api_info_preprocess(api_name, api_info_dict):
69
- """
70
- Function Description:
71
- Preprocesses the API information.
72
- Parameter:
73
- api_name: Name of the API.
74
- api_info_dict: argument of the API.
75
- Return api_info_dict:
76
- convert_type: Type of conversion.
77
- api_info_dict: Processed argument of the API.
78
- """
79
- convert_type = check_need_convert(api_name)
80
- if api_name == 'cross_entropy':
81
- api_info_dict = cross_entropy_process(api_info_dict)
82
- return convert_type, api_info_dict
83
-
84
-
85
- def cross_entropy_process(api_info_dict):
86
- """
87
- Function Description:
88
- Preprocesses the cross_entropy API information.
89
- Parameter:
90
- api_info_dict: argument of the API.
91
- Return api_info_dict:
92
- api_info_dict: Processed argument of the API.
93
- """
94
- if 'args' in api_info_dict and len(api_info_dict['args']) > 1 and 'Min' in api_info_dict['args'][1]:
95
- if api_info_dict['args'][1]['Min'] <= 0:
96
- # The second argument in cross_entropy should be -100 or not less than 0
97
- api_info_dict['args'][1]['Min'] = 0
98
- return api_info_dict
99
-
100
-
101
- def initialize_save_path(save_path, dir_name):
102
- data_path = os.path.join(save_path, dir_name)
103
- if os.path.exists(data_path):
104
- logger.warning(f"{data_path} already exists, it will be overwritten")
105
- else:
106
- os.mkdir(data_path, mode=FileCheckConst.DATA_DIR_AUTHORITY)
107
- data_path_checker = FileChecker(data_path, FileCheckConst.DIR)
108
- data_path_checker.common_check()
109
- return data_path
110
-
111
-
112
- def write_pt(file_path, tensor):
113
- if os.path.exists(file_path):
114
- raise ValueError(f"File {file_path} already exists")
115
- torch.save(tensor, file_path)
116
- full_path = os.path.realpath(file_path)
117
- change_mode(full_path, FileCheckConst.DATA_FILE_AUTHORITY)
118
- return full_path
119
-
120
-
121
- def get_real_data_path(file_path):
122
- targets = ['forward_real_data', 'backward_real_data', 'ut_error_data\d+']
123
- pattern = re.compile(r'({})'.format('|'.join(targets)))
124
- match = pattern.search(file_path)
125
- if match:
126
- target_index = match.start()
127
- target_path = file_path[target_index:]
128
- return target_path
129
- else:
130
- raise DumpException(DumpException.INVALID_PATH_ERROR)
131
-
132
-
133
- def get_full_data_path(data_path, real_data_path):
134
- if not data_path:
135
- return data_path
136
- full_data_path = os.path.join(real_data_path, data_path)
137
- return os.path.realpath(full_data_path)
138
-
139
-
140
- class UtDataProcessor:
141
- def __init__(self, save_path):
142
- self.save_path = save_path
143
- self.index = 0
144
-
145
- def save_tensors_in_element(self, api_name, element):
146
- self.index = 0
147
- self._save_recursive(api_name, element)
148
-
149
- def _save_recursive(self, api_name, element):
150
- if isinstance(element, torch.Tensor):
151
- api_args = api_name + Const.SEP + str(self.index)
152
- create_directory(self.save_path)
153
- file_path = os.path.join(self.save_path, f'{api_args}.pt')
154
- write_pt(file_path, element.contiguous().cpu().detach())
155
- self.index += 1
156
- elif element is None or isinstance(element, (bool, int, float, str, slice)):
157
- self.index += 1
158
- elif isinstance(element, (list, tuple)):
159
- for item in element:
160
- self._save_recursive(api_name, item)
161
- elif isinstance(element, dict):
162
- for value in element.values():
163
- self._save_recursive(api_name, value)
164
- else:
165
- self.index += 1
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ import os
18
+ import re
19
+ from collections import namedtuple
20
+
21
+ import torch
22
+
23
+ try:
24
+ import torch_npu
25
+ except ImportError:
26
+ IS_GPU = True
27
+ else:
28
+ IS_GPU = False
29
+
30
+ from msprobe.pytorch.common.log import logger
31
+ from msprobe.pytorch.common.utils import save_pt
32
+ from msprobe.core.common.file_utils import create_directory
33
+ from msprobe.core.common.const import Const
34
+ from msprobe.core.common.utils import CompareException
35
+
36
+ ApiData = namedtuple('ApiData', ['name', 'args', 'kwargs', 'result', 'step', 'rank'],
37
+ defaults=['unknown', None, None, None, 0, 0])
38
+
39
+
40
+ class DumpException(CompareException):
41
+ pass
42
+
43
+
44
+ def check_object_type(check_object, allow_type):
45
+ """
46
+ Function Description:
47
+ Check if the object belongs to a certain data type
48
+ Parameter:
49
+ check_object: the object to be checked
50
+ allow_type: legal data type
51
+ Exception Description:
52
+ when invalid data throw exception
53
+ """
54
+ if not isinstance(check_object, allow_type):
55
+ logger.error(f"{check_object} not of {allow_type} type")
56
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
57
+
58
+
59
+ class SoftlinkCheckException(Exception):
60
+ pass
61
+
62
+
63
+ def check_need_convert(api_name):
64
+ convert_type = None
65
+ for key, value in Const.CONVERT_API.items():
66
+ if api_name not in value:
67
+ continue
68
+ else:
69
+ convert_type = key
70
+ return convert_type
71
+
72
+
73
+ def api_info_preprocess(api_name, api_info_dict):
74
+ """
75
+ Function Description:
76
+ Preprocesses the API information.
77
+ Parameter:
78
+ api_name: Name of the API.
79
+ api_info_dict: argument of the API.
80
+ Return api_info_dict:
81
+ convert_type: Type of conversion.
82
+ api_info_dict: Processed argument of the API.
83
+ """
84
+ convert_type = check_need_convert(api_name)
85
+ if api_name == 'cross_entropy':
86
+ api_info_dict = cross_entropy_process(api_info_dict)
87
+ return convert_type, api_info_dict
88
+
89
+
90
+ def cross_entropy_process(api_info_dict):
91
+ """
92
+ Function Description:
93
+ Preprocesses the cross_entropy API information.
94
+ Parameter:
95
+ api_info_dict: argument of the API.
96
+ Return api_info_dict:
97
+ api_info_dict: Processed argument of the API.
98
+ """
99
+ if 'input_args' in api_info_dict and len(api_info_dict['input_args']) > 1 and 'Min' in api_info_dict['input_args'][1]:
100
+ if api_info_dict['input_args'][1]['Min'] <= 0:
101
+ # The second argument in cross_entropy should be -100 or not less than 0
102
+ api_info_dict['input_args'][1]['Min'] = 0
103
+ return api_info_dict
104
+
105
+
106
+ def initialize_save_path(save_path, dir_name):
107
+ data_path = os.path.join(save_path, dir_name)
108
+ create_directory(data_path)
109
+ return data_path
110
+
111
+
112
+ def get_real_data_path(file_path):
113
+ targets = ['forward_real_data', 'backward_real_data', 'ut_error_data\d+']
114
+ pattern = re.compile(r'({})'.format('|'.join(targets)))
115
+ match = pattern.search(file_path)
116
+ if match:
117
+ target_index = match.start()
118
+ target_path = file_path[target_index:]
119
+ return target_path
120
+ else:
121
+ raise DumpException(DumpException.INVALID_PATH_ERROR)
122
+
123
+
124
+ def get_full_data_path(data_path, real_data_path):
125
+ if not data_path:
126
+ return data_path
127
+ full_data_path = os.path.join(real_data_path, data_path)
128
+ return os.path.realpath(full_data_path)
129
+
130
+
131
+ class UtDataProcessor:
132
+ def __init__(self, save_path):
133
+ self.save_path = save_path
134
+ self.index = 0
135
+
136
+ def save_tensors_in_element(self, api_name, element):
137
+ self.index = 0
138
+ self._save_recursive(api_name, element)
139
+
140
+ def _save_recursive(self, api_name, element):
141
+ if isinstance(element, torch.Tensor):
142
+ api_args = api_name + Const.SEP + str(self.index)
143
+ create_directory(self.save_path)
144
+ file_path = os.path.join(self.save_path, f'{api_args}.pt')
145
+ try:
146
+ tensor = element.contiguous().detach().cpu()
147
+ except Exception as err:
148
+ logger.error(f"Failed to transfer tensor to cpu for {api_args}")
149
+ raise DumpException(DumpException.INVALID_DATA_ERROR) from err
150
+ save_pt(tensor, file_path)
151
+ self.index += 1
152
+ elif element is None or isinstance(element, (bool, int, float, str, slice)):
153
+ self.index += 1
154
+ elif isinstance(element, (list, tuple)):
155
+ for item in element:
156
+ self._save_recursive(api_name, item)
157
+ elif isinstance(element, dict):
158
+ for value in element.values():
159
+ self._save_recursive(api_name, value)
160
+ else:
161
+ self.index += 1
162
+
163
+
164
+ def extract_basic_api_segments(api_full_name):
165
+ """
166
+ Function Description:
167
+ Extract the name of the API.
168
+ Parameter:
169
+ api_full_name: Full name of the API. Example: torch.matmul.0, torch.linalg.inv.0
170
+ Return:
171
+ api_type: Type of api. Example: torch, tensor, etc.
172
+ api_name: Name of api. Example: matmul, linalg.inv, etc.
173
+ """
174
+ api_type = None
175
+ api_parts = api_full_name.split(Const.SEP)
176
+ api_parts_length = len(api_parts)
177
+ if api_parts_length == Const.THREE_SEGMENT:
178
+ api_type, api_name, _ = api_parts
179
+ elif api_parts_length == Const.FOUR_SEGMENT:
180
+ api_type, prefix, api_name, _ = api_parts
181
+ api_name = Const.SEP.join([prefix, api_name])
182
+ else:
183
+ api_name = None
184
+ return api_type, api_name
185
+
186
+
187
+ def extract_detailed_api_segments(full_api_name_with_direction_status):
188
+ """
189
+ Function Description:
190
+ Extract the name of the API.
191
+ Parameter:
192
+ full_api_name_with_direction_status: Full name of the API. Example: torch.matmul.0.forward.output.0
193
+ Return:
194
+ api_name: Name of api. Example: matmul, mul, etc.
195
+ full_api_name: Full name of api. Example: torch.matmul.0
196
+ direction_status: Direction status of api. Example: forward, backward, etc.
197
+ """
198
+ api_type = None
199
+ prefix = None
200
+ api_name = None
201
+ direction_status = None
202
+ api_parts = full_api_name_with_direction_status.split(Const.SEP)
203
+ api_parts_length = len(api_parts)
204
+ if api_parts_length == Const.SIX_SEGMENT:
205
+ api_type, api_name, api_order, direction_status, _, _ = api_parts
206
+ full_api_name = Const.SEP.join([api_type, api_name, api_order])
207
+ elif api_parts_length == Const.SEVEN_SEGMENT:
208
+ api_type, prefix, api_name, api_order, direction_status, _, _ = api_parts
209
+ full_api_name = Const.SEP.join([api_type, prefix, api_name, api_order])
210
+ api_name = Const.SEP.join([prefix, api_name])
211
+ else:
212
+ full_api_name = None
213
+ return api_name, full_api_name, direction_status
214
+