mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,328 +1,335 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
-
18
- import os
19
- import math
20
- import torch
21
- import numpy
22
-
23
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
24
- from msprobe.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path, check_object_type, \
25
- get_full_data_path, CompareException
26
- from msprobe.pytorch.common.log import logger
27
- from msprobe.core.common.const import Const
28
-
29
- TORCH_TYPE = ["torch.device", "torch.dtype"]
30
- TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
31
- FLOAT_TYPE = ['torch.float32', 'torch.float', 'torch.float64', 'torch.double', 'torch.float16',
32
- 'torch.half', 'torch.bfloat16']
33
- NUMPY_TYPE = ["numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32",
34
- "numpy.uint64", "numpy.float16", "numpy.float32", "numpy.float64", "numpy.float128", "numpy.complex64",
35
- "numpy.complex128", "numpy.complex256", "numpy.bool_", "numpy.string_", "numpy.bytes_", "numpy.unicode_"]
36
-
37
-
38
- def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
39
- """
40
- Function Description:
41
- Based on arg basic information, generate arg data
42
- Parameter:
43
- info: arg basic information. Dict
44
- api_name: API name
45
- need_grad: set Tensor grad for backward
46
- convert_type: convert ori_type to dist_type flag.
47
- """
48
- check_object_type(info, dict)
49
- data_type = info.get('type')
50
- data_path = info.get('datapath', info.get('data_name'))
51
- data_path = get_full_data_path(data_path, real_data_path)
52
- if data_type in TENSOR_DATA_LIST:
53
- if data_path:
54
- data = gen_real_tensor(data_path, convert_type)
55
- else:
56
- data = gen_random_tensor(info, convert_type)
57
- if api_name in hf_32_standard_api and data.dtype == torch.float32:
58
- data = fp32_to_hf32_to_fp32(data)
59
- if info.get('requires_grad') and need_grad:
60
- data.requires_grad_(True)
61
- temp_data = data * 1
62
- data = temp_data.type_as(data)
63
- data.retain_grad()
64
- elif data_type.startswith("numpy"):
65
- if data_type not in NUMPY_TYPE:
66
- raise Exception("{} is not supported now".format(data_type))
67
- data = info.get("value")
68
- try:
69
- data = eval(data_type)(data)
70
- except Exception as err:
71
- logger.error("Failed to convert the type to numpy: %s" % str(err))
72
- elif data_type == "torch.Size":
73
- data = torch.Size(info.get("value"))
74
- else:
75
- data = info.get('value')
76
- if info.get("type") == "slice":
77
- data = slice(*data)
78
- return data
79
-
80
-
81
- def gen_real_tensor(data_path, convert_type):
82
- """
83
- Function Description:
84
- Based on API data path, generate input parameters real data
85
- Parameter:
86
- data_path: API data path
87
- convert_type: convert ori_type to dist_type flag.
88
- """
89
- data_path = os.path.realpath(data_path)
90
- check_file_or_directory_path(data_path)
91
- if not data_path.endswith('.pt') and not data_path.endswith('.npy'):
92
- error_info = f"The file: {data_path} is not a pt or numpy file."
93
- raise CompareException(CompareException.INVALID_FILE_ERROR, error_info)
94
- if data_path.endswith('.pt'):
95
- data = torch.load(data_path).cpu()
96
- else:
97
- data_np = numpy.load(data_path)
98
- data = torch.from_numpy(data_np)
99
- if convert_type:
100
- ori_dtype = Const.CONVERT.get(convert_type)[0]
101
- dist_dtype = Const.CONVERT.get(convert_type)[1]
102
- if str(data.dtype) == ori_dtype:
103
- data = data.type(eval(dist_dtype))
104
- return data
105
-
106
-
107
- def gen_random_tensor(info, convert_type):
108
- """
109
- Function Description:
110
- Based on API MAX and MIN, generate input parameters random data
111
- Parameter:
112
- info: API data info
113
- convert_type: convert ori_type to dist_type flag.
114
- """
115
- check_object_type(info, dict)
116
- low, high = info.get('Min'), info.get('Max')
117
- low_origin, high_origin = info.get('Min_origin'), info.get('Max_origin')
118
- low_info = [low, low_origin]
119
- high_info = [high, high_origin]
120
- data_dtype = info.get('dtype')
121
- shape = tuple(info.get('shape'))
122
- if not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
123
- error_info = f'Data info Min: {low} , Max: {high}, info type must be int or float.'
124
- raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
125
- if data_dtype == "torch.bool":
126
- data = gen_bool_tensor(low, high, shape)
127
- else:
128
- data = gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type)
129
- return data
130
-
131
-
132
- def fp32_to_hf32_to_fp32(input_tensor):
133
- # 将输入的float32 tensor转为hf32 tensor,再转为float32 tensor
134
- input_np = input_tensor.detach().numpy()
135
- input_int = input_np.view(numpy.int32)
136
- input_int = numpy.right_shift(numpy.right_shift(input_int, 11) + 1, 1)
137
- input_int = numpy.left_shift(input_int, 12)
138
- input_fp32 = input_int.view(numpy.float32)
139
- input_hf32 = torch.from_numpy(input_fp32)
140
- return input_hf32
141
-
142
-
143
- def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
144
- """
145
- Function Description:
146
- Based on API basic information, generate int or float tensor
147
- Parameter:
148
- low_info: [low, low_origin], low is the minimum value in the tensor removed inf and nan,
149
- low_origin is the original minimum value in the tensor
150
- high_info: [high, high_origin], high is the maximum value in the tensor removed inf and nan,
151
- high_origin is the original maximum value in the tensor
152
- shape:The shape of Tensor
153
- data_dtype: The data type of Tensor
154
- convert_type: convert ori_type to dist_type flag.
155
- """
156
- if convert_type:
157
- ori_dtype = Const.CONVERT.get(convert_type)[0]
158
- if ori_dtype == data_dtype:
159
- data_dtype = Const.CONVERT.get(convert_type)[1]
160
- low, low_origin = low_info[0], low_info[1]
161
- high, high_origin = high_info[0], high_info[1]
162
- if data_dtype in FLOAT_TYPE:
163
- if math.isnan(high):
164
- tensor = torch._C._VariableFunctionsClass.full(shape, float('nan'), dtype=eval(data_dtype))
165
- return tensor
166
- #high_origin为新版json中的属性,只有当high_origin不为None,且high为inf或-inf时,原tensor全为inf或-inf
167
- if high_origin and high in [float('inf'), float('-inf')]:
168
- tensor = torch._C._VariableFunctionsClass.full(shape, high, dtype=eval(data_dtype))
169
- tensor[-1] = low
170
- return tensor
171
- low_scale, high_scale = low, high
172
- dtype_finfo = torch.finfo(eval(data_dtype))
173
- #适配老版json high和low为inf或-inf的情况,取dtype的最大值或最小值进行放缩
174
- if high == float('inf'):
175
- high_scale = dtype_finfo.max
176
- elif high == float('-inf'):
177
- high_scale = dtype_finfo.min
178
- if low == float('inf'):
179
- low_scale = dtype_finfo.max
180
- elif low == float('-inf'):
181
- low_scale = dtype_finfo.min
182
-
183
- scale = high_scale - low_scale
184
- rand01 = torch.rand(shape, dtype=eval(data_dtype))
185
- tensor = rand01 * scale + low_scale
186
- elif 'int' in data_dtype or 'long' in data_dtype:
187
- low, high = int(low), int(high)
188
- tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype))
189
- else:
190
- logger.error('Dtype is not supported: ' + data_dtype)
191
- raise NotImplementedError()
192
- if tensor.nelement() == 0:
193
- return tensor
194
- tmp_tensor = tensor.reshape(-1)
195
- if high_origin and math.isnan(high_origin):
196
- if tmp_tensor.numel() <= 2:
197
- tmp_tensor[0] = float('nan')
198
- tmp_tensor[-1] = high
199
- else:
200
- tmp_tensor[0] = low
201
- tmp_tensor[1] = float('nan')
202
- tmp_tensor[-1] = high
203
- else:
204
- tmp_tensor[0] = low
205
- tmp_tensor[-1] = high
206
- if high_origin in [float('inf'), float('-inf')]:
207
- tmp_tensor[-1] = high_origin
208
- if low_origin in [float('inf'), float('-inf')]:
209
- tmp_tensor[0] = low_origin
210
- data = tmp_tensor.reshape(shape)
211
- return data
212
-
213
-
214
- def gen_bool_tensor(low, high, shape):
215
- """
216
- Function Description:
217
- Based on API basic information, generate bool tensor
218
- Parameter:
219
- low: The minimum value in Tensor
220
- high: The max value in Tensor
221
- shape:The shape of Tensor
222
- """
223
- low, high = int(low), int(high)
224
- if low > high:
225
- low, high = high, low
226
- tensor = torch.randint(low, high + 1, shape)
227
- data = torch.gt(tensor, 0)
228
- return data
229
-
230
-
231
- def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
232
- """
233
- Function Description:
234
- Based on API basic information, generate input parameters: args, for API forward running
235
- Parameter:
236
- api_info: API basic information. List
237
- api_name: API name
238
- need_grad: set Tensor grad for backward
239
- convert_type: convert ori_type to dist_type flag.
240
- real_data_path: the root directory for storing real data.
241
- """
242
- check_object_type(args_info, list)
243
- args_result = []
244
- for arg in args_info:
245
- if isinstance(arg, (list, tuple)):
246
- data = gen_args(arg, api_name, need_grad, convert_type, real_data_path)
247
- elif isinstance(arg, dict):
248
- data = gen_data(arg, api_name, need_grad, convert_type, real_data_path)
249
- elif arg is None:
250
- data = None
251
- else:
252
- logger.warning(f'Warning: {arg} is not supported')
253
- raise NotImplementedError()
254
- args_result.append(data)
255
- return args_result
256
-
257
-
258
- def gen_kwargs(api_info, convert_type=None, real_data_path=None):
259
- """
260
- Function Description:
261
- Based on API basic information, generate input parameters: kwargs, for API forward running
262
- Parameter:
263
- api_info: API basic information. Dict
264
- convert_type: convert ori_type to dist_type flag.
265
- real_data_path: the root directory for storing real data.
266
- """
267
- check_object_type(api_info, dict)
268
- kwargs_params = api_info.get("input_kwargs")
269
- for key, value in kwargs_params.items():
270
- if isinstance(value, (list, tuple)):
271
- kwargs_params[key] = gen_list_kwargs(value, convert_type, real_data_path)
272
- elif value is None:
273
- kwargs_params[key] = None
274
- elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"):
275
- kwargs_params[key] = gen_data(value, True, convert_type, real_data_path)
276
- elif value.get('type') in TORCH_TYPE:
277
- gen_torch_kwargs(kwargs_params, key, value)
278
- else:
279
- kwargs_params[key] = value.get('value')
280
- return kwargs_params
281
-
282
-
283
- def gen_torch_kwargs(kwargs_params, key, value):
284
- if value.get('type') != "torch.device":
285
- kwargs_params[key] = eval(value.get('value'))
286
-
287
-
288
- def gen_list_kwargs(kwargs_item_value, convert_type, real_data_path=None):
289
- """
290
- Function Description:
291
- When kwargs value is list, generate the list of kwargs result
292
- Parameter:
293
- kwargs_item_value: kwargs value before to generate. List
294
- convert_type: convert ori_type to dist_type flag.
295
- """
296
- kwargs_item_result = []
297
- for item in kwargs_item_value:
298
- if item.get('type') in TENSOR_DATA_LIST:
299
- item_value = gen_data(item, False, convert_type, real_data_path)
300
- elif item.get('type') == "torch.Size":
301
- item_value = torch.Size(item.get('value'))
302
- else:
303
- item_value = item.get('value')
304
- kwargs_item_result.append(item_value)
305
- return kwargs_item_result
306
-
307
-
308
- def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
309
- """
310
- Function Description:
311
- Based on API basic information, generate input parameters: args, kwargs, for API forward running
312
- Parameter:
313
- api_info: API basic information. Dict
314
- api_name: API name
315
- need_grad: set grad for backward
316
- convert_type: convert ori_type to dist_type flag.
317
- """
318
- check_object_type(api_info, dict)
319
- if convert_type and convert_type not in Const.CONVERT:
320
- error_info = f"convert_type params not support {convert_type}."
321
- raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
322
- kwargs_params = gen_kwargs(api_info, convert_type, real_data_path)
323
- if api_info.get("input_args"):
324
- args_params = gen_args(api_info.get("input_args"), api_name, need_grad, convert_type, real_data_path)
325
- else:
326
- logger.warning(f'Warning: No args in {api_info} ')
327
- args_params = []
328
- return args_params, kwargs_params
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ import os
19
+ import math
20
+ import torch
21
+ import numpy
22
+
23
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
24
+ from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \
25
+ CompareException
26
+ from msprobe.core.common.file_utils import FileChecker, load_npy
27
+ from msprobe.pytorch.common.log import logger
28
+ from msprobe.pytorch.common.utils import load_pt
29
+ from msprobe.core.common.const import Const, FileCheckConst
30
+
31
+ TORCH_TYPE = ["torch.device", "torch.dtype"]
32
+ TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
33
+ FLOAT_TYPE = ['torch.float32', 'torch.float', 'torch.float64', 'torch.double', 'torch.float16',
34
+ 'torch.half', 'torch.bfloat16']
35
+ NUMPY_TYPE = ["numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32",
36
+ "numpy.uint64", "numpy.float16", "numpy.float32", "numpy.float64", "numpy.float128", "numpy.complex64",
37
+ "numpy.complex128", "numpy.complex256", "numpy.bool_", "numpy.string_", "numpy.bytes_", "numpy.unicode_"]
38
+
39
+
40
+ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
41
+ """
42
+ Function Description:
43
+ Based on arg basic information, generate arg data
44
+ Parameter:
45
+ info: arg basic information. Dict
46
+ api_name: API name
47
+ need_grad: set Tensor grad for backward
48
+ convert_type: convert ori_type to dist_type flag.
49
+ """
50
+ check_object_type(info, dict)
51
+ data_type = info.get('type')
52
+ data_path = info.get('datapath', info.get('data_name'))
53
+ data_path = get_full_data_path(data_path, real_data_path)
54
+ if data_type in TENSOR_DATA_LIST:
55
+ if data_path:
56
+ data = gen_real_tensor(data_path, convert_type)
57
+ else:
58
+ data = gen_random_tensor(info, convert_type)
59
+ if api_name in hf_32_standard_api and data.dtype == torch.float32:
60
+ data = fp32_to_hf32_to_fp32(data)
61
+ if info.get('requires_grad') and need_grad:
62
+ data.requires_grad_(True)
63
+ temp_data = data * 1
64
+ data = temp_data.type_as(data)
65
+ data.retain_grad()
66
+ elif data_type.startswith("numpy"):
67
+ if data_type not in NUMPY_TYPE:
68
+ raise Exception("{} is not supported now".format(data_type))
69
+ data = info.get("value")
70
+ try:
71
+ data = eval(data_type)(data)
72
+ except Exception as err:
73
+ logger.error("Failed to convert the type to numpy: %s" % str(err))
74
+ elif data_type == "torch.Size":
75
+ data = torch.Size(info.get("value"))
76
+ else:
77
+ data = info.get('value')
78
+ if info.get("type") == "slice":
79
+ data = slice(*data)
80
+ if info.get("type") == "ellipsis":
81
+ data = ...
82
+ return data
83
+
84
+
85
+ def gen_real_tensor(data_path, convert_type):
86
+ """
87
+ Function Description:
88
+ Based on API data path, generate input parameters real data
89
+ Parameter:
90
+ data_path: API data path
91
+ convert_type: convert ori_type to dist_type flag.
92
+ """
93
+ data_path = os.path.realpath(data_path)
94
+ data_path_checker = FileChecker(data_path, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE)
95
+ data_path = data_path_checker.common_check()
96
+ if not data_path.endswith('.pt') and not data_path.endswith('.npy'):
97
+ error_info = f"The file: {data_path} is not a pt or numpy file."
98
+ raise CompareException(CompareException.INVALID_FILE_ERROR, error_info)
99
+ if data_path.endswith('.pt'):
100
+ data = load_pt(data_path, to_cpu=True)
101
+ else:
102
+ data_np = load_npy(data_path)
103
+ data = torch.from_numpy(data_np)
104
+ if convert_type:
105
+ ori_dtype = Const.CONVERT.get(convert_type)[0]
106
+ dist_dtype = Const.CONVERT.get(convert_type)[1]
107
+ if str(data.dtype) == ori_dtype:
108
+ data = data.type(eval(dist_dtype))
109
+ return data
110
+
111
+
112
+ def gen_random_tensor(info, convert_type):
113
+ """
114
+ Function Description:
115
+ Based on API MAX and MIN, generate input parameters random data
116
+ Parameter:
117
+ info: API data info
118
+ convert_type: convert ori_type to dist_type flag.
119
+ """
120
+ check_object_type(info, dict)
121
+ low, high = info.get('Min'), info.get('Max')
122
+ low_origin, high_origin = info.get('Min_origin'), info.get('Max_origin')
123
+ low_info = [low, low_origin]
124
+ high_info = [high, high_origin]
125
+ data_dtype = info.get('dtype')
126
+ shape = tuple(info.get('shape'))
127
+ if not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
128
+ error_info = f'Data info Min: {low} , Max: {high}, info type must be int or float.'
129
+ raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
130
+ if data_dtype == "torch.bool":
131
+ data = gen_bool_tensor(low, high, shape)
132
+ else:
133
+ data = gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type)
134
+ return data
135
+
136
+
137
+ def fp32_to_hf32_to_fp32(input_tensor):
138
+ # 将输入的float32 tensor转为hf32 tensor,再转为float32 tensor
139
+ input_np = input_tensor.detach().numpy()
140
+ input_int = input_np.view(numpy.int32)
141
+ input_int = numpy.right_shift(numpy.right_shift(input_int, 11) + 1, 1)
142
+ input_int = numpy.left_shift(input_int, 12)
143
+ input_fp32 = input_int.view(numpy.float32)
144
+ input_hf32 = torch.from_numpy(input_fp32)
145
+ return input_hf32
146
+
147
+
148
+ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
149
+ """
150
+ Function Description:
151
+ Based on API basic information, generate int or float tensor
152
+ Parameter:
153
+ low_info: [low, low_origin], low is the minimum value in the tensor removed inf and nan,
154
+ low_origin is the original minimum value in the tensor
155
+ high_info: [high, high_origin], high is the maximum value in the tensor removed inf and nan,
156
+ high_origin is the original maximum value in the tensor
157
+ shape:The shape of Tensor
158
+ data_dtype: The data type of Tensor
159
+ convert_type: convert ori_type to dist_type flag.
160
+ """
161
+ if convert_type:
162
+ ori_dtype = Const.CONVERT.get(convert_type)[0]
163
+ if ori_dtype == data_dtype:
164
+ data_dtype = Const.CONVERT.get(convert_type)[1]
165
+ low, low_origin = low_info[0], low_info[1]
166
+ high, high_origin = high_info[0], high_info[1]
167
+ if data_dtype in FLOAT_TYPE:
168
+ if math.isnan(high):
169
+ tensor = torch._C._VariableFunctionsClass.full(shape, float('nan'), dtype=eval(data_dtype))
170
+ return tensor
171
+ #high_origin为新版json中的属性,只有当high_origin不为None,且high为inf或-inf时,原tensor全为inf或-inf
172
+ if high_origin and high in [float('inf'), float('-inf')]:
173
+ tensor = torch._C._VariableFunctionsClass.full(shape, high, dtype=eval(data_dtype))
174
+ tensor[-1] = low
175
+ return tensor
176
+ low_scale, high_scale = low, high
177
+ dtype_finfo = torch.finfo(eval(data_dtype))
178
+ #适配老版json high和lowinf或-inf的情况,取dtype的最大值或最小值进行放缩
179
+ if high == float('inf'):
180
+ high_scale = dtype_finfo.max
181
+ elif high == float('-inf'):
182
+ high_scale = dtype_finfo.min
183
+ if low == float('inf'):
184
+ low_scale = dtype_finfo.max
185
+ elif low == float('-inf'):
186
+ low_scale = dtype_finfo.min
187
+
188
+ scale = high_scale - low_scale
189
+ rand01 = torch.rand(shape, dtype=eval(data_dtype))
190
+ tensor = rand01 * scale + low_scale
191
+ elif 'int' in data_dtype or 'long' in data_dtype:
192
+ low, high = int(low), int(high)
193
+ tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype))
194
+ else:
195
+ logger.error('Dtype is not supported: ' + data_dtype)
196
+ raise NotImplementedError()
197
+ if tensor.nelement() == 0:
198
+ return tensor
199
+ tmp_tensor = tensor.reshape(-1)
200
+ if high_origin and math.isnan(high_origin):
201
+ if tmp_tensor.numel() <= 2:
202
+ tmp_tensor[0] = float('nan')
203
+ tmp_tensor[-1] = high
204
+ else:
205
+ tmp_tensor[0] = low
206
+ tmp_tensor[1] = float('nan')
207
+ tmp_tensor[-1] = high
208
+ else:
209
+ tmp_tensor[0] = low
210
+ tmp_tensor[-1] = high
211
+ if high_origin in [float('inf'), float('-inf')]:
212
+ tmp_tensor[-1] = high_origin
213
+ if low_origin in [float('inf'), float('-inf')]:
214
+ tmp_tensor[0] = low_origin
215
+ data = tmp_tensor.reshape(shape)
216
+ return data
217
+
218
+
219
+ def gen_bool_tensor(low, high, shape):
220
+ """
221
+ Function Description:
222
+ Based on API basic information, generate bool tensor
223
+ Parameter:
224
+ low: The minimum value in Tensor
225
+ high: The max value in Tensor
226
+ shape:The shape of Tensor
227
+ """
228
+ low, high = int(low), int(high)
229
+ if low > high:
230
+ low, high = high, low
231
+ tensor = torch.randint(low, high + 1, shape)
232
+ data = torch.gt(tensor, 0)
233
+ return data
234
+
235
+
236
+ def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
237
+ """
238
+ Function Description:
239
+ Based on API basic information, generate input parameters: args, for API forward running
240
+ Parameter:
241
+ api_info: API basic information. List
242
+ api_name: API name
243
+ need_grad: set Tensor grad for backward
244
+ convert_type: convert ori_type to dist_type flag.
245
+ real_data_path: the root directory for storing real data.
246
+ """
247
+ check_object_type(args_info, list)
248
+ args_result = []
249
+ for arg in args_info:
250
+ if isinstance(arg, (list, tuple)):
251
+ data = gen_args(arg, api_name, need_grad, convert_type, real_data_path)
252
+ elif isinstance(arg, dict):
253
+ data = gen_data(arg, api_name, need_grad, convert_type, real_data_path)
254
+ elif arg is None:
255
+ data = None
256
+ else:
257
+ logger.warning(f'Warning: {arg} is not supported')
258
+ raise NotImplementedError()
259
+ args_result.append(data)
260
+ return args_result
261
+
262
+
263
+ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
264
+ """
265
+ Function Description:
266
+ Based on API basic information, generate input parameters: kwargs, for API forward running
267
+ Parameter:
268
+ api_info: API basic information. Dict
269
+ api_name: API name
270
+ convert_type: convert ori_type to dist_type flag.
271
+ real_data_path: the root directory for storing real data.
272
+ """
273
+ check_object_type(api_info, dict)
274
+ kwargs_params = api_info.get("input_kwargs")
275
+ for key, value in kwargs_params.items():
276
+ if isinstance(value, (list, tuple)):
277
+ kwargs_params[key] = gen_list_kwargs(value, api_name, convert_type, real_data_path)
278
+ elif value is None:
279
+ kwargs_params[key] = None
280
+ elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"):
281
+ kwargs_params[key] = gen_data(value, api_name, True, convert_type, real_data_path)
282
+ elif value.get('type') in TORCH_TYPE:
283
+ gen_torch_kwargs(kwargs_params, key, value)
284
+ else:
285
+ kwargs_params[key] = value.get('value')
286
+ return kwargs_params
287
+
288
+
289
+ def gen_torch_kwargs(kwargs_params, key, value):
290
+ if value.get('type') != "torch.device":
291
+ kwargs_params[key] = eval(value.get('value'))
292
+
293
+
294
+ def gen_list_kwargs(kwargs_item_value, api_name, convert_type, real_data_path=None):
295
+ """
296
+ Function Description:
297
+ When kwargs value is list, generate the list of kwargs result
298
+ Parameter:
299
+ kwargs_item_value: kwargs value before to generate. List
300
+ api_name: API name
301
+ convert_type: convert ori_type to dist_type flag.
302
+ """
303
+ kwargs_item_result = []
304
+ for item in kwargs_item_value:
305
+ if item.get('type') in TENSOR_DATA_LIST:
306
+ item_value = gen_data(item, api_name, False, convert_type, real_data_path)
307
+ elif item.get('type') == "torch.Size":
308
+ item_value = torch.Size(item.get('value'))
309
+ else:
310
+ item_value = item.get('value')
311
+ kwargs_item_result.append(item_value)
312
+ return kwargs_item_result
313
+
314
+
315
+ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
316
+ """
317
+ Function Description:
318
+ Based on API basic information, generate input parameters: args, kwargs, for API forward running
319
+ Parameter:
320
+ api_info: API basic information. Dict
321
+ api_name: API name
322
+ need_grad: set grad for backward
323
+ convert_type: convert ori_type to dist_type flag.
324
+ """
325
+ check_object_type(api_info, dict)
326
+ if convert_type and convert_type not in Const.CONVERT:
327
+ error_info = f"convert_type params not support {convert_type}."
328
+ raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
329
+ kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path)
330
+ if api_info.get("input_args"):
331
+ args_params = gen_args(api_info.get("input_args"), api_name, need_grad, convert_type, real_data_path)
332
+ else:
333
+ logger.warning(f'Warning: No args in {api_info} ')
334
+ args_params = []
335
+ return args_params, kwargs_params