mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,88 +1,148 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import multiprocessing
2
17
  import os
3
- import json
18
+ import re
19
+ from copy import deepcopy
20
+
4
21
  import pandas as pd
5
- from msprobe.core.common.file_utils import FileOpen
22
+ from msprobe.core.advisor.advisor import Advisor
6
23
  from msprobe.core.common.const import CompareConst, Const
7
24
  from msprobe.core.common.exceptions import FileCheckException
8
- from msprobe.core.common.log import logger
9
- from msprobe.core.common.utils import add_time_with_xlsx, CompareException
25
+ from msprobe.core.common.file_utils import load_json
10
26
  from msprobe.core.common.file_utils import remove_path
11
- from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op
27
+ from msprobe.core.common.log import logger
28
+ from msprobe.core.common.utils import add_time_with_xlsx, CompareException, check_op_str_pattern_valid, safe_get_value
29
+ from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op, check_dump_json_str, \
30
+ check_stack_json_str
12
31
  from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
13
- from msprobe.core.compare.utils import read_op, merge_tensor, get_un_match_accuracy, get_accuracy
14
32
  from msprobe.core.compare.multiprocessing_compute import _handle_multi_process, ComparisonResult, _save_cmp_result
15
33
  from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \
16
34
  get_error_message
17
- from msprobe.core.advisor.advisor import Advisor
35
+ from msprobe.core.compare.utils import read_op, merge_tensor, get_un_match_accuracy, get_accuracy, \
36
+ get_rela_diff_summary_mode, print_compare_ends_info
37
+ from tqdm import tqdm
18
38
 
19
39
 
20
40
  class Comparator:
21
-
41
+
22
42
  def __init__(self):
23
43
  pass
24
-
25
- @classmethod
26
- def make_result_table(cls,result, md5_compare, summary_compare, stack_mode):
27
- header = []
28
- if md5_compare:
29
- header = CompareConst.MD5_COMPARE_RESULT_HEADER[:]
30
- elif summary_compare:
31
- header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
44
+
45
+ @staticmethod
46
+ def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args):
47
+ npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
48
+ bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
49
+
50
+ if len(npu_struct) < 3 or len(bench_struct) < 3:
51
+ logger.error(f"The length of npu_struct and bench_struct must be >= 3, "
52
+ f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. Please check!")
53
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
54
+
55
+ result_item = [ms_op_name, bench_op_name, npu_struct[0], bench_struct[0],
56
+ npu_struct[1], bench_struct[1], npu_struct[2], bench_struct[2],
57
+ CompareConst.PASS if npu_struct[2] == bench_struct[2] else CompareConst.DIFF]
58
+
59
+ if len(args) >= 2 and args[0]:
60
+ result_item.extend(args[1])
32
61
  else:
33
- header = CompareConst.COMPARE_RESULT_HEADER[:]
62
+ result_item.append(CompareConst.NONE)
63
+ return result_item
64
+
65
+ @staticmethod
66
+ def calculate_summary_data(npu_summary_data, bench_summary_data, result_item):
67
+ err_msg = ""
68
+ result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
69
+ bench_summary_data, err_msg)
70
+ result_item.append(accuracy_check)
71
+ result_item.append(err_msg)
72
+
73
+ @staticmethod
74
+ def _generate_na_data(ops_all):
75
+ if not ops_all:
76
+ return {}
77
+ key = next(iter(ops_all))
78
+ value = deepcopy(ops_all[key])
79
+ for k, v in value.items():
80
+ if isinstance(v, tuple):
81
+ value[k] = tuple(CompareConst.N_A for _ in range(len(v)))
82
+ elif isinstance(v, list):
83
+ value[k] = [CompareConst.N_A] * len(v)
84
+ else:
85
+ value[k] = CompareConst.N_A
86
+ return value
87
+
88
+ @classmethod
89
+ def make_result_table(cls, result, stack_mode, dump_mode):
90
+ header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode][:]
34
91
 
35
- all_mode_bool = not (summary_compare or md5_compare)
36
92
  if stack_mode:
37
- if all_mode_bool:
38
- header.append(CompareConst.STACK)
93
+ header.append(CompareConst.STACK)
94
+ if dump_mode == Const.ALL:
39
95
  header.append(CompareConst.DATA_NAME)
40
- else:
41
- header.append(CompareConst.STACK)
42
96
  else:
43
- if all_mode_bool:
97
+ if dump_mode == Const.ALL:
44
98
  for row in result:
45
- del row[-2]
99
+ del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列
46
100
  header.append(CompareConst.DATA_NAME)
47
101
  else:
48
102
  for row in result:
49
- del row[-1]
50
- result_df = pd.DataFrame(result, columns=header)
51
- return result_df
52
-
103
+ del row[-1] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,非真实数据时为倒数第1列
104
+ result_df = pd.DataFrame(result, columns=header, dtype='object')
105
+ return result_df
106
+
53
107
  @classmethod
54
- def gen_merge_list(self, json_data, op_name,stack_json_data, summary_compare, md5_compare):
108
+ def gen_merge_list(cls, json_data, op_name, stack_json_data, dump_mode):
55
109
  op_data = json_data['data'][op_name]
110
+ check_dump_json_str(op_data, op_name)
56
111
  op_parsed_list = read_op(op_data, op_name)
57
- if op_name in stack_json_data:
58
- op_parsed_list.append({'full_op_name': op_name, 'full_info': stack_json_data[op_name]})
59
- else:
60
- op_parsed_list.append({'full_op_name': op_name, 'full_info': None})
61
-
62
- merge_list = merge_tensor(op_parsed_list, summary_compare, md5_compare)
112
+
113
+ stack_info = stack_json_data.get(op_name)
114
+ if stack_info is not None:
115
+ check_stack_json_str(stack_info, op_name)
116
+ op_parsed_list.append({
117
+ 'full_op_name': op_name,
118
+ 'full_info': stack_info
119
+ })
120
+
121
+ merge_list = merge_tensor(op_parsed_list, dump_mode)
63
122
  return merge_list
64
-
123
+
65
124
  def check_op(self, npu_dict, bench_dict, fuzzy_match):
66
- a_op_name = npu_dict["op_name"]
67
- b_op_name = bench_dict["op_name"]
68
- graph_mode = check_graph_mode(a_op_name[0], b_op_name[0])
69
-
70
- frame_name = getattr(self,"frame_name")
125
+ npu_op_name = npu_dict[CompareConst.OP_NAME]
126
+ bench_op_name = bench_dict[CompareConst.OP_NAME]
127
+ graph_mode = check_graph_mode(safe_get_value(npu_op_name, 0, "npu_op_name"),
128
+ safe_get_value(bench_op_name, 0, "bench_op_name"))
129
+
130
+ frame_name = getattr(self, "frame_name")
71
131
  if frame_name == "PTComparator":
72
132
  from msprobe.pytorch.compare.match import graph_mapping
73
133
  if graph_mode:
74
- return graph_mapping.match(a_op_name[0], b_op_name[0])
134
+ return graph_mapping.match(npu_op_name[0], bench_op_name[0])
75
135
  struct_match = check_struct_match(npu_dict, bench_dict)
76
136
  if not fuzzy_match:
77
- return a_op_name == b_op_name and struct_match
137
+ return npu_op_name == bench_op_name and struct_match
78
138
  is_match = True
79
139
  try:
80
- is_match = fuzzy_check_op(a_op_name, b_op_name)
140
+ is_match = fuzzy_check_op(npu_op_name, bench_op_name)
81
141
  except Exception as err:
82
- logger.warning("%s and %s can not fuzzy match." % (a_op_name, b_op_name))
142
+ logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
83
143
  is_match = False
84
144
  return is_match and struct_match
85
-
145
+
86
146
  def match_op(self, npu_queue, bench_queue, fuzzy_match):
87
147
  for b_index, b_op in enumerate(bench_queue[0: -1]):
88
148
  if self.check_op(npu_queue[-1], b_op, fuzzy_match):
@@ -93,12 +153,12 @@ class Comparator:
93
153
  if self.check_op(n_op, bench_queue[-1], fuzzy_match):
94
154
  return n_index, len(bench_queue) - 1
95
155
  return -1, -1
96
-
97
- def compare_process(self, file_handles, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False):
98
- npu_json_handle, bench_json_handle, stack_json_handle = file_handles
99
- npu_json_data = json.load(npu_json_handle)
100
- bench_json_data = json.load(bench_json_handle)
101
- stack_json_data = json.load(stack_json_handle)
156
+
157
+ def compare_process(self, file_lists, stack_mode, fuzzy_match, dump_mode):
158
+ npu_json_path, bench_json_path, stack_json_path = file_lists
159
+ npu_json_data = load_json(npu_json_path)
160
+ bench_json_data = load_json(bench_json_path)
161
+ stack_json_data = load_json(stack_json_path)
102
162
 
103
163
  if fuzzy_match:
104
164
  logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
@@ -114,14 +174,18 @@ class Comparator:
114
174
  last_npu_ops_len = 0
115
175
  last_bench_ops_len = 0
116
176
 
177
+ npu_api_nums = len(npu_json_data['data'])
178
+ progress_bar = tqdm(total=npu_api_nums, desc="API/Module Read Progress", unit="item", ncols=100)
179
+
117
180
  while True:
118
181
  if not read_err_npu and not read_err_bench:
119
182
  break
120
183
  try:
121
184
  last_npu_ops_len = len(npu_ops_queue)
122
185
  op_name_npu = next(ops_npu_iter)
186
+ check_op_str_pattern_valid(op_name_npu)
123
187
  read_err_npu = True
124
- npu_merge_list = self.gen_merge_list(npu_json_data,op_name_npu,stack_json_data,summary_compare,md5_compare)
188
+ npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data, dump_mode)
125
189
  if npu_merge_list:
126
190
  npu_ops_queue.append(npu_merge_list)
127
191
  except StopIteration:
@@ -129,12 +193,15 @@ class Comparator:
129
193
  try:
130
194
  last_bench_ops_len = len(bench_ops_queue)
131
195
  op_name_bench = next(ops_bench_iter)
132
- bench_merge_list = self.gen_merge_list(bench_json_data,op_name_bench,stack_json_data,summary_compare,md5_compare)
196
+ check_op_str_pattern_valid(op_name_bench)
197
+ bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data, dump_mode)
133
198
  if bench_merge_list:
134
199
  bench_ops_queue.append(bench_merge_list)
135
200
  except StopIteration:
136
201
  read_err_bench = False
137
202
 
203
+ progress_bar.update(1)
204
+
138
205
  # merge all boolean expressions
139
206
  both_empty = not npu_ops_queue and not bench_ops_queue
140
207
  no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len)
@@ -153,24 +220,144 @@ class Comparator:
153
220
  b_match_data = bench_ops_queue[b_match_point]
154
221
  un_match_data = npu_ops_queue[0: n_match_point]
155
222
  for npu_data in un_match_data:
156
- get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
157
- get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare)
223
+ get_un_match_accuracy(result, npu_data, dump_mode)
224
+ get_accuracy(result, n_match_data, b_match_data, dump_mode)
158
225
  del npu_ops_queue[0: n_match_point + 1]
159
226
  del bench_ops_queue[0: b_match_point + 1]
227
+ progress_bar.close()
160
228
  if npu_ops_queue:
161
229
  for npu_data in npu_ops_queue:
162
- get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
163
-
164
- result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
230
+ get_un_match_accuracy(result, npu_data, dump_mode)
231
+
232
+ result_df = self.make_result_table(result, stack_mode, dump_mode)
165
233
  return result_df
166
-
167
- def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
234
+
235
+ def merge_data(self, json_data, stack_json_data, dump_mode):
236
+ ops_all = {}
237
+ for op_name in json_data.get('data', {}):
238
+ merge_list = self.gen_merge_list(json_data, op_name, stack_json_data, dump_mode)
239
+ if merge_list:
240
+ input_index, output_index = 0, 0
241
+ for index, input_or_output in enumerate(merge_list[CompareConst.OP_NAME]):
242
+ input_or_output_list = input_or_output.split(Const.SEP)
243
+ data_name = merge_list.get('data_name')
244
+ data_name = data_name[index] if data_name else None
245
+ if Const.INPUT in input_or_output_list or Const.KWARGS in input_or_output_list:
246
+ ops_all[input_or_output] = {
247
+ CompareConst.STRUCT: safe_get_value(merge_list, input_index, "merge_list",
248
+ key=CompareConst.INPUT_STRUCT),
249
+ CompareConst.SUMMARY: safe_get_value(merge_list, index, "merge_list",
250
+ key=CompareConst.SUMMARY),
251
+ 'data_name': data_name,
252
+ 'stack_info': merge_list.get('stack_info')
253
+ }
254
+ input_index += 1
255
+
256
+ elif Const.OUTPUT in input_or_output_list:
257
+ ops_all[input_or_output] = {
258
+ CompareConst.STRUCT: safe_get_value(merge_list, output_index, "merge_list",
259
+ key=CompareConst.OUTPUT_STRUCT),
260
+ CompareConst.SUMMARY: safe_get_value(merge_list, index, "merge_list",
261
+ key=CompareConst.SUMMARY),
262
+ 'data_name': data_name,
263
+ 'stack_info': merge_list.get('stack_info')
264
+ }
265
+ output_index += 1
266
+ return ops_all
267
+
268
+ def get_accuracy(self, npu_ops_all, bench_ops_all, dump_mode):
269
+ result = []
270
+ bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all)
271
+ for ms_op_name, bench_op_name in self.data_mapping_dict.items():
272
+ if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
273
+ npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
274
+ bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
275
+ has_stack = npu_stack_info and bench_stack_info
276
+ if dump_mode == Const.MD5:
277
+ result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all,
278
+ bench_ops_all, has_stack, npu_stack_info))
279
+ continue
280
+
281
+ npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
282
+ bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
283
+
284
+ if len(npu_struct) < 2 or len(bench_struct) < 2:
285
+ logger.error(
286
+ f"The length of npu_struct and bench_struct must be >= 2, "
287
+ f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. "
288
+ f"Please check!"
289
+ )
290
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
291
+
292
+ base_result_item = [
293
+ ms_op_name, bench_op_name,
294
+ npu_struct[0],
295
+ bench_struct[0],
296
+ npu_struct[1],
297
+ bench_struct[1]
298
+ ]
299
+
300
+ if dump_mode == Const.SUMMARY:
301
+ result_item = base_result_item + [" "] * 8
302
+ else:
303
+ result_item = base_result_item + [" "] * 5
304
+
305
+ npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
306
+ result_item.extend(npu_summary_data)
307
+ bench_summary_data = bench_ops_all.get(bench_op_name).get("summary")
308
+ result_item.extend(bench_summary_data)
309
+ if dump_mode == Const.SUMMARY:
310
+ self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item)
311
+ else:
312
+ result_item.append(CompareConst.ACCURACY_CHECK_YES)
313
+ result_item.append("")
314
+ if has_stack:
315
+ result_item.extend(npu_stack_info)
316
+ else:
317
+ result_item.append(CompareConst.NONE)
318
+ if dump_mode == Const.ALL:
319
+ result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None))
320
+ result.append(result_item)
321
+ elif ms_op_name not in npu_ops_all:
322
+ logger.warning(f'Can not find npu op name : `{ms_op_name}` in npu dump json file.')
323
+ elif bench_op_name not in npu_ops_all:
324
+ logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.')
325
+ return result
326
+
327
+ def compare_process_custom(self, file_lists, stack_mode, dump_mode):
328
+ npu_json_path, bench_json_path, stack_json_path = file_lists
329
+ npu_json_data = load_json(npu_json_path)
330
+ bench_json_data = load_json(bench_json_path)
331
+ stack_json_data = load_json(stack_json_path)
332
+
333
+ npu_ops_all = self.merge_data(npu_json_data, stack_json_data, dump_mode)
334
+ bench_ops_all = self.merge_data(bench_json_data, stack_json_data, dump_mode)
335
+
336
+ result = self.get_accuracy(npu_ops_all, bench_ops_all, dump_mode)
337
+ result_df = self.make_result_table(result, stack_mode, dump_mode)
338
+ return result_df
339
+
340
+ def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param, bench_data):
341
+ """
342
+ :param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0
343
+ :param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
344
+ :param op_name_mapping_dict: op_name和npy或pt文件的映射关系
345
+ :param input_param: npu_json_path/bench_json_path/stack_json_path等参数
346
+ :param bench_data: bench的dump数据中"data"字段
347
+ :return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
348
+ 用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、
349
+ 最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
350
+ """
168
351
  npu_bench_name_list = op_name_mapping_dict[npu_op_name]
169
- data_name = npu_bench_name_list[1]
352
+ data_name = safe_get_value(npu_bench_name_list, 1, "npu_bench_name_list")
170
353
  error_file, relative_err, error_flag = None, None, False
354
+ bench_data_name = get_bench_data_name(bench_op_name, bench_data)
171
355
  if data_name == '-1' or data_name == -1: # 没有真实数据路径
172
356
  n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
173
357
  error_flag = True
358
+ elif not bench_data_name:
359
+ n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
360
+ error_file = 'no_bench_data'
174
361
  else:
175
362
  try:
176
363
  read_npy_data = getattr(self, "read_npy_data")
@@ -178,17 +365,18 @@ class Comparator:
178
365
  if frame_name == "MSComparator":
179
366
  n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX)
180
367
  if self.cross_frame:
181
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.PT_SUFFIX, load_pt_file=True)
368
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name,
369
+ load_pt_file=True)
182
370
  else:
183
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.NUMPY_SUFFIX)
371
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
184
372
  else:
185
373
  n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX)
186
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.PT_SUFFIX)
374
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
187
375
  except IOError as error:
188
376
  error_file = error.filename
189
377
  n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
190
378
  error_flag = True
191
- except FileCheckException:
379
+ except (FileCheckException, CompareException):
192
380
  error_file = data_name
193
381
  n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
194
382
  error_flag = True
@@ -205,7 +393,7 @@ class Comparator:
205
393
  err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
206
394
  result_list.append(err_msg)
207
395
  return result_list
208
-
396
+
209
397
  def compare_core(self, input_parma, output_path, **kwargs):
210
398
  """
211
399
  Compares data from multiple JSON files and generates a comparison report.
@@ -219,8 +407,7 @@ class Comparator:
219
407
  - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
220
408
  - suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
221
409
  - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
222
- - summary_compare (bool, optional): Enables summary comparison mode. Defaults to False.
223
- - md5_compare (bool, optional): Enables MD5 comparison. Defaults to False.
410
+ - dump_mode (str): ALL, SUMMARY, MD5.
224
411
 
225
412
  Returns:
226
413
  """
@@ -229,29 +416,43 @@ class Comparator:
229
416
  auto_analyze = kwargs.get('auto_analyze', True)
230
417
  suffix = kwargs.get('suffix', '')
231
418
  fuzzy_match = kwargs.get('fuzzy_match', False)
232
- summary_compare = kwargs.get('summary_compare', False)
233
- md5_compare = kwargs.get('md5_compare', False)
419
+ dump_mode = kwargs.get('dump_mode', None)
234
420
 
235
421
  logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
236
422
  file_name = add_time_with_xlsx("compare_result" + suffix)
237
423
  file_path = os.path.join(os.path.realpath(output_path), file_name)
238
424
  remove_path(file_path)
239
- highlight_dict = {'red_rows': [], 'yellow_rows': []}
240
-
241
- with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \
242
- FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \
243
- FileOpen(input_parma.get("stack_json_path"), "r") as stack_json:
244
- result_df = self.compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match,
245
- summary_compare, md5_compare)
246
-
247
- if not md5_compare and not summary_compare:
248
- result_df = self._do_multi_process(input_parma, result_df)
249
- find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare)
425
+ highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
426
+
427
+ npu_json = input_parma.get("npu_json_path")
428
+ bench_json = input_parma.get("bench_json_path")
429
+ stack_json = input_parma.get("stack_json_path")
430
+ if self.data_mapping:
431
+ result_df = self.compare_process_custom([npu_json, bench_json, stack_json], stack_mode, dump_mode)
432
+ else:
433
+ result_df = self.compare_process(
434
+ [npu_json, bench_json, stack_json],
435
+ stack_mode,
436
+ fuzzy_match,
437
+ dump_mode
438
+ )
439
+
440
+ if not result_df.values.tolist():
441
+ logger.warning("Can`t match any op.")
442
+ return
443
+
444
+ if dump_mode == Const.ALL:
445
+ result_df = self.do_multi_process(input_parma, result_df)
446
+
447
+ find_compare_result_error_rows(result_df, highlight_dict, dump_mode)
250
448
  highlight_rows_xlsx(result_df, highlight_dict, file_path)
449
+
251
450
  if auto_analyze:
252
- advisor = Advisor(result_df, output_path)
451
+ advisor = Advisor(result_df, output_path, suffix)
253
452
  advisor.analysis()
254
-
453
+
454
+ print_compare_ends_info()
455
+
255
456
  def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
256
457
  cos_result = []
257
458
  max_err_result = []
@@ -260,18 +461,22 @@ class Comparator:
260
461
  one_thousand_err_ratio_result = []
261
462
  five_thousand_err_ratio_result = []
262
463
  is_print_compare_log = input_param.get("is_print_compare_log")
464
+ bench_data = load_json(input_param.get("bench_json_path")).get('data')
263
465
  for i in range(len(result_df)):
264
466
  npu_op_name = result_df.iloc[i, 0]
265
467
  bench_op_name = result_df.iloc[i, 1]
266
468
  if is_print_compare_log:
267
469
  logger.info("start compare: {}".format(npu_op_name))
268
- cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = self.compare_by_op(
269
- npu_op_name, bench_op_name, dump_path_dict, input_param)
470
+
471
+ cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = \
472
+ self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param, bench_data)
473
+
270
474
  if is_print_compare_log:
271
475
  logger.info(
272
- "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, one_thousand_err_ratio {}, "
273
- "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err, err_msg,
274
- one_thousand_err_ratio, five_thousand_err_ratio))
476
+ "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
477
+ one_thousand_err_ratio {}, "
478
+ "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err,
479
+ err_msg, one_thousand_err_ratio, five_thousand_err_ratio))
275
480
  cos_result.append(cos_sim)
276
481
  max_err_result.append(max_abs_err)
277
482
  max_relative_err_result.append(max_relative_err)
@@ -288,13 +493,46 @@ class Comparator:
288
493
  five_thousand_err_ratio_result=five_thousand_err_ratio_result
289
494
  )
290
495
 
291
- return _save_cmp_result(idx, cr, result_df, lock)
292
-
293
- def _do_multi_process(self,input_parma, result_df):
496
+ return _save_cmp_result(idx, cr, result_df, lock)
497
+
498
+ def do_multi_process(self, input_parma, result_df):
294
499
  try:
295
- result_df = _handle_multi_process(self.compare_ops, input_parma, result_df, multiprocessing.Manager().RLock())
500
+ result_df = _handle_multi_process(self.compare_ops, input_parma, result_df,
501
+ multiprocessing.Manager().RLock())
296
502
  return result_df
297
503
  except ValueError as e:
298
504
  logger.error('result dataframe is not found.')
299
505
  raise CompareException(CompareException.INVALID_DATA_ERROR) from e
300
-
506
+
507
+ def get_bench_data_name(bench_op_name, bench_data):
508
+ bench_name_list = re.split(r'\.(input|output|kwargs)\.', bench_op_name)
509
+ bench_data_bundle = bench_data.get(bench_name_list[0], {})
510
+ if not bench_data_bundle or len(bench_name_list) < 3:
511
+ return None
512
+ layers = bench_name_list[2].split(Const.SEP)
513
+
514
+ def get(key, container):
515
+ if isinstance(container, dict):
516
+ return container.get(key)
517
+ if isinstance(container, list):
518
+ try:
519
+ return container[int(key)]
520
+ except (ValueError, IndexError):
521
+ return None
522
+ return None
523
+
524
+ def get_by_layer(container):
525
+ data = container
526
+ for layer in layers:
527
+ data = get(layer, data)
528
+ return get(CompareConst.DATA_NAME.lower(), data)
529
+
530
+ if Const.INPUT == bench_name_list[1]:
531
+ return get_by_layer(bench_data_bundle.get(Const.INPUT, bench_data_bundle.get(Const.INPUT_ARGS)))
532
+ elif Const.KWARGS == bench_name_list[1]:
533
+ return get_by_layer(bench_data_bundle.get(Const.INPUT_KWARGS))
534
+ elif Const.OUTPUT == bench_name_list[1]:
535
+ return get_by_layer(bench_data_bundle.get(Const.OUTPUT))
536
+ else:
537
+ return None
538
+