mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,41 +15,57 @@
15
15
 
16
16
  import multiprocessing
17
17
  import os
18
+ import re
19
+ from copy import deepcopy
20
+
18
21
  import pandas as pd
19
22
  from tqdm import tqdm
20
- from msprobe.core.common.file_utils import load_json
23
+
24
+ from msprobe.core.advisor.advisor import Advisor
21
25
  from msprobe.core.common.const import CompareConst, Const
22
26
  from msprobe.core.common.exceptions import FileCheckException
27
+ from msprobe.core.common.file_utils import load_json, remove_path
23
28
  from msprobe.core.common.log import logger
24
- from msprobe.core.common.utils import add_time_with_xlsx, CompareException, check_op_str_pattern_valid
25
- from msprobe.core.common.file_utils import remove_path
26
- from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op, check_dump_json_str, \
27
- check_stack_json_str
29
+ from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, safe_get_value
30
+ from msprobe.core.compare.check import check_dump_json_str, check_graph_mode, check_stack_json_str, \
31
+ check_struct_match, fuzzy_check_op
28
32
  from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
29
- from msprobe.core.compare.utils import read_op, merge_tensor, get_un_match_accuracy, get_accuracy
30
- from msprobe.core.compare.multiprocessing_compute import _handle_multi_process, ComparisonResult, _save_cmp_result
31
- from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \
32
- get_error_message
33
- from msprobe.core.advisor.advisor import Advisor
33
+ from msprobe.core.compare.multiprocessing_compute import ComparisonResult, _handle_multi_process, _save_cmp_result
34
+ from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_flag_and_msg
35
+ from msprobe.core.compare.utils import get_accuracy, get_rela_diff_summary_mode, get_un_match_accuracy, merge_tensor, \
36
+ print_compare_ends_info, read_op, get_name_and_state, reorder_op_x_list
37
+
38
+
39
+ class ModeConfig:
40
+ def __init__(self, stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=None):
41
+ self.stack_mode = stack_mode
42
+ self.auto_analyze = auto_analyze
43
+ self.fuzzy_match = fuzzy_match
44
+ self.dump_mode = dump_mode
34
45
 
35
46
 
36
47
  class Comparator:
37
-
38
- def __init__(self):
39
- pass
48
+ def __init__(self, mode_config: ModeConfig):
49
+ self.stack_mode = mode_config.stack_mode
50
+ self.auto_analyze = mode_config.auto_analyze
51
+ self.fuzzy_match = mode_config.fuzzy_match
52
+ self.dump_mode = mode_config.dump_mode
40
53
 
41
54
  @staticmethod
42
55
  def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args):
43
- result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
44
- bench_ops_all.get(bench_op_name).get('struct')[0],
45
- npu_ops_all.get(ms_op_name).get('struct')[1],
46
- bench_ops_all.get(bench_op_name).get('struct')[1],
47
- npu_ops_all.get(ms_op_name).get('struct')[2],
48
- bench_ops_all.get(bench_op_name).get('struct')[2],
49
- CompareConst.PASS if npu_ops_all.get(ms_op_name).get('struct')[2]
50
- == bench_ops_all.get(bench_op_name).get('struct')[2]
51
- else CompareConst.DIFF]
52
- if args[0]:
56
+ npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
57
+ bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
58
+
59
+ if len(npu_struct) < 3 or len(bench_struct) < 3:
60
+ logger.error(f"The length of npu_struct and bench_struct must be >= 3, "
61
+ f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. Please check!")
62
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
63
+
64
+ result_item = [ms_op_name, bench_op_name, npu_struct[0], bench_struct[0],
65
+ npu_struct[1], bench_struct[1], npu_struct[2], bench_struct[2],
66
+ CompareConst.PASS if npu_struct[2] == bench_struct[2] else CompareConst.DIFF]
67
+
68
+ if len(args) >= 2 and args[0]:
53
69
  result_item.extend(args[1])
54
70
  else:
55
71
  result_item.append(CompareConst.NONE)
@@ -58,113 +74,102 @@ class Comparator:
58
74
  @staticmethod
59
75
  def calculate_summary_data(npu_summary_data, bench_summary_data, result_item):
60
76
  err_msg = ""
61
- start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
62
- warning_flag = False
63
- for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
64
- if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
65
- diff = npu_val - bench_val
66
- if bench_val != 0:
67
- relative = str(abs((diff / bench_val) * 100)) + '%'
68
- else:
69
- relative = "N/A"
70
- result_item[start_idx + i] = diff
71
- result_item[start_idx + i + 4] = relative
72
- magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
73
- if magnitude_diff > 0.5:
74
- warning_flag = True
75
- else:
76
- result_item[start_idx + i] = CompareConst.NONE
77
- accuracy_check = CompareConst.WARNING if warning_flag else ""
78
- err_msg += "Need double check api accuracy." if warning_flag else ""
79
- for i in range(start_idx, len(result_item)):
80
- if str(result_item[i]) in ('inf', '-inf', 'nan'):
81
- result_item[i] = f'{result_item[i]}\t'
77
+ result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
78
+ bench_summary_data, err_msg)
82
79
  result_item.append(accuracy_check)
83
80
  result_item.append(err_msg)
84
-
85
- @classmethod
86
- def make_result_table(cls, result, md5_compare, summary_compare, stack_mode):
87
- if md5_compare:
88
- header = CompareConst.MD5_COMPARE_RESULT_HEADER[:]
89
- elif summary_compare:
90
- header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
91
- else:
92
- header = CompareConst.COMPARE_RESULT_HEADER[:]
93
81
 
94
- all_mode_bool = not (summary_compare or md5_compare)
95
- if stack_mode:
96
- if all_mode_bool:
97
- header.append(CompareConst.STACK)
98
- header.append(CompareConst.DATA_NAME)
82
+ @staticmethod
83
+ def _generate_na_data(ops_all):
84
+ if not ops_all:
85
+ return {}
86
+ key = next(iter(ops_all))
87
+ value = deepcopy(ops_all[key])
88
+ for k, v in value.items():
89
+ if isinstance(v, tuple):
90
+ value[k] = tuple(CompareConst.N_A for _ in range(len(v)))
91
+ elif isinstance(v, list):
92
+ value[k] = [CompareConst.N_A] * len(v)
99
93
  else:
100
- header.append(CompareConst.STACK)
94
+ value[k] = CompareConst.N_A
95
+ return value
96
+
97
+ def make_result_table(self, result):
98
+ header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:]
99
+
100
+ if self.stack_mode:
101
+ header.append(CompareConst.STACK)
102
+ if self.dump_mode == Const.ALL:
103
+ header.append(CompareConst.DATA_NAME)
101
104
  else:
102
- if all_mode_bool:
105
+ if self.dump_mode == Const.ALL:
103
106
  for row in result:
104
- del row[-2]
107
+ del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列
105
108
  header.append(CompareConst.DATA_NAME)
106
109
  else:
107
110
  for row in result:
108
- del row[-1]
111
+ del row[-1] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,非真实数据时为倒数第1列
109
112
  result_df = pd.DataFrame(result, columns=header, dtype='object')
110
- return result_df
111
-
112
- @classmethod
113
- def gen_merge_list(cls, json_data, op_name, stack_json_data, summary_compare, md5_compare):
113
+ return result_df
114
+
115
+ def gen_merge_list(self, json_data, op_name, stack_json_data):
114
116
  op_data = json_data['data'][op_name]
115
117
  check_dump_json_str(op_data, op_name)
116
118
  op_parsed_list = read_op(op_data, op_name)
117
119
 
118
- stack_info = stack_json_data.get(op_name)
119
- if stack_info is not None:
120
- check_stack_json_str(stack_info, op_name)
121
- op_parsed_list.append({
122
- 'full_op_name': op_name,
123
- 'full_info': stack_info
124
- })
125
-
126
- merge_list = merge_tensor(op_parsed_list, summary_compare, md5_compare)
120
+ if self.stack_mode:
121
+ stack_info = stack_json_data.get(op_name)
122
+ if stack_info is not None:
123
+ check_stack_json_str(stack_info, op_name)
124
+ # append only when stack_mode is True,
125
+ op_parsed_list.append({
126
+ 'full_op_name': op_name,
127
+ 'full_info': stack_info
128
+ })
129
+
130
+ merge_list = merge_tensor(op_parsed_list, self.dump_mode)
127
131
  return merge_list
128
-
129
- def check_op(self, npu_dict, bench_dict, fuzzy_match):
130
- a_op_name = npu_dict["op_name"]
131
- b_op_name = bench_dict["op_name"]
132
- graph_mode = check_graph_mode(a_op_name[0], b_op_name[0])
133
-
132
+
133
+ def check_op(self, npu_dict, bench_dict):
134
+ npu_op_name = npu_dict[CompareConst.OP_NAME]
135
+ bench_op_name = bench_dict[CompareConst.OP_NAME]
136
+ graph_mode = check_graph_mode(safe_get_value(npu_op_name, 0, "npu_op_name"),
137
+ safe_get_value(bench_op_name, 0, "bench_op_name"))
138
+
134
139
  frame_name = getattr(self, "frame_name")
135
140
  if frame_name == "PTComparator":
136
141
  from msprobe.pytorch.compare.match import graph_mapping
137
142
  if graph_mode:
138
- return graph_mapping.match(a_op_name[0], b_op_name[0])
143
+ return graph_mapping.match(npu_op_name[0], bench_op_name[0])
139
144
  struct_match = check_struct_match(npu_dict, bench_dict)
140
- if not fuzzy_match:
141
- return a_op_name == b_op_name and struct_match
142
- is_match = True
145
+ if not self.fuzzy_match:
146
+ name_match = npu_op_name == bench_op_name
147
+ return name_match and struct_match
143
148
  try:
144
- is_match = fuzzy_check_op(a_op_name, b_op_name)
149
+ name_match = fuzzy_check_op(npu_op_name, bench_op_name)
145
150
  except Exception as err:
146
- logger.warning("%s and %s can not fuzzy match." % (a_op_name, b_op_name))
147
- is_match = False
148
- return is_match and struct_match
149
-
150
- def match_op(self, npu_queue, bench_queue, fuzzy_match):
151
+ logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
152
+ name_match = False
153
+ return name_match and struct_match
154
+
155
+ def match_op(self, npu_queue, bench_queue):
151
156
  for b_index, b_op in enumerate(bench_queue[0: -1]):
152
- if self.check_op(npu_queue[-1], b_op, fuzzy_match):
157
+ if self.check_op(npu_queue[-1], b_op):
153
158
  return len(npu_queue) - 1, b_index
154
- if self.check_op(npu_queue[-1], bench_queue[-1], fuzzy_match):
159
+ if self.check_op(npu_queue[-1], bench_queue[-1]):
155
160
  return len(npu_queue) - 1, len(bench_queue) - 1
156
161
  for n_index, n_op in enumerate(npu_queue[0: -1]):
157
- if self.check_op(n_op, bench_queue[-1], fuzzy_match):
162
+ if self.check_op(n_op, bench_queue[-1]):
158
163
  return n_index, len(bench_queue) - 1
159
164
  return -1, -1
160
-
161
- def compare_process(self, file_lists, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False):
165
+
166
+ def compare_process(self, file_lists):
162
167
  npu_json_path, bench_json_path, stack_json_path = file_lists
163
168
  npu_json_data = load_json(npu_json_path)
164
169
  bench_json_data = load_json(bench_json_path)
165
- stack_json_data = load_json(stack_json_path)
170
+ stack_json_data = load_json(stack_json_path) if self.stack_mode else None
166
171
 
167
- if fuzzy_match:
172
+ if self.fuzzy_match:
168
173
  logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
169
174
 
170
175
  npu_ops_queue = []
@@ -188,9 +193,7 @@ class Comparator:
188
193
  last_npu_ops_len = len(npu_ops_queue)
189
194
  op_name_npu = next(ops_npu_iter)
190
195
  check_op_str_pattern_valid(op_name_npu)
191
- read_err_npu = True
192
- npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data,
193
- summary_compare, md5_compare)
196
+ npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data)
194
197
  if npu_merge_list:
195
198
  npu_ops_queue.append(npu_merge_list)
196
199
  except StopIteration:
@@ -199,8 +202,7 @@ class Comparator:
199
202
  last_bench_ops_len = len(bench_ops_queue)
200
203
  op_name_bench = next(ops_bench_iter)
201
204
  check_op_str_pattern_valid(op_name_bench)
202
- bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data,
203
- summary_compare, md5_compare)
205
+ bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data)
204
206
  if bench_merge_list:
205
207
  bench_ops_queue.append(bench_merge_list)
206
208
  except StopIteration:
@@ -219,78 +221,105 @@ class Comparator:
219
221
  logger.info("Please check whether the number and calls of APIs in NPU and Bench models are consistent.")
220
222
  break
221
223
 
222
- n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue, fuzzy_match)
224
+ n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue)
225
+
226
+ # 如果没有匹配到,数据放到队列中,跳过,直到后面匹配到,把匹配之前的api放到不匹配中
223
227
  if n_match_point == -1 and b_match_point == -1:
224
228
  continue
229
+
225
230
  n_match_data = npu_ops_queue[n_match_point]
226
231
  b_match_data = bench_ops_queue[b_match_point]
227
232
  un_match_data = npu_ops_queue[0: n_match_point]
228
233
  for npu_data in un_match_data:
229
- get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
230
- get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare)
234
+ get_un_match_accuracy(result, npu_data, self.dump_mode)
235
+ get_accuracy(result, n_match_data, b_match_data, self.dump_mode)
231
236
  del npu_ops_queue[0: n_match_point + 1]
232
237
  del bench_ops_queue[0: b_match_point + 1]
238
+ progress_bar.close()
233
239
  if npu_ops_queue:
234
240
  for npu_data in npu_ops_queue:
235
- get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
236
-
237
- result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
241
+ get_un_match_accuracy(result, npu_data, self.dump_mode)
242
+
243
+ result_df = self.make_result_table(result)
238
244
  return result_df
239
245
 
240
- def merge_data(self, json_data, stack_json_data, summary_compare, md5_compare):
246
+ def merge_data(self, json_data, stack_json_data):
241
247
  ops_all = {}
242
248
  for op_name in json_data.get('data', {}):
243
- merge_list = self.gen_merge_list(json_data, op_name, stack_json_data, summary_compare,
244
- md5_compare)
249
+ merge_list = self.gen_merge_list(json_data, op_name, stack_json_data)
245
250
  if merge_list:
246
- input_index, output_index = 0, 0
247
- for index, input_or_output in enumerate(merge_list['op_name']):
248
- input_or_output_list = input_or_output.split(Const.SEP)
249
- data_name = merge_list.get('data_name')
250
- data_name = data_name[index] if data_name else None
251
- if Const.INPUT in input_or_output_list or Const.KWARGS in input_or_output_list:
252
- ops_all[input_or_output] = {'struct': merge_list.get('input_struct')[input_index],
253
- 'summary': merge_list.get('summary')[index],
254
- 'data_name': data_name,
255
- 'stack_info': merge_list.get('stack_info')}
256
- input_index += 1
257
-
258
- elif Const.OUTPUT in input_or_output_list:
259
- ops_all[input_or_output] = {'struct': merge_list.get('output_struct')[output_index],
260
- 'summary': merge_list.get('summary')[index],
261
- 'data_name': data_name,
262
- 'stack_info': merge_list.get('stack_info')}
263
- output_index += 1
251
+ struct_to_index_mapping = {
252
+ CompareConst.INPUT_STRUCT: 0,
253
+ CompareConst.OUTPUT_STRUCT: 0,
254
+ CompareConst.PARAMS_STRUCT: 0,
255
+ CompareConst.PARAMS_GRAD_STRUCT: 0
256
+ }
257
+
258
+ op_name_list = merge_list.get(CompareConst.OP_NAME)
259
+ summary_list = merge_list.get(Const.SUMMARY)
260
+ data_name_list = merge_list.get('data_name')
261
+ op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
262
+ summary_list,
263
+ data_name_list)
264
+ for index, op_full_name in enumerate(op_name_reorder):
265
+ data_name = data_name_reorder[index] if data_name_reorder else None
266
+
267
+ _, state = get_name_and_state(op_full_name)
268
+ struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
269
+ if not struct_key:
270
+ continue
271
+ ops_all[op_full_name] = {
272
+ CompareConst.STRUCT: safe_get_value(merge_list, struct_to_index_mapping.get(struct_key),
273
+ "merge_list", key=struct_key),
274
+ CompareConst.SUMMARY: safe_get_value(summary_reorder, index, "summary_reorder"),
275
+ 'data_name': data_name,
276
+ 'stack_info': merge_list.get('stack_info')
277
+ }
278
+ struct_to_index_mapping[struct_key] += 1
264
279
  return ops_all
265
280
 
266
- def get_accuracy(self, npu_ops_all, bench_ops_all, summary_compare, md5_compare):
281
+ def get_accuracy(self, npu_ops_all, bench_ops_all):
267
282
  result = []
283
+ bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all)
268
284
  for ms_op_name, bench_op_name in self.data_mapping_dict.items():
269
285
  if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
270
286
  npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
271
287
  bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
272
288
  has_stack = npu_stack_info and bench_stack_info
273
- if md5_compare:
289
+ if self.dump_mode == Const.MD5:
274
290
  result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all,
275
291
  bench_ops_all, has_stack, npu_stack_info))
276
292
  continue
277
- if summary_compare:
278
- result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
279
- bench_ops_all.get(bench_op_name).get('struct')[0],
280
- npu_ops_all.get(ms_op_name).get('struct')[1],
281
- bench_ops_all.get(bench_op_name).get('struct')[1],
282
- " ", " ", " ", " ", " ", " ", " ", " "]
293
+
294
+ npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
295
+ bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
296
+
297
+ if len(npu_struct) < 2 or len(bench_struct) < 2:
298
+ logger.error(
299
+ f"The length of npu_struct and bench_struct must be >= 2, "
300
+ f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. "
301
+ f"Please check!"
302
+ )
303
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
304
+
305
+ base_result_item = [
306
+ ms_op_name, bench_op_name,
307
+ npu_struct[0],
308
+ bench_struct[0],
309
+ npu_struct[1],
310
+ bench_struct[1]
311
+ ]
312
+
313
+ if self.dump_mode == Const.SUMMARY:
314
+ result_item = base_result_item + [" "] * 8
283
315
  else:
284
- result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
285
- bench_ops_all.get(bench_op_name).get('struct')[0],
286
- npu_ops_all.get(ms_op_name).get('struct')[1],
287
- bench_ops_all.get(bench_op_name).get('struct')[1],
288
- " ", " ", " ", " ", " "]
316
+ result_item = base_result_item + [" "] * 5
317
+
289
318
  npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
290
319
  result_item.extend(npu_summary_data)
291
320
  bench_summary_data = bench_ops_all.get(bench_op_name).get("summary")
292
321
  result_item.extend(bench_summary_data)
293
- if summary_compare:
322
+ if self.dump_mode == Const.SUMMARY:
294
323
  self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item)
295
324
  else:
296
325
  result_item.append(CompareConst.ACCURACY_CHECK_YES)
@@ -299,7 +328,7 @@ class Comparator:
299
328
  result_item.extend(npu_stack_info)
300
329
  else:
301
330
  result_item.append(CompareConst.NONE)
302
- if not (summary_compare or md5_compare):
331
+ if self.dump_mode == Const.ALL:
303
332
  result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None))
304
333
  result.append(result_item)
305
334
  elif ms_op_name not in npu_ops_all:
@@ -308,26 +337,39 @@ class Comparator:
308
337
  logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.')
309
338
  return result
310
339
 
311
- def compare_process_custom(self, file_lists, stack_mode, summary_compare=False, md5_compare=False):
340
+ def compare_process_custom(self, file_lists):
312
341
  npu_json_path, bench_json_path, stack_json_path = file_lists
313
342
  npu_json_data = load_json(npu_json_path)
314
343
  bench_json_data = load_json(bench_json_path)
315
- stack_json_data = load_json(stack_json_path)
344
+ stack_json_data = load_json(stack_json_path) if self.stack_mode else None
345
+ npu_ops_all = self.merge_data(npu_json_data, stack_json_data)
346
+ bench_ops_all = self.merge_data(bench_json_data, stack_json_data)
316
347
 
317
- npu_ops_all = self.merge_data(npu_json_data, stack_json_data, summary_compare, md5_compare)
318
- bench_ops_all = self.merge_data(bench_json_data, stack_json_data, summary_compare, md5_compare)
319
-
320
- result = self.get_accuracy(npu_ops_all, bench_ops_all, summary_compare, md5_compare)
321
- result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
348
+ result = self.get_accuracy(npu_ops_all, bench_ops_all)
349
+ result_df = self.make_result_table(result)
322
350
  return result_df
323
351
 
324
- def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
352
+ def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param, bench_data):
353
+ """
354
+ :param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0
355
+ :param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
356
+ :param op_name_mapping_dict: op_name和npy或pt文件的映射关系
357
+ :param input_param: npu_json_path/bench_json_path/stack_json_path等参数
358
+ :param bench_data: bench的dump数据中"data"字段
359
+ :return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
360
+ 用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、
361
+ 最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
362
+ """
325
363
  npu_bench_name_list = op_name_mapping_dict[npu_op_name]
326
- data_name = npu_bench_name_list[1]
364
+ data_name = safe_get_value(npu_bench_name_list, 1, "npu_bench_name_list")
327
365
  error_file, relative_err, error_flag = None, None, False
366
+ bench_data_name = get_bench_data_name(bench_op_name, bench_data)
328
367
  if data_name == '-1' or data_name == -1: # 没有真实数据路径
329
368
  n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
330
369
  error_flag = True
370
+ elif not bench_data_name:
371
+ n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
372
+ error_file = 'no_bench_data'
331
373
  else:
332
374
  try:
333
375
  read_npy_data = getattr(self, "read_npy_data")
@@ -335,42 +377,39 @@ class Comparator:
335
377
  if frame_name == "MSComparator":
336
378
  n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX)
337
379
  if self.cross_frame:
338
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
339
- bench_op_name + Const.PT_SUFFIX, load_pt_file=True)
380
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name,
381
+ load_pt_file=True)
340
382
  else:
341
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
342
- bench_op_name + Const.NUMPY_SUFFIX)
383
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
343
384
  else:
344
385
  n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX)
345
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.PT_SUFFIX)
386
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
346
387
  except IOError as error:
347
388
  error_file = error.filename
348
389
  n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
349
390
  error_flag = True
350
- except FileCheckException:
391
+ except (FileCheckException, CompareException):
351
392
  error_file = data_name
352
393
  n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
353
394
  error_flag = True
354
395
 
355
- n_value, b_value, error_flag = get_error_type(n_value, b_value, error_flag)
356
- if not error_flag:
357
- relative_err = get_relative_err(n_value, b_value)
358
- n_value, b_value = reshape_value(n_value, b_value)
396
+ # 通过n_value, b_value同时得到错误标志和错误信息
397
+ n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value,
398
+ error_flag=error_flag, error_file=error_file)
359
399
 
360
- err_msg = get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=error_file)
361
- result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=relative_err)
400
+ result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg)
362
401
 
363
- if npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
402
+ if self.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
364
403
  err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
365
404
  result_list.append(err_msg)
366
405
  return result_list
367
-
368
- def compare_core(self, input_parma, output_path, **kwargs):
406
+
407
+ def compare_core(self, input_param, output_path, **kwargs):
369
408
  """
370
409
  Compares data from multiple JSON files and generates a comparison report.
371
410
 
372
411
  Args:
373
- input_parma (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
412
+ input_param (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
374
413
  "stack_path").
375
414
  output_path (str): The path where the output Excel report will be saved.
376
415
  **kwargs: Additional keyword arguments including:
@@ -378,51 +417,43 @@ class Comparator:
378
417
  - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
379
418
  - suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
380
419
  - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
381
- - summary_compare (bool, optional): Enables summary comparison mode. Defaults to False.
382
- - md5_compare (bool, optional): Enables MD5 comparison. Defaults to False.
420
+ - dump_mode (str): ALL, SUMMARY, MD5.
383
421
 
384
422
  Returns:
385
423
  """
386
424
  # get kwargs or set default value
387
- stack_mode = kwargs.get('stack_mode', False)
388
- auto_analyze = kwargs.get('auto_analyze', True)
389
425
  suffix = kwargs.get('suffix', '')
390
- fuzzy_match = kwargs.get('fuzzy_match', False)
391
- summary_compare = kwargs.get('summary_compare', False)
392
- md5_compare = kwargs.get('md5_compare', False)
393
426
 
394
427
  logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
395
428
  file_name = add_time_with_xlsx("compare_result" + suffix)
396
429
  file_path = os.path.join(os.path.realpath(output_path), file_name)
397
430
  remove_path(file_path)
398
- highlight_dict = {'red_rows': [], 'yellow_rows': []}
431
+ highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
399
432
 
400
- npu_json = input_parma.get("npu_json_path")
401
- bench_json = input_parma.get("bench_json_path")
402
- stack_json = input_parma.get("stack_json_path")
433
+ npu_json = input_param.get("npu_json_path")
434
+ bench_json = input_param.get("bench_json_path")
435
+ stack_json = input_param.get("stack_json_path")
403
436
  if self.data_mapping:
404
- result_df = self.compare_process_custom([npu_json, bench_json, stack_json], stack_mode,
405
- summary_compare, md5_compare)
437
+ result_df = self.compare_process_custom([npu_json, bench_json, stack_json])
406
438
  else:
407
- result_df = self.compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match,
408
- summary_compare, md5_compare)
439
+ result_df = self.compare_process([npu_json, bench_json, stack_json])
409
440
 
410
441
  if not result_df.values.tolist():
411
442
  logger.warning("Can`t match any op.")
412
443
  return
413
444
 
414
- if not md5_compare and not summary_compare:
415
- result_df = self._do_multi_process(input_parma, result_df)
445
+ if self.dump_mode == Const.ALL:
446
+ result_df = self.do_multi_process(input_param, result_df)
416
447
 
417
- logger.info("Highlight suspicious API/Module start.")
418
- find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare)
448
+ find_compare_result_error_rows(result_df, highlight_dict, self.dump_mode)
419
449
  highlight_rows_xlsx(result_df, highlight_dict, file_path)
420
- logger.info("Highlight suspicious API/Module finish.")
421
450
 
422
- if auto_analyze:
451
+ if self.auto_analyze:
423
452
  advisor = Advisor(result_df, output_path, suffix)
424
453
  advisor.analysis()
425
-
454
+
455
+ print_compare_ends_info()
456
+
426
457
  def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
427
458
  cos_result = []
428
459
  max_err_result = []
@@ -431,13 +462,16 @@ class Comparator:
431
462
  one_thousand_err_ratio_result = []
432
463
  five_thousand_err_ratio_result = []
433
464
  is_print_compare_log = input_param.get("is_print_compare_log")
465
+ bench_data = load_json(input_param.get("bench_json_path")).get('data')
434
466
  for i in range(len(result_df)):
435
467
  npu_op_name = result_df.iloc[i, 0]
436
468
  bench_op_name = result_df.iloc[i, 1]
437
469
  if is_print_compare_log:
438
470
  logger.info("start compare: {}".format(npu_op_name))
471
+
439
472
  cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = \
440
- self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
473
+ self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param, bench_data)
474
+
441
475
  if is_print_compare_log:
442
476
  logger.info(
443
477
  "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
@@ -460,9 +494,9 @@ class Comparator:
460
494
  five_thousand_err_ratio_result=five_thousand_err_ratio_result
461
495
  )
462
496
 
463
- return _save_cmp_result(idx, cr, result_df, lock)
464
-
465
- def _do_multi_process(self, input_parma, result_df):
497
+ return _save_cmp_result(idx, cr, result_df, lock)
498
+
499
+ def do_multi_process(self, input_parma, result_df):
466
500
  try:
467
501
  result_df = _handle_multi_process(self.compare_ops, input_parma, result_df,
468
502
  multiprocessing.Manager().RLock())
@@ -470,4 +504,46 @@ class Comparator:
470
504
  except ValueError as e:
471
505
  logger.error('result dataframe is not found.')
472
506
  raise CompareException(CompareException.INVALID_DATA_ERROR) from e
473
-
507
+
508
+
509
+ def get_bench_data_name(bench_op_name, bench_data):
510
+ bench_name_list = re.split(r'\.(input|output|kwargs|parameters|parameters_grad)\.', bench_op_name)
511
+ if len(bench_name_list) > 1 and bench_name_list[1] == Const.PARAMS_GRAD:
512
+ bench_data_bundle = bench_data.get(bench_name_list[0] + Const.SEP + bench_name_list[1], {})
513
+ else:
514
+ bench_data_bundle = bench_data.get(bench_name_list[0], {})
515
+ if not bench_data_bundle or len(bench_name_list) < 3:
516
+ return None
517
+ layers = bench_name_list[2].split(Const.SEP)
518
+
519
+ def _get(key, container):
520
+ if isinstance(container, dict):
521
+ return container.get(key)
522
+ if isinstance(container, list):
523
+ try:
524
+ return container[int(key)]
525
+ except (ValueError, IndexError):
526
+ return None
527
+ return None
528
+
529
+ def get_by_layer(container, params_grad=False):
530
+ data = container
531
+ # dump.json中parameters_grad的结构为key:[{}], 如果存在key,有且只有一个列表元素,而op_name中只命名到了key,因此加'0'
532
+ if params_grad:
533
+ layers.append('0')
534
+ for layer in layers:
535
+ data = _get(layer, data)
536
+ return _get(CompareConst.DATA_NAME.lower(), data)
537
+
538
+ if Const.INPUT == bench_name_list[1]:
539
+ return get_by_layer(bench_data_bundle.get(Const.INPUT, bench_data_bundle.get(Const.INPUT_ARGS)))
540
+ elif Const.KWARGS == bench_name_list[1]:
541
+ return get_by_layer(bench_data_bundle.get(Const.INPUT_KWARGS))
542
+ elif Const.OUTPUT == bench_name_list[1]:
543
+ return get_by_layer(bench_data_bundle.get(Const.OUTPUT))
544
+ elif Const.PARAMS == bench_name_list[1]:
545
+ return get_by_layer(bench_data_bundle.get(Const.PARAMS))
546
+ elif Const.PARAMS_GRAD == bench_name_list[1]:
547
+ return get_by_layer(bench_data_bundle, params_grad=True)
548
+ else:
549
+ return None