mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,332 +1,370 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
-
18
- import os
19
- import math
20
- import torch
21
- import numpy
22
-
23
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
24
- from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \
25
- CompareException
26
- from msprobe.core.common.file_check import FileChecker
27
- from msprobe.pytorch.common.log import logger
28
- from msprobe.core.common.const import Const, FileCheckConst
29
-
30
- TORCH_TYPE = ["torch.device", "torch.dtype"]
31
- TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
32
- FLOAT_TYPE = ['torch.float32', 'torch.float', 'torch.float64', 'torch.double', 'torch.float16',
33
- 'torch.half', 'torch.bfloat16']
34
- NUMPY_TYPE = ["numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32",
35
- "numpy.uint64", "numpy.float16", "numpy.float32", "numpy.float64", "numpy.float128", "numpy.complex64",
36
- "numpy.complex128", "numpy.complex256", "numpy.bool_", "numpy.string_", "numpy.bytes_", "numpy.unicode_"]
37
-
38
-
39
- def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
40
- """
41
- Function Description:
42
- Based on arg basic information, generate arg data
43
- Parameter:
44
- info: arg basic information. Dict
45
- api_name: API name
46
- need_grad: set Tensor grad for backward
47
- convert_type: convert ori_type to dist_type flag.
48
- """
49
- check_object_type(info, dict)
50
- data_type = info.get('type')
51
- data_path = info.get('datapath', info.get('data_name'))
52
- data_path = get_full_data_path(data_path, real_data_path)
53
- if data_type in TENSOR_DATA_LIST:
54
- if data_path:
55
- data = gen_real_tensor(data_path, convert_type)
56
- else:
57
- data = gen_random_tensor(info, convert_type)
58
- if api_name in hf_32_standard_api and data.dtype == torch.float32:
59
- data = fp32_to_hf32_to_fp32(data)
60
- if info.get('requires_grad') and need_grad:
61
- data.requires_grad_(True)
62
- temp_data = data * 1
63
- data = temp_data.type_as(data)
64
- data.retain_grad()
65
- elif data_type.startswith("numpy"):
66
- if data_type not in NUMPY_TYPE:
67
- raise Exception("{} is not supported now".format(data_type))
68
- data = info.get("value")
69
- try:
70
- data = eval(data_type)(data)
71
- except Exception as err:
72
- logger.error("Failed to convert the type to numpy: %s" % str(err))
73
- elif data_type == "torch.Size":
74
- data = torch.Size(info.get("value"))
75
- else:
76
- data = info.get('value')
77
- if info.get("type") == "slice":
78
- data = slice(*data)
79
- return data
80
-
81
-
82
- def gen_real_tensor(data_path, convert_type):
83
- """
84
- Function Description:
85
- Based on API data path, generate input parameters real data
86
- Parameter:
87
- data_path: API data path
88
- convert_type: convert ori_type to dist_type flag.
89
- """
90
- data_path = os.path.realpath(data_path)
91
- data_path_checker = FileChecker(data_path, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE)
92
- data_path = data_path_checker.common_check()
93
- if not data_path.endswith('.pt') and not data_path.endswith('.npy'):
94
- error_info = f"The file: {data_path} is not a pt or numpy file."
95
- raise CompareException(CompareException.INVALID_FILE_ERROR, error_info)
96
- if data_path.endswith('.pt'):
97
- data = torch.load(data_path, map_location=torch.device('cpu'))
98
- else:
99
- data_np = numpy.load(data_path)
100
- data = torch.from_numpy(data_np)
101
- if convert_type:
102
- ori_dtype = Const.CONVERT.get(convert_type)[0]
103
- dist_dtype = Const.CONVERT.get(convert_type)[1]
104
- if str(data.dtype) == ori_dtype:
105
- data = data.type(eval(dist_dtype))
106
- return data
107
-
108
-
109
- def gen_random_tensor(info, convert_type):
110
- """
111
- Function Description:
112
- Based on API MAX and MIN, generate input parameters random data
113
- Parameter:
114
- info: API data info
115
- convert_type: convert ori_type to dist_type flag.
116
- """
117
- check_object_type(info, dict)
118
- low, high = info.get('Min'), info.get('Max')
119
- low_origin, high_origin = info.get('Min_origin'), info.get('Max_origin')
120
- low_info = [low, low_origin]
121
- high_info = [high, high_origin]
122
- data_dtype = info.get('dtype')
123
- shape = tuple(info.get('shape'))
124
- if not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
125
- error_info = f'Data info Min: {low} , Max: {high}, info type must be int or float.'
126
- raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
127
- if data_dtype == "torch.bool":
128
- data = gen_bool_tensor(low, high, shape)
129
- else:
130
- data = gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type)
131
- return data
132
-
133
-
134
- def fp32_to_hf32_to_fp32(input_tensor):
135
- # 将输入的float32 tensor转为hf32 tensor,再转为float32 tensor
136
- input_np = input_tensor.detach().numpy()
137
- input_int = input_np.view(numpy.int32)
138
- input_int = numpy.right_shift(numpy.right_shift(input_int, 11) + 1, 1)
139
- input_int = numpy.left_shift(input_int, 12)
140
- input_fp32 = input_int.view(numpy.float32)
141
- input_hf32 = torch.from_numpy(input_fp32)
142
- return input_hf32
143
-
144
-
145
- def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
146
- """
147
- Function Description:
148
- Based on API basic information, generate int or float tensor
149
- Parameter:
150
- low_info: [low, low_origin], low is the minimum value in the tensor removed inf and nan,
151
- low_origin is the original minimum value in the tensor
152
- high_info: [high, high_origin], high is the maximum value in the tensor removed inf and nan,
153
- high_origin is the original maximum value in the tensor
154
- shape:The shape of Tensor
155
- data_dtype: The data type of Tensor
156
- convert_type: convert ori_type to dist_type flag.
157
- """
158
- if convert_type:
159
- ori_dtype = Const.CONVERT.get(convert_type)[0]
160
- if ori_dtype == data_dtype:
161
- data_dtype = Const.CONVERT.get(convert_type)[1]
162
- low, low_origin = low_info[0], low_info[1]
163
- high, high_origin = high_info[0], high_info[1]
164
- if data_dtype in FLOAT_TYPE:
165
- if math.isnan(high):
166
- tensor = torch._C._VariableFunctionsClass.full(shape, float('nan'), dtype=eval(data_dtype))
167
- return tensor
168
- #high_origin为新版json中的属性,只有当high_origin不为None,且high为inf或-inf时,原tensor全为inf或-inf
169
- if high_origin and high in [float('inf'), float('-inf')]:
170
- tensor = torch._C._VariableFunctionsClass.full(shape, high, dtype=eval(data_dtype))
171
- tensor[-1] = low
172
- return tensor
173
- low_scale, high_scale = low, high
174
- dtype_finfo = torch.finfo(eval(data_dtype))
175
- #适配老版json high和low为inf或-inf的情况,取dtype的最大值或最小值进行放缩
176
- if high == float('inf'):
177
- high_scale = dtype_finfo.max
178
- elif high == float('-inf'):
179
- high_scale = dtype_finfo.min
180
- if low == float('inf'):
181
- low_scale = dtype_finfo.max
182
- elif low == float('-inf'):
183
- low_scale = dtype_finfo.min
184
-
185
- scale = high_scale - low_scale
186
- rand01 = torch.rand(shape, dtype=eval(data_dtype))
187
- tensor = rand01 * scale + low_scale
188
- elif 'int' in data_dtype or 'long' in data_dtype:
189
- low, high = int(low), int(high)
190
- tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype))
191
- else:
192
- logger.error('Dtype is not supported: ' + data_dtype)
193
- raise NotImplementedError()
194
- if tensor.nelement() == 0:
195
- return tensor
196
- tmp_tensor = tensor.reshape(-1)
197
- if high_origin and math.isnan(high_origin):
198
- if tmp_tensor.numel() <= 2:
199
- tmp_tensor[0] = float('nan')
200
- tmp_tensor[-1] = high
201
- else:
202
- tmp_tensor[0] = low
203
- tmp_tensor[1] = float('nan')
204
- tmp_tensor[-1] = high
205
- else:
206
- tmp_tensor[0] = low
207
- tmp_tensor[-1] = high
208
- if high_origin in [float('inf'), float('-inf')]:
209
- tmp_tensor[-1] = high_origin
210
- if low_origin in [float('inf'), float('-inf')]:
211
- tmp_tensor[0] = low_origin
212
- data = tmp_tensor.reshape(shape)
213
- return data
214
-
215
-
216
- def gen_bool_tensor(low, high, shape):
217
- """
218
- Function Description:
219
- Based on API basic information, generate bool tensor
220
- Parameter:
221
- low: The minimum value in Tensor
222
- high: The max value in Tensor
223
- shape:The shape of Tensor
224
- """
225
- low, high = int(low), int(high)
226
- if low > high:
227
- low, high = high, low
228
- tensor = torch.randint(low, high + 1, shape)
229
- data = torch.gt(tensor, 0)
230
- return data
231
-
232
-
233
- def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
234
- """
235
- Function Description:
236
- Based on API basic information, generate input parameters: args, for API forward running
237
- Parameter:
238
- api_info: API basic information. List
239
- api_name: API name
240
- need_grad: set Tensor grad for backward
241
- convert_type: convert ori_type to dist_type flag.
242
- real_data_path: the root directory for storing real data.
243
- """
244
- check_object_type(args_info, list)
245
- args_result = []
246
- for arg in args_info:
247
- if isinstance(arg, (list, tuple)):
248
- data = gen_args(arg, api_name, need_grad, convert_type, real_data_path)
249
- elif isinstance(arg, dict):
250
- data = gen_data(arg, api_name, need_grad, convert_type, real_data_path)
251
- elif arg is None:
252
- data = None
253
- else:
254
- logger.warning(f'Warning: {arg} is not supported')
255
- raise NotImplementedError()
256
- args_result.append(data)
257
- return args_result
258
-
259
-
260
- def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
261
- """
262
- Function Description:
263
- Based on API basic information, generate input parameters: kwargs, for API forward running
264
- Parameter:
265
- api_info: API basic information. Dict
266
- api_name: API name
267
- convert_type: convert ori_type to dist_type flag.
268
- real_data_path: the root directory for storing real data.
269
- """
270
- check_object_type(api_info, dict)
271
- kwargs_params = api_info.get("input_kwargs")
272
- for key, value in kwargs_params.items():
273
- if isinstance(value, (list, tuple)):
274
- kwargs_params[key] = gen_list_kwargs(value, api_name, convert_type, real_data_path)
275
- elif value is None:
276
- kwargs_params[key] = None
277
- elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"):
278
- kwargs_params[key] = gen_data(value, api_name, True, convert_type, real_data_path)
279
- elif value.get('type') in TORCH_TYPE:
280
- gen_torch_kwargs(kwargs_params, key, value)
281
- else:
282
- kwargs_params[key] = value.get('value')
283
- return kwargs_params
284
-
285
-
286
- def gen_torch_kwargs(kwargs_params, key, value):
287
- if value.get('type') != "torch.device":
288
- kwargs_params[key] = eval(value.get('value'))
289
-
290
-
291
- def gen_list_kwargs(kwargs_item_value, api_name, convert_type, real_data_path=None):
292
- """
293
- Function Description:
294
- When kwargs value is list, generate the list of kwargs result
295
- Parameter:
296
- kwargs_item_value: kwargs value before to generate. List
297
- api_name: API name
298
- convert_type: convert ori_type to dist_type flag.
299
- """
300
- kwargs_item_result = []
301
- for item in kwargs_item_value:
302
- if item.get('type') in TENSOR_DATA_LIST:
303
- item_value = gen_data(item, api_name, False, convert_type, real_data_path)
304
- elif item.get('type') == "torch.Size":
305
- item_value = torch.Size(item.get('value'))
306
- else:
307
- item_value = item.get('value')
308
- kwargs_item_result.append(item_value)
309
- return kwargs_item_result
310
-
311
-
312
- def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
313
- """
314
- Function Description:
315
- Based on API basic information, generate input parameters: args, kwargs, for API forward running
316
- Parameter:
317
- api_info: API basic information. Dict
318
- api_name: API name
319
- need_grad: set grad for backward
320
- convert_type: convert ori_type to dist_type flag.
321
- """
322
- check_object_type(api_info, dict)
323
- if convert_type and convert_type not in Const.CONVERT:
324
- error_info = f"convert_type params not support {convert_type}."
325
- raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
326
- kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path)
327
- if api_info.get("input_args"):
328
- args_params = gen_args(api_info.get("input_args"), api_name, need_grad, convert_type, real_data_path)
329
- else:
330
- logger.warning(f'Warning: No args in {api_info} ')
331
- args_params = []
332
- return args_params, kwargs_params
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import os
19
+ import math
20
+ import torch
21
+ import numpy
22
+
23
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
24
+ from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \
25
+ CompareException, get_module_and_atttribute_name, get_attribute
26
+ from msprobe.core.common.file_utils import FileChecker, load_npy
27
+ from msprobe.pytorch.common.log import logger
28
+ from msprobe.pytorch.common.utils import load_pt
29
+ from msprobe.core.common.const import Const, FileCheckConst, CompareConst
30
+
31
+ TORCH_TYPE = ["torch.device", "torch.dtype"]
32
+ TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
33
+ FLOAT_TYPE = [
34
+ 'torch.float32',
35
+ 'torch.float',
36
+ 'torch.float64',
37
+ 'torch.double',
38
+ 'torch.float16',
39
+ 'torch.half',
40
+ 'torch.bfloat16'
41
+ ]
42
+ NUMPY_TYPE = [
43
+ "numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32",
44
+ "numpy.uint64", "numpy.float16", "numpy.float32", "numpy.float64", "numpy.float128", "numpy.complex64",
45
+ "numpy.complex128", "numpy.complex256", "numpy.bool_", "numpy.string_", "numpy.bytes_", "numpy.unicode_"
46
+ ]
47
+
48
+
49
+ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
50
+ """
51
+ Function Description:
52
+ Based on arg basic information, generate arg data
53
+ Parameter:
54
+ info: arg basic information. Dict
55
+ api_name: API name
56
+ need_grad: set Tensor grad for backward
57
+ convert_type: convert ori_type to dist_type flag.
58
+ """
59
+ check_object_type(info, dict)
60
+ data_type = info.get('type')
61
+ data_path = info.get('datapath', info.get('data_name'))
62
+ data_path = get_full_data_path(data_path, real_data_path)
63
+ if data_type in TENSOR_DATA_LIST:
64
+ if data_path:
65
+ data = gen_real_tensor(data_path, convert_type)
66
+ else:
67
+ data = gen_random_tensor(info, convert_type)
68
+ if api_name in hf_32_standard_api and data.dtype == torch.float32:
69
+ data = fp32_to_hf32_to_fp32(data)
70
+ if info.get('requires_grad') and need_grad:
71
+ data.requires_grad_(True)
72
+ temp_data = data * 1
73
+ data = temp_data.type_as(data)
74
+ data.retain_grad()
75
+ elif data_type.startswith("numpy"):
76
+ if data_type not in NUMPY_TYPE:
77
+ raise Exception("{} is not supported now".format(data_type))
78
+ data = info.get("value")
79
+ try:
80
+ module_name, attribute_name = get_module_and_atttribute_name(data_type)
81
+ data = get_attribute(module_name, attribute_name)(data)
82
+ except Exception as err:
83
+ logger.error("Failed to convert the type to numpy: %s" % str(err))
84
+ elif data_type == "torch.Size":
85
+ data = torch.Size(info.get("value"))
86
+ else:
87
+ data = info.get('value')
88
+ if info.get("type") == "slice":
89
+ data = slice(*data)
90
+ if info.get("type") == "ellipsis":
91
+ data = ...
92
+ return data
93
+
94
+
95
+ def gen_real_tensor(data_path, convert_type):
96
+ """
97
+ Function Description:
98
+ Based on API data path, generate input parameters real data
99
+ Parameter:
100
+ data_path: API data path
101
+ convert_type: convert ori_type to dist_type flag.
102
+ """
103
+ data_path = os.path.realpath(data_path)
104
+ data_path_checker = FileChecker(data_path, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE)
105
+ data_path = data_path_checker.common_check()
106
+ if not data_path.endswith('.pt') and not data_path.endswith('.npy'):
107
+ error_info = f"The file: {data_path} is not a pt or numpy file."
108
+ raise CompareException(CompareException.INVALID_FILE_ERROR, error_info)
109
+ if data_path.endswith('.pt'):
110
+ data = load_pt(data_path, to_cpu=True)
111
+ else:
112
+ data_np = load_npy(data_path)
113
+ data = torch.from_numpy(data_np)
114
+ if convert_type:
115
+ ori_dtype = Const.CONVERT.get(convert_type)[0]
116
+ dist_dtype = Const.CONVERT.get(convert_type)[1]
117
+ module_name, attribute_name = get_module_and_atttribute_name(dist_dtype)
118
+ if str(data.dtype) == ori_dtype:
119
+ data = data.type(get_attribute(module_name, attribute_name))
120
+ return data
121
+
122
+
123
+ def gen_random_tensor(info, convert_type):
124
+ """
125
+ Function Description:
126
+ Based on API MAX and MIN, generate input parameters random data
127
+ Parameter:
128
+ info: API data info
129
+ convert_type: convert ori_type to dist_type flag.
130
+ """
131
+ check_object_type(info, dict)
132
+
133
+ low_origin = info.get('Min')
134
+ low = info.get('Min_except_inf_nan', low_origin)
135
+ high_origin = info.get('Max')
136
+ high = info.get('Max_except_inf_nan', high_origin)
137
+
138
+ low_info = [low, low_origin]
139
+ high_info = [high, high_origin]
140
+ data_dtype = info.get('dtype')
141
+ shape = tuple(info.get('shape'))
142
+ if not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
143
+ error_info = f'Data info Min: {low} , Max: {high}, info type must be int or float.'
144
+ raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
145
+ if data_dtype == "torch.bool":
146
+ data = gen_bool_tensor(low, high, shape)
147
+ else:
148
+ data = gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type)
149
+ return data
150
+
151
+
152
+ def fp32_to_hf32_to_fp32(input_tensor):
153
+ # 将输入的float32 tensor转为hf32 tensor,再转为float32 tensor
154
+ input_np = input_tensor.detach().numpy()
155
+ input_int = input_np.view(numpy.int32)
156
+ input_int = numpy.right_shift(numpy.right_shift(input_int, 11) + 1, 1)
157
+ input_int = numpy.left_shift(input_int, 12)
158
+ input_fp32 = input_int.view(numpy.float32)
159
+ input_hf32 = torch.from_numpy(input_fp32)
160
+ return input_hf32
161
+
162
+
163
+ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
164
+ """
165
+ Function Description:
166
+ Based on API basic information, generate int or float tensor
167
+ Parameter:
168
+ low_info: [low, low_origin], low is the minimum value in the tensor removed inf and nan,
169
+ low_origin is the original minimum value in the tensor
170
+ high_info: [high, high_origin], high is the maximum value in the tensor removed inf and nan,
171
+ high_origin is the original maximum value in the tensor
172
+ shape:The shape of Tensor
173
+ data_dtype: The data type of Tensor
174
+ convert_type: convert ori_type to dist_type flag.
175
+ """
176
+ if convert_type:
177
+ ori_dtype = Const.CONVERT.get(convert_type)[0]
178
+ if ori_dtype == data_dtype:
179
+ data_dtype = Const.CONVERT.get(convert_type)[1]
180
+ low, low_origin = low_info[0], low_info[1]
181
+ high, high_origin = high_info[0], high_info[1]
182
+ module_name, attribute_name = get_module_and_atttribute_name(data_dtype)
183
+ dtype = get_attribute(module_name, attribute_name)
184
+ if data_dtype in FLOAT_TYPE:
185
+ if math.isnan(high):
186
+ tensor = torch.full(shape, float('nan'), dtype=dtype)
187
+ return tensor
188
+ #high_origin为新版json中的属性,只有当high_origin不为None,且high为inf或-inf时,原tensor全为inf或-inf
189
+ if high_origin and high in [float(CompareConst.INF), float(CompareConst.NEG_INF)]:
190
+ tensor = torch.full(shape, high, dtype=dtype)
191
+ tensor[-1] = low
192
+ return tensor
193
+ low_scale, high_scale = low, high
194
+ dtype_finfo = torch.finfo(dtype)
195
+ #适配老版json high和low为inf或-inf的情况,取dtype的最大值或最小值进行放缩
196
+ if high == float(CompareConst.INF):
197
+ high_scale = dtype_finfo.max
198
+ elif high == float(CompareConst.NEG_INF):
199
+ high_scale = dtype_finfo.min
200
+ if low == float(CompareConst.INF):
201
+ low_scale = dtype_finfo.max
202
+ elif low == float(CompareConst.NEG_INF):
203
+ low_scale = dtype_finfo.min
204
+
205
+ scale = high_scale - low_scale
206
+ rand01 = torch.rand(shape, dtype=dtype)
207
+ tensor = rand01 * scale + low_scale
208
+ elif 'int' in data_dtype or 'long' in data_dtype:
209
+ low, high = int(low), int(high)
210
+ tensor = torch.randint(low, high + 1, shape, dtype=dtype)
211
+ else:
212
+ logger.error('Dtype is not supported: ' + data_dtype)
213
+ raise NotImplementedError()
214
+ if tensor.nelement() == 0:
215
+ return tensor
216
+ tmp_tensor = tensor.reshape(-1)
217
+ if high_origin and math.isnan(high_origin):
218
+ if tmp_tensor.numel() <= 2:
219
+ tmp_tensor[0] = float('nan')
220
+ tmp_tensor[-1] = high
221
+ else:
222
+ tmp_tensor[0] = low
223
+ tmp_tensor[1] = float('nan')
224
+ tmp_tensor[-1] = high
225
+ else:
226
+ tmp_tensor[0] = low
227
+ tmp_tensor[-1] = high
228
+ if high_origin in [float(CompareConst.INF), float(CompareConst.NEG_INF)]:
229
+ tmp_tensor[-1] = high_origin
230
+ if low_origin in [float(CompareConst.INF), float(CompareConst.NEG_INF)]:
231
+ tmp_tensor[0] = low_origin
232
+ data = tmp_tensor.reshape(shape)
233
+ return data
234
+
235
+
236
+ def gen_bool_tensor(low, high, shape):
237
+ """
238
+ Function Description:
239
+ Based on API basic information, generate bool tensor
240
+ Parameter:
241
+ low: The minimum value in Tensor
242
+ high: The max value in Tensor
243
+ shape:The shape of Tensor
244
+ """
245
+ low, high = int(low), int(high)
246
+ if low > high:
247
+ low, high = high, low
248
+ tensor = torch.randint(low, high + 1, shape)
249
+ data = torch.gt(tensor, 0)
250
+ return data
251
+
252
+
253
+ def gen_args(args_info, api_name, func_options):
254
+ """
255
+ Function Description:
256
+ Based on API basic information, generate input parameters: args, for API forward running
257
+ Parameter:
258
+ api_info: API basic information. List
259
+ api_name: API name
260
+ need_grad: set Tensor grad for backward
261
+ convert_type: convert ori_type to dist_type flag.
262
+ real_data_path: the root directory for storing real data.
263
+ """
264
+ check_object_type(args_info, list)
265
+ args_result = []
266
+
267
+ need_grad = func_options.get('need_grad', True)
268
+ convert_type = func_options.get('convert_type', None)
269
+ real_data_path = func_options.get('real_data_path', None)
270
+ depth = func_options.get('depth', 0)
271
+
272
+ if depth > Const.MAX_DEPTH:
273
+ logger.error("The depth of args is too large, please check the input args.")
274
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
275
+
276
+ for arg in args_info:
277
+ if isinstance(arg, (list, tuple)):
278
+ func_options['depth'] = depth + 1
279
+ data = gen_args(arg, api_name, func_options)
280
+ elif isinstance(arg, dict):
281
+ data = gen_data(arg, api_name, need_grad, convert_type, real_data_path)
282
+ elif arg is None:
283
+ data = None
284
+ else:
285
+ logger.warning(f'Warning: {arg} is not supported')
286
+ raise NotImplementedError()
287
+ args_result.append(data)
288
+ return args_result
289
+
290
+
291
+ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
292
+ """
293
+ Function Description:
294
+ Based on API basic information, generate input parameters: kwargs, for API forward running
295
+ Parameter:
296
+ api_info: API basic information. Dict
297
+ api_name: API name
298
+ convert_type: convert ori_type to dist_type flag.
299
+ real_data_path: the root directory for storing real data.
300
+ """
301
+ check_object_type(api_info, dict)
302
+ kwargs_params = api_info.get("input_kwargs")
303
+ for key, value in kwargs_params.items():
304
+ if isinstance(value, (list, tuple)):
305
+ kwargs_params[key] = gen_list_kwargs(value, api_name, convert_type, real_data_path)
306
+ elif value is None:
307
+ kwargs_params[key] = None
308
+ elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"):
309
+ kwargs_params[key] = gen_data(value, api_name, True, convert_type, real_data_path)
310
+ elif value.get('type') in TORCH_TYPE:
311
+ gen_torch_kwargs(kwargs_params, key, value)
312
+ else:
313
+ kwargs_params[key] = value.get('value')
314
+ return kwargs_params
315
+
316
+
317
+ def gen_torch_kwargs(kwargs_params, key, value):
318
+ if value.get('type') != "torch.device":
319
+ module_name, attribute_name = get_module_and_atttribute_name(value.get('value'))
320
+ kwargs_params[key] = get_attribute(module_name, attribute_name)
321
+
322
+
323
+ def gen_list_kwargs(kwargs_item_value, api_name, convert_type, real_data_path=None):
324
+ """
325
+ Function Description:
326
+ When kwargs value is list, generate the list of kwargs result
327
+ Parameter:
328
+ kwargs_item_value: kwargs value before to generate. List
329
+ api_name: API name
330
+ convert_type: convert ori_type to dist_type flag.
331
+ """
332
+ kwargs_item_result = []
333
+ for item in kwargs_item_value:
334
+ if item.get('type') in TENSOR_DATA_LIST:
335
+ item_value = gen_data(item, api_name, False, convert_type, real_data_path)
336
+ elif item.get('type') == "torch.Size":
337
+ item_value = torch.Size(item.get('value'))
338
+ else:
339
+ item_value = item.get('value')
340
+ kwargs_item_result.append(item_value)
341
+ return kwargs_item_result
342
+
343
+
344
+ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
345
+ """
346
+ Function Description:
347
+ Based on API basic information, generate input parameters: args, kwargs, for API forward running
348
+ Parameter:
349
+ api_info: API basic information. Dict
350
+ api_name: API name
351
+ need_grad: set grad for backward
352
+ convert_type: convert ori_type to dist_type flag.
353
+ """
354
+ check_object_type(api_info, dict)
355
+ if convert_type and convert_type not in Const.CONVERT:
356
+ error_info = f"convert_type params not support {convert_type}."
357
+ raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
358
+ kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path)
359
+ func_options = {
360
+ 'need_grad': need_grad,
361
+ 'convert_type': convert_type,
362
+ 'real_data_path': real_data_path,
363
+ 'depth': 0
364
+ }
365
+ if api_info.get("input_args"):
366
+ args_params = gen_args(api_info.get("input_args"), api_name, func_options)
367
+ else:
368
+ logger.warning(f'Warning: No args in {api_info} ')
369
+ args_params = []
370
+ return args_params, kwargs_params