mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,28 +19,37 @@ import re
19
19
  from copy import deepcopy
20
20
 
21
21
  import pandas as pd
22
+ from tqdm import tqdm
23
+
22
24
  from msprobe.core.advisor.advisor import Advisor
23
25
  from msprobe.core.common.const import CompareConst, Const
24
26
  from msprobe.core.common.exceptions import FileCheckException
25
- from msprobe.core.common.file_utils import load_json
26
- from msprobe.core.common.file_utils import remove_path
27
+ from msprobe.core.common.file_utils import load_json, remove_path
27
28
  from msprobe.core.common.log import logger
28
- from msprobe.core.common.utils import add_time_with_xlsx, CompareException, check_op_str_pattern_valid, safe_get_value
29
- from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op, check_dump_json_str, \
30
- check_stack_json_str
29
+ from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, safe_get_value
30
+ from msprobe.core.compare.check import check_dump_json_str, check_graph_mode, check_stack_json_str, \
31
+ check_struct_match, fuzzy_check_op
31
32
  from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
32
- from msprobe.core.compare.multiprocessing_compute import _handle_multi_process, ComparisonResult, _save_cmp_result
33
- from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \
34
- get_error_message
35
- from msprobe.core.compare.utils import read_op, merge_tensor, get_un_match_accuracy, get_accuracy, \
36
- get_rela_diff_summary_mode, print_compare_ends_info
37
- from tqdm import tqdm
33
+ from msprobe.core.compare.multiprocessing_compute import ComparisonResult, _handle_multi_process, _save_cmp_result
34
+ from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_flag_and_msg
35
+ from msprobe.core.compare.utils import get_accuracy, get_rela_diff_summary_mode, get_un_match_accuracy, merge_tensor, \
36
+ print_compare_ends_info, read_op, get_name_and_state, reorder_op_x_list
38
37
 
39
38
 
40
- class Comparator:
39
+ class ModeConfig:
40
+ def __init__(self, stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=None):
41
+ self.stack_mode = stack_mode
42
+ self.auto_analyze = auto_analyze
43
+ self.fuzzy_match = fuzzy_match
44
+ self.dump_mode = dump_mode
41
45
 
42
- def __init__(self):
43
- pass
46
+
47
+ class Comparator:
48
+ def __init__(self, mode_config: ModeConfig):
49
+ self.stack_mode = mode_config.stack_mode
50
+ self.auto_analyze = mode_config.auto_analyze
51
+ self.fuzzy_match = mode_config.fuzzy_match
52
+ self.dump_mode = mode_config.dump_mode
44
53
 
45
54
  @staticmethod
46
55
  def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args):
@@ -85,16 +94,15 @@ class Comparator:
85
94
  value[k] = CompareConst.N_A
86
95
  return value
87
96
 
88
- @classmethod
89
- def make_result_table(cls, result, stack_mode, dump_mode):
90
- header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode][:]
97
+ def make_result_table(self, result):
98
+ header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:]
91
99
 
92
- if stack_mode:
100
+ if self.stack_mode:
93
101
  header.append(CompareConst.STACK)
94
- if dump_mode == Const.ALL:
102
+ if self.dump_mode == Const.ALL:
95
103
  header.append(CompareConst.DATA_NAME)
96
104
  else:
97
- if dump_mode == Const.ALL:
105
+ if self.dump_mode == Const.ALL:
98
106
  for row in result:
99
107
  del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列
100
108
  header.append(CompareConst.DATA_NAME)
@@ -104,24 +112,25 @@ class Comparator:
104
112
  result_df = pd.DataFrame(result, columns=header, dtype='object')
105
113
  return result_df
106
114
 
107
- @classmethod
108
- def gen_merge_list(cls, json_data, op_name, stack_json_data, dump_mode):
115
+ def gen_merge_list(self, json_data, op_name, stack_json_data):
109
116
  op_data = json_data['data'][op_name]
110
117
  check_dump_json_str(op_data, op_name)
111
118
  op_parsed_list = read_op(op_data, op_name)
112
119
 
113
- stack_info = stack_json_data.get(op_name)
114
- if stack_info is not None:
115
- check_stack_json_str(stack_info, op_name)
116
- op_parsed_list.append({
117
- 'full_op_name': op_name,
118
- 'full_info': stack_info
119
- })
120
-
121
- merge_list = merge_tensor(op_parsed_list, dump_mode)
120
+ if self.stack_mode:
121
+ stack_info = stack_json_data.get(op_name)
122
+ if stack_info is not None:
123
+ check_stack_json_str(stack_info, op_name)
124
+ # append only when stack_mode is True,
125
+ op_parsed_list.append({
126
+ 'full_op_name': op_name,
127
+ 'full_info': stack_info
128
+ })
129
+
130
+ merge_list = merge_tensor(op_parsed_list, self.dump_mode)
122
131
  return merge_list
123
132
 
124
- def check_op(self, npu_dict, bench_dict, fuzzy_match):
133
+ def check_op(self, npu_dict, bench_dict):
125
134
  npu_op_name = npu_dict[CompareConst.OP_NAME]
126
135
  bench_op_name = bench_dict[CompareConst.OP_NAME]
127
136
  graph_mode = check_graph_mode(safe_get_value(npu_op_name, 0, "npu_op_name"),
@@ -133,34 +142,34 @@ class Comparator:
133
142
  if graph_mode:
134
143
  return graph_mapping.match(npu_op_name[0], bench_op_name[0])
135
144
  struct_match = check_struct_match(npu_dict, bench_dict)
136
- if not fuzzy_match:
137
- return npu_op_name == bench_op_name and struct_match
138
- is_match = True
145
+ if not self.fuzzy_match:
146
+ name_match = npu_op_name == bench_op_name
147
+ return name_match and struct_match
139
148
  try:
140
- is_match = fuzzy_check_op(npu_op_name, bench_op_name)
149
+ name_match = fuzzy_check_op(npu_op_name, bench_op_name)
141
150
  except Exception as err:
142
151
  logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
143
- is_match = False
144
- return is_match and struct_match
152
+ name_match = False
153
+ return name_match and struct_match
145
154
 
146
- def match_op(self, npu_queue, bench_queue, fuzzy_match):
155
+ def match_op(self, npu_queue, bench_queue):
147
156
  for b_index, b_op in enumerate(bench_queue[0: -1]):
148
- if self.check_op(npu_queue[-1], b_op, fuzzy_match):
157
+ if self.check_op(npu_queue[-1], b_op):
149
158
  return len(npu_queue) - 1, b_index
150
- if self.check_op(npu_queue[-1], bench_queue[-1], fuzzy_match):
159
+ if self.check_op(npu_queue[-1], bench_queue[-1]):
151
160
  return len(npu_queue) - 1, len(bench_queue) - 1
152
161
  for n_index, n_op in enumerate(npu_queue[0: -1]):
153
- if self.check_op(n_op, bench_queue[-1], fuzzy_match):
162
+ if self.check_op(n_op, bench_queue[-1]):
154
163
  return n_index, len(bench_queue) - 1
155
164
  return -1, -1
156
165
 
157
- def compare_process(self, file_lists, stack_mode, fuzzy_match, dump_mode):
166
+ def compare_process(self, file_lists):
158
167
  npu_json_path, bench_json_path, stack_json_path = file_lists
159
168
  npu_json_data = load_json(npu_json_path)
160
169
  bench_json_data = load_json(bench_json_path)
161
- stack_json_data = load_json(stack_json_path)
170
+ stack_json_data = load_json(stack_json_path) if self.stack_mode else None
162
171
 
163
- if fuzzy_match:
172
+ if self.fuzzy_match:
164
173
  logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
165
174
 
166
175
  npu_ops_queue = []
@@ -184,8 +193,7 @@ class Comparator:
184
193
  last_npu_ops_len = len(npu_ops_queue)
185
194
  op_name_npu = next(ops_npu_iter)
186
195
  check_op_str_pattern_valid(op_name_npu)
187
- read_err_npu = True
188
- npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data, dump_mode)
196
+ npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data)
189
197
  if npu_merge_list:
190
198
  npu_ops_queue.append(npu_merge_list)
191
199
  except StopIteration:
@@ -194,7 +202,7 @@ class Comparator:
194
202
  last_bench_ops_len = len(bench_ops_queue)
195
203
  op_name_bench = next(ops_bench_iter)
196
204
  check_op_str_pattern_valid(op_name_bench)
197
- bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data, dump_mode)
205
+ bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data)
198
206
  if bench_merge_list:
199
207
  bench_ops_queue.append(bench_merge_list)
200
208
  except StopIteration:
@@ -213,59 +221,64 @@ class Comparator:
213
221
  logger.info("Please check whether the number and calls of APIs in NPU and Bench models are consistent.")
214
222
  break
215
223
 
216
- n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue, fuzzy_match)
224
+ n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue)
225
+
226
+ # 如果没有匹配到,数据放到队列中,跳过,直到后面匹配到,把匹配之前的api放到不匹配中
217
227
  if n_match_point == -1 and b_match_point == -1:
218
228
  continue
229
+
219
230
  n_match_data = npu_ops_queue[n_match_point]
220
231
  b_match_data = bench_ops_queue[b_match_point]
221
232
  un_match_data = npu_ops_queue[0: n_match_point]
222
233
  for npu_data in un_match_data:
223
- get_un_match_accuracy(result, npu_data, dump_mode)
224
- get_accuracy(result, n_match_data, b_match_data, dump_mode)
234
+ get_un_match_accuracy(result, npu_data, self.dump_mode)
235
+ get_accuracy(result, n_match_data, b_match_data, self.dump_mode)
225
236
  del npu_ops_queue[0: n_match_point + 1]
226
237
  del bench_ops_queue[0: b_match_point + 1]
227
238
  progress_bar.close()
228
239
  if npu_ops_queue:
229
240
  for npu_data in npu_ops_queue:
230
- get_un_match_accuracy(result, npu_data, dump_mode)
241
+ get_un_match_accuracy(result, npu_data, self.dump_mode)
231
242
 
232
- result_df = self.make_result_table(result, stack_mode, dump_mode)
243
+ result_df = self.make_result_table(result)
233
244
  return result_df
234
245
 
235
- def merge_data(self, json_data, stack_json_data, dump_mode):
246
+ def merge_data(self, json_data, stack_json_data):
236
247
  ops_all = {}
237
248
  for op_name in json_data.get('data', {}):
238
- merge_list = self.gen_merge_list(json_data, op_name, stack_json_data, dump_mode)
249
+ merge_list = self.gen_merge_list(json_data, op_name, stack_json_data)
239
250
  if merge_list:
240
- input_index, output_index = 0, 0
241
- for index, input_or_output in enumerate(merge_list[CompareConst.OP_NAME]):
242
- input_or_output_list = input_or_output.split(Const.SEP)
243
- data_name = merge_list.get('data_name')
244
- data_name = data_name[index] if data_name else None
245
- if Const.INPUT in input_or_output_list or Const.KWARGS in input_or_output_list:
246
- ops_all[input_or_output] = {
247
- CompareConst.STRUCT: safe_get_value(merge_list, input_index, "merge_list",
248
- key=CompareConst.INPUT_STRUCT),
249
- CompareConst.SUMMARY: safe_get_value(merge_list, index, "merge_list",
250
- key=CompareConst.SUMMARY),
251
- 'data_name': data_name,
252
- 'stack_info': merge_list.get('stack_info')
253
- }
254
- input_index += 1
255
-
256
- elif Const.OUTPUT in input_or_output_list:
257
- ops_all[input_or_output] = {
258
- CompareConst.STRUCT: safe_get_value(merge_list, output_index, "merge_list",
259
- key=CompareConst.OUTPUT_STRUCT),
260
- CompareConst.SUMMARY: safe_get_value(merge_list, index, "merge_list",
261
- key=CompareConst.SUMMARY),
262
- 'data_name': data_name,
263
- 'stack_info': merge_list.get('stack_info')
264
- }
265
- output_index += 1
251
+ struct_to_index_mapping = {
252
+ CompareConst.INPUT_STRUCT: 0,
253
+ CompareConst.OUTPUT_STRUCT: 0,
254
+ CompareConst.PARAMS_STRUCT: 0,
255
+ CompareConst.PARAMS_GRAD_STRUCT: 0
256
+ }
257
+
258
+ op_name_list = merge_list.get(CompareConst.OP_NAME)
259
+ summary_list = merge_list.get(Const.SUMMARY)
260
+ data_name_list = merge_list.get('data_name')
261
+ op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
262
+ summary_list,
263
+ data_name_list)
264
+ for index, op_full_name in enumerate(op_name_reorder):
265
+ data_name = data_name_reorder[index] if data_name_reorder else None
266
+
267
+ _, state = get_name_and_state(op_full_name)
268
+ struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
269
+ if not struct_key:
270
+ continue
271
+ ops_all[op_full_name] = {
272
+ CompareConst.STRUCT: safe_get_value(merge_list, struct_to_index_mapping.get(struct_key),
273
+ "merge_list", key=struct_key),
274
+ CompareConst.SUMMARY: safe_get_value(summary_reorder, index, "summary_reorder"),
275
+ 'data_name': data_name,
276
+ 'stack_info': merge_list.get('stack_info')
277
+ }
278
+ struct_to_index_mapping[struct_key] += 1
266
279
  return ops_all
267
280
 
268
- def get_accuracy(self, npu_ops_all, bench_ops_all, dump_mode):
281
+ def get_accuracy(self, npu_ops_all, bench_ops_all):
269
282
  result = []
270
283
  bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all)
271
284
  for ms_op_name, bench_op_name in self.data_mapping_dict.items():
@@ -273,7 +286,7 @@ class Comparator:
273
286
  npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
274
287
  bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
275
288
  has_stack = npu_stack_info and bench_stack_info
276
- if dump_mode == Const.MD5:
289
+ if self.dump_mode == Const.MD5:
277
290
  result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all,
278
291
  bench_ops_all, has_stack, npu_stack_info))
279
292
  continue
@@ -297,7 +310,7 @@ class Comparator:
297
310
  bench_struct[1]
298
311
  ]
299
312
 
300
- if dump_mode == Const.SUMMARY:
313
+ if self.dump_mode == Const.SUMMARY:
301
314
  result_item = base_result_item + [" "] * 8
302
315
  else:
303
316
  result_item = base_result_item + [" "] * 5
@@ -306,7 +319,7 @@ class Comparator:
306
319
  result_item.extend(npu_summary_data)
307
320
  bench_summary_data = bench_ops_all.get(bench_op_name).get("summary")
308
321
  result_item.extend(bench_summary_data)
309
- if dump_mode == Const.SUMMARY:
322
+ if self.dump_mode == Const.SUMMARY:
310
323
  self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item)
311
324
  else:
312
325
  result_item.append(CompareConst.ACCURACY_CHECK_YES)
@@ -315,7 +328,7 @@ class Comparator:
315
328
  result_item.extend(npu_stack_info)
316
329
  else:
317
330
  result_item.append(CompareConst.NONE)
318
- if dump_mode == Const.ALL:
331
+ if self.dump_mode == Const.ALL:
319
332
  result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None))
320
333
  result.append(result_item)
321
334
  elif ms_op_name not in npu_ops_all:
@@ -324,17 +337,16 @@ class Comparator:
324
337
  logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.')
325
338
  return result
326
339
 
327
- def compare_process_custom(self, file_lists, stack_mode, dump_mode):
340
+ def compare_process_custom(self, file_lists):
328
341
  npu_json_path, bench_json_path, stack_json_path = file_lists
329
342
  npu_json_data = load_json(npu_json_path)
330
343
  bench_json_data = load_json(bench_json_path)
331
- stack_json_data = load_json(stack_json_path)
332
-
333
- npu_ops_all = self.merge_data(npu_json_data, stack_json_data, dump_mode)
334
- bench_ops_all = self.merge_data(bench_json_data, stack_json_data, dump_mode)
344
+ stack_json_data = load_json(stack_json_path) if self.stack_mode else None
345
+ npu_ops_all = self.merge_data(npu_json_data, stack_json_data)
346
+ bench_ops_all = self.merge_data(bench_json_data, stack_json_data)
335
347
 
336
- result = self.get_accuracy(npu_ops_all, bench_ops_all, dump_mode)
337
- result_df = self.make_result_table(result, stack_mode, dump_mode)
348
+ result = self.get_accuracy(npu_ops_all, bench_ops_all)
349
+ result_df = self.make_result_table(result)
338
350
  return result_df
339
351
 
340
352
  def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param, bench_data):
@@ -381,25 +393,23 @@ class Comparator:
381
393
  n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
382
394
  error_flag = True
383
395
 
384
- n_value, b_value, error_flag = get_error_type(n_value, b_value, error_flag)
385
- if not error_flag:
386
- relative_err = get_relative_err(n_value, b_value)
387
- n_value, b_value = reshape_value(n_value, b_value)
396
+ # 通过n_value, b_value同时得到错误标志和错误信息
397
+ n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value,
398
+ error_flag=error_flag, error_file=error_file)
388
399
 
389
- err_msg = get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=error_file)
390
- result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=relative_err)
400
+ result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg)
391
401
 
392
- if npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
402
+ if self.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
393
403
  err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
394
404
  result_list.append(err_msg)
395
405
  return result_list
396
406
 
397
- def compare_core(self, input_parma, output_path, **kwargs):
407
+ def compare_core(self, input_param, output_path, **kwargs):
398
408
  """
399
409
  Compares data from multiple JSON files and generates a comparison report.
400
410
 
401
411
  Args:
402
- input_parma (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
412
+ input_param (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
403
413
  "stack_path").
404
414
  output_path (str): The path where the output Excel report will be saved.
405
415
  **kwargs: Additional keyword arguments including:
@@ -412,11 +422,7 @@ class Comparator:
412
422
  Returns:
413
423
  """
414
424
  # get kwargs or set default value
415
- stack_mode = kwargs.get('stack_mode', False)
416
- auto_analyze = kwargs.get('auto_analyze', True)
417
425
  suffix = kwargs.get('suffix', '')
418
- fuzzy_match = kwargs.get('fuzzy_match', False)
419
- dump_mode = kwargs.get('dump_mode', None)
420
426
 
421
427
  logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
422
428
  file_name = add_time_with_xlsx("compare_result" + suffix)
@@ -424,30 +430,25 @@ class Comparator:
424
430
  remove_path(file_path)
425
431
  highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
426
432
 
427
- npu_json = input_parma.get("npu_json_path")
428
- bench_json = input_parma.get("bench_json_path")
429
- stack_json = input_parma.get("stack_json_path")
433
+ npu_json = input_param.get("npu_json_path")
434
+ bench_json = input_param.get("bench_json_path")
435
+ stack_json = input_param.get("stack_json_path")
430
436
  if self.data_mapping:
431
- result_df = self.compare_process_custom([npu_json, bench_json, stack_json], stack_mode, dump_mode)
437
+ result_df = self.compare_process_custom([npu_json, bench_json, stack_json])
432
438
  else:
433
- result_df = self.compare_process(
434
- [npu_json, bench_json, stack_json],
435
- stack_mode,
436
- fuzzy_match,
437
- dump_mode
438
- )
439
+ result_df = self.compare_process([npu_json, bench_json, stack_json])
439
440
 
440
441
  if not result_df.values.tolist():
441
442
  logger.warning("Can`t match any op.")
442
443
  return
443
444
 
444
- if dump_mode == Const.ALL:
445
- result_df = self.do_multi_process(input_parma, result_df)
445
+ if self.dump_mode == Const.ALL:
446
+ result_df = self.do_multi_process(input_param, result_df)
446
447
 
447
- find_compare_result_error_rows(result_df, highlight_dict, dump_mode)
448
+ find_compare_result_error_rows(result_df, highlight_dict, self.dump_mode)
448
449
  highlight_rows_xlsx(result_df, highlight_dict, file_path)
449
450
 
450
- if auto_analyze:
451
+ if self.auto_analyze:
451
452
  advisor = Advisor(result_df, output_path, suffix)
452
453
  advisor.analysis()
453
454
 
@@ -504,14 +505,18 @@ class Comparator:
504
505
  logger.error('result dataframe is not found.')
505
506
  raise CompareException(CompareException.INVALID_DATA_ERROR) from e
506
507
 
508
+
507
509
  def get_bench_data_name(bench_op_name, bench_data):
508
- bench_name_list = re.split(r'\.(input|output|kwargs)\.', bench_op_name)
509
- bench_data_bundle = bench_data.get(bench_name_list[0], {})
510
+ bench_name_list = re.split(r'\.(input|output|kwargs|parameters|parameters_grad)\.', bench_op_name)
511
+ if len(bench_name_list) > 1 and bench_name_list[1] == Const.PARAMS_GRAD:
512
+ bench_data_bundle = bench_data.get(bench_name_list[0] + Const.SEP + bench_name_list[1], {})
513
+ else:
514
+ bench_data_bundle = bench_data.get(bench_name_list[0], {})
510
515
  if not bench_data_bundle or len(bench_name_list) < 3:
511
516
  return None
512
517
  layers = bench_name_list[2].split(Const.SEP)
513
518
 
514
- def get(key, container):
519
+ def _get(key, container):
515
520
  if isinstance(container, dict):
516
521
  return container.get(key)
517
522
  if isinstance(container, list):
@@ -521,11 +526,14 @@ def get_bench_data_name(bench_op_name, bench_data):
521
526
  return None
522
527
  return None
523
528
 
524
- def get_by_layer(container):
529
+ def get_by_layer(container, params_grad=False):
525
530
  data = container
531
+ # dump.json中parameters_grad的结构为key:[{}], 如果存在key,有且只有一个列表元素,而op_name中只命名到了key,因此加'0'
532
+ if params_grad:
533
+ layers.append('0')
526
534
  for layer in layers:
527
- data = get(layer, data)
528
- return get(CompareConst.DATA_NAME.lower(), data)
535
+ data = _get(layer, data)
536
+ return _get(CompareConst.DATA_NAME.lower(), data)
529
537
 
530
538
  if Const.INPUT == bench_name_list[1]:
531
539
  return get_by_layer(bench_data_bundle.get(Const.INPUT, bench_data_bundle.get(Const.INPUT_ARGS)))
@@ -533,6 +541,9 @@ def get_bench_data_name(bench_op_name, bench_data):
533
541
  return get_by_layer(bench_data_bundle.get(Const.INPUT_KWARGS))
534
542
  elif Const.OUTPUT == bench_name_list[1]:
535
543
  return get_by_layer(bench_data_bundle.get(Const.OUTPUT))
544
+ elif Const.PARAMS == bench_name_list[1]:
545
+ return get_by_layer(bench_data_bundle.get(Const.PARAMS))
546
+ elif Const.PARAMS_GRAD == bench_name_list[1]:
547
+ return get_by_layer(bench_data_bundle, params_grad=True)
536
548
  else:
537
549
  return None
538
-
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,8 +16,7 @@
16
16
  from msprobe.core.common.log import logger
17
17
  from msprobe.core.compare.utils import rename_api
18
18
  from msprobe.core.common.utils import check_op_str_pattern_valid, CompareException
19
- from msprobe.core.common.const import Const
20
-
19
+ from msprobe.core.common.const import CompareConst, Const
21
20
 
22
21
  dtype_mapping = {
23
22
  "Int8": "torch.int8",
@@ -38,31 +37,40 @@ dtype_mapping = {
38
37
  }
39
38
 
40
39
 
41
- def check_struct_match(npu_dict, bench_dict):
42
- npu_struct_in = npu_dict.get("input_struct")
43
- bench_struct_in = bench_dict.get("input_struct")
44
- npu_struct_out = npu_dict.get("output_struct")
45
- bench_struct_out = bench_dict.get("output_struct")
40
+ def compare_op_dict_struct(npu_dict, bench_dict):
41
+ return all(npu_dict.get(key) == bench_dict.get(key) for key in CompareConst.STRUCT_COMPARE_KEY)
46
42
 
47
- is_match = npu_struct_in == bench_struct_in and npu_struct_out == bench_struct_out
43
+
44
+ def check_struct_match(npu_dict, bench_dict):
45
+ is_match = compare_op_dict_struct(npu_dict, bench_dict)
48
46
  if not is_match:
49
- if len(npu_struct_in) == 0 or len(bench_struct_in) == 0 or len(npu_struct_in) != len(bench_struct_in):
50
- return False
47
+ struct_match_list = []
51
48
  try:
52
- struct_in_is_match = check_type_shape_match(npu_struct_in, bench_struct_in)
53
- struct_out_is_match = check_type_shape_match(npu_struct_out, bench_struct_out)
49
+ for i, key in enumerate(CompareConst.STRUCT_COMPARE_KEY):
50
+ # 首先额外检查input_struct是否空,input_struct不可能为空
51
+ if i == 0 and (not npu_dict.get(key, []) or not bench_dict.get(key, [])):
52
+ return False
53
+ struct_match_list.append(check_type_shape_match(npu_dict.get(key, []), bench_dict.get(key, [])))
54
54
  except CompareException as error:
55
55
  err_msg = f'index out of bounds error occurs in npu or bench api, please check!\n' \
56
56
  f'npu_dict: {npu_dict}' \
57
57
  f'bench_dict: {bench_dict}'
58
58
  logger.error(err_msg)
59
59
  raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
60
- is_match = struct_in_is_match and struct_out_is_match
60
+ is_match = all(struct_match_list)
61
61
  return is_match
62
62
 
63
63
 
64
64
  def check_type_shape_match(npu_struct, bench_struct):
65
- shape_type_match = False
65
+ """
66
+ further check dtypes with a dtype mapping list when dtypes are not entirely consistent.
67
+ """
68
+ if len(npu_struct) != len(bench_struct):
69
+ return False
70
+ if not npu_struct and not bench_struct:
71
+ return True
72
+
73
+ struct_match = False
66
74
  for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
67
75
  try:
68
76
  npu_type = npu_type_shape[0]
@@ -76,22 +84,14 @@ def check_type_shape_match(npu_struct, bench_struct):
76
84
  shape_match = npu_shape == bench_shape
77
85
  type_match = npu_type == bench_type
78
86
  if not type_match:
79
- ms_type = [
80
- [Const.FLOAT16, Const.FLOAT32], [Const.FLOAT32, Const.FLOAT16],
81
- [Const.FLOAT16, Const.BFLOAT16], [Const.BFLOAT16, Const.FLOAT16]
82
- ]
83
- torch_type = [
84
- [Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16],
85
- [Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16]
86
- ]
87
- if ([npu_type, bench_type] in ms_type) or ([npu_type, bench_type] in torch_type):
87
+ if ([npu_type, bench_type] in CompareConst.MS_TYPE) or ([npu_type, bench_type] in CompareConst.TORCH_TYPE):
88
88
  type_match = True
89
89
  else:
90
90
  type_match = False
91
- shape_type_match = shape_match and type_match
92
- if not shape_type_match:
91
+ struct_match = shape_match and type_match
92
+ if not struct_match:
93
93
  return False
94
- return shape_type_match
94
+ return struct_match
95
95
 
96
96
 
97
97
  def check_graph_mode(a_op_name, b_op_name):
@@ -103,6 +103,8 @@ def check_graph_mode(a_op_name, b_op_name):
103
103
 
104
104
 
105
105
  def fuzzy_check_op(npu_name_list, bench_name_list):
106
+ # 先检查api里的item长度是否相等,如果不是parameters_grad, 必然有input或者output,长度不可能为0
107
+ # 如果是parameters_grad, "parameters_grad"字段的字典不会是空字典,因此len>=1
106
108
  if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list):
107
109
  return False
108
110
  is_match = True
@@ -148,11 +150,11 @@ def check_json_key_value(input_output, op_name, depth=0):
148
150
  return
149
151
  if isinstance(input_output, list):
150
152
  for item in input_output:
151
- check_json_key_value(item, op_name, depth+1)
153
+ check_json_key_value(item, op_name, depth + 1)
152
154
  elif isinstance(input_output, dict):
153
155
  for key, value in input_output.items():
154
156
  if isinstance(value, dict):
155
- check_json_key_value(value, op_name, depth+1)
157
+ check_json_key_value(value, op_name, depth + 1)
156
158
  else:
157
159
  valid_key_value(key, value, op_name)
158
160
 
@@ -38,40 +38,41 @@ def compare_cli(args):
38
38
  else:
39
39
  from msprobe.mindspore.compare.ms_compare import ms_compare
40
40
  from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare
41
+
42
+ common_kwargs = {
43
+ "auto_analyze": auto_analyze,
44
+ "fuzzy_match": args.fuzzy_match,
45
+ "data_mapping": args.data_mapping,
46
+ }
47
+
41
48
  if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE:
42
- if "stack_path" not in input_param:
43
- logger.error(f"Missing stack_path in configuration file {args.input_path}, please check!")
44
- raise CompareException(CompareException.INVALID_PATH_ERROR)
45
49
  input_param["npu_json_path"] = input_param.pop("npu_path")
46
50
  input_param["bench_json_path"] = input_param.pop("bench_path")
47
- input_param["stack_json_path"] = input_param.pop("stack_path")
51
+ if "stack_path" not in input_param:
52
+ logger.warning(f"Missing stack_path in the configuration file. "
53
+ f"Automatically detecting stack.json to determine whether to display NPU_Stack_Info.")
54
+ else:
55
+ input_param["stack_json_path"] = input_param.pop("stack_path")
56
+
48
57
  if frame_name == Const.PT_FRAMEWORK:
49
- kwargs = {
50
- "data_mapping": args.data_mapping
51
- }
52
- compare(input_param, args.output_path, stack_mode=args.stack_mode, auto_analyze=auto_analyze,
53
- fuzzy_match=args.fuzzy_match, **kwargs)
58
+ kwargs = {**common_kwargs, "stack_mode": args.stack_mode}
59
+ compare(input_param, args.output_path, **kwargs)
54
60
  else:
55
61
  kwargs = {
62
+ **common_kwargs,
56
63
  "stack_mode": args.stack_mode,
57
- "auto_analyze": auto_analyze,
58
- "fuzzy_match": args.fuzzy_match,
59
64
  "cell_mapping": args.cell_mapping,
60
65
  "api_mapping": args.api_mapping,
61
- "data_mapping": args.data_mapping,
62
66
  "layer_mapping": args.layer_mapping
63
67
  }
64
-
65
68
  ms_compare(input_param, args.output_path, **kwargs)
66
69
  elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
67
70
  kwargs = {
71
+ **common_kwargs,
68
72
  "stack_mode": args.stack_mode,
69
- "auto_analyze": auto_analyze,
70
- "fuzzy_match": args.fuzzy_match,
71
73
  "is_print_compare_log": input_param.get("is_print_compare_log", True),
72
74
  "cell_mapping": args.cell_mapping,
73
75
  "api_mapping": args.api_mapping,
74
- "data_mapping": args.data_mapping,
75
76
  "layer_mapping": args.layer_mapping
76
77
  }
77
78
  if input_param.get("rank_id") is not None: