mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,69 +1,76 @@
1
- from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
2
- from msprobe.core.common.const import Const
3
- from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
4
- from msprobe.core.common.exceptions import ApiAccuracyCheckerException
5
- from msprobe.core.common.log import logger
6
-
7
- class ApiInfo:
8
- def __init__(self, api_name):
9
- self.api_name = api_name
10
- self.forward_info = None
11
- self.backward_info = None
12
-
13
- def load_forward_info(self, forward_info_dict):
14
- self.forward_info = forward_info_dict
15
-
16
- def load_backward_info(self, backward_info_dict):
17
- self.backward_info = backward_info_dict
18
-
19
- def check_forward_info(self):
20
- return self.forward_info is not None
21
-
22
- def check_backward_info(self):
23
- return self.backward_info is not None
24
-
25
- def get_compute_element_list(self, forward_or_backward, input_or_output):
26
- '''
27
- Args:
28
- forward_or_backward: str, Union["forward", "backward"]
29
- input_or_output: str, Union["input", "output"]
30
-
31
- Return:
32
- compute_element_list: List[ComputeElement]
33
- '''
34
- mapping = {
35
- (Const.FORWARD, Const.INPUT): [self.forward_info, Const.INPUT_ARGS,
36
- f"input_args field of {self.api_name} forward api in api_info.json"],
37
- (Const.FORWARD, Const.OUTPUT): [self.forward_info, Const.OUTPUT,
38
- f"output field of {self.api_name} forward api in api_info.json"],
39
- (Const.BACKWARD, Const.INPUT): [self.backward_info, Const.INPUT,
40
- f"input field of {self.api_name} backward api in api_info.json"],
41
- (Const.BACKWARD, Const.OUTPUT): [self.backward_info, Const.OUTPUT,
42
- f"output field of {self.api_name} backward api in api_info.json"]
43
- }
44
- dict_instance, key, key_desc = mapping.get((forward_or_backward, input_or_output))
45
- compute_element_info_list = check_and_get_from_json_dict(dict_instance, key, key_desc, accepted_type=list)
46
- compute_element_list = [ComputeElement(compute_element_info=compute_element_info)
47
- for compute_element_info in compute_element_info_list]
48
- return compute_element_list
49
-
50
- def get_kwargs(self):
51
- '''
52
- Return:
53
- kwargs_compute_element_dict: dict{str: ComputeElement}
54
- '''
55
- kwargs_dict = check_and_get_from_json_dict(self.forward_info, Const.INPUT_KWARGS,
56
- "input_kwargs in api_info.json", accepted_type=dict)
57
- for key_str, compute_element_info in kwargs_dict.items():
58
- if not isinstance(key_str, str):
59
- err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string"
60
- logger.error_log_with_exp(err_msg,
61
- ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
62
- if not isinstance(compute_element_info, (list, dict)):
63
- err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or dict"
64
- logger.error_log_with_exp(err_msg,
65
- ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
66
- kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info)
67
- for key_str, compute_element_info in kwargs_dict.items()}
68
- return kwargs_compute_element_dict
69
-
1
+ from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
2
+ from msprobe.core.common.const import Const
3
+ from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
4
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
5
+ from msprobe.mindspore.common.log import logger
6
+ from msprobe.core.common.utils import is_invalid_pattern
7
+
8
+ class ApiInfo:
9
+ def __init__(self, api_name):
10
+ if not isinstance(api_name, str):
11
+ err_msg = "ApiInfo.__init__ failed: api_name is not a string"
12
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
13
+ if is_invalid_pattern(api_name):
14
+ err_msg = "ApiInfo.__init__ failed: api_name contain illegal character"
15
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
16
+ self.api_name = api_name
17
+ self.forward_info = None
18
+ self.backward_info = None
19
+
20
+ def load_forward_info(self, forward_info_dict):
21
+ self.forward_info = forward_info_dict
22
+
23
+ def load_backward_info(self, backward_info_dict):
24
+ self.backward_info = backward_info_dict
25
+
26
+ def check_forward_info(self):
27
+ return self.forward_info is not None
28
+
29
+ def check_backward_info(self):
30
+ return self.backward_info is not None
31
+
32
+ def get_compute_element_list(self, forward_or_backward, input_or_output):
33
+ '''
34
+ Args:
35
+ forward_or_backward: str, Union["forward", "backward"]
36
+ input_or_output: str, Union["input", "output"]
37
+
38
+ Return:
39
+ compute_element_list: List[ComputeElement]
40
+ '''
41
+ mapping = {
42
+ (Const.FORWARD, Const.INPUT): [self.forward_info, Const.INPUT_ARGS,
43
+ f"input_args field of {self.api_name} forward api in api_info.json"],
44
+ (Const.FORWARD, Const.OUTPUT): [self.forward_info, Const.OUTPUT,
45
+ f"output field of {self.api_name} forward api in api_info.json"],
46
+ (Const.BACKWARD, Const.INPUT): [self.backward_info, Const.INPUT,
47
+ f"input field of {self.api_name} backward api in api_info.json"],
48
+ (Const.BACKWARD, Const.OUTPUT): [self.backward_info, Const.OUTPUT,
49
+ f"output field of {self.api_name} backward api in api_info.json"]
50
+ }
51
+ dict_instance, key, key_desc = mapping.get((forward_or_backward, input_or_output))
52
+ compute_element_info_list = check_and_get_from_json_dict(dict_instance, key, key_desc, accepted_type=list)
53
+ compute_element_list = [ComputeElement(compute_element_info=compute_element_info)
54
+ for compute_element_info in compute_element_info_list]
55
+ return compute_element_list
56
+
57
+ def get_kwargs(self):
58
+ '''
59
+ Return:
60
+ kwargs_compute_element_dict: dict{str: ComputeElement}
61
+ '''
62
+ kwargs_dict = check_and_get_from_json_dict(self.forward_info, Const.INPUT_KWARGS,
63
+ "input_kwargs in api_info.json", accepted_type=dict)
64
+ for key_str, compute_element_info in kwargs_dict.items():
65
+ if not isinstance(key_str, str):
66
+ err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string"
67
+ logger.error_log_with_exp(err_msg,
68
+ ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
69
+ if not isinstance(compute_element_info, (list, dict)):
70
+ err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or dict"
71
+ logger.error_log_with_exp(err_msg,
72
+ ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
73
+ kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info)
74
+ for key_str, compute_element_info in kwargs_dict.items()}
75
+ return kwargs_compute_element_dict
76
+
@@ -1,152 +1,156 @@
1
-
2
-
3
- import mindspore
4
- import torch
5
- from mindspore import ops
6
-
7
- from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
8
- from msprobe.core.common.const import Const, MsCompareConst
9
- from msprobe.core.common.exceptions import ApiAccuracyCheckerException
10
- from msprobe.core.common.log import logger
11
- from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
12
-
13
-
14
- class ApiInputAggregation:
15
- def __init__(self, inputs, kwargs, gradient_inputs) -> None:
16
- '''
17
- Args:
18
- inputs: List[ComputeElement]
19
- kwargs: dict{str: ComputeElement}
20
- gradient_inputs: Union[List[ComputeElement], None]
21
- '''
22
- self.inputs = inputs
23
- self.kwargs = kwargs
24
- self.gradient_inputs = gradient_inputs
25
-
26
- api_parent_module_mapping = {
27
- (MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint,
28
- (MsCompareConst.MINT, Const.PT_FRAMEWORK): torch,
29
- (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
30
- (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional
31
- }
32
-
33
- class ApiRunner:
34
- def __call__(self, api_input_aggregation, api_name_str, forward_or_backward=Const.FORWARD,
35
- api_platform=Const.MS_FRAMEWORK):
36
- '''
37
- Args:
38
- api_input_aggregation: ApiInputAggregation
39
- api_name_str: str, e.g. "MintFunctional.relu.0"
40
- forward_or_backward: str, Union["forward", "backward"]
41
- api_platform: str, Union["mindspore", "torch"]
42
-
43
- Return:
44
- outputs: list[ComputeElement]
45
-
46
- Description:
47
- run mindspore.mint/torch api
48
- '''
49
- api_type_str, api_sub_name = self.get_info_from_name(api_name_str)
50
- api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform)
51
-
52
- return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform)
53
-
54
- @staticmethod
55
- def get_info_from_name(api_name_str):
56
- '''
57
- Args:
58
- api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
59
-
60
- Return:
61
- api_type_str: str, Union["MintFunctional", "Mint"]
62
- api_sub_name: str, e.g. "relu"
63
- '''
64
- api_name_list = api_name_str.split(Const.SEP)
65
- if len(api_name_list) != 3:
66
- err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
67
- logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
68
- api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
69
- if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL]:
70
- err_msg = f"ApiRunner.get_info_from_name failed: not mint or mint.nn.functional api"
71
- logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
72
-
73
- return api_type_str, api_sub_name
74
-
75
- @staticmethod
76
- def get_api_instance(api_type_str, api_sub_name, api_platform):
77
- '''
78
- Args:
79
- api_type_str: str, Union["MintFunctional", "Mint"]
80
- api_sub_name: str, e.g. "relu"
81
- api_platform: str: Union["mindpore", "torch"]
82
-
83
- Return:
84
- api_instance: function object
85
-
86
- Description:
87
- get mindspore.mint/torch api fucntion
88
- mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
89
- mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
90
- '''
91
-
92
- api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
93
- module_str = "mindspore.mint." if api_platform == Const.MS_FRAMEWORK else "torch."
94
- submodule_str = "nn.functional." if api_type_str == MsCompareConst.MINT_FUNCTIONAL else ""
95
- full_api_name = module_str + submodule_str + api_sub_name
96
- if not hasattr(api_parent_module, api_sub_name):
97
- err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
98
- logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
99
-
100
- api_instance = getattr(api_parent_module, api_sub_name)
101
- if not callable(api_instance):
102
- err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not callable"
103
- logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
104
-
105
- return api_instance
106
-
107
- @staticmethod
108
- def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform):
109
- inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
110
- for compute_element in api_input_aggregation.inputs)
111
- kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform)
112
- for key, value in api_input_aggregation.kwargs.items()}
113
- gradient_inputs = api_input_aggregation.gradient_inputs
114
-
115
- if forward_or_backward == Const.FORWARD:
116
- forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
117
- forward_result_tuple = convert_to_tuple(forward_result)
118
- res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
119
- else:
120
- if gradient_inputs is None:
121
- err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing"
122
- logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
123
- gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
124
- for compute_element in gradient_inputs)
125
- if api_platform == Const.MS_FRAMEWORK:
126
- if len(gradient_inputs) == 1:
127
- gradient_inputs = gradient_inputs[0]
128
- def api_with_kwargs(*forward_inputs):
129
- return api_instance(*forward_inputs, **kwargs)
130
- grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs)
131
- backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
132
- backward_result_tuple = convert_to_tuple(backward_result)
133
- res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
134
- else:
135
- #set requires_grad
136
- for tensor in inputs:
137
- if hasattr(tensor, "requires_grad"):
138
- setattr(tensor, "requires_grad", True)
139
- forward_results = api_instance(*inputs, **kwargs)
140
- forward_results = convert_to_tuple(forward_results)
141
- for forward_res, gradient_in in zip(forward_results, gradient_inputs):
142
- forward_res.backward(gradient_in)
143
- backward_result_list = []
144
- for tensor in inputs:
145
- if hasattr(tensor, "grad"):
146
- backward_result_list.append(getattr(tensor, "grad"))
147
- res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_list]
148
-
149
- return res_compute_element_list
150
-
151
-
1
+
2
+
3
+ import mindspore
4
+ import torch
5
+ from mindspore import ops
6
+
7
+ from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
8
+ from msprobe.core.common.const import Const, MsCompareConst
9
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
10
+ from msprobe.mindspore.common.log import logger
11
+ from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
12
+ from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str
13
+
14
+
15
+ class ApiInputAggregation:
16
+ def __init__(self, inputs, kwargs, gradient_inputs) -> None:
17
+ '''
18
+ Args:
19
+ inputs: List[ComputeElement]
20
+ kwargs: dict{str: ComputeElement}
21
+ gradient_inputs: Union[List[ComputeElement], None]
22
+ '''
23
+ self.inputs = inputs
24
+ self.kwargs = kwargs
25
+ self.gradient_inputs = gradient_inputs
26
+
27
+ api_parent_module_mapping = {
28
+ (MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint,
29
+ (MsCompareConst.MINT, Const.PT_FRAMEWORK): torch,
30
+ (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
31
+ (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional
32
+ }
33
+
34
+
35
+ class ApiRunner:
36
+ def __call__(self, api_input_aggregation, api_name_str, forward_or_backward=Const.FORWARD,
37
+ api_platform=Const.MS_FRAMEWORK):
38
+ '''
39
+ Args:
40
+ api_input_aggregation: ApiInputAggregation
41
+ api_name_str: str, e.g. "MintFunctional.relu.0"
42
+ forward_or_backward: str, Union["forward", "backward"]
43
+ api_platform: str, Union["mindspore", "torch"]
44
+
45
+ Return:
46
+ outputs: list[ComputeElement]
47
+
48
+ Description:
49
+ run mindspore.mint/torch api
50
+ '''
51
+ api_type_str, api_sub_name = self.get_info_from_name(api_name_str)
52
+ api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform)
53
+
54
+ return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform)
55
+
56
+ @staticmethod
57
+ def get_info_from_name(api_name_str):
58
+ '''
59
+ Args:
60
+ api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
61
+
62
+ Return:
63
+ api_type_str: str, Union["MintFunctional", "Mint"]
64
+ api_sub_name: str, e.g. "relu"
65
+ '''
66
+ api_name_list = api_name_str.split(Const.SEP)
67
+ if len(api_name_list) != 3:
68
+ err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
69
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
70
+ api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
71
+ if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL]:
72
+ err_msg = f"ApiRunner.get_info_from_name failed: not mint or mint.nn.functional api"
73
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
74
+
75
+ return api_type_str, api_sub_name
76
+
77
+ @staticmethod
78
+ def get_api_instance(api_type_str, api_sub_name, api_platform):
79
+ '''
80
+ Args:
81
+ api_type_str: str, Union["MintFunctional", "Mint"]
82
+ api_sub_name: str, e.g. "relu"
83
+ api_platform: str: Union["mindpore", "torch"]
84
+
85
+ Return:
86
+ api_instance: function object
87
+
88
+ Description:
89
+ get mindspore.mint/torch api fucntion
90
+ mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
91
+ mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
92
+ '''
93
+
94
+ api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
95
+ module_str = "mindspore.mint." if api_platform == Const.MS_FRAMEWORK else "torch."
96
+ submodule_str = "nn.functional." if api_type_str == MsCompareConst.MINT_FUNCTIONAL else ""
97
+ full_api_name = module_str + submodule_str + api_sub_name
98
+ if not hasattr(api_parent_module, api_sub_name):
99
+ err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
100
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
101
+
102
+ api_instance = getattr(api_parent_module, api_sub_name)
103
+ if not callable(api_instance):
104
+ err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not callable"
105
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
106
+
107
+ return api_instance
108
+
109
+ @staticmethod
110
+ def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform):
111
+ inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
112
+ for compute_element in api_input_aggregation.inputs)
113
+ kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform)
114
+ for key, value in api_input_aggregation.kwargs.items()}
115
+ gradient_inputs = api_input_aggregation.gradient_inputs
116
+
117
+ if forward_or_backward == Const.FORWARD:
118
+ forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
119
+ forward_result_tuple = convert_to_tuple(forward_result)
120
+ res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
121
+ else:
122
+ if gradient_inputs is None:
123
+ err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing"
124
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
125
+ gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
126
+ for compute_element in gradient_inputs)
127
+ if api_platform == Const.MS_FRAMEWORK:
128
+ if len(gradient_inputs) == 1:
129
+ gradient_inputs = gradient_inputs[0]
130
+ def api_with_kwargs(*forward_inputs):
131
+ return api_instance(*forward_inputs, **kwargs)
132
+ grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs)
133
+ backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
134
+ backward_result_tuple = convert_to_tuple(backward_result)
135
+ res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
136
+ else:
137
+ #set requires_grad
138
+ requires_grad_index = []
139
+ for index, tensor in enumerate(inputs):
140
+ if isinstance(tensor, torch.Tensor) and \
141
+ torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list:
142
+ setattr(tensor, "requires_grad", True)
143
+ requires_grad_index.append(index)
144
+ forward_results = api_instance(*inputs, **kwargs)
145
+ forward_results = convert_to_tuple(forward_results)
146
+ for forward_res, gradient_in in zip(forward_results, gradient_inputs):
147
+ forward_res.backward(gradient_in)
148
+ backward_result_list = []
149
+ for index in requires_grad_index:
150
+ backward_result_list.append(getattr(inputs[index], "grad"))
151
+ res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_list]
152
+
153
+ return res_compute_element_list
154
+
155
+
152
156
  api_runner = ApiRunner()