mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,381 +1,416 @@
1
- # 进行比对及结果展示
2
- import os
3
- from collections import namedtuple
4
-
5
- import numpy as np
6
- from msprobe.core.common.utils import write_csv, get_json_contents, CompareException
7
- import torch
8
- from msprobe.core.common.const import Const, CompareConst
9
- from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \
10
- get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \
11
- get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \
12
- check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err
13
- from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
14
- from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
15
- from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
16
- DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, absolute_standard_api, binary_standard_api, \
17
- ulp_standard_api, thousandth_standard_api, apis_threshold
18
- from msprobe.pytorch.common.log import logger
19
-
20
-
21
- ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status',
22
- 'fwd_compare_alg_results', 'bwd_compare_alg_results', 'rank'])
23
-
24
-
25
- INDEX_TEST_RESULT_GROUP = 3
26
- INDEX_FIRST_GROUP = 0
27
- INDEX_MESSAGE = -1
28
-
29
-
30
- class Comparator:
31
- # consts for result csv
32
- COLUMN_API_NAME = "API name"
33
- COLUMN_FORWARD_SUCCESS = "Forward Test Success"
34
- COLUMN_BACKWARD_SUCCESS = "Backward Test Success"
35
- COLUMN_STACK_INFO = "Traceback callstack info"
36
-
37
- def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None, config=None):
38
- self.save_path_str = result_csv_path
39
- self.detail_save_path_str = details_csv_path
40
- self.save_path_list = [result_csv_path]
41
- self.detail_save_path_list = [details_csv_path]
42
-
43
- if config and config.online_config.is_online:
44
- self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv")
45
- self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv")
46
- self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list]
47
- self.detail_save_path_list = \
48
- [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
49
-
50
- if not is_continue_run_ut:
51
- self.write_csv_title()
52
- if stack_info_json_path:
53
- self.stack_info = get_json_contents(stack_info_json_path)
54
- else:
55
- self.stack_info = None
56
-
57
- @staticmethod
58
- def get_path_from_rank(rank, path_list, path_pattern):
59
- return path_list[-1] if len(path_list) == 1 else path_pattern.format(rank)
60
-
61
- @staticmethod
62
- def print_pretest_result():
63
- logger.info("Successfully completed run_ut/multi_run_ut.")
64
-
65
- @staticmethod
66
- def _compare_dropout(bench_output, device_output):
67
- tensor_num = bench_output.numel()
68
- if tensor_num >= 100:
69
- if abs((bench_output == 0).sum() - (device_output == 0).cpu().sum()) / tensor_num < 0.1:
70
- return CompareConst.PASS, 1
71
- else:
72
- return CompareConst.ERROR, 0
73
- else:
74
- return CompareConst.PASS, 1
75
-
76
- @staticmethod
77
- def _compare_builtin_type(bench_output, device_output, compare_column):
78
- if not isinstance(bench_output, (bool, int, float, str)):
79
- return CompareConst.PASS, compare_column, ""
80
- if bench_output != device_output:
81
- return CompareConst.ERROR, compare_column, ""
82
- compare_column.error_rate = 0
83
- return CompareConst.PASS, compare_column, ""
84
-
85
- @staticmethod
86
- def _compare_bool_tensor(bench_output, device_output):
87
- error_nums = (bench_output != device_output).sum()
88
- if bench_output.size == 0:
89
- return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result."
90
- error_rate = float(error_nums / bench_output.size)
91
- result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
92
- return error_rate, result, ""
93
-
94
- @staticmethod
95
- def _get_absolute_threshold_attribute(api_name, dtype):
96
- small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value')
97
- small_value_atol = apis_threshold.get(api_name).get(dtype).get('small_value_atol')
98
- rtol = apis_threshold.get(api_name).get(dtype).get('rtol')
99
- return small_value_threshold, small_value_atol, rtol
100
-
101
- @staticmethod
102
- def _get_run_ut_detail(test_result):
103
- """get run_ut detail before write to csv, called by online run_ut"""
104
- test_rows = []
105
- try:
106
- subject_prefix = test_result[0]
107
- fwd_result = test_result[3]
108
- bwd_result = test_result[4]
109
- except IndexError as e:
110
- logger.error("List index out of bounds when writing detail CSV.")
111
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e
112
-
113
- if isinstance(fwd_result, list):
114
- for i, test_subject in enumerate(fwd_result):
115
- subject = subject_prefix + ".forward.output." + str(i)
116
- test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision)
117
- if isinstance(item, float) else item for item in test_subject]
118
- test_rows.append([subject] + list(test_subject))
119
- if isinstance(bwd_result, list):
120
- for i, test_subject in enumerate(bwd_result):
121
- subject = subject_prefix + ".backward.output." + str(i)
122
- test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision)
123
- if isinstance(item, float) else item for item in test_subject]
124
- test_rows.append([subject] + list(test_subject))
125
- return test_rows
126
-
127
- def write_csv_title(self):
128
- summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS,
129
- self.COLUMN_BACKWARD_SUCCESS, "Message"]]
130
- for save_path, detail_save_path in zip(self.save_path_list, self.detail_save_path_list):
131
- if not os.path.exists(save_path):
132
- write_csv(summary_test_rows, save_path)
133
- if not os.path.exists(detail_save_path):
134
- write_csv(DETAIL_TEST_ROWS, detail_save_path)
135
-
136
- def write_summary_csv(self, test_result):
137
- test_rows = []
138
- try:
139
- name = test_result[0]
140
- df_row = list(test_result[:INDEX_TEST_RESULT_GROUP])
141
- if test_result[1] == "SKIP":
142
- df_row.append(test_result[INDEX_TEST_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
143
- if self.stack_info:
144
- stack_info = "\n".join(self.stack_info[name])
145
- df_row.append(stack_info)
146
- test_rows.append(df_row)
147
- save_path = self.get_path_from_rank(test_result[-1], self.save_path_list, self.save_path_str)
148
- except IndexError as e:
149
- logger.error("List index out of bounds when writing summary CSV.")
150
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e
151
- write_csv(test_rows, save_path)
152
-
153
- def write_detail_csv(self, test_result):
154
- test_rows = self._get_run_ut_detail(test_result)
155
- detail_save_path = self.get_path_from_rank(test_result[-1],
156
- self.detail_save_path_list,
157
- self.detail_save_path_str)
158
- write_csv(test_rows, detail_save_path)
159
-
160
- def record_results(self, args):
161
- self.write_summary_csv(args)
162
- self.write_detail_csv(args)
163
-
164
- def compare_output(self, full_api_name, data_info, is_online=False):
165
- """Get compare result and write to result and detail csv.
166
- is_online: bool, default False. True: called by online api precision compare, only compare without write to csv.
167
- """
168
- _, api_name, _ = full_api_name.split(Const.SEP)
169
- bench_output, device_output = data_info.bench_output, data_info.device_output
170
- bench_grad, device_grad = data_info.bench_grad, data_info.device_grad
171
- backward_message = data_info.backward_message
172
- if "dropout" in full_api_name:
173
- fwd_success_status, fwd_compare_alg_results = self._compare_dropout(bench_output, device_output)
174
- else:
175
- fwd_success_status, fwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_output,
176
- device_output)
177
- if not (bench_grad and device_grad):
178
- bwd_success_status, bwd_compare_alg_results = (CompareConst.SPACE, [])
179
- else:
180
- if "dropout" in full_api_name:
181
- bwd_success_status, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], device_grad[0])
182
- else:
183
- bwd_success_status, bwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_grad,
184
- device_grad)
185
- if backward_message:
186
- backward_column = CompareColumn()
187
- bwd_compare_alg_results = [backward_column.to_column_value(CompareConst.SKIP, backward_message)]
188
- else:
189
- bwd_success_status = bwd_success_status if bwd_compare_alg_results is not None else CompareConst.SPACE
190
- result_info = ResultInfo(full_api_name,
191
- fwd_success_status,
192
- bwd_success_status,
193
- fwd_compare_alg_results,
194
- bwd_compare_alg_results,
195
- data_info.rank)
196
- if is_online:
197
- # get run_ut compare detail
198
- return self._get_run_ut_detail(result_info)
199
- self.record_results(result_info)
200
- return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
201
- or bwd_success_status == CompareConst.SPACE
202
-
203
- def _compare_core_wrapper(self, api_name, bench_output, device_output):
204
- detailed_result_total = []
205
- test_final_success = CompareConst.PASS
206
- if isinstance(bench_output, (list, tuple)):
207
- status, compare_result, message = [], [], []
208
- if len(bench_output) > len(device_output):
209
- status = [CompareConst.ERROR]
210
- message = ["bench and npu output structure is different."]
211
- else:
212
- device_output = device_output[:len(bench_output)]
213
- for b_out_i, n_out_i in zip(bench_output, device_output):
214
- status_i, compare_result_i, message_i = self._compare_core(api_name, b_out_i, n_out_i)
215
- status.append(status_i)
216
- compare_result.append(compare_result_i)
217
- message.append(message_i)
218
- else:
219
- status, compare_result, message = self._compare_core(api_name, bench_output, device_output)
220
- if not isinstance(status, list):
221
- detailed_result_total.append(compare_result.to_column_value(status, message))
222
- if status == CompareConst.ERROR:
223
- test_final_success = CompareConst.ERROR
224
- elif status == CompareConst.WARNING:
225
- test_final_success = CompareConst.WARNING
226
- else:
227
- for item, item_status in enumerate(status):
228
- detailed_result_total.append(compare_result[item].to_column_value(item_status, message[item]))
229
- if item_status == CompareConst.ERROR:
230
- test_final_success = CompareConst.ERROR
231
- elif item_status == CompareConst.WARNING:
232
- test_final_success = CompareConst.WARNING
233
- return test_final_success, detailed_result_total
234
-
235
- def _compare_core(self, api_name, bench_output, device_output):
236
- compare_column = CompareColumn()
237
- if not isinstance(bench_output, type(device_output)):
238
- return CompareConst.ERROR, compare_column, "bench and npu output type is different."
239
- elif isinstance(bench_output, dict):
240
- b_keys, n_keys = set(bench_output.keys()), set(device_output.keys())
241
- if b_keys != n_keys:
242
- return CompareConst.ERROR, compare_column, "bench and npu output dict keys are different."
243
- else:
244
- status, compare_result, message = self._compare_core(api_name, list(bench_output.values()),
245
- list(device_output.values()))
246
- elif isinstance(bench_output, torch.Tensor):
247
- copy_bench_out = bench_output.detach().clone()
248
- copy_device_output = device_output.detach().clone()
249
- compare_column.bench_type = str(copy_bench_out.dtype)
250
- compare_column.npu_type = str(copy_device_output.dtype)
251
- compare_column.shape = tuple(device_output.shape)
252
- status, compare_result, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output,
253
- compare_column)
254
- elif isinstance(bench_output, (bool, int, float, str)):
255
- compare_column.bench_type = str(type(bench_output))
256
- compare_column.npu_type = str(type(device_output))
257
- status, compare_result, message = self._compare_builtin_type(bench_output, device_output, compare_column)
258
- elif bench_output is None:
259
- return CompareConst.SKIP, compare_column, "Bench output is None, skip this test."
260
- else:
261
- return CompareConst.PASS, compare_column,
262
- "Unexpected output type in compare_core: {}".format(type(bench_output))
263
-
264
- return status, compare_result, message
265
-
266
- def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column):
267
- cpu_shape = bench_output.shape
268
- npu_shape = device_output.shape
269
- npu_dtype = device_output.dtype
270
- if npu_dtype == torch.bfloat16:
271
- bench_output = bench_output.to(torch.float32)
272
- device_output = device_output.to(torch.float32)
273
- bench_output = bench_output.numpy()
274
- device_output = device_output.cpu().numpy()
275
- if cpu_shape != npu_shape:
276
- return CompareConst.ERROR, compare_column, f"The shape of bench{str(cpu_shape)} " \
277
- f"and npu{str(npu_shape)} not equal."
278
- if not check_dtype_comparable(bench_output, device_output):
279
- return CompareConst.ERROR, compare_column, f"Bench out dtype is {bench_output.dtype} but " \
280
- f"npu output dtype is {device_output.dtype}, cannot compare."
281
- message = ""
282
- if bench_output.dtype in [bool, np.uint8, np.int8, np.int16, np.uint16, np.uint32, np.int32,
283
- np.int64, np.uint64]:
284
- message += f"Compare algorithm is not supported for {bench_output.dtype} data. " \
285
- f"Only judged by Error Rate."
286
- err_rate, status, msg = self._compare_bool_tensor(bench_output, device_output)
287
- message += msg + "\n"
288
- compare_column.error_rate = err_rate
289
- return status, compare_column, message
290
- else:
291
- status, compare_column, message = self._compare_float_tensor(api_name, bench_output, device_output,
292
- compare_column, npu_dtype)
293
- return status, compare_column, message
294
-
295
- def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype):
296
- message = ""
297
- abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype)
298
- abs_err = get_abs_err(bench_output, device_output)
299
- rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps)
300
- if api_name in thousandth_standard_api:
301
- thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
302
- compare_column.rel_err_thousandth = thousand_res
303
- if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST:
304
- both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output)
305
- if api_name in binary_standard_api:
306
- err_rate, _, _ = self._compare_bool_tensor(bench_output, device_output)
307
- compare_column.error_rate = err_rate
308
- elif api_name in absolute_standard_api:
309
- small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute(
310
- api_name, str(dtype))
311
- rel_err = abs_err / abs_bench_with_eps
312
- small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold)
313
- normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask))
314
- compare_column.inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, bench_output, device_output,
315
- dtype, rtol)
316
- compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol)
317
- compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol)
318
- elif api_name in ulp_standard_api:
319
- if bench_output.size == 0:
320
- compare_column.max_ulp_error = 0
321
- compare_column.mean_ulp_error = 0
322
- compare_column.ulp_error_proportion = 0
323
- else:
324
- ulp_err = get_ulp_err(bench_output, device_output, dtype)
325
- compare_column.max_ulp_error = np.max(ulp_err)
326
- compare_column.mean_ulp_error = np.mean(ulp_err)
327
- if dtype == torch.float32:
328
- compare_column.ulp_error_proportion = np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size
329
- else:
330
- compare_column.ulp_error_proportion = np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size
331
- else:
332
- dtype_config = precision_configs.get(dtype)
333
- small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0])
334
- abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0])
335
- compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask)
336
- rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask)
337
- compare_column.RMSE = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask))
338
- compare_column.EB = get_error_balance(bench_output, device_output)
339
- if rel_err.size == 0:
340
- return CompareConst.ERROR, compare_column, "Relative error result list is empty."
341
- compare_column.Max_rel_error = get_max_rel_err(rel_err)
342
- compare_column.Mean_rel_error = get_mean_rel_err(rel_err)
343
-
344
- cos_res, cos_status, msg = cosine_sim(bench_output, device_output)
345
- compare_column.cosine_sim = cos_res
346
- message += msg + "\n"
347
- if not cos_status:
348
- message += "Cosine similarity is less than 0.99, consider as error, skip other check and set to SPACE.\n"
349
- return CompareConst.ERROR, compare_column, message
350
-
351
- max_abs_res, max_abs_status = get_max_abs_err(abs_err)
352
- compare_column.max_abs_err = max_abs_res
353
- if max_abs_status:
354
- message += "Max abs error is less than 0.001, consider as pass, skip other check and set to SPACE.\n"
355
- return CompareConst.PASS, compare_column, message
356
-
357
- if dtype in [torch.float16, torch.bfloat16]:
358
- hundred_res, hundred_status = get_rel_err_ratio(rel_err_orign, CompareConst.HUNDRED_RATIO_THRESHOLD)
359
- compare_column.rel_err_hundredth = hundred_res
360
- if not hundred_status:
361
- message += "Relative error is greater than 0.01, consider as error, skip other check and set to SPACE.\n"
362
- return CompareConst.ERROR, compare_column, message
363
- thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
364
- compare_column.rel_err_thousandth = thousand_res
365
- if dtype in [torch.float16, torch.bfloat16]:
366
- if thousand_status:
367
- message += "Relative error is less than 0.001, consider as pass, skip other check and set to SPACE.\n"
368
- return CompareConst.PASS, compare_column, message
369
- message += "Relative error is greater than 0.001, consider as warning, skip other check and set to SPACE.\n"
370
- return CompareConst.WARNING, compare_column, message
371
- ten_thousand_res, ten_thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.TEN_THOUSAND_RATIO_THRESHOLD)
372
- compare_column.rel_err_ten_thousandth = ten_thousand_res
373
- if dtype in [torch.float32, torch.float64]:
374
- if not thousand_status:
375
- message += "Relative error is greater than 0.001, consider as error, skip other check and set to SPACE.\n"
376
- return CompareConst.ERROR, compare_column, message
377
- if not ten_thousand_status:
378
- message += "Relative error is greater than 0.0001, consider as warning, skip other check and set to SPACE.\n"
379
- return CompareConst.WARNING, compare_column, message
380
- message += "Relative error is less than 0.0001, consider as pass.\n"
381
- return CompareConst.PASS, compare_column, message
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ # 进行比对及结果展示
19
+ import os
20
+ from collections import namedtuple
21
+
22
+ import numpy as np
23
+ from msprobe.core.common.utils import CompareException
24
+ from msprobe.core.common.file_utils import get_json_contents, write_csv
25
+ import torch
26
+ from msprobe.core.common.const import CompareConst
27
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \
28
+ get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \
29
+ get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \
30
+ check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err
31
+ from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
32
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
33
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
34
+ DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, absolute_standard_api, binary_standard_api, \
35
+ ulp_standard_api, thousandth_standard_api, apis_threshold
36
+ from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
37
+ from msprobe.pytorch.common.log import logger
38
+
39
+
40
+ ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status',
41
+ 'fwd_compare_alg_results', 'bwd_compare_alg_results', 'rank'])
42
+
43
+
44
+ INDEX_TEST_RESULT_GROUP = 3
45
+ INDEX_FIRST_GROUP = 0
46
+ INDEX_MESSAGE = -1
47
+
48
+
49
+ class Comparator:
50
+ # consts for result csv
51
+ COLUMN_API_NAME = "API name"
52
+ COLUMN_FORWARD_SUCCESS = "Forward Test Success"
53
+ COLUMN_BACKWARD_SUCCESS = "Backward Test Success"
54
+ COLUMN_STACK_INFO = "Traceback callstack info"
55
+
56
+ def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None, config=None):
57
+ self.save_path_str = result_csv_path
58
+ self.detail_save_path_str = details_csv_path
59
+ self.save_path_list = [result_csv_path]
60
+ self.detail_save_path_list = [details_csv_path]
61
+
62
+ if config and config.online_config.is_online:
63
+ self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv")
64
+ self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv")
65
+ self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list]
66
+ self.detail_save_path_list = \
67
+ [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
68
+
69
+ if not is_continue_run_ut:
70
+ self.write_csv_title()
71
+ if stack_info_json_path:
72
+ self.stack_info = get_json_contents(stack_info_json_path)
73
+ else:
74
+ self.stack_info = None
75
+
76
+ @staticmethod
77
+ def get_path_from_rank(rank, path_list, path_pattern):
78
+ return path_list[-1] if len(path_list) == 1 else path_pattern.format(rank)
79
+
80
+ @staticmethod
81
+ def print_pretest_result():
82
+ logger.info("Successfully completed run_ut/multi_run_ut.")
83
+
84
+ @staticmethod
85
+ def _compare_dropout(bench_output, device_output):
86
+ tensor_num = bench_output.numel()
87
+ if tensor_num >= 100:
88
+ if abs((bench_output == 0).sum() - (device_output == 0).cpu().sum()) / tensor_num < 0.1:
89
+ return CompareConst.PASS, 1
90
+ else:
91
+ return CompareConst.ERROR, 0
92
+ else:
93
+ return CompareConst.PASS, 1
94
+
95
+ @staticmethod
96
+ def _compare_builtin_type(bench_output, device_output, compare_column):
97
+ if not isinstance(bench_output, (bool, int, float, str)):
98
+ return CompareConst.PASS, compare_column, ""
99
+ if bench_output != device_output:
100
+ return CompareConst.ERROR, compare_column, ""
101
+ compare_column.error_rate = 0
102
+ return CompareConst.PASS, compare_column, ""
103
+
104
+ @staticmethod
105
+ def _compare_bool_tensor(bench_output, device_output):
106
+ error_nums = (bench_output != device_output).sum()
107
+ if bench_output.size == 0:
108
+ return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result."
109
+ error_rate = float(error_nums / bench_output.size)
110
+ result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
111
+ return error_rate, result, ""
112
+
113
+ @staticmethod
114
+ def _get_absolute_threshold_attribute(api_name, dtype):
115
+ small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value')
116
+ small_value_atol = apis_threshold.get(api_name).get(dtype).get('small_value_atol')
117
+ rtol = apis_threshold.get(api_name).get(dtype).get('rtol')
118
+ return small_value_threshold, small_value_atol, rtol
119
+
120
+ @staticmethod
121
+ def _get_run_ut_detail(test_result):
122
+ """get run_ut detail before write to csv, called by online run_ut"""
123
+ test_rows = []
124
+ try:
125
+ subject_prefix = test_result[0]
126
+ fwd_result = test_result[3]
127
+ bwd_result = test_result[4]
128
+ except IndexError as e:
129
+ logger.error("List index out of bounds when writing detail CSV.")
130
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e
131
+
132
+ if isinstance(fwd_result, list):
133
+ for i, test_subject in enumerate(fwd_result):
134
+ subject = subject_prefix + ".forward.output." + str(i)
135
+ test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision)
136
+ if isinstance(item, float) else item for item in test_subject]
137
+ test_rows.append([subject] + list(test_subject))
138
+ if isinstance(bwd_result, list):
139
+ for i, test_subject in enumerate(bwd_result):
140
+ subject = subject_prefix + ".backward.output." + str(i)
141
+ test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision)
142
+ if isinstance(item, float) else item for item in test_subject]
143
+ test_rows.append([subject] + list(test_subject))
144
+ return test_rows
145
+
146
+ def write_csv_title(self):
147
+ summary_test_rows = [
148
+ [self.COLUMN_API_NAME,
149
+ self.COLUMN_FORWARD_SUCCESS,
150
+ self.COLUMN_BACKWARD_SUCCESS,
151
+ "Message"]
152
+ ]
153
+ for save_path, detail_save_path in zip(self.save_path_list, self.detail_save_path_list):
154
+ if not os.path.exists(save_path):
155
+ write_csv(summary_test_rows, save_path)
156
+ if not os.path.exists(detail_save_path):
157
+ write_csv(DETAIL_TEST_ROWS, detail_save_path)
158
+
159
+ def write_summary_csv(self, test_result):
160
+ test_rows = []
161
+ try:
162
+ name = test_result[0]
163
+ df_row = list(test_result[:INDEX_TEST_RESULT_GROUP])
164
+ if test_result[1] == CompareConst.SKIP:
165
+ df_row.append(test_result[INDEX_TEST_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
166
+ if self.stack_info:
167
+ stack_info = "\n".join(self.stack_info[name])
168
+ df_row.append(stack_info)
169
+ test_rows.append(df_row)
170
+ save_path = self.get_path_from_rank(test_result[-1], self.save_path_list, self.save_path_str)
171
+ except IndexError as e:
172
+ logger.error("List index out of bounds when writing summary CSV.")
173
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e
174
+ write_csv(test_rows, save_path)
175
+
176
+ def write_detail_csv(self, test_result):
177
+ test_rows = self._get_run_ut_detail(test_result)
178
+ detail_save_path = self.get_path_from_rank(test_result[-1],
179
+ self.detail_save_path_list,
180
+ self.detail_save_path_str)
181
+ write_csv(test_rows, detail_save_path)
182
+
183
+ def record_results(self, args):
184
+ self.write_summary_csv(args)
185
+ self.write_detail_csv(args)
186
+
187
+
188
+ def compare_output(self, full_api_name, data_info, is_online=False):
189
+ """Get compare result and write to result and detail csv.
190
+ is_online: bool, default False. True: called by online api precision compare, only compare without write to csv.
191
+ """
192
+ _, api_name = extract_basic_api_segments(full_api_name)
193
+ if not api_name:
194
+ raise ValueError(f"API name {full_api_name} has not been adapted.")
195
+ bench_output, device_output = data_info.bench_output, data_info.device_output
196
+ bench_grad, device_grad = data_info.bench_grad, data_info.device_grad
197
+ backward_message = data_info.backward_message
198
+ if "dropout" in full_api_name:
199
+ fwd_success_status, fwd_compare_alg_results = self._compare_dropout(bench_output, device_output)
200
+ else:
201
+ fwd_success_status, fwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_output,
202
+ device_output)
203
+ if not (bench_grad and device_grad):
204
+ bwd_success_status, bwd_compare_alg_results = (CompareConst.SPACE, [])
205
+ else:
206
+ if "dropout" in full_api_name:
207
+ bwd_success_status, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], device_grad[0])
208
+ else:
209
+ bwd_success_status, bwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_grad,
210
+ device_grad)
211
+ if backward_message:
212
+ backward_column = CompareColumn()
213
+ bwd_compare_alg_results = [backward_column.to_column_value(CompareConst.SKIP, backward_message)]
214
+ else:
215
+ bwd_success_status = bwd_success_status if bwd_compare_alg_results is not None else CompareConst.SPACE
216
+ result_info = ResultInfo(full_api_name,
217
+ fwd_success_status,
218
+ bwd_success_status,
219
+ fwd_compare_alg_results,
220
+ bwd_compare_alg_results,
221
+ data_info.rank)
222
+ if is_online:
223
+ # get run_ut compare detail
224
+ return self._get_run_ut_detail(result_info)
225
+ self.record_results(result_info)
226
+ return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
227
+ or bwd_success_status == CompareConst.SPACE
228
+
229
+ def _compare_core_wrapper(self, api_name, bench_output, device_output):
230
+ detailed_result_total = []
231
+ test_final_success = CompareConst.PASS
232
+ if isinstance(bench_output, (list, tuple)):
233
+ status, compare_result, message = [], [], []
234
+ if len(bench_output) > len(device_output):
235
+ status = [CompareConst.ERROR]
236
+ message = ["bench and npu output structure is different."]
237
+ else:
238
+ device_output = device_output[:len(bench_output)]
239
+ for b_out_i, n_out_i in zip(bench_output, device_output):
240
+ status_i, compare_result_i, message_i = self._compare_core(api_name, b_out_i, n_out_i)
241
+ status.append(status_i)
242
+ compare_result.append(compare_result_i)
243
+ message.append(message_i)
244
+ else:
245
+ status, compare_result, message = self._compare_core(api_name, bench_output, device_output)
246
+ if not isinstance(status, list):
247
+ detailed_result_total.append(compare_result.to_column_value(status, message))
248
+ if status == CompareConst.ERROR:
249
+ test_final_success = CompareConst.ERROR
250
+ elif status == CompareConst.WARNING:
251
+ test_final_success = CompareConst.WARNING
252
+ else:
253
+ for item, item_status in enumerate(status):
254
+ detailed_result_total.append(compare_result[item].to_column_value(item_status, message[item]))
255
+ if item_status == CompareConst.ERROR:
256
+ test_final_success = CompareConst.ERROR
257
+ elif item_status == CompareConst.WARNING:
258
+ test_final_success = CompareConst.WARNING
259
+ return test_final_success, detailed_result_total
260
+
261
+ def _compare_core(self, api_name, bench_output, device_output):
262
+ compare_column = CompareColumn()
263
+ if not isinstance(bench_output, type(device_output)):
264
+ status = CompareConst.ERROR
265
+ message = "bench and npu output type is different."
266
+ elif isinstance(bench_output, dict):
267
+ b_keys, n_keys = set(bench_output.keys()), set(device_output.keys())
268
+ if b_keys != n_keys:
269
+ status = CompareConst.ERROR
270
+ message = "bench and npu output dict keys are different."
271
+ else:
272
+ status, compare_column, message = self._compare_core(api_name, list(bench_output.values()),
273
+ list(device_output.values()))
274
+ elif isinstance(bench_output, torch.Tensor):
275
+ copy_bench_out = bench_output.detach().clone()
276
+ copy_device_output = device_output.detach().clone()
277
+ compare_column.bench_type = str(copy_bench_out.dtype)
278
+ compare_column.npu_type = str(copy_device_output.dtype)
279
+ compare_column.shape = tuple(device_output.shape)
280
+ status, compare_column, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output,
281
+ compare_column)
282
+ elif isinstance(bench_output, (bool, int, float, str)):
283
+ compare_column.bench_type = str(type(bench_output))
284
+ compare_column.npu_type = str(type(device_output))
285
+ status, compare_column, message = self._compare_builtin_type(bench_output, device_output, compare_column)
286
+ elif bench_output is None:
287
+ status = CompareConst.SKIP
288
+ message = "Bench output is None, skip this test."
289
+ else:
290
+ status = CompareConst.ERROR
291
+ message = "Unexpected output type in compare_core: {}".format(type(bench_output))
292
+
293
+ return status, compare_column, message
294
+
295
+ def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column):
296
+ cpu_shape = bench_output.shape
297
+ npu_shape = device_output.shape
298
+ npu_dtype = device_output.dtype
299
+ if npu_dtype == torch.bfloat16:
300
+ bench_output = bench_output.to(torch.float32)
301
+ device_output = device_output.to(torch.float32)
302
+ bench_output = bench_output.cpu().numpy()
303
+ device_output = device_output.cpu().numpy()
304
+ if cpu_shape != npu_shape:
305
+ return CompareConst.ERROR, compare_column, f"The shape of bench{str(cpu_shape)} " \
306
+ f"and npu{str(npu_shape)} not equal."
307
+ if not check_dtype_comparable(bench_output, device_output):
308
+ return CompareConst.ERROR, compare_column, f"Bench out dtype is {bench_output.dtype} but " \
309
+ f"npu output dtype is {device_output.dtype}, cannot compare."
310
+ message = ""
311
+ if bench_output.dtype in [bool, np.uint8, np.int8, np.int16, np.uint16, np.uint32, np.int32,
312
+ np.int64, np.uint64]:
313
+ message += f"Compare algorithm is not supported for {bench_output.dtype} data. " \
314
+ f"Only judged by Error Rate."
315
+ err_rate, status, msg = self._compare_bool_tensor(bench_output, device_output)
316
+ message += msg + "\n"
317
+ compare_column.error_rate = err_rate
318
+ return status, compare_column, message
319
+ else:
320
+ status, compare_column, message = self._compare_float_tensor(api_name, bench_output, device_output,
321
+ compare_column, npu_dtype)
322
+ return status, compare_column, message
323
+
324
+ def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype):
325
+ message = ""
326
+ abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype)
327
+ abs_err = get_abs_err(bench_output, device_output)
328
+ rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps)
329
+ if api_name in thousandth_standard_api:
330
+ thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
331
+ compare_column.rel_err_thousandth = thousand_res
332
+ if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST:
333
+ both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output)
334
+ if api_name in binary_standard_api:
335
+ err_rate, _, _ = self._compare_bool_tensor(bench_output, device_output)
336
+ compare_column.error_rate = err_rate
337
+ elif api_name in absolute_standard_api:
338
+ small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute(
339
+ api_name, str(dtype))
340
+ rel_err = abs_err / abs_bench_with_eps
341
+ small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold)
342
+ normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask))
343
+ compare_column.inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, bench_output, device_output,
344
+ dtype, rtol)
345
+ compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol)
346
+ compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol)
347
+ elif api_name in ulp_standard_api:
348
+ if bench_output.size == 0:
349
+ compare_column.max_ulp_error = 0
350
+ compare_column.mean_ulp_error = 0
351
+ compare_column.ulp_error_proportion = 0
352
+ else:
353
+ ulp_err = get_ulp_err(bench_output, device_output, dtype)
354
+ compare_column.max_ulp_error = np.max(ulp_err)
355
+ compare_column.mean_ulp_error = np.mean(ulp_err)
356
+ if dtype == torch.float32:
357
+ compare_column.ulp_error_proportion = \
358
+ np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size
359
+ else:
360
+ compare_column.ulp_error_proportion = \
361
+ np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size
362
+ else:
363
+ dtype_config = precision_configs.get(dtype)
364
+ small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0])
365
+ abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0])
366
+ compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask)
367
+ rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask)
368
+ compare_column.rmse = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask))
369
+ compare_column.eb = get_error_balance(bench_output, device_output)
370
+ if rel_err.size == 0:
371
+ return CompareConst.ERROR, compare_column, "Relative error result list is empty."
372
+ compare_column.max_rel_error = get_max_rel_err(rel_err)
373
+ compare_column.mean_rel_error = get_mean_rel_err(rel_err)
374
+
375
+ cos_res, cos_status, msg = cosine_sim(bench_output, device_output)
376
+ compare_column.cosine_sim = cos_res
377
+ message += msg + "\n"
378
+ if not cos_status:
379
+ message += "Cosine similarity is less than 0.99, consider as error, skip other check and set to SPACE.\n"
380
+ return CompareConst.ERROR, compare_column, message
381
+
382
+ max_abs_res, max_abs_status = get_max_abs_err(abs_err)
383
+ compare_column.max_abs_err = max_abs_res
384
+ if max_abs_status:
385
+ message += "Max abs error is less than 0.001, consider as pass, skip other check and set to SPACE.\n"
386
+ return CompareConst.PASS, compare_column, message
387
+
388
+ if dtype in [torch.float16, torch.bfloat16]:
389
+ hundred_res, hundred_status = get_rel_err_ratio(rel_err_orign, CompareConst.HUNDRED_RATIO_THRESHOLD)
390
+ compare_column.rel_err_hundredth = hundred_res
391
+ if not hundred_status:
392
+ message += "Relative error is greater than 0.01, consider as error, " \
393
+ "skip other check and set to SPACE.\n"
394
+ return CompareConst.ERROR, compare_column, message
395
+ thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
396
+ compare_column.rel_err_thousandth = thousand_res
397
+ if dtype in [torch.float16, torch.bfloat16]:
398
+ if thousand_status:
399
+ message += "Relative error is less than 0.001, consider as pass, skip other check and set to SPACE.\n"
400
+ return CompareConst.PASS, compare_column, message
401
+ message += "Relative error is greater than 0.001, consider as warning, skip other check and set to SPACE.\n"
402
+ return CompareConst.WARNING, compare_column, message
403
+ ten_thousand_res, ten_thousand_status = get_rel_err_ratio(
404
+ rel_err_orign, CompareConst.TEN_THOUSAND_RATIO_THRESHOLD)
405
+ compare_column.rel_err_ten_thousandth = ten_thousand_res
406
+ if dtype in [torch.float32, torch.float64]:
407
+ if not thousand_status:
408
+ message += "Relative error is greater than 0.001, consider as error, " \
409
+ "skip other check and set to SPACE.\n"
410
+ return CompareConst.ERROR, compare_column, message
411
+ if not ten_thousand_status:
412
+ message += "Relative error is greater than 0.0001, consider as warning, " \
413
+ "skip other check and set to SPACE.\n"
414
+ return CompareConst.WARNING, compare_column, message
415
+ message += "Relative error is less than 0.0001, consider as pass.\n"
416
+ return CompareConst.PASS, compare_column, message