mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,213 +1,230 @@
1
- # 定义比对算法及比对标准
2
- import torch
3
- import numpy as np
4
-
5
- from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAMETERS
6
- from msprobe.core.common.const import CompareConst
7
-
8
-
9
- #cos
10
- def cosine_sim(bench_output, device_output):
11
- msg = ""
12
- n_value = device_output.reshape(-1)
13
- b_value = bench_output.reshape(-1)
14
- cos = CompareConst.SPACE
15
- np.seterr(divide="ignore", invalid="ignore")
16
- if n_value.shape != b_value.shape:
17
- msg = f"Shape of device and bench outputs don't match. device: {n_value.shape}, bench: {b_value.shape}."
18
- return -1, False, msg
19
- if len(n_value) == 1:
20
- msg = "All the data in device dump data is scalar. Please refer to other compare algorithms."
21
- return cos, True, msg
22
- n_value_max = np.max(np.abs(n_value))
23
- b_value_max = np.max(np.abs(b_value))
24
- if n_value_max <= np.finfo(float).eps and b_value_max <= np.finfo(float).eps:
25
- msg = "All the data in device and bench outputs are zero."
26
- return cos, True, msg
27
- elif n_value_max <= np.finfo(float).eps:
28
- msg = "All the data is zero in device dump data."
29
- return CompareConst.SPACE, False, msg
30
- elif b_value_max <= np.finfo(float).eps:
31
- msg = "All the data is zero in bench dump data."
32
- return CompareConst.SPACE, False, msg
33
- else:
34
- n_value = n_value.astype(float) / n_value_max
35
- b_value = b_value.astype(float) / b_value_max
36
- cos = np.dot(n_value, b_value) / (np.linalg.norm(n_value) * np.linalg.norm(b_value))
37
- if np.isnan(cos):
38
- msg = "Dump data has NaN when comparing with Cosine Similarity."
39
- cos = np.clip(cos, -1, 1)
40
- return cos, cos > 0.99, msg
41
-
42
-
43
- #rmse
44
- def get_rmse(abs_err, inf_nan_mask):
45
- masked_ae = np.where(inf_nan_mask, 0, abs_err)
46
- mse = np.mean(np.square(masked_ae))
47
- inf_nan_cnt = np.sum(inf_nan_mask)
48
- mse = mse * (abs_err.size / (abs_err.size - inf_nan_cnt + 0.0001) + 0.0001)
49
- rmse = np.sqrt(mse)
50
- return rmse
51
-
52
-
53
- #误差均衡性
54
- def get_error_balance(bench_data, device_data):
55
- larger_count = np.sum(np.greater(device_data - bench_data.astype(device_data.dtype), 0))
56
- smaller_count = np.sum(np.less(device_data - bench_data.astype(device_data.dtype), 0))
57
- total_count = bench_data.size
58
- error_balance = abs(larger_count - smaller_count) / total_count if total_count > 0 else 0
59
- return error_balance
60
-
61
-
62
- #小值域错误占比
63
- def get_small_value_err_ratio(small_value_mask, abs_err_greater_mask):
64
- err_mask = np.logical_and(small_value_mask, abs_err_greater_mask)
65
- small_value_err_num = np.sum(err_mask)
66
- small_value_num = np.sum(small_value_mask)
67
- return 0 if small_value_num == 0 else small_value_err_num / small_value_num
68
-
69
-
70
- def get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask):
71
- rel_err_tmp = abs_err / abs_bench_with_eps
72
- rel_err_mask = np.logical_or(small_value_mask, inf_nan_mask)
73
- rel_err = np.where(rel_err_mask, -1, rel_err_tmp)
74
- return rel_err
75
-
76
-
77
- def get_abs_err(bench_data, device_data):
78
- abs_err = np.abs(device_data - bench_data)
79
- return abs_err
80
-
81
-
82
- def get_rel_err_origin(abs_err, b_value):
83
- rel_err_origin = np.abs(abs_err / b_value)
84
- return rel_err_origin
85
-
86
-
87
- def get_max_abs_err(abs_err):
88
- max_abs_err = abs_err.max()
89
- bool_result = max_abs_err < 0.001
90
- return max_abs_err, bool_result
91
-
92
-
93
- #相对误差最大值
94
- def get_max_rel_err(rel_err):
95
- return np.max(rel_err) if np.max(rel_err) >= 0 else 0
96
-
97
-
98
- #相对误差均值
99
- def get_mean_rel_err(rel_err):
100
- non_negative_rel_err = rel_err[rel_err >= 0]
101
- return np.mean(non_negative_rel_err) if non_negative_rel_err.size > 0 else 0
102
-
103
-
104
- def get_rel_err_ratio(rel_err, thresholding):
105
- if np.size(rel_err) == 0:
106
- ratio = 1
107
- else:
108
- ratio = np.divide(np.sum(rel_err < thresholding), np.size(rel_err))
109
- bool_result = ratio > (1 - thresholding)
110
- return ratio, bool_result
111
-
112
-
113
- def get_finite_and_infinite_mask(bench_output, device_output):
114
- device_finite_mask = np.isfinite(device_output)
115
- bench_finite_mask = np.isfinite(bench_output.astype(device_output.dtype))
116
- both_finite_mask = np.logical_and(device_finite_mask, bench_finite_mask)
117
- inf_nan_mask = np.logical_not(both_finite_mask)
118
- return both_finite_mask, inf_nan_mask
119
-
120
-
121
- def get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold):
122
- small_value_mask = np.less_equal(abs_bench, small_value_threshold)
123
- small_value_mask = np.logical_and(small_value_mask, both_finite_mask)
124
- return small_value_mask
125
-
126
-
127
- def get_abs_bench_with_eps(bench, dtype):
128
- abs_bench = np.abs(bench)
129
- eps = np.finfo(bench.dtype).eps if dtype != torch.bfloat16 else CompareConst.BFLOAT16_EPS
130
- abs_bench_with_eps = abs_bench + eps
131
- return abs_bench, abs_bench_with_eps
132
-
133
-
134
- def check_inf_nan_value(inf_nan_mask, bench_output, device_output, dtype, rtol):
135
- '''
136
- 新精度标准的绝对阈值法中,检查npu和golden输出的inf、nan是否一致
137
- 输入:
138
- inf_nan_mask:npu输出和golden输出的inf、nan的mask
139
- bench_output:golden输出
140
- device_output:npu输出
141
- dtype:npu输出的dtype
142
- 输出:
143
- inf_nan_err_ratio:npu输出和golden输出的inf、nan不一致的比例
144
- '''
145
- abs_gpu, abs_gpu_with_eps = get_abs_bench_with_eps(bench_output, dtype)
146
- golden_same_dtype = bench_output.astype(device_output.dtype)
147
- a_min = np.finfo(device_output.dtype).min if dtype != torch.bfloat16 else CompareConst.BFLOAT16_MIN
148
- a_max = np.finfo(device_output.dtype).max if dtype != torch.bfloat16 else CompareConst.BFLOAT16_MAX
149
- golden_clip = np.clip(golden_same_dtype, a_min, a_max)
150
- npu_clip = np.clip(device_output, a_min, a_max)
151
- clipped_abs_ae = np.abs(npu_clip - golden_clip)
152
- clipped_re = clipped_abs_ae / abs_gpu_with_eps
153
- pass_mask = np.less_equal(clipped_re, rtol)
154
- both_nan_mask = np.logical_and(np.isnan(device_output), np.isnan(golden_clip))
155
- pass_mask = np.logical_or(pass_mask, both_nan_mask)
156
- not_pass_mask = np.logical_not(pass_mask)
157
- not_pass_mask = np.logical_and(not_pass_mask, inf_nan_mask)
158
-
159
- inf_nan_err_cnt = np.sum(not_pass_mask)
160
- return 0 if np.sum(inf_nan_mask) == 0 else inf_nan_err_cnt / np.sum(inf_nan_mask)
161
-
162
-
163
- def check_small_value(abs_err, small_value_mask, small_value_atol):
164
- '''
165
- 新精度标准的相对阈值法中,检查npu和golden小值域输出的相对误差是否满足阈值
166
- 输入:
167
- rel_err:npu输出和golden输出的相对误差
168
- normal_value_mask:npu输出和golden输出的正常值mask
169
- rtol:相对误差的阈值
170
- 输出:
171
- rel_err_ratio:npu输出和golden输出的相对误差不满足阈值的比例
172
- '''
173
- greater_mask = np.greater(abs_err, small_value_atol)
174
- err_mask = np.logical_and(greater_mask, small_value_mask)
175
- err_cnt = np.sum(err_mask)
176
- return 0 if np.sum(small_value_mask) == 0 else err_cnt / np.sum(small_value_mask)
177
-
178
-
179
- def check_norm_value(normal_value_mask, rel_err, rtol):
180
- '''
181
- 新精度标准的绝对阈值法中,检查npu和golden正常值输出的绝对误差是否满足阈值
182
- 输入:
183
- abs_err:npu输出和golden输出的绝对误差
184
- normal_value_mask:npu输出和golden输出的正常值mask
185
- atol:绝对误差的阈值
186
- 输出:
187
- abs_err_ratio:npu输出和golden输出的绝对误差不满足阈值的比例
188
- '''
189
- err_mask = np.greater(rel_err, rtol)
190
- err_mask = np.logical_and(err_mask, normal_value_mask)
191
- err_cnt = np.sum(err_mask)
192
- return 0 if np.sum(normal_value_mask) == 0 else err_cnt / np.sum(normal_value_mask)
193
-
194
-
195
- def get_ulp_err(bench_output, device_output, dtype):
196
- parameters = ULP_PARAMETERS.get(dtype)
197
- min_eb = parameters.get('min_eb')[0]
198
- exponent_num = parameters.get('exponent_num')[0]
199
- abs_bench = np.abs(bench_output)
200
- eb = np.where(abs_bench == 0, 0, np.floor(np.log2(abs_bench)))
201
- eb = np.maximum(eb, min_eb)
202
-
203
- if dtype == torch.float32:
204
- ulp_err = calc_ulp_err(bench_output, device_output, eb, exponent_num, np.float64)
205
- else:
206
- ulp_err = calc_ulp_err(bench_output, device_output, eb, exponent_num, np.float32)
207
- ulp_err = np.abs(ulp_err)
208
- return ulp_err
209
-
210
-
211
- def calc_ulp_err(bench_output, device_output, eb, exponent_num, data_type):
212
- return (device_output.astype(data_type) - bench_output).astype(data_type) * \
213
- np.exp2(-eb + exponent_num).astype(data_type)
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ # 定义比对算法及比对标准
19
+ import torch
20
+ import numpy as np
21
+
22
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAMETERS
23
+ from msprobe.core.common.const import CompareConst
24
+
25
+
26
+ #cos
27
+ def cosine_sim(bench_output, device_output):
28
+ msg = ""
29
+ n_value = device_output.reshape(-1)
30
+ b_value = bench_output.reshape(-1)
31
+ cos = CompareConst.SPACE
32
+ np.seterr(divide="ignore", invalid="ignore")
33
+ if n_value.shape != b_value.shape:
34
+ msg = f"Shape of device and bench outputs don't match. device: {n_value.shape}, bench: {b_value.shape}."
35
+ return -1, False, msg
36
+ if len(n_value) == 1:
37
+ msg = "All the data in device dump data is scalar. Please refer to other compare algorithms."
38
+ return cos, True, msg
39
+ n_value_max = np.max(np.abs(n_value))
40
+ b_value_max = np.max(np.abs(b_value))
41
+ if n_value_max <= np.finfo(float).eps and b_value_max <= np.finfo(float).eps:
42
+ msg = "All the data in device and bench outputs are zero."
43
+ return cos, True, msg
44
+ elif n_value_max <= np.finfo(float).eps:
45
+ msg = "All the data is zero in device dump data."
46
+ return CompareConst.SPACE, False, msg
47
+ elif b_value_max <= np.finfo(float).eps:
48
+ msg = "All the data is zero in bench dump data."
49
+ return CompareConst.SPACE, False, msg
50
+ else:
51
+ n_value = n_value.astype(float) / n_value_max
52
+ b_value = b_value.astype(float) / b_value_max
53
+ cos = np.dot(n_value, b_value) / (np.linalg.norm(n_value) * np.linalg.norm(b_value))
54
+ if np.isnan(cos):
55
+ msg = "Dump data has NaN when comparing with Cosine Similarity."
56
+ cos = np.clip(cos, -1, 1)
57
+ return cos, cos > 0.99, msg
58
+
59
+
60
+ #rmse
61
+ def get_rmse(abs_err, inf_nan_mask):
62
+ masked_ae = np.where(inf_nan_mask, 0, abs_err)
63
+ mse = np.mean(np.square(masked_ae))
64
+ inf_nan_cnt = np.sum(inf_nan_mask)
65
+ mse = mse * (abs_err.size / (abs_err.size - inf_nan_cnt + 0.0001) + 0.0001)
66
+ rmse = np.sqrt(mse)
67
+ return rmse
68
+
69
+
70
+ #误差均衡性
71
+ def get_error_balance(bench_data, device_data):
72
+ larger_count = np.sum(np.greater(device_data - bench_data.astype(device_data.dtype), 0))
73
+ smaller_count = np.sum(np.less(device_data - bench_data.astype(device_data.dtype), 0))
74
+ total_count = bench_data.size
75
+ error_balance = abs(larger_count - smaller_count) / total_count if total_count > 0 else 0
76
+ return error_balance
77
+
78
+
79
+ #小值域错误占比
80
+ def get_small_value_err_ratio(small_value_mask, abs_err_greater_mask):
81
+ err_mask = np.logical_and(small_value_mask, abs_err_greater_mask)
82
+ small_value_err_num = np.sum(err_mask)
83
+ small_value_num = np.sum(small_value_mask)
84
+ return 0 if small_value_num == 0 else small_value_err_num / small_value_num
85
+
86
+
87
+ def get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask):
88
+ rel_err_tmp = abs_err / abs_bench_with_eps
89
+ rel_err_mask = np.logical_or(small_value_mask, inf_nan_mask)
90
+ rel_err = np.where(rel_err_mask, -1, rel_err_tmp)
91
+ return rel_err
92
+
93
+
94
+ def get_abs_err(bench_data, device_data):
95
+ abs_err = np.abs(device_data - bench_data)
96
+ return abs_err
97
+
98
+
99
+ def get_rel_err_origin(abs_err, b_value):
100
+ rel_err_origin = np.abs(abs_err / b_value)
101
+ return rel_err_origin
102
+
103
+
104
+ def get_max_abs_err(abs_err):
105
+ max_abs_err = abs_err.max()
106
+ bool_result = max_abs_err < 0.001
107
+ return max_abs_err, bool_result
108
+
109
+
110
+ #相对误差最大值
111
+ def get_max_rel_err(rel_err):
112
+ return np.max(rel_err) if np.max(rel_err) >= 0 else 0
113
+
114
+
115
+ #相对误差均值
116
+ def get_mean_rel_err(rel_err):
117
+ non_negative_rel_err = rel_err[rel_err >= 0]
118
+ return np.mean(non_negative_rel_err) if non_negative_rel_err.size > 0 else 0
119
+
120
+
121
+ def get_rel_err_ratio(rel_err, thresholding):
122
+ if np.size(rel_err) == 0:
123
+ ratio = 1
124
+ else:
125
+ ratio = np.divide(np.sum(rel_err < thresholding), np.size(rel_err))
126
+ bool_result = ratio > (1 - thresholding)
127
+ return ratio, bool_result
128
+
129
+
130
+ def get_finite_and_infinite_mask(bench_output, device_output):
131
+ device_finite_mask = np.isfinite(device_output)
132
+ bench_finite_mask = np.isfinite(bench_output.astype(device_output.dtype))
133
+ both_finite_mask = np.logical_and(device_finite_mask, bench_finite_mask)
134
+ inf_nan_mask = np.logical_not(both_finite_mask)
135
+ return both_finite_mask, inf_nan_mask
136
+
137
+
138
+ def get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold):
139
+ small_value_mask = np.less_equal(abs_bench, small_value_threshold)
140
+ small_value_mask = np.logical_and(small_value_mask, both_finite_mask)
141
+ return small_value_mask
142
+
143
+
144
+ def get_abs_bench_with_eps(bench, dtype):
145
+ abs_bench = np.abs(bench)
146
+ eps = np.finfo(bench.dtype).eps if dtype != torch.bfloat16 else CompareConst.BFLOAT16_EPS
147
+ abs_bench_with_eps = abs_bench + eps
148
+ return abs_bench, abs_bench_with_eps
149
+
150
+
151
+ def check_inf_nan_value(inf_nan_mask, bench_output, device_output, dtype, rtol):
152
+ '''
153
+ 新精度标准的绝对阈值法中,检查npu和golden输出的inf、nan是否一致
154
+ 输入:
155
+ inf_nan_mask:npu输出和golden输出的inf、nan的mask
156
+ bench_output:golden输出
157
+ device_output:npu输出
158
+ dtype:npu输出的dtype
159
+ 输出:
160
+ inf_nan_err_ratio:npu输出和golden输出的inf、nan不一致的比例
161
+ '''
162
+ _, abs_gpu_with_eps = get_abs_bench_with_eps(bench_output, dtype)
163
+ golden_same_dtype = bench_output.astype(device_output.dtype)
164
+ a_min = np.finfo(device_output.dtype).min if dtype != torch.bfloat16 else CompareConst.BFLOAT16_MIN
165
+ a_max = np.finfo(device_output.dtype).max if dtype != torch.bfloat16 else CompareConst.BFLOAT16_MAX
166
+ golden_clip = np.clip(golden_same_dtype, a_min, a_max)
167
+ npu_clip = np.clip(device_output, a_min, a_max)
168
+ clipped_abs_ae = np.abs(npu_clip - golden_clip)
169
+ clipped_re = clipped_abs_ae / abs_gpu_with_eps
170
+ pass_mask = np.less_equal(clipped_re, rtol)
171
+ both_nan_mask = np.logical_and(np.isnan(device_output), np.isnan(golden_clip))
172
+ pass_mask = np.logical_or(pass_mask, both_nan_mask)
173
+ not_pass_mask = np.logical_not(pass_mask)
174
+ not_pass_mask = np.logical_and(not_pass_mask, inf_nan_mask)
175
+
176
+ inf_nan_err_cnt = np.sum(not_pass_mask)
177
+ return 0 if np.sum(inf_nan_mask) == 0 else inf_nan_err_cnt / np.sum(inf_nan_mask)
178
+
179
+
180
+ def check_small_value(abs_err, small_value_mask, small_value_atol):
181
+ '''
182
+ 新精度标准的相对阈值法中,检查npu和golden小值域输出的相对误差是否满足阈值
183
+ 输入:
184
+ rel_err:npu输出和golden输出的相对误差
185
+ normal_value_mask:npu输出和golden输出的正常值mask
186
+ rtol:相对误差的阈值
187
+ 输出:
188
+ rel_err_ratio:npu输出和golden输出的相对误差不满足阈值的比例
189
+ '''
190
+ greater_mask = np.greater(abs_err, small_value_atol)
191
+ err_mask = np.logical_and(greater_mask, small_value_mask)
192
+ err_cnt = np.sum(err_mask)
193
+ return 0 if np.sum(small_value_mask) == 0 else err_cnt / np.sum(small_value_mask)
194
+
195
+
196
+ def check_norm_value(normal_value_mask, rel_err, rtol):
197
+ '''
198
+ 新精度标准的绝对阈值法中,检查npu和golden正常值输出的绝对误差是否满足阈值
199
+ 输入:
200
+ abs_err:npu输出和golden输出的绝对误差
201
+ normal_value_mask:npu输出和golden输出的正常值mask
202
+ atol:绝对误差的阈值
203
+ 输出:
204
+ abs_err_ratio:npu输出和golden输出的绝对误差不满足阈值的比例
205
+ '''
206
+ err_mask = np.greater(rel_err, rtol)
207
+ err_mask = np.logical_and(err_mask, normal_value_mask)
208
+ err_cnt = np.sum(err_mask)
209
+ return 0 if np.sum(normal_value_mask) == 0 else err_cnt / np.sum(normal_value_mask)
210
+
211
+
212
+ def get_ulp_err(bench_output, device_output, dtype):
213
+ parameters = ULP_PARAMETERS.get(dtype)
214
+ min_eb = parameters.get('min_eb')[0]
215
+ exponent_num = parameters.get('exponent_num')[0]
216
+ abs_bench = np.abs(bench_output)
217
+ eb = np.where(abs_bench == 0, 0, np.floor(np.log2(abs_bench)))
218
+ eb = np.maximum(eb, min_eb)
219
+
220
+ if dtype == torch.float32:
221
+ ulp_err = calc_ulp_err(bench_output, device_output, eb, exponent_num, np.float64)
222
+ else:
223
+ ulp_err = calc_ulp_err(bench_output, device_output, eb, exponent_num, np.float32)
224
+ ulp_err = np.abs(ulp_err)
225
+ return ulp_err
226
+
227
+
228
+ def calc_ulp_err(bench_output, device_output, eb, exponent_num, data_type):
229
+ return (device_output.astype(data_type) - bench_output).astype(data_type) * \
230
+ np.exp2(-eb + exponent_num).astype(data_type)