mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -15,28 +15,31 @@
15
15
 
16
16
  import os
17
17
  import re
18
+ import math
19
+ import zlib
20
+ from dataclasses import dataclass
21
+
18
22
  import numpy as np
23
+
19
24
  from msprobe.core.common.const import Const, CompareConst
20
- from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger
25
+ from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value
21
26
  from msprobe.core.common.file_utils import check_file_or_directory_path
22
27
 
23
28
 
24
29
  def extract_json(dirname, stack_json=False):
25
30
  json_path = ''
26
- for fname in os.listdir(dirname):
27
- if fname == "construct.json":
28
- continue
29
- full_path = os.path.join(dirname, fname)
30
- if full_path.endswith('.json'):
31
- json_path = full_path
32
- if not stack_json and 'stack' not in json_path:
33
- break
34
- if stack_json and 'stack' in json_path:
35
- break
31
+ for filename in os.listdir(dirname):
32
+ target_file_name = 'stack.json' if stack_json else 'dump.json'
33
+ if filename == target_file_name:
34
+ json_path = os.path.join(dirname, filename)
35
+ break
36
36
 
37
37
  # Provide robustness on invalid directory inputs
38
38
  if not json_path:
39
- logger.error(f'No file is found in dump dir {dirname}. ')
39
+ if stack_json:
40
+ logger.error(f'stack.json is not found in dump dir {dirname}.')
41
+ else:
42
+ logger.error(f'dump.json is not found in dump dir {dirname}.')
40
43
  raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
41
44
  return json_path
42
45
 
@@ -44,7 +47,7 @@ def extract_json(dirname, stack_json=False):
44
47
  def check_and_return_dir_contents(dump_dir, prefix):
45
48
  """
46
49
  check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
47
- pattern: ^{prefix}(?:0|[0-9][1-9]*)?$
50
+ pattern: ^{prefix}(?:0|[1-9][0-9]*)?$
48
51
 
49
52
  Args:
50
53
  dump_dir (str): dump dir
@@ -60,7 +63,7 @@ def check_and_return_dir_contents(dump_dir, prefix):
60
63
  check_regex_prefix_format_valid(prefix)
61
64
  check_file_or_directory_path(dump_dir, True)
62
65
  contents = os.listdir(dump_dir)
63
- pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$')
66
+ pattern = re.compile(rf'^{prefix}(?:0|[1-9][0-9]*)?$')
64
67
  for name in contents:
65
68
  if not pattern.match(name):
66
69
  logger.error(
@@ -84,122 +87,89 @@ def rename_api(npu_name, process):
84
87
 
85
88
 
86
89
  def read_op(op_data, op_name):
90
+ io_name_mapping = {
91
+ Const.INPUT_ARGS: '.input',
92
+ Const.INPUT_KWARGS: '.input',
93
+ Const.INPUT: '.input',
94
+ Const.OUTPUT: '.output'
95
+ }
96
+
87
97
  op_parsed_list = []
88
- if Const.FORWARD in op_name:
89
- if Const.INPUT_ARGS in op_data:
90
- input_item = op_data[Const.INPUT_ARGS]
91
- input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
92
- op_parsed_list = input_parsed_list.copy()
93
- input_parsed_list.clear()
94
- if Const.INPUT_KWARGS in op_data:
95
- kwargs_item = op_data[Const.INPUT_KWARGS]
96
- if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list):
97
- kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '.input', None)
98
- op_parsed_list += kwarg_parsed_list
99
- kwarg_parsed_list.clear()
100
- elif kwargs_item:
101
- for kwarg in kwargs_item:
102
- kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '.input.' + kwarg, None)
103
- op_parsed_list += kwarg_parsed_list
104
- kwarg_parsed_list.clear()
105
- if Const.OUTPUT in op_data:
106
- output_item = op_data[Const.OUTPUT]
107
- output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
108
- op_parsed_list += output_parsed_list
109
- output_parsed_list.clear()
110
- if Const.BACKWARD in op_name:
111
- if Const.INPUT in op_data:
112
- input_item = op_data[Const.INPUT]
113
- input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
114
- op_parsed_list = input_parsed_list.copy()
115
- input_parsed_list.clear()
116
- if Const.OUTPUT in op_data:
117
- output_item = op_data[Const.OUTPUT]
118
- output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
119
- op_parsed_list += output_parsed_list
120
- output_parsed_list.clear()
98
+ for name in io_name_mapping:
99
+ if name in op_data:
100
+ op_parsed_list.extend(op_item_parse(op_data[name], op_name + io_name_mapping[name]))
121
101
  return op_parsed_list
122
102
 
123
103
 
124
- def op_item_parse(item, op_name, index, item_list=None, top_bool=True, depth=0):
104
+ def op_item_parse(op_data, op_name: str, depth: int = 0) -> list:
105
+ default_item = {
106
+ 'full_op_name': op_name,
107
+ 'type': None,
108
+ 'Max': None,
109
+ 'Min': None,
110
+ 'Mean': None,
111
+ 'Norm': None,
112
+ 'dtype': None,
113
+ 'shape': None,
114
+ 'md5': None,
115
+ 'value': None,
116
+ 'data_name': '-1'
117
+ }
118
+
125
119
  if depth > Const.MAX_DEPTH:
126
- logger.error(f"parse of api/module of {op_name} exceeds the recursion limit.")
120
+ logger.error(f'parse of api/module of {op_name} exceeds the recursion limit.')
127
121
  raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
128
- if item_list is None:
129
- item_list = []
130
- if item is None or (isinstance(item, dict) and not item):
131
- if not top_bool:
132
- tmp = {
133
- 'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None,
134
- 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'
135
- }
136
- else:
137
- tmp = {
138
- 'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None,
139
- 'shape': None, 'md5': None, 'data_name': '-1'
140
- }
141
- item_list.append(tmp)
142
- return item_list
143
- if index is None:
144
- if isinstance(item, dict):
145
- full_op_name = op_name + '.0'
146
- else:
147
- full_op_name = op_name
148
- else:
149
- full_op_name = op_name + Const.SEP + str(index)
150
- if isinstance(item, dict):
151
- if 'type' not in item:
152
- for kwarg in item:
153
- kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None, depth=depth+1)
154
- item_list += kwarg_parsed_list
155
- kwarg_parsed_list.clear()
156
- elif 'dtype' in item:
157
- parsed_item = item
158
- parsed_item['full_op_name'] = full_op_name
159
- item_list.append(parsed_item)
160
- elif 'type' in item:
161
- parsed_item = {}
162
- if item['type'] == 'torch.Size':
163
- parsed_item['full_op_name'] = full_op_name
164
- parsed_item['dtype'] = 'torch.Size'
165
- parsed_item['shape'] = str(item['value'])
166
- parsed_item['md5'] = None
167
- parsed_item['Max'] = None
168
- parsed_item['Min'] = None
169
- parsed_item['Mean'] = None
170
- parsed_item['Norm'] = None
171
- parsed_item['data_name'] = '-1'
172
- item_list.append(parsed_item)
173
- elif item['type'] == 'slice':
174
- parsed_item['full_op_name'] = full_op_name
175
- parsed_item['dtype'] = 'slice'
176
- parsed_item['shape'] = str(np.shape(np.array(item['value'])))
177
- parsed_item['md5'] = None
178
- parsed_item['Max'] = None
179
- parsed_item['Min'] = None
180
- parsed_item['Mean'] = None
181
- parsed_item['Norm'] = None
182
- parsed_item['data_name'] = '-1'
183
- item_list.append(parsed_item)
184
- else:
185
- parsed_item['full_op_name'] = full_op_name
186
- parsed_item['dtype'] = str(type(item['value']))
187
- parsed_item['shape'] = '[]'
188
- parsed_item['md5'] = None
189
- parsed_item['Max'] = item['value']
190
- parsed_item['Min'] = item['value']
191
- parsed_item['Mean'] = item['value']
192
- parsed_item['Norm'] = item['value']
193
- parsed_item['data_name'] = '-1'
194
- item_list.append(parsed_item)
195
- else:
196
- resolve_api_special_parameters(item, full_op_name, item_list)
197
- else:
198
- for j, item_spec in enumerate(item):
199
- op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False, depth=depth+1)
122
+
123
+ if op_data is None:
124
+ return [default_item]
125
+ elif not op_data:
126
+ return []
127
+
128
+ item_list = []
129
+ if isinstance(op_data, list):
130
+ for i, data in enumerate(op_data):
131
+ item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), depth + 1))
132
+ elif isinstance(op_data, dict):
133
+ if is_leaf_data(op_data):
134
+ return [gen_op_item(op_data, op_name)]
135
+ for sub_name, sub_data in op_data.items():
136
+ item_list.extend(op_item_parse(sub_data, op_name + Const.SEP + str(sub_name), depth + 1))
200
137
  return item_list
201
138
 
202
139
 
140
+ def is_leaf_data(op_data):
141
+ return 'type' in op_data and isinstance(op_data['type'], str)
142
+
143
+
144
+ def gen_op_item(op_data, op_name):
145
+ op_item = {}
146
+ op_item.update(op_data)
147
+ op_item['full_op_name'] = op_name
148
+ op_item['data_name'] = op_data.get('data_name', '-1')
149
+
150
+ params = ['Max', 'Min', 'Mean', 'Norm']
151
+ for i in params:
152
+ if i not in op_item:
153
+ op_item[i] = None
154
+
155
+ if not op_item.get('dtype'):
156
+ if op_item.get('type') == 'torch.Size':
157
+ op_item['dtype'] = op_data.get('type')
158
+ op_item['shape'] = str(op_data.get('value'))
159
+ elif op_item.get('type') == 'slice':
160
+ op_item['dtype'] = op_data.get('type')
161
+ op_item['shape'] = str(np.shape(np.array(op_data.get('value'))))
162
+ else:
163
+ op_item['dtype'] = str(type(op_data.get('value')))
164
+ op_item['shape'] = '[]'
165
+ for i in params:
166
+ op_item[i] = op_data.get('value')
167
+ if not op_item.get('md5'):
168
+ op_item['md5'] = f"{zlib.crc32(str(op_data.get('value', '')).encode()):08x}"
169
+
170
+ return op_item
171
+
172
+
203
173
  def resolve_api_special_parameters(data_dict, full_op_name, item_list):
204
174
  """
205
175
  Function Description:
@@ -231,131 +201,173 @@ def resolve_api_special_parameters(data_dict, full_op_name, item_list):
231
201
  item_list.append(parsed_item)
232
202
 
233
203
 
234
- def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False):
204
+ def process_summary_data(summary_data):
205
+ """处理summary_data中的nan值,返回处理后的列表"""
206
+ return [CompareConst.NAN if isinstance(x, float) and math.isnan(x) else x for x in summary_data]
207
+
208
+
209
+ def get_rela_diff_summary_mode(result_item, npu_summary_data, bench_summary_data, err_msg):
210
+ start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
211
+ warning_flag = False
212
+ for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
213
+ if all(isinstance(val, (float, int)) and not isinstance(val, bool) for val in [npu_val, bench_val]):
214
+ diff = npu_val - bench_val
215
+ if math.isnan(diff):
216
+ diff = CompareConst.NAN
217
+ relative = CompareConst.NAN
218
+ else:
219
+ if bench_val != 0:
220
+ relative = str(abs((diff / bench_val) * 100)) + '%'
221
+ else:
222
+ relative = CompareConst.N_A
223
+ magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + CompareConst.EPSILON)
224
+ if magnitude_diff > CompareConst.MAGNITUDE:
225
+ warning_flag = True
226
+ result_item[start_idx + i] = diff
227
+ result_item[start_idx + i + CompareConst.STATISTICS_INDICATOR_NUM] = relative
228
+ else:
229
+ result_item[start_idx + i] = CompareConst.N_A
230
+ result_item[start_idx + i + CompareConst.STATISTICS_INDICATOR_NUM] = CompareConst.N_A
231
+
232
+ accuracy_check = CompareConst.WARNING if warning_flag else ""
233
+ err_msg += "Need double check api accuracy." if warning_flag else ""
234
+ for i in range(start_idx, len(result_item)):
235
+ if str(result_item[i]) in ('inf', '-inf', 'nan'):
236
+ result_item[i] = f'{result_item[i]}\t'
237
+ return result_item, accuracy_check, err_msg
238
+
239
+
240
+ @dataclass
241
+ class ApiItemInfo:
242
+ name: str
243
+ struct: tuple
244
+ stack_info: list
245
+
246
+
247
+ def stack_column_process(result_item, has_stack, index, key, npu_stack_info):
248
+ if has_stack and index == 0 and key == CompareConst.INPUT_STRUCT:
249
+ result_item.extend(npu_stack_info)
250
+ else:
251
+ result_item.append(CompareConst.NONE)
252
+ return result_item
253
+
254
+
255
+ def result_item_init(n_info, b_info, dump_mode):
256
+ n_len = len(n_info.struct)
257
+ b_len = len(b_info.struct)
258
+ struct_long_enough = (n_len > 2 and b_len > 2) if dump_mode == Const.MD5 else (n_len > 1 and b_len > 1)
259
+ if struct_long_enough:
260
+ result_item = [
261
+ n_info.name, b_info.name, n_info.struct[0], b_info.struct[0], n_info.struct[1], b_info.struct[1]
262
+ ]
263
+ if dump_mode == Const.MD5:
264
+ md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF
265
+ result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result])
266
+ elif dump_mode == Const.SUMMARY:
267
+ result_item.extend([" "] * 8)
268
+ else:
269
+ result_item.extend([" "] * 5)
270
+ else:
271
+ err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \
272
+ f"npu_info_struct is {n_info.struct}\n" \
273
+ f"bench_info_struct is {b_info.struct}"
274
+ logger.error(err_msg)
275
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
276
+ return result_item
277
+
278
+
279
+ def get_accuracy(result, n_dict, b_dict, dump_mode):
235
280
  def get_accuracy_core(n_start, n_len, b_start, b_len, key):
236
281
  min_len = min(n_len, b_len)
237
282
  npu_stack_info = n_dict.get("stack_info", None)
238
283
  bench_stack_info = b_dict.get("stack_info", None)
239
284
  has_stack = npu_stack_info and bench_stack_info
240
285
 
241
- all_mode_bool = not (summary_compare or md5_compare)
242
- if all_mode_bool:
286
+ if dump_mode == Const.ALL:
243
287
  npu_data_name = n_dict.get("data_name", None)
244
288
  bench_data_name = b_dict.get("data_name", None)
245
289
 
246
290
  for index in range(min_len):
247
-
248
- n_name = n_dict['op_name'][n_start + index]
249
- b_name = b_dict['op_name'][b_start + index]
250
- n_struct = n_dict[key][index]
251
- b_struct = b_dict[key][index]
291
+ n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name")
292
+ b_name = safe_get_value(b_dict, b_start + index, "b_dict", key="op_name")
293
+ n_struct = safe_get_value(n_dict, index, "n_dict", key=key)
294
+ b_struct = safe_get_value(b_dict, index, "b_dict", key=key)
252
295
  err_msg = ""
253
- if md5_compare:
254
- result_item = [
255
- n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], n_struct[2], b_struct[2],
256
- CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF
257
- ]
258
- if has_stack and index == 0 and key == "input_struct":
259
- result_item.extend(npu_stack_info)
260
- else:
261
- result_item.append(CompareConst.NONE)
296
+
297
+ npu_info = ApiItemInfo(n_name, n_struct, npu_stack_info)
298
+ bench_info = ApiItemInfo(b_name, b_struct, bench_stack_info)
299
+ result_item = result_item_init(npu_info, bench_info, dump_mode)
300
+
301
+ if dump_mode == Const.MD5:
302
+ result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
262
303
  result.append(result_item)
263
304
  continue
264
305
 
265
- if summary_compare:
266
- result_item = [
267
- n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
268
- " ", " ", " ", " ", " ", " ", " ", " "
269
- ]
270
- else:
271
- result_item = [
272
- n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
273
- " ", " ", " ", " ", " "
274
- ]
275
-
276
- npu_summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
277
- result_item.extend(npu_summary_data)
278
- bench_summary_data = b_dict.get(CompareConst.SUMMARY)[b_start + index]
279
- result_item.extend(bench_summary_data)
280
-
281
- if summary_compare:
282
- start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
283
- warning_flag = False
284
- for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
285
- if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
286
- diff = npu_val - bench_val
287
- if bench_val != 0:
288
- relative = str(abs((diff / bench_val) * 100)) + '%'
289
- else:
290
- relative = CompareConst.N_A
291
- result_item[start_idx + i] = diff
292
- result_item[start_idx + i + 4] = relative
293
- magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
294
- if magnitude_diff > 0.5:
295
- warning_flag = True
296
- else:
297
- result_item[start_idx + i] = CompareConst.NONE
298
- accuracy_check = CompareConst.WARNING if warning_flag else ""
299
- err_msg += "Need double check api accuracy." if warning_flag else ""
300
- for i in range(start_idx, len(result_item)):
301
- if str(result_item[i]) in ('inf', '-inf', 'nan'):
302
- result_item[i] = f'{result_item[i]}\t'
303
-
304
- result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES)
306
+ npu_summary_data = safe_get_value(n_dict, n_start + index, "n_dict", key=CompareConst.SUMMARY)
307
+ bench_summary_data = safe_get_value(b_dict, b_start + index, "b_dict", key=CompareConst.SUMMARY)
308
+ result_item.extend(process_summary_data(npu_summary_data))
309
+ result_item.extend(process_summary_data(bench_summary_data))
310
+
311
+ if dump_mode == Const.SUMMARY:
312
+ result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
313
+ bench_summary_data, err_msg)
314
+
315
+ result_item.append(accuracy_check if dump_mode == Const.SUMMARY else CompareConst.ACCURACY_CHECK_YES)
305
316
  result_item.append(err_msg)
306
- if has_stack and index == 0 and key == "input_struct":
307
- result_item.extend(npu_stack_info)
308
- else:
309
- result_item.append(CompareConst.NONE)
310
- if all_mode_bool:
311
- result_item.append(npu_data_name[n_start + index])
317
+ result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
318
+ if dump_mode == Const.ALL:
319
+ result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
312
320
 
313
321
  result.append(result_item)
314
322
 
315
323
  if n_len > b_len:
316
324
  for index in range(b_len, n_len):
317
- n_name = n_dict['op_name'][n_start + index]
318
- n_struct = n_dict[key][index]
319
- if md5_compare:
325
+ try:
326
+ n_name = n_dict['op_name'][n_start + index]
327
+ n_struct = n_dict[key][index]
328
+ if dump_mode == Const.MD5:
329
+ result_item = [
330
+ n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
331
+ n_struct[2], CompareConst.NAN, CompareConst.NAN
332
+ ]
333
+ result.append(result_item)
334
+ continue
320
335
  result_item = [
321
336
  n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
322
- n_struct[2], CompareConst.NAN, CompareConst.NAN
337
+ " ", " ", " ", " ", " "
323
338
  ]
324
- result.append(result_item)
325
- continue
326
- result_item = [
327
- n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
328
- " ", " ", " ", " ", " "
329
- ]
330
- summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
331
- result_item.extend(summary_data)
332
- summary_data = [CompareConst.NAN for _ in range(len(n_dict.get(CompareConst.SUMMARY)[0]))]
333
- result_item.extend(summary_data)
339
+ summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
340
+ result_item.extend(summary_data)
341
+ summary_data = [CompareConst.NAN for _ in range(len(n_dict.get(CompareConst.SUMMARY)[0]))]
342
+ result_item.extend(summary_data)
343
+ except IndexError as e:
344
+ err_msg = "index out of bounds error occurs, please check!\n" \
345
+ f"n_dict is {n_dict}"
346
+ logger.error(err_msg)
347
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
334
348
 
335
349
  err_msg = ""
336
350
  result_item.append(CompareConst.ACCURACY_CHECK_YES)
337
351
  result_item.append(err_msg)
338
-
339
- if has_stack and index == 0 and key == "input_struct":
340
- result_item.extend(npu_stack_info)
341
- else:
342
- result_item.append(CompareConst.NONE)
343
- if all_mode_bool:
344
- result_item.append(npu_data_name[n_start + index])
352
+ result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
353
+ if dump_mode == Const.ALL:
354
+ result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
345
355
 
346
356
  result.append(result_item)
347
357
 
348
358
  n_num = len(n_dict['op_name'])
349
359
  b_num = len(b_dict['op_name'])
350
- n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
351
- b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
360
+ n_num_input = len([name for name in n_dict['op_name']
361
+ if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
362
+ b_num_input = len([name for name in b_dict['op_name']
363
+ if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
352
364
  n_num_output = n_num - n_num_input
353
365
  b_num_output = b_num - b_num_input
354
366
  get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
355
367
  get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, 'output_struct')
356
368
 
357
369
 
358
- def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
370
+ def get_un_match_accuracy(result, n_dict, dump_mode):
359
371
  index_out = 0
360
372
  npu_stack_info = n_dict.get("stack_info", None)
361
373
  bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
@@ -363,14 +375,22 @@ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
363
375
  accuracy_check_res = CompareConst.N_A
364
376
  for index, n_name in enumerate(n_dict["op_name"]):
365
377
  name_ele_list = n_name.split(Const.SEP)
366
- if "input" in name_ele_list:
367
- n_struct = n_dict["input_struct"][index]
368
- else:
369
- n_struct = n_dict["output_struct"][index_out]
378
+ if Const.INPUT in name_ele_list or Const.KWARGS in name_ele_list:
379
+ n_struct = safe_get_value(n_dict, index, "n_dict", key=CompareConst.INPUT_STRUCT)
380
+ if Const.OUTPUT in name_ele_list:
381
+ n_struct = safe_get_value(n_dict, index_out, "n_dict", key=CompareConst.OUTPUT_STRUCT)
370
382
  index_out += 1
371
383
 
372
- result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
373
- if md5_compare:
384
+ try:
385
+ result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
386
+ except IndexError as e:
387
+ err_msg = "index out of bounds error occurs, please check!\n" \
388
+ f"op_name of n_dict is {n_dict['op_name']}\n" \
389
+ f"input_struct of n_dict is {n_dict[CompareConst.INPUT_STRUCT]}\n" \
390
+ f"output_struct of n_dict is {n_dict[CompareConst.OUTPUT_STRUCT]}"
391
+ logger.error(err_msg)
392
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
393
+ if dump_mode == Const.MD5:
374
394
  result_item.extend([CompareConst.N_A] * 3)
375
395
  if npu_stack_info and index == 0:
376
396
  result_item.extend(npu_stack_info)
@@ -378,11 +398,11 @@ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
378
398
  result_item.append(CompareConst.NONE)
379
399
  result.append(result_item)
380
400
  continue
381
- if summary_compare:
401
+ if dump_mode == Const.SUMMARY:
382
402
  result_item.extend([CompareConst.N_A] * 8)
383
403
  else:
384
404
  result_item.extend([CompareConst.N_A] * 5)
385
- npu_summary_data = n_dict.get("summary")[index]
405
+ npu_summary_data = safe_get_value(n_dict, index, "n_dict", key=CompareConst.SUMMARY)
386
406
  result_item.extend(npu_summary_data)
387
407
  bench_summary_data = [CompareConst.N_A] * 4
388
408
  result_item.extend(bench_summary_data)
@@ -392,22 +412,21 @@ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
392
412
  result_item.extend(npu_stack_info)
393
413
  else:
394
414
  result_item.append(CompareConst.NONE)
395
- if not md5_compare and not summary_compare and result_item[1] == CompareConst.N_A:
415
+ if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
396
416
  result_item.extend(["-1"])
397
417
  result.append(result_item)
398
418
 
399
419
 
400
- def merge_tensor(tensor_list, summary_compare, md5_compare):
420
+ def merge_tensor(tensor_list, dump_mode):
401
421
  op_dict = {}
402
422
  op_dict["op_name"] = []
403
- op_dict["input_struct"] = []
404
- op_dict["kwargs_struct"] = []
405
- op_dict["output_struct"] = []
406
- op_dict["summary"] = []
423
+ op_dict[CompareConst.INPUT_STRUCT] = []
424
+ op_dict[CompareConst.KWARGS_STRUCT] = []
425
+ op_dict[CompareConst.OUTPUT_STRUCT] = []
426
+ op_dict[Const.SUMMARY] = []
407
427
  op_dict["stack_info"] = []
408
428
 
409
- all_mode_bool = not (summary_compare or md5_compare)
410
- if all_mode_bool:
429
+ if dump_mode == Const.ALL:
411
430
  op_dict["data_name"] = []
412
431
 
413
432
  for tensor in tensor_list:
@@ -416,38 +435,44 @@ def merge_tensor(tensor_list, summary_compare, md5_compare):
416
435
  break
417
436
  op_dict["op_name"].append(tensor['full_op_name'])
418
437
  name_ele_list = tensor['full_op_name'].split(Const.SEP)
419
- if not md5_compare:
420
- if "input" in name_ele_list:
421
- op_dict["input_struct"].append((tensor['dtype'], tensor['shape']))
422
- elif "kwarg" in name_ele_list:
423
- op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape']))
424
- elif "output" in name_ele_list:
425
- op_dict["output_struct"].append((tensor['dtype'], tensor['shape']))
426
- else:
427
- if "input" in name_ele_list:
428
- op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
429
- if "kwarg" in name_ele_list:
430
- op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
431
- elif "output" in name_ele_list:
432
- op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
433
- op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
434
-
435
- if all_mode_bool:
438
+ name_to_struct_mapping = {
439
+ Const.INPUT: CompareConst.INPUT_STRUCT,
440
+ Const.KWARGS: CompareConst.KWARGS_STRUCT,
441
+ Const.OUTPUT: CompareConst.OUTPUT_STRUCT
442
+ }
443
+ for name_key, struct_key in name_to_struct_mapping.items():
444
+ if name_key in name_ele_list:
445
+ if dump_mode == Const.MD5:
446
+ op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
447
+ else:
448
+ op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
449
+ break
450
+ op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]])
451
+
452
+ if dump_mode == Const.ALL:
436
453
  op_dict["data_name"].append(tensor['data_name'])
437
- data_name = op_dict["data_name"][-1].rsplit(Const.SEP, 1)[0]
454
+ data_name = safe_get_value(op_dict, -1, "op_dict", key="data_name").rsplit(Const.SEP, 1)[0]
438
455
  if data_name != "-1":
439
456
  op_dict["op_name"][-1] = data_name
440
457
 
441
- if not op_dict["kwargs_struct"]:
442
- del op_dict["kwargs_struct"]
458
+ if not op_dict[CompareConst.KWARGS_STRUCT]:
459
+ del op_dict[CompareConst.KWARGS_STRUCT]
443
460
  return op_dict if op_dict["op_name"] else {}
444
461
 
445
462
 
463
+ def print_compare_ends_info():
464
+ total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
465
+ logger.info('*' * total_len)
466
+ logger.info(f"*{CompareConst.COMPARE_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
467
+ logger.info('*' * total_len)
468
+
469
+
446
470
  def _compare_parser(parser):
447
471
  parser.add_argument("-i", "--input_path", dest="input_path", type=str,
448
472
  help="<Required> The compare input path, a dict json.", required=True)
449
473
  parser.add_argument("-o", "--output_path", dest="output_path", type=str,
450
- help="<Required> The compare task result out path.", required=True)
474
+ help="<Required> The compare task result out path. Default path: ./output",
475
+ required=False, default="./output", nargs="?", const="./output")
451
476
  parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
452
477
  help="<optional> Whether to save stack info.", required=False)
453
478
  parser.add_argument("-c", "--compare_only", dest="compare_only", action="store_true",
@@ -457,8 +482,8 @@ def _compare_parser(parser):
457
482
  parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True,
458
483
  help="<optional> The cell mapping file path.", required=False)
459
484
  parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
460
- help="<optional> The api mapping file path.", required=False)
485
+ help="<optional> The api mapping file path.", required=False)
461
486
  parser.add_argument("-dm", "--data_mapping", dest="data_mapping", type=str,
462
487
  help="<optional> The data mapping file path.", required=False)
463
- parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str,
488
+ parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True,
464
489
  help="<optional> The layer mapping file path.", required=False)