mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,69 +1,69 @@
1
- from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
2
- from msprobe.core.common.const import Const
3
- from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
4
- from msprobe.core.common.exceptions import ApiAccuracyCheckerException
5
- from msprobe.core.common.log import logger
6
-
7
- class ApiInfo:
8
- def __init__(self, api_name):
9
- self.api_name = api_name
10
- self.forward_info = None
11
- self.backward_info = None
12
-
13
- def load_forward_info(self, forward_info_dict):
14
- self.forward_info = forward_info_dict
15
-
16
- def load_backward_info(self, backward_info_dict):
17
- self.backward_info = backward_info_dict
18
-
19
- def check_forward_info(self):
20
- return self.forward_info is not None
21
-
22
- def check_backward_info(self):
23
- return self.backward_info is not None
24
-
25
- def get_compute_element_list(self, forward_or_backward, input_or_output):
26
- '''
27
- Args:
28
- forward_or_backward: str, Union["forward", "backward"]
29
- input_or_output: str, Union["input", "output"]
30
-
31
- Return:
32
- compute_element_list: List[ComputeElement]
33
- '''
34
- mapping = {
35
- (Const.FORWARD, Const.INPUT): [self.forward_info, Const.INPUT_ARGS,
36
- f"input_args field of {self.api_name} forward api in api_info.json"],
37
- (Const.FORWARD, Const.OUTPUT): [self.forward_info, Const.OUTPUT,
38
- f"output field of {self.api_name} forward api in api_info.json"],
39
- (Const.BACKWARD, Const.INPUT): [self.backward_info, Const.INPUT,
40
- f"input field of {self.api_name} backward api in api_info.json"],
41
- (Const.BACKWARD, Const.OUTPUT): [self.backward_info, Const.OUTPUT,
42
- f"output field of {self.api_name} backward api in api_info.json"]
43
- }
44
- dict_instance, key, key_desc = mapping.get((forward_or_backward, input_or_output))
45
- compute_element_info_list = check_and_get_from_json_dict(dict_instance, key, key_desc, accepted_type=list)
46
- compute_element_list = [ComputeElement(compute_element_info=compute_element_info)
47
- for compute_element_info in compute_element_info_list]
48
- return compute_element_list
49
-
50
- def get_kwargs(self):
51
- '''
52
- Return:
53
- kwargs_compute_element_dict: dict{str: ComputeElement}
54
- '''
55
- kwargs_dict = check_and_get_from_json_dict(self.forward_info, Const.INPUT_KWARGS,
56
- "input_kwargs in api_info.json", accepted_type=dict)
57
- for key_str, compute_element_info in kwargs_dict.items():
58
- if not isinstance(key_str, str):
59
- err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string"
60
- logger.error_log_with_exp(err_msg,
61
- ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
62
- if not isinstance(compute_element_info, (list, dict)):
63
- err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or dict"
64
- logger.error_log_with_exp(err_msg,
65
- ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
66
- kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info)
67
- for key_str, compute_element_info in kwargs_dict.items()}
68
- return kwargs_compute_element_dict
69
-
1
+ from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
2
+ from msprobe.core.common.const import Const
3
+ from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
4
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
5
+ from msprobe.mindspore.common.log import logger
6
+
7
+ class ApiInfo:
8
+ def __init__(self, api_name):
9
+ self.api_name = api_name
10
+ self.forward_info = None
11
+ self.backward_info = None
12
+
13
+ def load_forward_info(self, forward_info_dict):
14
+ self.forward_info = forward_info_dict
15
+
16
+ def load_backward_info(self, backward_info_dict):
17
+ self.backward_info = backward_info_dict
18
+
19
+ def check_forward_info(self):
20
+ return self.forward_info is not None
21
+
22
+ def check_backward_info(self):
23
+ return self.backward_info is not None
24
+
25
+ def get_compute_element_list(self, forward_or_backward, input_or_output):
26
+ '''
27
+ Args:
28
+ forward_or_backward: str, Union["forward", "backward"]
29
+ input_or_output: str, Union["input", "output"]
30
+
31
+ Return:
32
+ compute_element_list: List[ComputeElement]
33
+ '''
34
+ mapping = {
35
+ (Const.FORWARD, Const.INPUT): [self.forward_info, Const.INPUT_ARGS,
36
+ f"input_args field of {self.api_name} forward api in api_info.json"],
37
+ (Const.FORWARD, Const.OUTPUT): [self.forward_info, Const.OUTPUT,
38
+ f"output field of {self.api_name} forward api in api_info.json"],
39
+ (Const.BACKWARD, Const.INPUT): [self.backward_info, Const.INPUT,
40
+ f"input field of {self.api_name} backward api in api_info.json"],
41
+ (Const.BACKWARD, Const.OUTPUT): [self.backward_info, Const.OUTPUT,
42
+ f"output field of {self.api_name} backward api in api_info.json"]
43
+ }
44
+ dict_instance, key, key_desc = mapping.get((forward_or_backward, input_or_output))
45
+ compute_element_info_list = check_and_get_from_json_dict(dict_instance, key, key_desc, accepted_type=list)
46
+ compute_element_list = [ComputeElement(compute_element_info=compute_element_info)
47
+ for compute_element_info in compute_element_info_list]
48
+ return compute_element_list
49
+
50
+ def get_kwargs(self):
51
+ '''
52
+ Return:
53
+ kwargs_compute_element_dict: dict{str: ComputeElement}
54
+ '''
55
+ kwargs_dict = check_and_get_from_json_dict(self.forward_info, Const.INPUT_KWARGS,
56
+ "input_kwargs in api_info.json", accepted_type=dict)
57
+ for key_str, compute_element_info in kwargs_dict.items():
58
+ if not isinstance(key_str, str):
59
+ err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string"
60
+ logger.error_log_with_exp(err_msg,
61
+ ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
62
+ if not isinstance(compute_element_info, (list, dict)):
63
+ err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or dict"
64
+ logger.error_log_with_exp(err_msg,
65
+ ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
66
+ kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info)
67
+ for key_str, compute_element_info in kwargs_dict.items()}
68
+ return kwargs_compute_element_dict
69
+
@@ -1,152 +1,156 @@
1
-
2
-
3
- import mindspore
4
- import torch
5
- from mindspore import ops
6
-
7
- from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
8
- from msprobe.core.common.const import Const, MsCompareConst
9
- from msprobe.core.common.exceptions import ApiAccuracyCheckerException
10
- from msprobe.core.common.log import logger
11
- from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
12
-
13
-
14
- class ApiInputAggregation:
15
- def __init__(self, inputs, kwargs, gradient_inputs) -> None:
16
- '''
17
- Args:
18
- inputs: List[ComputeElement]
19
- kwargs: dict{str: ComputeElement}
20
- gradient_inputs: Union[List[ComputeElement], None]
21
- '''
22
- self.inputs = inputs
23
- self.kwargs = kwargs
24
- self.gradient_inputs = gradient_inputs
25
-
26
- api_parent_module_mapping = {
27
- (MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint,
28
- (MsCompareConst.MINT, Const.PT_FRAMEWORK): torch,
29
- (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
30
- (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional
31
- }
32
-
33
- class ApiRunner:
34
- def __call__(self, api_input_aggregation, api_name_str, forward_or_backward=Const.FORWARD,
35
- api_platform=Const.MS_FRAMEWORK):
36
- '''
37
- Args:
38
- api_input_aggregation: ApiInputAggregation
39
- api_name_str: str, e.g. "MintFunctional.relu.0"
40
- forward_or_backward: str, Union["forward", "backward"]
41
- api_platform: str, Union["mindspore", "torch"]
42
-
43
- Return:
44
- outputs: list[ComputeElement]
45
-
46
- Description:
47
- run mindspore.mint/torch api
48
- '''
49
- api_type_str, api_sub_name = self.get_info_from_name(api_name_str)
50
- api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform)
51
-
52
- return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform)
53
-
54
- @staticmethod
55
- def get_info_from_name(api_name_str):
56
- '''
57
- Args:
58
- api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
59
-
60
- Return:
61
- api_type_str: str, Union["MintFunctional", "Mint"]
62
- api_sub_name: str, e.g. "relu"
63
- '''
64
- api_name_list = api_name_str.split(Const.SEP)
65
- if len(api_name_list) != 3:
66
- err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
67
- logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
68
- api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
69
- if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL]:
70
- err_msg = f"ApiRunner.get_info_from_name failed: not mint or mint.nn.functional api"
71
- logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
72
-
73
- return api_type_str, api_sub_name
74
-
75
- @staticmethod
76
- def get_api_instance(api_type_str, api_sub_name, api_platform):
77
- '''
78
- Args:
79
- api_type_str: str, Union["MintFunctional", "Mint"]
80
- api_sub_name: str, e.g. "relu"
81
- api_platform: str: Union["mindpore", "torch"]
82
-
83
- Return:
84
- api_instance: function object
85
-
86
- Description:
87
- get mindspore.mint/torch api fucntion
88
- mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
89
- mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
90
- '''
91
-
92
- api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
93
- module_str = "mindspore.mint." if api_platform == Const.MS_FRAMEWORK else "torch."
94
- submodule_str = "nn.functional." if api_type_str == MsCompareConst.MINT_FUNCTIONAL else ""
95
- full_api_name = module_str + submodule_str + api_sub_name
96
- if not hasattr(api_parent_module, api_sub_name):
97
- err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
98
- logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
99
-
100
- api_instance = getattr(api_parent_module, api_sub_name)
101
- if not callable(api_instance):
102
- err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not callable"
103
- logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
104
-
105
- return api_instance
106
-
107
- @staticmethod
108
- def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform):
109
- inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
110
- for compute_element in api_input_aggregation.inputs)
111
- kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform)
112
- for key, value in api_input_aggregation.kwargs.items()}
113
- gradient_inputs = api_input_aggregation.gradient_inputs
114
-
115
- if forward_or_backward == Const.FORWARD:
116
- forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
117
- forward_result_tuple = convert_to_tuple(forward_result)
118
- res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
119
- else:
120
- if gradient_inputs is None:
121
- err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing"
122
- logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
123
- gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
124
- for compute_element in gradient_inputs)
125
- if api_platform == Const.MS_FRAMEWORK:
126
- if len(gradient_inputs) == 1:
127
- gradient_inputs = gradient_inputs[0]
128
- def api_with_kwargs(*forward_inputs):
129
- return api_instance(*forward_inputs, **kwargs)
130
- grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs)
131
- backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
132
- backward_result_tuple = convert_to_tuple(backward_result)
133
- res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
134
- else:
135
- #set requires_grad
136
- for tensor in inputs:
137
- if hasattr(tensor, "requires_grad"):
138
- setattr(tensor, "requires_grad", True)
139
- forward_results = api_instance(*inputs, **kwargs)
140
- forward_results = convert_to_tuple(forward_results)
141
- for forward_res, gradient_in in zip(forward_results, gradient_inputs):
142
- forward_res.backward(gradient_in)
143
- backward_result_list = []
144
- for tensor in inputs:
145
- if hasattr(tensor, "grad"):
146
- backward_result_list.append(getattr(tensor, "grad"))
147
- res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_list]
148
-
149
- return res_compute_element_list
150
-
151
-
1
+
2
+
3
+ import mindspore
4
+ import torch
5
+ from mindspore import ops
6
+
7
+ from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
8
+ from msprobe.core.common.const import Const, MsCompareConst
9
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
10
+ from msprobe.mindspore.common.log import logger
11
+ from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
12
+ from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str
13
+
14
+
15
+ class ApiInputAggregation:
16
+ def __init__(self, inputs, kwargs, gradient_inputs) -> None:
17
+ '''
18
+ Args:
19
+ inputs: List[ComputeElement]
20
+ kwargs: dict{str: ComputeElement}
21
+ gradient_inputs: Union[List[ComputeElement], None]
22
+ '''
23
+ self.inputs = inputs
24
+ self.kwargs = kwargs
25
+ self.gradient_inputs = gradient_inputs
26
+
27
+ api_parent_module_mapping = {
28
+ (MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint,
29
+ (MsCompareConst.MINT, Const.PT_FRAMEWORK): torch,
30
+ (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
31
+ (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional
32
+ }
33
+
34
+
35
+ class ApiRunner:
36
+ def __call__(self, api_input_aggregation, api_name_str, forward_or_backward=Const.FORWARD,
37
+ api_platform=Const.MS_FRAMEWORK):
38
+ '''
39
+ Args:
40
+ api_input_aggregation: ApiInputAggregation
41
+ api_name_str: str, e.g. "MintFunctional.relu.0"
42
+ forward_or_backward: str, Union["forward", "backward"]
43
+ api_platform: str, Union["mindspore", "torch"]
44
+
45
+ Return:
46
+ outputs: list[ComputeElement]
47
+
48
+ Description:
49
+ run mindspore.mint/torch api
50
+ '''
51
+ api_type_str, api_sub_name = self.get_info_from_name(api_name_str)
52
+ api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform)
53
+
54
+ return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform)
55
+
56
+ @staticmethod
57
+ def get_info_from_name(api_name_str):
58
+ '''
59
+ Args:
60
+ api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
61
+
62
+ Return:
63
+ api_type_str: str, Union["MintFunctional", "Mint"]
64
+ api_sub_name: str, e.g. "relu"
65
+ '''
66
+ api_name_list = api_name_str.split(Const.SEP)
67
+ if len(api_name_list) != 3:
68
+ err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
69
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
70
+ api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
71
+ if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL]:
72
+ err_msg = f"ApiRunner.get_info_from_name failed: not mint or mint.nn.functional api"
73
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
74
+
75
+ return api_type_str, api_sub_name
76
+
77
+ @staticmethod
78
+ def get_api_instance(api_type_str, api_sub_name, api_platform):
79
+ '''
80
+ Args:
81
+ api_type_str: str, Union["MintFunctional", "Mint"]
82
+ api_sub_name: str, e.g. "relu"
83
+ api_platform: str: Union["mindpore", "torch"]
84
+
85
+ Return:
86
+ api_instance: function object
87
+
88
+ Description:
89
+ get mindspore.mint/torch api fucntion
90
+ mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
91
+ mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
92
+ '''
93
+
94
+ api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
95
+ module_str = "mindspore.mint." if api_platform == Const.MS_FRAMEWORK else "torch."
96
+ submodule_str = "nn.functional." if api_type_str == MsCompareConst.MINT_FUNCTIONAL else ""
97
+ full_api_name = module_str + submodule_str + api_sub_name
98
+ if not hasattr(api_parent_module, api_sub_name):
99
+ err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
100
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
101
+
102
+ api_instance = getattr(api_parent_module, api_sub_name)
103
+ if not callable(api_instance):
104
+ err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not callable"
105
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
106
+
107
+ return api_instance
108
+
109
+ @staticmethod
110
+ def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform):
111
+ inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
112
+ for compute_element in api_input_aggregation.inputs)
113
+ kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform)
114
+ for key, value in api_input_aggregation.kwargs.items()}
115
+ gradient_inputs = api_input_aggregation.gradient_inputs
116
+
117
+ if forward_or_backward == Const.FORWARD:
118
+ forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
119
+ forward_result_tuple = convert_to_tuple(forward_result)
120
+ res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
121
+ else:
122
+ if gradient_inputs is None:
123
+ err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing"
124
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
125
+ gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
126
+ for compute_element in gradient_inputs)
127
+ if api_platform == Const.MS_FRAMEWORK:
128
+ if len(gradient_inputs) == 1:
129
+ gradient_inputs = gradient_inputs[0]
130
+ def api_with_kwargs(*forward_inputs):
131
+ return api_instance(*forward_inputs, **kwargs)
132
+ grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs)
133
+ backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
134
+ backward_result_tuple = convert_to_tuple(backward_result)
135
+ res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
136
+ else:
137
+ #set requires_grad
138
+ requires_grad_index = []
139
+ for index, tensor in enumerate(inputs):
140
+ if isinstance(tensor, torch.Tensor) and \
141
+ torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list:
142
+ setattr(tensor, "requires_grad", True)
143
+ requires_grad_index.append(index)
144
+ forward_results = api_instance(*inputs, **kwargs)
145
+ forward_results = convert_to_tuple(forward_results)
146
+ for forward_res, gradient_in in zip(forward_results, gradient_inputs):
147
+ forward_res.backward(gradient_in)
148
+ backward_result_list = []
149
+ for index in requires_grad_index:
150
+ backward_result_list.append(getattr(inputs[index], "grad"))
151
+ res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_list]
152
+
153
+ return res_compute_element_list
154
+
155
+
152
156
  api_runner = ApiRunner()