mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,381 +1,386 @@
1
- # 进行比对及结果展示
2
- import os
3
- from collections import namedtuple
4
-
5
- import numpy as np
6
- from msprobe.core.common.utils import write_csv, get_json_contents, CompareException
7
- import torch
8
- from msprobe.core.common.const import Const, CompareConst
9
- from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \
10
- get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \
11
- get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \
12
- check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err
13
- from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
14
- from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
15
- from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
16
- DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, absolute_standard_api, binary_standard_api, \
17
- ulp_standard_api, thousandth_standard_api, apis_threshold
18
- from msprobe.pytorch.common.log import logger
19
-
20
-
21
- ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status',
22
- 'fwd_compare_alg_results', 'bwd_compare_alg_results', 'rank'])
23
-
24
-
25
- INDEX_TEST_RESULT_GROUP = 3
26
- INDEX_FIRST_GROUP = 0
27
- INDEX_MESSAGE = -1
28
-
29
-
30
- class Comparator:
31
- # consts for result csv
32
- COLUMN_API_NAME = "API name"
33
- COLUMN_FORWARD_SUCCESS = "Forward Test Success"
34
- COLUMN_BACKWARD_SUCCESS = "Backward Test Success"
35
- COLUMN_STACK_INFO = "Traceback callstack info"
36
-
37
- def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None, config=None):
38
- self.save_path_str = result_csv_path
39
- self.detail_save_path_str = details_csv_path
40
- self.save_path_list = [result_csv_path]
41
- self.detail_save_path_list = [details_csv_path]
42
-
43
- if config and config.online_config.is_online:
44
- self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv")
45
- self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv")
46
- self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list]
47
- self.detail_save_path_list = \
48
- [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
49
-
50
- if not is_continue_run_ut:
51
- self.write_csv_title()
52
- if stack_info_json_path:
53
- self.stack_info = get_json_contents(stack_info_json_path)
54
- else:
55
- self.stack_info = None
56
-
57
- @staticmethod
58
- def get_path_from_rank(rank, path_list, path_pattern):
59
- return path_list[-1] if len(path_list) == 1 else path_pattern.format(rank)
60
-
61
- @staticmethod
62
- def print_pretest_result():
63
- logger.info("Successfully completed run_ut/multi_run_ut.")
64
-
65
- @staticmethod
66
- def _compare_dropout(bench_output, device_output):
67
- tensor_num = bench_output.numel()
68
- if tensor_num >= 100:
69
- if abs((bench_output == 0).sum() - (device_output == 0).cpu().sum()) / tensor_num < 0.1:
70
- return CompareConst.PASS, 1
71
- else:
72
- return CompareConst.ERROR, 0
73
- else:
74
- return CompareConst.PASS, 1
75
-
76
- @staticmethod
77
- def _compare_builtin_type(bench_output, device_output, compare_column):
78
- if not isinstance(bench_output, (bool, int, float, str)):
79
- return CompareConst.PASS, compare_column, ""
80
- if bench_output != device_output:
81
- return CompareConst.ERROR, compare_column, ""
82
- compare_column.error_rate = 0
83
- return CompareConst.PASS, compare_column, ""
84
-
85
- @staticmethod
86
- def _compare_bool_tensor(bench_output, device_output):
87
- error_nums = (bench_output != device_output).sum()
88
- if bench_output.size == 0:
89
- return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result."
90
- error_rate = float(error_nums / bench_output.size)
91
- result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
92
- return error_rate, result, ""
93
-
94
- @staticmethod
95
- def _get_absolute_threshold_attribute(api_name, dtype):
96
- small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value')
97
- small_value_atol = apis_threshold.get(api_name).get(dtype).get('small_value_atol')
98
- rtol = apis_threshold.get(api_name).get(dtype).get('rtol')
99
- return small_value_threshold, small_value_atol, rtol
100
-
101
- @staticmethod
102
- def _get_run_ut_detail(test_result):
103
- """get run_ut detail before write to csv, called by online run_ut"""
104
- test_rows = []
105
- try:
106
- subject_prefix = test_result[0]
107
- fwd_result = test_result[3]
108
- bwd_result = test_result[4]
109
- except IndexError as e:
110
- logger.error("List index out of bounds when writing detail CSV.")
111
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e
112
-
113
- if isinstance(fwd_result, list):
114
- for i, test_subject in enumerate(fwd_result):
115
- subject = subject_prefix + ".forward.output." + str(i)
116
- test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision)
117
- if isinstance(item, float) else item for item in test_subject]
118
- test_rows.append([subject] + list(test_subject))
119
- if isinstance(bwd_result, list):
120
- for i, test_subject in enumerate(bwd_result):
121
- subject = subject_prefix + ".backward.output." + str(i)
122
- test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision)
123
- if isinstance(item, float) else item for item in test_subject]
124
- test_rows.append([subject] + list(test_subject))
125
- return test_rows
126
-
127
- def write_csv_title(self):
128
- summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS,
129
- self.COLUMN_BACKWARD_SUCCESS, "Message"]]
130
- for save_path, detail_save_path in zip(self.save_path_list, self.detail_save_path_list):
131
- if not os.path.exists(save_path):
132
- write_csv(summary_test_rows, save_path)
133
- if not os.path.exists(detail_save_path):
134
- write_csv(DETAIL_TEST_ROWS, detail_save_path)
135
-
136
- def write_summary_csv(self, test_result):
137
- test_rows = []
138
- try:
139
- name = test_result[0]
140
- df_row = list(test_result[:INDEX_TEST_RESULT_GROUP])
141
- if test_result[1] == "SKIP":
142
- df_row.append(test_result[INDEX_TEST_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
143
- if self.stack_info:
144
- stack_info = "\n".join(self.stack_info[name])
145
- df_row.append(stack_info)
146
- test_rows.append(df_row)
147
- save_path = self.get_path_from_rank(test_result[-1], self.save_path_list, self.save_path_str)
148
- except IndexError as e:
149
- logger.error("List index out of bounds when writing summary CSV.")
150
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e
151
- write_csv(test_rows, save_path)
152
-
153
- def write_detail_csv(self, test_result):
154
- test_rows = self._get_run_ut_detail(test_result)
155
- detail_save_path = self.get_path_from_rank(test_result[-1],
156
- self.detail_save_path_list,
157
- self.detail_save_path_str)
158
- write_csv(test_rows, detail_save_path)
159
-
160
- def record_results(self, args):
161
- self.write_summary_csv(args)
162
- self.write_detail_csv(args)
163
-
164
- def compare_output(self, full_api_name, data_info, is_online=False):
165
- """Get compare result and write to result and detail csv.
166
- is_online: bool, default False. True: called by online api precision compare, only compare without write to csv.
167
- """
168
- _, api_name, _ = full_api_name.split(Const.SEP)
169
- bench_output, device_output = data_info.bench_output, data_info.device_output
170
- bench_grad, device_grad = data_info.bench_grad, data_info.device_grad
171
- backward_message = data_info.backward_message
172
- if "dropout" in full_api_name:
173
- fwd_success_status, fwd_compare_alg_results = self._compare_dropout(bench_output, device_output)
174
- else:
175
- fwd_success_status, fwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_output,
176
- device_output)
177
- if not (bench_grad and device_grad):
178
- bwd_success_status, bwd_compare_alg_results = (CompareConst.SPACE, [])
179
- else:
180
- if "dropout" in full_api_name:
181
- bwd_success_status, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], device_grad[0])
182
- else:
183
- bwd_success_status, bwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_grad,
184
- device_grad)
185
- if backward_message:
186
- backward_column = CompareColumn()
187
- bwd_compare_alg_results = [backward_column.to_column_value(CompareConst.SKIP, backward_message)]
188
- else:
189
- bwd_success_status = bwd_success_status if bwd_compare_alg_results is not None else CompareConst.SPACE
190
- result_info = ResultInfo(full_api_name,
191
- fwd_success_status,
192
- bwd_success_status,
193
- fwd_compare_alg_results,
194
- bwd_compare_alg_results,
195
- data_info.rank)
196
- if is_online:
197
- # get run_ut compare detail
198
- return self._get_run_ut_detail(result_info)
199
- self.record_results(result_info)
200
- return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
201
- or bwd_success_status == CompareConst.SPACE
202
-
203
- def _compare_core_wrapper(self, api_name, bench_output, device_output):
204
- detailed_result_total = []
205
- test_final_success = CompareConst.PASS
206
- if isinstance(bench_output, (list, tuple)):
207
- status, compare_result, message = [], [], []
208
- if len(bench_output) > len(device_output):
209
- status = [CompareConst.ERROR]
210
- message = ["bench and npu output structure is different."]
211
- else:
212
- device_output = device_output[:len(bench_output)]
213
- for b_out_i, n_out_i in zip(bench_output, device_output):
214
- status_i, compare_result_i, message_i = self._compare_core(api_name, b_out_i, n_out_i)
215
- status.append(status_i)
216
- compare_result.append(compare_result_i)
217
- message.append(message_i)
218
- else:
219
- status, compare_result, message = self._compare_core(api_name, bench_output, device_output)
220
- if not isinstance(status, list):
221
- detailed_result_total.append(compare_result.to_column_value(status, message))
222
- if status == CompareConst.ERROR:
223
- test_final_success = CompareConst.ERROR
224
- elif status == CompareConst.WARNING:
225
- test_final_success = CompareConst.WARNING
226
- else:
227
- for item, item_status in enumerate(status):
228
- detailed_result_total.append(compare_result[item].to_column_value(item_status, message[item]))
229
- if item_status == CompareConst.ERROR:
230
- test_final_success = CompareConst.ERROR
231
- elif item_status == CompareConst.WARNING:
232
- test_final_success = CompareConst.WARNING
233
- return test_final_success, detailed_result_total
234
-
235
- def _compare_core(self, api_name, bench_output, device_output):
236
- compare_column = CompareColumn()
237
- if not isinstance(bench_output, type(device_output)):
238
- return CompareConst.ERROR, compare_column, "bench and npu output type is different."
239
- elif isinstance(bench_output, dict):
240
- b_keys, n_keys = set(bench_output.keys()), set(device_output.keys())
241
- if b_keys != n_keys:
242
- return CompareConst.ERROR, compare_column, "bench and npu output dict keys are different."
243
- else:
244
- status, compare_result, message = self._compare_core(api_name, list(bench_output.values()),
245
- list(device_output.values()))
246
- elif isinstance(bench_output, torch.Tensor):
247
- copy_bench_out = bench_output.detach().clone()
248
- copy_device_output = device_output.detach().clone()
249
- compare_column.bench_type = str(copy_bench_out.dtype)
250
- compare_column.npu_type = str(copy_device_output.dtype)
251
- compare_column.shape = tuple(device_output.shape)
252
- status, compare_result, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output,
253
- compare_column)
254
- elif isinstance(bench_output, (bool, int, float, str)):
255
- compare_column.bench_type = str(type(bench_output))
256
- compare_column.npu_type = str(type(device_output))
257
- status, compare_result, message = self._compare_builtin_type(bench_output, device_output, compare_column)
258
- elif bench_output is None:
259
- return CompareConst.SKIP, compare_column, "Bench output is None, skip this test."
260
- else:
261
- return CompareConst.PASS, compare_column,
262
- "Unexpected output type in compare_core: {}".format(type(bench_output))
263
-
264
- return status, compare_result, message
265
-
266
- def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column):
267
- cpu_shape = bench_output.shape
268
- npu_shape = device_output.shape
269
- npu_dtype = device_output.dtype
270
- if npu_dtype == torch.bfloat16:
271
- bench_output = bench_output.to(torch.float32)
272
- device_output = device_output.to(torch.float32)
273
- bench_output = bench_output.numpy()
274
- device_output = device_output.cpu().numpy()
275
- if cpu_shape != npu_shape:
276
- return CompareConst.ERROR, compare_column, f"The shape of bench{str(cpu_shape)} " \
277
- f"and npu{str(npu_shape)} not equal."
278
- if not check_dtype_comparable(bench_output, device_output):
279
- return CompareConst.ERROR, compare_column, f"Bench out dtype is {bench_output.dtype} but " \
280
- f"npu output dtype is {device_output.dtype}, cannot compare."
281
- message = ""
282
- if bench_output.dtype in [bool, np.uint8, np.int8, np.int16, np.uint16, np.uint32, np.int32,
283
- np.int64, np.uint64]:
284
- message += f"Compare algorithm is not supported for {bench_output.dtype} data. " \
285
- f"Only judged by Error Rate."
286
- err_rate, status, msg = self._compare_bool_tensor(bench_output, device_output)
287
- message += msg + "\n"
288
- compare_column.error_rate = err_rate
289
- return status, compare_column, message
290
- else:
291
- status, compare_column, message = self._compare_float_tensor(api_name, bench_output, device_output,
292
- compare_column, npu_dtype)
293
- return status, compare_column, message
294
-
295
- def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype):
296
- message = ""
297
- abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype)
298
- abs_err = get_abs_err(bench_output, device_output)
299
- rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps)
300
- if api_name in thousandth_standard_api:
301
- thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
302
- compare_column.rel_err_thousandth = thousand_res
303
- if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST:
304
- both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output)
305
- if api_name in binary_standard_api:
306
- err_rate, _, _ = self._compare_bool_tensor(bench_output, device_output)
307
- compare_column.error_rate = err_rate
308
- elif api_name in absolute_standard_api:
309
- small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute(
310
- api_name, str(dtype))
311
- rel_err = abs_err / abs_bench_with_eps
312
- small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold)
313
- normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask))
314
- compare_column.inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, bench_output, device_output,
315
- dtype, rtol)
316
- compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol)
317
- compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol)
318
- elif api_name in ulp_standard_api:
319
- if bench_output.size == 0:
320
- compare_column.max_ulp_error = 0
321
- compare_column.mean_ulp_error = 0
322
- compare_column.ulp_error_proportion = 0
323
- else:
324
- ulp_err = get_ulp_err(bench_output, device_output, dtype)
325
- compare_column.max_ulp_error = np.max(ulp_err)
326
- compare_column.mean_ulp_error = np.mean(ulp_err)
327
- if dtype == torch.float32:
328
- compare_column.ulp_error_proportion = np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size
329
- else:
330
- compare_column.ulp_error_proportion = np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size
331
- else:
332
- dtype_config = precision_configs.get(dtype)
333
- small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0])
334
- abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0])
335
- compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask)
336
- rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask)
337
- compare_column.RMSE = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask))
338
- compare_column.EB = get_error_balance(bench_output, device_output)
339
- if rel_err.size == 0:
340
- return CompareConst.ERROR, compare_column, "Relative error result list is empty."
341
- compare_column.Max_rel_error = get_max_rel_err(rel_err)
342
- compare_column.Mean_rel_error = get_mean_rel_err(rel_err)
343
-
344
- cos_res, cos_status, msg = cosine_sim(bench_output, device_output)
345
- compare_column.cosine_sim = cos_res
346
- message += msg + "\n"
347
- if not cos_status:
348
- message += "Cosine similarity is less than 0.99, consider as error, skip other check and set to SPACE.\n"
349
- return CompareConst.ERROR, compare_column, message
350
-
351
- max_abs_res, max_abs_status = get_max_abs_err(abs_err)
352
- compare_column.max_abs_err = max_abs_res
353
- if max_abs_status:
354
- message += "Max abs error is less than 0.001, consider as pass, skip other check and set to SPACE.\n"
355
- return CompareConst.PASS, compare_column, message
356
-
357
- if dtype in [torch.float16, torch.bfloat16]:
358
- hundred_res, hundred_status = get_rel_err_ratio(rel_err_orign, CompareConst.HUNDRED_RATIO_THRESHOLD)
359
- compare_column.rel_err_hundredth = hundred_res
360
- if not hundred_status:
361
- message += "Relative error is greater than 0.01, consider as error, skip other check and set to SPACE.\n"
362
- return CompareConst.ERROR, compare_column, message
363
- thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
364
- compare_column.rel_err_thousandth = thousand_res
365
- if dtype in [torch.float16, torch.bfloat16]:
366
- if thousand_status:
367
- message += "Relative error is less than 0.001, consider as pass, skip other check and set to SPACE.\n"
368
- return CompareConst.PASS, compare_column, message
369
- message += "Relative error is greater than 0.001, consider as warning, skip other check and set to SPACE.\n"
370
- return CompareConst.WARNING, compare_column, message
371
- ten_thousand_res, ten_thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.TEN_THOUSAND_RATIO_THRESHOLD)
372
- compare_column.rel_err_ten_thousandth = ten_thousand_res
373
- if dtype in [torch.float32, torch.float64]:
374
- if not thousand_status:
375
- message += "Relative error is greater than 0.001, consider as error, skip other check and set to SPACE.\n"
376
- return CompareConst.ERROR, compare_column, message
377
- if not ten_thousand_status:
378
- message += "Relative error is greater than 0.0001, consider as warning, skip other check and set to SPACE.\n"
379
- return CompareConst.WARNING, compare_column, message
380
- message += "Relative error is less than 0.0001, consider as pass.\n"
381
- return CompareConst.PASS, compare_column, message
1
+ # 进行比对及结果展示
2
+ import os
3
+ from collections import namedtuple
4
+
5
+ import numpy as np
6
+ from msprobe.core.common.utils import CompareException
7
+ from msprobe.core.common.file_utils import get_json_contents, write_csv
8
+ import torch
9
+ from msprobe.core.common.const import CompareConst
10
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \
11
+ get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \
12
+ get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \
13
+ check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err
14
+ from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
15
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
16
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
17
+ DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, absolute_standard_api, binary_standard_api, \
18
+ ulp_standard_api, thousandth_standard_api, apis_threshold
19
+ from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
20
+ from msprobe.pytorch.common.log import logger
21
+
22
+
23
+ ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status',
24
+ 'fwd_compare_alg_results', 'bwd_compare_alg_results', 'rank'])
25
+
26
+
27
+ INDEX_TEST_RESULT_GROUP = 3
28
+ INDEX_FIRST_GROUP = 0
29
+ INDEX_MESSAGE = -1
30
+
31
+
32
+ class Comparator:
33
+ # consts for result csv
34
+ COLUMN_API_NAME = "API name"
35
+ COLUMN_FORWARD_SUCCESS = "Forward Test Success"
36
+ COLUMN_BACKWARD_SUCCESS = "Backward Test Success"
37
+ COLUMN_STACK_INFO = "Traceback callstack info"
38
+
39
+ def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None, config=None):
40
+ self.save_path_str = result_csv_path
41
+ self.detail_save_path_str = details_csv_path
42
+ self.save_path_list = [result_csv_path]
43
+ self.detail_save_path_list = [details_csv_path]
44
+
45
+ if config and config.online_config.is_online:
46
+ self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv")
47
+ self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv")
48
+ self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list]
49
+ self.detail_save_path_list = \
50
+ [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
51
+
52
+ if not is_continue_run_ut:
53
+ self.write_csv_title()
54
+ if stack_info_json_path:
55
+ self.stack_info = get_json_contents(stack_info_json_path)
56
+ else:
57
+ self.stack_info = None
58
+
59
+ @staticmethod
60
+ def get_path_from_rank(rank, path_list, path_pattern):
61
+ return path_list[-1] if len(path_list) == 1 else path_pattern.format(rank)
62
+
63
+ @staticmethod
64
+ def print_pretest_result():
65
+ logger.info("Successfully completed run_ut/multi_run_ut.")
66
+
67
+ @staticmethod
68
+ def _compare_dropout(bench_output, device_output):
69
+ tensor_num = bench_output.numel()
70
+ if tensor_num >= 100:
71
+ if abs((bench_output == 0).sum() - (device_output == 0).cpu().sum()) / tensor_num < 0.1:
72
+ return CompareConst.PASS, 1
73
+ else:
74
+ return CompareConst.ERROR, 0
75
+ else:
76
+ return CompareConst.PASS, 1
77
+
78
+ @staticmethod
79
+ def _compare_builtin_type(bench_output, device_output, compare_column):
80
+ if not isinstance(bench_output, (bool, int, float, str)):
81
+ return CompareConst.PASS, compare_column, ""
82
+ if bench_output != device_output:
83
+ return CompareConst.ERROR, compare_column, ""
84
+ compare_column.error_rate = 0
85
+ return CompareConst.PASS, compare_column, ""
86
+
87
+ @staticmethod
88
+ def _compare_bool_tensor(bench_output, device_output):
89
+ error_nums = (bench_output != device_output).sum()
90
+ if bench_output.size == 0:
91
+ return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result."
92
+ error_rate = float(error_nums / bench_output.size)
93
+ result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
94
+ return error_rate, result, ""
95
+
96
+ @staticmethod
97
+ def _get_absolute_threshold_attribute(api_name, dtype):
98
+ small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value')
99
+ small_value_atol = apis_threshold.get(api_name).get(dtype).get('small_value_atol')
100
+ rtol = apis_threshold.get(api_name).get(dtype).get('rtol')
101
+ return small_value_threshold, small_value_atol, rtol
102
+
103
+ @staticmethod
104
+ def _get_run_ut_detail(test_result):
105
+ """get run_ut detail before write to csv, called by online run_ut"""
106
+ test_rows = []
107
+ try:
108
+ subject_prefix = test_result[0]
109
+ fwd_result = test_result[3]
110
+ bwd_result = test_result[4]
111
+ except IndexError as e:
112
+ logger.error("List index out of bounds when writing detail CSV.")
113
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e
114
+
115
+ if isinstance(fwd_result, list):
116
+ for i, test_subject in enumerate(fwd_result):
117
+ subject = subject_prefix + ".forward.output." + str(i)
118
+ test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision)
119
+ if isinstance(item, float) else item for item in test_subject]
120
+ test_rows.append([subject] + list(test_subject))
121
+ if isinstance(bwd_result, list):
122
+ for i, test_subject in enumerate(bwd_result):
123
+ subject = subject_prefix + ".backward.output." + str(i)
124
+ test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision)
125
+ if isinstance(item, float) else item for item in test_subject]
126
+ test_rows.append([subject] + list(test_subject))
127
+ return test_rows
128
+
129
+ def write_csv_title(self):
130
+ summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS,
131
+ self.COLUMN_BACKWARD_SUCCESS, "Message"]]
132
+ for save_path, detail_save_path in zip(self.save_path_list, self.detail_save_path_list):
133
+ if not os.path.exists(save_path):
134
+ write_csv(summary_test_rows, save_path)
135
+ if not os.path.exists(detail_save_path):
136
+ write_csv(DETAIL_TEST_ROWS, detail_save_path)
137
+
138
+ def write_summary_csv(self, test_result):
139
+ test_rows = []
140
+ try:
141
+ name = test_result[0]
142
+ df_row = list(test_result[:INDEX_TEST_RESULT_GROUP])
143
+ if test_result[1] == CompareConst.SKIP:
144
+ df_row.append(test_result[INDEX_TEST_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
145
+ if self.stack_info:
146
+ stack_info = "\n".join(self.stack_info[name])
147
+ df_row.append(stack_info)
148
+ test_rows.append(df_row)
149
+ save_path = self.get_path_from_rank(test_result[-1], self.save_path_list, self.save_path_str)
150
+ except IndexError as e:
151
+ logger.error("List index out of bounds when writing summary CSV.")
152
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e
153
+ write_csv(test_rows, save_path)
154
+
155
+ def write_detail_csv(self, test_result):
156
+ test_rows = self._get_run_ut_detail(test_result)
157
+ detail_save_path = self.get_path_from_rank(test_result[-1],
158
+ self.detail_save_path_list,
159
+ self.detail_save_path_str)
160
+ write_csv(test_rows, detail_save_path)
161
+
162
+ def record_results(self, args):
163
+ self.write_summary_csv(args)
164
+ self.write_detail_csv(args)
165
+
166
+
167
+ def compare_output(self, full_api_name, data_info, is_online=False):
168
+ """Get compare result and write to result and detail csv.
169
+ is_online: bool, default False. True: called by online api precision compare, only compare without write to csv.
170
+ """
171
+ _, api_name = extract_basic_api_segments(full_api_name)
172
+ if not api_name:
173
+ raise ValueError(f"API name {full_api_name} has not been adapted.")
174
+ bench_output, device_output = data_info.bench_output, data_info.device_output
175
+ bench_grad, device_grad = data_info.bench_grad, data_info.device_grad
176
+ backward_message = data_info.backward_message
177
+ if "dropout" in full_api_name:
178
+ fwd_success_status, fwd_compare_alg_results = self._compare_dropout(bench_output, device_output)
179
+ else:
180
+ fwd_success_status, fwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_output,
181
+ device_output)
182
+ if not (bench_grad and device_grad):
183
+ bwd_success_status, bwd_compare_alg_results = (CompareConst.SPACE, [])
184
+ else:
185
+ if "dropout" in full_api_name:
186
+ bwd_success_status, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], device_grad[0])
187
+ else:
188
+ bwd_success_status, bwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_grad,
189
+ device_grad)
190
+ if backward_message:
191
+ backward_column = CompareColumn()
192
+ bwd_compare_alg_results = [backward_column.to_column_value(CompareConst.SKIP, backward_message)]
193
+ else:
194
+ bwd_success_status = bwd_success_status if bwd_compare_alg_results is not None else CompareConst.SPACE
195
+ result_info = ResultInfo(full_api_name,
196
+ fwd_success_status,
197
+ bwd_success_status,
198
+ fwd_compare_alg_results,
199
+ bwd_compare_alg_results,
200
+ data_info.rank)
201
+ if is_online:
202
+ # get run_ut compare detail
203
+ return self._get_run_ut_detail(result_info)
204
+ self.record_results(result_info)
205
+ return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
206
+ or bwd_success_status == CompareConst.SPACE
207
+
208
+ def _compare_core_wrapper(self, api_name, bench_output, device_output):
209
+ detailed_result_total = []
210
+ test_final_success = CompareConst.PASS
211
+ if isinstance(bench_output, (list, tuple)):
212
+ status, compare_result, message = [], [], []
213
+ if len(bench_output) > len(device_output):
214
+ status = [CompareConst.ERROR]
215
+ message = ["bench and npu output structure is different."]
216
+ else:
217
+ device_output = device_output[:len(bench_output)]
218
+ for b_out_i, n_out_i in zip(bench_output, device_output):
219
+ status_i, compare_result_i, message_i = self._compare_core(api_name, b_out_i, n_out_i)
220
+ status.append(status_i)
221
+ compare_result.append(compare_result_i)
222
+ message.append(message_i)
223
+ else:
224
+ status, compare_result, message = self._compare_core(api_name, bench_output, device_output)
225
+ if not isinstance(status, list):
226
+ detailed_result_total.append(compare_result.to_column_value(status, message))
227
+ if status == CompareConst.ERROR:
228
+ test_final_success = CompareConst.ERROR
229
+ elif status == CompareConst.WARNING:
230
+ test_final_success = CompareConst.WARNING
231
+ else:
232
+ for item, item_status in enumerate(status):
233
+ detailed_result_total.append(compare_result[item].to_column_value(item_status, message[item]))
234
+ if item_status == CompareConst.ERROR:
235
+ test_final_success = CompareConst.ERROR
236
+ elif item_status == CompareConst.WARNING:
237
+ test_final_success = CompareConst.WARNING
238
+ return test_final_success, detailed_result_total
239
+
240
+ def _compare_core(self, api_name, bench_output, device_output):
241
+ compare_column = CompareColumn()
242
+ if not isinstance(bench_output, type(device_output)):
243
+ return CompareConst.ERROR, compare_column, "bench and npu output type is different."
244
+ elif isinstance(bench_output, dict):
245
+ b_keys, n_keys = set(bench_output.keys()), set(device_output.keys())
246
+ if b_keys != n_keys:
247
+ return CompareConst.ERROR, compare_column, "bench and npu output dict keys are different."
248
+ else:
249
+ status, compare_result, message = self._compare_core(api_name, list(bench_output.values()),
250
+ list(device_output.values()))
251
+ elif isinstance(bench_output, torch.Tensor):
252
+ copy_bench_out = bench_output.detach().clone()
253
+ copy_device_output = device_output.detach().clone()
254
+ compare_column.bench_type = str(copy_bench_out.dtype)
255
+ compare_column.npu_type = str(copy_device_output.dtype)
256
+ compare_column.shape = tuple(device_output.shape)
257
+ status, compare_result, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output,
258
+ compare_column)
259
+ elif isinstance(bench_output, (bool, int, float, str)):
260
+ compare_column.bench_type = str(type(bench_output))
261
+ compare_column.npu_type = str(type(device_output))
262
+ status, compare_result, message = self._compare_builtin_type(bench_output, device_output, compare_column)
263
+ elif bench_output is None:
264
+ return CompareConst.SKIP, compare_column, "Bench output is None, skip this test."
265
+ else:
266
+ return CompareConst.PASS, compare_column,
267
+ "Unexpected output type in compare_core: {}".format(type(bench_output))
268
+
269
+ return status, compare_result, message
270
+
271
+ def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column):
272
+ cpu_shape = bench_output.shape
273
+ npu_shape = device_output.shape
274
+ npu_dtype = device_output.dtype
275
+ if npu_dtype == torch.bfloat16:
276
+ bench_output = bench_output.to(torch.float32)
277
+ device_output = device_output.to(torch.float32)
278
+ bench_output = bench_output.cpu().numpy()
279
+ device_output = device_output.cpu().numpy()
280
+ if cpu_shape != npu_shape:
281
+ return CompareConst.ERROR, compare_column, f"The shape of bench{str(cpu_shape)} " \
282
+ f"and npu{str(npu_shape)} not equal."
283
+ if not check_dtype_comparable(bench_output, device_output):
284
+ return CompareConst.ERROR, compare_column, f"Bench out dtype is {bench_output.dtype} but " \
285
+ f"npu output dtype is {device_output.dtype}, cannot compare."
286
+ message = ""
287
+ if bench_output.dtype in [bool, np.uint8, np.int8, np.int16, np.uint16, np.uint32, np.int32,
288
+ np.int64, np.uint64]:
289
+ message += f"Compare algorithm is not supported for {bench_output.dtype} data. " \
290
+ f"Only judged by Error Rate."
291
+ err_rate, status, msg = self._compare_bool_tensor(bench_output, device_output)
292
+ message += msg + "\n"
293
+ compare_column.error_rate = err_rate
294
+ return status, compare_column, message
295
+ else:
296
+ status, compare_column, message = self._compare_float_tensor(api_name, bench_output, device_output,
297
+ compare_column, npu_dtype)
298
+ return status, compare_column, message
299
+
300
+ def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype):
301
+ message = ""
302
+ abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype)
303
+ abs_err = get_abs_err(bench_output, device_output)
304
+ rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps)
305
+ if api_name in thousandth_standard_api:
306
+ thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
307
+ compare_column.rel_err_thousandth = thousand_res
308
+ if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST:
309
+ both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output)
310
+ if api_name in binary_standard_api:
311
+ err_rate, _, _ = self._compare_bool_tensor(bench_output, device_output)
312
+ compare_column.error_rate = err_rate
313
+ elif api_name in absolute_standard_api:
314
+ small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute(
315
+ api_name, str(dtype))
316
+ rel_err = abs_err / abs_bench_with_eps
317
+ small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold)
318
+ normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask))
319
+ compare_column.inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, bench_output, device_output,
320
+ dtype, rtol)
321
+ compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol)
322
+ compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol)
323
+ elif api_name in ulp_standard_api:
324
+ if bench_output.size == 0:
325
+ compare_column.max_ulp_error = 0
326
+ compare_column.mean_ulp_error = 0
327
+ compare_column.ulp_error_proportion = 0
328
+ else:
329
+ ulp_err = get_ulp_err(bench_output, device_output, dtype)
330
+ compare_column.max_ulp_error = np.max(ulp_err)
331
+ compare_column.mean_ulp_error = np.mean(ulp_err)
332
+ if dtype == torch.float32:
333
+ compare_column.ulp_error_proportion = np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size
334
+ else:
335
+ compare_column.ulp_error_proportion = np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size
336
+ else:
337
+ dtype_config = precision_configs.get(dtype)
338
+ small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0])
339
+ abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0])
340
+ compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask)
341
+ rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask)
342
+ compare_column.RMSE = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask))
343
+ compare_column.EB = get_error_balance(bench_output, device_output)
344
+ if rel_err.size == 0:
345
+ return CompareConst.ERROR, compare_column, "Relative error result list is empty."
346
+ compare_column.Max_rel_error = get_max_rel_err(rel_err)
347
+ compare_column.Mean_rel_error = get_mean_rel_err(rel_err)
348
+
349
+ cos_res, cos_status, msg = cosine_sim(bench_output, device_output)
350
+ compare_column.cosine_sim = cos_res
351
+ message += msg + "\n"
352
+ if not cos_status:
353
+ message += "Cosine similarity is less than 0.99, consider as error, skip other check and set to SPACE.\n"
354
+ return CompareConst.ERROR, compare_column, message
355
+
356
+ max_abs_res, max_abs_status = get_max_abs_err(abs_err)
357
+ compare_column.max_abs_err = max_abs_res
358
+ if max_abs_status:
359
+ message += "Max abs error is less than 0.001, consider as pass, skip other check and set to SPACE.\n"
360
+ return CompareConst.PASS, compare_column, message
361
+
362
+ if dtype in [torch.float16, torch.bfloat16]:
363
+ hundred_res, hundred_status = get_rel_err_ratio(rel_err_orign, CompareConst.HUNDRED_RATIO_THRESHOLD)
364
+ compare_column.rel_err_hundredth = hundred_res
365
+ if not hundred_status:
366
+ message += "Relative error is greater than 0.01, consider as error, skip other check and set to SPACE.\n"
367
+ return CompareConst.ERROR, compare_column, message
368
+ thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
369
+ compare_column.rel_err_thousandth = thousand_res
370
+ if dtype in [torch.float16, torch.bfloat16]:
371
+ if thousand_status:
372
+ message += "Relative error is less than 0.001, consider as pass, skip other check and set to SPACE.\n"
373
+ return CompareConst.PASS, compare_column, message
374
+ message += "Relative error is greater than 0.001, consider as warning, skip other check and set to SPACE.\n"
375
+ return CompareConst.WARNING, compare_column, message
376
+ ten_thousand_res, ten_thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.TEN_THOUSAND_RATIO_THRESHOLD)
377
+ compare_column.rel_err_ten_thousandth = ten_thousand_res
378
+ if dtype in [torch.float32, torch.float64]:
379
+ if not thousand_status:
380
+ message += "Relative error is greater than 0.001, consider as error, skip other check and set to SPACE.\n"
381
+ return CompareConst.ERROR, compare_column, message
382
+ if not ten_thousand_status:
383
+ message += "Relative error is greater than 0.0001, consider as warning, skip other check and set to SPACE.\n"
384
+ return CompareConst.WARNING, compare_column, message
385
+ message += "Relative error is less than 0.0001, consider as pass.\n"
386
+ return CompareConst.PASS, compare_column, message