mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,28 +1,45 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
1
15
 
2
16
  import os
3
17
  import re
18
+ import math
19
+ import zlib
20
+ from dataclasses import dataclass
21
+
4
22
  import numpy as np
23
+
5
24
  from msprobe.core.common.const import Const, CompareConst
6
- from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger
25
+ from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value
7
26
  from msprobe.core.common.file_utils import check_file_or_directory_path
8
27
 
9
28
 
10
29
  def extract_json(dirname, stack_json=False):
11
30
  json_path = ''
12
- for fname in os.listdir(dirname):
13
- if fname == "construct.json":
14
- continue
15
- full_path = os.path.join(dirname, fname)
16
- if full_path.endswith('.json'):
17
- json_path = full_path
18
- if not stack_json and 'stack' not in json_path:
19
- break
20
- if stack_json and 'stack' in json_path:
21
- break
31
+ for filename in os.listdir(dirname):
32
+ target_file_name = 'stack.json' if stack_json else 'dump.json'
33
+ if filename == target_file_name:
34
+ json_path = os.path.join(dirname, filename)
35
+ break
22
36
 
23
37
  # Provide robustness on invalid directory inputs
24
38
  if not json_path:
25
- logger.error(f'No file is found in dump dir {dirname}. ')
39
+ if stack_json:
40
+ logger.error(f'stack.json is not found in dump dir {dirname}.')
41
+ else:
42
+ logger.error(f'dump.json is not found in dump dir {dirname}.')
26
43
  raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
27
44
  return json_path
28
45
 
@@ -30,7 +47,7 @@ def extract_json(dirname, stack_json=False):
30
47
  def check_and_return_dir_contents(dump_dir, prefix):
31
48
  """
32
49
  check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
33
- pattern: ^{prefix}(?:0|[0-9][1-9]*)?$
50
+ pattern: ^{prefix}(?:0|[1-9][0-9]*)?$
34
51
 
35
52
  Args:
36
53
  dump_dir (str): dump dir
@@ -46,7 +63,7 @@ def check_and_return_dir_contents(dump_dir, prefix):
46
63
  check_regex_prefix_format_valid(prefix)
47
64
  check_file_or_directory_path(dump_dir, True)
48
65
  contents = os.listdir(dump_dir)
49
- pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$')
66
+ pattern = re.compile(rf'^{prefix}(?:0|[1-9][0-9]*)?$')
50
67
  for name in contents:
51
68
  if not pattern.match(name):
52
69
  logger.error(
@@ -59,122 +76,100 @@ def check_and_return_dir_contents(dump_dir, prefix):
59
76
 
60
77
  def rename_api(npu_name, process):
61
78
  npu_split = npu_name.split(process)
62
- torch_func_index, in_out = npu_split[0], npu_split[1]
79
+ try:
80
+ torch_func_index, in_out = npu_split[0], npu_split[1]
81
+ except IndexError as error:
82
+ logger.error(f'{npu_name} can not be split with {process}, please check!')
83
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
63
84
  torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
64
85
  torch_func = str(torch_func_split[0]) + str(in_out)
65
86
  return torch_func
66
87
 
67
88
 
68
89
  def read_op(op_data, op_name):
69
- op_parsed_list = Const.DEFAULT_LIST
70
- if Const.FORWARD in op_name:
71
- if Const.INPUT_ARGS in op_data:
72
- input_item = op_data[Const.INPUT_ARGS]
73
- input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
74
- op_parsed_list = input_parsed_list.copy()
75
- input_parsed_list.clear()
76
- if Const.INPUT_KWARGS in op_data:
77
- kwargs_item = op_data[Const.INPUT_KWARGS]
78
- if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list):
79
- kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '.input', None)
80
- op_parsed_list += kwarg_parsed_list
81
- kwarg_parsed_list.clear()
82
- elif kwargs_item:
83
- for kwarg in kwargs_item:
84
- kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '.input.' + kwarg, None)
85
- op_parsed_list += kwarg_parsed_list
86
- kwarg_parsed_list.clear()
87
- if Const.OUTPUT in op_data:
88
- output_item = op_data[Const.OUTPUT]
89
- output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
90
- op_parsed_list += output_parsed_list
91
- output_parsed_list.clear()
92
- if Const.BACKWARD in op_name:
93
- if Const.INPUT in op_data:
94
- input_item = op_data[Const.INPUT]
95
- input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
96
- op_parsed_list = input_parsed_list.copy()
97
- input_parsed_list.clear()
98
- if Const.OUTPUT in op_data:
99
- output_item = op_data[Const.OUTPUT]
100
- output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
101
- op_parsed_list += output_parsed_list
102
- output_parsed_list.clear()
90
+ io_name_mapping = {
91
+ Const.INPUT_ARGS: '.input',
92
+ Const.INPUT_KWARGS: '.input',
93
+ Const.INPUT: '.input',
94
+ Const.OUTPUT: '.output'
95
+ }
96
+
97
+ op_parsed_list = []
98
+ for name in io_name_mapping:
99
+ if name in op_data:
100
+ op_parsed_list.extend(op_item_parse(op_data[name], op_name + io_name_mapping[name]))
103
101
  return op_parsed_list
104
102
 
105
103
 
106
- def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
107
- if item_list is None:
108
- item_list = []
109
- if item is None or (isinstance(item, dict) and not item):
110
- if not top_bool:
111
- tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None,
112
- 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'}
113
- else:
114
- tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None,
115
- 'shape': None, 'md5': None, 'data_name': '-1'}
116
- item_list.append(tmp)
117
- return item_list
118
- if index is None:
119
- if isinstance(item, dict):
120
- full_op_name = op_name + '.0'
121
- else:
122
- full_op_name = op_name
123
- else:
124
- full_op_name = op_name + Const.SEP + str(index)
125
- if isinstance(item, dict):
126
- if 'type' not in item:
127
- for kwarg in item:
128
- kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None)
129
- item_list += kwarg_parsed_list
130
- kwarg_parsed_list.clear()
131
- elif 'dtype' in item:
132
- parsed_item = item
133
- parsed_item['full_op_name'] = full_op_name
134
- item_list.append(parsed_item)
135
- elif 'type' in item:
136
- parsed_item = {}
137
- if item['type'] == 'torch.Size':
138
- parsed_item['full_op_name'] = full_op_name
139
- parsed_item['dtype'] = 'torch.Size'
140
- parsed_item['shape'] = str(item['value'])
141
- parsed_item['md5'] = None
142
- parsed_item['Max'] = None
143
- parsed_item['Min'] = None
144
- parsed_item['Mean'] = None
145
- parsed_item['Norm'] = None
146
- parsed_item['data_name'] = '-1'
147
- item_list.append(parsed_item)
148
- elif item['type'] == 'slice':
149
- parsed_item['full_op_name'] = full_op_name
150
- parsed_item['dtype'] = 'slice'
151
- parsed_item['shape'] = str(np.shape(np.array(item['value'])))
152
- parsed_item['md5'] = None
153
- parsed_item['Max'] = None
154
- parsed_item['Min'] = None
155
- parsed_item['Mean'] = None
156
- parsed_item['Norm'] = None
157
- parsed_item['data_name'] = '-1'
158
- item_list.append(parsed_item)
159
- else:
160
- parsed_item['full_op_name'] = full_op_name
161
- parsed_item['dtype'] = str(type(item['value']))
162
- parsed_item['shape'] = '[]'
163
- parsed_item['md5'] = None
164
- parsed_item['Max'] = item['value']
165
- parsed_item['Min'] = item['value']
166
- parsed_item['Mean'] = item['value']
167
- parsed_item['Norm'] = item['value']
168
- parsed_item['data_name'] = '-1'
169
- item_list.append(parsed_item)
170
- else:
171
- resolve_api_special_parameters(item, full_op_name, item_list)
172
- else:
173
- for j, item_spec in enumerate(item):
174
- op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False)
104
+ def op_item_parse(op_data, op_name: str, depth: int = 0) -> list:
105
+ default_item = {
106
+ 'full_op_name': op_name,
107
+ 'type': None,
108
+ 'Max': None,
109
+ 'Min': None,
110
+ 'Mean': None,
111
+ 'Norm': None,
112
+ 'dtype': None,
113
+ 'shape': None,
114
+ 'md5': None,
115
+ 'value': None,
116
+ 'data_name': '-1'
117
+ }
118
+
119
+ if depth > Const.MAX_DEPTH:
120
+ logger.error(f'parse of api/module of {op_name} exceeds the recursion limit.')
121
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
122
+
123
+ if op_data is None:
124
+ return [default_item]
125
+ elif not op_data:
126
+ return []
127
+
128
+ item_list = []
129
+ if isinstance(op_data, list):
130
+ for i, data in enumerate(op_data):
131
+ item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), depth + 1))
132
+ elif isinstance(op_data, dict):
133
+ if is_leaf_data(op_data):
134
+ return [gen_op_item(op_data, op_name)]
135
+ for sub_name, sub_data in op_data.items():
136
+ item_list.extend(op_item_parse(sub_data, op_name + Const.SEP + str(sub_name), depth + 1))
175
137
  return item_list
176
138
 
177
139
 
140
+ def is_leaf_data(op_data):
141
+ return 'type' in op_data and isinstance(op_data['type'], str)
142
+
143
+
144
+ def gen_op_item(op_data, op_name):
145
+ op_item = {}
146
+ op_item.update(op_data)
147
+ op_item['full_op_name'] = op_name
148
+ op_item['data_name'] = op_data.get('data_name', '-1')
149
+
150
+ params = ['Max', 'Min', 'Mean', 'Norm']
151
+ for i in params:
152
+ if i not in op_item:
153
+ op_item[i] = None
154
+
155
+ if not op_item.get('dtype'):
156
+ if op_item.get('type') == 'torch.Size':
157
+ op_item['dtype'] = op_data.get('type')
158
+ op_item['shape'] = str(op_data.get('value'))
159
+ elif op_item.get('type') == 'slice':
160
+ op_item['dtype'] = op_data.get('type')
161
+ op_item['shape'] = str(np.shape(np.array(op_data.get('value'))))
162
+ else:
163
+ op_item['dtype'] = str(type(op_data.get('value')))
164
+ op_item['shape'] = '[]'
165
+ for i in params:
166
+ op_item[i] = op_data.get('value')
167
+ if not op_item.get('md5'):
168
+ op_item['md5'] = f"{zlib.crc32(str(op_data.get('value', '')).encode()):08x}"
169
+
170
+ return op_item
171
+
172
+
178
173
  def resolve_api_special_parameters(data_dict, full_op_name, item_list):
179
174
  """
180
175
  Function Description:
@@ -206,139 +201,196 @@ def resolve_api_special_parameters(data_dict, full_op_name, item_list):
206
201
  item_list.append(parsed_item)
207
202
 
208
203
 
209
- def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False):
204
+ def process_summary_data(summary_data):
205
+ """处理summary_data中的nan值,返回处理后的列表"""
206
+ return [CompareConst.NAN if isinstance(x, float) and math.isnan(x) else x for x in summary_data]
207
+
208
+
209
+ def get_rela_diff_summary_mode(result_item, npu_summary_data, bench_summary_data, err_msg):
210
+ start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
211
+ warning_flag = False
212
+ for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
213
+ if all(isinstance(val, (float, int)) and not isinstance(val, bool) for val in [npu_val, bench_val]):
214
+ diff = npu_val - bench_val
215
+ if math.isnan(diff):
216
+ diff = CompareConst.NAN
217
+ relative = CompareConst.NAN
218
+ else:
219
+ if bench_val != 0:
220
+ relative = str(abs((diff / bench_val) * 100)) + '%'
221
+ else:
222
+ relative = CompareConst.N_A
223
+ magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + CompareConst.EPSILON)
224
+ if magnitude_diff > CompareConst.MAGNITUDE:
225
+ warning_flag = True
226
+ result_item[start_idx + i] = diff
227
+ result_item[start_idx + i + CompareConst.STATISTICS_INDICATOR_NUM] = relative
228
+ else:
229
+ result_item[start_idx + i] = CompareConst.N_A
230
+ result_item[start_idx + i + CompareConst.STATISTICS_INDICATOR_NUM] = CompareConst.N_A
231
+
232
+ accuracy_check = CompareConst.WARNING if warning_flag else ""
233
+ err_msg += "Need double check api accuracy." if warning_flag else ""
234
+ for i in range(start_idx, len(result_item)):
235
+ if str(result_item[i]) in ('inf', '-inf', 'nan'):
236
+ result_item[i] = f'{result_item[i]}\t'
237
+ return result_item, accuracy_check, err_msg
238
+
239
+
240
+ @dataclass
241
+ class ApiItemInfo:
242
+ name: str
243
+ struct: tuple
244
+ stack_info: list
245
+
246
+
247
+ def stack_column_process(result_item, has_stack, index, key, npu_stack_info):
248
+ if has_stack and index == 0 and key == CompareConst.INPUT_STRUCT:
249
+ result_item.extend(npu_stack_info)
250
+ else:
251
+ result_item.append(CompareConst.NONE)
252
+ return result_item
253
+
254
+
255
+ def result_item_init(n_info, b_info, dump_mode):
256
+ n_len = len(n_info.struct)
257
+ b_len = len(b_info.struct)
258
+ struct_long_enough = (n_len > 2 and b_len > 2) if dump_mode == Const.MD5 else (n_len > 1 and b_len > 1)
259
+ if struct_long_enough:
260
+ result_item = [
261
+ n_info.name, b_info.name, n_info.struct[0], b_info.struct[0], n_info.struct[1], b_info.struct[1]
262
+ ]
263
+ if dump_mode == Const.MD5:
264
+ md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF
265
+ result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result])
266
+ elif dump_mode == Const.SUMMARY:
267
+ result_item.extend([" "] * 8)
268
+ else:
269
+ result_item.extend([" "] * 5)
270
+ else:
271
+ err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \
272
+ f"npu_info_struct is {n_info.struct}\n" \
273
+ f"bench_info_struct is {b_info.struct}"
274
+ logger.error(err_msg)
275
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
276
+ return result_item
277
+
278
+
279
+ def get_accuracy(result, n_dict, b_dict, dump_mode):
210
280
  def get_accuracy_core(n_start, n_len, b_start, b_len, key):
211
281
  min_len = min(n_len, b_len)
212
282
  npu_stack_info = n_dict.get("stack_info", None)
213
283
  bench_stack_info = b_dict.get("stack_info", None)
214
284
  has_stack = npu_stack_info and bench_stack_info
215
285
 
216
- all_mode_bool = not (summary_compare or md5_compare)
217
- if all_mode_bool:
286
+ if dump_mode == Const.ALL:
218
287
  npu_data_name = n_dict.get("data_name", None)
219
288
  bench_data_name = b_dict.get("data_name", None)
220
289
 
221
290
  for index in range(min_len):
222
-
223
- n_name = n_dict['op_name'][n_start + index]
224
- b_name = b_dict['op_name'][b_start + index]
225
- n_struct = n_dict[key][index]
226
- b_struct = b_dict[key][index]
291
+ n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name")
292
+ b_name = safe_get_value(b_dict, b_start + index, "b_dict", key="op_name")
293
+ n_struct = safe_get_value(n_dict, index, "n_dict", key=key)
294
+ b_struct = safe_get_value(b_dict, index, "b_dict", key=key)
227
295
  err_msg = ""
228
- if md5_compare:
229
- result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
230
- n_struct[2], b_struct[2],
231
- CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF]
232
- if has_stack and index == 0 and key == "input_struct":
233
- result_item.extend(npu_stack_info)
234
- else:
235
- result_item.append(CompareConst.NONE)
296
+
297
+ npu_info = ApiItemInfo(n_name, n_struct, npu_stack_info)
298
+ bench_info = ApiItemInfo(b_name, b_struct, bench_stack_info)
299
+ result_item = result_item_init(npu_info, bench_info, dump_mode)
300
+
301
+ if dump_mode == Const.MD5:
302
+ result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
236
303
  result.append(result_item)
237
304
  continue
238
305
 
239
- if summary_compare:
240
- result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
241
- " ", " ", " ", " ", " ", " ", " ", " "]
242
- else:
243
- result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
244
- " ", " ", " ", " ", " "]
245
-
246
- npu_summary_data = n_dict.get("summary")[n_start + index]
247
- result_item.extend(npu_summary_data)
248
- bench_summary_data = b_dict.get("summary")[b_start + index]
249
- result_item.extend(bench_summary_data)
250
-
251
- if summary_compare:
252
- start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
253
- warning_flag = False
254
- for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
255
- if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
256
- diff = npu_val - bench_val
257
- if bench_val != 0:
258
- relative = str(abs((diff / bench_val) * 100)) + '%'
259
- else:
260
- relative = "N/A"
261
- result_item[start_idx + i] = diff
262
- result_item[start_idx + i + 4] = relative
263
- magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
264
- if magnitude_diff > 0.5:
265
- warning_flag = True
266
- else:
267
- result_item[start_idx + i] = CompareConst.NONE
268
- accuracy_check = CompareConst.WARNING if warning_flag else ""
269
- err_msg += "Need double check api accuracy." if warning_flag else ""
270
- for i in range(start_idx, len(result_item)):
271
- if str(result_item[i]) in ('inf', '-inf', 'nan'):
272
- result_item[i] = f'{result_item[i]}\t'
273
-
274
- result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES)
306
+ npu_summary_data = safe_get_value(n_dict, n_start + index, "n_dict", key=CompareConst.SUMMARY)
307
+ bench_summary_data = safe_get_value(b_dict, b_start + index, "b_dict", key=CompareConst.SUMMARY)
308
+ result_item.extend(process_summary_data(npu_summary_data))
309
+ result_item.extend(process_summary_data(bench_summary_data))
310
+
311
+ if dump_mode == Const.SUMMARY:
312
+ result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
313
+ bench_summary_data, err_msg)
314
+
315
+ result_item.append(accuracy_check if dump_mode == Const.SUMMARY else CompareConst.ACCURACY_CHECK_YES)
275
316
  result_item.append(err_msg)
276
- if has_stack and index == 0 and key == "input_struct":
277
- result_item.extend(npu_stack_info)
278
- else:
279
- result_item.append(CompareConst.NONE)
280
- if all_mode_bool:
281
- result_item.append(npu_data_name[n_start + index])
317
+ result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
318
+ if dump_mode == Const.ALL:
319
+ result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
282
320
 
283
321
  result.append(result_item)
284
322
 
285
323
  if n_len > b_len:
286
324
  for index in range(b_len, n_len):
287
- n_name = n_dict['op_name'][n_start + index]
288
- n_struct = n_dict[key][index]
289
- if md5_compare:
290
- result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
291
- n_struct[1], CompareConst.NAN, n_struct[2], CompareConst.NAN, CompareConst.NAN]
292
- result.append(result_item)
293
- continue
294
- result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
295
- n_struct[1], CompareConst.NAN, " ", " ", " ", " ", " "]
296
- summary_data = n_dict.get("summary")[n_start + index]
297
- result_item.extend(summary_data)
298
- summary_data = [CompareConst.NAN for _ in range(len(n_dict.get("summary")[0]))]
299
- result_item.extend(summary_data)
325
+ try:
326
+ n_name = n_dict['op_name'][n_start + index]
327
+ n_struct = n_dict[key][index]
328
+ if dump_mode == Const.MD5:
329
+ result_item = [
330
+ n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
331
+ n_struct[2], CompareConst.NAN, CompareConst.NAN
332
+ ]
333
+ result.append(result_item)
334
+ continue
335
+ result_item = [
336
+ n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
337
+ " ", " ", " ", " ", " "
338
+ ]
339
+ summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
340
+ result_item.extend(summary_data)
341
+ summary_data = [CompareConst.NAN for _ in range(len(n_dict.get(CompareConst.SUMMARY)[0]))]
342
+ result_item.extend(summary_data)
343
+ except IndexError as e:
344
+ err_msg = "index out of bounds error occurs, please check!\n" \
345
+ f"n_dict is {n_dict}"
346
+ logger.error(err_msg)
347
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
300
348
 
301
349
  err_msg = ""
302
350
  result_item.append(CompareConst.ACCURACY_CHECK_YES)
303
351
  result_item.append(err_msg)
304
-
305
- if has_stack and index == 0 and key == "input_struct":
306
- result_item.extend(npu_stack_info)
307
- else:
308
- result_item.append(CompareConst.NONE)
309
- if all_mode_bool:
310
- result_item.append(npu_data_name[n_start + index])
352
+ result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
353
+ if dump_mode == Const.ALL:
354
+ result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
311
355
 
312
356
  result.append(result_item)
313
357
 
314
358
  n_num = len(n_dict['op_name'])
315
359
  b_num = len(b_dict['op_name'])
316
- n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name])
317
- b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name])
318
- n_num_kwarg = len([name for name in n_dict['op_name'] if 'kwarg' in name])
319
- b_num_kwarg = len([name for name in b_dict['op_name'] if 'kwarg' in name])
320
- n_num_output = n_num - n_num_input - n_num_kwarg
321
- b_num_output = b_num - b_num_input - b_num_kwarg
360
+ n_num_input = len([name for name in n_dict['op_name']
361
+ if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
362
+ b_num_input = len([name for name in b_dict['op_name']
363
+ if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
364
+ n_num_output = n_num - n_num_input
365
+ b_num_output = b_num - b_num_input
322
366
  get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
323
- get_accuracy_core(n_num_input, n_num_kwarg, b_num_input, b_num_kwarg, "kwargs_struct")
324
- get_accuracy_core(n_num_input + n_num_kwarg, n_num_output, b_num_input + b_num_kwarg, b_num_output, 'output_struct')
367
+ get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, 'output_struct')
325
368
 
326
369
 
327
- def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
370
+ def get_un_match_accuracy(result, n_dict, dump_mode):
328
371
  index_out = 0
329
372
  npu_stack_info = n_dict.get("stack_info", None)
330
373
  bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
331
374
  err_msg = CompareConst.NO_BENCH
332
375
  accuracy_check_res = CompareConst.N_A
333
376
  for index, n_name in enumerate(n_dict["op_name"]):
334
- if n_name.find("input") != -1:
335
- n_struct = n_dict["input_struct"][index]
336
- else:
337
- n_struct = n_dict["output_struct"][index_out]
377
+ name_ele_list = n_name.split(Const.SEP)
378
+ if Const.INPUT in name_ele_list or Const.KWARGS in name_ele_list:
379
+ n_struct = safe_get_value(n_dict, index, "n_dict", key=CompareConst.INPUT_STRUCT)
380
+ if Const.OUTPUT in name_ele_list:
381
+ n_struct = safe_get_value(n_dict, index_out, "n_dict", key=CompareConst.OUTPUT_STRUCT)
338
382
  index_out += 1
339
383
 
340
- result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
341
- if md5_compare:
384
+ try:
385
+ result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
386
+ except IndexError as e:
387
+ err_msg = "index out of bounds error occurs, please check!\n" \
388
+ f"op_name of n_dict is {n_dict['op_name']}\n" \
389
+ f"input_struct of n_dict is {n_dict[CompareConst.INPUT_STRUCT]}\n" \
390
+ f"output_struct of n_dict is {n_dict[CompareConst.OUTPUT_STRUCT]}"
391
+ logger.error(err_msg)
392
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
393
+ if dump_mode == Const.MD5:
342
394
  result_item.extend([CompareConst.N_A] * 3)
343
395
  if npu_stack_info and index == 0:
344
396
  result_item.extend(npu_stack_info)
@@ -346,11 +398,11 @@ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
346
398
  result_item.append(CompareConst.NONE)
347
399
  result.append(result_item)
348
400
  continue
349
- if summary_compare:
401
+ if dump_mode == Const.SUMMARY:
350
402
  result_item.extend([CompareConst.N_A] * 8)
351
403
  else:
352
404
  result_item.extend([CompareConst.N_A] * 5)
353
- npu_summary_data = n_dict.get("summary")[index]
405
+ npu_summary_data = safe_get_value(n_dict, index, "n_dict", key=CompareConst.SUMMARY)
354
406
  result_item.extend(npu_summary_data)
355
407
  bench_summary_data = [CompareConst.N_A] * 4
356
408
  result_item.extend(bench_summary_data)
@@ -360,22 +412,21 @@ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
360
412
  result_item.extend(npu_stack_info)
361
413
  else:
362
414
  result_item.append(CompareConst.NONE)
363
- if not md5_compare and not summary_compare and result_item[1] == CompareConst.N_A:
415
+ if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
364
416
  result_item.extend(["-1"])
365
417
  result.append(result_item)
366
418
 
367
419
 
368
- def merge_tensor(tensor_list, summary_compare, md5_compare):
420
+ def merge_tensor(tensor_list, dump_mode):
369
421
  op_dict = {}
370
422
  op_dict["op_name"] = []
371
- op_dict["input_struct"] = []
372
- op_dict["kwargs_struct"] = []
373
- op_dict["output_struct"] = []
374
- op_dict["summary"] = []
423
+ op_dict[CompareConst.INPUT_STRUCT] = []
424
+ op_dict[CompareConst.KWARGS_STRUCT] = []
425
+ op_dict[CompareConst.OUTPUT_STRUCT] = []
426
+ op_dict[Const.SUMMARY] = []
375
427
  op_dict["stack_info"] = []
376
428
 
377
- all_mode_bool = not (summary_compare or md5_compare)
378
- if all_mode_bool:
429
+ if dump_mode == Const.ALL:
379
430
  op_dict["data_name"] = []
380
431
 
381
432
  for tensor in tensor_list:
@@ -383,36 +434,45 @@ def merge_tensor(tensor_list, summary_compare, md5_compare):
383
434
  op_dict['stack_info'].append(tensor['full_info'])
384
435
  break
385
436
  op_dict["op_name"].append(tensor['full_op_name'])
386
- if not md5_compare:
387
- if tensor['full_op_name'].find("input") != -1:
388
- op_dict["input_struct"].append((tensor['dtype'], tensor['shape']))
389
- elif tensor['full_op_name'].find("kwarg") != -1:
390
- op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape']))
391
- elif tensor['full_op_name'].find("output") != -1:
392
- op_dict["output_struct"].append((tensor['dtype'], tensor['shape']))
393
- else:
394
- if tensor['full_op_name'].find("input") != -1:
395
- op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
396
- elif tensor['full_op_name'].find("kwarg") != -1:
397
- op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
398
- elif tensor['full_op_name'].find("output") != -1:
399
- op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
400
-
401
- op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
437
+ name_ele_list = tensor['full_op_name'].split(Const.SEP)
438
+ name_to_struct_mapping = {
439
+ Const.INPUT: CompareConst.INPUT_STRUCT,
440
+ Const.KWARGS: CompareConst.KWARGS_STRUCT,
441
+ Const.OUTPUT: CompareConst.OUTPUT_STRUCT
442
+ }
443
+ for name_key, struct_key in name_to_struct_mapping.items():
444
+ if name_key in name_ele_list:
445
+ if dump_mode == Const.MD5:
446
+ op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
447
+ else:
448
+ op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
449
+ break
450
+ op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]])
402
451
 
403
- if all_mode_bool:
452
+ if dump_mode == Const.ALL:
404
453
  op_dict["data_name"].append(tensor['data_name'])
454
+ data_name = safe_get_value(op_dict, -1, "op_dict", key="data_name").rsplit(Const.SEP, 1)[0]
455
+ if data_name != "-1":
456
+ op_dict["op_name"][-1] = data_name
405
457
 
406
- if not op_dict["kwargs_struct"]:
407
- del op_dict["kwargs_struct"]
458
+ if not op_dict[CompareConst.KWARGS_STRUCT]:
459
+ del op_dict[CompareConst.KWARGS_STRUCT]
408
460
  return op_dict if op_dict["op_name"] else {}
409
461
 
410
462
 
463
+ def print_compare_ends_info():
464
+ total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
465
+ logger.info('*' * total_len)
466
+ logger.info(f"*{CompareConst.COMPARE_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
467
+ logger.info('*' * total_len)
468
+
469
+
411
470
  def _compare_parser(parser):
412
471
  parser.add_argument("-i", "--input_path", dest="input_path", type=str,
413
- help="<Required> The compare input path, a dict json.", required=True)
472
+ help="<Required> The compare input path, a dict json.", required=True)
414
473
  parser.add_argument("-o", "--output_path", dest="output_path", type=str,
415
- help="<Required> The compare task result out path.", required=True)
474
+ help="<Required> The compare task result out path. Default path: ./output",
475
+ required=False, default="./output", nargs="?", const="./output")
416
476
  parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
417
477
  help="<optional> Whether to save stack info.", required=False)
418
478
  parser.add_argument("-c", "--compare_only", dest="compare_only", action="store_true",
@@ -423,8 +483,7 @@ def _compare_parser(parser):
423
483
  help="<optional> The cell mapping file path.", required=False)
424
484
  parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
425
485
  help="<optional> The api mapping file path.", required=False)
426
-
427
-
428
-
429
-
430
-
486
+ parser.add_argument("-dm", "--data_mapping", dest="data_mapping", type=str,
487
+ help="<optional> The data mapping file path.", required=False)
488
+ parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True,
489
+ help="<optional> The layer mapping file path.", required=False)