mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,244 +1,295 @@
1
- import abc
2
- import numpy as np
3
- from msprobe.core.common.utils import format_value
4
- from msprobe.core.common.const import Const, CompareConst
5
- from msprobe.pytorch.common.log import logger
6
-
7
-
8
- def handle_inf_nan(n_value, b_value):
9
- """处理inf和nan的数据"""
10
- n_inf = np.isinf(n_value)
11
- b_inf = np.isinf(b_value)
12
- n_nan = np.isnan(n_value)
13
- b_nan = np.isnan(b_value)
14
- n_invalid = np.any(n_inf) or np.any(n_nan)
15
- b_invalid = np.any(b_inf) or np.any(b_nan)
16
- if n_invalid or b_invalid:
17
- if np.array_equal(n_inf, b_inf) and np.array_equal(n_nan, b_nan):
18
- n_value[n_inf] = 0
19
- b_value[b_inf] = 0
20
- n_value[n_nan] = 0
21
- b_value[b_nan] = 0
22
- else:
23
- return CompareConst.NAN, CompareConst.NAN
24
- return n_value, b_value
25
-
26
-
27
- def get_error_type(n_value, b_value, error_flag):
28
- """判断数据是否有异常并返回异常的n_value, b_value,同时返回error_flag"""
29
- if error_flag:
30
- return CompareConst.READ_NONE, CompareConst.READ_NONE, True
31
- if n_value.size == 0: # 判断读取到的数据是否为空
32
- return CompareConst.NONE, CompareConst.NONE, True
33
- if n_value.shape != b_value.shape: # 判断NPU和bench的数据结构是否一致
34
- return CompareConst.SHAPE_UNMATCH, CompareConst.SHAPE_UNMATCH, True
35
- if not n_value.shape: # 判断数据是否为标量
36
- return n_value, b_value, False
37
-
38
- n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
39
- if n_value is CompareConst.NAN or b_value is CompareConst.NAN:
40
- return CompareConst.NAN, CompareConst.NAN, True
41
- return n_value, b_value, False
42
-
43
-
44
- def reshape_value(n_value, b_value):
45
- """返回reshape后的数据"""
46
- if not n_value.shape: # 判断数据是否为标量
47
- if n_value.dtype == bool:
48
- n_value = n_value.astype(float)
49
- b_value = b_value.astype(float)
50
- return n_value, b_value
51
-
52
- n_value = n_value.reshape(-1).astype(float)
53
- b_value = b_value.reshape(-1).astype(float)
54
- return n_value, b_value
55
-
56
-
57
- def get_error_message(n_value, b_value, op_name, error_flag, error_file=None):
58
- """获取异常情况的错误信息"""
59
- if error_flag:
60
- if n_value == CompareConst.READ_NONE:
61
- if error_file:
62
- return "Dump file: {} not found.".format(error_file)
63
- return CompareConst.NO_BENCH
64
- if n_value == CompareConst.NONE:
65
- return "This is empty data, can not compare."
66
- if n_value == CompareConst.SHAPE_UNMATCH:
67
- return "Shape of NPU and bench Tensor do not match. Skipped."
68
- if n_value == CompareConst.NAN:
69
- return "The position of inf or nan in NPU and bench Tensor do not match."
70
- else:
71
- if not n_value.shape:
72
- return "This is type of scalar data, can not compare."
73
- if n_value.dtype != b_value.dtype:
74
- logger.warning("Dtype of NPU and bench Tensor do not match: {}".format(op_name))
75
- return "Dtype of NPU and bench Tensor do not match."
76
- return ""
77
-
78
-
79
- class TensorComparisonBasic(abc.ABC):
80
- """NPU和bench中npy数据的比较模板"""
81
- @abc.abstractmethod
82
- def apply(self, n_value, b_value, error_flag, relative_err=None):
83
- raise NotImplementedError
84
-
85
-
86
- class GetCosineSimilarity(TensorComparisonBasic):
87
- """计算cosine相似度"""
88
- @staticmethod
89
- def correct_data(result):
90
- if result == CompareConst.NAN:
91
- return result
92
- if float(result) > CompareConst.COSINE_THRESHOLD:
93
- return 1.0
94
- return result
95
-
96
- def apply(self, n_value, b_value, error_flag, relative_err=None):
97
- if error_flag:
98
- if n_value == CompareConst.READ_NONE:
99
- return CompareConst.NONE, ''
100
- if n_value == CompareConst.NONE:
101
- return CompareConst.UNSUPPORTED, ''
102
- if n_value == CompareConst.SHAPE_UNMATCH:
103
- return CompareConst.SHAPE_UNMATCH, ''
104
- if n_value == CompareConst.NAN:
105
- return "N/A", ''
106
-
107
- if not n_value.shape:
108
- return CompareConst.UNSUPPORTED, ''
109
-
110
- with np.errstate(divide='ignore', invalid='ignore'):
111
- if len(n_value) == 1:
112
- return CompareConst.UNSUPPORTED, "This tensor is scalar."
113
- num = n_value.dot(b_value)
114
- a_norm = np.linalg.norm(n_value)
115
- b_norm = np.linalg.norm(b_value)
116
-
117
- if a_norm <= Const.FLOAT_EPSILON and b_norm <= Const.FLOAT_EPSILON:
118
- return 1.0, ''
119
- if a_norm <= Const.FLOAT_EPSILON:
120
- return CompareConst.NAN, 'Cannot compare by Cosine Similarity, All the data is Zero in npu dump data.'
121
- if b_norm <= Const.FLOAT_EPSILON:
122
- return CompareConst.NAN, 'Cannot compare by Cosine Similarity, All the data is Zero in Bench dump data.'
123
-
124
- cos = num / (a_norm * b_norm)
125
- if np.isnan(cos):
126
- return CompareConst.NAN, 'Cannot compare by Cosine Similarity, the dump data has NaN.'
127
- result = format_value(cos)
128
- result = self.correct_data(result)
129
- return 1.0 if float(result) > 0.99999 else result, ''
130
-
131
-
132
- class GetMaxAbsErr(TensorComparisonBasic):
133
- """计算最大绝对误差"""
134
- def apply(self, n_value, b_value, error_flag, relative_err=None):
135
- if error_flag:
136
- if n_value == CompareConst.READ_NONE:
137
- return CompareConst.NONE, ""
138
- if n_value == CompareConst.NONE:
139
- return 0, ""
140
- if n_value == CompareConst.SHAPE_UNMATCH:
141
- return CompareConst.SHAPE_UNMATCH, ""
142
- if n_value == CompareConst.NAN:
143
- return "N/A", ""
144
-
145
- temp_res = n_value - b_value
146
- max_value = np.max(np.abs(temp_res))
147
- return format_value(max_value), ""
148
-
149
-
150
- def get_relative_err(n_value, b_value):
151
- """计算相对误差"""
152
- with np.errstate(divide='ignore', invalid='ignore'):
153
- if b_value.dtype not in CompareConst.FLOAT_TYPE:
154
- n_value, b_value = n_value.astype(float), b_value.astype(float)
155
- zero_mask = (b_value == 0)
156
- b_value[zero_mask] += np.finfo(b_value.dtype).eps
157
- n_value[zero_mask] += np.finfo(b_value.dtype).eps
158
- relative_err = np.divide((n_value - b_value), b_value)
159
- return np.abs(relative_err)
160
-
161
-
162
- class GetMaxRelativeErr(TensorComparisonBasic):
163
- """计算最大相对误差"""
164
- def apply(self, n_value, b_value, error_flag, relative_err=None):
165
- if error_flag:
166
- if n_value == CompareConst.READ_NONE:
167
- return CompareConst.NONE, ''
168
- if n_value == CompareConst.NONE:
169
- return 0, ''
170
- if n_value == CompareConst.SHAPE_UNMATCH:
171
- return CompareConst.SHAPE_UNMATCH, ''
172
- if n_value == CompareConst.NAN:
173
- return "N/A", ''
174
-
175
- if relative_err is None:
176
- relative_err = get_relative_err(n_value, b_value)
177
- max_relative_err = np.max(np.abs(relative_err))
178
- if np.isnan(max_relative_err):
179
- message = 'Cannot compare by MaxRelativeError, the data contains nan in dump data.'
180
- return CompareConst.NAN, message
181
- return format_value(max_relative_err), ''
182
-
183
-
184
- class GetThousandErrRatio(TensorComparisonBasic):
185
- """计算相对误差小于千分之一的比例"""
186
- def apply(self, n_value, b_value, error_flag, relative_err=None):
187
- if error_flag:
188
- if n_value == CompareConst.READ_NONE:
189
- return CompareConst.NONE, ""
190
- if n_value == CompareConst.NONE:
191
- return 0, ""
192
- if n_value == CompareConst.SHAPE_UNMATCH:
193
- return CompareConst.SHAPE_UNMATCH, ""
194
- if n_value == CompareConst.NAN:
195
- return "N/A", ""
196
-
197
- if not n_value.shape:
198
- return CompareConst.NAN, ""
199
- if relative_err is None:
200
- relative_err = get_relative_err(n_value, b_value)
201
- if not np.size(relative_err):
202
- return CompareConst.NAN, ""
203
- return format_value(np.sum(relative_err < CompareConst.THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
204
-
205
-
206
- class GetFiveThousandErrRatio(TensorComparisonBasic):
207
- """计算相对误差小于千分之五的比例"""
208
- def apply(self, n_value, b_value, error_flag, relative_err=None):
209
- if error_flag:
210
- if n_value == CompareConst.READ_NONE:
211
- return CompareConst.NONE, ""
212
- if n_value == CompareConst.NONE:
213
- return 0, ""
214
- if n_value == CompareConst.SHAPE_UNMATCH:
215
- return CompareConst.SHAPE_UNMATCH, ""
216
- if n_value == CompareConst.NAN:
217
- return "N/A", ""
218
-
219
- if not n_value.shape:
220
- return CompareConst.NAN, ""
221
- if relative_err is None:
222
- relative_err = get_relative_err(n_value, b_value)
223
- if not np.size(relative_err):
224
- return CompareConst.NAN, ""
225
- return format_value(np.sum(relative_err < CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
226
-
227
-
228
- class CompareOps:
229
- compare_ops = {
230
- "cosine_similarity": GetCosineSimilarity(),
231
- "max_abs_error": GetMaxAbsErr(),
232
- "max_relative_error": GetMaxRelativeErr(),
233
- "one_thousand_err_ratio": GetThousandErrRatio(),
234
- "five_thousand_err_ratio": GetFiveThousandErrRatio()
235
- }
236
-
237
-
238
- def compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=None):
239
- result_list = []
240
- for op in CompareOps.compare_ops.values():
241
- result, msg = op.apply(n_value, b_value, error_flag, relative_err=relative_err)
242
- err_msg += msg
243
- result_list.append(result)
244
- return result_list, err_msg
1
+ import abc
2
+ import numpy as np
3
+ from msprobe.core.common.utils import format_value
4
+ from msprobe.core.common.const import Const, CompareConst
5
+ from msprobe.core.common.log import logger
6
+
7
+
8
+ def handle_inf_nan(n_value, b_value):
9
+ """处理inf和nan的数据"""
10
+ n_inf = np.isinf(n_value)
11
+ b_inf = np.isinf(b_value)
12
+ n_nan = np.isnan(n_value)
13
+ b_nan = np.isnan(b_value)
14
+ n_invalid = np.any(n_inf) or np.any(n_nan)
15
+ b_invalid = np.any(b_inf) or np.any(b_nan)
16
+ if n_invalid or b_invalid:
17
+ if np.array_equal(n_inf, b_inf) and np.array_equal(n_nan, b_nan):
18
+ n_value[n_inf] = 0
19
+ b_value[b_inf] = 0
20
+ n_value[n_nan] = 0
21
+ b_value[b_nan] = 0
22
+ else:
23
+ return CompareConst.NAN, CompareConst.NAN
24
+ return n_value, b_value
25
+
26
+
27
+ def get_error_type(n_value, b_value, error_flag):
28
+ """判断数据是否有异常并返回异常的n_value, b_value,同时返回error_flag"""
29
+ if error_flag:
30
+ return CompareConst.READ_NONE, CompareConst.READ_NONE, True
31
+ if n_value.size == 0: # 判断读取到的数据是否为空
32
+ return CompareConst.NONE, CompareConst.NONE, True
33
+ if n_value.shape != b_value.shape: # 判断NPU和bench的数据结构是否一致
34
+ return CompareConst.SHAPE_UNMATCH, CompareConst.SHAPE_UNMATCH, True
35
+ if not n_value.shape: # 判断数据是否为标量
36
+ return n_value, b_value, False
37
+
38
+ n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
39
+ if n_value is CompareConst.NAN or b_value is CompareConst.NAN:
40
+ return CompareConst.NAN, CompareConst.NAN, True
41
+ return n_value, b_value, False
42
+
43
+
44
+ def reshape_value(n_value, b_value):
45
+ """返回reshape后的数据"""
46
+ if not n_value.shape: # 判断数据是否为标量
47
+ if n_value.dtype == bool:
48
+ n_value = n_value.astype(float)
49
+ b_value = b_value.astype(float)
50
+ return n_value, b_value
51
+
52
+ n_value = n_value.reshape(-1).astype(float)
53
+ b_value = b_value.reshape(-1).astype(float)
54
+ return n_value, b_value
55
+
56
+
57
+ def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None):
58
+ """获取异常情况的错误信息"""
59
+ if error_flag:
60
+ if n_value == CompareConst.READ_NONE:
61
+ if error_file:
62
+ return "Dump file: {} not found.".format(error_file)
63
+ return CompareConst.NO_BENCH
64
+ if n_value == CompareConst.NONE:
65
+ return "This is empty data, can not compare."
66
+ if n_value == CompareConst.SHAPE_UNMATCH:
67
+ return "Shape of NPU and bench Tensor do not match. Skipped."
68
+ if n_value == CompareConst.NAN:
69
+ return "The position of inf or nan in NPU and bench Tensor do not match."
70
+ else:
71
+ if not n_value.shape:
72
+ return "This is type of scalar data, can not compare."
73
+ if n_value.dtype != b_value.dtype:
74
+ logger.warning("Dtype of NPU and bench Tensor do not match: {}".format(npu_op_name))
75
+ return "Dtype of NPU and bench Tensor do not match."
76
+ return ""
77
+
78
+
79
+ def npy_data_check(n_value, b_value):
80
+ error_message = ""
81
+ if n_value is None or b_value is None:
82
+ error_message += "Dump file not found.\n"
83
+ if n_value == "" or b_value == "":
84
+ error_message += "Dump file not found.\n"
85
+
86
+ # 检查 n_value 和 b_value 是否为空
87
+ if not error_message and (n_value.size == 0 or b_value.size == 0):
88
+ error_message += "This is empty data, can not compare.\n"
89
+
90
+ if not error_message:
91
+ if not n_value.shape or not b_value.shape:
92
+ error_message += "This is type of scalar data, can not compare.\n"
93
+ if n_value.shape != b_value.shape:
94
+ error_message += "Shape of NPU and bench Tensor do not match.\n"
95
+ if n_value.dtype != b_value.dtype:
96
+ error_message += "Dtype of NPU and bench Tensor do not match. Skipped.\n"
97
+
98
+ if not error_message:
99
+ n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有 nan/inf 数据
100
+ if CompareConst.NAN in (n_value, b_value):
101
+ error_message += "The position of inf or nan in NPU and bench Tensor do not match.\n"
102
+ if error_message == "":
103
+ error_flag = False
104
+ else:
105
+ error_flag = True
106
+ return error_flag, error_message
107
+
108
+
109
+ def statistics_data_check(result_dict):
110
+ error_message = ""
111
+
112
+ if result_dict.get(CompareConst.NPU_NAME) is None or result_dict.get(CompareConst.BENCH_NAME) is None:
113
+ error_message += "Dump file not found.\n"
114
+
115
+ if not result_dict.get(CompareConst.NPU_SHAPE) or not result_dict.get(CompareConst.BENCH_SHAPE):
116
+ error_message += "This is type of scalar data, can not compare.\n"
117
+ elif result_dict.get(CompareConst.NPU_SHAPE) != result_dict.get(CompareConst.BENCH_SHAPE):
118
+ error_message += "Tensor shapes do not match.\n"
119
+
120
+ if result_dict.get(CompareConst.NPU_DTYPE) != result_dict.get(CompareConst.BENCH_DTYPE):
121
+ error_message += "Dtype of NPU and bench Tensor do not match. Skipped.\n"
122
+
123
+ if error_message == "":
124
+ error_flag = False
125
+ else:
126
+ error_flag = True
127
+ return error_flag, error_message
128
+
129
+
130
+ class TensorComparisonBasic(abc.ABC):
131
+ """NPU和bench中npy数据的比较模板"""
132
+ @abc.abstractmethod
133
+ def apply(self, n_value, b_value, error_flag, relative_err=None):
134
+ raise NotImplementedError
135
+
136
+
137
+ class GetCosineSimilarity(TensorComparisonBasic):
138
+ """计算cosine相似度"""
139
+ @staticmethod
140
+ def correct_data(result):
141
+ if result == CompareConst.NAN:
142
+ return result
143
+ if float(result) > CompareConst.COSINE_THRESHOLD:
144
+ return round(float(result), 6)
145
+ return result
146
+
147
+ def apply(self, n_value, b_value, error_flag, relative_err=None):
148
+ if error_flag:
149
+ if n_value == CompareConst.READ_NONE:
150
+ return CompareConst.NONE, ''
151
+ if n_value == CompareConst.NONE:
152
+ return CompareConst.UNSUPPORTED, ''
153
+ if n_value == CompareConst.SHAPE_UNMATCH:
154
+ return CompareConst.SHAPE_UNMATCH, ''
155
+ if n_value == CompareConst.NAN:
156
+ return "N/A", ''
157
+
158
+ if not n_value.shape:
159
+ return CompareConst.UNSUPPORTED, ''
160
+
161
+ with np.errstate(divide='ignore', invalid='ignore'):
162
+ if len(n_value) == 1:
163
+ return CompareConst.UNSUPPORTED, "This tensor is scalar."
164
+ num = n_value.dot(b_value)
165
+ a_norm = np.linalg.norm(n_value)
166
+ b_norm = np.linalg.norm(b_value)
167
+
168
+ if a_norm <= Const.FLOAT_EPSILON and b_norm <= Const.FLOAT_EPSILON:
169
+ return 1.0, ''
170
+ if a_norm <= Const.FLOAT_EPSILON:
171
+ return CompareConst.NAN, 'Cannot compare by Cosine Similarity, All the data is Zero in npu dump data.'
172
+ if b_norm <= Const.FLOAT_EPSILON:
173
+ return CompareConst.NAN, 'Cannot compare by Cosine Similarity, All the data is Zero in Bench dump data.'
174
+
175
+ cos = num / (a_norm * b_norm)
176
+ if np.isnan(cos):
177
+ return CompareConst.NAN, 'Cannot compare by Cosine Similarity, the dump data has NaN.'
178
+ result = format_value(cos)
179
+ result = self.correct_data(result)
180
+ return 1.0 if float(result) > 0.99999 else result, ''
181
+
182
+
183
+ class GetMaxAbsErr(TensorComparisonBasic):
184
+ """计算最大绝对误差"""
185
+ def apply(self, n_value, b_value, error_flag, relative_err=None):
186
+ if error_flag:
187
+ if n_value == CompareConst.READ_NONE:
188
+ return CompareConst.NONE, ""
189
+ if n_value == CompareConst.NONE:
190
+ return 0, ""
191
+ if n_value == CompareConst.SHAPE_UNMATCH:
192
+ return CompareConst.SHAPE_UNMATCH, ""
193
+ if n_value == CompareConst.NAN:
194
+ return "N/A", ""
195
+
196
+ temp_res = n_value - b_value
197
+ max_value = np.max(np.abs(temp_res))
198
+ return format_value(max_value), ""
199
+
200
+
201
+ def get_relative_err(n_value, b_value):
202
+ """计算相对误差"""
203
+ with np.errstate(divide='ignore', invalid='ignore'):
204
+ if b_value.dtype not in CompareConst.FLOAT_TYPE:
205
+ n_value, b_value = n_value.astype(float), b_value.astype(float)
206
+ zero_mask = (b_value == 0)
207
+ b_value[zero_mask] += np.finfo(b_value.dtype).eps
208
+ n_value[zero_mask] += np.finfo(b_value.dtype).eps
209
+ relative_err = np.divide((n_value - b_value), b_value)
210
+ return np.abs(relative_err)
211
+
212
+
213
+ class GetMaxRelativeErr(TensorComparisonBasic):
214
+ """计算最大相对误差"""
215
+ def apply(self, n_value, b_value, error_flag, relative_err=None):
216
+ if error_flag:
217
+ if n_value == CompareConst.READ_NONE:
218
+ return CompareConst.NONE, ''
219
+ if n_value == CompareConst.NONE:
220
+ return 0, ''
221
+ if n_value == CompareConst.SHAPE_UNMATCH:
222
+ return CompareConst.SHAPE_UNMATCH, ''
223
+ if n_value == CompareConst.NAN:
224
+ return "N/A", ''
225
+
226
+ if relative_err is None:
227
+ relative_err = get_relative_err(n_value, b_value)
228
+ max_relative_err = np.max(np.abs(relative_err))
229
+ if np.isnan(max_relative_err):
230
+ message = 'Cannot compare by MaxRelativeError, the data contains nan in dump data.'
231
+ return CompareConst.NAN, message
232
+ return format_value(max_relative_err), ''
233
+
234
+
235
+ class GetThousandErrRatio(TensorComparisonBasic):
236
+ """计算相对误差小于千分之一的比例"""
237
+ def apply(self, n_value, b_value, error_flag, relative_err=None):
238
+ if error_flag:
239
+ if n_value == CompareConst.READ_NONE:
240
+ return CompareConst.NONE, ""
241
+ if n_value == CompareConst.NONE:
242
+ return 0, ""
243
+ if n_value == CompareConst.SHAPE_UNMATCH:
244
+ return CompareConst.SHAPE_UNMATCH, ""
245
+ if n_value == CompareConst.NAN:
246
+ return "N/A", ""
247
+
248
+ if not n_value.shape:
249
+ return CompareConst.NAN, ""
250
+ if relative_err is None:
251
+ relative_err = get_relative_err(n_value, b_value)
252
+ if not np.size(relative_err):
253
+ return CompareConst.NAN, ""
254
+ return format_value(np.sum(relative_err < CompareConst.THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
255
+
256
+
257
+ class GetFiveThousandErrRatio(TensorComparisonBasic):
258
+ """计算相对误差小于千分之五的比例"""
259
+ def apply(self, n_value, b_value, error_flag, relative_err=None):
260
+ if error_flag:
261
+ if n_value == CompareConst.READ_NONE:
262
+ return CompareConst.NONE, ""
263
+ if n_value == CompareConst.NONE:
264
+ return 0, ""
265
+ if n_value == CompareConst.SHAPE_UNMATCH:
266
+ return CompareConst.SHAPE_UNMATCH, ""
267
+ if n_value == CompareConst.NAN:
268
+ return "N/A", ""
269
+
270
+ if not n_value.shape:
271
+ return CompareConst.NAN, ""
272
+ if relative_err is None:
273
+ relative_err = get_relative_err(n_value, b_value)
274
+ if not np.size(relative_err):
275
+ return CompareConst.NAN, ""
276
+ return format_value(np.sum(relative_err < CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
277
+
278
+
279
+ class CompareOps:
280
+ compare_ops = {
281
+ "cosine_similarity": GetCosineSimilarity(),
282
+ "max_abs_error": GetMaxAbsErr(),
283
+ "max_relative_error": GetMaxRelativeErr(),
284
+ "one_thousand_err_ratio": GetThousandErrRatio(),
285
+ "five_thousand_err_ratio": GetFiveThousandErrRatio()
286
+ }
287
+
288
+
289
+ def compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=None):
290
+ result_list = []
291
+ for op in CompareOps.compare_ops.values():
292
+ result, msg = op.apply(n_value, b_value, error_flag, relative_err=relative_err)
293
+ err_msg += msg
294
+ result_list.append(result)
295
+ return result_list, err_msg