mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,236 +1,235 @@
1
- # 进行比对及结果展示
2
- import os
3
- import sys
4
- import csv
5
- import json
6
- from collections import namedtuple
7
- from rich.table import Table
8
- from rich.console import Console
9
- from msprobe.core.common.const import CompareConst, FileCheckConst
10
- from msprobe.core.common.file_check import FileOpen, change_mode
11
- from .single_compare import single_benchmark_compare_wrap
12
- from msprobe.pytorch.common.log import logger
13
- from msprobe.core.common.utils import CompareException
14
-
15
- ELEMENT_NUM_THRESHOLD = 100
16
- ZERO_NUM_THRESHOLD = 0.1
17
- FLOAT_PRECISION = 14
18
-
19
- ResultInfo = namedtuple('ResultInfo', ['api_name', 'is_fwd_success', 'is_bwd_success',
20
- 'fwd_compare_alg_results', 'bwd_compare_alg_results'])
21
-
22
- def get_file_content_bytes(file):
23
- with FileOpen(file, 'rb') as file_handle:
24
- return file_handle.read()
25
-
26
-
27
- def get_json_contents(file_path):
28
- ops = get_file_content_bytes(file_path)
29
- try:
30
- json_obj = json.loads(ops)
31
- except ValueError as error:
32
- logger.error('Failed to load "%s". %s' % (file_path, str(error)))
33
- raise CompareException(CompareException.INVALID_FILE_ERROR) from error
34
- if not isinstance(json_obj, dict):
35
- logger.error('Json file %s, content is not a dictionary!' % file_path)
36
- raise CompareException(CompareException.INVALID_FILE_ERROR)
37
- return json_obj
38
-
39
-
40
- def write_csv(data, filepath):
41
- with FileOpen(filepath, 'a', encoding='utf-8-sig') as f:
42
- writer = csv.writer(f)
43
- writer.writerows(data)
44
- change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
45
-
46
-
47
- class Saver:
48
- # consts for result csv
49
- COLUMN_API_NAME = "API name"
50
- COLUMN_FORWARD_SUCCESS = "Forward Test Success"
51
- COLUMN_BACKWARD_SUCCESS = "Backward Test Success"
52
- COLUMN_STACK_INFO = "Traceback callstack info"
53
-
54
- def __init__(self, save_path, detail_save_path, stack_info):
55
- self.save_path = save_path
56
- self.detail_save_path = detail_save_path
57
- self.stack_info = stack_info
58
-
59
- self.test_result_cnt = {
60
- "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0,
61
- "total_num": 0, "forward_or_backward_fail_num": 0
62
- }
63
-
64
- def write_csv_title(self):
65
- summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS, "Message"]]
66
- write_csv(summary_test_rows, self.save_path)
67
-
68
- detail_test_rows = [[
69
- "Npu Name", "Bench Dtype", "NPU Dtype", "Shape",
70
- "error_balance", "max_abs_diff", "max_abs_idx",
71
- "max_rel_diff", "max_rel_idx", "eb_thd",
72
- "error_thd", "Status","Message"
73
- ]]
74
- write_csv(detail_test_rows, self.detail_save_path)
75
-
76
- def print_pretest_result(self):
77
- self.get_statistics_from_result_csv()
78
- if self.test_result_cnt.get("total_num") != 0:
79
- passing_rate = str(self.test_result_cnt.get("success_num") /
80
- (self.test_result_cnt.get("total_num") + sys.float_info.epsilon))
81
- else:
82
- passing_rate = "0"
83
-
84
- console = Console()
85
- table_total = Table(
86
- show_header=True, title="Overall Statistics", show_lines=True, width=75
87
- )
88
- table_total.add_column("Result")
89
- table_total.add_column("Statistics")
90
- table_total.add_row("[green]Pass[/green]", str(self.test_result_cnt.get("success_num")))
91
- table_total.add_row("[red]Fail[/red]", str(self.test_result_cnt.get("forward_and_backward_fail_num") +
92
- self.test_result_cnt.get("forward_or_backward_fail_num")))
93
- table_total.add_row("Passing Rate", passing_rate)
94
-
95
- table_detail = Table(
96
- show_header=True, title="Detail Statistics", show_lines=True, width=75
97
- )
98
- table_detail.add_column("Result")
99
- table_detail.add_column("Statistics")
100
- table_detail.add_row("Only Forward Fail", str(self.test_result_cnt.get("forward_fail_num")))
101
- table_detail.add_row("Only Backward Fail", str(self.test_result_cnt.get("backward_fail_num")))
102
- table_detail.add_row(
103
- "Both Forward & Backward Fail", str(self.test_result_cnt.get("forward_and_backward_fail_num")))
104
-
105
- console.print(table_total)
106
- console.print(table_detail)
107
-
108
- def get_statistics_from_result_csv(self):
109
- checklist = [CompareConst.TRUE, CompareConst.FALSE, CompareConst.NA, CompareConst.SKIP]
110
- with FileOpen(self.save_path, 'r') as file:
111
- reader = csv.reader(file)
112
- result_csv_rows = [row for row in reader]
113
- result_csv_name = os.path.basename(self.save_path)
114
- for item in result_csv_rows[1:]:
115
- if not isinstance(item, list) or len(item) < 3:
116
- raise ValueError("The number of columns in %s is incorrect" % result_csv_name)
117
- if not all(item[i] and item[i].upper() in checklist for i in (1, 2)):
118
- raise ValueError(
119
- "The value in the 2nd or 3rd column of %s is wrong, it must be TRUE, FALSE, SKIP or N/A"
120
- % result_csv_name)
121
- column1 = item[1].upper()
122
- column2 = item[2].upper()
123
- if column1 == CompareConst.SKIP:
124
- continue
125
- self.test_result_cnt["total_num"] += 1
126
- if column1 == CompareConst.TRUE and column2 in [CompareConst.TRUE, 'N/A']:
127
- self.test_result_cnt['success_num'] += 1
128
- elif column1 == CompareConst.FALSE and column2 == CompareConst.FALSE:
129
- self.test_result_cnt['forward_and_backward_fail_num'] += 1
130
- elif column1 == CompareConst.FALSE:
131
- self.test_result_cnt['forward_fail_num'] += 1
132
- self.test_result_cnt['forward_or_backward_fail_num'] += 1
133
- else:
134
- self.test_result_cnt['backward_fail_num'] += 1
135
- self.test_result_cnt['forward_or_backward_fail_num'] += 1
136
-
137
- def write_summary_csv(self, test_result):
138
- test_rows = []
139
- if self.stack_info:
140
- test_rows[0].append(self.COLUMN_STACK_INFO)
141
-
142
- name = test_result.api_name
143
- df_row = [test_result.api_name, test_result.is_fwd_success, test_result.is_bwd_success]
144
- if test_result.is_fwd_success == "SKIP" or test_result.is_bwd_success == "SKIP":
145
- df_row.append(test_result.fwd_compare_alg_results)
146
- if self.stack_info:
147
- stack_info = "\n".join(self.stack_info[name])
148
- df_row.append(stack_info)
149
- test_rows.append(df_row)
150
- write_csv(test_rows, self.save_path)
151
-
152
- def write_detail_csv(self, test_result):
153
- def get_rows_from_list(result, name, sub_prefix):
154
- rows = []
155
- if isinstance(result, list):
156
- for i, test_subject in enumerate(result):
157
- subject = sub_prefix + "." + name + ".output." + str(i)
158
- test_subject = ["{:.{}f}".format(item, FLOAT_PRECISION) if isinstance(item, float) else item for
159
- item in test_subject]
160
- rows.append([subject] + list(test_subject))
161
- return rows
162
-
163
- test_rows = []
164
- subject_prefix = test_result.api_name
165
- fwd_result = test_result.fwd_compare_alg_results
166
- bwd_result = test_result.bwd_compare_alg_results
167
-
168
- test_rows.extend(get_rows_from_list(fwd_result, "forward", subject_prefix))
169
- test_rows.extend(get_rows_from_list(bwd_result, "backward", subject_prefix))
170
-
171
- write_csv(test_rows, self.detail_save_path)
172
-
173
- def record_results(self, result_info):
174
- self.write_summary_csv(result_info)
175
- self.write_detail_csv(result_info)
176
-
177
-
178
- class Comparator:
179
-
180
- def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None):
181
- self.save_path = result_csv_path
182
- self.detail_save_path = details_csv_path
183
- if stack_info_json_path:
184
- self.stack_info = get_json_contents(stack_info_json_path)
185
- else:
186
- self.stack_info = None
187
- self.saver = Saver(result_csv_path, details_csv_path, self.stack_info)
188
-
189
- if is_continue_run_ut and not os.path.exists(self.save_path) and not os.path.exists(self.detail_save_path):
190
- self.saver.write_csv_title()
191
-
192
- @staticmethod
193
- def _compare_core_wrapper(bench_out, npu_out):
194
- detailed_result_total = []
195
- test_final_success = True
196
- status, details = single_benchmark_compare_wrap(npu_out, bench_out)
197
- if not isinstance(status, list):
198
- detailed_result_total.append(details)
199
- test_final_success = status
200
- else:
201
- for item, item_status in enumerate(status):
202
- detailed_result_total.append(details.get(item, 'key does not exist'))
203
- if not item_status:
204
- test_final_success = False
205
- return test_final_success, detailed_result_total
206
-
207
- @staticmethod
208
- def _compare_dropout(bench_out, npu_out):
209
- tensor_num = bench_out.numel()
210
- if tensor_num >= ELEMENT_NUM_THRESHOLD:
211
- if abs((bench_out == 0).sum() - (npu_out == 0).cpu().sum()) / tensor_num < ZERO_NUM_THRESHOLD:
212
- return True, 1
213
- else:
214
- return False, 0
215
- else:
216
- return True, 1
217
-
218
- def compare_output(self, api_name, bench_out, npu_out, bench_grad=None, npu_grad=None):
219
- if "dropout" in api_name:
220
- is_fwd_success, fwd_compare_alg_results = self._compare_dropout(bench_out, npu_out)
221
- else:
222
- is_fwd_success, fwd_compare_alg_results = self._compare_core_wrapper(bench_out, npu_out)
223
- if bench_grad and npu_grad:
224
- if "dropout" in api_name:
225
- is_bwd_success, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], npu_grad[0])
226
- else:
227
- is_bwd_success, bwd_compare_alg_results = self._compare_core_wrapper(bench_grad, npu_grad)
228
- else:
229
- is_bwd_success, bwd_compare_alg_results = True, None
230
- if is_bwd_success and bwd_compare_alg_results is None:
231
- self.saver.record_results(ResultInfo(api_name, is_fwd_success, CompareConst.NAN, fwd_compare_alg_results,
232
- bwd_compare_alg_results))
233
- else:
234
- self.saver.record_results(ResultInfo(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results,
235
- bwd_compare_alg_results))
236
- return is_fwd_success, is_bwd_success
1
+ # 进行比对及结果展示
2
+ import os
3
+ import sys
4
+ import csv
5
+ import json
6
+ from collections import namedtuple
7
+ from rich.table import Table
8
+ from rich.console import Console
9
+ from msprobe.core.common.const import CompareConst, FileCheckConst
10
+ from msprobe.core.common.file_utils import FileOpen, change_mode, read_csv
11
+ from msprobe.pytorch.online_dispatch.single_compare import single_benchmark_compare_wrap
12
+ from msprobe.pytorch.common.log import logger
13
+ from msprobe.core.common.utils import CompareException, check_op_str_pattern_valid
14
+
15
+ ELEMENT_NUM_THRESHOLD = 100
16
+ ZERO_NUM_THRESHOLD = 0.1
17
+ FLOAT_PRECISION = 14
18
+
19
+ ResultInfo = namedtuple('ResultInfo', ['api_name', 'is_fwd_success', 'is_bwd_success',
20
+ 'fwd_compare_alg_results', 'bwd_compare_alg_results'])
21
+
22
+ def get_file_content_bytes(file):
23
+ with FileOpen(file, 'rb') as file_handle:
24
+ return file_handle.read()
25
+
26
+
27
+ def get_json_contents(file_path):
28
+ ops = get_file_content_bytes(file_path)
29
+ try:
30
+ json_obj = json.loads(ops)
31
+ except ValueError as error:
32
+ logger.error('Failed to load "%s". %s' % (file_path, str(error)))
33
+ raise CompareException(CompareException.INVALID_FILE_ERROR) from error
34
+ if not isinstance(json_obj, dict):
35
+ logger.error('Json file %s, content is not a dictionary!' % file_path)
36
+ raise CompareException(CompareException.INVALID_FILE_ERROR)
37
+ return json_obj
38
+
39
+
40
+ def write_csv(data, filepath):
41
+ with FileOpen(filepath, 'a', encoding='utf-8-sig') as f:
42
+ writer = csv.writer(f)
43
+ writer.writerows(data)
44
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
45
+
46
+
47
+ class Saver:
48
+ # consts for result csv
49
+ COLUMN_API_NAME = "API name"
50
+ COLUMN_FORWARD_SUCCESS = "Forward Test Success"
51
+ COLUMN_BACKWARD_SUCCESS = "Backward Test Success"
52
+ COLUMN_STACK_INFO = "Traceback callstack info"
53
+
54
+ def __init__(self, save_path, detail_save_path, stack_info):
55
+ self.save_path = save_path
56
+ self.detail_save_path = detail_save_path
57
+ self.stack_info = stack_info
58
+
59
+ self.test_result_cnt = {
60
+ "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0,
61
+ "total_num": 0, "forward_or_backward_fail_num": 0
62
+ }
63
+
64
+ def write_csv_title(self):
65
+ summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS, "Message"]]
66
+ write_csv(summary_test_rows, self.save_path)
67
+
68
+ detail_test_rows = [[
69
+ "Npu Name", "Bench Dtype", "NPU Dtype", "Shape",
70
+ "error_balance", "max_abs_diff", "max_abs_idx",
71
+ "max_rel_diff", "max_rel_idx", "eb_thd",
72
+ "error_thd", "Status","Message"
73
+ ]]
74
+ write_csv(detail_test_rows, self.detail_save_path)
75
+
76
+ def print_pretest_result(self):
77
+ self.get_statistics_from_result_csv()
78
+ if self.test_result_cnt.get("total_num") != 0:
79
+ passing_rate = str(self.test_result_cnt.get("success_num") /
80
+ (self.test_result_cnt.get("total_num") + sys.float_info.epsilon))
81
+ else:
82
+ passing_rate = "0"
83
+
84
+ console = Console()
85
+ table_total = Table(
86
+ show_header=True, title="Overall Statistics", show_lines=True, width=75
87
+ )
88
+ table_total.add_column("Result")
89
+ table_total.add_column("Statistics")
90
+ table_total.add_row("[green]Pass[/green]", str(self.test_result_cnt.get("success_num")))
91
+ table_total.add_row("[red]Fail[/red]", str(self.test_result_cnt.get("forward_and_backward_fail_num") +
92
+ self.test_result_cnt.get("forward_or_backward_fail_num")))
93
+ table_total.add_row("Passing Rate", passing_rate)
94
+
95
+ table_detail = Table(
96
+ show_header=True, title="Detail Statistics", show_lines=True, width=75
97
+ )
98
+ table_detail.add_column("Result")
99
+ table_detail.add_column("Statistics")
100
+ table_detail.add_row("Only Forward Fail", str(self.test_result_cnt.get("forward_fail_num")))
101
+ table_detail.add_row("Only Backward Fail", str(self.test_result_cnt.get("backward_fail_num")))
102
+ table_detail.add_row(
103
+ "Both Forward & Backward Fail", str(self.test_result_cnt.get("forward_and_backward_fail_num")))
104
+
105
+ console.print(table_total)
106
+ console.print(table_detail)
107
+
108
+ def get_statistics_from_result_csv(self):
109
+ checklist = [CompareConst.TRUE, CompareConst.FALSE, CompareConst.NA, CompareConst.SKIP]
110
+ data = read_csv(self.save_path)
111
+ result_csv_name = os.path.basename(self.save_path)
112
+ for _, row in data.iterrows():
113
+ if len(row) < 3:
114
+ raise ValueError("The number of columns in %s is incorrect" % result_csv_name)
115
+ if not all(row[i] and row[i].upper() in checklist for i in (1, 2)):
116
+ raise ValueError(
117
+ "The value in the 2nd or 3rd column of %s is wrong, it must be TRUE, FALSE, SKIP or N/A"
118
+ % result_csv_name)
119
+ column1 = row[1].upper()
120
+ column2 = row[2].upper()
121
+ if column1 == CompareConst.SKIP:
122
+ continue
123
+ self.test_result_cnt["total_num"] += 1
124
+ if column1 == CompareConst.TRUE and column2 in [CompareConst.TRUE, 'N/A']:
125
+ self.test_result_cnt['success_num'] += 1
126
+ elif column1 == CompareConst.FALSE and column2 == CompareConst.FALSE:
127
+ self.test_result_cnt['forward_and_backward_fail_num'] += 1
128
+ elif column1 == CompareConst.FALSE:
129
+ self.test_result_cnt['forward_fail_num'] += 1
130
+ self.test_result_cnt['forward_or_backward_fail_num'] += 1
131
+ else:
132
+ self.test_result_cnt['backward_fail_num'] += 1
133
+ self.test_result_cnt['forward_or_backward_fail_num'] += 1
134
+
135
+ def write_summary_csv(self, test_result):
136
+ test_rows = []
137
+ if self.stack_info:
138
+ test_rows[0].append(self.COLUMN_STACK_INFO)
139
+
140
+ check_op_str_pattern_valid(test_result.api_name)
141
+ df_row = [test_result.api_name, test_result.is_fwd_success, test_result.is_bwd_success]
142
+ if test_result.is_fwd_success == "SKIP" or test_result.is_bwd_success == "SKIP":
143
+ df_row.append(test_result.fwd_compare_alg_results)
144
+ if self.stack_info:
145
+ check_op_str_pattern_valid(self.stack_info[test_result.api_name])
146
+ stack_info = "\n".join(self.stack_info[test_result.api_name])
147
+ df_row.append(stack_info)
148
+ test_rows.append(df_row)
149
+ write_csv(test_rows, self.save_path)
150
+
151
+ def write_detail_csv(self, test_result):
152
+ def get_rows_from_list(result, name, sub_prefix):
153
+ rows = []
154
+ if isinstance(result, list):
155
+ for i, test_subject in enumerate(result):
156
+ subject = sub_prefix + "." + name + ".output." + str(i)
157
+ test_subject = ["{:.{}f}".format(item, FLOAT_PRECISION) if isinstance(item, float) else item for
158
+ item in test_subject]
159
+ rows.append([subject] + list(test_subject))
160
+ return rows
161
+
162
+ test_rows = []
163
+ subject_prefix = test_result.api_name
164
+ fwd_result = test_result.fwd_compare_alg_results
165
+ bwd_result = test_result.bwd_compare_alg_results
166
+
167
+ test_rows.extend(get_rows_from_list(fwd_result, "forward", subject_prefix))
168
+ test_rows.extend(get_rows_from_list(bwd_result, "backward", subject_prefix))
169
+
170
+ write_csv(test_rows, self.detail_save_path)
171
+
172
+ def record_results(self, result_info):
173
+ self.write_summary_csv(result_info)
174
+ self.write_detail_csv(result_info)
175
+
176
+
177
+ class Comparator:
178
+
179
+ def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None):
180
+ self.save_path = result_csv_path
181
+ self.detail_save_path = details_csv_path
182
+ if stack_info_json_path:
183
+ self.stack_info = get_json_contents(stack_info_json_path)
184
+ else:
185
+ self.stack_info = None
186
+ self.saver = Saver(result_csv_path, details_csv_path, self.stack_info)
187
+
188
+ if is_continue_run_ut and not os.path.exists(self.save_path) and not os.path.exists(self.detail_save_path):
189
+ self.saver.write_csv_title()
190
+
191
+ @staticmethod
192
+ def _compare_core_wrapper(bench_out, npu_out):
193
+ detailed_result_total = []
194
+ test_final_success = True
195
+ status, details = single_benchmark_compare_wrap(npu_out, bench_out)
196
+ if not isinstance(status, list):
197
+ detailed_result_total.append(details)
198
+ test_final_success = status
199
+ else:
200
+ for item, item_status in enumerate(status):
201
+ detailed_result_total.append(details.get(item, 'key does not exist'))
202
+ if not item_status:
203
+ test_final_success = False
204
+ return test_final_success, detailed_result_total
205
+
206
+ @staticmethod
207
+ def _compare_dropout(bench_out, npu_out):
208
+ tensor_num = bench_out.numel()
209
+ if tensor_num >= ELEMENT_NUM_THRESHOLD:
210
+ if abs((bench_out == 0).sum() - (npu_out == 0).cpu().sum()) / tensor_num < ZERO_NUM_THRESHOLD:
211
+ return True, 1
212
+ else:
213
+ return False, 0
214
+ else:
215
+ return True, 1
216
+
217
+ def compare_output(self, api_name, bench_out, npu_out, bench_grad=None, npu_grad=None):
218
+ if "dropout" in api_name:
219
+ is_fwd_success, fwd_compare_alg_results = self._compare_dropout(bench_out, npu_out)
220
+ else:
221
+ is_fwd_success, fwd_compare_alg_results = self._compare_core_wrapper(bench_out, npu_out)
222
+ if bench_grad and npu_grad:
223
+ if "dropout" in api_name:
224
+ is_bwd_success, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], npu_grad[0])
225
+ else:
226
+ is_bwd_success, bwd_compare_alg_results = self._compare_core_wrapper(bench_grad, npu_grad)
227
+ else:
228
+ is_bwd_success, bwd_compare_alg_results = True, None
229
+ if is_bwd_success and bwd_compare_alg_results is None:
230
+ self.saver.record_results(ResultInfo(api_name, is_fwd_success, CompareConst.NAN, fwd_compare_alg_results,
231
+ bwd_compare_alg_results))
232
+ else:
233
+ self.saver.record_results(ResultInfo(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results,
234
+ bwd_compare_alg_results))
235
+ return is_fwd_success, is_bwd_success