mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,246 +1,279 @@
1
- import json
2
- import os
3
-
4
- from msprobe.core.common.file_check import FileOpen
5
- from msprobe.core.common.utils import write_csv, add_time_as_suffix
6
- from msprobe.core.common.const import Const, CompareConst, MsCompareConst
7
- from msprobe.core.common.log import logger
8
- from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
9
- from msprobe.mindspore.api_accuracy_checker.api_runner import api_runner, ApiInputAggregation
10
- from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
11
- from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
12
-
13
-
14
- class BasicInfoAndStatus:
15
- def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
16
- self.api_name = api_name
17
- self.bench_dtype = bench_dtype
18
- self.tested_dtype = tested_dtype
19
- self.shape = shape
20
- self.status = status
21
- self.err_msg = err_msg
22
-
23
- class ResultCsvEntry:
24
- def __init__(self) -> None:
25
- self.forward_pass_status = None
26
- self.backward_pass_status = None
27
- self.forward_err_msg = ""
28
- self.backward_err_msg = ""
29
- self.overall_err_msg = None
30
-
31
-
32
- class ApiAccuracyChecker:
33
- def __init__(self):
34
- self.api_infos = dict()
35
- self.results = dict()
36
-
37
- @staticmethod
38
- def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
39
- '''
40
- Args:
41
- api_info: ApiInfo
42
- api_name_str: str
43
- api_input_aggregation: ApiInputAggregation
44
- forward_or_backward: str: Union["forward", "backward"]
45
-
46
- Return:
47
- output_list: List[tuple(str, str, BasicInfoAndStatus, dict{str: CompareResult})]
48
-
49
- Description:
50
- get mindspore api output, run torch api and get output.
51
- compare output.
52
- record compare result.
53
- '''
54
- # get output
55
- if global_context.get_is_constructed():
56
- # constructed situation, need use constructed input to run mindspore api getting tested_output
57
- tested_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.MS_FRAMEWORK)
58
- else:
59
- tested_outputs = api_info.get_compute_element_list(forward_or_backward, Const.OUTPUT)
60
- bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
61
-
62
- # compare output
63
- output_list = []
64
- for i, (bench_out, tested_out) in enumerate(zip(bench_outputs, tested_outputs)):
65
- api_name_with_slot = Const.SEP.join([api_name_str, forward_or_backward, Const.OUTPUT, str(i)])
66
- bench_dtype = bench_out.get_dtype()
67
- tested_dtype = tested_out.get_dtype()
68
- shape = bench_out.get_shape()
69
-
70
- compare_result_dict = dict()
71
- for compare_algorithm_name, compare_algorithm in compare_algorithms.items():
72
- compare_result = compare_algorithm(bench_out, tested_out)
73
- compare_result_dict[compare_algorithm_name] = compare_result
74
-
75
- if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \
76
- compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
77
- status = CompareConst.PASS
78
- err_msg = ""
79
- else:
80
- status = CompareConst.ERROR
81
- err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg + \
82
- compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
83
- basic_info_status = \
84
- BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
85
- output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
86
- return output_list
87
-
88
- def parse(self, api_info_path):
89
- with FileOpen(api_info_path, "r") as f:
90
- api_info_dict = json.load(f)
91
-
92
- # init global context
93
- task = check_and_get_from_json_dict(api_info_dict, MsCompareConst.TASK_FIELD,
94
- "task field in api_info.json",accepted_type=str,
95
- accepted_value=(MsCompareConst.STATISTICS_TASK,
96
- MsCompareConst.TENSOR_TASK))
97
- is_constructed = task == MsCompareConst.STATISTICS_TASK
98
- if not is_constructed:
99
- dump_data_dir = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DUMP_DATA_DIR_FIELD,
100
- "dump_data_dir field in api_info.json", accepted_type=str)
101
- else:
102
- dump_data_dir = ""
103
- global_context.init(is_constructed, dump_data_dir)
104
-
105
- api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
106
- "data field in api_info.json", accepted_type=dict)
107
- for api_name, api_info in api_info_data.items():
108
- is_mint = api_name.split(Const.SEP)[0] in \
109
- (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL)
110
- if not is_mint:
111
- continue
112
- forbackward_str = api_name.split(Const.SEP)[-1]
113
- if forbackward_str not in (Const.FORWARD, Const.BACKWARD):
114
- logger.warning(f"api: {api_name} is not recognized as forward api or backward api, skip this.")
115
- api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
116
- if api_name not in self.api_infos:
117
- self.api_infos[api_name] = ApiInfo(api_name)
118
-
119
- if forbackward_str == Const.FORWARD:
120
- self.api_infos[api_name].load_forward_info(api_info)
121
- else:
122
- self.api_infos[api_name].load_backward_info(api_info)
123
-
124
- def run_and_compare(self):
125
- for api_name_str, api_info in self.api_infos.items():
126
- if not api_info.check_forward_info():
127
- logger.warning(f"api: {api_name_str} is lack of forward infomation, skip forward and backward check")
128
- continue
129
- forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
130
- kwargs = api_info.get_kwargs()
131
- forward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, None)
132
- forward_output_list = None
133
- try:
134
- forward_output_list = \
135
- self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
136
- except Exception as e:
137
- logger.warning(f"exception occurs when running and comparing {api_name_str} forward api"
138
- f"detailed exception information: {e}")
139
- self.record(forward_output_list)
140
-
141
- if not api_info.check_backward_info():
142
- logger.warning(f"api: {api_name_str} is lack of backward infomation, skip backward check")
143
- continue
144
- gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
145
- backward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
146
- backward_output_list = None
147
- try:
148
- backward_output_list = \
149
- self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
150
- except Exception as e:
151
- logger.warning(f"exception occurs when running and comparing {api_name_str} backward api"
152
- f"detailed exception information: {e}")
153
- self.record(backward_output_list)
154
-
155
- def record(self, output_list):
156
- if output_list is None:
157
- return
158
- for output in output_list:
159
- api_real_name, forward_or_backward, basic_info, compare_result_dict = output
160
- key = tuple([api_real_name, forward_or_backward])
161
- if key not in self.results:
162
- self.results[key] = []
163
- self.results[key].append(tuple([basic_info, compare_result_dict]))
164
-
165
-
166
- def to_detail_csv(self, csv_dir):
167
- # detail_csv
168
- detail_csv = []
169
- detail_csv_header_basic_info = [
170
- MsCompareConst.DETAIL_CSV_API_NAME,
171
- MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
172
- MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
173
- MsCompareConst.DETAIL_CSV_SHAPE,
174
- ]
175
- detail_csv_header_compare_result = list(compare_algorithms.keys())
176
- detail_csv_header_status = [
177
- MsCompareConst.DETAIL_CSV_PASS_STATUS,
178
- MsCompareConst.DETAIL_CSV_MESSAGE,
179
- ]
180
-
181
- detail_csv_header = detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
182
- detail_csv.append(detail_csv_header)
183
-
184
- for _, results in self.results.items():
185
- # detail csv
186
- for res in results:
187
- basic_info, compare_result_dict = res
188
- csv_row_basic_info = \
189
- [basic_info.api_name, basic_info.bench_dtype, basic_info.tested_dtype, basic_info.shape]
190
- csv_row_compare_result = list(compare_result_dict.get(algorithm_name).compare_value \
191
- for algorithm_name in detail_csv_header_compare_result)
192
- csv_row_status = [basic_info.status, basic_info.err_msg]
193
- csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
194
- detail_csv.append(csv_row)
195
-
196
- file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.DETAIL_CSV_FILE_NAME))
197
- write_csv(detail_csv, file_name, mode="w")
198
-
199
-
200
- def to_result_csv(self, csv_dir):
201
- result_csv_dict = dict()
202
- for key, results in self.results.items():
203
- api_real_name, forward_or_backward = key
204
- forward_or_backward_pass_status = CompareConst.PASS
205
- forward_or_backward_overall_err_msg = ""
206
- # detail csv
207
- for res in results:
208
- basic_info, _ = res
209
- if basic_info.status != CompareConst.PASS:
210
- forward_or_backward_pass_status = CompareConst.ERROR
211
- forward_or_backward_overall_err_msg += basic_info.err_msg
212
- forward_or_backward_overall_err_msg = \
213
- "" if forward_or_backward_pass_status == CompareConst.PASS else forward_or_backward_overall_err_msg
214
-
215
- #result_csv_dict
216
- if api_real_name not in result_csv_dict:
217
- result_csv_dict[api_real_name] = ResultCsvEntry()
218
- if forward_or_backward == Const.FORWARD:
219
- result_csv_dict[api_real_name].forward_pass_status = forward_or_backward_pass_status
220
- result_csv_dict[api_real_name].forward_err_msg = forward_or_backward_overall_err_msg
221
- else:
222
- result_csv_dict[api_real_name].backward_pass_status = forward_or_backward_pass_status
223
- result_csv_dict[api_real_name].backward_err_msg = forward_or_backward_overall_err_msg
224
-
225
- #result_csv
226
- result_csv = []
227
- result_csv_header = [
228
- MsCompareConst.DETAIL_CSV_API_NAME,
229
- MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
230
- MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
231
- MsCompareConst.DETAIL_CSV_MESSAGE,
232
- ]
233
- result_csv.append(result_csv_header)
234
-
235
- for api_name, result_csv_entry in result_csv_dict.items():
236
- if result_csv_entry.forward_pass_status == CompareConst.PASS and \
237
- result_csv_entry.backward_pass_status == CompareConst.PASS:
238
- overall_err_msg = ""
239
- else:
240
- overall_err_msg = result_csv_entry.forward_err_msg + result_csv_entry.backward_err_msg
241
- row = [api_name, result_csv_entry.forward_pass_status,
242
- result_csv_entry.backward_pass_status, overall_err_msg]
243
- result_csv.append(row)
244
-
245
- file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
1
+ import json
2
+ import os
3
+
4
+ from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv
5
+ from msprobe.core.common.utils import add_time_as_suffix
6
+ from msprobe.core.common.const import Const, CompareConst, MsCompareConst
7
+ from msprobe.mindspore.common.log import logger
8
+ from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
9
+ from msprobe.mindspore.api_accuracy_checker.api_runner import api_runner, ApiInputAggregation
10
+ from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
11
+ from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context,
12
+ trim_output_compute_element_list)
13
+
14
+
15
+ class BasicInfoAndStatus:
16
+ def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
17
+ self.api_name = api_name
18
+ self.bench_dtype = bench_dtype
19
+ self.tested_dtype = tested_dtype
20
+ self.shape = shape
21
+ self.status = status
22
+ self.err_msg = err_msg
23
+
24
+ class ResultCsvEntry:
25
+ def __init__(self) -> None:
26
+ self.forward_pass_status = None
27
+ self.backward_pass_status = None
28
+ self.forward_err_msg = ""
29
+ self.backward_err_msg = ""
30
+ self.overall_err_msg = None
31
+
32
+
33
+ class ApiAccuracyChecker:
34
+ def __init__(self):
35
+ self.api_infos = dict()
36
+ self.results = dict()
37
+
38
+ @staticmethod
39
+ def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
40
+ '''
41
+ Args:
42
+ api_info: ApiInfo
43
+ api_name_str: str
44
+ api_input_aggregation: ApiInputAggregation
45
+ forward_or_backward: str: Union["forward", "backward"]
46
+
47
+ Return:
48
+ output_list: List[tuple(str, str, BasicInfoAndStatus, dict{str: CompareResult})]
49
+
50
+ Description:
51
+ get mindspore api output, run torch api and get output.
52
+ compare output.
53
+ record compare result.
54
+ '''
55
+ # get output
56
+ if global_context.get_is_constructed():
57
+ # constructed situation, need use constructed input to run mindspore api getting tested_output
58
+ tested_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.MS_FRAMEWORK)
59
+ else:
60
+ tested_outputs = api_info.get_compute_element_list(forward_or_backward, Const.OUTPUT)
61
+ bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
62
+ tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward)
63
+ bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward)
64
+ if len(tested_outputs) != len(bench_outputs):
65
+ logger.warning(f"ApiAccuracyChecker.run_and_compare_helper: api: {api_name_str}.{forward_or_backward}, "
66
+ "number of bench outputs and tested outputs is different, comparing result can be wrong. "
67
+ f"tested outputs: {len(tested_outputs)}, bench outputs: {len(bench_outputs)}")
68
+
69
+ # compare output
70
+ output_list = []
71
+ for i, (bench_out, tested_out) in enumerate(zip(bench_outputs, tested_outputs)):
72
+ api_name_with_slot = Const.SEP.join([api_name_str, forward_or_backward, Const.OUTPUT, str(i)])
73
+ bench_dtype = bench_out.get_dtype()
74
+ tested_dtype = tested_out.get_dtype()
75
+ shape = bench_out.get_shape()
76
+
77
+ compare_result_dict = dict()
78
+ for compare_algorithm_name, compare_algorithm in compare_algorithms.items():
79
+ compare_result = compare_algorithm(bench_out, tested_out)
80
+ compare_result_dict[compare_algorithm_name] = compare_result
81
+
82
+ if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \
83
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
84
+ status = CompareConst.PASS
85
+ err_msg = ""
86
+ else:
87
+ status = CompareConst.ERROR
88
+ err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg + \
89
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
90
+ basic_info_status = \
91
+ BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
92
+ output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
93
+ return output_list
94
+
95
+ @staticmethod
96
+ def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD):
97
+ '''
98
+ Args:
99
+ api_info: ApiInfo
100
+ forward_or_backward: str
101
+ Returns:
102
+ ApiInputAggregation
103
+ '''
104
+ forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
105
+ kwargs = api_info.get_kwargs()
106
+ if forward_or_backward == Const.FORWARD:
107
+ gradient_inputs = None
108
+ else:
109
+ gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
110
+ return ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
111
+
112
+ def parse(self, api_info_path):
113
+ with FileOpen(api_info_path, "r") as f:
114
+ api_info_dict = json.load(f)
115
+
116
+ # init global context
117
+ task = check_and_get_from_json_dict(api_info_dict, MsCompareConst.TASK_FIELD,
118
+ "task field in api_info.json",accepted_type=str,
119
+ accepted_value=(MsCompareConst.STATISTICS_TASK,
120
+ MsCompareConst.TENSOR_TASK))
121
+ is_constructed = task == MsCompareConst.STATISTICS_TASK
122
+ if not is_constructed:
123
+ dump_data_dir = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DUMP_DATA_DIR_FIELD,
124
+ "dump_data_dir field in api_info.json", accepted_type=str)
125
+ else:
126
+ dump_data_dir = ""
127
+ global_context.init(is_constructed, dump_data_dir)
128
+
129
+ api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
130
+ "data field in api_info.json", accepted_type=dict)
131
+ for api_name, api_info in api_info_data.items():
132
+ is_mint = api_name.split(Const.SEP)[0] in \
133
+ (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL)
134
+ if not is_mint:
135
+ continue
136
+ forbackward_str = api_name.split(Const.SEP)[-1]
137
+ if forbackward_str not in (Const.FORWARD, Const.BACKWARD):
138
+ logger.warning(f"api: {api_name} is not recognized as forward api or backward api, skip this.")
139
+ api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
140
+ if api_name not in self.api_infos:
141
+ self.api_infos[api_name] = ApiInfo(api_name)
142
+
143
+ if forbackward_str == Const.FORWARD:
144
+ self.api_infos[api_name].load_forward_info(api_info)
145
+ else:
146
+ self.api_infos[api_name].load_backward_info(api_info)
147
+
148
+ def run_and_compare(self):
149
+ for api_name_str, api_info in self.api_infos.items():
150
+ if not api_info.check_forward_info():
151
+ logger.warning(f"api: {api_name_str} is lack of forward infomation, skip forward and backward check.")
152
+ continue
153
+ try:
154
+ forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
155
+ except Exception as e:
156
+ logger.warning(f"exception occurs when getting inputs for {api_name_str} forward api. "
157
+ f"skip forward and backward check. detailed exception information: {e}.")
158
+ continue
159
+ forward_output_list = None
160
+ try:
161
+ forward_output_list = \
162
+ self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
163
+ except Exception as e:
164
+ logger.warning(f"exception occurs when running and comparing {api_name_str} forward api. "
165
+ f"detailed exception information: {e}.")
166
+ self.record(forward_output_list)
167
+
168
+ if not api_info.check_backward_info():
169
+ logger.warning(f"api: {api_name_str} is lack of backward infomation, skip backward check.")
170
+ continue
171
+ try:
172
+ backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
173
+ except Exception as e:
174
+ logger.warning(f"exception occurs when getting inputs for {api_name_str} backward api. "
175
+ f"skip backward check. detailed exception information: {e}.")
176
+ continue
177
+ backward_output_list = None
178
+ try:
179
+ backward_output_list = \
180
+ self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
181
+ except Exception as e:
182
+ logger.warning(f"exception occurs when running and comparing {api_name_str} backward api. "
183
+ f"detailed exception information: {e}.")
184
+ self.record(backward_output_list)
185
+
186
+ def record(self, output_list):
187
+ if output_list is None:
188
+ return
189
+ for output in output_list:
190
+ api_real_name, forward_or_backward, basic_info, compare_result_dict = output
191
+ key = tuple([api_real_name, forward_or_backward])
192
+ if key not in self.results:
193
+ self.results[key] = []
194
+ self.results[key].append(tuple([basic_info, compare_result_dict]))
195
+
196
+
197
+ def to_detail_csv(self, csv_dir):
198
+ # detail_csv
199
+ detail_csv = []
200
+ detail_csv_header_basic_info = [
201
+ MsCompareConst.DETAIL_CSV_API_NAME,
202
+ MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
203
+ MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
204
+ MsCompareConst.DETAIL_CSV_SHAPE,
205
+ ]
206
+ detail_csv_header_compare_result = list(compare_algorithms.keys())
207
+ detail_csv_header_status = [
208
+ MsCompareConst.DETAIL_CSV_PASS_STATUS,
209
+ MsCompareConst.DETAIL_CSV_MESSAGE,
210
+ ]
211
+
212
+ detail_csv_header = detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
213
+ detail_csv.append(detail_csv_header)
214
+
215
+ for _, results in self.results.items():
216
+ # detail csv
217
+ for res in results:
218
+ basic_info, compare_result_dict = res
219
+ csv_row_basic_info = \
220
+ [basic_info.api_name, basic_info.bench_dtype, basic_info.tested_dtype, basic_info.shape]
221
+ csv_row_compare_result = list(compare_result_dict.get(algorithm_name).compare_value \
222
+ for algorithm_name in detail_csv_header_compare_result)
223
+ csv_row_status = [basic_info.status, basic_info.err_msg]
224
+ csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
225
+ detail_csv.append(csv_row)
226
+
227
+ file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.DETAIL_CSV_FILE_NAME))
228
+ create_directory(csv_dir)
229
+ write_csv(detail_csv, file_name, mode="w")
230
+
231
+
232
+ def to_result_csv(self, csv_dir):
233
+ result_csv_dict = dict()
234
+ for key, results in self.results.items():
235
+ api_real_name, forward_or_backward = key
236
+ forward_or_backward_pass_status = CompareConst.PASS
237
+ forward_or_backward_overall_err_msg = ""
238
+ # detail csv
239
+ for res in results:
240
+ basic_info, _ = res
241
+ if basic_info.status != CompareConst.PASS:
242
+ forward_or_backward_pass_status = CompareConst.ERROR
243
+ forward_or_backward_overall_err_msg += basic_info.err_msg
244
+ forward_or_backward_overall_err_msg = \
245
+ "" if forward_or_backward_pass_status == CompareConst.PASS else forward_or_backward_overall_err_msg
246
+
247
+ #result_csv_dict
248
+ if api_real_name not in result_csv_dict:
249
+ result_csv_dict[api_real_name] = ResultCsvEntry()
250
+ if forward_or_backward == Const.FORWARD:
251
+ result_csv_dict[api_real_name].forward_pass_status = forward_or_backward_pass_status
252
+ result_csv_dict[api_real_name].forward_err_msg = forward_or_backward_overall_err_msg
253
+ else:
254
+ result_csv_dict[api_real_name].backward_pass_status = forward_or_backward_pass_status
255
+ result_csv_dict[api_real_name].backward_err_msg = forward_or_backward_overall_err_msg
256
+
257
+ #result_csv
258
+ result_csv = []
259
+ result_csv_header = [
260
+ MsCompareConst.DETAIL_CSV_API_NAME,
261
+ MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
262
+ MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
263
+ MsCompareConst.DETAIL_CSV_MESSAGE,
264
+ ]
265
+ result_csv.append(result_csv_header)
266
+
267
+ for api_name, result_csv_entry in result_csv_dict.items():
268
+ if result_csv_entry.forward_pass_status == CompareConst.PASS and \
269
+ result_csv_entry.backward_pass_status == CompareConst.PASS:
270
+ overall_err_msg = ""
271
+ else:
272
+ overall_err_msg = result_csv_entry.forward_err_msg + result_csv_entry.backward_err_msg
273
+ row = [api_name, result_csv_entry.forward_pass_status,
274
+ result_csv_entry.backward_pass_status, overall_err_msg]
275
+ result_csv.append(row)
276
+
277
+ file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
278
+ create_directory(csv_dir)
246
279
  write_csv(result_csv, file_name, mode="w")