mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -13,111 +13,234 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import multiprocessing
17
16
  import os
18
17
  import re
19
- from copy import deepcopy
18
+ from dataclasses import dataclass
19
+ from collections import defaultdict
20
20
 
21
+ import numpy as np
21
22
  import pandas as pd
22
23
  from tqdm import tqdm
23
24
 
24
25
  from msprobe.core.advisor.advisor import Advisor
25
26
  from msprobe.core.common.const import CompareConst, Const
26
27
  from msprobe.core.common.exceptions import FileCheckException
27
- from msprobe.core.common.file_utils import load_json, remove_path
28
+ from msprobe.core.common.file_utils import load_json, remove_path, create_directory
28
29
  from msprobe.core.common.log import logger
29
- from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, safe_get_value
30
- from msprobe.core.compare.check import check_dump_json_str, check_graph_mode, check_stack_json_str, \
31
- check_struct_match, fuzzy_check_op
32
- from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
33
- from msprobe.core.compare.multiprocessing_compute import ComparisonResult, _handle_multi_process, _save_cmp_result
34
- from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_flag_and_msg
35
- from msprobe.core.compare.utils import get_accuracy, get_rela_diff_summary_mode, get_un_match_accuracy, merge_tensor, \
36
- print_compare_ends_info, read_op, get_name_and_state, reorder_op_x_list
37
-
38
-
39
- class ModeConfig:
40
- def __init__(self, stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=None):
41
- self.stack_mode = stack_mode
42
- self.auto_analyze = auto_analyze
43
- self.fuzzy_match = fuzzy_match
44
- self.dump_mode = dump_mode
30
+ from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, \
31
+ set_dump_path, get_dump_mode, check_compare_param, check_configuration_param, load_stack_json, get_file_type
32
+ from msprobe.core.compare.check import check_dump_json_str, check_stack_json_str, cross_dtype_mapping
33
+ from msprobe.core.compare.utils import merge_tensor, print_compare_ends_info, read_op, \
34
+ reorder_op_x_list, set_stack_json_path
35
+ from msprobe.core.compare.config import ModeConfig, MappingConfig, MappingDict
36
+ from msprobe.core.compare.multiprocessing_compute import CompareRealData
37
+ from msprobe.core.compare.highlight import HighLight
38
+
39
+
40
+ @dataclass
41
+ class ComparisonConfig:
42
+ dump_mode: str
43
+ stack_mode: bool
44
+ auto_analyze: bool
45
+ fuzzy_match: bool
46
+ data_mapping: dict
47
+ suffix: str
48
+ cell_mapping: dict
49
+ api_mapping: dict
50
+ layer_mapping: dict
51
+ compared_file_type: str
45
52
 
46
53
 
47
54
  class Comparator:
48
- def __init__(self, mode_config: ModeConfig):
49
- self.stack_mode = mode_config.stack_mode
50
- self.auto_analyze = mode_config.auto_analyze
51
- self.fuzzy_match = mode_config.fuzzy_match
52
- self.dump_mode = mode_config.dump_mode
55
+ def __init__(self, file_reader, mode_config: ModeConfig, mapping_config: MappingConfig, is_cross_framework=False):
56
+ self.file_reader = file_reader
57
+ self.mode_config = mode_config
58
+ self.mapping_config = mapping_config
59
+ self.cross_frame = is_cross_framework
60
+
61
+ self.mapping_dict = MappingDict(mapping_config)
53
62
 
54
63
  @staticmethod
55
- def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args):
56
- npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
57
- bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
64
+ def process_output_file(output_path, suffix, compared_file_type):
65
+ file_name_prefix_mapping = {
66
+ Const.DUMP_JSON_FILE: "compare_result",
67
+ Const.DEBUG_JSON_FILE: "debug_compare_result"
68
+ }
69
+ file_name_prefix = file_name_prefix_mapping.get(compared_file_type, "compare_result")
70
+ file_name = add_time_with_xlsx(file_name_prefix + suffix)
71
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
72
+ if os.path.exists(file_path):
73
+ logger.warning(f"{file_path} will be deleted.")
74
+ remove_path(file_path)
75
+ return file_path
58
76
 
59
- if len(npu_struct) < 3 or len(bench_struct) < 3:
60
- logger.error(f"The length of npu_struct and bench_struct must be >= 3, "
61
- f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. Please check!")
62
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
77
+ def compare_core(self, input_param, output_path, **kwargs):
78
+ """
79
+ Compares data from multiple JSON files and generates a comparison report.
63
80
 
64
- result_item = [ms_op_name, bench_op_name, npu_struct[0], bench_struct[0],
65
- npu_struct[1], bench_struct[1], npu_struct[2], bench_struct[2],
66
- CompareConst.PASS if npu_struct[2] == bench_struct[2] else CompareConst.DIFF]
81
+ Args:
82
+ input_param (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
83
+ "stack_path").
84
+ output_path (str): The path where the output Excel report will be saved.
85
+ **kwargs: Additional keyword arguments including:
86
+ - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False.
87
+ - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
88
+ - suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
89
+ - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
90
+ - dump_mode (str): ALL, SUMMARY, MD5.
67
91
 
68
- if len(args) >= 2 and args[0]:
69
- result_item.extend(args[1])
70
- else:
71
- result_item.append(CompareConst.NONE)
72
- return result_item
92
+ Returns:
93
+ """
94
+ logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
73
95
 
74
- @staticmethod
75
- def calculate_summary_data(npu_summary_data, bench_summary_data, result_item):
76
- err_msg = ""
77
- result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
78
- bench_summary_data, err_msg)
79
- result_item.append(accuracy_check)
80
- result_item.append(err_msg)
96
+ # get kwargs or set default value
97
+ suffix = kwargs.get('suffix', '')
81
98
 
82
- @staticmethod
83
- def _generate_na_data(ops_all):
84
- if not ops_all:
85
- return {}
86
- key = next(iter(ops_all))
87
- value = deepcopy(ops_all[key])
88
- for k, v in value.items():
89
- if isinstance(v, tuple):
90
- value[k] = tuple(CompareConst.N_A for _ in range(len(v)))
91
- elif isinstance(v, list):
92
- value[k] = [CompareConst.N_A] * len(v)
93
- else:
94
- value[k] = CompareConst.N_A
95
- return value
99
+ # process output file
100
+ file_path = self.process_output_file(output_path, suffix, self.mode_config.compared_file_type)
101
+
102
+ # initialize the compare result table and compare general data(name, dtype, shape, statistics/md5, etc.)
103
+ npu_json = input_param.get("npu_json_path")
104
+ bench_json = input_param.get("bench_json_path")
105
+ stack_json = input_param.get("stack_json_path")
106
+ result_df = self.compare_statistics([npu_json, bench_json, stack_json])
107
+ if not result_df.values.tolist():
108
+ logger.warning("Can`t match any op. No compare result file generated.")
109
+ return
96
110
 
97
- def make_result_table(self, result):
98
- header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:]
111
+ # compare real data
112
+ if self.mode_config.dump_mode == Const.ALL:
113
+ compare_real_data = CompareRealData(self.file_reader, self.mode_config, self.cross_frame)
114
+ result_df = compare_real_data.do_multi_process(input_param, result_df)
99
115
 
100
- if self.stack_mode:
101
- header.append(CompareConst.STACK)
102
- if self.dump_mode == Const.ALL:
103
- header.append(CompareConst.DATA_NAME)
104
- else:
105
- if self.dump_mode == Const.ALL:
106
- for row in result:
107
- del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列
108
- header.append(CompareConst.DATA_NAME)
109
- else:
110
- for row in result:
111
- del row[-1] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,非真实数据时为倒数第1列
112
- result_df = pd.DataFrame(result, columns=header, dtype='object')
113
- return result_df
116
+ # highlight suspicious API
117
+ highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
118
+ highlight = HighLight(self.mode_config)
119
+ if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE:
120
+ highlight.find_compare_result_error_rows(result_df, highlight_dict)
121
+ highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path)
122
+
123
+ # output compare analysis suggestions
124
+ if self.mode_config.auto_analyze:
125
+ advisor = Advisor(result_df, output_path, suffix)
126
+ advisor.analysis()
127
+
128
+ print_compare_ends_info()
129
+
130
+ def compare_statistics(self, file_list):
131
+ # load and parse json data
132
+ parse_data = ParseData(self.mode_config)
133
+ npu_df, bench_df = parse_data.parse(file_list)
134
+
135
+ npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str)
136
+ bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str)
137
+
138
+ # create new columns for compare op_name and shape
139
+ # process npu_df's COMPARE_KEY whether same or different framework
140
+ process_df = ProcessDf(self.mode_config, self.mapping_config, self.mapping_dict)
141
+ npu_df, bench_df = process_df.process_compare_key_and_shape(npu_df, bench_df)
142
+
143
+ # match npu and bench, match_result contains both npu_info and bench_info
144
+ match = Match(self.mode_config, self.mapping_config, self.cross_frame)
145
+ match_result = match.match_api_infos(npu_df, bench_df)
146
+ # 筛选出npu_name存在的行并填充筛选出行中的缺失值为N/A
147
+ match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A)
148
+ bench_columns = [i + '_y' for i in bench_df.columns]
149
+ match_result.loc[~match.gen_dtype_condition(match_result), bench_columns] = CompareConst.N_A
150
+
151
+ # organize compare result table by renaming columns
152
+ create_table = CreateTable(self.mode_config)
153
+ result_df, header = create_table.make_result_df(match_result)
154
+
155
+ # calculate statistics diff
156
+ calc_stats_diff = CalcStatsDiff(self.mode_config)
157
+ return calc_stats_diff.calc_accuracy(result_df, header)
158
+
159
+
160
+ class ParseData:
161
+ def __init__(self, mode_config: ModeConfig):
162
+ self.mode_config = mode_config
163
+
164
+ def parse(self, file_list):
165
+ npu_json_path, bench_json_path, stack_json_path = file_list
166
+ npu_json_data = load_json(npu_json_path)
167
+ bench_json_data = load_json(bench_json_path)
168
+ stack_json_data = load_stack_json(stack_json_path) if self.mode_config.stack_mode else None
169
+
170
+ # parse json data and generate df
171
+ npu_df = self.gen_data_df(npu_json_data, stack_json_data)
172
+ bench_df = self.gen_data_df(bench_json_data, stack_json_data)
173
+
174
+ return npu_df, bench_df
175
+
176
+ def gen_data_df(self, data_json, stack_json_data):
177
+ result = {
178
+ CompareConst.OP_NAME: [],
179
+ Const.DTYPE: [],
180
+ Const.SHAPE: [],
181
+ Const.SUMMARY: [],
182
+ Const.STACK_INFO: []
183
+ }
184
+ if self.mode_config.dump_mode == Const.ALL:
185
+ result['data_name'] = []
186
+ elif self.mode_config.dump_mode == Const.MD5:
187
+ result[Const.MD5] = []
188
+
189
+ apis_data = data_json.get('data', None)
190
+ if not apis_data:
191
+ logger.warning('No APIs found in dump.json.')
192
+ return pd.DataFrame(result)
193
+
194
+ api_nums = len(apis_data)
195
+ progress_bar = tqdm(total=api_nums, desc="API/Module Read Progress", unit="api/module", ncols=100)
196
+
197
+ # 从json中循环解析API数据,遍历所有API
198
+ for data_name in apis_data:
199
+ check_op_str_pattern_valid(data_name)
200
+ merge_list = self.gen_merge_list(data_json, data_name, stack_json_data)
201
+ if not merge_list:
202
+ continue
203
+
204
+ op_name_list = merge_list.get(CompareConst.OP_NAME)
205
+ summary_list = merge_list.get(Const.SUMMARY)
206
+ data_name_list = merge_list.get('data_name')
207
+ op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
208
+ summary_list,
209
+ data_name_list)
210
+ # 遍历单个API的所有item
211
+ for index, op_name in enumerate(op_name_reorder):
212
+ result[CompareConst.OP_NAME].append(op_name)
213
+ if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name):
214
+ struct = merge_list[CompareConst.INPUT_STRUCT].pop(0)
215
+ elif CompareConst.OUTPUT_PATTERN in op_name:
216
+ struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0)
217
+ elif CompareConst.PARAMS_PATTERN in op_name:
218
+ struct = merge_list[CompareConst.PARAMS_STRUCT].pop(0)
219
+ elif CompareConst.PARAMS_GRAD_PATTERN in op_name:
220
+ struct = merge_list[CompareConst.PARAMS_GRAD_STRUCT].pop(0)
221
+ else:
222
+ struct = merge_list[CompareConst.DEBUG_STRUCT].pop(0)
223
+ result[Const.DTYPE].append(struct[0])
224
+ result[Const.SHAPE].append(struct[1])
225
+ if self.mode_config.dump_mode == Const.MD5:
226
+ result[Const.MD5].append(struct[2])
227
+ result[Const.SUMMARY].append(summary_reorder.pop(0))
228
+ result[Const.STACK_INFO].append(
229
+ merge_list[Const.STACK_INFO][0] if index == 0 and self.mode_config.stack_mode else None)
230
+ if self.mode_config.dump_mode == Const.ALL:
231
+ result['data_name'].append(data_name_reorder.pop(0))
232
+
233
+ progress_bar.update(1)
234
+ progress_bar.close()
235
+ return pd.DataFrame(result)
114
236
 
115
237
  def gen_merge_list(self, json_data, op_name, stack_json_data):
116
238
  op_data = json_data['data'][op_name]
117
- check_dump_json_str(op_data, op_name)
239
+ if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE:
240
+ check_dump_json_str(op_data, op_name)
118
241
  op_parsed_list = read_op(op_data, op_name)
119
242
 
120
- if self.stack_mode:
243
+ if self.mode_config.stack_mode:
121
244
  stack_info = stack_json_data.get(op_name)
122
245
  if stack_info is not None:
123
246
  check_stack_json_str(stack_info, op_name)
@@ -127,392 +250,487 @@ class Comparator:
127
250
  'full_info': stack_info
128
251
  })
129
252
 
130
- merge_list = merge_tensor(op_parsed_list, self.dump_mode)
253
+ merge_list = merge_tensor(op_parsed_list, self.mode_config.dump_mode)
131
254
  return merge_list
132
255
 
133
- def check_op(self, npu_dict, bench_dict):
134
- npu_op_name = npu_dict[CompareConst.OP_NAME]
135
- bench_op_name = bench_dict[CompareConst.OP_NAME]
136
- graph_mode = check_graph_mode(safe_get_value(npu_op_name, 0, "npu_op_name"),
137
- safe_get_value(bench_op_name, 0, "bench_op_name"))
138
-
139
- frame_name = getattr(self, "frame_name")
140
- if frame_name == "PTComparator":
141
- from msprobe.pytorch.compare.match import graph_mapping
142
- if graph_mode:
143
- return graph_mapping.match(npu_op_name[0], bench_op_name[0])
144
- struct_match = check_struct_match(npu_dict, bench_dict)
145
- if not self.fuzzy_match:
146
- name_match = npu_op_name == bench_op_name
147
- return name_match and struct_match
148
- try:
149
- name_match = fuzzy_check_op(npu_op_name, bench_op_name)
150
- except Exception as err:
151
- logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
152
- name_match = False
153
- return name_match and struct_match
154
256
 
155
- def match_op(self, npu_queue, bench_queue):
156
- for b_index, b_op in enumerate(bench_queue[0: -1]):
157
- if self.check_op(npu_queue[-1], b_op):
158
- return len(npu_queue) - 1, b_index
159
- if self.check_op(npu_queue[-1], bench_queue[-1]):
160
- return len(npu_queue) - 1, len(bench_queue) - 1
161
- for n_index, n_op in enumerate(npu_queue[0: -1]):
162
- if self.check_op(n_op, bench_queue[-1]):
163
- return n_index, len(bench_queue) - 1
164
- return -1, -1
257
+ class ProcessDf:
258
+ def __init__(self, mode_config: ModeConfig, mapping_config: MappingConfig, mapping_dict: MappingDict):
259
+ self.mode_config = mode_config
260
+ self.mapping_config = mapping_config
261
+ self.mapping_dict = mapping_dict
165
262
 
166
- def compare_process(self, file_lists):
167
- npu_json_path, bench_json_path, stack_json_path = file_lists
168
- npu_json_data = load_json(npu_json_path)
169
- bench_json_data = load_json(bench_json_path)
170
- stack_json_data = load_json(stack_json_path) if self.stack_mode else None
263
+ @staticmethod
264
+ def get_api_name(api_list):
265
+ try:
266
+ api_name = api_list[0] + Const.SEP + api_list[1]
267
+ except IndexError as error:
268
+ logger.error('Failed to retrieve API name, please check if the dump data is reasonable')
269
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
270
+ return api_name
271
+
272
+ def process_compare_key_and_shape(self, npu_df, bench_df):
273
+ npu_df = self.assign_npu_df_compare_key(npu_df, bench_df)
274
+ npu_df[CompareConst.CMP_SHAPE] = npu_df[Const.SHAPE]
275
+ bench_df[CompareConst.CMP_KEY] = bench_df[CompareConst.OP_NAME]
276
+ bench_df[CompareConst.CMP_SHAPE] = bench_df[Const.SHAPE]
277
+ return npu_df, bench_df
278
+
279
+ def assign_npu_df_compare_key(self, npu_df, bench_df):
280
+ """
281
+ 处理 npu_df 的 COMPARE_KEY 赋值逻辑
171
282
 
172
- if self.fuzzy_match:
173
- logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
283
+ :param npu_df: DataFrame,NPU 对比数据
284
+ :param bench_df: DataFrame,Bench 对比数据
285
+ :return: compare_key(name)处理后的 npu_df
286
+ """
287
+ # 处理api_mapping映射
288
+ if self.mapping_config.api_mapping:
289
+ # 如果用户不传api_mapping.yaml,先使用内置api_mapping.yaml替换npu_op_name
290
+ npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_internal_api_mapping)
291
+ # 如果用户传入api_mapping.yaml,再使用传入api_mapping.yaml进一步替换npu_op_name
292
+ if isinstance(self.mapping_config.api_mapping, str):
293
+ self.modify_compare_data_with_user_mapping(npu_df, bench_df)
294
+ # 处理cell_mapping映射
295
+ elif self.mapping_config.cell_mapping:
296
+ npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping)
297
+ # 处理data_mapping映射
298
+ elif self.mapping_config.data_mapping:
299
+ npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_data_mapping)
300
+ else:
301
+ npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME]
302
+ return npu_df
303
+
304
+ def process_internal_api_mapping(self, npu_op_name):
305
+ # get api name & class name from op_name
306
+ ms_api_name = self.get_api_name(npu_op_name.split(Const.SEP))
307
+ class_name = ms_api_name.split(Const.SEP)[0]
308
+ if class_name == "Mint":
309
+ return npu_op_name.replace("Mint", "Torch")
310
+ elif class_name == "MintFunctional":
311
+ return npu_op_name.replace("MintFunctional", "Functional")
312
+ elif self.mapping_dict.ms_to_pt_mapping.get(ms_api_name):
313
+ return npu_op_name.replace(ms_api_name, self.mapping_dict.ms_to_pt_mapping.get(ms_api_name))
314
+ else:
315
+ return npu_op_name
316
+
317
+ def modify_compare_data_with_user_mapping(self, npu_df, bench_df):
318
+ def gen_input_compare_key(pattern, term):
319
+ is_unmatched = True
320
+ for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')):
321
+ if op_name.split(pattern)[1].startswith(str(prefix)):
322
+ npu_df.loc[index, CompareConst.CMP_KEY] = (
323
+ op_name.replace(pattern + str(prefix),
324
+ pattern + str(mapping_dict.get(f'pt_{term}')[i])))
325
+ is_unmatched = False
326
+ return is_unmatched
327
+
328
+ ms_api_indices_dict = self.get_api_indices_dict(npu_df)
329
+ pt_api_indices_dict = self.get_api_indices_dict(bench_df)
330
+
331
+ for mapping_dict in self.mapping_dict.api_mapping_dict:
332
+ all_length_equal = True
333
+ for k1, k2 in CompareConst.API_MAPPING_KEYS_TO_COMPARE:
334
+ if len(mapping_dict.get(k1, [])) != len(mapping_dict.get(k2, [])):
335
+ all_length_equal = False
336
+ if not all_length_equal:
337
+ logger.warning('The user-defined mapping table is incorrect,\
338
+ make sure that the number of parameters is equal')
339
+ continue
174
340
 
175
- npu_ops_queue = []
176
- bench_ops_queue = []
177
- result = []
341
+ ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api')
342
+ if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict:
343
+ continue
344
+ for index in ms_api_indices_dict.get(ms_api):
345
+ op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1)
346
+ if CompareConst.INPUT_PATTERN in op_name:
347
+ is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args')
348
+ elif CompareConst.KWARGS_PATTERN in op_name:
349
+ is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args')
350
+ elif CompareConst.OUTPUT_PATTERN in op_name:
351
+ is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output')
352
+ elif CompareConst.PARAMS_PATTERN in op_name:
353
+ is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters')
354
+ elif CompareConst.PARAMS_GRAD_PATTERN in op_name:
355
+ is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad')
356
+ else:
357
+ logger.error(f'Excepted op_name: {op_name}')
358
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
359
+ if is_abandoned:
360
+ npu_df.loc[index, CompareConst.CMP_KEY] = op_name + 'abandoned'
178
361
 
179
- ops_npu_iter = iter(npu_json_data['data'])
180
- ops_bench_iter = iter(bench_json_data['data'])
181
- read_err_npu = True
182
- read_err_bench = True
183
- last_npu_ops_len = 0
184
- last_bench_ops_len = 0
362
+ def get_api_indices_dict(self, op_name_df):
363
+ """
364
+ 生成多个api对应的各自的所有的input、output等的index的键值对字典
365
+ 示例:
366
+ {'Functional.conv2d': [0, 1, 2, 3],
367
+ 'Functional.batch_norm': [4, 5, 6, 7, 8]
368
+ }
369
+ """
370
+ api_indices_dict = defaultdict(list)
371
+ for op_index, name in enumerate(op_name_df[CompareConst.OP_NAME]):
372
+ api_name = self.get_api_name(name.split(Const.SEP))
373
+ api_indices_dict[api_name].append(op_index)
374
+ return api_indices_dict
375
+
376
+ def process_cell_mapping(self, npu_op_name):
377
+ if not npu_op_name:
378
+ return CompareConst.N_A
379
+ param_grad_flag = Const.PARAMS_GRAD in npu_op_name.split(Const.SEP)
380
+ if not param_grad_flag and not re.search(Const.REGEX_FORWARD_BACKWARD, npu_op_name):
381
+ return CompareConst.N_A
382
+ npu_op_name = npu_op_name.replace("Cell", "Module", 1)
383
+ if self.mapping_dict.cell_mapping_dict:
384
+ # get cell name & class name from op_name
385
+ # Cell.fc1.Dense.forward.0.input.0
386
+ cell_name = re.split(r'\.(?:forward|backward|parameters_grad)\.', npu_op_name.split(Const.SEP, 1)[-1])[0]
387
+ if cell_name in self.mapping_dict.cell_mapping_dict:
388
+ npu_op_name = npu_op_name.replace(cell_name, self.mapping_dict.cell_mapping_dict[cell_name], 1)
389
+ return npu_op_name
390
+
391
+ def process_data_mapping(self, npu_op_name):
392
+ return self.mapping_dict.data_mapping_dict.get(npu_op_name, npu_op_name)
393
+
394
+
395
+ class Match:
396
+ def __init__(self, mode_config: ModeConfig, mapping_config: MappingConfig, cross_frame):
397
+ self.mode_config = mode_config
398
+ self.mapping_config = mapping_config
399
+ self.cross_frame = cross_frame
185
400
 
186
- npu_api_nums = len(npu_json_data['data'])
187
- progress_bar = tqdm(total=npu_api_nums, desc="API/Module Read Progress", unit="item", ncols=100)
401
+ @staticmethod
402
+ def put_unmatched_in_table(match_result, npu_op_item):
403
+ npu_columns = npu_op_item.index.tolist()[:-2]
404
+ new_columns = [name[:-1] + 'y' for name in npu_columns]
405
+ na_series = pd.Series([CompareConst.N_A] * len(new_columns), index=new_columns)
406
+ new_result_item = pd.concat([npu_op_item, na_series]).to_frame().T
407
+ new_result_item.columns = CompareConst.MATCH_RESULT_COLUMNS
408
+ match_result = pd.concat([match_result, new_result_item])
409
+ return match_result
188
410
 
189
- while True:
190
- if not read_err_npu and not read_err_bench:
191
- break
192
- try:
193
- last_npu_ops_len = len(npu_ops_queue)
194
- op_name_npu = next(ops_npu_iter)
195
- check_op_str_pattern_valid(op_name_npu)
196
- npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data)
197
- if npu_merge_list:
198
- npu_ops_queue.append(npu_merge_list)
199
- except StopIteration:
200
- read_err_npu = False
201
- try:
202
- last_bench_ops_len = len(bench_ops_queue)
203
- op_name_bench = next(ops_bench_iter)
204
- check_op_str_pattern_valid(op_name_bench)
205
- bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data)
206
- if bench_merge_list:
207
- bench_ops_queue.append(bench_merge_list)
208
- except StopIteration:
209
- read_err_bench = False
411
+ @staticmethod
412
+ def put_matched_in_table(match_result, npu_op_item, bench_op_item):
413
+ head_len = len(CompareConst.MATCH_RESULT_COLUMNS)
414
+ new_result_item = pd.concat([npu_op_item, bench_op_item]).head(head_len).to_frame().T
415
+ new_result_item.columns = CompareConst.MATCH_RESULT_COLUMNS
416
+ match_result = pd.concat([match_result, new_result_item])
417
+ return match_result
210
418
 
211
- progress_bar.update(1)
419
+ @staticmethod
420
+ def rename_api(op_name):
421
+ """
422
+ 原api: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号}
423
+ rename后: {api_type}.{api_name}.{前向反向}.{input/output}.{参数序号}
424
+ """
425
+ if Const.FORWARD not in op_name and Const.BACKWARD not in op_name:
426
+ return op_name
427
+ process = Const.FORWARD if Const.FORWARD in op_name else Const.BACKWARD
428
+ name_split = op_name.split(process)
429
+ try:
430
+ torch_func_index, in_out = name_split[0], name_split[1]
431
+ except IndexError as error:
432
+ logger.error(f'{op_name} can not be split with {process}, please check!')
433
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
434
+ torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
435
+ torch_func = str(torch_func_split[0]) + Const.SEP + process + str(in_out)
436
+ return torch_func
437
+
438
+ def check_op_item(self, npu_op_item, bench_op_item):
439
+ name_match = self.rename_api(npu_op_item[CompareConst.CMP_KEY]) == self.rename_api(
440
+ bench_op_item[CompareConst.CMP_KEY])
441
+ shape_match = npu_op_item[CompareConst.CMP_SHAPE] == bench_op_item[CompareConst.CMP_SHAPE]
442
+ if name_match and shape_match:
443
+ return True
444
+ else:
445
+ npu_op_name = npu_op_item[CompareConst.OP_NAME]
446
+ bench_op_name = bench_op_item[CompareConst.OP_NAME]
447
+ check_op_str_pattern_valid(npu_op_name)
448
+ check_op_str_pattern_valid(bench_op_name)
449
+ logger.warning(f"{npu_op_name} and {bench_op_name} can not fuzzy match")
450
+ return False
212
451
 
213
- # merge all boolean expressions
214
- both_empty = not npu_ops_queue and not bench_ops_queue
215
- no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len)
216
- if both_empty or no_change:
217
- continue
452
+ def match_api_infos(self, npu_df, bench_df):
453
+ """
454
+ 正常匹配和模糊匹配
455
+ """
456
+ if self.mapping_config.data_mapping:
457
+ match_result = pd.merge(npu_df, bench_df, on=[CompareConst.CMP_KEY], how='left')
458
+
459
+ # reorder match_result by op_name of npu
460
+ op_name_order = npu_df[CompareConst.OP_NAME].tolist()
461
+ match_result[CompareConst.OP_NAME_X] = pd.Categorical(match_result[CompareConst.OP_NAME_X],
462
+ categories=op_name_order, ordered=True)
463
+ match_result = match_result.sort_values(CompareConst.OP_NAME_X).reset_index(drop=True)
464
+ match_result[CompareConst.OP_NAME_X] = match_result[CompareConst.OP_NAME_X].astype('object')
465
+ elif not self.mode_config.fuzzy_match:
466
+ match_result = pd.merge(npu_df, bench_df, on=[CompareConst.CMP_KEY, CompareConst.CMP_SHAPE],
467
+ how='outer')
468
+ else:
469
+ match_result = self.process_fuzzy_match(npu_df, bench_df)
470
+ return match_result
218
471
 
219
- # APIs in NPU and Bench models unconsistent judgment
472
+ def process_fuzzy_match(self, npu_df, bench_df):
473
+ """
474
+ 模糊匹配通过循环方式匹配api
475
+ """
476
+ npu_ops_queue = []
477
+ bench_ops_queue = []
478
+ match_result = pd.DataFrame(columns=CompareConst.MATCH_RESULT_COLUMNS)
479
+
480
+ max_len = max(len(npu_df), len(bench_df))
481
+ min_len = min(len(npu_df), len(bench_df))
482
+ for i in range(max_len):
483
+ if i < min_len:
484
+ npu_ops_queue.append(npu_df.iloc[i])
485
+ bench_ops_queue.append(bench_df.iloc[i])
486
+ else:
487
+ try:
488
+ npu_ops_queue.append(npu_df.iloc[i])
489
+ except IndexError:
490
+ pass
491
+ try:
492
+ bench_ops_queue.append(bench_df.iloc[i])
493
+ except IndexError:
494
+ pass
495
+
496
+ # 如果append之后queue状态不一致,则判断结束
220
497
  if bool(npu_ops_queue) ^ bool(bench_ops_queue):
221
- logger.info("Please check whether the number and calls of APIs in NPU and Bench models are consistent.")
222
498
  break
223
499
 
224
- n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue)
500
+ npu_match_point, bench_match_point = self.match_op(npu_ops_queue, bench_ops_queue)
225
501
 
226
- # 如果没有匹配到,数据放到队列中,跳过,直到后面匹配到,把匹配之前的api放到不匹配中
227
- if n_match_point == -1 and b_match_point == -1:
502
+ # 如果没有匹配到,数据放到队列中,跳过。直到后面匹配到,把匹配之前的api放到不匹配中
503
+ if npu_match_point == -1 and bench_match_point == -1:
228
504
  continue
229
505
 
230
- n_match_data = npu_ops_queue[n_match_point]
231
- b_match_data = bench_ops_queue[b_match_point]
232
- un_match_data = npu_ops_queue[0: n_match_point]
233
- for npu_data in un_match_data:
234
- get_un_match_accuracy(result, npu_data, self.dump_mode)
235
- get_accuracy(result, n_match_data, b_match_data, self.dump_mode)
236
- del npu_ops_queue[0: n_match_point + 1]
237
- del bench_ops_queue[0: b_match_point + 1]
238
- progress_bar.close()
506
+ npu_op_item = npu_ops_queue[npu_match_point]
507
+ bench_op_item = bench_ops_queue[bench_match_point]
508
+ unmatched_data = npu_ops_queue[0: npu_match_point]
509
+ for op_item in unmatched_data:
510
+ match_result = self.put_unmatched_in_table(match_result, op_item)
511
+ match_result = self.put_matched_in_table(match_result, npu_op_item, bench_op_item)
512
+ del npu_ops_queue[0: npu_match_point + 1]
513
+ del bench_ops_queue[0: bench_match_point + 1]
514
+
239
515
  if npu_ops_queue:
240
- for npu_data in npu_ops_queue:
241
- get_un_match_accuracy(result, npu_data, self.dump_mode)
242
-
243
- result_df = self.make_result_table(result)
244
- return result_df
245
-
246
- def merge_data(self, json_data, stack_json_data):
247
- ops_all = {}
248
- for op_name in json_data.get('data', {}):
249
- merge_list = self.gen_merge_list(json_data, op_name, stack_json_data)
250
- if merge_list:
251
- struct_to_index_mapping = {
252
- CompareConst.INPUT_STRUCT: 0,
253
- CompareConst.OUTPUT_STRUCT: 0,
254
- CompareConst.PARAMS_STRUCT: 0,
255
- CompareConst.PARAMS_GRAD_STRUCT: 0
256
- }
257
-
258
- op_name_list = merge_list.get(CompareConst.OP_NAME)
259
- summary_list = merge_list.get(Const.SUMMARY)
260
- data_name_list = merge_list.get('data_name')
261
- op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
262
- summary_list,
263
- data_name_list)
264
- for index, op_full_name in enumerate(op_name_reorder):
265
- data_name = data_name_reorder[index] if data_name_reorder else None
266
-
267
- _, state = get_name_and_state(op_full_name)
268
- struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
269
- if not struct_key:
270
- continue
271
- ops_all[op_full_name] = {
272
- CompareConst.STRUCT: safe_get_value(merge_list, struct_to_index_mapping.get(struct_key),
273
- "merge_list", key=struct_key),
274
- CompareConst.SUMMARY: safe_get_value(summary_reorder, index, "summary_reorder"),
275
- 'data_name': data_name,
276
- 'stack_info': merge_list.get('stack_info')
277
- }
278
- struct_to_index_mapping[struct_key] += 1
279
- return ops_all
280
-
281
- def get_accuracy(self, npu_ops_all, bench_ops_all):
282
- result = []
283
- bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all)
284
- for ms_op_name, bench_op_name in self.data_mapping_dict.items():
285
- check_op_str_pattern_valid(ms_op_name)
286
- check_op_str_pattern_valid(bench_op_name)
287
- if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
288
- npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
289
- bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
290
- has_stack = npu_stack_info and bench_stack_info
291
- if self.dump_mode == Const.MD5:
292
- result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all,
293
- bench_ops_all, has_stack, npu_stack_info))
294
- continue
295
-
296
- npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
297
- bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
298
-
299
- if len(npu_struct) < 2 or len(bench_struct) < 2:
300
- logger.error(
301
- f"The length of npu_struct and bench_struct must be >= 2, "
302
- f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. "
303
- f"Please check!"
304
- )
305
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
306
-
307
- base_result_item = [
308
- ms_op_name, bench_op_name,
309
- npu_struct[0],
310
- bench_struct[0],
311
- npu_struct[1],
312
- bench_struct[1]
313
- ]
314
-
315
- if self.dump_mode == Const.SUMMARY:
316
- result_item = base_result_item + [" "] * 8 # 8个统计量数据情况的比对指标
317
- else:
318
- result_item = base_result_item + [" "] * 6 # 6个真实数据情况的比对指标
319
-
320
- npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
321
- result_item.extend(npu_summary_data)
322
- bench_summary_data = bench_ops_all.get(bench_op_name).get("summary")
323
- result_item.extend(bench_summary_data)
324
- if self.dump_mode == Const.SUMMARY:
325
- self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item)
326
- else:
327
- result_item.append(CompareConst.ACCURACY_CHECK_YES)
328
- result_item.append("")
329
- if has_stack:
330
- result_item.extend(npu_stack_info)
331
- else:
332
- result_item.append(CompareConst.NONE)
333
- if self.dump_mode == Const.ALL:
334
- ms_data_name = npu_ops_all.get(ms_op_name).get("data_name", None)
335
- pt_data_name = bench_ops_all.get(bench_op_name).get("data_name", None)
336
- result_item.append([ms_data_name, pt_data_name])
337
- result.append(result_item)
338
- logger.info(f"{ms_op_name}, {bench_op_name} compared.")
339
- elif ms_op_name not in npu_ops_all:
340
- logger.warning(f'Can not find npu op name : `{ms_op_name}` in npu dump json file.')
341
- elif bench_op_name not in npu_ops_all:
342
- logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.')
343
- return result
516
+ for op_item in npu_ops_queue:
517
+ match_result = self.put_unmatched_in_table(match_result, op_item)
344
518
 
345
- def compare_process_custom(self, file_lists):
346
- npu_json_path, bench_json_path, stack_json_path = file_lists
347
- npu_json_data = load_json(npu_json_path)
348
- bench_json_data = load_json(bench_json_path)
349
- stack_json_data = load_json(stack_json_path) if self.stack_mode else None
350
- npu_ops_all = self.merge_data(npu_json_data, stack_json_data)
351
- bench_ops_all = self.merge_data(bench_json_data, stack_json_data)
519
+ match_result.reset_index(drop=True, inplace=True)
520
+ return match_result
352
521
 
353
- result = self.get_accuracy(npu_ops_all, bench_ops_all)
354
- result_df = self.make_result_table(result)
355
- return result_df
522
+ def match_op(self, npu_queue, bench_queue):
523
+ for b_index, b_op in enumerate(bench_queue[0: -1]):
524
+ if self.check_op_item(npu_queue[-1], b_op):
525
+ return len(npu_queue) - 1, b_index
526
+ if self.check_op_item(npu_queue[-1], bench_queue[-1]):
527
+ return len(npu_queue) - 1, len(bench_queue) - 1
528
+ for n_index, n_op in enumerate(npu_queue[0: -1]):
529
+ if self.check_op_item(n_op, bench_queue[-1]):
530
+ return n_index, len(bench_queue) - 1
531
+ return -1, -1
356
532
 
357
- def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
533
+ def gen_dtype_condition(self, match_result):
358
534
  """
359
- :param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0
360
- :param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
361
- :param op_name_mapping_dict: op_name和npy或pt文件的映射关系
362
- :param input_param: npu_json_path/bench_json_path/stack_json_path等参数
363
- :return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
364
- 用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、欧式距离
365
- 最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
535
+ dtype匹配条件为npu、bench的dtype一致或属于规定的映射关系
366
536
  """
367
- error_file, relative_err, error_flag = None, None, False
537
+ # 如果使用了data_mapping,不校验dtype,返回全True的DataFrame
538
+ if self.mapping_config.data_mapping:
539
+ return pd.Series(True, index=match_result.index)
540
+
541
+ npu_dtype = match_result['dtype_x']
542
+ bench_dtype = match_result['dtype_y']
543
+ npu_dtype = self.process_cross_frame_dtype(npu_dtype)
544
+ bench_dtype = self.process_cross_frame_dtype(bench_dtype)
545
+
546
+ equal_condition = npu_dtype == bench_dtype
547
+ match_condition = (
548
+ (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[0]) & bench_dtype.isin(
549
+ CompareConst.DTYPE_MATCH_GROUPS[0])) |
550
+ (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[1]) & bench_dtype.isin(
551
+ CompareConst.DTYPE_MATCH_GROUPS[1]))
552
+ )
553
+ return equal_condition | match_condition
368
554
 
369
- data_name_pair = op_name_mapping_dict.get(npu_op_name)
370
- npu_data_name = data_name_pair[0]
371
- bench_data_name = data_name_pair[1]
555
+ def process_cross_frame_dtype(self, dtype):
556
+ if self.cross_frame:
557
+ dtype = dtype.map(cross_dtype_mapping).fillna(dtype)
558
+ return dtype
372
559
 
373
- if str(npu_data_name) == '-1': # 没有npu真实数据
374
- n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
375
- elif str(bench_data_name) == '-1': # 没有bench真实数据
376
- n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
377
- error_file = 'no_bench_data'
378
- else:
379
- npu_dir = input_param.get("npu_dump_data_dir")
380
- bench_dir = input_param.get("bench_dump_data_dir")
381
- try:
382
- frame_name = getattr(self, "frame_name")
383
- read_npy_data = getattr(self, "read_npy_data")
384
- if frame_name == "MSComparator":
385
- n_value = read_npy_data(npu_dir, npu_data_name)
386
- if self.cross_frame:
387
- b_value = read_npy_data(bench_dir, bench_data_name, load_pt_file=True)
388
- else:
389
- b_value = read_npy_data(bench_dir, bench_data_name)
390
- else:
391
- n_value = read_npy_data(npu_dir, npu_data_name)
392
- b_value = read_npy_data(bench_dir, bench_data_name)
393
- except IOError as error:
394
- error_file = error.filename
395
- n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
396
- error_flag = True
397
- except (FileCheckException, CompareException):
398
- error_file = npu_data_name
399
- n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
400
- error_flag = True
401
-
402
- # 通过n_value, b_value同时得到错误标志和错误信息
403
- n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value,
404
- error_flag=error_flag, error_file=error_file)
405
-
406
- result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg)
407
-
408
- if self.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
409
- err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
410
- result_list.append(err_msg)
411
- return result_list
412
560
 
413
- def compare_core(self, input_param, output_path, **kwargs):
414
- """
415
- Compares data from multiple JSON files and generates a comparison report.
416
-
417
- Args:
418
- input_param (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
419
- "stack_path").
420
- output_path (str): The path where the output Excel report will be saved.
421
- **kwargs: Additional keyword arguments including:
422
- - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False.
423
- - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
424
- - suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
425
- - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
426
- - dump_mode (str): ALL, SUMMARY, MD5.
561
+ class CreateTable:
562
+ def __init__(self, mode_config: ModeConfig):
563
+ self.mode_config = mode_config
427
564
 
428
- Returns:
429
- """
430
- # get kwargs or set default value
431
- suffix = kwargs.get('suffix', '')
565
+ @staticmethod
566
+ def process_data_name(result):
567
+ result['data_name_x'] = result.apply(lambda row: [row['data_name_x'], row['data_name_y']], axis=1)
568
+ return result
432
569
 
433
- logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
434
- file_name = add_time_with_xlsx("compare_result" + suffix)
435
- file_path = os.path.join(os.path.realpath(output_path), file_name)
436
- if os.path.exists(file_path):
437
- logger.warning(f"{file_path} will be deleted.")
438
- remove_path(file_path)
439
- highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
570
+ @staticmethod
571
+ def set_summary(summary):
572
+ if summary == CompareConst.N_A:
573
+ return [CompareConst.N_A] * 4 # 4为统计值个数
574
+ summary_list = []
575
+ for i in summary:
576
+ if str(i).lower() == 'nan':
577
+ summary_list.append(CompareConst.NAN)
578
+ else:
579
+ summary_list.append(i)
580
+ return summary_list
440
581
 
441
- npu_json = input_param.get("npu_json_path")
442
- bench_json = input_param.get("bench_json_path")
443
- stack_json = input_param.get("stack_json_path")
444
- if self.data_mapping:
445
- result_df = self.compare_process_custom([npu_json, bench_json, stack_json])
582
+ def make_result_df(self, result):
583
+ # get header
584
+ header = CompareConst.HEAD_OF_COMPARE_MODE[self.mode_config.dump_mode][:]
585
+ if self.mode_config.stack_mode:
586
+ header.append(CompareConst.STACK)
587
+ if self.mode_config.dump_mode == Const.ALL:
588
+ header.append(CompareConst.DATA_NAME)
589
+ result = self.process_data_name(result)
590
+
591
+ # rename match_result columns
592
+ result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
593
+ 'op_name_y': CompareConst.BENCH_NAME,
594
+ 'dtype_x': CompareConst.NPU_DTYPE,
595
+ 'dtype_y': CompareConst.BENCH_DTYPE,
596
+ 'shape_x': CompareConst.NPU_SHAPE,
597
+ 'shape_y': CompareConst.BENCH_SHAPE,
598
+ 'md5_x': CompareConst.NPU_MD5,
599
+ 'md5_y': CompareConst.BENCH_MD5,
600
+ 'data_name_x': CompareConst.DATA_NAME,
601
+ 'stack_info_x': CompareConst.STACK}, inplace=True)
602
+
603
+ # process summary data
604
+ npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
605
+ bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
606
+ CompareConst.BENCH_NORM]
607
+ if result.empty:
608
+ result[npu_summary] = pd.DataFrame(columns=npu_summary)
609
+ result[bench_summary] = pd.DataFrame(columns=bench_summary)
446
610
  else:
447
- result_df = self.compare_process([npu_json, bench_json, stack_json])
611
+ result[npu_summary] = result['summary_x'].apply(self.set_summary).tolist()
612
+ result[bench_summary] = result['summary_y'].apply(self.set_summary).tolist()
448
613
 
449
- if not result_df.values.tolist():
450
- logger.warning("Can`t match any op.")
451
- return
614
+ result_df = pd.DataFrame(columns=header)
615
+ for h in header:
616
+ if h in result.columns:
617
+ result_df[h] = result[h]
618
+ return result_df, header
452
619
 
453
- if self.dump_mode == Const.ALL:
454
- result_df = self.do_multi_process(input_param, result_df)
455
620
 
456
- find_compare_result_error_rows(result_df, highlight_dict, self.dump_mode)
457
- highlight_rows_xlsx(result_df, highlight_dict, file_path)
458
-
459
- if self.auto_analyze:
460
- advisor = Advisor(result_df, output_path, suffix)
461
- advisor.analysis()
621
+ class CalcStatsDiff:
622
+ def __init__(self, mode_config: ModeConfig):
623
+ self.mode_config = mode_config
462
624
 
463
- print_compare_ends_info()
625
+ @staticmethod
626
+ def type_check(val):
627
+ """
628
+ 检查是否为数值或字符串形式的nan, 如果是返回True
629
+ """
630
+ check_series = pd.Series(False, index=val.index)
631
+ val_str = val.astype(str)
632
+ check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True
633
+ return check_series
464
634
 
465
- def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
466
- cos_result = []
467
- euc_dist_result = []
468
- max_err_result = []
469
- max_relative_err_result = []
470
- one_thousand_err_ratio_result = []
471
- five_thousand_err_ratio_result = []
472
- err_mess = []
473
-
474
- is_print_compare_log = input_param.get("is_print_compare_log")
475
-
476
- for i in range(len(result_df)):
477
- npu_op_name = result_df.iloc[i, 0]
478
- bench_op_name = result_df.iloc[i, 1]
479
- if is_print_compare_log:
480
- logger.info("start compare: {}".format(npu_op_name))
481
-
482
- cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg \
483
- = self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
484
-
485
- if is_print_compare_log:
486
- logger.info(
487
- "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
488
- one_thousand_err_ratio {}, "
489
- "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err,
490
- err_msg, one_thousand_err_ratio, five_thousand_err_ratio))
491
- cos_result.append(cos_sim)
492
- euc_dist_result.append(euc_dist)
493
- max_err_result.append(max_abs_err)
494
- max_relative_err_result.append(max_relative_err)
495
- one_thousand_err_ratio_result.append(one_thousand_err_ratio)
496
- five_thousand_err_ratio_result.append(five_thousand_err_ratio)
497
- err_mess.append(err_msg)
498
-
499
- cr = ComparisonResult(
500
- cos_result=cos_result,
501
- euc_dist_result=euc_dist_result,
502
- max_err_result=max_err_result,
503
- max_relative_err_result=max_relative_err_result,
504
- one_thousand_err_ratio_result=one_thousand_err_ratio_result,
505
- five_thousand_err_ratio_result=five_thousand_err_ratio_result,
506
- err_msgs=err_mess
635
+ @staticmethod
636
+ def get_number(val):
637
+ return pd.to_numeric(val.astype(str), errors='coerce')
638
+
639
+ def calc_summary_diff(self, result_df, cond_no_bench, stats_index: str):
640
+ npu_val = result_df['NPU ' + stats_index]
641
+ bench_val = result_df['Bench ' + stats_index]
642
+ diff_name = stats_index.capitalize() + ' diff'
643
+ rel_err_name = ('norm' if stats_index == 'l2norm' else stats_index).capitalize() + 'RelativeErr'
644
+
645
+ # npu、bench中统计量均为数字或nan
646
+ cond_num_nan = self.type_check(npu_val) & self.type_check(bench_val)
647
+
648
+ # 如果统计量不是数字或nan,就赋值统计量差异为N/A
649
+ result_df.loc[~cond_num_nan, [diff_name, rel_err_name]] = CompareConst.N_A
650
+ cond_valid_stat = ~cond_no_bench & cond_num_nan # 有效统计条件:bench_name不是N/A,并且NPU和bench的统计量都是数字或nan
651
+ result_df.loc[cond_valid_stat, diff_name] = self.get_number(npu_val) - self.get_number(bench_val)
652
+
653
+ cond_diff_nan = result_df[diff_name].isna() # 统计量差异是nan
654
+ cond_nan_diff = cond_valid_stat & cond_diff_nan
655
+ result_df.loc[cond_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN
656
+
657
+ cond_not_nan_diff = cond_valid_stat & ~cond_diff_nan
658
+ condition_pt_zero = bench_val == 0
659
+ result_df.loc[cond_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.N_A
660
+
661
+ # 相对误差转成百分比字符串
662
+ cond_ref_err = cond_not_nan_diff & ~condition_pt_zero
663
+ result_df.loc[cond_ref_err, rel_err_name] = (
664
+ result_df.loc[cond_ref_err, diff_name] / bench_val[cond_ref_err] * 100)
665
+ result_df.loc[cond_ref_err, rel_err_name] = (result_df.loc[cond_ref_err, rel_err_name].abs().astype(str) + '%')
666
+
667
+ magnitude = self.get_number(result_df[diff_name]).abs() / (pd.Series(
668
+ np.maximum(self.get_number(npu_val), self.get_number(bench_val))).abs() + CompareConst.EPSILON)
669
+ return magnitude > CompareConst.MAGNITUDE
670
+
671
+ def calc_accuracy(self, result_df, header):
672
+ # bench name N/A represents no bench data, err_msg adds "No bench data matched."
673
+ condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
674
+ result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
675
+ result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH
676
+
677
+ if self.mode_config.dump_mode == Const.MD5:
678
+ condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5]
679
+ result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS
680
+ result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
681
+ elif self.mode_config.dump_mode == Const.SUMMARY:
682
+ warning_list = [
683
+ self.calc_summary_diff(result_df, condition_no_bench, stats_index)
684
+ for stats_index in ['max', 'min', 'mean', 'l2norm']
685
+ ]
686
+ warning_flag = pd.DataFrame(warning_list).any()
687
+ result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
688
+ result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
689
+ result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
690
+ else:
691
+ fill_cols = [CompareConst.COSINE, CompareConst.EUC_DIST,
692
+ CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
693
+ CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
694
+ CompareConst.ERROR_MESSAGE]
695
+ result_df.loc[~condition_no_bench, fill_cols] = ''
696
+ result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES
697
+
698
+ return result_df[header]
699
+
700
+
701
+ def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig:
702
+ """公共的前置处理逻辑,返回封装后的 ComparisonConfig 对象"""
703
+ try:
704
+ config = ComparisonConfig(
705
+ dump_mode='',
706
+ stack_mode=False,
707
+ auto_analyze=kwargs.get('auto_analyze', True),
708
+ fuzzy_match=kwargs.get('fuzzy_match', False),
709
+ data_mapping=kwargs.get('data_mapping', {}),
710
+ suffix=kwargs.get('suffix', ''),
711
+ cell_mapping=kwargs.get('cell_mapping', {}),
712
+ api_mapping=kwargs.get('api_mapping', {}),
713
+ layer_mapping=kwargs.get('layer_mapping', {}),
714
+ compared_file_type='',
507
715
  )
508
716
 
509
- return _save_cmp_result(idx, cr, result_df, lock)
717
+ set_dump_path(input_param)
718
+ config.dump_mode = get_dump_mode(input_param)
719
+ config.compared_file_type = get_file_type(input_param.get("npu_json_path", None))
510
720
 
511
- def do_multi_process(self, input_param, result_df):
512
- try:
513
- result_df = _handle_multi_process(self.compare_ops, input_param, result_df,
514
- multiprocessing.Manager().RLock())
515
- return result_df
516
- except ValueError as e:
517
- logger.error('result dataframe is not found.')
518
- raise CompareException(CompareException.INVALID_DATA_ERROR) from e
721
+ # set stack_mode and set "stack_json_path" in input_param
722
+ if 'stack_json_path' in input_param:
723
+ config.stack_mode = kwargs.get('stack_mode', False)
724
+ else:
725
+ config.stack_mode = set_stack_json_path(input_param)
726
+
727
+ check_configuration_param(config.stack_mode, config.auto_analyze, config.fuzzy_match,
728
+ input_param.get('is_print_compare_log', True))
729
+ create_directory(output_path)
730
+ check_compare_param(input_param, output_path, config.dump_mode, config.stack_mode)
731
+
732
+ return config
733
+
734
+ except (CompareException, FileCheckException) as error:
735
+ logger.error('Compare failed. Please check the arguments and do it again!')
736
+ raise CompareException(error.code) from error