mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,332 +1,335 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
-
18
- import os
19
- import math
20
- import torch
21
- import numpy
22
-
23
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
24
- from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \
25
- CompareException
26
- from msprobe.core.common.file_check import FileChecker
27
- from msprobe.pytorch.common.log import logger
28
- from msprobe.core.common.const import Const, FileCheckConst
29
-
30
- TORCH_TYPE = ["torch.device", "torch.dtype"]
31
- TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
32
- FLOAT_TYPE = ['torch.float32', 'torch.float', 'torch.float64', 'torch.double', 'torch.float16',
33
- 'torch.half', 'torch.bfloat16']
34
- NUMPY_TYPE = ["numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32",
35
- "numpy.uint64", "numpy.float16", "numpy.float32", "numpy.float64", "numpy.float128", "numpy.complex64",
36
- "numpy.complex128", "numpy.complex256", "numpy.bool_", "numpy.string_", "numpy.bytes_", "numpy.unicode_"]
37
-
38
-
39
- def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
40
- """
41
- Function Description:
42
- Based on arg basic information, generate arg data
43
- Parameter:
44
- info: arg basic information. Dict
45
- api_name: API name
46
- need_grad: set Tensor grad for backward
47
- convert_type: convert ori_type to dist_type flag.
48
- """
49
- check_object_type(info, dict)
50
- data_type = info.get('type')
51
- data_path = info.get('datapath', info.get('data_name'))
52
- data_path = get_full_data_path(data_path, real_data_path)
53
- if data_type in TENSOR_DATA_LIST:
54
- if data_path:
55
- data = gen_real_tensor(data_path, convert_type)
56
- else:
57
- data = gen_random_tensor(info, convert_type)
58
- if api_name in hf_32_standard_api and data.dtype == torch.float32:
59
- data = fp32_to_hf32_to_fp32(data)
60
- if info.get('requires_grad') and need_grad:
61
- data.requires_grad_(True)
62
- temp_data = data * 1
63
- data = temp_data.type_as(data)
64
- data.retain_grad()
65
- elif data_type.startswith("numpy"):
66
- if data_type not in NUMPY_TYPE:
67
- raise Exception("{} is not supported now".format(data_type))
68
- data = info.get("value")
69
- try:
70
- data = eval(data_type)(data)
71
- except Exception as err:
72
- logger.error("Failed to convert the type to numpy: %s" % str(err))
73
- elif data_type == "torch.Size":
74
- data = torch.Size(info.get("value"))
75
- else:
76
- data = info.get('value')
77
- if info.get("type") == "slice":
78
- data = slice(*data)
79
- return data
80
-
81
-
82
- def gen_real_tensor(data_path, convert_type):
83
- """
84
- Function Description:
85
- Based on API data path, generate input parameters real data
86
- Parameter:
87
- data_path: API data path
88
- convert_type: convert ori_type to dist_type flag.
89
- """
90
- data_path = os.path.realpath(data_path)
91
- data_path_checker = FileChecker(data_path, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE)
92
- data_path = data_path_checker.common_check()
93
- if not data_path.endswith('.pt') and not data_path.endswith('.npy'):
94
- error_info = f"The file: {data_path} is not a pt or numpy file."
95
- raise CompareException(CompareException.INVALID_FILE_ERROR, error_info)
96
- if data_path.endswith('.pt'):
97
- data = torch.load(data_path, map_location=torch.device('cpu'))
98
- else:
99
- data_np = numpy.load(data_path)
100
- data = torch.from_numpy(data_np)
101
- if convert_type:
102
- ori_dtype = Const.CONVERT.get(convert_type)[0]
103
- dist_dtype = Const.CONVERT.get(convert_type)[1]
104
- if str(data.dtype) == ori_dtype:
105
- data = data.type(eval(dist_dtype))
106
- return data
107
-
108
-
109
- def gen_random_tensor(info, convert_type):
110
- """
111
- Function Description:
112
- Based on API MAX and MIN, generate input parameters random data
113
- Parameter:
114
- info: API data info
115
- convert_type: convert ori_type to dist_type flag.
116
- """
117
- check_object_type(info, dict)
118
- low, high = info.get('Min'), info.get('Max')
119
- low_origin, high_origin = info.get('Min_origin'), info.get('Max_origin')
120
- low_info = [low, low_origin]
121
- high_info = [high, high_origin]
122
- data_dtype = info.get('dtype')
123
- shape = tuple(info.get('shape'))
124
- if not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
125
- error_info = f'Data info Min: {low} , Max: {high}, info type must be int or float.'
126
- raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
127
- if data_dtype == "torch.bool":
128
- data = gen_bool_tensor(low, high, shape)
129
- else:
130
- data = gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type)
131
- return data
132
-
133
-
134
- def fp32_to_hf32_to_fp32(input_tensor):
135
- # 将输入的float32 tensor转为hf32 tensor,再转为float32 tensor
136
- input_np = input_tensor.detach().numpy()
137
- input_int = input_np.view(numpy.int32)
138
- input_int = numpy.right_shift(numpy.right_shift(input_int, 11) + 1, 1)
139
- input_int = numpy.left_shift(input_int, 12)
140
- input_fp32 = input_int.view(numpy.float32)
141
- input_hf32 = torch.from_numpy(input_fp32)
142
- return input_hf32
143
-
144
-
145
- def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
146
- """
147
- Function Description:
148
- Based on API basic information, generate int or float tensor
149
- Parameter:
150
- low_info: [low, low_origin], low is the minimum value in the tensor removed inf and nan,
151
- low_origin is the original minimum value in the tensor
152
- high_info: [high, high_origin], high is the maximum value in the tensor removed inf and nan,
153
- high_origin is the original maximum value in the tensor
154
- shape:The shape of Tensor
155
- data_dtype: The data type of Tensor
156
- convert_type: convert ori_type to dist_type flag.
157
- """
158
- if convert_type:
159
- ori_dtype = Const.CONVERT.get(convert_type)[0]
160
- if ori_dtype == data_dtype:
161
- data_dtype = Const.CONVERT.get(convert_type)[1]
162
- low, low_origin = low_info[0], low_info[1]
163
- high, high_origin = high_info[0], high_info[1]
164
- if data_dtype in FLOAT_TYPE:
165
- if math.isnan(high):
166
- tensor = torch._C._VariableFunctionsClass.full(shape, float('nan'), dtype=eval(data_dtype))
167
- return tensor
168
- #high_origin为新版json中的属性,只有当high_origin不为None,且high为inf或-inf时,原tensor全为inf或-inf
169
- if high_origin and high in [float('inf'), float('-inf')]:
170
- tensor = torch._C._VariableFunctionsClass.full(shape, high, dtype=eval(data_dtype))
171
- tensor[-1] = low
172
- return tensor
173
- low_scale, high_scale = low, high
174
- dtype_finfo = torch.finfo(eval(data_dtype))
175
- #适配老版json high和low为inf或-inf的情况,取dtype的最大值或最小值进行放缩
176
- if high == float('inf'):
177
- high_scale = dtype_finfo.max
178
- elif high == float('-inf'):
179
- high_scale = dtype_finfo.min
180
- if low == float('inf'):
181
- low_scale = dtype_finfo.max
182
- elif low == float('-inf'):
183
- low_scale = dtype_finfo.min
184
-
185
- scale = high_scale - low_scale
186
- rand01 = torch.rand(shape, dtype=eval(data_dtype))
187
- tensor = rand01 * scale + low_scale
188
- elif 'int' in data_dtype or 'long' in data_dtype:
189
- low, high = int(low), int(high)
190
- tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype))
191
- else:
192
- logger.error('Dtype is not supported: ' + data_dtype)
193
- raise NotImplementedError()
194
- if tensor.nelement() == 0:
195
- return tensor
196
- tmp_tensor = tensor.reshape(-1)
197
- if high_origin and math.isnan(high_origin):
198
- if tmp_tensor.numel() <= 2:
199
- tmp_tensor[0] = float('nan')
200
- tmp_tensor[-1] = high
201
- else:
202
- tmp_tensor[0] = low
203
- tmp_tensor[1] = float('nan')
204
- tmp_tensor[-1] = high
205
- else:
206
- tmp_tensor[0] = low
207
- tmp_tensor[-1] = high
208
- if high_origin in [float('inf'), float('-inf')]:
209
- tmp_tensor[-1] = high_origin
210
- if low_origin in [float('inf'), float('-inf')]:
211
- tmp_tensor[0] = low_origin
212
- data = tmp_tensor.reshape(shape)
213
- return data
214
-
215
-
216
- def gen_bool_tensor(low, high, shape):
217
- """
218
- Function Description:
219
- Based on API basic information, generate bool tensor
220
- Parameter:
221
- low: The minimum value in Tensor
222
- high: The max value in Tensor
223
- shape:The shape of Tensor
224
- """
225
- low, high = int(low), int(high)
226
- if low > high:
227
- low, high = high, low
228
- tensor = torch.randint(low, high + 1, shape)
229
- data = torch.gt(tensor, 0)
230
- return data
231
-
232
-
233
- def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
234
- """
235
- Function Description:
236
- Based on API basic information, generate input parameters: args, for API forward running
237
- Parameter:
238
- api_info: API basic information. List
239
- api_name: API name
240
- need_grad: set Tensor grad for backward
241
- convert_type: convert ori_type to dist_type flag.
242
- real_data_path: the root directory for storing real data.
243
- """
244
- check_object_type(args_info, list)
245
- args_result = []
246
- for arg in args_info:
247
- if isinstance(arg, (list, tuple)):
248
- data = gen_args(arg, api_name, need_grad, convert_type, real_data_path)
249
- elif isinstance(arg, dict):
250
- data = gen_data(arg, api_name, need_grad, convert_type, real_data_path)
251
- elif arg is None:
252
- data = None
253
- else:
254
- logger.warning(f'Warning: {arg} is not supported')
255
- raise NotImplementedError()
256
- args_result.append(data)
257
- return args_result
258
-
259
-
260
- def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
261
- """
262
- Function Description:
263
- Based on API basic information, generate input parameters: kwargs, for API forward running
264
- Parameter:
265
- api_info: API basic information. Dict
266
- api_name: API name
267
- convert_type: convert ori_type to dist_type flag.
268
- real_data_path: the root directory for storing real data.
269
- """
270
- check_object_type(api_info, dict)
271
- kwargs_params = api_info.get("input_kwargs")
272
- for key, value in kwargs_params.items():
273
- if isinstance(value, (list, tuple)):
274
- kwargs_params[key] = gen_list_kwargs(value, api_name, convert_type, real_data_path)
275
- elif value is None:
276
- kwargs_params[key] = None
277
- elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"):
278
- kwargs_params[key] = gen_data(value, api_name, True, convert_type, real_data_path)
279
- elif value.get('type') in TORCH_TYPE:
280
- gen_torch_kwargs(kwargs_params, key, value)
281
- else:
282
- kwargs_params[key] = value.get('value')
283
- return kwargs_params
284
-
285
-
286
- def gen_torch_kwargs(kwargs_params, key, value):
287
- if value.get('type') != "torch.device":
288
- kwargs_params[key] = eval(value.get('value'))
289
-
290
-
291
- def gen_list_kwargs(kwargs_item_value, api_name, convert_type, real_data_path=None):
292
- """
293
- Function Description:
294
- When kwargs value is list, generate the list of kwargs result
295
- Parameter:
296
- kwargs_item_value: kwargs value before to generate. List
297
- api_name: API name
298
- convert_type: convert ori_type to dist_type flag.
299
- """
300
- kwargs_item_result = []
301
- for item in kwargs_item_value:
302
- if item.get('type') in TENSOR_DATA_LIST:
303
- item_value = gen_data(item, api_name, False, convert_type, real_data_path)
304
- elif item.get('type') == "torch.Size":
305
- item_value = torch.Size(item.get('value'))
306
- else:
307
- item_value = item.get('value')
308
- kwargs_item_result.append(item_value)
309
- return kwargs_item_result
310
-
311
-
312
- def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
313
- """
314
- Function Description:
315
- Based on API basic information, generate input parameters: args, kwargs, for API forward running
316
- Parameter:
317
- api_info: API basic information. Dict
318
- api_name: API name
319
- need_grad: set grad for backward
320
- convert_type: convert ori_type to dist_type flag.
321
- """
322
- check_object_type(api_info, dict)
323
- if convert_type and convert_type not in Const.CONVERT:
324
- error_info = f"convert_type params not support {convert_type}."
325
- raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
326
- kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path)
327
- if api_info.get("input_args"):
328
- args_params = gen_args(api_info.get("input_args"), api_name, need_grad, convert_type, real_data_path)
329
- else:
330
- logger.warning(f'Warning: No args in {api_info} ')
331
- args_params = []
332
- return args_params, kwargs_params
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ import os
19
+ import math
20
+ import torch
21
+ import numpy
22
+
23
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
24
+ from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \
25
+ CompareException
26
+ from msprobe.core.common.file_utils import FileChecker, load_npy
27
+ from msprobe.pytorch.common.log import logger
28
+ from msprobe.pytorch.common.utils import load_pt
29
+ from msprobe.core.common.const import Const, FileCheckConst
30
+
31
+ TORCH_TYPE = ["torch.device", "torch.dtype"]
32
+ TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
33
+ FLOAT_TYPE = ['torch.float32', 'torch.float', 'torch.float64', 'torch.double', 'torch.float16',
34
+ 'torch.half', 'torch.bfloat16']
35
+ NUMPY_TYPE = ["numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32",
36
+ "numpy.uint64", "numpy.float16", "numpy.float32", "numpy.float64", "numpy.float128", "numpy.complex64",
37
+ "numpy.complex128", "numpy.complex256", "numpy.bool_", "numpy.string_", "numpy.bytes_", "numpy.unicode_"]
38
+
39
+
40
+ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
41
+ """
42
+ Function Description:
43
+ Based on arg basic information, generate arg data
44
+ Parameter:
45
+ info: arg basic information. Dict
46
+ api_name: API name
47
+ need_grad: set Tensor grad for backward
48
+ convert_type: convert ori_type to dist_type flag.
49
+ """
50
+ check_object_type(info, dict)
51
+ data_type = info.get('type')
52
+ data_path = info.get('datapath', info.get('data_name'))
53
+ data_path = get_full_data_path(data_path, real_data_path)
54
+ if data_type in TENSOR_DATA_LIST:
55
+ if data_path:
56
+ data = gen_real_tensor(data_path, convert_type)
57
+ else:
58
+ data = gen_random_tensor(info, convert_type)
59
+ if api_name in hf_32_standard_api and data.dtype == torch.float32:
60
+ data = fp32_to_hf32_to_fp32(data)
61
+ if info.get('requires_grad') and need_grad:
62
+ data.requires_grad_(True)
63
+ temp_data = data * 1
64
+ data = temp_data.type_as(data)
65
+ data.retain_grad()
66
+ elif data_type.startswith("numpy"):
67
+ if data_type not in NUMPY_TYPE:
68
+ raise Exception("{} is not supported now".format(data_type))
69
+ data = info.get("value")
70
+ try:
71
+ data = eval(data_type)(data)
72
+ except Exception as err:
73
+ logger.error("Failed to convert the type to numpy: %s" % str(err))
74
+ elif data_type == "torch.Size":
75
+ data = torch.Size(info.get("value"))
76
+ else:
77
+ data = info.get('value')
78
+ if info.get("type") == "slice":
79
+ data = slice(*data)
80
+ if info.get("type") == "ellipsis":
81
+ data = ...
82
+ return data
83
+
84
+
85
+ def gen_real_tensor(data_path, convert_type):
86
+ """
87
+ Function Description:
88
+ Based on API data path, generate input parameters real data
89
+ Parameter:
90
+ data_path: API data path
91
+ convert_type: convert ori_type to dist_type flag.
92
+ """
93
+ data_path = os.path.realpath(data_path)
94
+ data_path_checker = FileChecker(data_path, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE)
95
+ data_path = data_path_checker.common_check()
96
+ if not data_path.endswith('.pt') and not data_path.endswith('.npy'):
97
+ error_info = f"The file: {data_path} is not a pt or numpy file."
98
+ raise CompareException(CompareException.INVALID_FILE_ERROR, error_info)
99
+ if data_path.endswith('.pt'):
100
+ data = load_pt(data_path, to_cpu=True)
101
+ else:
102
+ data_np = load_npy(data_path)
103
+ data = torch.from_numpy(data_np)
104
+ if convert_type:
105
+ ori_dtype = Const.CONVERT.get(convert_type)[0]
106
+ dist_dtype = Const.CONVERT.get(convert_type)[1]
107
+ if str(data.dtype) == ori_dtype:
108
+ data = data.type(eval(dist_dtype))
109
+ return data
110
+
111
+
112
+ def gen_random_tensor(info, convert_type):
113
+ """
114
+ Function Description:
115
+ Based on API MAX and MIN, generate input parameters random data
116
+ Parameter:
117
+ info: API data info
118
+ convert_type: convert ori_type to dist_type flag.
119
+ """
120
+ check_object_type(info, dict)
121
+ low, high = info.get('Min'), info.get('Max')
122
+ low_origin, high_origin = info.get('Min_origin'), info.get('Max_origin')
123
+ low_info = [low, low_origin]
124
+ high_info = [high, high_origin]
125
+ data_dtype = info.get('dtype')
126
+ shape = tuple(info.get('shape'))
127
+ if not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
128
+ error_info = f'Data info Min: {low} , Max: {high}, info type must be int or float.'
129
+ raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
130
+ if data_dtype == "torch.bool":
131
+ data = gen_bool_tensor(low, high, shape)
132
+ else:
133
+ data = gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type)
134
+ return data
135
+
136
+
137
+ def fp32_to_hf32_to_fp32(input_tensor):
138
+ # 将输入的float32 tensor转为hf32 tensor,再转为float32 tensor
139
+ input_np = input_tensor.detach().numpy()
140
+ input_int = input_np.view(numpy.int32)
141
+ input_int = numpy.right_shift(numpy.right_shift(input_int, 11) + 1, 1)
142
+ input_int = numpy.left_shift(input_int, 12)
143
+ input_fp32 = input_int.view(numpy.float32)
144
+ input_hf32 = torch.from_numpy(input_fp32)
145
+ return input_hf32
146
+
147
+
148
+ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
149
+ """
150
+ Function Description:
151
+ Based on API basic information, generate int or float tensor
152
+ Parameter:
153
+ low_info: [low, low_origin], low is the minimum value in the tensor removed inf and nan,
154
+ low_origin is the original minimum value in the tensor
155
+ high_info: [high, high_origin], high is the maximum value in the tensor removed inf and nan,
156
+ high_origin is the original maximum value in the tensor
157
+ shape:The shape of Tensor
158
+ data_dtype: The data type of Tensor
159
+ convert_type: convert ori_type to dist_type flag.
160
+ """
161
+ if convert_type:
162
+ ori_dtype = Const.CONVERT.get(convert_type)[0]
163
+ if ori_dtype == data_dtype:
164
+ data_dtype = Const.CONVERT.get(convert_type)[1]
165
+ low, low_origin = low_info[0], low_info[1]
166
+ high, high_origin = high_info[0], high_info[1]
167
+ if data_dtype in FLOAT_TYPE:
168
+ if math.isnan(high):
169
+ tensor = torch._C._VariableFunctionsClass.full(shape, float('nan'), dtype=eval(data_dtype))
170
+ return tensor
171
+ #high_origin为新版json中的属性,只有当high_origin不为None,且high为inf或-inf时,原tensor全为inf或-inf
172
+ if high_origin and high in [float('inf'), float('-inf')]:
173
+ tensor = torch._C._VariableFunctionsClass.full(shape, high, dtype=eval(data_dtype))
174
+ tensor[-1] = low
175
+ return tensor
176
+ low_scale, high_scale = low, high
177
+ dtype_finfo = torch.finfo(eval(data_dtype))
178
+ #适配老版json high和low为inf或-inf的情况,取dtype的最大值或最小值进行放缩
179
+ if high == float('inf'):
180
+ high_scale = dtype_finfo.max
181
+ elif high == float('-inf'):
182
+ high_scale = dtype_finfo.min
183
+ if low == float('inf'):
184
+ low_scale = dtype_finfo.max
185
+ elif low == float('-inf'):
186
+ low_scale = dtype_finfo.min
187
+
188
+ scale = high_scale - low_scale
189
+ rand01 = torch.rand(shape, dtype=eval(data_dtype))
190
+ tensor = rand01 * scale + low_scale
191
+ elif 'int' in data_dtype or 'long' in data_dtype:
192
+ low, high = int(low), int(high)
193
+ tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype))
194
+ else:
195
+ logger.error('Dtype is not supported: ' + data_dtype)
196
+ raise NotImplementedError()
197
+ if tensor.nelement() == 0:
198
+ return tensor
199
+ tmp_tensor = tensor.reshape(-1)
200
+ if high_origin and math.isnan(high_origin):
201
+ if tmp_tensor.numel() <= 2:
202
+ tmp_tensor[0] = float('nan')
203
+ tmp_tensor[-1] = high
204
+ else:
205
+ tmp_tensor[0] = low
206
+ tmp_tensor[1] = float('nan')
207
+ tmp_tensor[-1] = high
208
+ else:
209
+ tmp_tensor[0] = low
210
+ tmp_tensor[-1] = high
211
+ if high_origin in [float('inf'), float('-inf')]:
212
+ tmp_tensor[-1] = high_origin
213
+ if low_origin in [float('inf'), float('-inf')]:
214
+ tmp_tensor[0] = low_origin
215
+ data = tmp_tensor.reshape(shape)
216
+ return data
217
+
218
+
219
+ def gen_bool_tensor(low, high, shape):
220
+ """
221
+ Function Description:
222
+ Based on API basic information, generate bool tensor
223
+ Parameter:
224
+ low: The minimum value in Tensor
225
+ high: The max value in Tensor
226
+ shape:The shape of Tensor
227
+ """
228
+ low, high = int(low), int(high)
229
+ if low > high:
230
+ low, high = high, low
231
+ tensor = torch.randint(low, high + 1, shape)
232
+ data = torch.gt(tensor, 0)
233
+ return data
234
+
235
+
236
+ def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
237
+ """
238
+ Function Description:
239
+ Based on API basic information, generate input parameters: args, for API forward running
240
+ Parameter:
241
+ api_info: API basic information. List
242
+ api_name: API name
243
+ need_grad: set Tensor grad for backward
244
+ convert_type: convert ori_type to dist_type flag.
245
+ real_data_path: the root directory for storing real data.
246
+ """
247
+ check_object_type(args_info, list)
248
+ args_result = []
249
+ for arg in args_info:
250
+ if isinstance(arg, (list, tuple)):
251
+ data = gen_args(arg, api_name, need_grad, convert_type, real_data_path)
252
+ elif isinstance(arg, dict):
253
+ data = gen_data(arg, api_name, need_grad, convert_type, real_data_path)
254
+ elif arg is None:
255
+ data = None
256
+ else:
257
+ logger.warning(f'Warning: {arg} is not supported')
258
+ raise NotImplementedError()
259
+ args_result.append(data)
260
+ return args_result
261
+
262
+
263
+ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
264
+ """
265
+ Function Description:
266
+ Based on API basic information, generate input parameters: kwargs, for API forward running
267
+ Parameter:
268
+ api_info: API basic information. Dict
269
+ api_name: API name
270
+ convert_type: convert ori_type to dist_type flag.
271
+ real_data_path: the root directory for storing real data.
272
+ """
273
+ check_object_type(api_info, dict)
274
+ kwargs_params = api_info.get("input_kwargs")
275
+ for key, value in kwargs_params.items():
276
+ if isinstance(value, (list, tuple)):
277
+ kwargs_params[key] = gen_list_kwargs(value, api_name, convert_type, real_data_path)
278
+ elif value is None:
279
+ kwargs_params[key] = None
280
+ elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"):
281
+ kwargs_params[key] = gen_data(value, api_name, True, convert_type, real_data_path)
282
+ elif value.get('type') in TORCH_TYPE:
283
+ gen_torch_kwargs(kwargs_params, key, value)
284
+ else:
285
+ kwargs_params[key] = value.get('value')
286
+ return kwargs_params
287
+
288
+
289
+ def gen_torch_kwargs(kwargs_params, key, value):
290
+ if value.get('type') != "torch.device":
291
+ kwargs_params[key] = eval(value.get('value'))
292
+
293
+
294
+ def gen_list_kwargs(kwargs_item_value, api_name, convert_type, real_data_path=None):
295
+ """
296
+ Function Description:
297
+ When kwargs value is list, generate the list of kwargs result
298
+ Parameter:
299
+ kwargs_item_value: kwargs value before to generate. List
300
+ api_name: API name
301
+ convert_type: convert ori_type to dist_type flag.
302
+ """
303
+ kwargs_item_result = []
304
+ for item in kwargs_item_value:
305
+ if item.get('type') in TENSOR_DATA_LIST:
306
+ item_value = gen_data(item, api_name, False, convert_type, real_data_path)
307
+ elif item.get('type') == "torch.Size":
308
+ item_value = torch.Size(item.get('value'))
309
+ else:
310
+ item_value = item.get('value')
311
+ kwargs_item_result.append(item_value)
312
+ return kwargs_item_result
313
+
314
+
315
+ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
316
+ """
317
+ Function Description:
318
+ Based on API basic information, generate input parameters: args, kwargs, for API forward running
319
+ Parameter:
320
+ api_info: API basic information. Dict
321
+ api_name: API name
322
+ need_grad: set grad for backward
323
+ convert_type: convert ori_type to dist_type flag.
324
+ """
325
+ check_object_type(api_info, dict)
326
+ if convert_type and convert_type not in Const.CONVERT:
327
+ error_info = f"convert_type params not support {convert_type}."
328
+ raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
329
+ kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path)
330
+ if api_info.get("input_args"):
331
+ args_params = gen_args(api_info.get("input_args"), api_name, need_grad, convert_type, real_data_path)
332
+ else:
333
+ logger.warning(f'Warning: No args in {api_info} ')
334
+ args_params = []
335
+ return args_params, kwargs_params