mindstudio-probe 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
  2. mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
  3. mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
  4. mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
  5. mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
  6. mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
  7. msprobe/README.md +182 -0
  8. msprobe/__init__.py +0 -0
  9. msprobe/config/README.md +397 -0
  10. msprobe/config/config.json +28 -0
  11. msprobe/config/img/free_benchmark.png +0 -0
  12. msprobe/core/common/const.py +241 -0
  13. msprobe/core/common/exceptions.py +88 -0
  14. msprobe/core/common/file_check.py +265 -0
  15. msprobe/core/common/log.py +55 -0
  16. msprobe/core/common/utils.py +516 -0
  17. msprobe/core/common_config.py +58 -0
  18. msprobe/core/data_dump/data_collector.py +140 -0
  19. msprobe/core/data_dump/data_processor/base.py +245 -0
  20. msprobe/core/data_dump/data_processor/factory.py +61 -0
  21. msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
  22. msprobe/core/data_dump/json_writer.py +116 -0
  23. msprobe/core/data_dump/scope.py +178 -0
  24. msprobe/mindspore/__init__.py +1 -0
  25. msprobe/mindspore/debugger/__init__.py +0 -0
  26. msprobe/mindspore/debugger/debugger_config.py +51 -0
  27. msprobe/mindspore/debugger/precision_debugger.py +32 -0
  28. msprobe/mindspore/doc/dump.md +65 -0
  29. msprobe/mindspore/dump/__init__.py +0 -0
  30. msprobe/mindspore/dump/api_kbk_dump.py +55 -0
  31. msprobe/mindspore/dump/dump_tool_factory.py +38 -0
  32. msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
  33. msprobe/mindspore/ms_config.py +78 -0
  34. msprobe/mindspore/overflow_check/__init__.py +0 -0
  35. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
  36. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
  37. msprobe/mindspore/task_handler_factory.py +21 -0
  38. msprobe/msprobe.py +67 -0
  39. msprobe/pytorch/__init__.py +4 -0
  40. msprobe/pytorch/advisor/advisor.py +124 -0
  41. msprobe/pytorch/advisor/advisor_const.py +59 -0
  42. msprobe/pytorch/advisor/advisor_result.py +58 -0
  43. msprobe/pytorch/api_accuracy_checker/.keep +0 -0
  44. msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
  45. msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
  46. msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
  47. msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
  48. msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
  49. msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
  50. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
  51. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
  52. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
  53. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
  54. msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
  55. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
  58. msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
  59. msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
  60. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
  61. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
  62. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
  63. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
  64. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
  65. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
  66. msprobe/pytorch/common/__init__.py +2 -0
  67. msprobe/pytorch/common/compare_script.template +14 -0
  68. msprobe/pytorch/common/log.py +32 -0
  69. msprobe/pytorch/common/parse_json.py +37 -0
  70. msprobe/pytorch/common/utils.py +224 -0
  71. msprobe/pytorch/compare/acc_compare.py +1024 -0
  72. msprobe/pytorch/compare/distributed_compare.py +111 -0
  73. msprobe/pytorch/compare/highlight.py +100 -0
  74. msprobe/pytorch/compare/mapping.yaml +607 -0
  75. msprobe/pytorch/compare/match.py +36 -0
  76. msprobe/pytorch/compare/npy_compare.py +244 -0
  77. msprobe/pytorch/debugger/__init__.py +0 -0
  78. msprobe/pytorch/debugger/debugger_config.py +86 -0
  79. msprobe/pytorch/debugger/precision_debugger.py +95 -0
  80. msprobe/pytorch/doc/FAQ.md +193 -0
  81. msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
  82. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
  83. msprobe/pytorch/doc/dump.md +207 -0
  84. msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
  85. msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
  86. msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
  87. msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
  88. msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
  89. msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
  90. msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
  91. msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
  92. msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
  93. msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
  94. msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
  95. msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
  96. msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
  97. msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
  98. msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
  99. msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
  100. msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
  101. msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
  102. msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
  103. msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
  104. msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
  105. msprobe/pytorch/doc/img/cpu_info.png +0 -0
  106. msprobe/pytorch/doc/img/module_compare.png +0 -0
  107. msprobe/pytorch/doc/parse_tool.md +286 -0
  108. msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
  109. msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
  110. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
  111. msprobe/pytorch/doc/run_overflow_check.md +25 -0
  112. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
  113. msprobe/pytorch/free_benchmark/__init__.py +8 -0
  114. msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
  115. msprobe/pytorch/free_benchmark/common/constant.py +67 -0
  116. msprobe/pytorch/free_benchmark/common/counter.py +72 -0
  117. msprobe/pytorch/free_benchmark/common/enums.py +37 -0
  118. msprobe/pytorch/free_benchmark/common/params.py +129 -0
  119. msprobe/pytorch/free_benchmark/common/utils.py +98 -0
  120. msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
  121. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
  122. msprobe/pytorch/free_benchmark/main.py +102 -0
  123. msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
  124. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
  125. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
  126. msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
  127. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
  132. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
  133. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
  134. msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
  135. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
  136. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
  137. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
  138. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
  139. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
  140. msprobe/pytorch/functional/__init__.py +0 -0
  141. msprobe/pytorch/functional/data_processor.py +0 -0
  142. msprobe/pytorch/functional/dump_module.py +39 -0
  143. msprobe/pytorch/hook_module/__init__.py +1 -0
  144. msprobe/pytorch/hook_module/api_registry.py +161 -0
  145. msprobe/pytorch/hook_module/hook_module.py +109 -0
  146. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
  147. msprobe/pytorch/hook_module/utils.py +29 -0
  148. msprobe/pytorch/hook_module/wrap_aten.py +100 -0
  149. msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
  150. msprobe/pytorch/hook_module/wrap_functional.py +108 -0
  151. msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
  152. msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
  153. msprobe/pytorch/hook_module/wrap_torch.py +88 -0
  154. msprobe/pytorch/hook_module/wrap_vf.py +64 -0
  155. msprobe/pytorch/module_processer.py +98 -0
  156. msprobe/pytorch/online_dispatch/__init__.py +20 -0
  157. msprobe/pytorch/online_dispatch/compare.py +236 -0
  158. msprobe/pytorch/online_dispatch/dispatch.py +274 -0
  159. msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
  160. msprobe/pytorch/online_dispatch/single_compare.py +391 -0
  161. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
  162. msprobe/pytorch/online_dispatch/utils.py +187 -0
  163. msprobe/pytorch/parse.py +4 -0
  164. msprobe/pytorch/parse_tool/__init__.py +0 -0
  165. msprobe/pytorch/parse_tool/cli.py +32 -0
  166. msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
  167. msprobe/pytorch/parse_tool/lib/compare.py +259 -0
  168. msprobe/pytorch/parse_tool/lib/config.py +51 -0
  169. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
  170. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
  171. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
  172. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
  173. msprobe/pytorch/parse_tool/lib/utils.py +367 -0
  174. msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
  175. msprobe/pytorch/pt_config.py +93 -0
  176. msprobe/pytorch/service.py +167 -0
  177. msprobe/test/core_ut/common/test_utils.py +345 -0
  178. msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
  179. msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
  180. msprobe/test/core_ut/data_dump/test_scope.py +151 -0
  181. msprobe/test/core_ut/test_common_config.py +152 -0
  182. msprobe/test/core_ut/test_file_check.py +218 -0
  183. msprobe/test/core_ut/test_log.py +109 -0
  184. msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
  185. msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
  186. msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
  187. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
  188. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
  189. msprobe/test/mindspore_ut/test_ms_config.py +69 -0
  190. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
  191. msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
  192. msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
  193. msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
  194. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
  195. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
  196. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
  197. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
  198. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
  199. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
  200. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
  201. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
  202. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
  203. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
  204. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
  205. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
  206. msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
  207. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
  208. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
  209. msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
  210. msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
  211. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
  212. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
  213. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
  214. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
  215. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
  216. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
  217. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
  218. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
  219. msprobe/test/pytorch_ut/test_pt_config.py +69 -0
  220. msprobe/test/pytorch_ut/test_service.py +59 -0
  221. msprobe/test/resources/advisor.txt +3 -0
  222. msprobe/test/resources/compare_result_20230703104808.csv +9 -0
  223. msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
  224. msprobe/test/resources/config.yaml +3 -0
  225. msprobe/test/resources/npu_test.pkl +8 -0
  226. msprobe/test/run_test.sh +30 -0
  227. msprobe/test/run_ut.py +58 -0
  228. msprobe/test/test_module_processer.py +64 -0
@@ -0,0 +1,328 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ import os
19
+ import math
20
+ import torch
21
+ import numpy
22
+
23
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
24
+ from msprobe.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path, check_object_type, \
25
+ get_full_data_path, CompareException
26
+ from msprobe.pytorch.common.log import logger
27
+ from msprobe.core.common.const import Const
28
+
29
+ TORCH_TYPE = ["torch.device", "torch.dtype"]
30
+ TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
31
+ FLOAT_TYPE = ['torch.float32', 'torch.float', 'torch.float64', 'torch.double', 'torch.float16',
32
+ 'torch.half', 'torch.bfloat16']
33
+ NUMPY_TYPE = ["numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32",
34
+ "numpy.uint64", "numpy.float16", "numpy.float32", "numpy.float64", "numpy.float128", "numpy.complex64",
35
+ "numpy.complex128", "numpy.complex256", "numpy.bool_", "numpy.string_", "numpy.bytes_", "numpy.unicode_"]
36
+
37
+
38
+ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
39
+ """
40
+ Function Description:
41
+ Based on arg basic information, generate arg data
42
+ Parameter:
43
+ info: arg basic information. Dict
44
+ api_name: API name
45
+ need_grad: set Tensor grad for backward
46
+ convert_type: convert ori_type to dist_type flag.
47
+ """
48
+ check_object_type(info, dict)
49
+ data_type = info.get('type')
50
+ data_path = info.get('datapath', info.get('data_name'))
51
+ data_path = get_full_data_path(data_path, real_data_path)
52
+ if data_type in TENSOR_DATA_LIST:
53
+ if data_path:
54
+ data = gen_real_tensor(data_path, convert_type)
55
+ else:
56
+ data = gen_random_tensor(info, convert_type)
57
+ if api_name in hf_32_standard_api and data.dtype == torch.float32:
58
+ data = fp32_to_hf32_to_fp32(data)
59
+ if info.get('requires_grad') and need_grad:
60
+ data.requires_grad_(True)
61
+ temp_data = data * 1
62
+ data = temp_data.type_as(data)
63
+ data.retain_grad()
64
+ elif data_type.startswith("numpy"):
65
+ if data_type not in NUMPY_TYPE:
66
+ raise Exception("{} is not supported now".format(data_type))
67
+ data = info.get("value")
68
+ try:
69
+ data = eval(data_type)(data)
70
+ except Exception as err:
71
+ logger.error("Failed to convert the type to numpy: %s" % str(err))
72
+ elif data_type == "torch.Size":
73
+ data = torch.Size(info.get("value"))
74
+ else:
75
+ data = info.get('value')
76
+ if info.get("type") == "slice":
77
+ data = slice(*data)
78
+ return data
79
+
80
+
81
+ def gen_real_tensor(data_path, convert_type):
82
+ """
83
+ Function Description:
84
+ Based on API data path, generate input parameters real data
85
+ Parameter:
86
+ data_path: API data path
87
+ convert_type: convert ori_type to dist_type flag.
88
+ """
89
+ data_path = os.path.realpath(data_path)
90
+ check_file_or_directory_path(data_path)
91
+ if not data_path.endswith('.pt') and not data_path.endswith('.npy'):
92
+ error_info = f"The file: {data_path} is not a pt or numpy file."
93
+ raise CompareException(CompareException.INVALID_FILE_ERROR, error_info)
94
+ if data_path.endswith('.pt'):
95
+ data = torch.load(data_path).cpu()
96
+ else:
97
+ data_np = numpy.load(data_path)
98
+ data = torch.from_numpy(data_np)
99
+ if convert_type:
100
+ ori_dtype = Const.CONVERT.get(convert_type)[0]
101
+ dist_dtype = Const.CONVERT.get(convert_type)[1]
102
+ if str(data.dtype) == ori_dtype:
103
+ data = data.type(eval(dist_dtype))
104
+ return data
105
+
106
+
107
+ def gen_random_tensor(info, convert_type):
108
+ """
109
+ Function Description:
110
+ Based on API MAX and MIN, generate input parameters random data
111
+ Parameter:
112
+ info: API data info
113
+ convert_type: convert ori_type to dist_type flag.
114
+ """
115
+ check_object_type(info, dict)
116
+ low, high = info.get('Min'), info.get('Max')
117
+ low_origin, high_origin = info.get('Min_origin'), info.get('Max_origin')
118
+ low_info = [low, low_origin]
119
+ high_info = [high, high_origin]
120
+ data_dtype = info.get('dtype')
121
+ shape = tuple(info.get('shape'))
122
+ if not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
123
+ error_info = f'Data info Min: {low} , Max: {high}, info type must be int or float.'
124
+ raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
125
+ if data_dtype == "torch.bool":
126
+ data = gen_bool_tensor(low, high, shape)
127
+ else:
128
+ data = gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type)
129
+ return data
130
+
131
+
132
+ def fp32_to_hf32_to_fp32(input_tensor):
133
+ # 将输入的float32 tensor转为hf32 tensor,再转为float32 tensor
134
+ input_np = input_tensor.detach().numpy()
135
+ input_int = input_np.view(numpy.int32)
136
+ input_int = numpy.right_shift(numpy.right_shift(input_int, 11) + 1, 1)
137
+ input_int = numpy.left_shift(input_int, 12)
138
+ input_fp32 = input_int.view(numpy.float32)
139
+ input_hf32 = torch.from_numpy(input_fp32)
140
+ return input_hf32
141
+
142
+
143
+ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
144
+ """
145
+ Function Description:
146
+ Based on API basic information, generate int or float tensor
147
+ Parameter:
148
+ low_info: [low, low_origin], low is the minimum value in the tensor removed inf and nan,
149
+ low_origin is the original minimum value in the tensor
150
+ high_info: [high, high_origin], high is the maximum value in the tensor removed inf and nan,
151
+ high_origin is the original maximum value in the tensor
152
+ shape:The shape of Tensor
153
+ data_dtype: The data type of Tensor
154
+ convert_type: convert ori_type to dist_type flag.
155
+ """
156
+ if convert_type:
157
+ ori_dtype = Const.CONVERT.get(convert_type)[0]
158
+ if ori_dtype == data_dtype:
159
+ data_dtype = Const.CONVERT.get(convert_type)[1]
160
+ low, low_origin = low_info[0], low_info[1]
161
+ high, high_origin = high_info[0], high_info[1]
162
+ if data_dtype in FLOAT_TYPE:
163
+ if math.isnan(high):
164
+ tensor = torch._C._VariableFunctionsClass.full(shape, float('nan'), dtype=eval(data_dtype))
165
+ return tensor
166
+ #high_origin为新版json中的属性,只有当high_origin不为None,且high为inf或-inf时,原tensor全为inf或-inf
167
+ if high_origin and high in [float('inf'), float('-inf')]:
168
+ tensor = torch._C._VariableFunctionsClass.full(shape, high, dtype=eval(data_dtype))
169
+ tensor[-1] = low
170
+ return tensor
171
+ low_scale, high_scale = low, high
172
+ dtype_finfo = torch.finfo(eval(data_dtype))
173
+ #适配老版json high和low为inf或-inf的情况,取dtype的最大值或最小值进行放缩
174
+ if high == float('inf'):
175
+ high_scale = dtype_finfo.max
176
+ elif high == float('-inf'):
177
+ high_scale = dtype_finfo.min
178
+ if low == float('inf'):
179
+ low_scale = dtype_finfo.max
180
+ elif low == float('-inf'):
181
+ low_scale = dtype_finfo.min
182
+
183
+ scale = high_scale - low_scale
184
+ rand01 = torch.rand(shape, dtype=eval(data_dtype))
185
+ tensor = rand01 * scale + low_scale
186
+ elif 'int' in data_dtype or 'long' in data_dtype:
187
+ low, high = int(low), int(high)
188
+ tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype))
189
+ else:
190
+ logger.error('Dtype is not supported: ' + data_dtype)
191
+ raise NotImplementedError()
192
+ if tensor.nelement() == 0:
193
+ return tensor
194
+ tmp_tensor = tensor.reshape(-1)
195
+ if high_origin and math.isnan(high_origin):
196
+ if tmp_tensor.numel() <= 2:
197
+ tmp_tensor[0] = float('nan')
198
+ tmp_tensor[-1] = high
199
+ else:
200
+ tmp_tensor[0] = low
201
+ tmp_tensor[1] = float('nan')
202
+ tmp_tensor[-1] = high
203
+ else:
204
+ tmp_tensor[0] = low
205
+ tmp_tensor[-1] = high
206
+ if high_origin in [float('inf'), float('-inf')]:
207
+ tmp_tensor[-1] = high_origin
208
+ if low_origin in [float('inf'), float('-inf')]:
209
+ tmp_tensor[0] = low_origin
210
+ data = tmp_tensor.reshape(shape)
211
+ return data
212
+
213
+
214
+ def gen_bool_tensor(low, high, shape):
215
+ """
216
+ Function Description:
217
+ Based on API basic information, generate bool tensor
218
+ Parameter:
219
+ low: The minimum value in Tensor
220
+ high: The max value in Tensor
221
+ shape:The shape of Tensor
222
+ """
223
+ low, high = int(low), int(high)
224
+ if low > high:
225
+ low, high = high, low
226
+ tensor = torch.randint(low, high + 1, shape)
227
+ data = torch.gt(tensor, 0)
228
+ return data
229
+
230
+
231
+ def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
232
+ """
233
+ Function Description:
234
+ Based on API basic information, generate input parameters: args, for API forward running
235
+ Parameter:
236
+ api_info: API basic information. List
237
+ api_name: API name
238
+ need_grad: set Tensor grad for backward
239
+ convert_type: convert ori_type to dist_type flag.
240
+ real_data_path: the root directory for storing real data.
241
+ """
242
+ check_object_type(args_info, list)
243
+ args_result = []
244
+ for arg in args_info:
245
+ if isinstance(arg, (list, tuple)):
246
+ data = gen_args(arg, api_name, need_grad, convert_type, real_data_path)
247
+ elif isinstance(arg, dict):
248
+ data = gen_data(arg, api_name, need_grad, convert_type, real_data_path)
249
+ elif arg is None:
250
+ data = None
251
+ else:
252
+ logger.warning(f'Warning: {arg} is not supported')
253
+ raise NotImplementedError()
254
+ args_result.append(data)
255
+ return args_result
256
+
257
+
258
+ def gen_kwargs(api_info, convert_type=None, real_data_path=None):
259
+ """
260
+ Function Description:
261
+ Based on API basic information, generate input parameters: kwargs, for API forward running
262
+ Parameter:
263
+ api_info: API basic information. Dict
264
+ convert_type: convert ori_type to dist_type flag.
265
+ real_data_path: the root directory for storing real data.
266
+ """
267
+ check_object_type(api_info, dict)
268
+ kwargs_params = api_info.get("input_kwargs")
269
+ for key, value in kwargs_params.items():
270
+ if isinstance(value, (list, tuple)):
271
+ kwargs_params[key] = gen_list_kwargs(value, convert_type, real_data_path)
272
+ elif value is None:
273
+ kwargs_params[key] = None
274
+ elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"):
275
+ kwargs_params[key] = gen_data(value, True, convert_type, real_data_path)
276
+ elif value.get('type') in TORCH_TYPE:
277
+ gen_torch_kwargs(kwargs_params, key, value)
278
+ else:
279
+ kwargs_params[key] = value.get('value')
280
+ return kwargs_params
281
+
282
+
283
+ def gen_torch_kwargs(kwargs_params, key, value):
284
+ if value.get('type') != "torch.device":
285
+ kwargs_params[key] = eval(value.get('value'))
286
+
287
+
288
+ def gen_list_kwargs(kwargs_item_value, convert_type, real_data_path=None):
289
+ """
290
+ Function Description:
291
+ When kwargs value is list, generate the list of kwargs result
292
+ Parameter:
293
+ kwargs_item_value: kwargs value before to generate. List
294
+ convert_type: convert ori_type to dist_type flag.
295
+ """
296
+ kwargs_item_result = []
297
+ for item in kwargs_item_value:
298
+ if item.get('type') in TENSOR_DATA_LIST:
299
+ item_value = gen_data(item, False, convert_type, real_data_path)
300
+ elif item.get('type') == "torch.Size":
301
+ item_value = torch.Size(item.get('value'))
302
+ else:
303
+ item_value = item.get('value')
304
+ kwargs_item_result.append(item_value)
305
+ return kwargs_item_result
306
+
307
+
308
+ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
309
+ """
310
+ Function Description:
311
+ Based on API basic information, generate input parameters: args, kwargs, for API forward running
312
+ Parameter:
313
+ api_info: API basic information. Dict
314
+ api_name: API name
315
+ need_grad: set grad for backward
316
+ convert_type: convert ori_type to dist_type flag.
317
+ """
318
+ check_object_type(api_info, dict)
319
+ if convert_type and convert_type not in Const.CONVERT:
320
+ error_info = f"convert_type params not support {convert_type}."
321
+ raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
322
+ kwargs_params = gen_kwargs(api_info, convert_type, real_data_path)
323
+ if api_info.get("input_args"):
324
+ args_params = gen_args(api_info.get("input_args"), api_name, need_grad, convert_type, real_data_path)
325
+ else:
326
+ logger.warning(f'Warning: No args in {api_info} ')
327
+ args_params = []
328
+ return args_params, kwargs_params
@@ -0,0 +1,203 @@
1
+ import subprocess
2
+ import json
3
+ import os
4
+ import sys
5
+ import argparse
6
+ import time
7
+ import signal
8
+ import threading
9
+ from collections import namedtuple
10
+ from itertools import cycle
11
+ from tqdm import tqdm
12
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_result_csv_path, \
13
+ get_validated_details_csv_path, preprocess_forward_content
14
+ from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
15
+ from msprobe.pytorch.common import parse_json_info_forward_backward
16
+ from msprobe.core.common.file_check import FileChecker, check_file_suffix, check_link, FileOpen, \
17
+ check_path_before_create, create_directory
18
+ from msprobe.pytorch.common.log import logger
19
+ from msprobe.core.common.const import FileCheckConst
20
+
21
+
22
+ def split_json_file(input_file, num_splits, filter_api):
23
+ forward_data, backward_data, real_data_path = parse_json_info_forward_backward(input_file)
24
+ if filter_api:
25
+ forward_data = preprocess_forward_content(forward_data)
26
+ for data_name in list(forward_data.keys()):
27
+ forward_data[f"{data_name}.forward"] = forward_data.pop(data_name)
28
+ for data_name in list(backward_data.keys()):
29
+ backward_data[f"{data_name}.backward"] = backward_data.pop(data_name)
30
+
31
+ with FileOpen(input_file, 'r') as file:
32
+ input_data = json.load(file)
33
+ input_data.pop("data")
34
+
35
+ items = list(forward_data.items())
36
+ total_items = len(items)
37
+ chunk_size = total_items // num_splits
38
+ split_files = []
39
+
40
+ for i in range(num_splits):
41
+ start = i * chunk_size
42
+ end = (i + 1) * chunk_size if i < num_splits - 1 else total_items
43
+
44
+ split_forward_data = dict(items[start:end])
45
+ temp_data = {
46
+ **input_data,
47
+ "data":{
48
+ **split_forward_data,
49
+ **backward_data
50
+ }
51
+ }
52
+ split_filename = f"temp_part{i}.json"
53
+ with FileOpen(split_filename, 'w') as split_file:
54
+ json.dump(temp_data, split_file)
55
+ split_files.append(split_filename)
56
+
57
+ return split_files, total_items
58
+
59
+
60
+ def signal_handler(signum, frame):
61
+ logger.warning(f'Signal handler called with signal {signum}')
62
+ raise KeyboardInterrupt()
63
+
64
+
65
+ signal.signal(signal.SIGINT, signal_handler)
66
+ signal.signal(signal.SIGTERM, signal_handler)
67
+
68
+
69
+ ParallelUTConfig = namedtuple('ParallelUTConfig', ['api_files', 'out_path', 'num_splits',
70
+ 'save_error_data_flag', 'jit_compile_flag', 'device_id',
71
+ 'result_csv_path', 'total_items', 'real_data_path'])
72
+
73
+
74
+ def run_parallel_ut(config):
75
+ processes = []
76
+ device_id_cycle = cycle(config.device_id)
77
+ if config.save_error_data_flag:
78
+ logger.info("UT task error datas will be saved")
79
+ logger.info(f"Starting parallel UT with {config.num_splits} processes")
80
+ progress_bar = tqdm(total=config.total_items, desc="Total items", unit="items")
81
+
82
+ def create_cmd(api_info, dev_id):
83
+ dirname, filename = os.path.split(os.path.abspath(__file__))
84
+ run_ut_path = os.path.join(dirname, "run_ut.py")
85
+ cmd = [
86
+ sys.executable, run_ut_path,
87
+ '-api_info', api_info,
88
+ *(['-o', config.out_path] if config.out_path else []),
89
+ '-d', str(dev_id),
90
+ *(['-j'] if config.jit_compile_flag else []),
91
+ *(['-save_error_data'] if config.save_error_data_flag else []),
92
+ '-csv_path', config.result_csv_path,
93
+ *(['-real_data_path', config.real_data_path] if config.real_data_path else [])
94
+ ]
95
+ return cmd
96
+
97
+ def read_process_output(process):
98
+ try:
99
+ while True:
100
+ if process.poll() is not None:
101
+ break
102
+ output = process.stdout.readline()
103
+ if output == '':
104
+ break
105
+ if '[ERROR]' in output:
106
+ print(output, end='')
107
+ sys.stdout.flush()
108
+ except ValueError as e:
109
+ logger.warning(f"An error occurred while reading subprocess output: {e}")
110
+
111
+ def update_progress_bar(progress_bar, result_csv_path):
112
+ while any(process.poll() is None for process in processes):
113
+ try:
114
+ with open(result_csv_path, 'r') as result_file:
115
+ completed_items = len(result_file.readlines()) - 1
116
+ progress_bar.update(completed_items - progress_bar.n)
117
+ except FileNotFoundError:
118
+ logger.warning(f"Result CSV file not found: {result_csv_path}.")
119
+ except Exception as e:
120
+ logger.error(f"An unexpected error occurred while reading result CSV: {e}")
121
+ time.sleep(1)
122
+
123
+ for api_info in config.api_files:
124
+ cmd = create_cmd(api_info, next(device_id_cycle))
125
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, bufsize=1)
126
+ processes.append(process)
127
+ threading.Thread(target=read_process_output, args=(process,), daemon=True).start()
128
+
129
+ progress_bar_thread = threading.Thread(target=update_progress_bar, args=(progress_bar, config.result_csv_path))
130
+ progress_bar_thread.start()
131
+
132
+ def clean_up():
133
+ progress_bar.close()
134
+ for process in processes:
135
+ try:
136
+ process.terminate()
137
+ process.wait(timeout=1)
138
+ except subprocess.TimeoutExpired:
139
+ process.kill()
140
+ for file in config.api_files:
141
+ check_link(file)
142
+ try:
143
+ os.remove(file)
144
+ except FileNotFoundError:
145
+ logger.warning(f"File not found and could not be deleted: {file}")
146
+
147
+ try:
148
+ for process in processes:
149
+ process.communicate(timeout=None)
150
+ except KeyboardInterrupt:
151
+ logger.warning("Interrupted by user, terminating processes and cleaning up...")
152
+ except Exception as e:
153
+ logger.error(f"An unexpected error occurred: {e}")
154
+ finally:
155
+ if progress_bar.n < config.total_items:
156
+ logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to the result CSV file will be utilized to resume the UT task.")
157
+ clean_up()
158
+ progress_bar_thread.join()
159
+ try:
160
+ comparator = Comparator(config.result_csv_path, config.result_csv_path, False)
161
+ comparator.print_pretest_result()
162
+ except FileNotFoundError as e:
163
+ logger.error(f"Error: {e}")
164
+ except Exception as e:
165
+ logger.error(f"An unexpected error occurred: {e}")
166
+
167
+
168
+ def prepare_config(args):
169
+ check_link(args.api_info_file)
170
+ api_info = os.path.realpath(args.api_info_file)
171
+ check_file_suffix(api_info, FileCheckConst.JSON_SUFFIX)
172
+ out_path = os.path.realpath(args.out_path) if args.out_path else "./"
173
+ check_path_before_create(out_path)
174
+ create_directory(out_path)
175
+ out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
176
+ out_path = out_path_checker.common_check()
177
+ split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
178
+
179
+ result_csv_path = args.result_csv_path or os.path.join(out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
180
+ if not args.result_csv_path:
181
+ details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv")
182
+ comparator = Comparator(result_csv_path, details_csv_path, False)
183
+ else:
184
+ result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
185
+ details_csv_path = get_validated_details_csv_path(result_csv_path)
186
+ logger.info(f"UT task result will be saved in {result_csv_path}")
187
+ logger.info(f"UT task details will be saved in {details_csv_path}")
188
+ return ParallelUTConfig(split_files, out_path, args.num_splits, args.save_error_data,
189
+ args.jit_compile, args.device_id, result_csv_path,
190
+ total_items, args.real_data_path)
191
+
192
+
193
+ def main():
194
+ parser = argparse.ArgumentParser(description='Run UT in parallel')
195
+ _run_ut_parser(parser)
196
+ parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, help='Number of splits for parallel processing. Range: 1-64')
197
+ args = parser.parse_args()
198
+ config = prepare_config(args)
199
+ run_parallel_ut(config)
200
+
201
+
202
+ if __name__ == '__main__':
203
+ main()
@@ -0,0 +1,127 @@
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+ try:
6
+ import torch_npu
7
+ except ImportError:
8
+ is_gpu = True
9
+ else:
10
+ is_gpu = False
11
+ import torch
12
+ from tqdm import tqdm
13
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, get_api_info
14
+ from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents
15
+ from msprobe.core.common.file_check import check_link
16
+ from msprobe.pytorch.common.log import logger
17
+
18
+ def check_tensor_overflow(x):
19
+ if isinstance(x, torch.Tensor) and x.numel() != 0 and x.dtype != torch.bool:
20
+ if len(x.shape) == 0:
21
+ tensor_max = x.cpu().detach().float().numpy().tolist()
22
+ tensor_min = tensor_max
23
+ else:
24
+ tensor_max = torch._C._VariableFunctionsClass.max(x).cpu().detach().float().numpy().tolist()
25
+ tensor_min = torch._C._VariableFunctionsClass.min(x).cpu().detach().float().numpy().tolist()
26
+ # inf
27
+ if tensor_max == float('inf') or tensor_min == float('-inf'):
28
+ return True
29
+ # nan
30
+ elif tensor_max != tensor_max or tensor_min != tensor_min:
31
+ return True
32
+ else:
33
+ return False
34
+ elif isinstance(x, bool) or isinstance(x, int) or isinstance(x, float):
35
+ if x == float('inf') or x == float('-inf') or x != x:
36
+ return True
37
+ else:
38
+ return False
39
+ else:
40
+ return False
41
+
42
+
43
+ def check_data_overflow(x):
44
+ if isinstance(x, (tuple, list)) and x:
45
+ for _, item in enumerate(x):
46
+ if check_data_overflow(item):
47
+ return True
48
+ return False
49
+ else:
50
+ return check_tensor_overflow(x)
51
+
52
+
53
+ def run_overflow_check(forward_file):
54
+ logger.info("start UT test")
55
+ forward_content = get_json_contents(forward_file)
56
+ for api_full_name, api_info_dict in tqdm(forward_content.items()):
57
+ try:
58
+ run_torch_api(api_full_name, api_info_dict)
59
+ except Exception as err:
60
+ api_name = api_full_name.split("_", 1)[1].rsplit("_", 2)[0]
61
+ if "not implemented for 'Half'" in str(err):
62
+ logger.warning(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API "
63
+ f"'fp16_to_fp32' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
64
+ elif "expected scalar type Long" in str(err):
65
+ logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
66
+ f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
67
+ else:
68
+ logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
69
+
70
+
71
+ def run_torch_api(api_full_name, api_info_dict):
72
+ torch.npu.clear_npu_overflow_flag()
73
+ api_type = api_full_name.split(".")[0]
74
+ api_name = api_full_name.split(".", 1)[1].rsplit(".", 2)[0]
75
+ args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path='')
76
+ if not need_grad:
77
+ logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
78
+ % api_full_name)
79
+ npu_args, npu_kwargs = generate_device_params(args, kwargs, False, api_name)
80
+ if kwargs.get("device"):
81
+ del kwargs["device"]
82
+ out = exec_api(api_type, api_name, args, kwargs)
83
+ npu_out = exec_api(api_type, api_name, npu_args, npu_kwargs)
84
+ cpu_overflow = check_data_overflow(out)
85
+ npu_overflow = torch_npu.npu.utils.npu_check_overflow(npu_out)
86
+ if cpu_overflow == npu_overflow:
87
+ logger.warning("The %s overflow is a normal overflow." % api_full_name)
88
+ else:
89
+ logger.warning("The %s overflow is an abnormal overflow." % api_full_name)
90
+ return
91
+
92
+
93
+ def _run_overflow_check_parser(parser):
94
+ parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="",
95
+ help="<Required> The api param tool result file: generate from api param tool, "
96
+ "a json file.",
97
+ required=True)
98
+ parser.add_argument("-j", "--jit_compile", dest="jit_compile", help="<optional> whether to turn on jit compile",
99
+ default=False, required=False)
100
+ parser.add_argument("-d", "--device", dest="device_id", type=int, help="<optional> set NPU device id to run ut",
101
+ default=0, required=False)
102
+
103
+
104
+ def _run_overflow_check(parser=None):
105
+ if not parser:
106
+ parser = argparse.ArgumentParser()
107
+ _run_overflow_check_parser(parser)
108
+ args = parser.parse_args(sys.argv[1:])
109
+ _run_overflow_check_command(args)
110
+
111
+
112
+ def _run_overflow_check_command(args):
113
+ torch.npu.set_compile_mode(jit_compile=args.jit_compile)
114
+ npu_device = "npu:" + str(args.device_id)
115
+ check_link(args.api_info_file)
116
+ api_info = os.path.realpath(args.api_info_file)
117
+ try:
118
+ torch.npu.set_device(npu_device)
119
+ except Exception as error:
120
+ logger.error(f"Set NPU device id failed. device id is: {args.device_id}")
121
+ raise NotImplementedError from error
122
+ run_overflow_check(api_info)
123
+
124
+
125
+ if __name__ == '__main__':
126
+ _run_overflow_check()
127
+ logger.info("UT task completed.")