mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,246 +1,255 @@
1
- import json
2
- import os
3
-
4
- from msprobe.core.common.file_check import FileOpen
5
- from msprobe.core.common.utils import write_csv, add_time_as_suffix
6
- from msprobe.core.common.const import Const, CompareConst, MsCompareConst
7
- from msprobe.core.common.log import logger
8
- from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
9
- from msprobe.mindspore.api_accuracy_checker.api_runner import api_runner, ApiInputAggregation
10
- from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
11
- from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
12
-
13
-
14
- class BasicInfoAndStatus:
15
- def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
16
- self.api_name = api_name
17
- self.bench_dtype = bench_dtype
18
- self.tested_dtype = tested_dtype
19
- self.shape = shape
20
- self.status = status
21
- self.err_msg = err_msg
22
-
23
- class ResultCsvEntry:
24
- def __init__(self) -> None:
25
- self.forward_pass_status = None
26
- self.backward_pass_status = None
27
- self.forward_err_msg = ""
28
- self.backward_err_msg = ""
29
- self.overall_err_msg = None
30
-
31
-
32
- class ApiAccuracyChecker:
33
- def __init__(self):
34
- self.api_infos = dict()
35
- self.results = dict()
36
-
37
- @staticmethod
38
- def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
39
- '''
40
- Args:
41
- api_info: ApiInfo
42
- api_name_str: str
43
- api_input_aggregation: ApiInputAggregation
44
- forward_or_backward: str: Union["forward", "backward"]
45
-
46
- Return:
47
- output_list: List[tuple(str, str, BasicInfoAndStatus, dict{str: CompareResult})]
48
-
49
- Description:
50
- get mindspore api output, run torch api and get output.
51
- compare output.
52
- record compare result.
53
- '''
54
- # get output
55
- if global_context.get_is_constructed():
56
- # constructed situation, need use constructed input to run mindspore api getting tested_output
57
- tested_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.MS_FRAMEWORK)
58
- else:
59
- tested_outputs = api_info.get_compute_element_list(forward_or_backward, Const.OUTPUT)
60
- bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
61
-
62
- # compare output
63
- output_list = []
64
- for i, (bench_out, tested_out) in enumerate(zip(bench_outputs, tested_outputs)):
65
- api_name_with_slot = Const.SEP.join([api_name_str, forward_or_backward, Const.OUTPUT, str(i)])
66
- bench_dtype = bench_out.get_dtype()
67
- tested_dtype = tested_out.get_dtype()
68
- shape = bench_out.get_shape()
69
-
70
- compare_result_dict = dict()
71
- for compare_algorithm_name, compare_algorithm in compare_algorithms.items():
72
- compare_result = compare_algorithm(bench_out, tested_out)
73
- compare_result_dict[compare_algorithm_name] = compare_result
74
-
75
- if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \
76
- compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
77
- status = CompareConst.PASS
78
- err_msg = ""
79
- else:
80
- status = CompareConst.ERROR
81
- err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg + \
82
- compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
83
- basic_info_status = \
84
- BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
85
- output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
86
- return output_list
87
-
88
- def parse(self, api_info_path):
89
- with FileOpen(api_info_path, "r") as f:
90
- api_info_dict = json.load(f)
91
-
92
- # init global context
93
- task = check_and_get_from_json_dict(api_info_dict, MsCompareConst.TASK_FIELD,
94
- "task field in api_info.json",accepted_type=str,
95
- accepted_value=(MsCompareConst.STATISTICS_TASK,
96
- MsCompareConst.TENSOR_TASK))
97
- is_constructed = task == MsCompareConst.STATISTICS_TASK
98
- if not is_constructed:
99
- dump_data_dir = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DUMP_DATA_DIR_FIELD,
100
- "dump_data_dir field in api_info.json", accepted_type=str)
101
- else:
102
- dump_data_dir = ""
103
- global_context.init(is_constructed, dump_data_dir)
104
-
105
- api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
106
- "data field in api_info.json", accepted_type=dict)
107
- for api_name, api_info in api_info_data.items():
108
- is_mint = api_name.split(Const.SEP)[0] in \
109
- (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL)
110
- if not is_mint:
111
- continue
112
- forbackward_str = api_name.split(Const.SEP)[-1]
113
- if forbackward_str not in (Const.FORWARD, Const.BACKWARD):
114
- logger.warning(f"api: {api_name} is not recognized as forward api or backward api, skip this.")
115
- api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
116
- if api_name not in self.api_infos:
117
- self.api_infos[api_name] = ApiInfo(api_name)
118
-
119
- if forbackward_str == Const.FORWARD:
120
- self.api_infos[api_name].load_forward_info(api_info)
121
- else:
122
- self.api_infos[api_name].load_backward_info(api_info)
123
-
124
- def run_and_compare(self):
125
- for api_name_str, api_info in self.api_infos.items():
126
- if not api_info.check_forward_info():
127
- logger.warning(f"api: {api_name_str} is lack of forward infomation, skip forward and backward check")
128
- continue
129
- forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
130
- kwargs = api_info.get_kwargs()
131
- forward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, None)
132
- forward_output_list = None
133
- try:
134
- forward_output_list = \
135
- self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
136
- except Exception as e:
137
- logger.warning(f"exception occurs when running and comparing {api_name_str} forward api"
138
- f"detailed exception information: {e}")
139
- self.record(forward_output_list)
140
-
141
- if not api_info.check_backward_info():
142
- logger.warning(f"api: {api_name_str} is lack of backward infomation, skip backward check")
143
- continue
144
- gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
145
- backward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
146
- backward_output_list = None
147
- try:
148
- backward_output_list = \
149
- self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
150
- except Exception as e:
151
- logger.warning(f"exception occurs when running and comparing {api_name_str} backward api"
152
- f"detailed exception information: {e}")
153
- self.record(backward_output_list)
154
-
155
- def record(self, output_list):
156
- if output_list is None:
157
- return
158
- for output in output_list:
159
- api_real_name, forward_or_backward, basic_info, compare_result_dict = output
160
- key = tuple([api_real_name, forward_or_backward])
161
- if key not in self.results:
162
- self.results[key] = []
163
- self.results[key].append(tuple([basic_info, compare_result_dict]))
164
-
165
-
166
- def to_detail_csv(self, csv_dir):
167
- # detail_csv
168
- detail_csv = []
169
- detail_csv_header_basic_info = [
170
- MsCompareConst.DETAIL_CSV_API_NAME,
171
- MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
172
- MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
173
- MsCompareConst.DETAIL_CSV_SHAPE,
174
- ]
175
- detail_csv_header_compare_result = list(compare_algorithms.keys())
176
- detail_csv_header_status = [
177
- MsCompareConst.DETAIL_CSV_PASS_STATUS,
178
- MsCompareConst.DETAIL_CSV_MESSAGE,
179
- ]
180
-
181
- detail_csv_header = detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
182
- detail_csv.append(detail_csv_header)
183
-
184
- for _, results in self.results.items():
185
- # detail csv
186
- for res in results:
187
- basic_info, compare_result_dict = res
188
- csv_row_basic_info = \
189
- [basic_info.api_name, basic_info.bench_dtype, basic_info.tested_dtype, basic_info.shape]
190
- csv_row_compare_result = list(compare_result_dict.get(algorithm_name).compare_value \
191
- for algorithm_name in detail_csv_header_compare_result)
192
- csv_row_status = [basic_info.status, basic_info.err_msg]
193
- csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
194
- detail_csv.append(csv_row)
195
-
196
- file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.DETAIL_CSV_FILE_NAME))
197
- write_csv(detail_csv, file_name, mode="w")
198
-
199
-
200
- def to_result_csv(self, csv_dir):
201
- result_csv_dict = dict()
202
- for key, results in self.results.items():
203
- api_real_name, forward_or_backward = key
204
- forward_or_backward_pass_status = CompareConst.PASS
205
- forward_or_backward_overall_err_msg = ""
206
- # detail csv
207
- for res in results:
208
- basic_info, _ = res
209
- if basic_info.status != CompareConst.PASS:
210
- forward_or_backward_pass_status = CompareConst.ERROR
211
- forward_or_backward_overall_err_msg += basic_info.err_msg
212
- forward_or_backward_overall_err_msg = \
213
- "" if forward_or_backward_pass_status == CompareConst.PASS else forward_or_backward_overall_err_msg
214
-
215
- #result_csv_dict
216
- if api_real_name not in result_csv_dict:
217
- result_csv_dict[api_real_name] = ResultCsvEntry()
218
- if forward_or_backward == Const.FORWARD:
219
- result_csv_dict[api_real_name].forward_pass_status = forward_or_backward_pass_status
220
- result_csv_dict[api_real_name].forward_err_msg = forward_or_backward_overall_err_msg
221
- else:
222
- result_csv_dict[api_real_name].backward_pass_status = forward_or_backward_pass_status
223
- result_csv_dict[api_real_name].backward_err_msg = forward_or_backward_overall_err_msg
224
-
225
- #result_csv
226
- result_csv = []
227
- result_csv_header = [
228
- MsCompareConst.DETAIL_CSV_API_NAME,
229
- MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
230
- MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
231
- MsCompareConst.DETAIL_CSV_MESSAGE,
232
- ]
233
- result_csv.append(result_csv_header)
234
-
235
- for api_name, result_csv_entry in result_csv_dict.items():
236
- if result_csv_entry.forward_pass_status == CompareConst.PASS and \
237
- result_csv_entry.backward_pass_status == CompareConst.PASS:
238
- overall_err_msg = ""
239
- else:
240
- overall_err_msg = result_csv_entry.forward_err_msg + result_csv_entry.backward_err_msg
241
- row = [api_name, result_csv_entry.forward_pass_status,
242
- result_csv_entry.backward_pass_status, overall_err_msg]
243
- result_csv.append(row)
244
-
245
- file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
1
+ import json
2
+ import os
3
+
4
+ from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv
5
+ from msprobe.core.common.utils import add_time_as_suffix
6
+ from msprobe.core.common.const import Const, CompareConst, MsCompareConst
7
+ from msprobe.mindspore.common.log import logger
8
+ from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
9
+ from msprobe.mindspore.api_accuracy_checker.api_runner import api_runner, ApiInputAggregation
10
+ from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
11
+ from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context,
12
+ trim_output_compute_element_list)
13
+
14
+
15
+ class BasicInfoAndStatus:
16
+ def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
17
+ self.api_name = api_name
18
+ self.bench_dtype = bench_dtype
19
+ self.tested_dtype = tested_dtype
20
+ self.shape = shape
21
+ self.status = status
22
+ self.err_msg = err_msg
23
+
24
+ class ResultCsvEntry:
25
+ def __init__(self) -> None:
26
+ self.forward_pass_status = None
27
+ self.backward_pass_status = None
28
+ self.forward_err_msg = ""
29
+ self.backward_err_msg = ""
30
+ self.overall_err_msg = None
31
+
32
+
33
+ class ApiAccuracyChecker:
34
+ def __init__(self):
35
+ self.api_infos = dict()
36
+ self.results = dict()
37
+
38
+ @staticmethod
39
+ def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
40
+ '''
41
+ Args:
42
+ api_info: ApiInfo
43
+ api_name_str: str
44
+ api_input_aggregation: ApiInputAggregation
45
+ forward_or_backward: str: Union["forward", "backward"]
46
+
47
+ Return:
48
+ output_list: List[tuple(str, str, BasicInfoAndStatus, dict{str: CompareResult})]
49
+
50
+ Description:
51
+ get mindspore api output, run torch api and get output.
52
+ compare output.
53
+ record compare result.
54
+ '''
55
+ # get output
56
+ if global_context.get_is_constructed():
57
+ # constructed situation, need use constructed input to run mindspore api getting tested_output
58
+ tested_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.MS_FRAMEWORK)
59
+ else:
60
+ tested_outputs = api_info.get_compute_element_list(forward_or_backward, Const.OUTPUT)
61
+ bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
62
+ tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward)
63
+ bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward)
64
+ if len(tested_outputs) != len(bench_outputs):
65
+ logger.warning(f"ApiAccuracyChecker.run_and_compare_helper: api: {api_name_str}.{forward_or_backward}, "
66
+ "number of bench outputs and tested outputs is different, comparing result can be wrong. "
67
+ f"tested outputs: {len(tested_outputs)}, bench outputs: {len(bench_outputs)}")
68
+
69
+ # compare output
70
+ output_list = []
71
+ for i, (bench_out, tested_out) in enumerate(zip(bench_outputs, tested_outputs)):
72
+ api_name_with_slot = Const.SEP.join([api_name_str, forward_or_backward, Const.OUTPUT, str(i)])
73
+ bench_dtype = bench_out.get_dtype()
74
+ tested_dtype = tested_out.get_dtype()
75
+ shape = bench_out.get_shape()
76
+
77
+ compare_result_dict = dict()
78
+ for compare_algorithm_name, compare_algorithm in compare_algorithms.items():
79
+ compare_result = compare_algorithm(bench_out, tested_out)
80
+ compare_result_dict[compare_algorithm_name] = compare_result
81
+
82
+ if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \
83
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
84
+ status = CompareConst.PASS
85
+ err_msg = ""
86
+ else:
87
+ status = CompareConst.ERROR
88
+ err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg + \
89
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
90
+ basic_info_status = \
91
+ BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
92
+ output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
93
+ return output_list
94
+
95
+ def parse(self, api_info_path):
96
+ with FileOpen(api_info_path, "r") as f:
97
+ api_info_dict = json.load(f)
98
+
99
+ # init global context
100
+ task = check_and_get_from_json_dict(api_info_dict, MsCompareConst.TASK_FIELD,
101
+ "task field in api_info.json",accepted_type=str,
102
+ accepted_value=(MsCompareConst.STATISTICS_TASK,
103
+ MsCompareConst.TENSOR_TASK))
104
+ is_constructed = task == MsCompareConst.STATISTICS_TASK
105
+ if not is_constructed:
106
+ dump_data_dir = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DUMP_DATA_DIR_FIELD,
107
+ "dump_data_dir field in api_info.json", accepted_type=str)
108
+ else:
109
+ dump_data_dir = ""
110
+ global_context.init(is_constructed, dump_data_dir)
111
+
112
+ api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
113
+ "data field in api_info.json", accepted_type=dict)
114
+ for api_name, api_info in api_info_data.items():
115
+ is_mint = api_name.split(Const.SEP)[0] in \
116
+ (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL)
117
+ if not is_mint:
118
+ continue
119
+ forbackward_str = api_name.split(Const.SEP)[-1]
120
+ if forbackward_str not in (Const.FORWARD, Const.BACKWARD):
121
+ logger.warning(f"api: {api_name} is not recognized as forward api or backward api, skip this.")
122
+ api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
123
+ if api_name not in self.api_infos:
124
+ self.api_infos[api_name] = ApiInfo(api_name)
125
+
126
+ if forbackward_str == Const.FORWARD:
127
+ self.api_infos[api_name].load_forward_info(api_info)
128
+ else:
129
+ self.api_infos[api_name].load_backward_info(api_info)
130
+
131
+ def run_and_compare(self):
132
+ for api_name_str, api_info in self.api_infos.items():
133
+ if not api_info.check_forward_info():
134
+ logger.warning(f"api: {api_name_str} is lack of forward infomation, skip forward and backward check")
135
+ continue
136
+ forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
137
+ kwargs = api_info.get_kwargs()
138
+ forward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, None)
139
+ forward_output_list = None
140
+ try:
141
+ forward_output_list = \
142
+ self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
143
+ except Exception as e:
144
+ logger.warning(f"exception occurs when running and comparing {api_name_str} forward api"
145
+ f"detailed exception information: {e}")
146
+ self.record(forward_output_list)
147
+
148
+ if not api_info.check_backward_info():
149
+ logger.warning(f"api: {api_name_str} is lack of backward infomation, skip backward check")
150
+ continue
151
+ gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
152
+ backward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
153
+ backward_output_list = None
154
+ try:
155
+ backward_output_list = \
156
+ self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
157
+ except Exception as e:
158
+ logger.warning(f"exception occurs when running and comparing {api_name_str} backward api"
159
+ f"detailed exception information: {e}")
160
+ self.record(backward_output_list)
161
+
162
+ def record(self, output_list):
163
+ if output_list is None:
164
+ return
165
+ for output in output_list:
166
+ api_real_name, forward_or_backward, basic_info, compare_result_dict = output
167
+ key = tuple([api_real_name, forward_or_backward])
168
+ if key not in self.results:
169
+ self.results[key] = []
170
+ self.results[key].append(tuple([basic_info, compare_result_dict]))
171
+
172
+
173
+ def to_detail_csv(self, csv_dir):
174
+ # detail_csv
175
+ detail_csv = []
176
+ detail_csv_header_basic_info = [
177
+ MsCompareConst.DETAIL_CSV_API_NAME,
178
+ MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
179
+ MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
180
+ MsCompareConst.DETAIL_CSV_SHAPE,
181
+ ]
182
+ detail_csv_header_compare_result = list(compare_algorithms.keys())
183
+ detail_csv_header_status = [
184
+ MsCompareConst.DETAIL_CSV_PASS_STATUS,
185
+ MsCompareConst.DETAIL_CSV_MESSAGE,
186
+ ]
187
+
188
+ detail_csv_header = detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
189
+ detail_csv.append(detail_csv_header)
190
+
191
+ for _, results in self.results.items():
192
+ # detail csv
193
+ for res in results:
194
+ basic_info, compare_result_dict = res
195
+ csv_row_basic_info = \
196
+ [basic_info.api_name, basic_info.bench_dtype, basic_info.tested_dtype, basic_info.shape]
197
+ csv_row_compare_result = list(compare_result_dict.get(algorithm_name).compare_value \
198
+ for algorithm_name in detail_csv_header_compare_result)
199
+ csv_row_status = [basic_info.status, basic_info.err_msg]
200
+ csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
201
+ detail_csv.append(csv_row)
202
+
203
+ file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.DETAIL_CSV_FILE_NAME))
204
+ create_directory(csv_dir)
205
+ write_csv(detail_csv, file_name, mode="w")
206
+
207
+
208
+ def to_result_csv(self, csv_dir):
209
+ result_csv_dict = dict()
210
+ for key, results in self.results.items():
211
+ api_real_name, forward_or_backward = key
212
+ forward_or_backward_pass_status = CompareConst.PASS
213
+ forward_or_backward_overall_err_msg = ""
214
+ # detail csv
215
+ for res in results:
216
+ basic_info, _ = res
217
+ if basic_info.status != CompareConst.PASS:
218
+ forward_or_backward_pass_status = CompareConst.ERROR
219
+ forward_or_backward_overall_err_msg += basic_info.err_msg
220
+ forward_or_backward_overall_err_msg = \
221
+ "" if forward_or_backward_pass_status == CompareConst.PASS else forward_or_backward_overall_err_msg
222
+
223
+ #result_csv_dict
224
+ if api_real_name not in result_csv_dict:
225
+ result_csv_dict[api_real_name] = ResultCsvEntry()
226
+ if forward_or_backward == Const.FORWARD:
227
+ result_csv_dict[api_real_name].forward_pass_status = forward_or_backward_pass_status
228
+ result_csv_dict[api_real_name].forward_err_msg = forward_or_backward_overall_err_msg
229
+ else:
230
+ result_csv_dict[api_real_name].backward_pass_status = forward_or_backward_pass_status
231
+ result_csv_dict[api_real_name].backward_err_msg = forward_or_backward_overall_err_msg
232
+
233
+ #result_csv
234
+ result_csv = []
235
+ result_csv_header = [
236
+ MsCompareConst.DETAIL_CSV_API_NAME,
237
+ MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
238
+ MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
239
+ MsCompareConst.DETAIL_CSV_MESSAGE,
240
+ ]
241
+ result_csv.append(result_csv_header)
242
+
243
+ for api_name, result_csv_entry in result_csv_dict.items():
244
+ if result_csv_entry.forward_pass_status == CompareConst.PASS and \
245
+ result_csv_entry.backward_pass_status == CompareConst.PASS:
246
+ overall_err_msg = ""
247
+ else:
248
+ overall_err_msg = result_csv_entry.forward_err_msg + result_csv_entry.backward_err_msg
249
+ row = [api_name, result_csv_entry.forward_pass_status,
250
+ result_csv_entry.backward_pass_status, overall_err_msg]
251
+ result_csv.append(row)
252
+
253
+ file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
254
+ create_directory(csv_dir)
246
255
  write_csv(result_csv, file_name, mode="w")