mindstudio-probe 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
  2. mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
  3. mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
  4. mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
  5. mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
  6. mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
  7. msprobe/README.md +182 -0
  8. msprobe/__init__.py +0 -0
  9. msprobe/config/README.md +397 -0
  10. msprobe/config/config.json +28 -0
  11. msprobe/config/img/free_benchmark.png +0 -0
  12. msprobe/core/common/const.py +241 -0
  13. msprobe/core/common/exceptions.py +88 -0
  14. msprobe/core/common/file_check.py +265 -0
  15. msprobe/core/common/log.py +55 -0
  16. msprobe/core/common/utils.py +516 -0
  17. msprobe/core/common_config.py +58 -0
  18. msprobe/core/data_dump/data_collector.py +140 -0
  19. msprobe/core/data_dump/data_processor/base.py +245 -0
  20. msprobe/core/data_dump/data_processor/factory.py +61 -0
  21. msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
  22. msprobe/core/data_dump/json_writer.py +116 -0
  23. msprobe/core/data_dump/scope.py +178 -0
  24. msprobe/mindspore/__init__.py +1 -0
  25. msprobe/mindspore/debugger/__init__.py +0 -0
  26. msprobe/mindspore/debugger/debugger_config.py +51 -0
  27. msprobe/mindspore/debugger/precision_debugger.py +32 -0
  28. msprobe/mindspore/doc/dump.md +65 -0
  29. msprobe/mindspore/dump/__init__.py +0 -0
  30. msprobe/mindspore/dump/api_kbk_dump.py +55 -0
  31. msprobe/mindspore/dump/dump_tool_factory.py +38 -0
  32. msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
  33. msprobe/mindspore/ms_config.py +78 -0
  34. msprobe/mindspore/overflow_check/__init__.py +0 -0
  35. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
  36. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
  37. msprobe/mindspore/task_handler_factory.py +21 -0
  38. msprobe/msprobe.py +67 -0
  39. msprobe/pytorch/__init__.py +4 -0
  40. msprobe/pytorch/advisor/advisor.py +124 -0
  41. msprobe/pytorch/advisor/advisor_const.py +59 -0
  42. msprobe/pytorch/advisor/advisor_result.py +58 -0
  43. msprobe/pytorch/api_accuracy_checker/.keep +0 -0
  44. msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
  45. msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
  46. msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
  47. msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
  48. msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
  49. msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
  50. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
  51. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
  52. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
  53. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
  54. msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
  55. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
  58. msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
  59. msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
  60. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
  61. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
  62. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
  63. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
  64. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
  65. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
  66. msprobe/pytorch/common/__init__.py +2 -0
  67. msprobe/pytorch/common/compare_script.template +14 -0
  68. msprobe/pytorch/common/log.py +32 -0
  69. msprobe/pytorch/common/parse_json.py +37 -0
  70. msprobe/pytorch/common/utils.py +224 -0
  71. msprobe/pytorch/compare/acc_compare.py +1024 -0
  72. msprobe/pytorch/compare/distributed_compare.py +111 -0
  73. msprobe/pytorch/compare/highlight.py +100 -0
  74. msprobe/pytorch/compare/mapping.yaml +607 -0
  75. msprobe/pytorch/compare/match.py +36 -0
  76. msprobe/pytorch/compare/npy_compare.py +244 -0
  77. msprobe/pytorch/debugger/__init__.py +0 -0
  78. msprobe/pytorch/debugger/debugger_config.py +86 -0
  79. msprobe/pytorch/debugger/precision_debugger.py +95 -0
  80. msprobe/pytorch/doc/FAQ.md +193 -0
  81. msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
  82. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
  83. msprobe/pytorch/doc/dump.md +207 -0
  84. msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
  85. msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
  86. msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
  87. msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
  88. msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
  89. msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
  90. msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
  91. msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
  92. msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
  93. msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
  94. msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
  95. msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
  96. msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
  97. msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
  98. msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
  99. msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
  100. msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
  101. msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
  102. msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
  103. msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
  104. msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
  105. msprobe/pytorch/doc/img/cpu_info.png +0 -0
  106. msprobe/pytorch/doc/img/module_compare.png +0 -0
  107. msprobe/pytorch/doc/parse_tool.md +286 -0
  108. msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
  109. msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
  110. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
  111. msprobe/pytorch/doc/run_overflow_check.md +25 -0
  112. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
  113. msprobe/pytorch/free_benchmark/__init__.py +8 -0
  114. msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
  115. msprobe/pytorch/free_benchmark/common/constant.py +67 -0
  116. msprobe/pytorch/free_benchmark/common/counter.py +72 -0
  117. msprobe/pytorch/free_benchmark/common/enums.py +37 -0
  118. msprobe/pytorch/free_benchmark/common/params.py +129 -0
  119. msprobe/pytorch/free_benchmark/common/utils.py +98 -0
  120. msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
  121. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
  122. msprobe/pytorch/free_benchmark/main.py +102 -0
  123. msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
  124. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
  125. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
  126. msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
  127. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
  132. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
  133. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
  134. msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
  135. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
  136. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
  137. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
  138. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
  139. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
  140. msprobe/pytorch/functional/__init__.py +0 -0
  141. msprobe/pytorch/functional/data_processor.py +0 -0
  142. msprobe/pytorch/functional/dump_module.py +39 -0
  143. msprobe/pytorch/hook_module/__init__.py +1 -0
  144. msprobe/pytorch/hook_module/api_registry.py +161 -0
  145. msprobe/pytorch/hook_module/hook_module.py +109 -0
  146. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
  147. msprobe/pytorch/hook_module/utils.py +29 -0
  148. msprobe/pytorch/hook_module/wrap_aten.py +100 -0
  149. msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
  150. msprobe/pytorch/hook_module/wrap_functional.py +108 -0
  151. msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
  152. msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
  153. msprobe/pytorch/hook_module/wrap_torch.py +88 -0
  154. msprobe/pytorch/hook_module/wrap_vf.py +64 -0
  155. msprobe/pytorch/module_processer.py +98 -0
  156. msprobe/pytorch/online_dispatch/__init__.py +20 -0
  157. msprobe/pytorch/online_dispatch/compare.py +236 -0
  158. msprobe/pytorch/online_dispatch/dispatch.py +274 -0
  159. msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
  160. msprobe/pytorch/online_dispatch/single_compare.py +391 -0
  161. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
  162. msprobe/pytorch/online_dispatch/utils.py +187 -0
  163. msprobe/pytorch/parse.py +4 -0
  164. msprobe/pytorch/parse_tool/__init__.py +0 -0
  165. msprobe/pytorch/parse_tool/cli.py +32 -0
  166. msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
  167. msprobe/pytorch/parse_tool/lib/compare.py +259 -0
  168. msprobe/pytorch/parse_tool/lib/config.py +51 -0
  169. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
  170. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
  171. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
  172. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
  173. msprobe/pytorch/parse_tool/lib/utils.py +367 -0
  174. msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
  175. msprobe/pytorch/pt_config.py +93 -0
  176. msprobe/pytorch/service.py +167 -0
  177. msprobe/test/core_ut/common/test_utils.py +345 -0
  178. msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
  179. msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
  180. msprobe/test/core_ut/data_dump/test_scope.py +151 -0
  181. msprobe/test/core_ut/test_common_config.py +152 -0
  182. msprobe/test/core_ut/test_file_check.py +218 -0
  183. msprobe/test/core_ut/test_log.py +109 -0
  184. msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
  185. msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
  186. msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
  187. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
  188. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
  189. msprobe/test/mindspore_ut/test_ms_config.py +69 -0
  190. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
  191. msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
  192. msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
  193. msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
  194. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
  195. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
  196. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
  197. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
  198. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
  199. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
  200. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
  201. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
  202. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
  203. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
  204. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
  205. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
  206. msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
  207. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
  208. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
  209. msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
  210. msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
  211. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
  212. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
  213. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
  214. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
  215. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
  216. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
  217. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
  218. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
  219. msprobe/test/pytorch_ut/test_pt_config.py +69 -0
  220. msprobe/test/pytorch_ut/test_service.py +59 -0
  221. msprobe/test/resources/advisor.txt +3 -0
  222. msprobe/test/resources/compare_result_20230703104808.csv +9 -0
  223. msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
  224. msprobe/test/resources/config.yaml +3 -0
  225. msprobe/test/resources/npu_test.pkl +8 -0
  226. msprobe/test/run_test.sh +30 -0
  227. msprobe/test/run_ut.py +58 -0
  228. msprobe/test/test_module_processer.py +64 -0
@@ -0,0 +1,516 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ import collections
18
+ import os
19
+ import re
20
+ import shutil
21
+ import stat
22
+ import subprocess
23
+ import time
24
+ import json
25
+ from datetime import datetime, timezone
26
+ from pathlib import Path
27
+ import numpy as np
28
+
29
+ from msprobe.core.common.file_check import FileOpen, FileChecker
30
+ from msprobe.core.common.const import Const, FileCheckConst, CompareConst, OverflowConst
31
+ from msprobe.core.common.log import logger
32
+
33
+
34
+ device = collections.namedtuple('device', ['type', 'index'])
35
+ prefixes = ['api_stack', 'list', 'range', 'acl']
36
+
37
+
38
+ class CompareException(Exception):
39
+ """
40
+ Class for Accuracy Compare Exception
41
+ """
42
+ NONE_ERROR = 0
43
+ INVALID_PATH_ERROR = 1
44
+ OPEN_FILE_ERROR = 2
45
+ CLOSE_FILE_ERROR = 3
46
+ READ_FILE_ERROR = 4
47
+ WRITE_FILE_ERROR = 5
48
+ INVALID_FILE_ERROR = 6
49
+ PERMISSION_ERROR = 7
50
+ INDEX_OUT_OF_BOUNDS_ERROR = 8
51
+ NO_DUMP_FILE_ERROR = 9
52
+ INVALID_DATA_ERROR = 10
53
+ INVALID_PARAM_ERROR = 11
54
+ INVALID_DUMP_RATIO = 12
55
+ INVALID_DUMP_FILE = 13
56
+ UNKNOWN_ERROR = 14
57
+ INVALID_DUMP_MODE = 15
58
+ PARSE_FILE_ERROR = 16
59
+ INVALID_COMPARE_MODE = 17
60
+ OVER_SIZE_FILE_ERROR = 18
61
+ INVALID_SUMMARY_MODE = 19
62
+ INVALID_TASK_ERROR = 20
63
+
64
+ def __init__(self, code, error_info: str = ""):
65
+ super(CompareException, self).__init__()
66
+ self.code = code
67
+ self.error_info = error_info
68
+
69
+ def __str__(self):
70
+ return self.error_info
71
+
72
+
73
+ class DumpException(CompareException):
74
+ pass
75
+
76
+
77
+ def make_dump_path_if_not_exists(dump_path):
78
+ if not os.path.exists(dump_path):
79
+ try:
80
+ Path(dump_path).mkdir(mode=0o750, exist_ok=True, parents=True)
81
+ except OSError as ex:
82
+ logger.error(
83
+ 'Failed to create {}.Please check the path permission or disk space .{}'.format(dump_path, str(ex)))
84
+ raise CompareException(CompareException.INVALID_PATH_ERROR) from ex
85
+ else:
86
+ if not os.path.isdir(dump_path):
87
+ logger.error('{} already exists and is not a directory.'.format(dump_path))
88
+
89
+
90
+ def check_mode_valid(mode, scope=None, api_list=None):
91
+ if scope is None:
92
+ scope = []
93
+ if api_list is None:
94
+ api_list = []
95
+ if not isinstance(scope, list):
96
+ raise ValueError("scope param set invalid, it's must be a list.")
97
+ if not isinstance(api_list, list):
98
+ raise ValueError("api_list param set invalid, it's must be a list.")
99
+ mode_check = {
100
+ Const.ALL: lambda: None,
101
+ Const.RANGE: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end].") if len(scope) != 2 else None,
102
+ Const.LIST: lambda: ValueError("set_dump_switch, scope param set invalid, it's should not be an empty list.") if len(scope) == 0 else None,
103
+ Const.STACK: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end] or [].") if len(scope) > 2 else None,
104
+ Const.ACL: lambda: ValueError("set_dump_switch, scope param set invalid, only one api name is supported in acl mode.") if len(scope) != 1 else None,
105
+ Const.API_LIST: lambda: ValueError("Current dump mode is 'api_list', but the content of api_list parameter is empty or valid.") if len(api_list) < 1 else None,
106
+ Const.API_STACK: lambda: None,
107
+ }
108
+ if mode not in Const.DUMP_MODE:
109
+ msg = "Current mode '%s' is not supported. Please use the field in %s" % \
110
+ (mode, Const.DUMP_MODE)
111
+ raise CompareException(CompareException.INVALID_DUMP_MODE, msg)
112
+
113
+ if mode_check.get(mode)() is not None:
114
+ raise mode_check.get(mode)()
115
+
116
+
117
+ def check_switch_valid(switch):
118
+ if switch not in ["ON", "OFF"]:
119
+ logger.error("Please set switch with 'ON' or 'OFF'.")
120
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
121
+
122
+
123
+ def check_dump_mode_valid(dump_mode):
124
+ if not isinstance(dump_mode, list):
125
+ logger.warning("Please set dump_mode as a list.")
126
+ dump_mode = [dump_mode]
127
+ if not all(mode in ["all", "forward", "backward", "input", "output"] for mode in dump_mode):
128
+ raise ValueError("Please set dump_mode as a list containing one or more of the following: 'all', 'forward', 'backward', 'input', 'output'.")
129
+ if 'input' not in dump_mode and 'output' not in dump_mode:
130
+ dump_mode.extend(['input', 'output'])
131
+ if 'forward' not in dump_mode and 'backward' not in dump_mode:
132
+ dump_mode.extend(['forward', 'backward'])
133
+ if 'all' in dump_mode or set(["forward", "backward", "input", "output"]).issubset(set(dump_mode)):
134
+ return ["forward", "backward", "input", "output"]
135
+ return dump_mode
136
+
137
+
138
+ def check_summary_mode_valid(summary_mode):
139
+ if summary_mode not in Const.SUMMARY_MODE:
140
+ msg = "The summary_mode is not valid"
141
+ raise CompareException(CompareException.INVALID_SUMMARY_MODE, msg)
142
+
143
+
144
+ def check_summary_only_valid(summary_only):
145
+ if not isinstance(summary_only, bool):
146
+ logger.error("Params summary_only only support True or False.")
147
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
148
+ return summary_only
149
+
150
+
151
+ def check_compare_param(input_parma, output_path, stack_mode=False, summary_compare=False, md5_compare=False):
152
+ if not (isinstance(input_parma, dict) and isinstance(output_path, str)):
153
+ logger.error("Invalid input parameters")
154
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
155
+ check_file_or_directory_path(input_parma.get("npu_json_path"), False)
156
+ check_file_or_directory_path(input_parma.get("bench_json_path"), False)
157
+ check_file_or_directory_path(input_parma.get("stack_json_path"), False)
158
+ if not summary_compare and not md5_compare:
159
+ check_file_or_directory_path(input_parma.get("npu_dump_data_dir"), True)
160
+ check_file_or_directory_path(input_parma.get("bench_dump_data_dir"), True)
161
+ check_file_or_directory_path(output_path, True)
162
+ with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \
163
+ FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \
164
+ FileOpen(input_parma.get("stack_json_path"), "r") as stack_json:
165
+ check_json_file(input_parma, npu_json, bench_json, stack_json)
166
+
167
+
168
+ def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False):
169
+ if not (isinstance(stack_mode, bool) and isinstance(auto_analyze, bool) and isinstance(fuzzy_match, bool)):
170
+ logger.error("Invalid input parameters which should be only bool type.")
171
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
172
+
173
+
174
+ def check_file_or_directory_path(path, isdir=False):
175
+ """
176
+ Function Description:
177
+ check whether the path is valid
178
+ Parameter:
179
+ path: the path to check
180
+ isdir: the path is dir or file
181
+ Exception Description:
182
+ when invalid data throw exception
183
+ """
184
+ if isdir:
185
+ path_checker = FileChecker(path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE)
186
+ else:
187
+ path_checker = FileChecker(path, FileCheckConst.FILE, FileCheckConst.READ_ABLE)
188
+ path_checker.common_check()
189
+
190
+
191
+ def is_starts_with(string, prefix_list):
192
+ return any(string.startswith(prefix) for prefix in prefix_list)
193
+
194
+
195
+ def _check_json(json_file_handle, file_name):
196
+ tensor_line = json_file_handle.readline()
197
+ if not tensor_line:
198
+ logger.error("dump file {} have empty line!".format(file_name))
199
+ raise CompareException(CompareException.INVALID_DUMP_FILE)
200
+ json_file_handle.seek(0, 0)
201
+
202
+
203
+ def check_json_file(input_param, npu_json, bench_json, stack_json):
204
+ _check_json(npu_json, input_param.get("npu_json_path"))
205
+ _check_json(bench_json, input_param.get("bench_json_path"))
206
+ _check_json(stack_json, input_param.get("stack_json_path"))
207
+
208
+
209
+ def check_file_size(input_file, max_size):
210
+ try:
211
+ file_size = os.path.getsize(input_file)
212
+ except OSError as os_error:
213
+ logger.error('Failed to open "%s". %s' % (input_file, str(os_error)))
214
+ raise CompareException(CompareException.INVALID_FILE_ERROR) from os_error
215
+ if file_size > max_size:
216
+ logger.error('The size (%d) of %s exceeds (%d) bytes, tools not support.'
217
+ % (file_size, input_file, max_size))
218
+ raise CompareException(CompareException.INVALID_FILE_ERROR)
219
+
220
+
221
+ def check_file_not_exists(file_path):
222
+ if os.path.exists(file_path) or os.path.islink(file_path):
223
+ remove_path(file_path)
224
+
225
+
226
+ def check_regex_prefix_format_valid(prefix):
227
+ """
228
+ validate the format of the regex prefix
229
+
230
+ Args:
231
+ prefix (str): The prefix string to validate.
232
+
233
+ Returns:
234
+ no returns
235
+
236
+ Raises:
237
+ ValueError: if the prefix length exceeds Const.REGEX_PREFIX_MAX_LENGTH characters or the prefix do not match
238
+ the given pattern Const.REGEX_PREFIX_PATTERN
239
+ """
240
+ if len(prefix) > Const.REGEX_PREFIX_MAX_LENGTH:
241
+ raise ValueError(f"Maximum length of prefix is {Const.REGEX_PREFIX_MAX_LENGTH}, while current length "
242
+ f"is {len(prefix)}")
243
+ if not re.match(Const.REGEX_PREFIX_PATTERN, prefix):
244
+ raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}")
245
+
246
+
247
+ def remove_path(path):
248
+ if not os.path.exists(path):
249
+ return
250
+ try:
251
+ if os.path.islink(path) or os.path.isfile(path):
252
+ os.remove(path)
253
+ else:
254
+ shutil.rmtree(path)
255
+ except PermissionError as err:
256
+ logger.error("Failed to delete {}. Please check the permission.".format(path))
257
+ raise CompareException(CompareException.INVALID_PATH_ERROR) from err
258
+
259
+
260
+ def get_dump_data_path(dump_dir):
261
+ """
262
+ Function Description:
263
+ traverse directories and obtain the absolute path of dump data
264
+ Parameter:
265
+ dump_dir: dump data directory
266
+ Return Value:
267
+ dump data path,file is exist or file is not exist
268
+ """
269
+ dump_data_path = None
270
+ file_is_exist = False
271
+
272
+ check_file_or_directory_path(dump_dir, True)
273
+ for dir_path, _, files in os.walk(dump_dir):
274
+ if len(files) != 0:
275
+ dump_data_path = dir_path
276
+ file_is_exist = True
277
+ break
278
+ dump_data_path = dir_path
279
+ return dump_data_path, file_is_exist
280
+
281
+
282
+ def create_directory(dir_path):
283
+ """
284
+ Function Description:
285
+ creating a directory with specified permissions
286
+ Parameter:
287
+ dir_path: directory path
288
+ Exception Description:
289
+ when invalid data throw exception
290
+ """
291
+ if not os.path.exists(dir_path):
292
+ try:
293
+ os.makedirs(dir_path, mode=0o700)
294
+ except OSError as ex:
295
+ logger.error(
296
+ 'Failed to create {}.Please check the path permission or disk space .{}'.format(dir_path, str(ex)))
297
+ raise CompareException(CompareException.INVALID_PATH_ERROR) from ex
298
+
299
+
300
+ def execute_command(cmd):
301
+ """
302
+ Function Description:
303
+ run the following command
304
+ Parameter:
305
+ cmd: command
306
+ Exception Description:
307
+ when invalid command throw exception
308
+ """
309
+ logger.info('Execute command:%s' % cmd)
310
+ process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
311
+ while process.poll() is None:
312
+ line = process.stdout.readline()
313
+ line = line.strip()
314
+ if line:
315
+ print(line)
316
+ if process.returncode != 0:
317
+ logger.error('Failed to execute command:%s' % " ".join(cmd))
318
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
319
+
320
+
321
+ def save_numpy_data(file_path, data):
322
+ """
323
+ save_numpy_data
324
+ """
325
+ if not os.path.exists(os.path.dirname(file_path)):
326
+ os.makedirs(os.path.dirname(file_path))
327
+ np.save(file_path, data)
328
+
329
+
330
+ def parse_value_by_comma(value):
331
+ """
332
+ parse value by comma, like '1,2,4,8'
333
+ """
334
+ value_list = []
335
+ value_str_list = value.split(Const.COMMA)
336
+ for value_str in value_str_list:
337
+ value_str = value_str.strip()
338
+ if value_str.isdigit() or value_str == '-1':
339
+ value_list.append(int(value_str))
340
+ else:
341
+ logger.error("please check your input shape.")
342
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
343
+ return value_list
344
+
345
+
346
+ def get_data_len_by_shape(shape):
347
+ data_len = 1
348
+ for item in shape:
349
+ if item == -1:
350
+ logger.error("please check your input shape, one dim in shape is -1.")
351
+ return -1
352
+ data_len = data_len * item
353
+ return data_len
354
+
355
+
356
+ def add_time_as_suffix(name):
357
+ return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
358
+
359
+
360
+ def add_time_with_xlsx(name):
361
+ return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
362
+
363
+
364
+ def get_time():
365
+ return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
366
+
367
+
368
+ def format_value(value):
369
+ return float('{:.12f}'.format(value))
370
+
371
+
372
+ def check_seed_all(seed, mode):
373
+ if isinstance(seed, int):
374
+ if seed < 0 or seed > Const.MAX_SEED_VALUE:
375
+ logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
376
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
377
+ else:
378
+ logger.error(f"Seed must be integer.")
379
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
380
+ if not isinstance(mode, bool):
381
+ logger.error(f"seed_all mode must be bool.")
382
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
383
+
384
+
385
+ def get_process_rank(model):
386
+ logger.info("Rank id is not provided. Trying to get the rank id of the model.")
387
+ try:
388
+ local_device = next(model.parameters()).device
389
+ except StopIteration:
390
+ logger.warning('There is no parameter in the model. Fail to get rank id.')
391
+ return 0, False
392
+ if local_device.type == 'cpu':
393
+ logger.warning("Warning: the debugger is unable to get the rank id. "
394
+ "This may cause the dumpped data to be corrupted in the "
395
+ "case of distributed training. (You may ignore this if you are using only one card.) "
396
+ "Transfer the model to npu or gpu before register_hook() to avoid this warning.")
397
+ return 0, False
398
+ else:
399
+ return local_device.index, True
400
+
401
+
402
+ def generate_compare_script(dump_path, pkl_file_path, dump_switch_mode):
403
+ template_path = os.path.join(os.path.dirname(__file__), "compare_script.template")
404
+ pkl_dir = os.path.dirname(pkl_file_path)
405
+ compare_script_path = os.path.join(pkl_dir, "compare_data.py")
406
+ is_api_stack = "True" if dump_switch_mode == Const.API_STACK else "False"
407
+
408
+ try:
409
+ with FileOpen(template_path, 'r') as ftemp, \
410
+ os.fdopen(os.open(compare_script_path, Const.WRITE_FLAGS, Const.WRITE_MODES), 'w+') as fout:
411
+ code_temp = ftemp.read()
412
+ fout.write(code_temp % (pkl_file_path, dump_path, is_api_stack))
413
+ except OSError:
414
+ logger.error(f"Failed to open file. Please check file {template_path} or path {pkl_dir}.")
415
+
416
+ logger.info(f"Generate compare script successfully which is {compare_script_path}.")
417
+
418
+
419
+ def check_file_valid(file_path):
420
+ if os.path.islink(file_path):
421
+ logger.error('The file path {} is a soft link.'.format(file_path))
422
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
423
+
424
+ if len(os.path.realpath(file_path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(file_path)) > \
425
+ Const.FILE_NAME_LENGTH:
426
+ logger.error('The file path length exceeds limit.')
427
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
428
+
429
+ if not re.match(Const.FILE_PATTERN, os.path.realpath(file_path)):
430
+ logger.error('The file path {} contains special characters.'.format(file_path))
431
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
432
+
433
+ if os.path.isfile(file_path):
434
+ file_size = os.path.getsize(file_path)
435
+ if file_path.endswith(Const.PKL_SUFFIX) and file_size > Const.ONE_GB:
436
+ logger.error('The file {} size is greater than 1GB.'.format(file_path))
437
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
438
+ if file_path.endswith(Const.NUMPY_SUFFIX) and file_size > Const.TEN_GB:
439
+ logger.error('The file {} size is greater than 10GB.'.format(file_path))
440
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
441
+
442
+
443
+ def check_path_before_create(path):
444
+ if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \
445
+ Const.FILE_NAME_LENGTH:
446
+ logger.error('The file path length exceeds limit.')
447
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
448
+
449
+ if not re.match(Const.FILE_PATTERN, os.path.realpath(path)):
450
+ logger.error('The file path {} contains special characters.'.format(path))
451
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
452
+
453
+
454
+ def check_inplace_op(prefix):
455
+ if len(prefix) > Const.DISTRIBUTED_PREFIX_LENGTH:
456
+ return False
457
+ match_op = re.findall(r"Distributed\.(.+?)\.\d", prefix)
458
+ op_name = match_op[0] if match_op else None
459
+ return op_name in Const.INPLACE_LIST
460
+
461
+
462
+ def md5_find(data):
463
+ for key_op in data:
464
+ for api_info in data[key_op]:
465
+ if isinstance(data[key_op][api_info], list):
466
+ for data_detail in data[key_op][api_info]:
467
+ if data_detail and 'md5' in data_detail:
468
+ return True
469
+ elif 'md5' in data[key_op][api_info]:
470
+ return True
471
+ return False
472
+
473
+
474
+ def task_dumppath_get(input_param):
475
+ npu_json_path = input_param.get("npu_json_path", None)
476
+ bench_json_path = input_param.get("bench_json_path", None)
477
+ if not npu_json_path or not bench_json_path:
478
+ logger.error(f"Please check the json path is valid.")
479
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
480
+ with FileOpen(npu_json_path, 'r') as npu_f:
481
+ npu_json_data = json.load(npu_f)
482
+ with FileOpen(bench_json_path, 'r') as bench_f:
483
+ bench_json_data = json.load(bench_f)
484
+ if npu_json_data['task'] != bench_json_data['task']:
485
+ logger.error(f"Please check the dump task is consistent.")
486
+ raise CompareException(CompareException.INVALID_TASK_ERROR)
487
+ if npu_json_data['task'] == Const.TENSOR:
488
+ summary_compare = False
489
+ md5_compare = False
490
+ elif npu_json_data['task'] == Const.STATISTICS:
491
+ md5_compare = md5_find(npu_json_data['data'])
492
+ if md5_compare:
493
+ summary_compare = False
494
+ else:
495
+ summary_compare = True
496
+ else:
497
+ logger.error(f"Compare is not required for overflow_check or free_benchmark.")
498
+ raise CompareException(CompareException.INVALID_TASK_ERROR)
499
+ input_param['npu_dump_data_dir'] = npu_json_data['dump_data_dir']
500
+ input_param['bench_dump_data_dir'] = bench_json_data['dump_data_dir']
501
+ return summary_compare, md5_compare
502
+
503
+
504
+ def get_header_index(header_name, summary_compare=False):
505
+ if summary_compare:
506
+ header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
507
+ else:
508
+ header = CompareConst.COMPARE_RESULT_HEADER[:]
509
+ if header_name not in header:
510
+ logger.error(f"{header_name} not in data name")
511
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
512
+ return header.index(header_name)
513
+
514
+
515
+ def convert_tuple(data):
516
+ return data if isinstance(data, tuple) else (data, )
@@ -0,0 +1,58 @@
1
+ from msprobe.core.common.const import Const
2
+ from msprobe.core.common.log import logger
3
+ from msprobe.core.common.exceptions import MsaccException
4
+
5
+
6
+ class CommonConfig:
7
+ def __init__(self, json_config):
8
+ self.task = json_config.get('task')
9
+ self.dump_path = json_config.get('dump_path')
10
+ self.rank = json_config.get('rank')
11
+ self.step = json_config.get('step')
12
+ self.level = json_config.get('level')
13
+ self.seed = json_config.get('seed')
14
+ self.acl_config = json_config.get('acl_config')
15
+ self.is_deterministic = json_config.get('is_deterministic', False)
16
+ self.enable_dataloader = json_config.get('enable_dataloader', False)
17
+ self._check_config()
18
+
19
+ def _check_config(self):
20
+ if self.task and self.task not in Const.TASK_LIST:
21
+ logger.error_log_with_exp(
22
+ "task is invalid, it should be one of {}".format(Const.TASK_LIST), MsaccException(MsaccException.INVALID_PARAM_ERROR))
23
+ if self.rank is not None and not isinstance(self.rank, list):
24
+ logger.error_log_with_exp("rank is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR))
25
+ if self.step is not None and not isinstance(self.step, list):
26
+ logger.error_log_with_exp("step is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR))
27
+ if self.level and self.level not in Const.LEVEL_LIST:
28
+ logger.error_log_with_exp(
29
+ "level is invalid, it should be one of {}".format(Const.LEVEL_LIST), MsaccException(MsaccException.INVALID_PARAM_ERROR))
30
+ if self.seed is not None and not isinstance(self.seed, int):
31
+ logger.error_log_with_exp("seed is invalid, it should be an integer", MsaccException(MsaccException.INVALID_PARAM_ERROR))
32
+ if not isinstance(self.is_deterministic, bool):
33
+ logger.error_log_with_exp(
34
+ "is_deterministic is invalid, it should be a boolean", MsaccException(MsaccException.INVALID_PARAM_ERROR))
35
+ if not isinstance(self.enable_dataloader, bool):
36
+ logger.error_log_with_exp(
37
+ "enable_dataloader is invalid, it should be a boolean", MsaccException(MsaccException.INVALID_PARAM_ERROR))
38
+
39
+
40
+ class BaseConfig:
41
+ def __init__(self, json_config):
42
+ self.scope = json_config.get('scope')
43
+ self.list = json_config.get('list')
44
+ self.data_mode = json_config.get('data_mode')
45
+ self.backward_input = json_config.get("backward_input")
46
+ self.file_format = json_config.get("file_format")
47
+ self.summary_mode = json_config.get("summary_mode")
48
+ self.overflow_num = json_config.get("overflow_num")
49
+ self.check_mode = json_config.get("check_mode")
50
+
51
+ def check_config(self):
52
+ if self.scope is not None and not isinstance(self.scope, list):
53
+ logger.error_log_with_exp("scope is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR))
54
+ if self.list is not None and not isinstance(self.list, list):
55
+ logger.error_log_with_exp("list is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR))
56
+ if self.data_mode is not None and not isinstance(self.data_mode, list):
57
+ logger.error_log_with_exp("data_mode is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR))
58
+