mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,299 +1,473 @@
1
- import multiprocessing
2
- import os
3
- import json
4
- import pandas as pd
5
- from msprobe.core.common.file_check import FileOpen
6
- from msprobe.core.common.const import CompareConst, Const
7
- from msprobe.core.common.exceptions import FileCheckException
8
- from msprobe.core.common.log import logger
9
- from msprobe.core.common.utils import add_time_with_xlsx, CompareException, check_file_not_exists
10
- from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op
11
- from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
12
- from msprobe.core.compare.utils import read_op, merge_tensor, CompareException, get_un_match_accuracy, get_accuracy
13
- from msprobe.core.compare.multiprocessing_compute import _handle_multi_process, ComparisonResult, _save_cmp_result
14
- from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \
15
- get_error_message
16
- from msprobe.core.advisor.advisor import Advisor
17
-
18
-
19
- class Comparator:
20
-
21
- def __init__(self):
22
- pass
23
-
24
- @classmethod
25
- def make_result_table(cls,result, md5_compare, summary_compare, stack_mode):
26
- header = []
27
- if md5_compare:
28
- header = CompareConst.MD5_COMPARE_RESULT_HEADER[:]
29
- elif summary_compare:
30
- header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
31
- else:
32
- header = CompareConst.COMPARE_RESULT_HEADER[:]
33
-
34
- all_mode_bool = not (summary_compare or md5_compare)
35
- if stack_mode:
36
- if all_mode_bool:
37
- header.append(CompareConst.STACK)
38
- header.append(CompareConst.DATA_NAME)
39
- else:
40
- header.append(CompareConst.STACK)
41
- else:
42
- if all_mode_bool:
43
- for row in result:
44
- del row[-2]
45
- header.append(CompareConst.DATA_NAME)
46
- else:
47
- for row in result:
48
- del row[-1]
49
- result_df = pd.DataFrame(result, columns=header)
50
- return result_df
51
-
52
- @classmethod
53
- def gen_merge_list(self, json_data, op_name,stack_json_data, summary_compare, md5_compare):
54
- op_data = json_data['data'][op_name]
55
- op_parsed_list = read_op(op_data, op_name)
56
- if op_name in stack_json_data:
57
- op_parsed_list.append({'full_op_name': op_name, 'full_info': stack_json_data[op_name]})
58
- else:
59
- op_parsed_list.append({'full_op_name': op_name, 'full_info': None})
60
-
61
- merge_list = merge_tensor(op_parsed_list, summary_compare, md5_compare)
62
- return merge_list
63
-
64
- def check_op(self, npu_dict, bench_dict, fuzzy_match):
65
- a_op_name = npu_dict["op_name"]
66
- b_op_name = bench_dict["op_name"]
67
- graph_mode = check_graph_mode(a_op_name[0], b_op_name[0])
68
-
69
- frame_name = getattr(self,"frame_name")
70
- if frame_name == "PTComparator":
71
- from msprobe.pytorch.compare.match import graph_mapping
72
- if graph_mode:
73
- return graph_mapping.match(a_op_name[0], b_op_name[0])
74
- struct_match = check_struct_match(npu_dict, bench_dict)
75
- if not fuzzy_match:
76
- return a_op_name == b_op_name and struct_match
77
- is_match = True
78
- try:
79
- is_match = fuzzy_check_op(a_op_name, b_op_name)
80
- except Exception as err:
81
- logger.warning("%s and %s can not fuzzy match." % (a_op_name, b_op_name))
82
- is_match = False
83
- return is_match and struct_match
84
-
85
- def match_op(self, npu_queue, bench_queue, fuzzy_match):
86
- for b_index, b_op in enumerate(bench_queue[0: -1]):
87
- if self.check_op(npu_queue[-1], b_op, fuzzy_match):
88
- return len(npu_queue) - 1, b_index
89
- if self.check_op(npu_queue[-1], bench_queue[-1], fuzzy_match):
90
- return len(npu_queue) - 1, len(bench_queue) - 1
91
- for n_index, n_op in enumerate(npu_queue[0: -1]):
92
- if self.check_op(n_op, bench_queue[-1], fuzzy_match):
93
- return n_index, len(bench_queue) - 1
94
- return -1, -1
95
-
96
- def compare_process(self, file_handles, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False):
97
- npu_json_handle, bench_json_handle, stack_json_handle = file_handles
98
- npu_json_data = json.load(npu_json_handle)
99
- bench_json_data = json.load(bench_json_handle)
100
- stack_json_data = json.load(stack_json_handle)
101
-
102
- if fuzzy_match:
103
- logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
104
-
105
- npu_ops_queue = []
106
- bench_ops_queue = []
107
- result = []
108
-
109
- ops_npu_iter = iter(npu_json_data['data'])
110
- ops_bench_iter = iter(bench_json_data['data'])
111
- read_err_npu = True
112
- read_err_bench = True
113
- last_npu_ops_len = 0
114
- last_bench_ops_len = 0
115
-
116
- while True:
117
- if not read_err_npu and not read_err_bench:
118
- break
119
- try:
120
- last_npu_ops_len = len(npu_ops_queue)
121
- op_name_npu = next(ops_npu_iter)
122
- read_err_npu = True
123
- npu_merge_list = self.gen_merge_list(npu_json_data,op_name_npu,stack_json_data,summary_compare,md5_compare)
124
- if npu_merge_list:
125
- npu_ops_queue.append(npu_merge_list)
126
- except StopIteration:
127
- read_err_npu = False
128
- try:
129
- last_bench_ops_len = len(bench_ops_queue)
130
- op_name_bench = next(ops_bench_iter)
131
- bench_merge_list = self.gen_merge_list(bench_json_data,op_name_bench,stack_json_data,summary_compare,md5_compare)
132
- if bench_merge_list:
133
- bench_ops_queue.append(bench_merge_list)
134
- except StopIteration:
135
- read_err_bench = False
136
-
137
- # merge all boolean expressions
138
- both_empty = not npu_ops_queue and not bench_ops_queue
139
- no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len)
140
- if both_empty or no_change:
141
- continue
142
-
143
- # APIs in NPU and Bench models unconsistent judgment
144
- if bool(npu_ops_queue) ^ bool(bench_ops_queue):
145
- logger.info("Please check whether the number and calls of APIs in NPU and Bench models are consistent.")
146
- break
147
-
148
- n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue, fuzzy_match)
149
- if n_match_point == -1 and b_match_point == -1:
150
- continue
151
- n_match_data = npu_ops_queue[n_match_point]
152
- b_match_data = bench_ops_queue[b_match_point]
153
- un_match_data = npu_ops_queue[0: n_match_point]
154
- for npu_data in un_match_data:
155
- get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
156
- get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare)
157
- del npu_ops_queue[0: n_match_point + 1]
158
- del bench_ops_queue[0: b_match_point + 1]
159
- if npu_ops_queue:
160
- for npu_data in npu_ops_queue:
161
- get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
162
-
163
- result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
164
- return result_df
165
-
166
- def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
167
- npu_bench_name_list = op_name_mapping_dict[npu_op_name]
168
- data_name = npu_bench_name_list[1]
169
- error_file, relative_err, error_flag = None, None, False
170
- if data_name == '-1' or data_name == -1: # 没有真实数据路径
171
- n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
172
- error_flag = True
173
- else:
174
- try:
175
- read_npy_data = getattr(self, "read_npy_data")
176
- frame_name = getattr(self, "frame_name")
177
- if frame_name == "MSComparator":
178
- n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX)
179
- if self.cross_frame:
180
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.PT_SUFFIX, load_pt_file=True)
181
- else:
182
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.NUMPY_SUFFIX)
183
- else:
184
- n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX)
185
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.PT_SUFFIX)
186
- except IOError as error:
187
- error_file = error.filename
188
- n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
189
- error_flag = True
190
- except FileCheckException:
191
- error_file = data_name
192
- n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
193
- error_flag = True
194
-
195
- n_value, b_value, error_flag = get_error_type(n_value, b_value, error_flag)
196
- if not error_flag:
197
- relative_err = get_relative_err(n_value, b_value)
198
- n_value, b_value = reshape_value(n_value, b_value)
199
-
200
- err_msg = get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=error_file)
201
- result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=relative_err)
202
-
203
- if npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
204
- err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
205
- result_list.append(err_msg)
206
- return result_list
207
-
208
- def compare_core(self, input_parma, output_path, **kwargs):
209
- """
210
- Compares data from multiple JSON files and generates a comparison report.
211
-
212
- Args:
213
- input_parma (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
214
- "stack_path").
215
- output_path (str): The path where the output Excel report will be saved.
216
- **kwargs: Additional keyword arguments including:
217
- - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False.
218
- - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
219
- - suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
220
- - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
221
- - summary_compare (bool, optional): Enables summary comparison mode. Defaults to False.
222
- - md5_compare (bool, optional): Enables MD5 comparison. Defaults to False.
223
-
224
- Returns:
225
- """
226
- # get kwargs or set default value
227
- stack_mode = kwargs.get('stack_mode', False)
228
- auto_analyze = kwargs.get('auto_analyze', True)
229
- suffix = kwargs.get('suffix', '')
230
- fuzzy_match = kwargs.get('fuzzy_match', False)
231
- summary_compare = kwargs.get('summary_compare', False)
232
- md5_compare = kwargs.get('md5_compare', False)
233
-
234
- logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
235
- file_name = add_time_with_xlsx("compare_result" + suffix)
236
- file_path = os.path.join(os.path.realpath(output_path), file_name)
237
- check_file_not_exists(file_path)
238
- highlight_dict = {'red_rows': [], 'yellow_rows': []}
239
-
240
- with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \
241
- FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \
242
- FileOpen(input_parma.get("stack_json_path"), "r") as stack_json:
243
- result_df = self.compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match,
244
- summary_compare, md5_compare)
245
-
246
- if not md5_compare and not summary_compare:
247
- result_df = self._do_multi_process(input_parma, result_df)
248
- find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare)
249
- highlight_rows_xlsx(result_df, highlight_dict, file_path)
250
- if auto_analyze:
251
- advisor = Advisor(result_df, output_path)
252
- advisor.analysis()
253
-
254
- def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
255
- cos_result = []
256
- max_err_result = []
257
- max_relative_err_result = []
258
- err_mess = []
259
- one_thousand_err_ratio_result = []
260
- five_thousand_err_ratio_result = []
261
- is_print_compare_log = input_param.get("is_print_compare_log")
262
- for i in range(len(result_df)):
263
- npu_op_name = result_df.iloc[i, 0]
264
- bench_op_name = result_df.iloc[i, 1]
265
- if is_print_compare_log:
266
- logger.info("start compare: {}".format(npu_op_name))
267
- cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = self.compare_by_op(
268
- npu_op_name, bench_op_name, dump_path_dict, input_param)
269
- if is_print_compare_log:
270
- logger.info(
271
- "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, one_thousand_err_ratio {}, "
272
- "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err, err_msg,
273
- one_thousand_err_ratio, five_thousand_err_ratio))
274
- cos_result.append(cos_sim)
275
- max_err_result.append(max_abs_err)
276
- max_relative_err_result.append(max_relative_err)
277
- err_mess.append(err_msg)
278
- one_thousand_err_ratio_result.append(one_thousand_err_ratio)
279
- five_thousand_err_ratio_result.append(five_thousand_err_ratio)
280
-
281
- cr = ComparisonResult(
282
- cos_result=cos_result,
283
- max_err_result=max_err_result,
284
- max_relative_err_result=max_relative_err_result,
285
- err_msgs=err_mess,
286
- one_thousand_err_ratio_result=one_thousand_err_ratio_result,
287
- five_thousand_err_ratio_result=five_thousand_err_ratio_result
288
- )
289
-
290
- return _save_cmp_result(idx, cr, result_df, lock)
291
-
292
- def _do_multi_process(self,input_parma, result_df):
293
- try:
294
- result_df = _handle_multi_process(self.compare_ops, input_parma, result_df, multiprocessing.Manager().RLock())
295
- return result_df
296
- except ValueError as e:
297
- logger.error('result dataframe is not found.')
298
- raise CompareException(CompareException.INVALID_DATA_ERROR) from e
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import multiprocessing
17
+ import os
18
+ import pandas as pd
19
+ from tqdm import tqdm
20
+ from msprobe.core.common.file_utils import load_json
21
+ from msprobe.core.common.const import CompareConst, Const
22
+ from msprobe.core.common.exceptions import FileCheckException
23
+ from msprobe.core.common.log import logger
24
+ from msprobe.core.common.utils import add_time_with_xlsx, CompareException, check_op_str_pattern_valid
25
+ from msprobe.core.common.file_utils import remove_path
26
+ from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op, check_dump_json_str, \
27
+ check_stack_json_str
28
+ from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
29
+ from msprobe.core.compare.utils import read_op, merge_tensor, get_un_match_accuracy, get_accuracy
30
+ from msprobe.core.compare.multiprocessing_compute import _handle_multi_process, ComparisonResult, _save_cmp_result
31
+ from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \
32
+ get_error_message
33
+ from msprobe.core.advisor.advisor import Advisor
34
+
35
+
36
+ class Comparator:
37
+
38
+ def __init__(self):
39
+ pass
40
+
41
+ @staticmethod
42
+ def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args):
43
+ result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
44
+ bench_ops_all.get(bench_op_name).get('struct')[0],
45
+ npu_ops_all.get(ms_op_name).get('struct')[1],
46
+ bench_ops_all.get(bench_op_name).get('struct')[1],
47
+ npu_ops_all.get(ms_op_name).get('struct')[2],
48
+ bench_ops_all.get(bench_op_name).get('struct')[2],
49
+ CompareConst.PASS if npu_ops_all.get(ms_op_name).get('struct')[2]
50
+ == bench_ops_all.get(bench_op_name).get('struct')[2]
51
+ else CompareConst.DIFF]
52
+ if args[0]:
53
+ result_item.extend(args[1])
54
+ else:
55
+ result_item.append(CompareConst.NONE)
56
+ return result_item
57
+
58
+ @staticmethod
59
+ def calculate_summary_data(npu_summary_data, bench_summary_data, result_item):
60
+ err_msg = ""
61
+ start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
62
+ warning_flag = False
63
+ for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
64
+ if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
65
+ diff = npu_val - bench_val
66
+ if bench_val != 0:
67
+ relative = str(abs((diff / bench_val) * 100)) + '%'
68
+ else:
69
+ relative = "N/A"
70
+ result_item[start_idx + i] = diff
71
+ result_item[start_idx + i + 4] = relative
72
+ magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
73
+ if magnitude_diff > 0.5:
74
+ warning_flag = True
75
+ else:
76
+ result_item[start_idx + i] = CompareConst.NONE
77
+ accuracy_check = CompareConst.WARNING if warning_flag else ""
78
+ err_msg += "Need double check api accuracy." if warning_flag else ""
79
+ for i in range(start_idx, len(result_item)):
80
+ if str(result_item[i]) in ('inf', '-inf', 'nan'):
81
+ result_item[i] = f'{result_item[i]}\t'
82
+ result_item.append(accuracy_check)
83
+ result_item.append(err_msg)
84
+
85
+ @classmethod
86
+ def make_result_table(cls, result, md5_compare, summary_compare, stack_mode):
87
+ if md5_compare:
88
+ header = CompareConst.MD5_COMPARE_RESULT_HEADER[:]
89
+ elif summary_compare:
90
+ header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
91
+ else:
92
+ header = CompareConst.COMPARE_RESULT_HEADER[:]
93
+
94
+ all_mode_bool = not (summary_compare or md5_compare)
95
+ if stack_mode:
96
+ if all_mode_bool:
97
+ header.append(CompareConst.STACK)
98
+ header.append(CompareConst.DATA_NAME)
99
+ else:
100
+ header.append(CompareConst.STACK)
101
+ else:
102
+ if all_mode_bool:
103
+ for row in result:
104
+ del row[-2]
105
+ header.append(CompareConst.DATA_NAME)
106
+ else:
107
+ for row in result:
108
+ del row[-1]
109
+ result_df = pd.DataFrame(result, columns=header, dtype='object')
110
+ return result_df
111
+
112
+ @classmethod
113
+ def gen_merge_list(cls, json_data, op_name, stack_json_data, summary_compare, md5_compare):
114
+ op_data = json_data['data'][op_name]
115
+ check_dump_json_str(op_data, op_name)
116
+ op_parsed_list = read_op(op_data, op_name)
117
+
118
+ stack_info = stack_json_data.get(op_name)
119
+ if stack_info is not None:
120
+ check_stack_json_str(stack_info, op_name)
121
+ op_parsed_list.append({
122
+ 'full_op_name': op_name,
123
+ 'full_info': stack_info
124
+ })
125
+
126
+ merge_list = merge_tensor(op_parsed_list, summary_compare, md5_compare)
127
+ return merge_list
128
+
129
+ def check_op(self, npu_dict, bench_dict, fuzzy_match):
130
+ a_op_name = npu_dict["op_name"]
131
+ b_op_name = bench_dict["op_name"]
132
+ graph_mode = check_graph_mode(a_op_name[0], b_op_name[0])
133
+
134
+ frame_name = getattr(self, "frame_name")
135
+ if frame_name == "PTComparator":
136
+ from msprobe.pytorch.compare.match import graph_mapping
137
+ if graph_mode:
138
+ return graph_mapping.match(a_op_name[0], b_op_name[0])
139
+ struct_match = check_struct_match(npu_dict, bench_dict)
140
+ if not fuzzy_match:
141
+ return a_op_name == b_op_name and struct_match
142
+ is_match = True
143
+ try:
144
+ is_match = fuzzy_check_op(a_op_name, b_op_name)
145
+ except Exception as err:
146
+ logger.warning("%s and %s can not fuzzy match." % (a_op_name, b_op_name))
147
+ is_match = False
148
+ return is_match and struct_match
149
+
150
+ def match_op(self, npu_queue, bench_queue, fuzzy_match):
151
+ for b_index, b_op in enumerate(bench_queue[0: -1]):
152
+ if self.check_op(npu_queue[-1], b_op, fuzzy_match):
153
+ return len(npu_queue) - 1, b_index
154
+ if self.check_op(npu_queue[-1], bench_queue[-1], fuzzy_match):
155
+ return len(npu_queue) - 1, len(bench_queue) - 1
156
+ for n_index, n_op in enumerate(npu_queue[0: -1]):
157
+ if self.check_op(n_op, bench_queue[-1], fuzzy_match):
158
+ return n_index, len(bench_queue) - 1
159
+ return -1, -1
160
+
161
+ def compare_process(self, file_lists, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False):
162
+ npu_json_path, bench_json_path, stack_json_path = file_lists
163
+ npu_json_data = load_json(npu_json_path)
164
+ bench_json_data = load_json(bench_json_path)
165
+ stack_json_data = load_json(stack_json_path)
166
+
167
+ if fuzzy_match:
168
+ logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
169
+
170
+ npu_ops_queue = []
171
+ bench_ops_queue = []
172
+ result = []
173
+
174
+ ops_npu_iter = iter(npu_json_data['data'])
175
+ ops_bench_iter = iter(bench_json_data['data'])
176
+ read_err_npu = True
177
+ read_err_bench = True
178
+ last_npu_ops_len = 0
179
+ last_bench_ops_len = 0
180
+
181
+ npu_api_nums = len(npu_json_data['data'])
182
+ progress_bar = tqdm(total=npu_api_nums, desc="API/Module Read Progress", unit="item", ncols=100)
183
+
184
+ while True:
185
+ if not read_err_npu and not read_err_bench:
186
+ break
187
+ try:
188
+ last_npu_ops_len = len(npu_ops_queue)
189
+ op_name_npu = next(ops_npu_iter)
190
+ check_op_str_pattern_valid(op_name_npu)
191
+ read_err_npu = True
192
+ npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data,
193
+ summary_compare, md5_compare)
194
+ if npu_merge_list:
195
+ npu_ops_queue.append(npu_merge_list)
196
+ except StopIteration:
197
+ read_err_npu = False
198
+ try:
199
+ last_bench_ops_len = len(bench_ops_queue)
200
+ op_name_bench = next(ops_bench_iter)
201
+ check_op_str_pattern_valid(op_name_bench)
202
+ bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data,
203
+ summary_compare, md5_compare)
204
+ if bench_merge_list:
205
+ bench_ops_queue.append(bench_merge_list)
206
+ except StopIteration:
207
+ read_err_bench = False
208
+
209
+ progress_bar.update(1)
210
+
211
+ # merge all boolean expressions
212
+ both_empty = not npu_ops_queue and not bench_ops_queue
213
+ no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len)
214
+ if both_empty or no_change:
215
+ continue
216
+
217
+ # APIs in NPU and Bench models unconsistent judgment
218
+ if bool(npu_ops_queue) ^ bool(bench_ops_queue):
219
+ logger.info("Please check whether the number and calls of APIs in NPU and Bench models are consistent.")
220
+ break
221
+
222
+ n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue, fuzzy_match)
223
+ if n_match_point == -1 and b_match_point == -1:
224
+ continue
225
+ n_match_data = npu_ops_queue[n_match_point]
226
+ b_match_data = bench_ops_queue[b_match_point]
227
+ un_match_data = npu_ops_queue[0: n_match_point]
228
+ for npu_data in un_match_data:
229
+ get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
230
+ get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare)
231
+ del npu_ops_queue[0: n_match_point + 1]
232
+ del bench_ops_queue[0: b_match_point + 1]
233
+ if npu_ops_queue:
234
+ for npu_data in npu_ops_queue:
235
+ get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
236
+
237
+ result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
238
+ return result_df
239
+
240
+ def merge_data(self, json_data, stack_json_data, summary_compare, md5_compare):
241
+ ops_all = {}
242
+ for op_name in json_data.get('data', {}):
243
+ merge_list = self.gen_merge_list(json_data, op_name, stack_json_data, summary_compare,
244
+ md5_compare)
245
+ if merge_list:
246
+ input_index, output_index = 0, 0
247
+ for index, input_or_output in enumerate(merge_list['op_name']):
248
+ input_or_output_list = input_or_output.split(Const.SEP)
249
+ data_name = merge_list.get('data_name')
250
+ data_name = data_name[index] if data_name else None
251
+ if Const.INPUT in input_or_output_list or Const.KWARGS in input_or_output_list:
252
+ ops_all[input_or_output] = {'struct': merge_list.get('input_struct')[input_index],
253
+ 'summary': merge_list.get('summary')[index],
254
+ 'data_name': data_name,
255
+ 'stack_info': merge_list.get('stack_info')}
256
+ input_index += 1
257
+
258
+ elif Const.OUTPUT in input_or_output_list:
259
+ ops_all[input_or_output] = {'struct': merge_list.get('output_struct')[output_index],
260
+ 'summary': merge_list.get('summary')[index],
261
+ 'data_name': data_name,
262
+ 'stack_info': merge_list.get('stack_info')}
263
+ output_index += 1
264
+ return ops_all
265
+
266
+ def get_accuracy(self, npu_ops_all, bench_ops_all, summary_compare, md5_compare):
267
+ result = []
268
+ for ms_op_name, bench_op_name in self.data_mapping_dict.items():
269
+ if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
270
+ npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
271
+ bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
272
+ has_stack = npu_stack_info and bench_stack_info
273
+ if md5_compare:
274
+ result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all,
275
+ bench_ops_all, has_stack, npu_stack_info))
276
+ continue
277
+ if summary_compare:
278
+ result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
279
+ bench_ops_all.get(bench_op_name).get('struct')[0],
280
+ npu_ops_all.get(ms_op_name).get('struct')[1],
281
+ bench_ops_all.get(bench_op_name).get('struct')[1],
282
+ " ", " ", " ", " ", " ", " ", " ", " "]
283
+ else:
284
+ result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
285
+ bench_ops_all.get(bench_op_name).get('struct')[0],
286
+ npu_ops_all.get(ms_op_name).get('struct')[1],
287
+ bench_ops_all.get(bench_op_name).get('struct')[1],
288
+ " ", " ", " ", " ", " "]
289
+ npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
290
+ result_item.extend(npu_summary_data)
291
+ bench_summary_data = bench_ops_all.get(bench_op_name).get("summary")
292
+ result_item.extend(bench_summary_data)
293
+ if summary_compare:
294
+ self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item)
295
+ else:
296
+ result_item.append(CompareConst.ACCURACY_CHECK_YES)
297
+ result_item.append("")
298
+ if has_stack:
299
+ result_item.extend(npu_stack_info)
300
+ else:
301
+ result_item.append(CompareConst.NONE)
302
+ if not (summary_compare or md5_compare):
303
+ result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None))
304
+ result.append(result_item)
305
+ elif ms_op_name not in npu_ops_all:
306
+ logger.warning(f'Can not find npu op name : `{ms_op_name}` in npu dump json file.')
307
+ elif bench_op_name not in npu_ops_all:
308
+ logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.')
309
+ return result
310
+
311
+ def compare_process_custom(self, file_lists, stack_mode, summary_compare=False, md5_compare=False):
312
+ npu_json_path, bench_json_path, stack_json_path = file_lists
313
+ npu_json_data = load_json(npu_json_path)
314
+ bench_json_data = load_json(bench_json_path)
315
+ stack_json_data = load_json(stack_json_path)
316
+
317
+ npu_ops_all = self.merge_data(npu_json_data, stack_json_data, summary_compare, md5_compare)
318
+ bench_ops_all = self.merge_data(bench_json_data, stack_json_data, summary_compare, md5_compare)
319
+
320
+ result = self.get_accuracy(npu_ops_all, bench_ops_all, summary_compare, md5_compare)
321
+ result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
322
+ return result_df
323
+
324
+ def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
325
+ npu_bench_name_list = op_name_mapping_dict[npu_op_name]
326
+ data_name = npu_bench_name_list[1]
327
+ error_file, relative_err, error_flag = None, None, False
328
+ if data_name == '-1' or data_name == -1: # 没有真实数据路径
329
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
330
+ error_flag = True
331
+ else:
332
+ try:
333
+ read_npy_data = getattr(self, "read_npy_data")
334
+ frame_name = getattr(self, "frame_name")
335
+ if frame_name == "MSComparator":
336
+ n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX)
337
+ if self.cross_frame:
338
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
339
+ bench_op_name + Const.PT_SUFFIX, load_pt_file=True)
340
+ else:
341
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
342
+ bench_op_name + Const.NUMPY_SUFFIX)
343
+ else:
344
+ n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX)
345
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.PT_SUFFIX)
346
+ except IOError as error:
347
+ error_file = error.filename
348
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
349
+ error_flag = True
350
+ except FileCheckException:
351
+ error_file = data_name
352
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
353
+ error_flag = True
354
+
355
+ n_value, b_value, error_flag = get_error_type(n_value, b_value, error_flag)
356
+ if not error_flag:
357
+ relative_err = get_relative_err(n_value, b_value)
358
+ n_value, b_value = reshape_value(n_value, b_value)
359
+
360
+ err_msg = get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=error_file)
361
+ result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=relative_err)
362
+
363
+ if npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
364
+ err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
365
+ result_list.append(err_msg)
366
+ return result_list
367
+
368
+ def compare_core(self, input_parma, output_path, **kwargs):
369
+ """
370
+ Compares data from multiple JSON files and generates a comparison report.
371
+
372
+ Args:
373
+ input_parma (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
374
+ "stack_path").
375
+ output_path (str): The path where the output Excel report will be saved.
376
+ **kwargs: Additional keyword arguments including:
377
+ - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False.
378
+ - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
379
+ - suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
380
+ - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
381
+ - summary_compare (bool, optional): Enables summary comparison mode. Defaults to False.
382
+ - md5_compare (bool, optional): Enables MD5 comparison. Defaults to False.
383
+
384
+ Returns:
385
+ """
386
+ # get kwargs or set default value
387
+ stack_mode = kwargs.get('stack_mode', False)
388
+ auto_analyze = kwargs.get('auto_analyze', True)
389
+ suffix = kwargs.get('suffix', '')
390
+ fuzzy_match = kwargs.get('fuzzy_match', False)
391
+ summary_compare = kwargs.get('summary_compare', False)
392
+ md5_compare = kwargs.get('md5_compare', False)
393
+
394
+ logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
395
+ file_name = add_time_with_xlsx("compare_result" + suffix)
396
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
397
+ remove_path(file_path)
398
+ highlight_dict = {'red_rows': [], 'yellow_rows': []}
399
+
400
+ npu_json = input_parma.get("npu_json_path")
401
+ bench_json = input_parma.get("bench_json_path")
402
+ stack_json = input_parma.get("stack_json_path")
403
+ if self.data_mapping:
404
+ result_df = self.compare_process_custom([npu_json, bench_json, stack_json], stack_mode,
405
+ summary_compare, md5_compare)
406
+ else:
407
+ result_df = self.compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match,
408
+ summary_compare, md5_compare)
409
+
410
+ if not result_df.values.tolist():
411
+ logger.warning("Can`t match any op.")
412
+ return
413
+
414
+ if not md5_compare and not summary_compare:
415
+ result_df = self._do_multi_process(input_parma, result_df)
416
+
417
+ logger.info("Highlight suspicious API/Module start.")
418
+ find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare)
419
+ highlight_rows_xlsx(result_df, highlight_dict, file_path)
420
+ logger.info("Highlight suspicious API/Module finish.")
421
+
422
+ if auto_analyze:
423
+ advisor = Advisor(result_df, output_path, suffix)
424
+ advisor.analysis()
425
+
426
+ def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
427
+ cos_result = []
428
+ max_err_result = []
429
+ max_relative_err_result = []
430
+ err_mess = []
431
+ one_thousand_err_ratio_result = []
432
+ five_thousand_err_ratio_result = []
433
+ is_print_compare_log = input_param.get("is_print_compare_log")
434
+ for i in range(len(result_df)):
435
+ npu_op_name = result_df.iloc[i, 0]
436
+ bench_op_name = result_df.iloc[i, 1]
437
+ if is_print_compare_log:
438
+ logger.info("start compare: {}".format(npu_op_name))
439
+ cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = \
440
+ self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
441
+ if is_print_compare_log:
442
+ logger.info(
443
+ "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
444
+ one_thousand_err_ratio {}, "
445
+ "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err,
446
+ err_msg, one_thousand_err_ratio, five_thousand_err_ratio))
447
+ cos_result.append(cos_sim)
448
+ max_err_result.append(max_abs_err)
449
+ max_relative_err_result.append(max_relative_err)
450
+ err_mess.append(err_msg)
451
+ one_thousand_err_ratio_result.append(one_thousand_err_ratio)
452
+ five_thousand_err_ratio_result.append(five_thousand_err_ratio)
453
+
454
+ cr = ComparisonResult(
455
+ cos_result=cos_result,
456
+ max_err_result=max_err_result,
457
+ max_relative_err_result=max_relative_err_result,
458
+ err_msgs=err_mess,
459
+ one_thousand_err_ratio_result=one_thousand_err_ratio_result,
460
+ five_thousand_err_ratio_result=five_thousand_err_ratio_result
461
+ )
462
+
463
+ return _save_cmp_result(idx, cr, result_df, lock)
464
+
465
+ def _do_multi_process(self, input_parma, result_df):
466
+ try:
467
+ result_df = _handle_multi_process(self.compare_ops, input_parma, result_df,
468
+ multiprocessing.Manager().RLock())
469
+ return result_df
470
+ except ValueError as e:
471
+ logger.error('result dataframe is not found.')
472
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
299
473