mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,224 +1,239 @@
1
- import os
2
-
3
- import mindspore
4
- import torch
5
- import numpy as np
6
-
7
- from msprobe.core.common.log import logger
8
- from msprobe.core.common.exceptions import ApiAccuracyCheckerException
9
- from msprobe.core.common.utils import load_npy
10
- from msprobe.mindspore.api_accuracy_checker.type_mapping import (dtype_str_to_np_dtype, api_info_type_str_to_type,
11
- ms_dtype_to_dtype_str, torch_dtype_to_dtype_str,
12
- dtype_str_to_ms_dtype, dtype_str_to_np_dtype,
13
- dtype_str_to_torch_dtype, type_to_api_info_type_str,
14
- DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE, TUPLE_TYPE_STR,
15
- MINDSPORE_TENSOR_TYPE_STR, float_dtype_str_list,
16
- int_dtype_str_list)
17
- from msprobe.core.common.const import Const
18
- from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
19
-
20
-
21
- class MstensorMetaData:
22
- def __init__(self, dtype_str, npy_path, maximum, minimum, shape) -> None:
23
- self.dtype_str = dtype_str
24
- self.npy_path = npy_path
25
- self.maximum = maximum
26
- self.minimum = minimum
27
- self.shape = shape
28
-
29
- class ComputeElement:
30
- def __init__(self, compute_element_info=None, parameter=None):
31
- self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple])
32
- if parameter is not None:
33
- self._init_with_parameter(parameter)
34
- elif isinstance(compute_element_info, (list, dict)):
35
- self._init_from_compute_element_info(compute_element_info)
36
- else:
37
- logger.error_log_with_exp(
38
- "ComputeElement.__init__ failed: not init with parameter or compute_element info is not (list, dict)",
39
- ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
40
-
41
- @staticmethod
42
- def transfer_to_torch_tensor(ms_tensor):
43
- '''
44
- Args:
45
- ms_tensor: mindspore.Tensor
46
- Return:
47
- torch_tensor: torch.Tensor
48
- '''
49
- ms_dtype = ms_tensor.dtype
50
- dtype_str = ms_dtype_to_dtype_str.get(ms_dtype)
51
- if dtype_str not in dtype_str_to_torch_dtype:
52
- err_msg = f"ComputeElement.transfer_to_torch_tensor failed: no matching torch dtype for {dtype_str}"
53
- logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
54
- else:
55
- torch_dtype = dtype_str_to_torch_dtype.get(dtype_str)
56
-
57
- if dtype_str in float_dtype_str_list:
58
- middle_dtype = mindspore.float64
59
- elif dtype_str in int_dtype_str_list:
60
- middle_dtype = mindspore.int64
61
- else:
62
- middle_dtype = mindspore.uint64
63
- np_ndarray = ms_tensor.astype(middle_dtype).numpy()
64
- torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
65
- return torch_tensor
66
-
67
- @staticmethod
68
- def transfer_to_mindspore_tensor(torch_tensor):
69
- '''
70
- Args:
71
- torch_tensor: torch.Tensor
72
-
73
- Return:
74
- ms_tensor: mindspore.Tensor
75
- '''
76
- torch_dtype = torch_tensor.dtype
77
- dtype_str = torch_dtype_to_dtype_str.get(torch_dtype)
78
- if dtype_str not in dtype_str_to_ms_dtype:
79
- err_msg = \
80
- f"ComputeElement._transfer_to_mindspore_tensor failed: no matching mindspore dtype for {dtype_str}"
81
- logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
82
- else:
83
- ms_dtype = dtype_str_to_ms_dtype.get(dtype_str)
84
-
85
- if dtype_str in float_dtype_str_list:
86
- middle_dtype = torch.float64
87
- elif dtype_str in int_dtype_str_list:
88
- middle_dtype = torch.int64
89
- np_ndarray = torch_tensor.to(middle_dtype, copy=True).numpy()
90
- ms_tensor = mindspore.Tensor.from_numpy(np_ndarray).astype(ms_dtype)
91
- return ms_tensor
92
-
93
- @staticmethod
94
- def convert_inf_to_real_num(value, dtype_str):
95
- if value == float("inf"):
96
- np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
97
- value = np.finfo(np_dtype).max
98
- elif value == float("-inf"):
99
- np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
100
- value = np.finfo(np_dtype).min
101
- return value
102
-
103
- def get_parameter(self, get_origin=True, tensor_platform=Const.MS_FRAMEWORK):
104
- '''
105
- Args:
106
- get_origin: boolean
107
- get_mindspore_tensor: boolean
108
-
109
- Return:
110
- parameter: Union[int, float, str, slice,tuple, torch.Tensor, mindspore.Tensor]
111
- '''
112
- if isinstance(self.parameter, self.supported_parameter_type):
113
- parameter_tmp = self.parameter
114
- elif isinstance(self.parameter, MstensorMetaData):
115
- mstensor_meta_data = self.parameter
116
- ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
117
- if global_context.get_is_constructed():
118
- np_dtype = dtype_str_to_np_dtype.get(mstensor_meta_data.dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
119
- ndarray = self._construct_ndarray(mstensor_meta_data.shape, mstensor_meta_data.maximum,
120
- mstensor_meta_data.minimum, np_dtype)
121
- else:
122
- ndarray = load_npy(mstensor_meta_data.npy_path)
123
- parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype)
124
- else:
125
- err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \
126
- "(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)"
127
- logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
128
-
129
- # if necessary, do transfer
130
- if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
131
- parameter = self.transfer_to_torch_tensor(parameter_tmp)
132
- elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform ==Const.MS_FRAMEWORK:
133
- parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
134
- else:
135
- parameter = parameter_tmp
136
-
137
- return parameter
138
-
139
- def get_shape(self):
140
- return self.shape
141
-
142
- def get_dtype(self):
143
- return self.dtype_str
144
-
145
- def _construct_ndarray(self, shape, maximum, minimum, np_dtype):
146
- shape = tuple(shape)
147
- np.random.seed(42)
148
- if np_dtype == np.bool_:
149
- ndarray = np.random.rand(*shape) > 0.5
150
- else:
151
- maximum = self.convert_inf_to_real_num(maximum, np_dtype)
152
- minimum = self.convert_inf_to_real_num(minimum, np_dtype)
153
- ndarray = np.random.uniform(minimum, maximum, shape).astype(np_dtype)
154
- return ndarray
155
-
156
- def _init_from_compute_element_info(self, compute_element_info):
157
- '''
158
- Args:
159
- compute_element_info: Union[list, dict]
160
- is_constructed: boolean
161
-
162
- Return:
163
- void
164
-
165
- init member attributes: self.shape, self.dtype_str, self.parameter
166
- '''
167
- if isinstance(compute_element_info, list):
168
- self.shape = tuple()
169
- self.dtype_str = TUPLE_TYPE_STR
170
- self.parameter = tuple(ComputeElement(compute_element_info=sub_info).get_parameter()
171
- for sub_info in compute_element_info)
172
- else:
173
- type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json",
174
- accepted_type=str, accepted_value=api_info_type_str_to_type.keys())
175
-
176
- if type_str == MINDSPORE_TENSOR_TYPE_STR:
177
- self._init_from_mstensor_compute_element_info(compute_element_info)
178
- else: # type_str in ("slice", "int", "float", "bool")
179
- value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json")
180
- self.shape = tuple()
181
- self.dtype_str = type_str
182
- self.parameter = slice(*tuple(value)) if type_str == "slice" else value
183
-
184
- def _init_from_mstensor_compute_element_info(self, compute_element_info):
185
- '''
186
- do not load real tensor, only record meta data
187
- '''
188
- dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json",
189
- accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys())
190
- shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json",
191
- accepted_type=(list,))
192
- if global_context.get_is_constructed():
193
- maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json",
194
- accepted_type=(int, float))
195
- minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json",
196
- accepted_type=(int, float))
197
-
198
- npy_path = None
199
- else:
200
- maximum, minimum = None, None
201
- data_name = check_and_get_from_json_dict(compute_element_info, "data_name",
202
- "data_name field in api_info.json", accepted_type=(str,))
203
- npy_path = os.path.join(global_context.get_dump_data_dir(), data_name)
204
- mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape)
205
- self.parameter = mstensor_meta_data
206
- self.dtype_str = dtype_str
207
- self.shape = tuple(shape)
208
-
209
- def _init_with_parameter(self, parameter):
210
- self.parameter = parameter
211
- if not isinstance(parameter, self.supported_parameter_type):
212
- err_msg = "ComputeElement._init_with_parameter failed: " \
213
- "parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)"
214
- logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
215
- if isinstance(parameter, mindspore.Tensor):
216
- self.shape = tuple(parameter.shape)
217
- self.dtype_str = ms_dtype_to_dtype_str.get(parameter.dtype)
218
- elif isinstance(parameter, torch.Tensor):
219
- self.shape = tuple(parameter.shape)
220
- self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype)
221
- else:
222
- self.shape = tuple()
223
- self.dtype_str = \
1
+ import os
2
+
3
+ import mindspore
4
+ import torch
5
+ import numpy as np
6
+
7
+ from msprobe.mindspore.common.log import logger
8
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
9
+ from msprobe.core.common.file_utils import load_npy
10
+ from msprobe.mindspore.api_accuracy_checker.type_mapping import (dtype_str_to_np_dtype, api_info_type_str_to_type,
11
+ ms_dtype_to_dtype_str, torch_dtype_to_dtype_str,
12
+ dtype_str_to_ms_dtype, dtype_str_to_np_dtype,
13
+ dtype_str_to_torch_dtype, type_to_api_info_type_str,
14
+ DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE, TUPLE_TYPE_STR,
15
+ MINDSPORE_TENSOR_TYPE_STR, float_dtype_str_list,
16
+ int_dtype_str_list)
17
+ from msprobe.core.common.const import Const
18
+ from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
19
+
20
+
21
+ class MstensorMetaData:
22
+ def __init__(self, dtype_str, npy_path, maximum, minimum, shape) -> None:
23
+ self.dtype_str = dtype_str
24
+ self.npy_path = npy_path
25
+ self.maximum = maximum
26
+ self.minimum = minimum
27
+ self.shape = shape
28
+
29
+ class ComputeElement:
30
+ def __init__(self, compute_element_info=None, parameter=None):
31
+ self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple])
32
+ if parameter is not None:
33
+ self._init_with_parameter(parameter)
34
+ elif isinstance(compute_element_info, (list, dict)):
35
+ self._init_from_compute_element_info(compute_element_info)
36
+ elif compute_element_info is None:
37
+ self._init_from_null_compute_element_info()
38
+ else:
39
+ logger.error_log_with_exp(
40
+ "ComputeElement.__init__ failed: not init with parameter or compute_element info is not (list, dict)",
41
+ ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
42
+
43
+ @staticmethod
44
+ def transfer_to_torch_tensor(ms_tensor):
45
+ '''
46
+ Args:
47
+ ms_tensor: mindspore.Tensor
48
+ Return:
49
+ torch_tensor: torch.Tensor
50
+ '''
51
+ ms_dtype = ms_tensor.dtype
52
+ dtype_str = ms_dtype_to_dtype_str.get(ms_dtype)
53
+ if dtype_str not in dtype_str_to_torch_dtype:
54
+ err_msg = f"ComputeElement.transfer_to_torch_tensor failed: no matching torch dtype for {dtype_str}"
55
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
56
+ else:
57
+ torch_dtype = dtype_str_to_torch_dtype.get(dtype_str)
58
+
59
+ if dtype_str in float_dtype_str_list:
60
+ middle_dtype = mindspore.float64
61
+ elif dtype_str in int_dtype_str_list:
62
+ middle_dtype = mindspore.int64
63
+ else:
64
+ middle_dtype = mindspore.uint64
65
+ np_ndarray = ms_tensor.astype(middle_dtype).numpy()
66
+ torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
67
+ return torch_tensor
68
+
69
+ @staticmethod
70
+ def transfer_to_mindspore_tensor(torch_tensor):
71
+ '''
72
+ Args:
73
+ torch_tensor: torch.Tensor
74
+
75
+ Return:
76
+ ms_tensor: mindspore.Tensor
77
+ '''
78
+ torch_dtype = torch_tensor.dtype
79
+ dtype_str = torch_dtype_to_dtype_str.get(torch_dtype)
80
+ if dtype_str not in dtype_str_to_ms_dtype:
81
+ err_msg = \
82
+ f"ComputeElement._transfer_to_mindspore_tensor failed: no matching mindspore dtype for {dtype_str}"
83
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
84
+ else:
85
+ ms_dtype = dtype_str_to_ms_dtype.get(dtype_str)
86
+
87
+ if dtype_str in float_dtype_str_list:
88
+ middle_dtype = torch.float64
89
+ elif dtype_str in int_dtype_str_list:
90
+ middle_dtype = torch.int64
91
+ np_ndarray = torch_tensor.to(middle_dtype, copy=True).numpy()
92
+ ms_tensor = mindspore.Tensor.from_numpy(np_ndarray).astype(ms_dtype)
93
+ return ms_tensor
94
+
95
+ @staticmethod
96
+ def convert_inf_to_real_num(value, dtype_str):
97
+ if value == float("inf"):
98
+ np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
99
+ value = np.finfo(np_dtype).max
100
+ elif value == float("-inf"):
101
+ np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
102
+ value = np.finfo(np_dtype).min
103
+ return value
104
+
105
+ def get_parameter(self, get_origin=True, tensor_platform=Const.MS_FRAMEWORK):
106
+ '''
107
+ Args:
108
+ get_origin: boolean
109
+ tensor_platform: str, Union["mindspore", "pytorch"]
110
+
111
+ Return:
112
+ parameter: Union[int, float, str, slice, tuple, torch.Tensor, mindspore.Tensor]
113
+ '''
114
+ if self.parameter is None:
115
+ return self.parameter
116
+ if isinstance(self.parameter, tuple):
117
+ return tuple([compute_element.get_parameter(get_origin=get_origin, tensor_platform=tensor_platform)
118
+ for compute_element in self.parameter])
119
+ elif isinstance(self.parameter, self.supported_parameter_type):
120
+ parameter_tmp = self.parameter
121
+ elif isinstance(self.parameter, MstensorMetaData):
122
+ mstensor_meta_data = self.parameter
123
+ ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
124
+ if global_context.get_is_constructed():
125
+ np_dtype = dtype_str_to_np_dtype.get(mstensor_meta_data.dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
126
+ ndarray = self._construct_ndarray(mstensor_meta_data.shape, mstensor_meta_data.maximum,
127
+ mstensor_meta_data.minimum, np_dtype)
128
+ else:
129
+ ndarray = load_npy(mstensor_meta_data.npy_path)
130
+ parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype)
131
+ else:
132
+ err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \
133
+ "(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)"
134
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
135
+
136
+ # if necessary, do transfer
137
+ if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
138
+ parameter = self.transfer_to_torch_tensor(parameter_tmp)
139
+ elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform ==Const.MS_FRAMEWORK:
140
+ parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
141
+ else:
142
+ parameter = parameter_tmp
143
+
144
+ return parameter
145
+
146
+ def get_shape(self):
147
+ return self.shape
148
+
149
+ def get_dtype(self):
150
+ return self.dtype_str
151
+
152
+ def _construct_ndarray(self, shape, maximum, minimum, np_dtype):
153
+ shape = tuple(shape)
154
+ np.random.seed(42)
155
+ if np_dtype == np.bool_:
156
+ ndarray = np.random.rand(*shape) > 0.5
157
+ else:
158
+ maximum = self.convert_inf_to_real_num(maximum, np_dtype)
159
+ minimum = self.convert_inf_to_real_num(minimum, np_dtype)
160
+ ndarray = np.random.uniform(minimum, maximum, shape).astype(np_dtype)
161
+ return ndarray
162
+
163
+ def _init_from_null_compute_element_info(self):
164
+ self.parameter = None
165
+ self.shape = tuple()
166
+ self.dtype = "None"
167
+
168
+ def _init_from_compute_element_info(self, compute_element_info):
169
+ '''
170
+ Args:
171
+ compute_element_info: Union[list, dict]
172
+
173
+ Return:
174
+ void
175
+
176
+ init member attributes: self.shape, self.dtype_str, self.parameter
177
+ '''
178
+ if isinstance(compute_element_info, list):
179
+ self.shape = tuple()
180
+ self.dtype_str = TUPLE_TYPE_STR
181
+ self.parameter = tuple([ComputeElement(compute_element_info=sub_info)
182
+ for sub_info in compute_element_info])
183
+ else:
184
+ type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json",
185
+ accepted_type=str, accepted_value=api_info_type_str_to_type.keys())
186
+
187
+ if type_str == MINDSPORE_TENSOR_TYPE_STR:
188
+ self._init_from_mstensor_compute_element_info(compute_element_info)
189
+ else: # type_str in ("slice", "int", "float", "bool")
190
+ value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json")
191
+ self.shape = tuple()
192
+ self.dtype_str = type_str
193
+ self.parameter = slice(*tuple(value)) if type_str == "slice" else value
194
+
195
+ def _init_from_mstensor_compute_element_info(self, compute_element_info):
196
+ '''
197
+ do not load real tensor, only record meta data
198
+ '''
199
+ dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json",
200
+ accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys())
201
+ shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json",
202
+ accepted_type=(list,))
203
+ if global_context.get_is_constructed():
204
+ maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json",
205
+ accepted_type=(int, float))
206
+ minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json",
207
+ accepted_type=(int, float))
208
+
209
+ npy_path = None
210
+ else:
211
+ maximum, minimum = None, None
212
+ data_name = check_and_get_from_json_dict(compute_element_info, "data_name",
213
+ "data_name field in api_info.json", accepted_type=(str,))
214
+ npy_path = os.path.join(global_context.get_dump_data_dir(), data_name)
215
+ mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape)
216
+ self.parameter = mstensor_meta_data
217
+ self.dtype_str = dtype_str
218
+ self.shape = tuple(shape)
219
+
220
+ def _init_with_parameter(self, parameter):
221
+ self.parameter = parameter
222
+ if not isinstance(parameter, self.supported_parameter_type):
223
+ err_msg = "ComputeElement._init_with_parameter failed: " \
224
+ "parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)"
225
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
226
+ if isinstance(parameter, mindspore.Tensor):
227
+ self.shape = tuple(parameter.shape)
228
+ self.dtype_str = ms_dtype_to_dtype_str.get(parameter.dtype)
229
+ elif isinstance(parameter, torch.Tensor):
230
+ self.shape = tuple(parameter.shape)
231
+ self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype)
232
+ elif isinstance(parameter, tuple):
233
+ self.shape = tuple()
234
+ self.dtype_str = TUPLE_TYPE_STR
235
+ self.parameter = tuple([ComputeElement(parameter=param) for param in parameter])
236
+ else:
237
+ self.shape = tuple()
238
+ self.dtype_str = \
224
239
  TUPLE_TYPE_STR if isinstance(parameter, tuple) else type_to_api_info_type_str.get(type(parameter))
@@ -1,16 +1,9 @@
1
- from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker
2
-
3
- def add_api_accuracy_checker_argument(parser):
4
- parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
5
- help="<Required> The api param tool result file: generate from api param tool, "
6
- "a json file.")
7
- parser.add_argument("-o", "--out_path", dest="out_path", default="./", type=str, required=False,
8
- help="<optional> The ut task result out path.")
9
-
10
-
11
- def api_checker_main(args):
12
- api_accuracy_checker = ApiAccuracyChecker()
13
- api_accuracy_checker.parse(args.api_info_file)
14
- api_accuracy_checker.run_and_compare()
15
- api_accuracy_checker.to_detail_csv(args.out_path)
1
+ from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker
2
+
3
+
4
+ def api_checker_main(args):
5
+ api_accuracy_checker = ApiAccuracyChecker()
6
+ api_accuracy_checker.parse(args.api_info_file)
7
+ api_accuracy_checker.run_and_compare()
8
+ api_accuracy_checker.to_detail_csv(args.out_path)
16
9
  api_accuracy_checker.to_result_csv(args.out_path)