mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,74 +1,74 @@
1
- from msprobe.core.common.const import CompareConst
2
-
3
-
4
- class CompareColumn:
5
- def __init__(self):
6
- self.bench_type = CompareConst.SPACE
7
- self.npu_type = CompareConst.SPACE
8
- self.shape = CompareConst.SPACE
9
- self.cosine_sim = CompareConst.SPACE
10
- self.max_abs_err = CompareConst.SPACE
11
- self.rel_err_hundredth = CompareConst.SPACE
12
- self.rel_err_thousandth = CompareConst.SPACE
13
- self.rel_err_ten_thousandth = CompareConst.SPACE
14
- self.error_rate = CompareConst.SPACE
15
- self.EB = CompareConst.SPACE
16
- self.RMSE = CompareConst.SPACE
17
- self.small_value_err_ratio = CompareConst.SPACE
18
- self.Max_rel_error = CompareConst.SPACE
19
- self.Mean_rel_error = CompareConst.SPACE
20
- self.inf_nan_error_ratio = CompareConst.SPACE
21
- self.rel_err_ratio = CompareConst.SPACE
22
- self.abs_err_ratio = CompareConst.SPACE
23
- self.max_ulp_error = CompareConst.SPACE
24
- self.mean_ulp_error = CompareConst.SPACE
25
- self.ulp_error_proportion = CompareConst.SPACE
26
-
27
- def to_column_value(self, is_pass, message):
28
- return [self.bench_type, self.npu_type, self.shape, self.cosine_sim, self.max_abs_err, self.rel_err_hundredth,
29
- self.rel_err_thousandth, self.rel_err_ten_thousandth, self.error_rate, self.EB, self.RMSE,
30
- self.small_value_err_ratio, self.Max_rel_error, self.Mean_rel_error, self.inf_nan_error_ratio,
31
- self.rel_err_ratio, self.abs_err_ratio, self.max_ulp_error, self.mean_ulp_error,
32
- self.ulp_error_proportion, is_pass, message]
33
-
34
-
35
- class ApiPrecisionOutputColumn:
36
- def __init__(self):
37
- self.api_name = CompareConst.SPACE
38
- self.small_value_err_ratio = CompareConst.SPACE
39
- self.small_value_err_status = CompareConst.SPACE
40
- self.rmse_ratio = CompareConst.SPACE
41
- self.rmse_status = CompareConst.SPACE
42
- self.max_rel_err_ratio = CompareConst.SPACE
43
- self.max_rel_err_status = CompareConst.SPACE
44
- self.mean_rel_err_ratio = CompareConst.SPACE
45
- self.mean_rel_err_status = CompareConst.SPACE
46
- self.eb_ratio = CompareConst.SPACE
47
- self.eb_status = CompareConst.SPACE
48
- self.inf_nan_error_ratio = CompareConst.SPACE
49
- self.inf_nan_error_ratio_status = CompareConst.SPACE
50
- self.rel_err_ratio = CompareConst.SPACE
51
- self.rel_err_ratio_status = CompareConst.SPACE
52
- self.abs_err_ratio = CompareConst.SPACE
53
- self.abs_err_ratio_status = CompareConst.SPACE
54
- self.error_rate = CompareConst.SPACE
55
- self.error_rate_status = CompareConst.SPACE
56
- self.mean_ulp_err = CompareConst.SPACE
57
- self.ulp_err_proportion = CompareConst.SPACE
58
- self.ulp_err_proportion_ratio = CompareConst.SPACE
59
- self.ulp_err_status = CompareConst.SPACE
60
- self.rel_err_thousandth = CompareConst.SPACE
61
- self.rel_err_thousandth_status = CompareConst.SPACE
62
- self.compare_result = CompareConst.SPACE
63
- self.compare_algorithm = CompareConst.SPACE
64
- self.compare_message = CompareConst.SPACE
65
-
66
- def to_column_value(self):
67
- return [self.api_name, self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
68
- self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
69
- self.mean_rel_err_status, self.eb_ratio, self.eb_status, self.inf_nan_error_ratio,
70
- self.inf_nan_error_ratio_status, self.rel_err_ratio, self.rel_err_ratio_status, self.abs_err_ratio,
71
- self.abs_err_ratio_status, self.error_rate, self.error_rate_status, self.mean_ulp_err,
72
- self.ulp_err_proportion, self.ulp_err_proportion_ratio, self.ulp_err_status, self.rel_err_thousandth,
73
- self.rel_err_thousandth_status, self.compare_result, self.compare_algorithm, self.compare_message]
1
+ from msprobe.core.common.const import CompareConst
2
+
3
+
4
+ class CompareColumn:
5
+ def __init__(self):
6
+ self.bench_type = CompareConst.SPACE
7
+ self.npu_type = CompareConst.SPACE
8
+ self.shape = CompareConst.SPACE
9
+ self.cosine_sim = CompareConst.SPACE
10
+ self.max_abs_err = CompareConst.SPACE
11
+ self.rel_err_hundredth = CompareConst.SPACE
12
+ self.rel_err_thousandth = CompareConst.SPACE
13
+ self.rel_err_ten_thousandth = CompareConst.SPACE
14
+ self.error_rate = CompareConst.SPACE
15
+ self.EB = CompareConst.SPACE
16
+ self.RMSE = CompareConst.SPACE
17
+ self.small_value_err_ratio = CompareConst.SPACE
18
+ self.Max_rel_error = CompareConst.SPACE
19
+ self.Mean_rel_error = CompareConst.SPACE
20
+ self.inf_nan_error_ratio = CompareConst.SPACE
21
+ self.rel_err_ratio = CompareConst.SPACE
22
+ self.abs_err_ratio = CompareConst.SPACE
23
+ self.max_ulp_error = CompareConst.SPACE
24
+ self.mean_ulp_error = CompareConst.SPACE
25
+ self.ulp_error_proportion = CompareConst.SPACE
26
+
27
+ def to_column_value(self, is_pass, message):
28
+ return [self.bench_type, self.npu_type, self.shape, self.cosine_sim, self.max_abs_err, self.rel_err_hundredth,
29
+ self.rel_err_thousandth, self.rel_err_ten_thousandth, self.error_rate, self.EB, self.RMSE,
30
+ self.small_value_err_ratio, self.Max_rel_error, self.Mean_rel_error, self.inf_nan_error_ratio,
31
+ self.rel_err_ratio, self.abs_err_ratio, self.max_ulp_error, self.mean_ulp_error,
32
+ self.ulp_error_proportion, is_pass, message]
33
+
34
+
35
+ class ApiPrecisionOutputColumn:
36
+ def __init__(self):
37
+ self.api_name = CompareConst.SPACE
38
+ self.small_value_err_ratio = CompareConst.SPACE
39
+ self.small_value_err_status = CompareConst.SPACE
40
+ self.rmse_ratio = CompareConst.SPACE
41
+ self.rmse_status = CompareConst.SPACE
42
+ self.max_rel_err_ratio = CompareConst.SPACE
43
+ self.max_rel_err_status = CompareConst.SPACE
44
+ self.mean_rel_err_ratio = CompareConst.SPACE
45
+ self.mean_rel_err_status = CompareConst.SPACE
46
+ self.eb_ratio = CompareConst.SPACE
47
+ self.eb_status = CompareConst.SPACE
48
+ self.inf_nan_error_ratio = CompareConst.SPACE
49
+ self.inf_nan_error_ratio_status = CompareConst.SPACE
50
+ self.rel_err_ratio = CompareConst.SPACE
51
+ self.rel_err_ratio_status = CompareConst.SPACE
52
+ self.abs_err_ratio = CompareConst.SPACE
53
+ self.abs_err_ratio_status = CompareConst.SPACE
54
+ self.error_rate = CompareConst.SPACE
55
+ self.error_rate_status = CompareConst.SPACE
56
+ self.mean_ulp_err = CompareConst.SPACE
57
+ self.ulp_err_proportion = CompareConst.SPACE
58
+ self.ulp_err_proportion_ratio = CompareConst.SPACE
59
+ self.ulp_err_status = CompareConst.SPACE
60
+ self.rel_err_thousandth = CompareConst.SPACE
61
+ self.rel_err_thousandth_status = CompareConst.SPACE
62
+ self.compare_result = CompareConst.SPACE
63
+ self.compare_algorithm = CompareConst.SPACE
64
+ self.compare_message = CompareConst.SPACE
65
+
66
+ def to_column_value(self):
67
+ return [self.api_name, self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
68
+ self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
69
+ self.mean_rel_err_status, self.eb_ratio, self.eb_status, self.inf_nan_error_ratio,
70
+ self.inf_nan_error_ratio_status, self.rel_err_ratio, self.rel_err_ratio_status, self.abs_err_ratio,
71
+ self.abs_err_ratio_status, self.error_rate, self.error_rate_status, self.mean_ulp_err,
72
+ self.ulp_err_proportion, self.ulp_err_proportion_ratio, self.ulp_err_status, self.rel_err_thousandth,
73
+ self.rel_err_thousandth_status, self.compare_result, self.compare_algorithm, self.compare_message]
74
74
 
@@ -1,245 +1,246 @@
1
- import time
2
- import os
3
- import math
4
-
5
- import torch
6
-
7
- from msprobe.core.common.utils import CompareException, load_yaml
8
- from msprobe.core.common.const import Const
9
- from msprobe.pytorch.common.log import logger
10
-
11
-
12
- current_time = time.strftime("%Y%m%d%H%M%S")
13
- API_PRECISION_COMPARE_RESULT_FILE_NAME = "api_precision_compare_result_" + current_time + ".csv"
14
- API_PRECISION_COMPARE_DETAILS_FILE_NAME = "api_precision_compare_details_" + current_time + ".csv"
15
- BENCHMARK_COMPARE_SUPPORT_LIST = ['torch.float16', 'torch.bfloat16', 'torch.float32']
16
- API_PRECISION_COMPARE_UNSUPPORT_LIST = ['torch.float64', 'torch.complex64', 'torch.complex128']
17
- ULP_COMPARE_SUPPORT_LIST = ['torch.float16', 'torch.bfloat16', 'torch.float32']
18
- BINARY_COMPARE_UNSUPPORT_LIST = BENCHMARK_COMPARE_SUPPORT_LIST + API_PRECISION_COMPARE_UNSUPPORT_LIST
19
-
20
-
21
- cur_path = os.path.dirname(os.path.realpath(__file__))
22
- standard_yaml_path = os.path.join(cur_path, "api_precision_standard.yaml")
23
- apis = load_yaml(standard_yaml_path)
24
- absolute_standard_api = apis.get('AbsoluteThreshStandard')
25
- binary_standard_api = apis.get('BinaryCompareStandard')
26
- ulp_standard_api = apis.get('ULPStandard')
27
- thousandth_standard_api = apis.get('ThousandthStandard')
28
-
29
-
30
- threshold_yaml_path = os.path.join(cur_path, "api_precision_threshold.yaml")
31
- apis_threshold = load_yaml(threshold_yaml_path)
32
-
33
-
34
- DETAIL_TEST_ROWS = [[
35
- "API Name", "Bench Dtype", "DEVICE Dtype", "Shape",
36
- "余弦相似度",
37
- "最大绝对误差",
38
- "双百指标",
39
- "双千指标",
40
- "双万指标",
41
- "二进制一致错误率",
42
- "误差均衡性",
43
- "均方根误差",
44
- "小值域错误占比",
45
- "相对误差最大值",
46
- "相对误差平均值",
47
- "inf/nan错误率",
48
- "相对误差错误率",
49
- "绝对误差错误率",
50
- "ULP误差最大值",
51
- "ULP误差平均值",
52
- "ULP误差大于阈值占比",
53
- "Status",
54
- "Message"
55
- ]]
56
-
57
-
58
- precision_configs = {
59
- torch.float16 : {
60
- 'small_value' : [
61
- 1e-3
62
- ],
63
- 'small_value_atol' : [
64
- 1e-5
65
- ]
66
- },
67
- torch.bfloat16: {
68
- 'small_value' : [
69
- 1e-3
70
- ],
71
- 'small_value_atol' : [
72
- 1e-5
73
- ]
74
- },
75
- torch.float32:{
76
- 'small_value' : [
77
- 1e-6
78
- ],
79
- 'small_value_atol' : [
80
- 1e-9
81
- ]
82
- }
83
- }
84
-
85
-
86
- ULP_PARAMETERS = {
87
- torch.float16 : {
88
- 'min_eb' : [
89
- -14
90
- ],
91
- 'exponent_num' : [
92
- 10
93
- ]
94
- },
95
- torch.bfloat16 : {
96
- 'min_eb' : [
97
- -126
98
- ],
99
- 'exponent_num' : [
100
- 7
101
- ]
102
- },
103
- torch.float32 : {
104
- 'min_eb' : [
105
- -126
106
- ],
107
- 'exponent_num' : [
108
- 23
109
- ]
110
- }
111
- }
112
-
113
-
114
- class ApiPrecisionCompareColumn:
115
- API_NAME = 'API Name'
116
- DEVICE_DTYPE = 'DEVICE Dtype'
117
- SMALL_VALUE_ERROR_RATE = '小值域错误占比'
118
- RMSE = '均方根误差'
119
- MAX_REL_ERR = '相对误差最大值'
120
- MEAN_REL_ERR = '相对误差平均值'
121
- EB = '误差均衡性'
122
- SMALL_VALUE_ERROR_RATIO = '小值域错误比值'
123
- SMALL_VALUE_ERROR_STATUS = '小值域判定结果'
124
- RMSE_RATIO = '均方根误差比值'
125
- RMSE_STATUS = '均方根误差判定结果'
126
- MAX_REL_ERR_RATIO = '相对误差最大值比值'
127
- MAX_REL_ERR_STATUS = '相对误差最大值判定结果'
128
- MEAN_REL_ERR_RATIO = '相对误差平均值比值'
129
- MEAN_REL_ERR_STATUS = '相对误差平均值判定结果'
130
- EB_RATIO = '误差均衡性比值'
131
- EB_STATUS = '误差均衡性判定结果'
132
- ERROR_RATE = '二进制一致错误率'
133
- ERROR_RATE_STATUS = '二进制一致错误率判定结果'
134
- INF_NAN_ERROR_RATIO = 'inf/nan错误率'
135
- INF_NAN_ERROR_RATIO_STATUS = 'inf/nan判定结果'
136
- REL_ERR_RATIO = '相对误差错误率'
137
- REL_ERR_RATIO_STATUS = '相对误差判定结果'
138
- ABS_ERR_RATIO = '绝对误差错误率'
139
- ABS_ERR_RATIO_STATUS = '绝对误差判定结果'
140
- MEAN_ULP_ERR = 'ULP误差平均值'
141
- ULP_ERR_PROPORTION = 'ULP误差大于阈值占比'
142
- ULP_ERR_PROPORTION_RATIO = 'ULP误差大于阈值占比比值'
143
- ULP_ERR_STATUS = 'ULP误差判定结果'
144
- REL_ERR_THOUSANDTH = '双千指标'
145
- REL_ERR_THOUSANDTH_STATUS = '双千指标判定结果'
146
- FINAL_RESULT = '比对结果'
147
- ALGORITHM = '比对算法'
148
- FORWWARD_STATUS = 'Forward Test Success'
149
- BACKWARD_STATUS = 'Backward Test Success'
150
- MESSAGE = 'Message'
151
-
152
- @staticmethod
153
- def to_required_columns():
154
- return [ApiPrecisionCompareColumn.API_NAME, ApiPrecisionCompareColumn.DEVICE_DTYPE,
155
- ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE, ApiPrecisionCompareColumn.RMSE,
156
- ApiPrecisionCompareColumn.MAX_REL_ERR, ApiPrecisionCompareColumn.MEAN_REL_ERR, ApiPrecisionCompareColumn.EB,
157
- ApiPrecisionCompareColumn.ERROR_RATE, ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO,
158
- ApiPrecisionCompareColumn.REL_ERR_RATIO, ApiPrecisionCompareColumn.ABS_ERR_RATIO,
159
- ApiPrecisionCompareColumn.MEAN_ULP_ERR, ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
160
- ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
161
-
162
- @staticmethod
163
- def get_detail_csv_title():
164
- return [ApiPrecisionCompareColumn.API_NAME,
165
- ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATIO, ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_STATUS,
166
- ApiPrecisionCompareColumn.RMSE_RATIO, ApiPrecisionCompareColumn.RMSE_STATUS,
167
- ApiPrecisionCompareColumn.MAX_REL_ERR_RATIO, ApiPrecisionCompareColumn.MAX_REL_ERR_STATUS,
168
- ApiPrecisionCompareColumn.MEAN_REL_ERR_RATIO, ApiPrecisionCompareColumn.MEAN_REL_ERR_STATUS,
169
- ApiPrecisionCompareColumn.EB_RATIO, ApiPrecisionCompareColumn.EB_STATUS,
170
- ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO, ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO_STATUS,
171
- ApiPrecisionCompareColumn.REL_ERR_RATIO, ApiPrecisionCompareColumn.REL_ERR_RATIO_STATUS,
172
- ApiPrecisionCompareColumn.ABS_ERR_RATIO, ApiPrecisionCompareColumn.ABS_ERR_RATIO_STATUS,
173
- ApiPrecisionCompareColumn.ERROR_RATE, ApiPrecisionCompareColumn.ERROR_RATE_STATUS,
174
- ApiPrecisionCompareColumn.MEAN_ULP_ERR, ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
175
- ApiPrecisionCompareColumn.ULP_ERR_PROPORTION_RATIO, ApiPrecisionCompareColumn.ULP_ERR_STATUS,
176
- ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH, ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH_STATUS,
177
- ApiPrecisionCompareColumn.FINAL_RESULT, ApiPrecisionCompareColumn.ALGORITHM, ApiPrecisionCompareColumn.MESSAGE]
178
-
179
- @staticmethod
180
- def get_result_csv_title():
181
- return [ApiPrecisionCompareColumn.API_NAME, ApiPrecisionCompareColumn.FORWWARD_STATUS,
182
- ApiPrecisionCompareColumn.BACKWARD_STATUS, ApiPrecisionCompareColumn.MESSAGE]
183
-
184
-
185
- CompareMessage = {
186
- "topk" : "在npu上,topk的入参sorted=False时不生效,会返回有序tensor,而cpu上会返回无序tensor。 如果topk精度不达标,请检查是否是该原因导致的。"
187
- }
188
-
189
-
190
- def check_dtype_comparable(x, y):
191
- if x.dtype in Const.FLOAT_TYPE:
192
- if y.dtype in Const.FLOAT_TYPE:
193
- return True
194
- return False
195
- if x.dtype in Const.BOOL_TYPE:
196
- if y.dtype in Const.BOOL_TYPE:
197
- return True
198
- return False
199
- if x.dtype in Const.INT_TYPE:
200
- if y.dtype in Const.INT_TYPE:
201
- return True
202
- return False
203
- logger.warning(f"Compare: Unexpected dtype {x.dtype}, {y.dtype}")
204
- return False
205
-
206
-
207
- def convert_str_to_float(input_data):
208
- if isinstance(input_data, str) and input_data.strip() == "":
209
- msg = 'ERROR: Input data is an empty string'
210
- raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
211
- try:
212
- float_data = float(input_data)
213
- return float_data
214
- except ValueError as e:
215
- msg = 'ERROR: Input data cannot be converted to float'
216
- raise CompareException(CompareException.INVALID_DATA_ERROR, msg) from e
217
-
218
-
219
- def is_inf_or_nan(x):
220
- return math.isnan(x) or math.isinf(x)
221
-
222
-
223
- def handle_infinity(x, y, column_name):
224
- if math.isinf(x) and math.isinf(y):
225
- if x == y:
226
- return float("nan"), True, f"{column_name}同为同号inf或nan\n"
227
- else:
228
- return float("nan"), False, f"{column_name}inf或nan不一致\n"
229
- else:
230
- return float("nan"), False, f"{column_name}inf或nan不一致\n"
231
-
232
-
233
- def handle_nan(x, y, column_name):
234
- if math.isnan(x) and math.isnan(y):
235
- return float("nan"), True, f"{column_name}同为同号inf或nan\n"
236
- else:
237
- return float("nan"), False, f"{column_name}inf或nan不一致\n"
238
-
239
-
240
- def check_inf_or_nan(x, y, column_name):
241
- if math.isinf(x) or math.isinf(y):
242
- return handle_infinity(x, y, column_name)
243
- else:
244
- return handle_nan(x, y, column_name)
1
+ import time
2
+ import os
3
+ import math
4
+
5
+ import torch
6
+
7
+ from msprobe.core.common.utils import CompareException
8
+ from msprobe.core.common.file_utils import load_yaml
9
+ from msprobe.core.common.const import Const
10
+ from msprobe.pytorch.common.log import logger
11
+
12
+
13
+ current_time = time.strftime("%Y%m%d%H%M%S")
14
+ API_PRECISION_COMPARE_RESULT_FILE_NAME = "api_precision_compare_result_" + current_time + ".csv"
15
+ API_PRECISION_COMPARE_DETAILS_FILE_NAME = "api_precision_compare_details_" + current_time + ".csv"
16
+ BENCHMARK_COMPARE_SUPPORT_LIST = ['torch.float16', 'torch.bfloat16', 'torch.float32']
17
+ API_PRECISION_COMPARE_UNSUPPORT_LIST = ['torch.float64', 'torch.complex64', 'torch.complex128']
18
+ ULP_COMPARE_SUPPORT_LIST = ['torch.float16', 'torch.bfloat16', 'torch.float32']
19
+ BINARY_COMPARE_UNSUPPORT_LIST = BENCHMARK_COMPARE_SUPPORT_LIST + API_PRECISION_COMPARE_UNSUPPORT_LIST
20
+
21
+
22
+ cur_path = os.path.dirname(os.path.realpath(__file__))
23
+ standard_yaml_path = os.path.join(cur_path, "api_precision_standard.yaml")
24
+ apis = load_yaml(standard_yaml_path)
25
+ absolute_standard_api = apis.get('AbsoluteThreshStandard')
26
+ binary_standard_api = apis.get('BinaryCompareStandard')
27
+ ulp_standard_api = apis.get('ULPStandard')
28
+ thousandth_standard_api = apis.get('ThousandthStandard')
29
+
30
+
31
+ threshold_yaml_path = os.path.join(cur_path, "api_precision_threshold.yaml")
32
+ apis_threshold = load_yaml(threshold_yaml_path)
33
+
34
+
35
+ DETAIL_TEST_ROWS = [[
36
+ "API Name", "Bench Dtype", "DEVICE Dtype", "Shape",
37
+ "余弦相似度",
38
+ "最大绝对误差",
39
+ "双百指标",
40
+ "双千指标",
41
+ "双万指标",
42
+ "二进制一致错误率",
43
+ "误差均衡性",
44
+ "均方根误差",
45
+ "小值域错误占比",
46
+ "相对误差最大值",
47
+ "相对误差平均值",
48
+ "inf/nan错误率",
49
+ "相对误差错误率",
50
+ "绝对误差错误率",
51
+ "ULP误差最大值",
52
+ "ULP误差平均值",
53
+ "ULP误差大于阈值占比",
54
+ "Status",
55
+ "Message"
56
+ ]]
57
+
58
+
59
+ precision_configs = {
60
+ torch.float16 : {
61
+ 'small_value' : [
62
+ 1e-3
63
+ ],
64
+ 'small_value_atol' : [
65
+ 1e-5
66
+ ]
67
+ },
68
+ torch.bfloat16: {
69
+ 'small_value' : [
70
+ 1e-3
71
+ ],
72
+ 'small_value_atol' : [
73
+ 1e-5
74
+ ]
75
+ },
76
+ torch.float32:{
77
+ 'small_value' : [
78
+ 1e-6
79
+ ],
80
+ 'small_value_atol' : [
81
+ 1e-9
82
+ ]
83
+ }
84
+ }
85
+
86
+
87
+ ULP_PARAMETERS = {
88
+ torch.float16 : {
89
+ 'min_eb' : [
90
+ -14
91
+ ],
92
+ 'exponent_num' : [
93
+ 10
94
+ ]
95
+ },
96
+ torch.bfloat16 : {
97
+ 'min_eb' : [
98
+ -126
99
+ ],
100
+ 'exponent_num' : [
101
+ 7
102
+ ]
103
+ },
104
+ torch.float32 : {
105
+ 'min_eb' : [
106
+ -126
107
+ ],
108
+ 'exponent_num' : [
109
+ 23
110
+ ]
111
+ }
112
+ }
113
+
114
+
115
+ class ApiPrecisionCompareColumn:
116
+ API_NAME = 'API Name'
117
+ DEVICE_DTYPE = 'DEVICE Dtype'
118
+ SMALL_VALUE_ERROR_RATE = '小值域错误占比'
119
+ RMSE = '均方根误差'
120
+ MAX_REL_ERR = '相对误差最大值'
121
+ MEAN_REL_ERR = '相对误差平均值'
122
+ EB = '误差均衡性'
123
+ SMALL_VALUE_ERROR_RATIO = '小值域错误比值'
124
+ SMALL_VALUE_ERROR_STATUS = '小值域判定结果'
125
+ RMSE_RATIO = '均方根误差比值'
126
+ RMSE_STATUS = '均方根误差判定结果'
127
+ MAX_REL_ERR_RATIO = '相对误差最大值比值'
128
+ MAX_REL_ERR_STATUS = '相对误差最大值判定结果'
129
+ MEAN_REL_ERR_RATIO = '相对误差平均值比值'
130
+ MEAN_REL_ERR_STATUS = '相对误差平均值判定结果'
131
+ EB_RATIO = '误差均衡性比值'
132
+ EB_STATUS = '误差均衡性判定结果'
133
+ ERROR_RATE = '二进制一致错误率'
134
+ ERROR_RATE_STATUS = '二进制一致错误率判定结果'
135
+ INF_NAN_ERROR_RATIO = 'inf/nan错误率'
136
+ INF_NAN_ERROR_RATIO_STATUS = 'inf/nan判定结果'
137
+ REL_ERR_RATIO = '相对误差错误率'
138
+ REL_ERR_RATIO_STATUS = '相对误差判定结果'
139
+ ABS_ERR_RATIO = '绝对误差错误率'
140
+ ABS_ERR_RATIO_STATUS = '绝对误差判定结果'
141
+ MEAN_ULP_ERR = 'ULP误差平均值'
142
+ ULP_ERR_PROPORTION = 'ULP误差大于阈值占比'
143
+ ULP_ERR_PROPORTION_RATIO = 'ULP误差大于阈值占比比值'
144
+ ULP_ERR_STATUS = 'ULP误差判定结果'
145
+ REL_ERR_THOUSANDTH = '双千指标'
146
+ REL_ERR_THOUSANDTH_STATUS = '双千指标判定结果'
147
+ FINAL_RESULT = '比对结果'
148
+ ALGORITHM = '比对算法'
149
+ FORWWARD_STATUS = 'Forward Test Success'
150
+ BACKWARD_STATUS = 'Backward Test Success'
151
+ MESSAGE = 'Message'
152
+
153
+ @staticmethod
154
+ def to_required_columns():
155
+ return [ApiPrecisionCompareColumn.API_NAME, ApiPrecisionCompareColumn.DEVICE_DTYPE,
156
+ ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE, ApiPrecisionCompareColumn.RMSE,
157
+ ApiPrecisionCompareColumn.MAX_REL_ERR, ApiPrecisionCompareColumn.MEAN_REL_ERR, ApiPrecisionCompareColumn.EB,
158
+ ApiPrecisionCompareColumn.ERROR_RATE, ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO,
159
+ ApiPrecisionCompareColumn.REL_ERR_RATIO, ApiPrecisionCompareColumn.ABS_ERR_RATIO,
160
+ ApiPrecisionCompareColumn.MEAN_ULP_ERR, ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
161
+ ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
162
+
163
+ @staticmethod
164
+ def get_detail_csv_title():
165
+ return [ApiPrecisionCompareColumn.API_NAME,
166
+ ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATIO, ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_STATUS,
167
+ ApiPrecisionCompareColumn.RMSE_RATIO, ApiPrecisionCompareColumn.RMSE_STATUS,
168
+ ApiPrecisionCompareColumn.MAX_REL_ERR_RATIO, ApiPrecisionCompareColumn.MAX_REL_ERR_STATUS,
169
+ ApiPrecisionCompareColumn.MEAN_REL_ERR_RATIO, ApiPrecisionCompareColumn.MEAN_REL_ERR_STATUS,
170
+ ApiPrecisionCompareColumn.EB_RATIO, ApiPrecisionCompareColumn.EB_STATUS,
171
+ ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO, ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO_STATUS,
172
+ ApiPrecisionCompareColumn.REL_ERR_RATIO, ApiPrecisionCompareColumn.REL_ERR_RATIO_STATUS,
173
+ ApiPrecisionCompareColumn.ABS_ERR_RATIO, ApiPrecisionCompareColumn.ABS_ERR_RATIO_STATUS,
174
+ ApiPrecisionCompareColumn.ERROR_RATE, ApiPrecisionCompareColumn.ERROR_RATE_STATUS,
175
+ ApiPrecisionCompareColumn.MEAN_ULP_ERR, ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
176
+ ApiPrecisionCompareColumn.ULP_ERR_PROPORTION_RATIO, ApiPrecisionCompareColumn.ULP_ERR_STATUS,
177
+ ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH, ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH_STATUS,
178
+ ApiPrecisionCompareColumn.FINAL_RESULT, ApiPrecisionCompareColumn.ALGORITHM, ApiPrecisionCompareColumn.MESSAGE]
179
+
180
+ @staticmethod
181
+ def get_result_csv_title():
182
+ return [ApiPrecisionCompareColumn.API_NAME, ApiPrecisionCompareColumn.FORWWARD_STATUS,
183
+ ApiPrecisionCompareColumn.BACKWARD_STATUS, ApiPrecisionCompareColumn.MESSAGE]
184
+
185
+
186
+ CompareMessage = {
187
+ "topk" : "在npu上,topk的入参sorted=False时不生效,会返回有序tensor,而cpu上会返回无序tensor。 如果topk精度不达标,请检查是否是该原因导致的。"
188
+ }
189
+
190
+
191
+ def check_dtype_comparable(x, y):
192
+ if x.dtype in Const.FLOAT_TYPE:
193
+ if y.dtype in Const.FLOAT_TYPE:
194
+ return True
195
+ return False
196
+ if x.dtype in Const.BOOL_TYPE:
197
+ if y.dtype in Const.BOOL_TYPE:
198
+ return True
199
+ return False
200
+ if x.dtype in Const.INT_TYPE:
201
+ if y.dtype in Const.INT_TYPE:
202
+ return True
203
+ return False
204
+ logger.warning(f"Compare: Unexpected dtype {x.dtype}, {y.dtype}")
205
+ return False
206
+
207
+
208
+ def convert_str_to_float(input_data):
209
+ if isinstance(input_data, str) and input_data.strip() == "":
210
+ msg = 'ERROR: Input data is an empty string'
211
+ raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
212
+ try:
213
+ float_data = float(input_data)
214
+ return float_data
215
+ except ValueError as e:
216
+ msg = 'ERROR: Input data cannot be converted to float'
217
+ raise CompareException(CompareException.INVALID_DATA_ERROR, msg) from e
218
+
219
+
220
+ def is_inf_or_nan(x):
221
+ return math.isnan(x) or math.isinf(x)
222
+
223
+
224
+ def handle_infinity(x, y, column_name):
225
+ if math.isinf(x) and math.isinf(y):
226
+ if x == y:
227
+ return float("nan"), True, f"{column_name}同为同号inf或nan\n"
228
+ else:
229
+ return float("nan"), False, f"{column_name}inf或nan不一致\n"
230
+ else:
231
+ return float("nan"), False, f"{column_name}inf或nan不一致\n"
232
+
233
+
234
+ def handle_nan(x, y, column_name):
235
+ if math.isnan(x) and math.isnan(y):
236
+ return float("nan"), True, f"{column_name}同为同号inf或nan\n"
237
+ else:
238
+ return float("nan"), False, f"{column_name}inf或nan不一致\n"
239
+
240
+
241
+ def check_inf_or_nan(x, y, column_name):
242
+ if math.isinf(x) or math.isinf(y):
243
+ return handle_infinity(x, y, column_name)
244
+ else:
245
+ return handle_nan(x, y, column_name)
245
246