mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -0,0 +1,365 @@
1
+ import json
2
+ import os
3
+ import math
4
+ from enum import Enum, auto
5
+ import torch
6
+ try:
7
+ import torch_npu
8
+ except ImportError:
9
+ pass
10
+ from tabulate import tabulate
11
+
12
+ TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
13
+ TORCH_BOOL_TYPE = ["torch.bool"]
14
+ TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int",
15
+ "torch.int64", "torch.long"]
16
+ TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float",
17
+ "torch.float64", "torch.double"]
18
+ TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"]
19
+ RAISE_PRECISION = {{
20
+ "torch.float16": torch.float32,
21
+ "torch.half": torch.float32,
22
+ "torch.bfloat16": torch.float32,
23
+ "torch.float32": torch.float64,
24
+ "torch.float": torch.float64
25
+ }}
26
+ THOUSANDTH_THRESHOLDING = 0.001
27
+ BACKWARD = 'backward'
28
+
29
+ class CompareStandard(Enum):
30
+ BINARY_EQUALITY_STANDARD = auto()
31
+ ABSOLUTE_THRESHOLD_STANDARD = auto()
32
+ ULP_ERROR_STANDARD = auto()
33
+ BENCHMARK_STANDARD = auto()
34
+ THOUSANDTH_STANDARD = auto()
35
+
36
+ def load_pt(pt_path, to_cpu=False):
37
+ pt_path = os.path.realpath(pt_path)
38
+ try:
39
+ if to_cpu:
40
+ pt = torch.load(pt_path, map_location=torch.device("cpu"))
41
+ else:
42
+ pt = torch.load(pt_path)
43
+ except Exception as e:
44
+ raise RuntimeError(f"load pt file {{pt_path}} failed") from e
45
+ return pt
46
+
47
+ def get_device():
48
+ if torch.cuda.is_available():
49
+ device = torch.device("cuda")
50
+ elif torch_npu.npu.is_available():
51
+ device = torch.device("npu")
52
+ else:
53
+ raise Exception("Error: This device is not NPU or GPU!")
54
+ return device
55
+
56
+
57
+ def generate_bool_tensor(low, high, shape):
58
+ low, high = int(low), int(high)
59
+ tensor = torch.randint(low, high + 1, shape)
60
+ bool_tensor = torch.gt(tensor, 0)
61
+ return bool_tensor
62
+
63
+
64
+ def generate_numerical_tensor(low, high, shape, data_dtype):
65
+ if data_dtype in TORCH_FLOAT_TYPE:
66
+ scale = high - low
67
+ rand01 = torch.rand(shape, dtype=eval(data_dtype))
68
+ tensor = rand01 * scale + low
69
+ elif data_dtype in TORCH_INT_TYPE:
70
+ low, high = int(low), int(high)
71
+ tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype))
72
+ else:
73
+ raise NotImplementedError(f"{{data_dtype}} is not supported!")
74
+ if torch.numel(tensor) == 0:
75
+ return tensor
76
+ tmp_tensor = tensor.reshape(-1)
77
+ tmp_tensor[0] = low
78
+ tmp_tensor[-1] = high
79
+ data = tmp_tensor.reshape(shape)
80
+ return data
81
+
82
+
83
+ def generate_random_tensor(info):
84
+ low, high = info.get('Min'), info.get('Max')
85
+ data_dtype = info.get('dtype')
86
+ shape = tuple(info.get('shape'))
87
+ if data_dtype == "torch.bool":
88
+ data = generate_bool_tensor(low, high, shape)
89
+ else:
90
+ data = generate_numerical_tensor(low, high, shape, data_dtype)
91
+ return data
92
+
93
+
94
+ def generate_real_tensor(data_path):
95
+ data_path = os.path.realpath(data_path)
96
+ data = load_pt(data_path, to_cpu = True)
97
+ return data
98
+
99
+
100
+ def generate_data(info):
101
+ data_type = info.get("type")
102
+ data_path = info.get("data_name")
103
+ data_grad = info.get("requires_grad")
104
+ if data_type in TENSOR_DATA_LIST:
105
+ if data_path:
106
+ data = generate_real_tensor(data_path)
107
+ else:
108
+ data = generate_random_tensor(info)
109
+ else:
110
+ data = info.get("value")
111
+ if data_grad == True:
112
+ data.requires_grad_(True)
113
+ return data
114
+
115
+
116
+ def get_input(propagation):
117
+ {args_element_assignment}
118
+ args_device = [{args_list_generator_device}]
119
+ args_bench = [{args_list_generator_bench}]
120
+ {kwargs_value_assignment}
121
+ kwargs_device = {{{kwargs_dict_generator_device}}}
122
+ kwargs_bench = {{{kwargs_dict_generator_bench}}}
123
+ {args_element_assignment_backward}
124
+ args_device_backward = [{args_list_generator_device_backward}]
125
+ args_bench_backward = [{args_list_generator_bench_backward}]
126
+ if propagation == BACKWARD:
127
+ return args_device, kwargs_device, args_bench, kwargs_bench, args_device_backward, args_bench_backward
128
+ return args_device, kwargs_device, args_bench, kwargs_bench
129
+
130
+ def exec_api(args, kwargs, args_grad_input, propagation):
131
+ output = {api_type}.{api_name}(*args, **kwargs)
132
+ if propagation == BACKWARD:
133
+ args_input_tensor = [tensor for tensor in args if isinstance(tensor, torch.Tensor) and tensor.requires_grad]
134
+ args_input_tensor.extend(
135
+ [value for value in kwargs.values() if isinstance(value, torch.Tensor) and value.requires_grad])
136
+ output_backward = torch.autograd.grad(outputs=output, inputs=args_input_tensor, grad_outputs=args_grad_input)
137
+ return output_backward
138
+ return output
139
+
140
+ def compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol):
141
+ out_bench = out_bench.to(out_device.dtype)
142
+ min = torch.finfo(out_device.dtype).min
143
+ max = torch.finfo(out_device.dtype).max
144
+ bench_clip = torch.clamp(out_bench, min=min, max=max)
145
+ device_clip = torch.clamp(out_device, min=min, max=max)
146
+ clipped_abs_ae = torch.abs(device_clip - bench_clip)
147
+ clipped_re = clipped_abs_ae / abs_bench_with_eps
148
+ pass_mask = torch.less_equal(clipped_re, rtol)
149
+ both_nan_mask = torch.logical_and(torch.isnan(out_device), torch.isnan(bench_clip))
150
+ pass_mask = torch.logical_or(pass_mask, both_nan_mask)
151
+ not_pass_mask = torch.logical_not(pass_mask)
152
+ not_pass_mask = torch.logical_and(not_pass_mask, inf_nan_mask)
153
+ inf_nan_err_cnt = torch.sum(not_pass_mask)
154
+ return 0 if torch.sum(inf_nan_mask) == 0 else inf_nan_err_cnt / torch.sum(inf_nan_mask)
155
+
156
+
157
+ def compute_rmse(abs_err, normal_value_mask):
158
+ if torch.sum(normal_value_mask) == 0:
159
+ return 0
160
+ else:
161
+ masked_ae = torch.where(normal_value_mask, abs_err, 0)
162
+ mse = torch.sum(torch.square(masked_ae)) / torch.sum(normal_value_mask)
163
+ rmse = torch.sqrt(mse)
164
+ return rmse
165
+
166
+
167
+ def compute_error_balance(out_device, out_bench):
168
+ larger_count = torch.sum(torch.greater(out_device - out_bench.to(out_device.dtype), 0))
169
+ smaller_count = torch.sum(torch.less(out_device - out_bench.to(out_device.dtype), 0))
170
+ if torch.numel(out_bench) == 0:
171
+ raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}")
172
+ error_balance = abs(larger_count - smaller_count) / torch.numel(out_bench)
173
+ return error_balance
174
+
175
+
176
+ def compare_tensor(out_device, out_bench, api_name):
177
+ if out_device.shape != out_bench.shape:
178
+ print("ERROR: shape of out_device and out_bench is not equal!")
179
+ return None
180
+ if torch.numel(out_bench) == 0:
181
+ print("Both out_device and out_bench have zero elements.")
182
+ return None
183
+ dtype_device = out_device.dtype
184
+ dtype_bench = out_bench.dtype
185
+ headers = ["Metric", "Value"]
186
+ table = [
187
+ ["Shape", out_bench.shape],
188
+ ["Dtype of out_device", out_device.dtype],
189
+ ["Dtype of out_bench", out_bench.dtype]
190
+ ]
191
+ if str(dtype_device) in TORCH_FLOAT_TYPE and str(dtype_bench) in TORCH_FLOAT_TYPE \
192
+ or str(dtype_device) in TORCH_INT_TYPE and str(dtype_bench) in TORCH_INT_TYPE \
193
+ or str(dtype_device) in TORCH_BOOL_TYPE and str(dtype_bench) in TORCH_BOOL_TYPE:
194
+ out_device = out_device.to(torch.device("cpu"))
195
+ if str(dtype_device) in TORCH_BOOL_TYPE or str(dtype_device) in TORCH_INT_TYPE or compare_standard == CompareStandard.BINARY_EQUALITY_STANDARD:
196
+ error_number = torch.sum(out_device != out_bench).item()
197
+ if torch.numel(out_bench) == 0:
198
+ raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}")
199
+ error_rate = error_number / torch.numel(out_bench)
200
+ table.append(["Compare Standard", "Binary Equality Standard"])
201
+ table.append(["Error Rate", error_rate])
202
+ else:
203
+ abs_err = torch.abs(out_device - out_bench)
204
+ abs_bench = torch.abs(out_bench)
205
+ if dtype_bench == torch.float32:
206
+ eps = 2 ** -23
207
+ if dtype_bench == torch.float64:
208
+ eps = 2 ** -52
209
+ abs_bench_with_eps = abs_bench + eps
210
+ rel_err = torch.abs(abs_err / abs_bench_with_eps)
211
+ device_finite_mask = torch.isfinite(out_device)
212
+ bench_finite_mask = torch.isfinite(out_bench.to(dtype_device))
213
+ both_finite_mask = torch.logical_and(device_finite_mask, bench_finite_mask)
214
+ inf_nan_mask = torch.logical_not(both_finite_mask)
215
+ if compare_standard == CompareStandard.ABSOLUTE_THRESHOLD_STANDARD:
216
+ if dtype_device == torch.float16:
217
+ rtol, small_value, small_value_atol = 1.0e-3, 1.0e-3, 1.0e-5
218
+ elif dtype_device == torch.bfloat16:
219
+ rtol, small_value, small_value_atol = 4.0e-3, 1.0e-3, 1.0e-5
220
+ else:
221
+ rtol, small_value, small_value_atol = 1.0e-6, 1.0e-6, 1.0e-9
222
+ small_value_mask = torch.less_equal(abs_bench, small_value)
223
+ small_value_mask = torch.logical_and(small_value_mask, both_finite_mask)
224
+ normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask))
225
+ inf_nan_proportion = compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol)
226
+ rel_err_mask = torch.greater(rel_err, rtol)
227
+ rel_err_mask = torch.logical_and(rel_err_mask, normal_value_mask)
228
+ if torch.sum(normal_value_mask) == 0:
229
+ rel_err_proportion = 0
230
+ else:
231
+ rel_err_proportion = torch.sum(rel_err_mask) / torch.sum(normal_value_mask)
232
+ abs_err_mask = torch.greater(abs_err, small_value_atol)
233
+ abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask)
234
+ if torch.sum(small_value_mask) == 0:
235
+ abs_err_proportion = 0
236
+ else:
237
+ abs_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask)
238
+ table.append(["Compare Standard", "Absolute Threshold Standard"])
239
+ table.append(["Relative Error Ratio", rel_err_proportion])
240
+ table.append(["Absolute Error Ratio", abs_err_proportion])
241
+ elif compare_standard == CompareStandard.ULP_ERROR_STANDARD:
242
+ if dtype_device == torch.float16:
243
+ min_eb, exponent_num = -14, 10
244
+ elif dtype_device == torch.bfloat16:
245
+ min_eb, exponent_num = -126, 7
246
+ else:
247
+ min_eb, exponent_num = -126, 23
248
+ eb = torch.where(abs_bench == 0, torch.zeros(out_bench.shape), torch.floor(torch.log2(abs_bench)))
249
+ eb = torch.maximum(eb, min_eb * torch.ones(out_bench.shape))
250
+ if dtype_device == torch.float32:
251
+ ulp_err = (out_device.to(torch.float64) - out_bench).to(torch.float64) * torch.exp2(-eb + exponent_num).to(torch.float64)
252
+ else:
253
+ ulp_err = (out_device.to(torch.float32) - out_bench).to(torch.float32) * torch.exp2(-eb + exponent_num).to(torch.float32)
254
+ ulp_err = torch.abs(ulp_err)
255
+ max_ulp_err = torch.max(ulp_err)
256
+ mean_ulp_err = torch.mean(ulp_err)
257
+ if torch.numel(out_bench) == 0:
258
+ raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}")
259
+ if dtype_device == torch.float32:
260
+ ulp_err_proportion = torch.sum(ulp_err > 32) / torch.numel(out_bench)
261
+ else:
262
+ ulp_err_proportion = torch.sum(ulp_err > 1) / torch.numel(out_bench)
263
+ table.append(["Compare Standard", "ULP error Standard"])
264
+ table.append(["Maximum ULP Error", max_ulp_err])
265
+ table.append(["Mean ULP Error", mean_ulp_err])
266
+ table.append(["ULP Error Proportion", ulp_err_proportion])
267
+ elif compare_standard == CompareStandard.THOUSANDTH_STANDARD:
268
+ rel_err_origin = torch.abs(abs_err / abs_bench_with_eps)
269
+ if torch.numel(rel_err_origin) == 0:
270
+ thousand_res = 1
271
+ else:
272
+ thousand_res = torch.divide(torch.sum(rel_err < THOUSANDTH_THRESHOLDING), torch.numel(rel_err_origin))
273
+ thousand_status = thousand_res > (1 - THOUSANDTH_THRESHOLDING)
274
+ table.append(["Compare Standard", "Thousandth Standard"])
275
+ table.append(["Thousandth ratio", thousand_res])
276
+ else:
277
+ if dtype_device == torch.float16:
278
+ small_value, small_value_atol = 1.0e-3, 1.0e-5
279
+ elif dtype_device == torch.bfloat16:
280
+ small_value, small_value_atol = 1.0e-3, 1.0e-5
281
+ else:
282
+ small_value, small_value_atol = 1.0e-6, 1.0e-9
283
+ small_value_mask = torch.less_equal(abs_bench, small_value)
284
+ small_value_mask = torch.logical_and(small_value_mask, both_finite_mask)
285
+ normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask))
286
+ abs_err_mask = torch.greater(abs_err, small_value_atol)
287
+ abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask)
288
+ if torch.sum(small_value_mask) == 0:
289
+ small_value_err_proportion = 0
290
+ else:
291
+ small_value_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask)
292
+ rel_err = torch.where(normal_value_mask, rel_err, -1 * torch.ones(out_device.shape))
293
+ if torch.max(rel_err) >= 0:
294
+ max_rel_err = torch.max(rel_err)
295
+ else:
296
+ max_rel_err = 0
297
+ if torch.sum(normal_value_mask) == 0:
298
+ mean_rel_err = 0
299
+ else:
300
+ mean_rel_err = torch.sum(torch.clamp(rel_err, min=0)) / torch.sum(normal_value_mask)
301
+ rmse = compute_rmse(abs_err, normal_value_mask)
302
+ error_balance = compute_error_balance(out_device, out_bench)
303
+ table.append(["Compare Standard", "Benchmark Standard"])
304
+ table.append(["Small Value Error Proportion", small_value_err_proportion])
305
+ table.append(["Maximum Relative Error", max_rel_err])
306
+ table.append(["Mean Relative Error", mean_rel_err])
307
+ table.append(["Root Mean Squared Error", rmse])
308
+ table.append(["Error Balance", error_balance])
309
+ else:
310
+ print(f"ERROR: out_device dtype is {{dtype_device}}, out_bench dtype is {{dtype_bench}}, not comparable.")
311
+ return None
312
+ print(tabulate(table, headers, tablefmt='grid'))
313
+ return None
314
+
315
+
316
+ def compare_element(out_device, out_bench, api_name):
317
+ if type(out_device) != type(out_bench):
318
+ print("ERROR: out_device and out_bench is not the same type!")
319
+ return None
320
+ if isinstance(out_bench, torch.Tensor):
321
+ compare_tensor(out_device, out_bench, api_name)
322
+ elif isinstance(out_bench, (bool, int, float, str)):
323
+ if out_device == out_bench:
324
+ print("PASS: out_device and out_bench equals.")
325
+ else:
326
+ print("ERROR: out_device and out_bench is not equal!")
327
+ else:
328
+ print(f"ERROR: comparison of type {{type(out_bench)}} is not supported.")
329
+ return None
330
+
331
+
332
+ def compare(out_device, out_bench, api_name):
333
+ print("Compare result:")
334
+ if type(out_device) != type(out_bench):
335
+ print("ERROR: out_device and out_bench is not the same type!")
336
+ return None
337
+ if isinstance(out_bench, (list, tuple)):
338
+ if len(out_device) != len(out_bench):
339
+ print("ERROR: len of out_device and out_bench is different!")
340
+ return None
341
+ for index, _ in enumerate(out_bench):
342
+ print(f"index {{index}}:")
343
+ compare_element(out_device[index], out_bench[index], api_name)
344
+ else:
345
+ compare_element(out_device, out_bench, api_name)
346
+
347
+ if __name__ == "__main__":
348
+ device = get_device()
349
+ api_name = "{api_name}"
350
+ propagation = "{propagation}"
351
+ compare_standard = {compare_standard}
352
+ torch.manual_seed({random_seed})
353
+ for i in range({iter_times}):
354
+ print(f"iter: {{i}}:")
355
+ if propagation == BACKWARD:
356
+ args_device, kwargs_device, args_bench, kwargs_bench, args_device_backward, args_bench_backward = get_input(propagation)
357
+ output_device = exec_api(args_device, kwargs_device, args_device_backward, propagation)
358
+ output_bench = exec_api(args_bench, kwargs_bench, args_bench_backward, propagation)
359
+ compare(output_device, output_bench, api_name)
360
+ else:
361
+ args_device, kwargs_device, args_bench, kwargs_bench = get_input(propagation)
362
+ output_device = exec_api(args_device, kwargs_device, None, propagation)
363
+ output_bench = exec_api(args_bench, kwargs_bench, None, propagation)
364
+ compare(output_device, output_bench, api_name)
365
+ print("Compare finished.")
@@ -1,8 +1,9 @@
1
1
  #!/usr/bin/env python3
2
2
  # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
7
  # you may not use this file except in compliance with the License.
7
8
  # You may obtain a copy of the License at
8
9
  #
@@ -13,7 +14,6 @@
13
14
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
15
  # See the License for the specific language governing permissions and
15
16
  # limitations under the License.
16
- """
17
17
 
18
18
  import os
19
19
  import math
@@ -22,19 +22,28 @@ import numpy
22
22
 
23
23
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
24
24
  from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \
25
- CompareException
25
+ CompareException, get_module_and_atttribute_name, get_attribute
26
26
  from msprobe.core.common.file_utils import FileChecker, load_npy
27
27
  from msprobe.pytorch.common.log import logger
28
28
  from msprobe.pytorch.common.utils import load_pt
29
- from msprobe.core.common.const import Const, FileCheckConst
29
+ from msprobe.core.common.const import Const, FileCheckConst, CompareConst
30
30
 
31
31
  TORCH_TYPE = ["torch.device", "torch.dtype"]
32
32
  TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
33
- FLOAT_TYPE = ['torch.float32', 'torch.float', 'torch.float64', 'torch.double', 'torch.float16',
34
- 'torch.half', 'torch.bfloat16']
35
- NUMPY_TYPE = ["numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32",
36
- "numpy.uint64", "numpy.float16", "numpy.float32", "numpy.float64", "numpy.float128", "numpy.complex64",
37
- "numpy.complex128", "numpy.complex256", "numpy.bool_", "numpy.string_", "numpy.bytes_", "numpy.unicode_"]
33
+ FLOAT_TYPE = [
34
+ 'torch.float32',
35
+ 'torch.float',
36
+ 'torch.float64',
37
+ 'torch.double',
38
+ 'torch.float16',
39
+ 'torch.half',
40
+ 'torch.bfloat16'
41
+ ]
42
+ NUMPY_TYPE = [
43
+ "numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32",
44
+ "numpy.uint64", "numpy.float16", "numpy.float32", "numpy.float64", "numpy.float128", "numpy.complex64",
45
+ "numpy.complex128", "numpy.complex256", "numpy.bool_", "numpy.string_", "numpy.bytes_", "numpy.unicode_"
46
+ ]
38
47
 
39
48
 
40
49
  def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
@@ -68,7 +77,8 @@ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
68
77
  raise Exception("{} is not supported now".format(data_type))
69
78
  data = info.get("value")
70
79
  try:
71
- data = eval(data_type)(data)
80
+ module_name, attribute_name = get_module_and_atttribute_name(data_type)
81
+ data = get_attribute(module_name, attribute_name)(data)
72
82
  except Exception as err:
73
83
  logger.error("Failed to convert the type to numpy: %s" % str(err))
74
84
  elif data_type == "torch.Size":
@@ -104,8 +114,9 @@ def gen_real_tensor(data_path, convert_type):
104
114
  if convert_type:
105
115
  ori_dtype = Const.CONVERT.get(convert_type)[0]
106
116
  dist_dtype = Const.CONVERT.get(convert_type)[1]
117
+ module_name, attribute_name = get_module_and_atttribute_name(dist_dtype)
107
118
  if str(data.dtype) == ori_dtype:
108
- data = data.type(eval(dist_dtype))
119
+ data = data.type(get_attribute(module_name, attribute_name))
109
120
  return data
110
121
 
111
122
 
@@ -118,13 +129,22 @@ def gen_random_tensor(info, convert_type):
118
129
  convert_type: convert ori_type to dist_type flag.
119
130
  """
120
131
  check_object_type(info, dict)
121
- low, high = info.get('Min'), info.get('Max')
122
- low_origin, high_origin = info.get('Min_origin'), info.get('Max_origin')
132
+
133
+ low_origin = info.get('Min')
134
+ low = info.get('Min_except_inf_nan', low_origin)
135
+ high_origin = info.get('Max')
136
+ high = info.get('Max_except_inf_nan', high_origin)
137
+
123
138
  low_info = [low, low_origin]
124
139
  high_info = [high, high_origin]
125
140
  data_dtype = info.get('dtype')
126
141
  shape = tuple(info.get('shape'))
127
- if not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
142
+ if 0 in shape:
143
+ low, low_origin = 0, 0
144
+ high, high_origin = 0, 0
145
+ low_info = [low, low_origin]
146
+ high_info = [high, high_origin]
147
+ elif not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
128
148
  error_info = f'Data info Min: {low} , Max: {high}, info type must be int or float.'
129
149
  raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
130
150
  if data_dtype == "torch.bool":
@@ -164,33 +184,35 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
164
184
  data_dtype = Const.CONVERT.get(convert_type)[1]
165
185
  low, low_origin = low_info[0], low_info[1]
166
186
  high, high_origin = high_info[0], high_info[1]
167
- if data_dtype in FLOAT_TYPE:
187
+ module_name, attribute_name = get_module_and_atttribute_name(data_dtype)
188
+ dtype = get_attribute(module_name, attribute_name)
189
+ if data_dtype in FLOAT_TYPE:
168
190
  if math.isnan(high):
169
- tensor = torch._C._VariableFunctionsClass.full(shape, float('nan'), dtype=eval(data_dtype))
191
+ tensor = torch.full(shape, float('nan'), dtype=dtype)
170
192
  return tensor
171
193
  #high_origin为新版json中的属性,只有当high_origin不为None,且high为inf或-inf时,原tensor全为inf或-inf
172
- if high_origin and high in [float('inf'), float('-inf')]:
173
- tensor = torch._C._VariableFunctionsClass.full(shape, high, dtype=eval(data_dtype))
194
+ if high_origin and high in [float(CompareConst.INF), float(CompareConst.NEG_INF)]:
195
+ tensor = torch.full(shape, high, dtype=dtype)
174
196
  tensor[-1] = low
175
197
  return tensor
176
198
  low_scale, high_scale = low, high
177
- dtype_finfo = torch.finfo(eval(data_dtype))
199
+ dtype_finfo = torch.finfo(dtype)
178
200
  #适配老版json high和low为inf或-inf的情况,取dtype的最大值或最小值进行放缩
179
- if high == float('inf'):
201
+ if high == float(CompareConst.INF):
180
202
  high_scale = dtype_finfo.max
181
- elif high == float('-inf'):
203
+ elif high == float(CompareConst.NEG_INF):
182
204
  high_scale = dtype_finfo.min
183
- if low == float('inf'):
205
+ if low == float(CompareConst.INF):
184
206
  low_scale = dtype_finfo.max
185
- elif low == float('-inf'):
207
+ elif low == float(CompareConst.NEG_INF):
186
208
  low_scale = dtype_finfo.min
187
209
 
188
210
  scale = high_scale - low_scale
189
- rand01 = torch.rand(shape, dtype=eval(data_dtype))
211
+ rand01 = torch.rand(shape, dtype=dtype)
190
212
  tensor = rand01 * scale + low_scale
191
213
  elif 'int' in data_dtype or 'long' in data_dtype:
192
214
  low, high = int(low), int(high)
193
- tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype))
215
+ tensor = torch.randint(low, high + 1, shape, dtype=dtype)
194
216
  else:
195
217
  logger.error('Dtype is not supported: ' + data_dtype)
196
218
  raise NotImplementedError()
@@ -208,9 +230,9 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
208
230
  else:
209
231
  tmp_tensor[0] = low
210
232
  tmp_tensor[-1] = high
211
- if high_origin in [float('inf'), float('-inf')]:
233
+ if high_origin in [float(CompareConst.INF), float(CompareConst.NEG_INF)]:
212
234
  tmp_tensor[-1] = high_origin
213
- if low_origin in [float('inf'), float('-inf')]:
235
+ if low_origin in [float(CompareConst.INF), float(CompareConst.NEG_INF)]:
214
236
  tmp_tensor[0] = low_origin
215
237
  data = tmp_tensor.reshape(shape)
216
238
  return data
@@ -233,7 +255,7 @@ def gen_bool_tensor(low, high, shape):
233
255
  return data
234
256
 
235
257
 
236
- def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
258
+ def gen_args(args_info, api_name, func_options):
237
259
  """
238
260
  Function Description:
239
261
  Based on API basic information, generate input parameters: args, for API forward running
@@ -246,9 +268,20 @@ def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_p
246
268
  """
247
269
  check_object_type(args_info, list)
248
270
  args_result = []
271
+
272
+ need_grad = func_options.get('need_grad', True)
273
+ convert_type = func_options.get('convert_type', None)
274
+ real_data_path = func_options.get('real_data_path', None)
275
+ depth = func_options.get('depth', 0)
276
+
277
+ if depth > Const.MAX_DEPTH:
278
+ logger.error("The depth of args is too large, please check the input args.")
279
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
280
+
249
281
  for arg in args_info:
250
282
  if isinstance(arg, (list, tuple)):
251
- data = gen_args(arg, api_name, need_grad, convert_type, real_data_path)
283
+ func_options['depth'] = depth + 1
284
+ data = gen_args(arg, api_name, func_options)
252
285
  elif isinstance(arg, dict):
253
286
  data = gen_data(arg, api_name, need_grad, convert_type, real_data_path)
254
287
  elif arg is None:
@@ -288,7 +321,8 @@ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
288
321
 
289
322
  def gen_torch_kwargs(kwargs_params, key, value):
290
323
  if value.get('type') != "torch.device":
291
- kwargs_params[key] = eval(value.get('value'))
324
+ module_name, attribute_name = get_module_and_atttribute_name(value.get('value'))
325
+ kwargs_params[key] = get_attribute(module_name, attribute_name)
292
326
 
293
327
 
294
328
  def gen_list_kwargs(kwargs_item_value, api_name, convert_type, real_data_path=None):
@@ -327,8 +361,14 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d
327
361
  error_info = f"convert_type params not support {convert_type}."
328
362
  raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
329
363
  kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path)
364
+ func_options = {
365
+ 'need_grad': need_grad,
366
+ 'convert_type': convert_type,
367
+ 'real_data_path': real_data_path,
368
+ 'depth': 0
369
+ }
330
370
  if api_info.get("input_args"):
331
- args_params = gen_args(api_info.get("input_args"), api_name, need_grad, convert_type, real_data_path)
371
+ args_params = gen_args(api_info.get("input_args"), api_name, func_options)
332
372
  else:
333
373
  logger.warning(f'Warning: No args in {api_info} ')
334
374
  args_params = []