mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -15,41 +15,48 @@
15
15
 
16
16
  import multiprocessing
17
17
  import os
18
+ import re
19
+ from copy import deepcopy
20
+
18
21
  import pandas as pd
19
- from tqdm import tqdm
20
- from msprobe.core.common.file_utils import load_json
22
+ from msprobe.core.advisor.advisor import Advisor
21
23
  from msprobe.core.common.const import CompareConst, Const
22
24
  from msprobe.core.common.exceptions import FileCheckException
23
- from msprobe.core.common.log import logger
24
- from msprobe.core.common.utils import add_time_with_xlsx, CompareException, check_op_str_pattern_valid
25
+ from msprobe.core.common.file_utils import load_json
25
26
  from msprobe.core.common.file_utils import remove_path
27
+ from msprobe.core.common.log import logger
28
+ from msprobe.core.common.utils import add_time_with_xlsx, CompareException, check_op_str_pattern_valid, safe_get_value
26
29
  from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op, check_dump_json_str, \
27
- check_stack_json_str
30
+ check_stack_json_str
28
31
  from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
29
- from msprobe.core.compare.utils import read_op, merge_tensor, get_un_match_accuracy, get_accuracy
30
32
  from msprobe.core.compare.multiprocessing_compute import _handle_multi_process, ComparisonResult, _save_cmp_result
31
33
  from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \
32
34
  get_error_message
33
- from msprobe.core.advisor.advisor import Advisor
35
+ from msprobe.core.compare.utils import read_op, merge_tensor, get_un_match_accuracy, get_accuracy, \
36
+ get_rela_diff_summary_mode, print_compare_ends_info
37
+ from tqdm import tqdm
34
38
 
35
39
 
36
40
  class Comparator:
37
-
41
+
38
42
  def __init__(self):
39
43
  pass
40
44
 
41
45
  @staticmethod
42
46
  def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args):
43
- result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
44
- bench_ops_all.get(bench_op_name).get('struct')[0],
45
- npu_ops_all.get(ms_op_name).get('struct')[1],
46
- bench_ops_all.get(bench_op_name).get('struct')[1],
47
- npu_ops_all.get(ms_op_name).get('struct')[2],
48
- bench_ops_all.get(bench_op_name).get('struct')[2],
49
- CompareConst.PASS if npu_ops_all.get(ms_op_name).get('struct')[2]
50
- == bench_ops_all.get(bench_op_name).get('struct')[2]
51
- else CompareConst.DIFF]
52
- if args[0]:
47
+ npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
48
+ bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
49
+
50
+ if len(npu_struct) < 3 or len(bench_struct) < 3:
51
+ logger.error(f"The length of npu_struct and bench_struct must be >= 3, "
52
+ f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. Please check!")
53
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
54
+
55
+ result_item = [ms_op_name, bench_op_name, npu_struct[0], bench_struct[0],
56
+ npu_struct[1], bench_struct[1], npu_struct[2], bench_struct[2],
57
+ CompareConst.PASS if npu_struct[2] == bench_struct[2] else CompareConst.DIFF]
58
+
59
+ if len(args) >= 2 and args[0]:
53
60
  result_item.extend(args[1])
54
61
  else:
55
62
  result_item.append(CompareConst.NONE)
@@ -58,59 +65,47 @@ class Comparator:
58
65
  @staticmethod
59
66
  def calculate_summary_data(npu_summary_data, bench_summary_data, result_item):
60
67
  err_msg = ""
61
- start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
62
- warning_flag = False
63
- for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
64
- if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
65
- diff = npu_val - bench_val
66
- if bench_val != 0:
67
- relative = str(abs((diff / bench_val) * 100)) + '%'
68
- else:
69
- relative = "N/A"
70
- result_item[start_idx + i] = diff
71
- result_item[start_idx + i + 4] = relative
72
- magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
73
- if magnitude_diff > 0.5:
74
- warning_flag = True
75
- else:
76
- result_item[start_idx + i] = CompareConst.NONE
77
- accuracy_check = CompareConst.WARNING if warning_flag else ""
78
- err_msg += "Need double check api accuracy." if warning_flag else ""
79
- for i in range(start_idx, len(result_item)):
80
- if str(result_item[i]) in ('inf', '-inf', 'nan'):
81
- result_item[i] = f'{result_item[i]}\t'
68
+ result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
69
+ bench_summary_data, err_msg)
82
70
  result_item.append(accuracy_check)
83
71
  result_item.append(err_msg)
84
-
72
+
73
+ @staticmethod
74
+ def _generate_na_data(ops_all):
75
+ if not ops_all:
76
+ return {}
77
+ key = next(iter(ops_all))
78
+ value = deepcopy(ops_all[key])
79
+ for k, v in value.items():
80
+ if isinstance(v, tuple):
81
+ value[k] = tuple(CompareConst.N_A for _ in range(len(v)))
82
+ elif isinstance(v, list):
83
+ value[k] = [CompareConst.N_A] * len(v)
84
+ else:
85
+ value[k] = CompareConst.N_A
86
+ return value
87
+
85
88
  @classmethod
86
- def make_result_table(cls, result, md5_compare, summary_compare, stack_mode):
87
- if md5_compare:
88
- header = CompareConst.MD5_COMPARE_RESULT_HEADER[:]
89
- elif summary_compare:
90
- header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
91
- else:
92
- header = CompareConst.COMPARE_RESULT_HEADER[:]
89
+ def make_result_table(cls, result, stack_mode, dump_mode):
90
+ header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode][:]
93
91
 
94
- all_mode_bool = not (summary_compare or md5_compare)
95
92
  if stack_mode:
96
- if all_mode_bool:
97
- header.append(CompareConst.STACK)
93
+ header.append(CompareConst.STACK)
94
+ if dump_mode == Const.ALL:
98
95
  header.append(CompareConst.DATA_NAME)
99
- else:
100
- header.append(CompareConst.STACK)
101
96
  else:
102
- if all_mode_bool:
97
+ if dump_mode == Const.ALL:
103
98
  for row in result:
104
- del row[-2]
99
+ del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列
105
100
  header.append(CompareConst.DATA_NAME)
106
101
  else:
107
102
  for row in result:
108
- del row[-1]
103
+ del row[-1] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,非真实数据时为倒数第1列
109
104
  result_df = pd.DataFrame(result, columns=header, dtype='object')
110
- return result_df
111
-
105
+ return result_df
106
+
112
107
  @classmethod
113
- def gen_merge_list(cls, json_data, op_name, stack_json_data, summary_compare, md5_compare):
108
+ def gen_merge_list(cls, json_data, op_name, stack_json_data, dump_mode):
114
109
  op_data = json_data['data'][op_name]
115
110
  check_dump_json_str(op_data, op_name)
116
111
  op_parsed_list = read_op(op_data, op_name)
@@ -122,31 +117,32 @@ class Comparator:
122
117
  'full_op_name': op_name,
123
118
  'full_info': stack_info
124
119
  })
125
-
126
- merge_list = merge_tensor(op_parsed_list, summary_compare, md5_compare)
120
+
121
+ merge_list = merge_tensor(op_parsed_list, dump_mode)
127
122
  return merge_list
128
-
123
+
129
124
  def check_op(self, npu_dict, bench_dict, fuzzy_match):
130
- a_op_name = npu_dict["op_name"]
131
- b_op_name = bench_dict["op_name"]
132
- graph_mode = check_graph_mode(a_op_name[0], b_op_name[0])
133
-
125
+ npu_op_name = npu_dict[CompareConst.OP_NAME]
126
+ bench_op_name = bench_dict[CompareConst.OP_NAME]
127
+ graph_mode = check_graph_mode(safe_get_value(npu_op_name, 0, "npu_op_name"),
128
+ safe_get_value(bench_op_name, 0, "bench_op_name"))
129
+
134
130
  frame_name = getattr(self, "frame_name")
135
131
  if frame_name == "PTComparator":
136
132
  from msprobe.pytorch.compare.match import graph_mapping
137
133
  if graph_mode:
138
- return graph_mapping.match(a_op_name[0], b_op_name[0])
134
+ return graph_mapping.match(npu_op_name[0], bench_op_name[0])
139
135
  struct_match = check_struct_match(npu_dict, bench_dict)
140
136
  if not fuzzy_match:
141
- return a_op_name == b_op_name and struct_match
137
+ return npu_op_name == bench_op_name and struct_match
142
138
  is_match = True
143
139
  try:
144
- is_match = fuzzy_check_op(a_op_name, b_op_name)
140
+ is_match = fuzzy_check_op(npu_op_name, bench_op_name)
145
141
  except Exception as err:
146
- logger.warning("%s and %s can not fuzzy match." % (a_op_name, b_op_name))
142
+ logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
147
143
  is_match = False
148
144
  return is_match and struct_match
149
-
145
+
150
146
  def match_op(self, npu_queue, bench_queue, fuzzy_match):
151
147
  for b_index, b_op in enumerate(bench_queue[0: -1]):
152
148
  if self.check_op(npu_queue[-1], b_op, fuzzy_match):
@@ -157,8 +153,8 @@ class Comparator:
157
153
  if self.check_op(n_op, bench_queue[-1], fuzzy_match):
158
154
  return n_index, len(bench_queue) - 1
159
155
  return -1, -1
160
-
161
- def compare_process(self, file_lists, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False):
156
+
157
+ def compare_process(self, file_lists, stack_mode, fuzzy_match, dump_mode):
162
158
  npu_json_path, bench_json_path, stack_json_path = file_lists
163
159
  npu_json_data = load_json(npu_json_path)
164
160
  bench_json_data = load_json(bench_json_path)
@@ -189,8 +185,7 @@ class Comparator:
189
185
  op_name_npu = next(ops_npu_iter)
190
186
  check_op_str_pattern_valid(op_name_npu)
191
187
  read_err_npu = True
192
- npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data,
193
- summary_compare, md5_compare)
188
+ npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data, dump_mode)
194
189
  if npu_merge_list:
195
190
  npu_ops_queue.append(npu_merge_list)
196
191
  except StopIteration:
@@ -199,8 +194,7 @@ class Comparator:
199
194
  last_bench_ops_len = len(bench_ops_queue)
200
195
  op_name_bench = next(ops_bench_iter)
201
196
  check_op_str_pattern_valid(op_name_bench)
202
- bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data,
203
- summary_compare, md5_compare)
197
+ bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data, dump_mode)
204
198
  if bench_merge_list:
205
199
  bench_ops_queue.append(bench_merge_list)
206
200
  except StopIteration:
@@ -226,71 +220,93 @@ class Comparator:
226
220
  b_match_data = bench_ops_queue[b_match_point]
227
221
  un_match_data = npu_ops_queue[0: n_match_point]
228
222
  for npu_data in un_match_data:
229
- get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
230
- get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare)
223
+ get_un_match_accuracy(result, npu_data, dump_mode)
224
+ get_accuracy(result, n_match_data, b_match_data, dump_mode)
231
225
  del npu_ops_queue[0: n_match_point + 1]
232
226
  del bench_ops_queue[0: b_match_point + 1]
227
+ progress_bar.close()
233
228
  if npu_ops_queue:
234
229
  for npu_data in npu_ops_queue:
235
- get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
236
-
237
- result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
230
+ get_un_match_accuracy(result, npu_data, dump_mode)
231
+
232
+ result_df = self.make_result_table(result, stack_mode, dump_mode)
238
233
  return result_df
239
234
 
240
- def merge_data(self, json_data, stack_json_data, summary_compare, md5_compare):
235
+ def merge_data(self, json_data, stack_json_data, dump_mode):
241
236
  ops_all = {}
242
237
  for op_name in json_data.get('data', {}):
243
- merge_list = self.gen_merge_list(json_data, op_name, stack_json_data, summary_compare,
244
- md5_compare)
238
+ merge_list = self.gen_merge_list(json_data, op_name, stack_json_data, dump_mode)
245
239
  if merge_list:
246
240
  input_index, output_index = 0, 0
247
- for index, input_or_output in enumerate(merge_list['op_name']):
241
+ for index, input_or_output in enumerate(merge_list[CompareConst.OP_NAME]):
248
242
  input_or_output_list = input_or_output.split(Const.SEP)
249
243
  data_name = merge_list.get('data_name')
250
244
  data_name = data_name[index] if data_name else None
251
245
  if Const.INPUT in input_or_output_list or Const.KWARGS in input_or_output_list:
252
- ops_all[input_or_output] = {'struct': merge_list.get('input_struct')[input_index],
253
- 'summary': merge_list.get('summary')[index],
254
- 'data_name': data_name,
255
- 'stack_info': merge_list.get('stack_info')}
246
+ ops_all[input_or_output] = {
247
+ CompareConst.STRUCT: safe_get_value(merge_list, input_index, "merge_list",
248
+ key=CompareConst.INPUT_STRUCT),
249
+ CompareConst.SUMMARY: safe_get_value(merge_list, index, "merge_list",
250
+ key=CompareConst.SUMMARY),
251
+ 'data_name': data_name,
252
+ 'stack_info': merge_list.get('stack_info')
253
+ }
256
254
  input_index += 1
257
255
 
258
256
  elif Const.OUTPUT in input_or_output_list:
259
- ops_all[input_or_output] = {'struct': merge_list.get('output_struct')[output_index],
260
- 'summary': merge_list.get('summary')[index],
261
- 'data_name': data_name,
262
- 'stack_info': merge_list.get('stack_info')}
257
+ ops_all[input_or_output] = {
258
+ CompareConst.STRUCT: safe_get_value(merge_list, output_index, "merge_list",
259
+ key=CompareConst.OUTPUT_STRUCT),
260
+ CompareConst.SUMMARY: safe_get_value(merge_list, index, "merge_list",
261
+ key=CompareConst.SUMMARY),
262
+ 'data_name': data_name,
263
+ 'stack_info': merge_list.get('stack_info')
264
+ }
263
265
  output_index += 1
264
266
  return ops_all
265
267
 
266
- def get_accuracy(self, npu_ops_all, bench_ops_all, summary_compare, md5_compare):
268
+ def get_accuracy(self, npu_ops_all, bench_ops_all, dump_mode):
267
269
  result = []
270
+ bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all)
268
271
  for ms_op_name, bench_op_name in self.data_mapping_dict.items():
269
272
  if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
270
273
  npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
271
274
  bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
272
275
  has_stack = npu_stack_info and bench_stack_info
273
- if md5_compare:
276
+ if dump_mode == Const.MD5:
274
277
  result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all,
275
278
  bench_ops_all, has_stack, npu_stack_info))
276
279
  continue
277
- if summary_compare:
278
- result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
279
- bench_ops_all.get(bench_op_name).get('struct')[0],
280
- npu_ops_all.get(ms_op_name).get('struct')[1],
281
- bench_ops_all.get(bench_op_name).get('struct')[1],
282
- " ", " ", " ", " ", " ", " ", " ", " "]
280
+
281
+ npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
282
+ bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
283
+
284
+ if len(npu_struct) < 2 or len(bench_struct) < 2:
285
+ logger.error(
286
+ f"The length of npu_struct and bench_struct must be >= 2, "
287
+ f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. "
288
+ f"Please check!"
289
+ )
290
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
291
+
292
+ base_result_item = [
293
+ ms_op_name, bench_op_name,
294
+ npu_struct[0],
295
+ bench_struct[0],
296
+ npu_struct[1],
297
+ bench_struct[1]
298
+ ]
299
+
300
+ if dump_mode == Const.SUMMARY:
301
+ result_item = base_result_item + [" "] * 8
283
302
  else:
284
- result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
285
- bench_ops_all.get(bench_op_name).get('struct')[0],
286
- npu_ops_all.get(ms_op_name).get('struct')[1],
287
- bench_ops_all.get(bench_op_name).get('struct')[1],
288
- " ", " ", " ", " ", " "]
303
+ result_item = base_result_item + [" "] * 5
304
+
289
305
  npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
290
306
  result_item.extend(npu_summary_data)
291
307
  bench_summary_data = bench_ops_all.get(bench_op_name).get("summary")
292
308
  result_item.extend(bench_summary_data)
293
- if summary_compare:
309
+ if dump_mode == Const.SUMMARY:
294
310
  self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item)
295
311
  else:
296
312
  result_item.append(CompareConst.ACCURACY_CHECK_YES)
@@ -299,7 +315,7 @@ class Comparator:
299
315
  result_item.extend(npu_stack_info)
300
316
  else:
301
317
  result_item.append(CompareConst.NONE)
302
- if not (summary_compare or md5_compare):
318
+ if dump_mode == Const.ALL:
303
319
  result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None))
304
320
  result.append(result_item)
305
321
  elif ms_op_name not in npu_ops_all:
@@ -308,26 +324,40 @@ class Comparator:
308
324
  logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.')
309
325
  return result
310
326
 
311
- def compare_process_custom(self, file_lists, stack_mode, summary_compare=False, md5_compare=False):
327
+ def compare_process_custom(self, file_lists, stack_mode, dump_mode):
312
328
  npu_json_path, bench_json_path, stack_json_path = file_lists
313
329
  npu_json_data = load_json(npu_json_path)
314
330
  bench_json_data = load_json(bench_json_path)
315
331
  stack_json_data = load_json(stack_json_path)
316
332
 
317
- npu_ops_all = self.merge_data(npu_json_data, stack_json_data, summary_compare, md5_compare)
318
- bench_ops_all = self.merge_data(bench_json_data, stack_json_data, summary_compare, md5_compare)
333
+ npu_ops_all = self.merge_data(npu_json_data, stack_json_data, dump_mode)
334
+ bench_ops_all = self.merge_data(bench_json_data, stack_json_data, dump_mode)
319
335
 
320
- result = self.get_accuracy(npu_ops_all, bench_ops_all, summary_compare, md5_compare)
321
- result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
336
+ result = self.get_accuracy(npu_ops_all, bench_ops_all, dump_mode)
337
+ result_df = self.make_result_table(result, stack_mode, dump_mode)
322
338
  return result_df
323
339
 
324
- def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
340
+ def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param, bench_data):
341
+ """
342
+ :param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0
343
+ :param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
344
+ :param op_name_mapping_dict: op_name和npy或pt文件的映射关系
345
+ :param input_param: npu_json_path/bench_json_path/stack_json_path等参数
346
+ :param bench_data: bench的dump数据中"data"字段
347
+ :return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
348
+ 用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、
349
+ 最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
350
+ """
325
351
  npu_bench_name_list = op_name_mapping_dict[npu_op_name]
326
- data_name = npu_bench_name_list[1]
352
+ data_name = safe_get_value(npu_bench_name_list, 1, "npu_bench_name_list")
327
353
  error_file, relative_err, error_flag = None, None, False
354
+ bench_data_name = get_bench_data_name(bench_op_name, bench_data)
328
355
  if data_name == '-1' or data_name == -1: # 没有真实数据路径
329
356
  n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
330
357
  error_flag = True
358
+ elif not bench_data_name:
359
+ n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
360
+ error_file = 'no_bench_data'
331
361
  else:
332
362
  try:
333
363
  read_npy_data = getattr(self, "read_npy_data")
@@ -335,19 +365,18 @@ class Comparator:
335
365
  if frame_name == "MSComparator":
336
366
  n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX)
337
367
  if self.cross_frame:
338
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
339
- bench_op_name + Const.PT_SUFFIX, load_pt_file=True)
368
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name,
369
+ load_pt_file=True)
340
370
  else:
341
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
342
- bench_op_name + Const.NUMPY_SUFFIX)
371
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
343
372
  else:
344
373
  n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX)
345
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.PT_SUFFIX)
374
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
346
375
  except IOError as error:
347
376
  error_file = error.filename
348
377
  n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
349
378
  error_flag = True
350
- except FileCheckException:
379
+ except (FileCheckException, CompareException):
351
380
  error_file = data_name
352
381
  n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
353
382
  error_flag = True
@@ -364,7 +393,7 @@ class Comparator:
364
393
  err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
365
394
  result_list.append(err_msg)
366
395
  return result_list
367
-
396
+
368
397
  def compare_core(self, input_parma, output_path, **kwargs):
369
398
  """
370
399
  Compares data from multiple JSON files and generates a comparison report.
@@ -378,8 +407,7 @@ class Comparator:
378
407
  - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
379
408
  - suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
380
409
  - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
381
- - summary_compare (bool, optional): Enables summary comparison mode. Defaults to False.
382
- - md5_compare (bool, optional): Enables MD5 comparison. Defaults to False.
410
+ - dump_mode (str): ALL, SUMMARY, MD5.
383
411
 
384
412
  Returns:
385
413
  """
@@ -388,41 +416,43 @@ class Comparator:
388
416
  auto_analyze = kwargs.get('auto_analyze', True)
389
417
  suffix = kwargs.get('suffix', '')
390
418
  fuzzy_match = kwargs.get('fuzzy_match', False)
391
- summary_compare = kwargs.get('summary_compare', False)
392
- md5_compare = kwargs.get('md5_compare', False)
419
+ dump_mode = kwargs.get('dump_mode', None)
393
420
 
394
421
  logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
395
422
  file_name = add_time_with_xlsx("compare_result" + suffix)
396
423
  file_path = os.path.join(os.path.realpath(output_path), file_name)
397
424
  remove_path(file_path)
398
- highlight_dict = {'red_rows': [], 'yellow_rows': []}
425
+ highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
399
426
 
400
427
  npu_json = input_parma.get("npu_json_path")
401
428
  bench_json = input_parma.get("bench_json_path")
402
429
  stack_json = input_parma.get("stack_json_path")
403
430
  if self.data_mapping:
404
- result_df = self.compare_process_custom([npu_json, bench_json, stack_json], stack_mode,
405
- summary_compare, md5_compare)
431
+ result_df = self.compare_process_custom([npu_json, bench_json, stack_json], stack_mode, dump_mode)
406
432
  else:
407
- result_df = self.compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match,
408
- summary_compare, md5_compare)
433
+ result_df = self.compare_process(
434
+ [npu_json, bench_json, stack_json],
435
+ stack_mode,
436
+ fuzzy_match,
437
+ dump_mode
438
+ )
409
439
 
410
440
  if not result_df.values.tolist():
411
441
  logger.warning("Can`t match any op.")
412
442
  return
413
443
 
414
- if not md5_compare and not summary_compare:
415
- result_df = self._do_multi_process(input_parma, result_df)
444
+ if dump_mode == Const.ALL:
445
+ result_df = self.do_multi_process(input_parma, result_df)
416
446
 
417
- logger.info("Highlight suspicious API/Module start.")
418
- find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare)
447
+ find_compare_result_error_rows(result_df, highlight_dict, dump_mode)
419
448
  highlight_rows_xlsx(result_df, highlight_dict, file_path)
420
- logger.info("Highlight suspicious API/Module finish.")
421
449
 
422
450
  if auto_analyze:
423
451
  advisor = Advisor(result_df, output_path, suffix)
424
452
  advisor.analysis()
425
-
453
+
454
+ print_compare_ends_info()
455
+
426
456
  def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
427
457
  cos_result = []
428
458
  max_err_result = []
@@ -431,13 +461,16 @@ class Comparator:
431
461
  one_thousand_err_ratio_result = []
432
462
  five_thousand_err_ratio_result = []
433
463
  is_print_compare_log = input_param.get("is_print_compare_log")
464
+ bench_data = load_json(input_param.get("bench_json_path")).get('data')
434
465
  for i in range(len(result_df)):
435
466
  npu_op_name = result_df.iloc[i, 0]
436
467
  bench_op_name = result_df.iloc[i, 1]
437
468
  if is_print_compare_log:
438
469
  logger.info("start compare: {}".format(npu_op_name))
470
+
439
471
  cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = \
440
- self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
472
+ self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param, bench_data)
473
+
441
474
  if is_print_compare_log:
442
475
  logger.info(
443
476
  "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
@@ -460,9 +493,9 @@ class Comparator:
460
493
  five_thousand_err_ratio_result=five_thousand_err_ratio_result
461
494
  )
462
495
 
463
- return _save_cmp_result(idx, cr, result_df, lock)
464
-
465
- def _do_multi_process(self, input_parma, result_df):
496
+ return _save_cmp_result(idx, cr, result_df, lock)
497
+
498
+ def do_multi_process(self, input_parma, result_df):
466
499
  try:
467
500
  result_df = _handle_multi_process(self.compare_ops, input_parma, result_df,
468
501
  multiprocessing.Manager().RLock())
@@ -470,4 +503,36 @@ class Comparator:
470
503
  except ValueError as e:
471
504
  logger.error('result dataframe is not found.')
472
505
  raise CompareException(CompareException.INVALID_DATA_ERROR) from e
473
-
506
+
507
+ def get_bench_data_name(bench_op_name, bench_data):
508
+ bench_name_list = re.split(r'\.(input|output|kwargs)\.', bench_op_name)
509
+ bench_data_bundle = bench_data.get(bench_name_list[0], {})
510
+ if not bench_data_bundle or len(bench_name_list) < 3:
511
+ return None
512
+ layers = bench_name_list[2].split(Const.SEP)
513
+
514
+ def get(key, container):
515
+ if isinstance(container, dict):
516
+ return container.get(key)
517
+ if isinstance(container, list):
518
+ try:
519
+ return container[int(key)]
520
+ except (ValueError, IndexError):
521
+ return None
522
+ return None
523
+
524
+ def get_by_layer(container):
525
+ data = container
526
+ for layer in layers:
527
+ data = get(layer, data)
528
+ return get(CompareConst.DATA_NAME.lower(), data)
529
+
530
+ if Const.INPUT == bench_name_list[1]:
531
+ return get_by_layer(bench_data_bundle.get(Const.INPUT, bench_data_bundle.get(Const.INPUT_ARGS)))
532
+ elif Const.KWARGS == bench_name_list[1]:
533
+ return get_by_layer(bench_data_bundle.get(Const.INPUT_KWARGS))
534
+ elif Const.OUTPUT == bench_name_list[1]:
535
+ return get_by_layer(bench_data_bundle.get(Const.OUTPUT))
536
+ else:
537
+ return None
538
+
@@ -35,18 +35,15 @@ dtype_mapping = {
35
35
  "BFloat16": "torch.bfloat16",
36
36
  "Complex64": "torch.complex64",
37
37
  "Complex128": "torch.complex128"
38
- }
38
+ }
39
39
 
40
40
 
41
- def check_struct_match(npu_dict, bench_dict, cross_frame=False):
41
+ def check_struct_match(npu_dict, bench_dict):
42
42
  npu_struct_in = npu_dict.get("input_struct")
43
43
  bench_struct_in = bench_dict.get("input_struct")
44
44
  npu_struct_out = npu_dict.get("output_struct")
45
45
  bench_struct_out = bench_dict.get("output_struct")
46
46
 
47
- if cross_frame:
48
- npu_struct_in = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_in]
49
- npu_struct_out = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_out]
50
47
  is_match = npu_struct_in == bench_struct_in and npu_struct_out == bench_struct_out
51
48
  if not is_match:
52
49
  if len(npu_struct_in) == 0 or len(bench_struct_in) == 0 or len(npu_struct_in) != len(bench_struct_in):
@@ -14,17 +14,22 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import json
17
- from msprobe.core.common.file_utils import FileOpen, check_file_type
17
+ from msprobe.core.common.file_utils import check_file_type, load_json
18
18
  from msprobe.core.common.const import FileCheckConst, Const
19
19
  from msprobe.core.common.utils import CompareException
20
20
  from msprobe.core.common.log import logger
21
21
 
22
22
 
23
23
  def compare_cli(args):
24
- with FileOpen(args.input_path, "r") as file:
25
- input_param = json.load(file)
24
+ input_param = load_json(args.input_path)
26
25
  npu_path = input_param.get("npu_path", None)
27
26
  bench_path = input_param.get("bench_path", None)
27
+ if not npu_path:
28
+ logger.error(f"Missing npu_path in configuration file {args.input_path}, please check!")
29
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
30
+ if not bench_path:
31
+ logger.error(f"Missing bench_path in configuration file {args.input_path}, please check!")
32
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
28
33
  frame_name = args.framework
29
34
  auto_analyze = not args.compare_only
30
35
  if frame_name == Const.PT_FRAMEWORK:
@@ -34,6 +39,9 @@ def compare_cli(args):
34
39
  from msprobe.mindspore.compare.ms_compare import ms_compare
35
40
  from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare
36
41
  if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE:
42
+ if "stack_path" not in input_param:
43
+ logger.error(f"Missing stack_path in configuration file {args.input_path}, please check!")
44
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
37
45
  input_param["npu_json_path"] = input_param.pop("npu_path")
38
46
  input_param["bench_json_path"] = input_param.pop("bench_path")
39
47
  input_param["stack_json_path"] = input_param.pop("stack_path")
@@ -56,7 +64,16 @@ def compare_cli(args):
56
64
 
57
65
  ms_compare(input_param, args.output_path, **kwargs)
58
66
  elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
59
- kwargs = {"stack_mode": args.stack_mode, "auto_analyze": auto_analyze, "fuzzy_match": args.fuzzy_match}
67
+ kwargs = {
68
+ "stack_mode": args.stack_mode,
69
+ "auto_analyze": auto_analyze,
70
+ "fuzzy_match": args.fuzzy_match,
71
+ "is_print_compare_log": input_param.get("is_print_compare_log", True),
72
+ "cell_mapping": args.cell_mapping,
73
+ "api_mapping": args.api_mapping,
74
+ "data_mapping": args.data_mapping,
75
+ "layer_mapping": args.layer_mapping
76
+ }
60
77
  if input_param.get("rank_id") is not None:
61
78
  ms_graph_compare(input_param, args.output_path)
62
79
  return