mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -1,21 +1,47 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  import re
3
- import copy
4
- import sys
5
- from itertools import zip_longest
6
18
 
7
- from msprobe.core.common.utils import check_compare_param, CompareException, check_configuration_param, \
8
- task_dumppath_get, struct_json_get, add_time_with_yaml
9
- from msprobe.core.common.file_utils import create_directory, load_yaml, load_npy, load_json, save_yaml, FileOpen
10
- from msprobe.core.common.const import Const, CompareConst
11
- from msprobe.core.common.log import logger
19
+ from collections import defaultdict
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+
24
+ from msprobe.core.common.const import CompareConst, Const
12
25
  from msprobe.core.common.exceptions import FileCheckException
26
+ from msprobe.core.common.file_utils import (FileOpen, create_directory, load_json,
27
+ load_npy, load_yaml)
28
+ from msprobe.core.common.log import logger
29
+ from msprobe.core.common.utils import (CompareException, check_compare_param,
30
+ check_configuration_param,
31
+ get_dump_mode, set_dump_path, check_op_str_pattern_valid)
32
+ from msprobe.core.compare.check import dtype_mapping
13
33
  from msprobe.core.compare.acc_compare import Comparator
14
- from msprobe.core.compare.check import check_struct_match, fuzzy_check_op
15
- from msprobe.mindspore.compare.modify_mapping import modify_mapping_with_stack
16
- from msprobe.mindspore.compare.layer_mapping import get_layer_mapping
34
+ from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping
35
+
17
36
 
18
37
  class MSComparator(Comparator):
38
+ """
39
+ 用于mindspore动态图同框架/跨框架精度比对,支持md5/summary/all模式。
40
+ cell_mapping: mindspore在cell级别(L0)dump数据和pytorch的module之间的映射关系;
41
+ api_mapping: mindspore在api级别(L1)dump数据和pytorch的api之间的映射关系;
42
+ data_mapping: mindspore的cell或api的入参/出参和pytorch之间的映射关系;
43
+ is_cross_framework: 是否跨框架。
44
+ """
19
45
  def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None, is_cross_framework=False):
20
46
  self.frame_name = MSComparator.__name__
21
47
  self.cell_mapping = cell_mapping
@@ -37,10 +63,108 @@ class MSComparator(Comparator):
37
63
  else:
38
64
  raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
39
65
  f"{type(self.data_mapping)}")
66
+
67
+ @classmethod
68
+ def calc_accuracy(cls, result_df, dump_mode, header):
69
+ condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
70
+ result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
71
+ result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH
72
+
73
+ def calc_summary_diff(data_type: str):
74
+ def type_check(val):
75
+ check_series = pd.Series(False, index=val.index)
76
+ val_str = val.astype(str)
77
+ check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True
78
+ return check_series
79
+
80
+ def get_number(val):
81
+ return pd.to_numeric(val.astype(str), errors='coerce')
82
+
83
+ ms_val = result_df['NPU ' + data_type]
84
+ pt_val = result_df['Bench ' + data_type]
85
+ diff_name = data_type.capitalize() + ' diff'
86
+ rel_err_name = ('norm' if data_type == 'l2norm' else data_type).capitalize() + 'RelativeErr'
87
+ condition_na = ~type_check(ms_val) | ~type_check(pt_val)
88
+ result_df.loc[condition_na, [diff_name, rel_err_name]] = CompareConst.N_A
89
+ result_df.loc[~(condition_no_bench | condition_na), diff_name] = get_number(ms_val) - get_number(pt_val)
90
+ condition_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].isna()
91
+ condition_not_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].notna()
92
+ result_df.loc[condition_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN
93
+ condition_pt_zero = pt_val == 0
94
+ result_df.loc[condition_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.NAN
95
+ condition_ref_err = condition_not_nan_diff & ~condition_pt_zero
96
+ result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, diff_name] /
97
+ pt_val[condition_ref_err] * 100)
98
+ result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, rel_err_name]
99
+ .abs().astype(str) + '%')
100
+ magnitude = get_number(result_df[diff_name]).abs() / (
101
+ pd.Series(np.maximum(get_number(ms_val), get_number(pt_val))).abs() + CompareConst.EPSILON)
102
+ return magnitude > CompareConst.MAGNITUDE
103
+
104
+ if dump_mode == Const.MD5:
105
+ condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5]
106
+ result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS
107
+ result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
108
+ elif dump_mode == Const.SUMMARY:
109
+ warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']]
110
+ warning_flag = pd.DataFrame(warning_list).all()
111
+ result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
112
+ result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
113
+ result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
114
+ else:
115
+ fill_cols = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
116
+ CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
117
+ CompareConst.ERROR_MESSAGE]
118
+ result_df.loc[~condition_no_bench, fill_cols] = ''
119
+ result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES
120
+ return result_df[header]
121
+
122
+ @classmethod
123
+ def make_result_df(cls, result, stack_mode, dump_mode):
124
+ header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode]
125
+
126
+ if stack_mode:
127
+ header.append(CompareConst.STACK)
128
+ if dump_mode == Const.ALL:
129
+ header.append(CompareConst.DATA_NAME)
130
+ result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
131
+ 'op_name_y': CompareConst.BENCH_NAME,
132
+ 'dtype_x': CompareConst.NPU_DTYPE,
133
+ 'dtype_y': CompareConst.BENCH_DTYPE,
134
+ 'shape_x': CompareConst.NPU_SHAPE,
135
+ 'shape_y': CompareConst.BENCH_SHAPE,
136
+ 'md5_x': CompareConst.NPU_MD5,
137
+ 'md5_y': CompareConst.BENCH_MD5,
138
+ 'data_name_x': CompareConst.DATA_NAME,
139
+ 'stack_info_x': CompareConst.STACK}, inplace=True)
140
+
141
+ npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
142
+ bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
143
+ CompareConst.BENCH_NORM]
144
+ def set_summary(summary):
145
+ if summary == CompareConst.N_A:
146
+ return [CompareConst.N_A] * 4
147
+ summary_list = []
148
+ for i in summary:
149
+ if i is None:
150
+ summary_list.append(CompareConst.N_A)
151
+ elif str(i).lower() == 'nan':
152
+ summary_list.append(CompareConst.NAN)
153
+ else:
154
+ summary_list.append(i)
155
+ return summary_list
156
+
157
+ result[npu_summary] = result['summary_x'].apply(set_summary).tolist()
158
+ result[bench_summary] = result['summary_y'].apply(set_summary).tolist()
159
+ result_df = pd.DataFrame(columns=header)
160
+ for h in header:
161
+ if h in result.columns:
162
+ result_df[h] = result[h]
163
+ return cls.calc_accuracy(result_df, dump_mode, header)
40
164
 
41
165
  def load_internal_api(self):
42
166
  cur_path = os.path.dirname(os.path.realpath(__file__))
43
- yaml_path = os.path.join(cur_path, "ms_to_pt_api.yaml")
167
+ yaml_path = os.path.abspath(os.path.join(cur_path, CompareConst.INTERNAL_API_MAPPING_FILE))
44
168
  return load_yaml(yaml_path)
45
169
 
46
170
  def load_mapping_file(self, mapping_file):
@@ -51,42 +175,20 @@ class MSComparator(Comparator):
51
175
  return mapping_dict
52
176
 
53
177
  def process_cell_mapping(self, npu_op_name):
54
- npu_op_name = [op_name.replace("Cell", "Module", 1) for op_name in npu_op_name]
178
+ if not npu_op_name or not re.match(r'.+(?:for|back)ward\..+', npu_op_name):
179
+ return CompareConst.N_A
180
+ npu_op_name = npu_op_name.replace("Cell", "Module", 1)
55
181
  if self.cell_mapping_dict:
56
- for index, op_name in enumerate(npu_op_name):
57
- # get cell name & class name from op_name
58
- # Cell.fc1.Dense.forward.0.input.0
59
- cell_name = op_name.split(Const.SEP, 1)[-1].rsplit(Const.SEP, 4)[0]
60
- if cell_name in self.cell_mapping_dict:
61
- npu_op_name[index] = op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
182
+ # get cell name & class name from op_name
183
+ # Cell.fc1.Dense.forward.0.input.0
184
+ cell_name = re.split(r'\.(?:for|back)ward\.', npu_op_name.split(Const.SEP, 1)[-1])[0]
185
+ if cell_name in self.cell_mapping_dict:
186
+ npu_op_name = npu_op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
62
187
  return npu_op_name
63
188
 
64
- def check_op(self, npu_dict, bench_dict, fuzzy_match):
65
- npu_dict_new, bench_dict_new = copy.deepcopy(npu_dict), copy.deepcopy(bench_dict)
66
- npu_op_name, bench_op_name = npu_dict_new.get(CompareConst.OP_NAME), bench_dict_new.get(CompareConst.OP_NAME)
67
- if self.cell_mapping is not None:
68
- npu_op_name = self.process_cell_mapping(npu_op_name)
69
- if self.api_mapping is not None:
70
- npu_op_name = self.process_internal_api_mapping(npu_op_name, bench_op_name)
71
- if isinstance(self.api_mapping, str):
72
- npu_dict_new, bench_dict_new, target_dict = self.transform_user_mapping_api(npu_dict_new,
73
- bench_dict_new)
74
- if target_dict:
75
- bench_dict = self.reconstitution_bench_dict(npu_dict, copy.deepcopy(bench_dict_new), target_dict)
76
- npu_op_name = npu_dict_new.get(CompareConst.OP_NAME)
77
- bench_op_name = bench_dict_new.get(CompareConst.OP_NAME)
78
- struct_match = check_struct_match(npu_dict_new, bench_dict_new, cross_frame=self.cross_frame)
79
- if not fuzzy_match:
80
- return npu_op_name == bench_op_name and struct_match
81
- is_match = True
82
- try:
83
- is_match = fuzzy_check_op(npu_op_name, bench_op_name)
84
- except Exception as err:
85
- logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
86
- is_match = False
87
- return is_match and struct_match
88
-
89
189
  def read_npy_data(self, dir_path, file_name, load_pt_file=False):
190
+ if not file_name:
191
+ return None
90
192
  data_path = os.path.join(dir_path, file_name)
91
193
  if load_pt_file:
92
194
  import torch
@@ -97,34 +199,22 @@ class MSComparator(Comparator):
97
199
  data_value = data_value.numpy()
98
200
  else:
99
201
  data_value = load_npy(data_path)
100
- return data_value
202
+ return data_value
101
203
 
102
- def api_replace(self, npu_op_name, target, para):
103
- for idx, _ in enumerate(npu_op_name):
104
- npu_op_name[idx] = npu_op_name[idx].replace(target, para)
105
- return npu_op_name
106
-
107
- def process_internal_api_mapping(self, npu_op_name, bench_op_name):
204
+ def process_internal_api_mapping(self, npu_op_name):
108
205
  # get api name & class name from op_name
109
206
  # Functional.addcmul.0.forward.input.0
110
- npu_op_name, bench_op_name = npu_op_name.copy(), bench_op_name.copy()
111
- ms_api_name = self.get_api_name(npu_op_name[0].split(Const.SEP))
112
- pt_api_name = self.get_api_name(bench_op_name[0].split(Const.SEP))
207
+ ms_api_name = self.get_api_name(npu_op_name.split(Const.SEP))
113
208
  class_name = ms_api_name.split(Const.SEP)[0]
114
209
  if class_name == "Mint":
115
- return self.api_replace(npu_op_name, "Mint", "Torch")
210
+ return npu_op_name.replace("Mint", "Torch")
116
211
  elif class_name == "MintFunctional":
117
- return self.api_replace(npu_op_name, "MintFunctional", "Functional")
118
- elif self.ms_to_pt_mapping.get(ms_api_name) == pt_api_name:
119
- return self.api_replace(npu_op_name, ms_api_name, pt_api_name)
212
+ return npu_op_name.replace("MintFunctional", "Functional")
213
+ elif self.ms_to_pt_mapping.get(ms_api_name):
214
+ return npu_op_name.replace(ms_api_name, self.ms_to_pt_mapping.get(ms_api_name))
120
215
  else:
121
216
  return npu_op_name
122
217
 
123
- def remove_element(self, op_name, struct, summary, idx):
124
- del op_name[idx]
125
- del struct[idx]
126
- del summary[idx]
127
-
128
218
  def get_api_name(self, api_list):
129
219
  try:
130
220
  api_name = api_list[0] + Const.SEP + api_list[1]
@@ -132,184 +222,126 @@ class MSComparator(Comparator):
132
222
  logger.error(f'Failed to retrieve API name, please check if the dump data is reasonable')
133
223
  raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
134
224
  return api_name
135
-
136
- def transform_user_mapping_api(self, new_npu_dict, new_bench_dict):
137
- """
138
- Transform user mapping API based on new NPU and benchmark dictionaries.
139
- Parameters:
140
- new_npu_dict (dict): New NPU operation dictionary.
141
- new_bench_dict (dict): New benchmark operation dictionary.
142
- Returns:
143
- tuple: Updated NPU and benchmark dictionaries, along with the target dictionary.
144
- """
145
- npu_op_name, bench_op_name = new_npu_dict.get(CompareConst.OP_NAME), new_bench_dict.get(CompareConst.OP_NAME)
146
- npu_struct_in = new_npu_dict.get(CompareConst.INPUT_STRUCT)
147
- bench_struct_in = new_bench_dict.get(CompareConst.INPUT_STRUCT)
148
- npu_struct_out = new_npu_dict.get(CompareConst.OUTPUT_STRUCT)
149
- bench_struct_out = new_bench_dict.get(CompareConst.OUTPUT_STRUCT)
150
- npu_summary, bench_summary = new_npu_dict.get(CompareConst.SUMMARY), new_bench_dict.get(CompareConst.SUMMARY)
151
- npu_in_len, bench_in_len = len(npu_struct_in), len(bench_struct_in)
152
- npu_out_len, bench_out_len = len(npu_struct_out), len(bench_struct_out)
153
- ms_api_list, pt_api_list = npu_op_name[0].split(Const.SEP), bench_op_name[0].split(Const.SEP)
154
- ms_api_name = self.get_api_name(ms_api_list)
155
- pt_api_name = self.get_api_name(pt_api_list)
156
- target_dict = {}
157
- for api_dict in self.api_mapping_dict:
158
- if api_dict.get("pt_api") == pt_api_name and api_dict.get("ms_api") == ms_api_name:
159
- ms_user_args_len, pt_user_args_len = len(api_dict.get("ms_args")), len(api_dict.get("pt_args"))
160
- ms_user_output_len, pt_user_output_len = len(api_dict.get("ms_output")), len(api_dict.get("pt_output"))
161
- if ms_user_args_len != pt_user_args_len or ms_user_output_len != pt_user_output_len:
162
- logger.warning("The user-defined mapping table is incorrect,\
163
- make sure that the number of parameters is equal")
164
- break
165
- ms_out_list = api_dict.get("ms_output", [])
166
- for idx in reversed(range(npu_out_len)):
167
- if idx not in ms_out_list:
168
- del npu_struct_out[idx]
169
- if idx + npu_in_len < len(npu_summary) and idx + npu_in_len < len(npu_op_name):
170
- del npu_summary[idx + npu_in_len]
171
- del npu_op_name[idx + npu_in_len]
172
- pt_out_list = api_dict.get("pt_output", [])
173
- for idx in reversed(range(bench_out_len)):
174
- if idx not in pt_out_list:
175
- del bench_struct_out[idx]
176
- if idx + bench_in_len < len(bench_summary) and idx + bench_in_len < len(bench_op_name):
177
- del bench_summary[idx + bench_in_len]
178
- del bench_op_name[idx + bench_in_len]
179
- ms_para_list = api_dict.get("ms_args", [])
180
- for idx in reversed(range(npu_in_len)):
181
- if idx not in ms_para_list:
182
- self.remove_element(npu_op_name, npu_struct_in, npu_summary, idx)
183
- pt_para_list = api_dict.get("pt_args", [])
184
- for idx in reversed(range(bench_in_len)):
185
- if idx not in pt_para_list:
186
- self.remove_element(bench_op_name, bench_struct_in, bench_summary, idx)
187
- npu_op_name = self.api_replace(npu_op_name, ms_api_name, pt_api_name)
188
- npu_op_name = self.para_sequence_update(npu_op_name, bench_op_name)
189
- target_dict = api_dict
190
- break
191
- if target_dict:
192
- new_npu_dict.update({CompareConst.OP_NAME: npu_op_name, CompareConst.INPUT_STRUCT: npu_struct_in,
193
- CompareConst.OUTPUT_STRUCT: npu_struct_out, CompareConst.SUMMARY: npu_summary})
194
- new_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in,
195
- CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
196
- return new_npu_dict, new_bench_dict, target_dict
197
-
198
- def para_sequence_update(self, npu_op_name, bench_op_name):
199
- for idx, _ in enumerate(npu_op_name):
200
- bench_op_name_list = bench_op_name[idx].rsplit(Const.SEP, 1)
201
- if len(bench_op_name_list) != 0:
202
- npu_op_name[idx] = npu_op_name[idx][:-1] + bench_op_name_list[-1]
203
- return npu_op_name
204
225
 
205
- def reconstitution_bench_dict(self, npu_dict, del_bench_dict, api_dict):
206
- ms_user_args_list = api_dict.get("ms_args", [])
207
- ms_user_output_list = api_dict.get("ms_output", [])
208
- npu_struct_in = npu_dict.get(CompareConst.INPUT_STRUCT)
209
- npu_struct_out = npu_dict.get(CompareConst.OUTPUT_STRUCT)
210
- npu_in_len = len(npu_struct_in)
211
- npu_out_len = len(npu_struct_out)
212
- if npu_in_len == len(ms_user_args_list) and npu_out_len == len(ms_user_output_list):
213
- return del_bench_dict
214
- ms_input_args_list = [i for i in range(npu_in_len)]
215
- input_sub_list = list(set(ms_input_args_list) - set(ms_user_args_list))
216
- ms_output_args_list = [i for i in range(npu_out_len)]
217
- output_sub_list = list(set(ms_output_args_list) - set(ms_user_output_list))
218
- bench_op_name = del_bench_dict.get(CompareConst.OP_NAME, [])
219
- bench_struct_in = del_bench_dict.get(CompareConst.INPUT_STRUCT, [])
220
- bench_struct_out = del_bench_dict.get(CompareConst.OUTPUT_STRUCT, [])
221
- bench_summary = del_bench_dict.get(CompareConst.SUMMARY, [])
222
- for idx in input_sub_list: # Fill in the blank value field in the pt dictionary
223
- bench_op_name.insert(idx, CompareConst.N_A)
224
- bench_struct_in.insert(idx, CompareConst.N_A)
225
- bench_summary.insert(idx, CompareConst.N_A)
226
- for idx in output_sub_list: # Fill in the blank value field in the pt dictionary
227
- bench_op_name.insert(npu_in_len + idx, CompareConst.N_A)
228
- bench_struct_out.insert(idx, CompareConst.N_A)
229
- bench_summary.insert(npu_in_len + idx, CompareConst.N_A)
230
- del_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in,
231
- CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
232
- return del_bench_dict
233
-
226
+ def compare_process(self, file_lists, stack_mode, fuzzy_match, dump_mode):
227
+ npu_json_path, bench_json_path, stack_json_path = file_lists
228
+ npu_json_data = load_json(npu_json_path)
229
+ bench_json_data = load_json(bench_json_path)
230
+ stack_json_data = load_json(stack_json_path)
234
231
 
235
- def sort_by_execution_sequence(npu_data, bench_data, mapping_list, flag):
236
- def generate_execution_sequence(data):
237
- sequence_map = {}
238
- for index, item in enumerate(data.keys()):
239
- if flag in item:
240
- item_split = item.split(Const.SEP)
241
- item_name = Const.SEP.join(item_split[0:-2])
242
- item_index = item_split[-1]
243
- if item_index == 'forward' or item_index == 'backward':
244
- item_index = item_split[-2]
245
- item_key = f"{item_name}.{item_index}"
246
- sequence_map[item_key] = index
247
- return sequence_map
232
+ npu_df = self.gen_data_df(npu_json_data, stack_json_data, dump_mode)
233
+ bench_df = self.gen_data_df(bench_json_data, stack_json_data, dump_mode)
234
+ if self.cell_mapping:
235
+ npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping)
236
+ elif self.api_mapping:
237
+ npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_internal_api_mapping)
238
+ if isinstance(self.api_mapping, str):
239
+ self.modify_compare_data_with_user_mapping(npu_df, bench_df)
240
+ else:
241
+ npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME]
242
+ npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str)
243
+ bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str)
244
+ npu_df[CompareConst.COMPARE_SHAPE] = npu_df[Const.SHAPE]
245
+ bench_df[CompareConst.COMPARE_SHAPE] = bench_df[Const.SHAPE]
246
+ bench_df[CompareConst.COMPARE_KEY] = bench_df[CompareConst.OP_NAME]
247
+ match_result = pd.merge(npu_df, bench_df, on=[CompareConst.COMPARE_KEY, CompareConst.COMPARE_SHAPE],
248
+ how='outer')
249
+ match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A)
248
250
 
249
- npu_map = generate_execution_sequence(npu_data)
250
- bench_map = generate_execution_sequence(bench_data)
251
+ def gen_dtype_condition():
252
+ npu_dtype = match_result['dtype_x']
253
+ bench_dtype = match_result['dtype_y']
254
+ if self.cross_frame:
255
+ npu_dtype = npu_dtype.map(dtype_mapping).fillna(npu_dtype)
256
+ return ((npu_dtype == bench_dtype) |
257
+ ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.FLOAT32)) |
258
+ ((npu_dtype == Const.FLOAT32) & (bench_dtype == Const.FLOAT16)) |
259
+ ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.BFLOAT16)) |
260
+ ((npu_dtype == Const.BFLOAT16) & (bench_dtype == Const.FLOAT16)) |
261
+ ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_FLOAT32)) |
262
+ ((npu_dtype == Const.TORCH_FLOAT32) & (bench_dtype == Const.TORCH_FLOAT16)) |
263
+ ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_BFLOAT16)) |
264
+ ((npu_dtype == Const.TORCH_BFLOAT16) & (bench_dtype == Const.TORCH_FLOAT16)))
265
+
266
+ match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A
267
+ return MSComparator.make_result_df(match_result, stack_mode, dump_mode)
251
268
 
252
- def sort_by_map(item):
253
- first_key = npu_map.get(item[0], sys.maxsize)
254
- second_key = bench_map.get(item[1], sys.maxsize)
255
- return first_key, second_key
269
+ def modify_compare_data_with_user_mapping(self, npu_df, bench_df):
270
+ def get_api_indices_dict(op_name_df):
271
+ api_indices_dict = defaultdict(list)
272
+ for op_index, name in enumerate(op_name_df[CompareConst.OP_NAME]):
273
+ api = self.get_api_name(name.split(Const.SEP))
274
+ api_indices_dict[api].append(op_index)
275
+ return api_indices_dict
256
276
 
257
- return sorted(mapping_list, key=sort_by_map)
277
+ ms_api_indices_dict = get_api_indices_dict(npu_df)
278
+ pt_api_indices_dict = get_api_indices_dict(bench_df)
258
279
 
280
+ def gen_input_compare_key(pattern, term):
281
+ flag = True
282
+ for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')):
283
+ if op_name.split(pattern)[1].startswith(str(prefix)):
284
+ npu_df.loc[index, CompareConst.COMPARE_KEY] = (
285
+ op_name.replace(pattern + str(prefix),
286
+ pattern + str(mapping_dict.get(f'pt_{term}')[i])))
287
+ flag = False
288
+ return flag
259
289
 
260
- def generate_kernel_data(map_value, data, flag):
261
- if not map_value:
262
- return [], []
263
- inputs_name = []
264
- outputs_name = []
265
- map_split = map_value.split(Const.SEP)
266
- map_name = Const.SEP.join(map_split[0:-1])
267
- map_index = map_split[-1]
268
- for key, value in data.items():
269
- if key.find(flag) != -1 and key.find(map_name) != -1:
270
- if key.split(Const.SEP)[-1] != map_index and key.split(Const.SEP)[-2] != map_index :
290
+ for mapping_dict in self.api_mapping_dict:
291
+ if (len(mapping_dict.get('ms_args')) != len(mapping_dict.get('pt_args')) or
292
+ len(mapping_dict.get('ms_output')) != len(mapping_dict.get('pt_output'))):
293
+ logger.warning('The user-defined mapping table is incorrect,\
294
+ make sure that the number of parameters is equal')
271
295
  continue
272
- if flag == 'forward':
273
- input_args = value.get('input_args', {})
274
- else:
275
- input_args = value.get('input', {})
276
- output_args = value.get('output', {})
277
- for i in range(len(input_args)):
278
- inputs_name.append(f"{key}.input.{i}")
279
- for i in range(len(output_args)):
280
- outputs_name.append(f"{key}.output.{i}")
281
- return inputs_name, outputs_name
282
-
283
-
284
- def generate_file_mapping(npu_json_path, bench_json_path, mapping_list):
285
-
286
- npu_data = load_json(npu_json_path).get("data", {})
287
- bench_data = load_json(bench_json_path).get("data", {})
288
-
289
- forward_data = []
290
- mapping_list = sort_by_execution_sequence(npu_data, bench_data, mapping_list, Const.FORWARD)
291
- for map_value in mapping_list:
292
- npu_forward_inputs, npu_backward_outputs = generate_kernel_data(map_value[0], npu_data, "forward")
293
- bench_forward_inputs, bench_backward_outputs = generate_kernel_data(map_value[1], bench_data, "forward")
294
- inputs_zip = list(zip_longest(npu_forward_inputs, bench_forward_inputs))
295
- outputs_zip = list(zip_longest(npu_backward_outputs, bench_backward_outputs))
296
- forward_data.extend(inputs_zip)
297
- forward_data.extend(outputs_zip)
298
-
299
- backward_data = []
300
- mapping_list = sort_by_execution_sequence(npu_data, bench_data, mapping_list, Const.BACKWARD)
301
- for map_value in mapping_list:
302
- npu_forward_inputs, npu_backward_outputs = generate_kernel_data(map_value[0], npu_data, "backward")
303
- bench_forward_inputs, bench_backward_outputs = generate_kernel_data(map_value[1], bench_data, "backward")
304
- inputs_zip = list(zip_longest(npu_forward_inputs, bench_forward_inputs))
305
- outputs_zip = list(zip_longest(npu_backward_outputs, bench_backward_outputs))
306
- backward_data.extend(inputs_zip)
307
- backward_data.extend(outputs_zip)
308
-
309
- kernel_data = forward_data + backward_data
310
- result = {key: value for key, value in kernel_data if key is not None}
296
+ ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api')
297
+ if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict:
298
+ continue
299
+ for index in ms_api_indices_dict.get(ms_api):
300
+ op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1)
301
+ if CompareConst.INPUT_PATTERN in op_name:
302
+ is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args')
303
+ elif CompareConst.KWARGS_PATTERN in op_name:
304
+ is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args')
305
+ elif CompareConst.OUTPUT_PATTERN in op_name:
306
+ is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output')
307
+ else:
308
+ logger.error(f'Excepted op_name: {op_name}')
309
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
310
+ if is_abandoned:
311
+ npu_df.loc[index, CompareConst.COMPARE_KEY] = op_name + 'abandoned'
311
312
 
312
- return result
313
+ def gen_data_df(self, data_json, stack_json, dump_mode):
314
+ result = {
315
+ CompareConst.OP_NAME: [],
316
+ Const.DTYPE: [],
317
+ Const.SHAPE: [],
318
+ Const.SUMMARY: [],
319
+ 'stack_info': []
320
+ }
321
+ if dump_mode == Const.ALL:
322
+ result['data_name'] = []
323
+ elif dump_mode == Const.MD5:
324
+ result[Const.MD5] = []
325
+ for data_name in data_json['data']:
326
+ check_op_str_pattern_valid(data_name)
327
+ merge_list = self.gen_merge_list(data_json, data_name, stack_json, dump_mode)
328
+ if not merge_list:
329
+ continue
330
+ for op_name in merge_list[CompareConst.OP_NAME]:
331
+ result[CompareConst.OP_NAME].append(op_name)
332
+ if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name):
333
+ struct = merge_list[CompareConst.INPUT_STRUCT].pop(0)
334
+ else:
335
+ struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0)
336
+ result[Const.DTYPE].append(struct[0])
337
+ result[Const.SHAPE].append(struct[1])
338
+ if dump_mode == Const.MD5:
339
+ result[Const.MD5].append(struct[2])
340
+ result[Const.SUMMARY].append(merge_list[Const.SUMMARY].pop(0))
341
+ result['stack_info'].append(merge_list['stack_info'][0])
342
+ if dump_mode == Const.ALL:
343
+ result['data_name'].append(merge_list['data_name'].pop(0))
344
+ return pd.DataFrame(result)
313
345
 
314
346
 
315
347
  def check_cross_framework(bench_json_path):
@@ -330,28 +362,19 @@ def ms_compare(input_param, output_path, **kwargs):
330
362
  api_mapping = kwargs.get('api_mapping', None)
331
363
  data_mapping = kwargs.get('data_mapping', None)
332
364
  layer_mapping = kwargs.get('layer_mapping', None)
365
+ suffix = kwargs.get('suffix', '')
333
366
 
334
- summary_compare, md5_compare = task_dumppath_get(input_param)
367
+ set_dump_path(input_param)
368
+ dump_mode = get_dump_mode(input_param)
335
369
  check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
336
370
  create_directory(output_path)
337
- check_compare_param(input_param, output_path, summary_compare, md5_compare)
371
+ check_compare_param(input_param, output_path, dump_mode)
338
372
  except (CompareException, FileCheckException) as error:
339
373
  logger.error('Compare failed. Please check the arguments and do it again!')
340
374
  raise CompareException(error.code) from error
341
375
  if layer_mapping:
342
- pt_stack, pt_construct = struct_json_get(input_param, Const.PT_FRAMEWORK)
343
- ms_stack, ms_construct = struct_json_get(input_param, Const.MS_FRAMEWORK)
344
- mapping = load_yaml(layer_mapping)
345
- ms_mapping_result = modify_mapping_with_stack(ms_stack, ms_construct)
346
- pt_mapping_result = modify_mapping_with_stack(pt_stack, pt_construct)
347
- layer_mapping = get_layer_mapping(ms_mapping_result, pt_mapping_result, mapping)
348
- data_mapping = generate_file_mapping(input_param.get("npu_json_path"), input_param.get("bench_json_path"), layer_mapping)
349
-
350
- data_mapping_name = add_time_with_yaml(f"data_mapping")
351
- data_mapping_path = os.path.join(os.path.realpath(output_path), f"{data_mapping_name}")
352
- save_yaml(data_mapping_path, data_mapping)
376
+ data_mapping = generate_data_mapping_by_layer_mapping(input_param, layer_mapping, output_path)
353
377
  is_cross_framework = check_cross_framework(input_param.get("bench_json_path"))
354
378
  ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping, is_cross_framework)
355
- ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
356
- auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
357
- md5_compare=md5_compare)
379
+ ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode, suffix=suffix,
380
+ auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, dump_mode=dump_mode)