mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -0,0 +1,152 @@
1
+
2
+
3
+ import mindspore
4
+ import torch
5
+ from mindspore import ops
6
+
7
+ from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
8
+ from msprobe.core.common.const import Const, MsCompareConst
9
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
10
+ from msprobe.core.common.log import logger
11
+ from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
12
+
13
+
14
+ class ApiInputAggregation:
15
+ def __init__(self, inputs, kwargs, gradient_inputs) -> None:
16
+ '''
17
+ Args:
18
+ inputs: List[ComputeElement]
19
+ kwargs: dict{str: ComputeElement}
20
+ gradient_inputs: Union[List[ComputeElement], None]
21
+ '''
22
+ self.inputs = inputs
23
+ self.kwargs = kwargs
24
+ self.gradient_inputs = gradient_inputs
25
+
26
+ api_parent_module_mapping = {
27
+ (MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint,
28
+ (MsCompareConst.MINT, Const.PT_FRAMEWORK): torch,
29
+ (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
30
+ (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional
31
+ }
32
+
33
+ class ApiRunner:
34
+ def __call__(self, api_input_aggregation, api_name_str, forward_or_backward=Const.FORWARD,
35
+ api_platform=Const.MS_FRAMEWORK):
36
+ '''
37
+ Args:
38
+ api_input_aggregation: ApiInputAggregation
39
+ api_name_str: str, e.g. "MintFunctional.relu.0"
40
+ forward_or_backward: str, Union["forward", "backward"]
41
+ api_platform: str, Union["mindspore", "torch"]
42
+
43
+ Return:
44
+ outputs: list[ComputeElement]
45
+
46
+ Description:
47
+ run mindspore.mint/torch api
48
+ '''
49
+ api_type_str, api_sub_name = self.get_info_from_name(api_name_str)
50
+ api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform)
51
+
52
+ return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform)
53
+
54
+ @staticmethod
55
+ def get_info_from_name(api_name_str):
56
+ '''
57
+ Args:
58
+ api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
59
+
60
+ Return:
61
+ api_type_str: str, Union["MintFunctional", "Mint"]
62
+ api_sub_name: str, e.g. "relu"
63
+ '''
64
+ api_name_list = api_name_str.split(Const.SEP)
65
+ if len(api_name_list) != 3:
66
+ err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
67
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
68
+ api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
69
+ if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL]:
70
+ err_msg = f"ApiRunner.get_info_from_name failed: not mint or mint.nn.functional api"
71
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
72
+
73
+ return api_type_str, api_sub_name
74
+
75
+ @staticmethod
76
+ def get_api_instance(api_type_str, api_sub_name, api_platform):
77
+ '''
78
+ Args:
79
+ api_type_str: str, Union["MintFunctional", "Mint"]
80
+ api_sub_name: str, e.g. "relu"
81
+ api_platform: str: Union["mindpore", "torch"]
82
+
83
+ Return:
84
+ api_instance: function object
85
+
86
+ Description:
87
+ get mindspore.mint/torch api fucntion
88
+ mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
89
+ mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
90
+ '''
91
+
92
+ api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
93
+ module_str = "mindspore.mint." if api_platform == Const.MS_FRAMEWORK else "torch."
94
+ submodule_str = "nn.functional." if api_type_str == MsCompareConst.MINT_FUNCTIONAL else ""
95
+ full_api_name = module_str + submodule_str + api_sub_name
96
+ if not hasattr(api_parent_module, api_sub_name):
97
+ err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
98
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
99
+
100
+ api_instance = getattr(api_parent_module, api_sub_name)
101
+ if not callable(api_instance):
102
+ err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not callable"
103
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
104
+
105
+ return api_instance
106
+
107
+ @staticmethod
108
+ def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform):
109
+ inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
110
+ for compute_element in api_input_aggregation.inputs)
111
+ kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform)
112
+ for key, value in api_input_aggregation.kwargs.items()}
113
+ gradient_inputs = api_input_aggregation.gradient_inputs
114
+
115
+ if forward_or_backward == Const.FORWARD:
116
+ forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
117
+ forward_result_tuple = convert_to_tuple(forward_result)
118
+ res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
119
+ else:
120
+ if gradient_inputs is None:
121
+ err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing"
122
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
123
+ gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
124
+ for compute_element in gradient_inputs)
125
+ if api_platform == Const.MS_FRAMEWORK:
126
+ if len(gradient_inputs) == 1:
127
+ gradient_inputs = gradient_inputs[0]
128
+ def api_with_kwargs(*forward_inputs):
129
+ return api_instance(*forward_inputs, **kwargs)
130
+ grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs)
131
+ backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
132
+ backward_result_tuple = convert_to_tuple(backward_result)
133
+ res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
134
+ else:
135
+ #set requires_grad
136
+ for tensor in inputs:
137
+ if hasattr(tensor, "requires_grad"):
138
+ setattr(tensor, "requires_grad", True)
139
+ forward_results = api_instance(*inputs, **kwargs)
140
+ forward_results = convert_to_tuple(forward_results)
141
+ for forward_res, gradient_in in zip(forward_results, gradient_inputs):
142
+ forward_res.backward(gradient_in)
143
+ backward_result_list = []
144
+ for tensor in inputs:
145
+ if hasattr(tensor, "grad"):
146
+ backward_result_list.append(getattr(tensor, "grad"))
147
+ res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_list]
148
+
149
+ return res_compute_element_list
150
+
151
+
152
+ api_runner = ApiRunner()
@@ -0,0 +1,197 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ import mindspore
4
+ import torch
5
+ import numpy as np
6
+
7
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
8
+ from msprobe.core.common.log import logger
9
+ from msprobe.core.common.const import CompareConst, MsCompareConst
10
+
11
+ class CompareResult:
12
+ def __init__(self, compare_value, pass_status, err_msg):
13
+ self.compare_value = compare_value
14
+ self.pass_status = pass_status
15
+ self.err_msg = err_msg
16
+
17
+
18
+ class BaseCompareAlgorithm(ABC):
19
+ def __init__(self) -> None:
20
+ super().__init__()
21
+ self.compare_algorithm_name = None
22
+ self.err_msg_mapping = {
23
+ CompareConst.COSINE: {
24
+ CompareConst.PASS: "",
25
+ CompareConst.ERROR: f"cosine similarity is less than threshold: {CompareConst.COS_THRESHOLD} ",
26
+ CompareConst.SKIP: "two inputs are not valid for computing cosine similarity, skip comparing ",
27
+ },
28
+ CompareConst.MAX_ABS_ERR: {
29
+ CompareConst.PASS: "",
30
+ CompareConst.ERROR: "max absolute difference is greater than " \
31
+ f"threshold: {CompareConst.MAX_ABS_ERR_THRESHOLD} ",
32
+ CompareConst.SKIP: "two inputs are not valid for computing max absolute difference, skip comparing ",
33
+ },
34
+ CompareConst.MAX_RELATIVE_ERR: {
35
+ CompareConst.PASS: "",
36
+ CompareConst.ERROR: "",
37
+ CompareConst.SKIP: "",
38
+ },
39
+ }
40
+
41
+ def __call__(self, bench_compute_element, tested_compute_element):
42
+ '''
43
+ Args:
44
+ bench_compute_element: ComputeElement
45
+ tested_compute_element: ComputeElement
46
+
47
+ Return:
48
+ compare_result: CompareResult
49
+ '''
50
+ if self.check_validity(bench_compute_element, tested_compute_element):
51
+ compare_value = self.run_compare(bench_compute_element, tested_compute_element)
52
+ pass_status = self.check_pass(compare_value)
53
+ else:
54
+ logger.warning(f"not suitable for computing {self.compare_algorithm_name}, skip this.")
55
+ compare_value = None
56
+ pass_status = CompareConst.SKIP
57
+
58
+ err_msg = self.err_msg_mapping.get(self.compare_algorithm_name).get(pass_status)
59
+
60
+ compare_result = CompareResult(compare_value, pass_status, err_msg)
61
+ return compare_result
62
+
63
+ @staticmethod
64
+ def convert_to_np_float64_ndarray(tensor):
65
+ if isinstance(tensor, mindspore.Tensor):
66
+ ndarray = tensor.astype(mindspore.float64).numpy()
67
+ elif isinstance(tensor, torch.Tensor):
68
+ ndarray = tensor.to(torch.float64, copy=True).numpy()
69
+ else:
70
+ err_msg = "BaseCompareAlgorithm.convert_to_np_float64_ndarray failed: " \
71
+ "input is not mindspore.Tensor or torch.Tensor"
72
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
73
+ return ndarray
74
+
75
+ @staticmethod
76
+ def check_two_tensor(bench_compute_element, tested_compute_element):
77
+ bench_parameter = bench_compute_element.get_parameter()
78
+ tested_parameter = tested_compute_element.get_parameter()
79
+
80
+ bench_is_tensor = isinstance(bench_parameter, (mindspore.Tensor, torch.Tensor))
81
+ tested_is_tensor = isinstance(tested_parameter, (mindspore.Tensor, torch.Tensor))
82
+ shape_same = bench_compute_element.get_shape() == tested_compute_element.get_shape()
83
+ return bench_is_tensor and tested_is_tensor and shape_same
84
+
85
+ @abstractmethod
86
+ def check_validity(self, bench_compute_element, tested_compute_element):
87
+ '''
88
+ Args:
89
+ bench_compute_element: ComputeElement
90
+ tested_compute_element: ComputeElement
91
+
92
+ Return:
93
+ check_res: boolean
94
+ '''
95
+ raise NotImplementedError
96
+
97
+ @abstractmethod
98
+ def run_compare(self, bench_compute_element, tested_compute_element):
99
+ '''
100
+ Args:
101
+ bench_compute_element: ComputeElement
102
+ tested_compute_element: ComputeElement
103
+
104
+ Return:
105
+ compare_value: float/int
106
+ '''
107
+ raise NotImplementedError
108
+
109
+ @abstractmethod
110
+ def check_pass(self, compare_value):
111
+ '''
112
+ Args:
113
+ compare_value: float/int
114
+
115
+ Return:
116
+ pass_status: str
117
+ '''
118
+ raise NotImplementedError
119
+
120
+
121
+ class CosineSimilarityCompareAlgorithm(BaseCompareAlgorithm):
122
+ def __init__(self) -> None:
123
+ super().__init__()
124
+ self.compare_algorithm_name = CompareConst.COSINE
125
+
126
+ def check_validity(self, bench_compute_element, tested_compute_element):
127
+ return self.check_two_tensor(bench_compute_element, tested_compute_element)
128
+
129
+ def run_compare(self, bench_compute_element, tested_compute_element):
130
+ bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
131
+ tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
132
+
133
+ bench_norm = np.linalg.norm(bench_ndarray)
134
+ tested_norm = np.linalg.norm(tested_ndarray)
135
+ dot_product = np.dot(bench_ndarray.flatten(), tested_ndarray.flatten())
136
+ cosine_similarity = (MsCompareConst.EPSILON + dot_product) / (MsCompareConst.EPSILON + bench_norm * tested_norm)
137
+ return cosine_similarity
138
+
139
+ def check_pass(self, compare_value):
140
+ if compare_value > CompareConst.COS_THRESHOLD:
141
+ return CompareConst.PASS
142
+ else:
143
+ return CompareConst.ERROR
144
+
145
+
146
+ class MaxAbsoluteDiffCompareAlgorithm(BaseCompareAlgorithm):
147
+ def __init__(self) -> None:
148
+ super().__init__()
149
+ self.compare_algorithm_name = CompareConst.MAX_ABS_ERR
150
+
151
+ def check_validity(self, bench_compute_element, tested_compute_element):
152
+ return self.check_two_tensor(bench_compute_element, tested_compute_element)
153
+
154
+ def run_compare(self, bench_compute_element, tested_compute_element):
155
+ bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
156
+ tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
157
+
158
+ max_absolute_diff = np.max(np.abs(bench_ndarray - tested_ndarray))
159
+ return max_absolute_diff
160
+
161
+ def check_pass(self, compare_value):
162
+ if compare_value < CompareConst.MAX_ABS_ERR_THRESHOLD:
163
+ return CompareConst.PASS
164
+ else:
165
+ return CompareConst.ERROR
166
+
167
+
168
+ class MaxRelativeDiffCompareAlgorithm(BaseCompareAlgorithm):
169
+ def __init__(self) -> None:
170
+ super().__init__()
171
+ self.compare_algorithm_name = CompareConst.MAX_RELATIVE_ERR
172
+
173
+ def check_validity(self, bench_compute_element, tested_compute_element):
174
+ return self.check_two_tensor(bench_compute_element, tested_compute_element)
175
+
176
+ def run_compare(self, bench_compute_element, tested_compute_element):
177
+ bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
178
+ tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
179
+
180
+ abs_diff = np.abs(bench_ndarray - tested_ndarray)
181
+ bench_ndarray_nonzero = np.abs(bench_ndarray) + (bench_ndarray == 0) * MsCompareConst.EPSILON
182
+ max_relative_diff = np.max(abs_diff / bench_ndarray_nonzero)
183
+ return max_relative_diff
184
+
185
+ def check_pass(self, compare_value):
186
+ if compare_value < CompareConst.MAX_RELATIVE_ERR_THRESHOLD:
187
+ return CompareConst.PASS
188
+ else:
189
+ return CompareConst.ERROR
190
+
191
+
192
+
193
+ compare_algorithms = {
194
+ CompareConst.COSINE: CosineSimilarityCompareAlgorithm(),
195
+ CompareConst.MAX_ABS_ERR: MaxAbsoluteDiffCompareAlgorithm(),
196
+ CompareConst.MAX_RELATIVE_ERR: MaxRelativeDiffCompareAlgorithm(),
197
+ }
@@ -0,0 +1,224 @@
1
+ import os
2
+
3
+ import mindspore
4
+ import torch
5
+ import numpy as np
6
+
7
+ from msprobe.core.common.log import logger
8
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
9
+ from msprobe.core.common.utils import load_npy
10
+ from msprobe.mindspore.api_accuracy_checker.type_mapping import (dtype_str_to_np_dtype, api_info_type_str_to_type,
11
+ ms_dtype_to_dtype_str, torch_dtype_to_dtype_str,
12
+ dtype_str_to_ms_dtype, dtype_str_to_np_dtype,
13
+ dtype_str_to_torch_dtype, type_to_api_info_type_str,
14
+ DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE, TUPLE_TYPE_STR,
15
+ MINDSPORE_TENSOR_TYPE_STR, float_dtype_str_list,
16
+ int_dtype_str_list)
17
+ from msprobe.core.common.const import Const
18
+ from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
19
+
20
+
21
+ class MstensorMetaData:
22
+ def __init__(self, dtype_str, npy_path, maximum, minimum, shape) -> None:
23
+ self.dtype_str = dtype_str
24
+ self.npy_path = npy_path
25
+ self.maximum = maximum
26
+ self.minimum = minimum
27
+ self.shape = shape
28
+
29
+ class ComputeElement:
30
+ def __init__(self, compute_element_info=None, parameter=None):
31
+ self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple])
32
+ if parameter is not None:
33
+ self._init_with_parameter(parameter)
34
+ elif isinstance(compute_element_info, (list, dict)):
35
+ self._init_from_compute_element_info(compute_element_info)
36
+ else:
37
+ logger.error_log_with_exp(
38
+ "ComputeElement.__init__ failed: not init with parameter or compute_element info is not (list, dict)",
39
+ ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
40
+
41
+ @staticmethod
42
+ def transfer_to_torch_tensor(ms_tensor):
43
+ '''
44
+ Args:
45
+ ms_tensor: mindspore.Tensor
46
+ Return:
47
+ torch_tensor: torch.Tensor
48
+ '''
49
+ ms_dtype = ms_tensor.dtype
50
+ dtype_str = ms_dtype_to_dtype_str.get(ms_dtype)
51
+ if dtype_str not in dtype_str_to_torch_dtype:
52
+ err_msg = f"ComputeElement.transfer_to_torch_tensor failed: no matching torch dtype for {dtype_str}"
53
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
54
+ else:
55
+ torch_dtype = dtype_str_to_torch_dtype.get(dtype_str)
56
+
57
+ if dtype_str in float_dtype_str_list:
58
+ middle_dtype = mindspore.float64
59
+ elif dtype_str in int_dtype_str_list:
60
+ middle_dtype = mindspore.int64
61
+ else:
62
+ middle_dtype = mindspore.uint64
63
+ np_ndarray = ms_tensor.astype(middle_dtype).numpy()
64
+ torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
65
+ return torch_tensor
66
+
67
+ @staticmethod
68
+ def transfer_to_mindspore_tensor(torch_tensor):
69
+ '''
70
+ Args:
71
+ torch_tensor: torch.Tensor
72
+
73
+ Return:
74
+ ms_tensor: mindspore.Tensor
75
+ '''
76
+ torch_dtype = torch_tensor.dtype
77
+ dtype_str = torch_dtype_to_dtype_str.get(torch_dtype)
78
+ if dtype_str not in dtype_str_to_ms_dtype:
79
+ err_msg = \
80
+ f"ComputeElement._transfer_to_mindspore_tensor failed: no matching mindspore dtype for {dtype_str}"
81
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
82
+ else:
83
+ ms_dtype = dtype_str_to_ms_dtype.get(dtype_str)
84
+
85
+ if dtype_str in float_dtype_str_list:
86
+ middle_dtype = torch.float64
87
+ elif dtype_str in int_dtype_str_list:
88
+ middle_dtype = torch.int64
89
+ np_ndarray = torch_tensor.to(middle_dtype, copy=True).numpy()
90
+ ms_tensor = mindspore.Tensor.from_numpy(np_ndarray).astype(ms_dtype)
91
+ return ms_tensor
92
+
93
+ @staticmethod
94
+ def convert_inf_to_real_num(value, dtype_str):
95
+ if value == float("inf"):
96
+ np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
97
+ value = np.finfo(np_dtype).max
98
+ elif value == float("-inf"):
99
+ np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
100
+ value = np.finfo(np_dtype).min
101
+ return value
102
+
103
+ def get_parameter(self, get_origin=True, tensor_platform=Const.MS_FRAMEWORK):
104
+ '''
105
+ Args:
106
+ get_origin: boolean
107
+ get_mindspore_tensor: boolean
108
+
109
+ Return:
110
+ parameter: Union[int, float, str, slice,tuple, torch.Tensor, mindspore.Tensor]
111
+ '''
112
+ if isinstance(self.parameter, self.supported_parameter_type):
113
+ parameter_tmp = self.parameter
114
+ elif isinstance(self.parameter, MstensorMetaData):
115
+ mstensor_meta_data = self.parameter
116
+ ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
117
+ if global_context.get_is_constructed():
118
+ np_dtype = dtype_str_to_np_dtype.get(mstensor_meta_data.dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
119
+ ndarray = self._construct_ndarray(mstensor_meta_data.shape, mstensor_meta_data.maximum,
120
+ mstensor_meta_data.minimum, np_dtype)
121
+ else:
122
+ ndarray = load_npy(mstensor_meta_data.npy_path)
123
+ parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype)
124
+ else:
125
+ err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \
126
+ "(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)"
127
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
128
+
129
+ # if necessary, do transfer
130
+ if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
131
+ parameter = self.transfer_to_torch_tensor(parameter_tmp)
132
+ elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform ==Const.MS_FRAMEWORK:
133
+ parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
134
+ else:
135
+ parameter = parameter_tmp
136
+
137
+ return parameter
138
+
139
+ def get_shape(self):
140
+ return self.shape
141
+
142
+ def get_dtype(self):
143
+ return self.dtype_str
144
+
145
+ def _construct_ndarray(self, shape, maximum, minimum, np_dtype):
146
+ shape = tuple(shape)
147
+ np.random.seed(42)
148
+ if np_dtype == np.bool_:
149
+ ndarray = np.random.rand(*shape) > 0.5
150
+ else:
151
+ maximum = self.convert_inf_to_real_num(maximum, np_dtype)
152
+ minimum = self.convert_inf_to_real_num(minimum, np_dtype)
153
+ ndarray = np.random.uniform(minimum, maximum, shape).astype(np_dtype)
154
+ return ndarray
155
+
156
+ def _init_from_compute_element_info(self, compute_element_info):
157
+ '''
158
+ Args:
159
+ compute_element_info: Union[list, dict]
160
+ is_constructed: boolean
161
+
162
+ Return:
163
+ void
164
+
165
+ init member attributes: self.shape, self.dtype_str, self.parameter
166
+ '''
167
+ if isinstance(compute_element_info, list):
168
+ self.shape = tuple()
169
+ self.dtype_str = TUPLE_TYPE_STR
170
+ self.parameter = tuple(ComputeElement(compute_element_info=sub_info).get_parameter()
171
+ for sub_info in compute_element_info)
172
+ else:
173
+ type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json",
174
+ accepted_type=str, accepted_value=api_info_type_str_to_type.keys())
175
+
176
+ if type_str == MINDSPORE_TENSOR_TYPE_STR:
177
+ self._init_from_mstensor_compute_element_info(compute_element_info)
178
+ else: # type_str in ("slice", "int", "float", "bool")
179
+ value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json")
180
+ self.shape = tuple()
181
+ self.dtype_str = type_str
182
+ self.parameter = slice(*tuple(value)) if type_str == "slice" else value
183
+
184
+ def _init_from_mstensor_compute_element_info(self, compute_element_info):
185
+ '''
186
+ do not load real tensor, only record meta data
187
+ '''
188
+ dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json",
189
+ accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys())
190
+ shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json",
191
+ accepted_type=(list,))
192
+ if global_context.get_is_constructed():
193
+ maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json",
194
+ accepted_type=(int, float))
195
+ minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json",
196
+ accepted_type=(int, float))
197
+
198
+ npy_path = None
199
+ else:
200
+ maximum, minimum = None, None
201
+ data_name = check_and_get_from_json_dict(compute_element_info, "data_name",
202
+ "data_name field in api_info.json", accepted_type=(str,))
203
+ npy_path = os.path.join(global_context.get_dump_data_dir(), data_name)
204
+ mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape)
205
+ self.parameter = mstensor_meta_data
206
+ self.dtype_str = dtype_str
207
+ self.shape = tuple(shape)
208
+
209
+ def _init_with_parameter(self, parameter):
210
+ self.parameter = parameter
211
+ if not isinstance(parameter, self.supported_parameter_type):
212
+ err_msg = "ComputeElement._init_with_parameter failed: " \
213
+ "parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)"
214
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
215
+ if isinstance(parameter, mindspore.Tensor):
216
+ self.shape = tuple(parameter.shape)
217
+ self.dtype_str = ms_dtype_to_dtype_str.get(parameter.dtype)
218
+ elif isinstance(parameter, torch.Tensor):
219
+ self.shape = tuple(parameter.shape)
220
+ self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype)
221
+ else:
222
+ self.shape = tuple()
223
+ self.dtype_str = \
224
+ TUPLE_TYPE_STR if isinstance(parameter, tuple) else type_to_api_info_type_str.get(type(parameter))
@@ -0,0 +1,16 @@
1
+ from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker
2
+
3
+ def add_api_accuracy_checker_argument(parser):
4
+ parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
5
+ help="<Required> The api param tool result file: generate from api param tool, "
6
+ "a json file.")
7
+ parser.add_argument("-o", "--out_path", dest="out_path", default="./", type=str, required=False,
8
+ help="<optional> The ut task result out path.")
9
+
10
+
11
+ def api_checker_main(args):
12
+ api_accuracy_checker = ApiAccuracyChecker()
13
+ api_accuracy_checker.parse(args.api_info_file)
14
+ api_accuracy_checker.run_and_compare()
15
+ api_accuracy_checker.to_detail_csv(args.out_path)
16
+ api_accuracy_checker.to_result_csv(args.out_path)