mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -13,111 +13,229 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import multiprocessing
17
16
  import os
18
17
  import re
19
- from copy import deepcopy
18
+ from dataclasses import dataclass
19
+ from collections import defaultdict
20
20
 
21
+ import numpy as np
21
22
  import pandas as pd
22
23
  from tqdm import tqdm
23
24
 
24
25
  from msprobe.core.advisor.advisor import Advisor
25
26
  from msprobe.core.common.const import CompareConst, Const
26
27
  from msprobe.core.common.exceptions import FileCheckException
27
- from msprobe.core.common.file_utils import load_json, remove_path
28
+ from msprobe.core.common.file_utils import load_json, remove_path, create_directory
28
29
  from msprobe.core.common.log import logger
29
- from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, safe_get_value
30
- from msprobe.core.compare.check import check_dump_json_str, check_graph_mode, check_stack_json_str, \
31
- check_struct_match, fuzzy_check_op
32
- from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
33
- from msprobe.core.compare.multiprocessing_compute import ComparisonResult, _handle_multi_process, _save_cmp_result
34
- from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_flag_and_msg
35
- from msprobe.core.compare.utils import get_accuracy, get_rela_diff_summary_mode, get_un_match_accuracy, merge_tensor, \
36
- print_compare_ends_info, read_op, get_name_and_state, reorder_op_x_list
37
-
38
-
39
- class ModeConfig:
40
- def __init__(self, stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=None):
41
- self.stack_mode = stack_mode
42
- self.auto_analyze = auto_analyze
43
- self.fuzzy_match = fuzzy_match
44
- self.dump_mode = dump_mode
30
+ from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, \
31
+ set_dump_path, get_dump_mode, check_compare_param, check_configuration_param, load_stack_json, get_file_type
32
+ from msprobe.core.compare.check import check_dump_json_str, check_stack_json_str, cross_dtype_mapping
33
+ from msprobe.core.compare.utils import merge_tensor, print_compare_ends_info, read_op, \
34
+ reorder_op_x_list, set_stack_json_path
35
+ from msprobe.core.compare.config import ModeConfig, MappingConfig, MappingDict
36
+ from msprobe.core.compare.multiprocessing_compute import CompareRealData
37
+ from msprobe.core.compare.highlight import HighLight
38
+
39
+
40
+ @dataclass
41
+ class ComparisonConfig:
42
+ dump_mode: str
43
+ stack_mode: bool
44
+ auto_analyze: bool
45
+ fuzzy_match: bool
46
+ data_mapping: dict
47
+ suffix: str
48
+ cell_mapping: dict
49
+ api_mapping: dict
50
+ layer_mapping: dict
51
+ compared_file_type: str
45
52
 
46
53
 
47
54
  class Comparator:
48
- def __init__(self, mode_config: ModeConfig):
49
- self.stack_mode = mode_config.stack_mode
50
- self.auto_analyze = mode_config.auto_analyze
51
- self.fuzzy_match = mode_config.fuzzy_match
52
- self.dump_mode = mode_config.dump_mode
55
+ def __init__(self, file_reader, mode_config: ModeConfig, mapping_config: MappingConfig, is_cross_framework=False):
56
+ self.file_reader = file_reader
57
+ self.mode_config = mode_config
58
+ self.mapping_config = mapping_config
59
+ self.cross_frame = is_cross_framework
60
+
61
+ self.mapping_dict = MappingDict(mapping_config)
53
62
 
54
63
  @staticmethod
55
- def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args):
56
- npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
57
- bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
64
+ def process_output_file(output_path, suffix, compared_file_type):
65
+ file_name_prefix_mapping = {
66
+ Const.DUMP_JSON_FILE: "compare_result",
67
+ Const.DEBUG_JSON_FILE: "debug_compare_result"
68
+ }
69
+ file_name_prefix = file_name_prefix_mapping.get(compared_file_type, "compare_result")
70
+ file_name = add_time_with_xlsx(file_name_prefix + suffix)
71
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
72
+ if os.path.exists(file_path):
73
+ logger.warning(f"{file_path} will be deleted.")
74
+ remove_path(file_path)
75
+ return file_path
58
76
 
59
- if len(npu_struct) < 3 or len(bench_struct) < 3:
60
- logger.error(f"The length of npu_struct and bench_struct must be >= 3, "
61
- f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. Please check!")
62
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
77
+ def compare_core(self, input_param, output_path, **kwargs):
78
+ """
79
+ Compares data from multiple JSON files and generates a comparison report.
63
80
 
64
- result_item = [ms_op_name, bench_op_name, npu_struct[0], bench_struct[0],
65
- npu_struct[1], bench_struct[1], npu_struct[2], bench_struct[2],
66
- CompareConst.PASS if npu_struct[2] == bench_struct[2] else CompareConst.DIFF]
81
+ Args:
82
+ input_param (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
83
+ "stack_path").
84
+ output_path (str): The path where the output Excel report will be saved.
85
+ **kwargs: Additional keyword arguments including:
86
+ - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False.
87
+ - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
88
+ - suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
89
+ - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
90
+ - dump_mode (str): ALL, SUMMARY, MD5.
67
91
 
68
- if len(args) >= 2 and args[0]:
69
- result_item.extend(args[1])
70
- else:
71
- result_item.append(CompareConst.NONE)
72
- return result_item
92
+ Returns:
93
+ """
94
+ logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
73
95
 
74
- @staticmethod
75
- def calculate_summary_data(npu_summary_data, bench_summary_data, result_item):
76
- err_msg = ""
77
- result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
78
- bench_summary_data, err_msg)
79
- result_item.append(accuracy_check)
80
- result_item.append(err_msg)
96
+ # get kwargs or set default value
97
+ suffix = kwargs.get('suffix', '')
81
98
 
82
- @staticmethod
83
- def _generate_na_data(ops_all):
84
- if not ops_all:
85
- return {}
86
- key = next(iter(ops_all))
87
- value = deepcopy(ops_all[key])
88
- for k, v in value.items():
89
- if isinstance(v, tuple):
90
- value[k] = tuple(CompareConst.N_A for _ in range(len(v)))
91
- elif isinstance(v, list):
92
- value[k] = [CompareConst.N_A] * len(v)
93
- else:
94
- value[k] = CompareConst.N_A
95
- return value
99
+ # process output file
100
+ file_path = self.process_output_file(output_path, suffix, self.mode_config.compared_file_type)
96
101
 
97
- def make_result_table(self, result):
98
- header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:]
102
+ # initialize the compare result table and compare general data(name, dtype, shape, statistics/md5, etc.)
103
+ npu_json = input_param.get("npu_json_path")
104
+ bench_json = input_param.get("bench_json_path")
105
+ stack_json = input_param.get("stack_json_path")
106
+ result_df = self.compare_statistics([npu_json, bench_json, stack_json])
107
+ if not result_df.values.tolist():
108
+ logger.warning("Can`t match any op.")
109
+ return
99
110
 
100
- if self.stack_mode:
101
- header.append(CompareConst.STACK)
102
- if self.dump_mode == Const.ALL:
103
- header.append(CompareConst.DATA_NAME)
104
- else:
105
- if self.dump_mode == Const.ALL:
106
- for row in result:
107
- del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列
108
- header.append(CompareConst.DATA_NAME)
109
- else:
110
- for row in result:
111
- del row[-1] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,非真实数据时为倒数第1列
112
- result_df = pd.DataFrame(result, columns=header, dtype='object')
113
- return result_df
111
+ # compare real data
112
+ if self.mode_config.dump_mode == Const.ALL:
113
+ compare_real_data = CompareRealData(self.file_reader, self.mode_config, self.cross_frame)
114
+ result_df = compare_real_data.do_multi_process(input_param, result_df)
115
+
116
+ # highlight suspicious API
117
+ highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
118
+ highlight = HighLight(self.mode_config)
119
+ if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE:
120
+ highlight.find_compare_result_error_rows(result_df, highlight_dict)
121
+ highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path)
122
+
123
+ # output compare analysis suggestions
124
+ if self.mode_config.auto_analyze:
125
+ advisor = Advisor(result_df, output_path, suffix)
126
+ advisor.analysis()
127
+
128
+ print_compare_ends_info()
129
+
130
+ def compare_statistics(self, file_list):
131
+ # load and parse json data
132
+ parse_data = ParseData(self.mode_config)
133
+ npu_df, bench_df = parse_data.parse(file_list)
134
+
135
+ npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str)
136
+ bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str)
137
+
138
+ # create new columns for compare op_name and shape
139
+ # process npu_df's COMPARE_KEY whether same or different framework
140
+ process_df = ProcessDf(self.mode_config, self.mapping_config, self.mapping_dict)
141
+ npu_df, bench_df = process_df.process_compare_key_and_shape(npu_df, bench_df)
142
+
143
+ # match npu and bench, match_result contains both npu_info and bench_info
144
+ match = Match(self.mode_config, self.mapping_config, self.cross_frame)
145
+ match_result = match.match_api_infos(npu_df, bench_df)
146
+ # 筛选出npu_name存在的行并填充筛选出行中的缺失值为N/A
147
+ match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A)
148
+ bench_columns = [i + '_y' for i in bench_df.columns]
149
+ match_result.loc[~match.gen_dtype_condition(match_result), bench_columns] = CompareConst.N_A
150
+
151
+ # organize compare result table by renaming columns
152
+ create_table = CreateTable(self.mode_config)
153
+ result_df, header = create_table.make_result_df(match_result)
154
+
155
+ # calculate statistics diff
156
+ calc_stats_diff = CalcStatsDiff(self.mode_config)
157
+ return calc_stats_diff.calc_accuracy(result_df, header)
158
+
159
+
160
+ class ParseData:
161
+ def __init__(self, mode_config: ModeConfig):
162
+ self.mode_config = mode_config
163
+
164
+ def parse(self, file_list):
165
+ npu_json_path, bench_json_path, stack_json_path = file_list
166
+ npu_json_data = load_json(npu_json_path)
167
+ bench_json_data = load_json(bench_json_path)
168
+ stack_json_data = load_stack_json(stack_json_path) if self.mode_config.stack_mode else None
169
+
170
+ # parse json data and generate df
171
+ npu_df = self.gen_data_df(npu_json_data, stack_json_data)
172
+ bench_df = self.gen_data_df(bench_json_data, stack_json_data)
173
+
174
+ return npu_df, bench_df
175
+
176
+ def gen_data_df(self, data_json, stack_json_data):
177
+ result = {
178
+ CompareConst.OP_NAME: [],
179
+ Const.DTYPE: [],
180
+ Const.SHAPE: [],
181
+ Const.SUMMARY: [],
182
+ Const.STACK_INFO: []
183
+ }
184
+ if self.mode_config.dump_mode == Const.ALL:
185
+ result['data_name'] = []
186
+ elif self.mode_config.dump_mode == Const.MD5:
187
+ result[Const.MD5] = []
188
+
189
+ api_nums = len(data_json['data'])
190
+ progress_bar = tqdm(total=api_nums, desc="API/Module Read Progress", unit="api/module", ncols=100)
191
+
192
+ # 从json中循环解析API数据,遍历所有API
193
+ for data_name in data_json['data']:
194
+ check_op_str_pattern_valid(data_name)
195
+ merge_list = self.gen_merge_list(data_json, data_name, stack_json_data)
196
+ if not merge_list:
197
+ continue
198
+
199
+ op_name_list = merge_list.get(CompareConst.OP_NAME)
200
+ summary_list = merge_list.get(Const.SUMMARY)
201
+ data_name_list = merge_list.get('data_name')
202
+ op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
203
+ summary_list,
204
+ data_name_list)
205
+ # 遍历单个API的所有item
206
+ for index, op_name in enumerate(op_name_reorder):
207
+ result[CompareConst.OP_NAME].append(op_name)
208
+ if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name):
209
+ struct = merge_list[CompareConst.INPUT_STRUCT].pop(0)
210
+ elif CompareConst.OUTPUT_PATTERN in op_name:
211
+ struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0)
212
+ elif CompareConst.PARAMS_PATTERN in op_name:
213
+ struct = merge_list[CompareConst.PARAMS_STRUCT].pop(0)
214
+ elif CompareConst.PARAMS_GRAD_PATTERN in op_name:
215
+ struct = merge_list[CompareConst.PARAMS_GRAD_STRUCT].pop(0)
216
+ else:
217
+ struct = merge_list[CompareConst.DEBUG_STRUCT].pop(0)
218
+ result[Const.DTYPE].append(struct[0])
219
+ result[Const.SHAPE].append(struct[1])
220
+ if self.mode_config.dump_mode == Const.MD5:
221
+ result[Const.MD5].append(struct[2])
222
+ result[Const.SUMMARY].append(summary_reorder.pop(0))
223
+ result[Const.STACK_INFO].append(
224
+ merge_list[Const.STACK_INFO][0] if index == 0 and self.mode_config.stack_mode else None)
225
+ if self.mode_config.dump_mode == Const.ALL:
226
+ result['data_name'].append(data_name_reorder.pop(0))
227
+
228
+ progress_bar.update(1)
229
+ progress_bar.close()
230
+ return pd.DataFrame(result)
114
231
 
115
232
  def gen_merge_list(self, json_data, op_name, stack_json_data):
116
233
  op_data = json_data['data'][op_name]
117
- check_dump_json_str(op_data, op_name)
234
+ if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE:
235
+ check_dump_json_str(op_data, op_name)
118
236
  op_parsed_list = read_op(op_data, op_name)
119
237
 
120
- if self.stack_mode:
238
+ if self.mode_config.stack_mode:
121
239
  stack_info = stack_json_data.get(op_name)
122
240
  if stack_info is not None:
123
241
  check_stack_json_str(stack_info, op_name)
@@ -127,423 +245,483 @@ class Comparator:
127
245
  'full_info': stack_info
128
246
  })
129
247
 
130
- merge_list = merge_tensor(op_parsed_list, self.dump_mode)
248
+ merge_list = merge_tensor(op_parsed_list, self.mode_config.dump_mode)
131
249
  return merge_list
132
250
 
133
- def check_op(self, npu_dict, bench_dict):
134
- npu_op_name = npu_dict[CompareConst.OP_NAME]
135
- bench_op_name = bench_dict[CompareConst.OP_NAME]
136
- graph_mode = check_graph_mode(safe_get_value(npu_op_name, 0, "npu_op_name"),
137
- safe_get_value(bench_op_name, 0, "bench_op_name"))
138
-
139
- frame_name = getattr(self, "frame_name")
140
- if frame_name == "PTComparator":
141
- from msprobe.pytorch.compare.match import graph_mapping
142
- if graph_mode:
143
- return graph_mapping.match(npu_op_name[0], bench_op_name[0])
144
- struct_match = check_struct_match(npu_dict, bench_dict)
145
- if not self.fuzzy_match:
146
- name_match = npu_op_name == bench_op_name
147
- return name_match and struct_match
148
- try:
149
- name_match = fuzzy_check_op(npu_op_name, bench_op_name)
150
- except Exception as err:
151
- logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
152
- name_match = False
153
- return name_match and struct_match
154
-
155
- def match_op(self, npu_queue, bench_queue):
156
- for b_index, b_op in enumerate(bench_queue[0: -1]):
157
- if self.check_op(npu_queue[-1], b_op):
158
- return len(npu_queue) - 1, b_index
159
- if self.check_op(npu_queue[-1], bench_queue[-1]):
160
- return len(npu_queue) - 1, len(bench_queue) - 1
161
- for n_index, n_op in enumerate(npu_queue[0: -1]):
162
- if self.check_op(n_op, bench_queue[-1]):
163
- return n_index, len(bench_queue) - 1
164
- return -1, -1
165
251
 
166
- def compare_process(self, file_lists):
167
- npu_json_path, bench_json_path, stack_json_path = file_lists
168
- npu_json_data = load_json(npu_json_path)
169
- bench_json_data = load_json(bench_json_path)
170
- stack_json_data = load_json(stack_json_path) if self.stack_mode else None
252
+ class ProcessDf:
253
+ def __init__(self, mode_config: ModeConfig, mapping_config: MappingConfig, mapping_dict: MappingDict):
254
+ self.mode_config = mode_config
255
+ self.mapping_config = mapping_config
256
+ self.mapping_dict = mapping_dict
171
257
 
172
- if self.fuzzy_match:
173
- logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
258
+ @staticmethod
259
+ def get_api_name(api_list):
260
+ try:
261
+ api_name = api_list[0] + Const.SEP + api_list[1]
262
+ except IndexError as error:
263
+ logger.error('Failed to retrieve API name, please check if the dump data is reasonable')
264
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
265
+ return api_name
266
+
267
+ def process_compare_key_and_shape(self, npu_df, bench_df):
268
+ npu_df = self.assign_npu_df_compare_key(npu_df, bench_df)
269
+ npu_df[CompareConst.CMP_SHAPE] = npu_df[Const.SHAPE]
270
+ bench_df[CompareConst.CMP_KEY] = bench_df[CompareConst.OP_NAME]
271
+ bench_df[CompareConst.CMP_SHAPE] = bench_df[Const.SHAPE]
272
+ return npu_df, bench_df
273
+
274
+ def assign_npu_df_compare_key(self, npu_df, bench_df):
275
+ """
276
+ 处理 npu_df 的 COMPARE_KEY 赋值逻辑
174
277
 
175
- npu_ops_queue = []
176
- bench_ops_queue = []
177
- result = []
278
+ :param npu_df: DataFrame,NPU 对比数据
279
+ :param bench_df: DataFrame,Bench 对比数据
280
+ :return: compare_key(name)处理后的 npu_df
281
+ """
282
+ # 处理api_mapping映射
283
+ if self.mapping_config.api_mapping:
284
+ # 如果用户不传api_mapping.yaml,先使用内置api_mapping.yaml替换npu_op_name
285
+ npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_internal_api_mapping)
286
+ # 如果用户传入api_mapping.yaml,再使用传入api_mapping.yaml进一步替换npu_op_name
287
+ if isinstance(self.mapping_config.api_mapping, str):
288
+ self.modify_compare_data_with_user_mapping(npu_df, bench_df)
289
+ # 处理cell_mapping映射
290
+ elif self.mapping_config.cell_mapping:
291
+ npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping)
292
+ # 处理data_mapping映射
293
+ elif self.mapping_config.data_mapping:
294
+ npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_data_mapping)
295
+ else:
296
+ npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME]
297
+ return npu_df
298
+
299
+ def process_internal_api_mapping(self, npu_op_name):
300
+ # get api name & class name from op_name
301
+ ms_api_name = self.get_api_name(npu_op_name.split(Const.SEP))
302
+ class_name = ms_api_name.split(Const.SEP)[0]
303
+ if class_name == "Mint":
304
+ return npu_op_name.replace("Mint", "Torch")
305
+ elif class_name == "MintFunctional":
306
+ return npu_op_name.replace("MintFunctional", "Functional")
307
+ elif self.mapping_dict.ms_to_pt_mapping.get(ms_api_name):
308
+ return npu_op_name.replace(ms_api_name, self.mapping_dict.ms_to_pt_mapping.get(ms_api_name))
309
+ else:
310
+ return npu_op_name
311
+
312
+ def modify_compare_data_with_user_mapping(self, npu_df, bench_df):
313
+ def gen_input_compare_key(pattern, term):
314
+ is_unmatched = True
315
+ for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')):
316
+ if op_name.split(pattern)[1].startswith(str(prefix)):
317
+ npu_df.loc[index, CompareConst.CMP_KEY] = (
318
+ op_name.replace(pattern + str(prefix),
319
+ pattern + str(mapping_dict.get(f'pt_{term}')[i])))
320
+ is_unmatched = False
321
+ return is_unmatched
322
+
323
+ ms_api_indices_dict = self.get_api_indices_dict(npu_df)
324
+ pt_api_indices_dict = self.get_api_indices_dict(bench_df)
325
+
326
+ for mapping_dict in self.mapping_dict.api_mapping_dict:
327
+ all_length_equal = True
328
+ for k1, k2 in CompareConst.API_MAPPING_KEYS_TO_COMPARE:
329
+ if len(mapping_dict.get(k1, [])) != len(mapping_dict.get(k2, [])):
330
+ all_length_equal = False
331
+ if not all_length_equal:
332
+ logger.warning('The user-defined mapping table is incorrect,\
333
+ make sure that the number of parameters is equal')
334
+ continue
178
335
 
179
- ops_npu_iter = iter(npu_json_data['data'])
180
- ops_bench_iter = iter(bench_json_data['data'])
181
- read_err_npu = True
182
- read_err_bench = True
183
- last_npu_ops_len = 0
184
- last_bench_ops_len = 0
336
+ ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api')
337
+ if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict:
338
+ continue
339
+ for index in ms_api_indices_dict.get(ms_api):
340
+ op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1)
341
+ if CompareConst.INPUT_PATTERN in op_name:
342
+ is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args')
343
+ elif CompareConst.KWARGS_PATTERN in op_name:
344
+ is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args')
345
+ elif CompareConst.OUTPUT_PATTERN in op_name:
346
+ is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output')
347
+ elif CompareConst.PARAMS_PATTERN in op_name:
348
+ is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters')
349
+ elif CompareConst.PARAMS_GRAD_PATTERN in op_name:
350
+ is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad')
351
+ else:
352
+ logger.error(f'Excepted op_name: {op_name}')
353
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
354
+ if is_abandoned:
355
+ npu_df.loc[index, CompareConst.CMP_KEY] = op_name + 'abandoned'
185
356
 
186
- npu_api_nums = len(npu_json_data['data'])
187
- progress_bar = tqdm(total=npu_api_nums, desc="API/Module Read Progress", unit="item", ncols=100)
357
+ def get_api_indices_dict(self, op_name_df):
358
+ """
359
+ 生成多个api对应的各自的所有的input、output等的index的键值对字典
360
+ 示例:
361
+ {'Functional.conv2d': [0, 1, 2, 3],
362
+ 'Functional.batch_norm': [4, 5, 6, 7, 8]
363
+ }
364
+ """
365
+ api_indices_dict = defaultdict(list)
366
+ for op_index, name in enumerate(op_name_df[CompareConst.OP_NAME]):
367
+ api_name = self.get_api_name(name.split(Const.SEP))
368
+ api_indices_dict[api_name].append(op_index)
369
+ return api_indices_dict
370
+
371
+ def process_cell_mapping(self, npu_op_name):
372
+ if not npu_op_name:
373
+ return CompareConst.N_A
374
+ param_grad_flag = Const.PARAMS_GRAD in npu_op_name.split(Const.SEP)
375
+ if not param_grad_flag and not re.search(Const.REGEX_FORWARD_BACKWARD, npu_op_name):
376
+ return CompareConst.N_A
377
+ npu_op_name = npu_op_name.replace("Cell", "Module", 1)
378
+ if self.mapping_dict.cell_mapping_dict:
379
+ # get cell name & class name from op_name
380
+ # Cell.fc1.Dense.forward.0.input.0
381
+ cell_name = re.split(r'\.(?:forward|backward|parameters_grad)\.', npu_op_name.split(Const.SEP, 1)[-1])[0]
382
+ if cell_name in self.mapping_dict.cell_mapping_dict:
383
+ npu_op_name = npu_op_name.replace(cell_name, self.mapping_dict.cell_mapping_dict[cell_name], 1)
384
+ return npu_op_name
385
+
386
+ def process_data_mapping(self, npu_op_name):
387
+ return self.mapping_dict.data_mapping_dict.get(npu_op_name, npu_op_name)
388
+
389
+
390
+ class Match:
391
+ def __init__(self, mode_config: ModeConfig, mapping_config: MappingConfig, cross_frame):
392
+ self.mode_config = mode_config
393
+ self.mapping_config = mapping_config
394
+ self.cross_frame = cross_frame
188
395
 
189
- while True:
190
- if not read_err_npu and not read_err_bench:
191
- break
192
- try:
193
- last_npu_ops_len = len(npu_ops_queue)
194
- op_name_npu = next(ops_npu_iter)
195
- check_op_str_pattern_valid(op_name_npu)
196
- npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data)
197
- if npu_merge_list:
198
- npu_ops_queue.append(npu_merge_list)
199
- except StopIteration:
200
- read_err_npu = False
201
- try:
202
- last_bench_ops_len = len(bench_ops_queue)
203
- op_name_bench = next(ops_bench_iter)
204
- check_op_str_pattern_valid(op_name_bench)
205
- bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data)
206
- if bench_merge_list:
207
- bench_ops_queue.append(bench_merge_list)
208
- except StopIteration:
209
- read_err_bench = False
396
+ @staticmethod
397
+ def put_unmatched_in_table(match_result, npu_op_item):
398
+ npu_columns = npu_op_item.index.tolist()[:-2]
399
+ new_columns = [name[:-1] + 'y' for name in npu_columns]
400
+ na_series = pd.Series([CompareConst.N_A] * len(new_columns), index=new_columns)
401
+ new_result_item = pd.concat([npu_op_item, na_series]).to_frame().T
402
+ new_result_item.columns = CompareConst.MATCH_RESULT_COLUMNS
403
+ match_result = pd.concat([match_result, new_result_item])
404
+ return match_result
210
405
 
211
- progress_bar.update(1)
406
+ @staticmethod
407
+ def put_matched_in_table(match_result, npu_op_item, bench_op_item):
408
+ head_len = len(CompareConst.MATCH_RESULT_COLUMNS)
409
+ new_result_item = pd.concat([npu_op_item, bench_op_item]).head(head_len).to_frame().T
410
+ new_result_item.columns = CompareConst.MATCH_RESULT_COLUMNS
411
+ match_result = pd.concat([match_result, new_result_item])
412
+ return match_result
212
413
 
213
- # merge all boolean expressions
214
- both_empty = not npu_ops_queue and not bench_ops_queue
215
- no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len)
216
- if both_empty or no_change:
217
- continue
414
+ @staticmethod
415
+ def rename_api(op_name):
416
+ """
417
+ 原api: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号}
418
+ rename后: {api_type}.{api_name}.{前向反向}.{input/output}.{参数序号}
419
+ """
420
+ if Const.FORWARD not in op_name and Const.BACKWARD not in op_name:
421
+ return op_name
422
+ process = Const.FORWARD if Const.FORWARD in op_name else Const.BACKWARD
423
+ name_split = op_name.split(process)
424
+ try:
425
+ torch_func_index, in_out = name_split[0], name_split[1]
426
+ except IndexError as error:
427
+ logger.error(f'{op_name} can not be split with {process}, please check!')
428
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
429
+ torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
430
+ torch_func = str(torch_func_split[0]) + Const.SEP + process + str(in_out)
431
+ return torch_func
432
+
433
+ def check_op_item(self, npu_op_item, bench_op_item):
434
+ name_match = self.rename_api(npu_op_item[CompareConst.CMP_KEY]) == self.rename_api(
435
+ bench_op_item[CompareConst.CMP_KEY])
436
+ shape_match = npu_op_item[CompareConst.CMP_SHAPE] == bench_op_item[CompareConst.CMP_SHAPE]
437
+ if name_match and shape_match:
438
+ return True
439
+ else:
440
+ npu_op_name = npu_op_item[CompareConst.OP_NAME]
441
+ bench_op_name = bench_op_item[CompareConst.OP_NAME]
442
+ check_op_str_pattern_valid(npu_op_name)
443
+ check_op_str_pattern_valid(bench_op_name)
444
+ logger.warning(f"{npu_op_name} and {bench_op_name} can not fuzzy match")
445
+ return False
446
+
447
+ def match_api_infos(self, npu_df, bench_df):
448
+ """
449
+ 正常匹配和模糊匹配
450
+ """
451
+ if self.mapping_config.data_mapping:
452
+ match_result = pd.merge(npu_df, bench_df, on=[CompareConst.CMP_KEY], how='left')
453
+
454
+ # reorder match_result by op_name of npu
455
+ op_name_order = npu_df[CompareConst.OP_NAME].tolist()
456
+ match_result[CompareConst.OP_NAME_X] = pd.Categorical(match_result[CompareConst.OP_NAME_X],
457
+ categories=op_name_order, ordered=True)
458
+ match_result = match_result.sort_values(CompareConst.OP_NAME_X).reset_index(drop=True)
459
+ match_result[CompareConst.OP_NAME_X] = match_result[CompareConst.OP_NAME_X].astype('object')
460
+ elif not self.mode_config.fuzzy_match:
461
+ match_result = pd.merge(npu_df, bench_df, on=[CompareConst.CMP_KEY, CompareConst.CMP_SHAPE],
462
+ how='outer')
463
+ else:
464
+ match_result = self.process_fuzzy_match(npu_df, bench_df)
465
+ return match_result
218
466
 
219
- # APIs in NPU and Bench models unconsistent judgment
467
+ def process_fuzzy_match(self, npu_df, bench_df):
468
+ """
469
+ 模糊匹配通过循环方式匹配api
470
+ """
471
+ npu_ops_queue = []
472
+ bench_ops_queue = []
473
+ match_result = pd.DataFrame(columns=CompareConst.MATCH_RESULT_COLUMNS)
474
+
475
+ max_len = max(len(npu_df), len(bench_df))
476
+ min_len = min(len(npu_df), len(bench_df))
477
+ for i in range(max_len):
478
+ if i < min_len:
479
+ npu_ops_queue.append(npu_df.iloc[i])
480
+ bench_ops_queue.append(bench_df.iloc[i])
481
+ else:
482
+ try:
483
+ npu_ops_queue.append(npu_df.iloc[i])
484
+ except IndexError:
485
+ pass
486
+ try:
487
+ bench_ops_queue.append(bench_df.iloc[i])
488
+ except IndexError:
489
+ pass
490
+
491
+ # 如果append之后queue状态不一致,则判断结束
220
492
  if bool(npu_ops_queue) ^ bool(bench_ops_queue):
221
- logger.info("Please check whether the number and calls of APIs in NPU and Bench models are consistent.")
222
493
  break
223
494
 
224
- n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue)
495
+ npu_match_point, bench_match_point = self.match_op(npu_ops_queue, bench_ops_queue)
225
496
 
226
- # 如果没有匹配到,数据放到队列中,跳过,直到后面匹配到,把匹配之前的api放到不匹配中
227
- if n_match_point == -1 and b_match_point == -1:
497
+ # 如果没有匹配到,数据放到队列中,跳过。直到后面匹配到,把匹配之前的api放到不匹配中
498
+ if npu_match_point == -1 and bench_match_point == -1:
228
499
  continue
229
500
 
230
- n_match_data = npu_ops_queue[n_match_point]
231
- b_match_data = bench_ops_queue[b_match_point]
232
- un_match_data = npu_ops_queue[0: n_match_point]
233
- for npu_data in un_match_data:
234
- get_un_match_accuracy(result, npu_data, self.dump_mode)
235
- get_accuracy(result, n_match_data, b_match_data, self.dump_mode)
236
- del npu_ops_queue[0: n_match_point + 1]
237
- del bench_ops_queue[0: b_match_point + 1]
238
- progress_bar.close()
501
+ npu_op_item = npu_ops_queue[npu_match_point]
502
+ bench_op_item = bench_ops_queue[bench_match_point]
503
+ unmatched_data = npu_ops_queue[0: npu_match_point]
504
+ for op_item in unmatched_data:
505
+ match_result = self.put_unmatched_in_table(match_result, op_item)
506
+ match_result = self.put_matched_in_table(match_result, npu_op_item, bench_op_item)
507
+ del npu_ops_queue[0: npu_match_point + 1]
508
+ del bench_ops_queue[0: bench_match_point + 1]
509
+
239
510
  if npu_ops_queue:
240
- for npu_data in npu_ops_queue:
241
- get_un_match_accuracy(result, npu_data, self.dump_mode)
242
-
243
- result_df = self.make_result_table(result)
244
- return result_df
245
-
246
- def merge_data(self, json_data, stack_json_data):
247
- ops_all = {}
248
- for op_name in json_data.get('data', {}):
249
- merge_list = self.gen_merge_list(json_data, op_name, stack_json_data)
250
- if merge_list:
251
- struct_to_index_mapping = {
252
- CompareConst.INPUT_STRUCT: 0,
253
- CompareConst.OUTPUT_STRUCT: 0,
254
- CompareConst.PARAMS_STRUCT: 0,
255
- CompareConst.PARAMS_GRAD_STRUCT: 0
256
- }
257
-
258
- op_name_list = merge_list.get(CompareConst.OP_NAME)
259
- summary_list = merge_list.get(Const.SUMMARY)
260
- data_name_list = merge_list.get('data_name')
261
- op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
262
- summary_list,
263
- data_name_list)
264
- for index, op_full_name in enumerate(op_name_reorder):
265
- data_name = data_name_reorder[index] if data_name_reorder else None
266
-
267
- _, state = get_name_and_state(op_full_name)
268
- struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
269
- if not struct_key:
270
- continue
271
- ops_all[op_full_name] = {
272
- CompareConst.STRUCT: safe_get_value(merge_list, struct_to_index_mapping.get(struct_key),
273
- "merge_list", key=struct_key),
274
- CompareConst.SUMMARY: safe_get_value(summary_reorder, index, "summary_reorder"),
275
- 'data_name': data_name,
276
- 'stack_info': merge_list.get('stack_info')
277
- }
278
- struct_to_index_mapping[struct_key] += 1
279
- return ops_all
280
-
281
- def get_accuracy(self, npu_ops_all, bench_ops_all):
282
- result = []
283
- bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all)
284
- for ms_op_name, bench_op_name in self.data_mapping_dict.items():
285
- if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
286
- npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
287
- bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
288
- has_stack = npu_stack_info and bench_stack_info
289
- if self.dump_mode == Const.MD5:
290
- result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all,
291
- bench_ops_all, has_stack, npu_stack_info))
292
- continue
293
-
294
- npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
295
- bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
296
-
297
- if len(npu_struct) < 2 or len(bench_struct) < 2:
298
- logger.error(
299
- f"The length of npu_struct and bench_struct must be >= 2, "
300
- f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. "
301
- f"Please check!"
302
- )
303
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
304
-
305
- base_result_item = [
306
- ms_op_name, bench_op_name,
307
- npu_struct[0],
308
- bench_struct[0],
309
- npu_struct[1],
310
- bench_struct[1]
311
- ]
312
-
313
- if self.dump_mode == Const.SUMMARY:
314
- result_item = base_result_item + [" "] * 8
315
- else:
316
- result_item = base_result_item + [" "] * 5
317
-
318
- npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
319
- result_item.extend(npu_summary_data)
320
- bench_summary_data = bench_ops_all.get(bench_op_name).get("summary")
321
- result_item.extend(bench_summary_data)
322
- if self.dump_mode == Const.SUMMARY:
323
- self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item)
324
- else:
325
- result_item.append(CompareConst.ACCURACY_CHECK_YES)
326
- result_item.append("")
327
- if has_stack:
328
- result_item.extend(npu_stack_info)
329
- else:
330
- result_item.append(CompareConst.NONE)
331
- if self.dump_mode == Const.ALL:
332
- result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None))
333
- result.append(result_item)
334
- elif ms_op_name not in npu_ops_all:
335
- logger.warning(f'Can not find npu op name : `{ms_op_name}` in npu dump json file.')
336
- elif bench_op_name not in npu_ops_all:
337
- logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.')
338
- return result
511
+ for op_item in npu_ops_queue:
512
+ match_result = self.put_unmatched_in_table(match_result, op_item)
339
513
 
340
- def compare_process_custom(self, file_lists):
341
- npu_json_path, bench_json_path, stack_json_path = file_lists
342
- npu_json_data = load_json(npu_json_path)
343
- bench_json_data = load_json(bench_json_path)
344
- stack_json_data = load_json(stack_json_path) if self.stack_mode else None
345
- npu_ops_all = self.merge_data(npu_json_data, stack_json_data)
346
- bench_ops_all = self.merge_data(bench_json_data, stack_json_data)
514
+ match_result.reset_index(drop=True, inplace=True)
515
+ return match_result
347
516
 
348
- result = self.get_accuracy(npu_ops_all, bench_ops_all)
349
- result_df = self.make_result_table(result)
350
- return result_df
517
+ def match_op(self, npu_queue, bench_queue):
518
+ for b_index, b_op in enumerate(bench_queue[0: -1]):
519
+ if self.check_op_item(npu_queue[-1], b_op):
520
+ return len(npu_queue) - 1, b_index
521
+ if self.check_op_item(npu_queue[-1], bench_queue[-1]):
522
+ return len(npu_queue) - 1, len(bench_queue) - 1
523
+ for n_index, n_op in enumerate(npu_queue[0: -1]):
524
+ if self.check_op_item(n_op, bench_queue[-1]):
525
+ return n_index, len(bench_queue) - 1
526
+ return -1, -1
351
527
 
352
- def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param, bench_data):
528
+ def gen_dtype_condition(self, match_result):
353
529
  """
354
- :param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0
355
- :param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
356
- :param op_name_mapping_dict: op_name和npy或pt文件的映射关系
357
- :param input_param: npu_json_path/bench_json_path/stack_json_path等参数
358
- :param bench_data: bench的dump数据中"data"字段
359
- :return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
360
- 用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、
361
- 最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
530
+ dtype匹配条件为npu、bench的dtype一致或属于规定的映射关系
362
531
  """
363
- npu_bench_name_list = op_name_mapping_dict[npu_op_name]
364
- data_name = safe_get_value(npu_bench_name_list, 1, "npu_bench_name_list")
365
- error_file, relative_err, error_flag = None, None, False
366
- bench_data_name = get_bench_data_name(bench_op_name, bench_data)
367
- if data_name == '-1' or data_name == -1: # 没有真实数据路径
368
- n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
369
- error_flag = True
370
- elif not bench_data_name:
371
- n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
372
- error_file = 'no_bench_data'
373
- else:
374
- try:
375
- read_npy_data = getattr(self, "read_npy_data")
376
- frame_name = getattr(self, "frame_name")
377
- if frame_name == "MSComparator":
378
- n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX)
379
- if self.cross_frame:
380
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name,
381
- load_pt_file=True)
382
- else:
383
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
384
- else:
385
- n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX)
386
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
387
- except IOError as error:
388
- error_file = error.filename
389
- n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
390
- error_flag = True
391
- except (FileCheckException, CompareException):
392
- error_file = data_name
393
- n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
394
- error_flag = True
395
-
396
- # 通过n_value, b_value同时得到错误标志和错误信息
397
- n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value,
398
- error_flag=error_flag, error_file=error_file)
399
-
400
- result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg)
401
-
402
- if self.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
403
- err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
404
- result_list.append(err_msg)
405
- return result_list
532
+ # 如果使用了data_mapping,不校验dtype,返回全True的DataFrame
533
+ if self.mapping_config.data_mapping:
534
+ return pd.Series(True, index=match_result.index)
535
+
536
+ npu_dtype = match_result['dtype_x']
537
+ bench_dtype = match_result['dtype_y']
538
+ npu_dtype = self.process_cross_frame_dtype(npu_dtype)
539
+ bench_dtype = self.process_cross_frame_dtype(bench_dtype)
540
+
541
+ equal_condition = npu_dtype == bench_dtype
542
+ match_condition = (
543
+ (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[0]) & bench_dtype.isin(
544
+ CompareConst.DTYPE_MATCH_GROUPS[0])) |
545
+ (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[1]) & bench_dtype.isin(
546
+ CompareConst.DTYPE_MATCH_GROUPS[1]))
547
+ )
548
+ return equal_condition | match_condition
406
549
 
407
- def compare_core(self, input_param, output_path, **kwargs):
408
- """
409
- Compares data from multiple JSON files and generates a comparison report.
550
+ def process_cross_frame_dtype(self, dtype):
551
+ if self.cross_frame:
552
+ dtype = dtype.map(cross_dtype_mapping).fillna(dtype)
553
+ return dtype
410
554
 
411
- Args:
412
- input_param (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
413
- "stack_path").
414
- output_path (str): The path where the output Excel report will be saved.
415
- **kwargs: Additional keyword arguments including:
416
- - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False.
417
- - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
418
- - suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
419
- - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
420
- - dump_mode (str): ALL, SUMMARY, MD5.
421
555
 
422
- Returns:
423
- """
424
- # get kwargs or set default value
425
- suffix = kwargs.get('suffix', '')
556
+ class CreateTable:
557
+ def __init__(self, mode_config: ModeConfig):
558
+ self.mode_config = mode_config
426
559
 
427
- logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
428
- file_name = add_time_with_xlsx("compare_result" + suffix)
429
- file_path = os.path.join(os.path.realpath(output_path), file_name)
430
- remove_path(file_path)
431
- highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
560
+ @staticmethod
561
+ def process_data_name(result):
562
+ result['data_name_x'] = result.apply(lambda row: [row['data_name_x'], row['data_name_y']], axis=1)
563
+ return result
432
564
 
433
- npu_json = input_param.get("npu_json_path")
434
- bench_json = input_param.get("bench_json_path")
435
- stack_json = input_param.get("stack_json_path")
436
- if self.data_mapping:
437
- result_df = self.compare_process_custom([npu_json, bench_json, stack_json])
438
- else:
439
- result_df = self.compare_process([npu_json, bench_json, stack_json])
565
+ @staticmethod
566
+ def set_summary(summary):
567
+ if summary == CompareConst.N_A:
568
+ return [CompareConst.N_A] * 4 # 4为统计值个数
569
+ summary_list = []
570
+ for i in summary:
571
+ if str(i).lower() == 'nan':
572
+ summary_list.append(CompareConst.NAN)
573
+ else:
574
+ summary_list.append(i)
575
+ return summary_list
440
576
 
441
- if not result_df.values.tolist():
442
- logger.warning("Can`t match any op.")
443
- return
577
+ def make_result_df(self, result):
578
+ # get header
579
+ header = CompareConst.HEAD_OF_COMPARE_MODE[self.mode_config.dump_mode][:]
580
+ if self.mode_config.stack_mode:
581
+ header.append(CompareConst.STACK)
582
+ if self.mode_config.dump_mode == Const.ALL:
583
+ header.append(CompareConst.DATA_NAME)
584
+ result = self.process_data_name(result)
585
+
586
+ # rename match_result columns
587
+ result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
588
+ 'op_name_y': CompareConst.BENCH_NAME,
589
+ 'dtype_x': CompareConst.NPU_DTYPE,
590
+ 'dtype_y': CompareConst.BENCH_DTYPE,
591
+ 'shape_x': CompareConst.NPU_SHAPE,
592
+ 'shape_y': CompareConst.BENCH_SHAPE,
593
+ 'md5_x': CompareConst.NPU_MD5,
594
+ 'md5_y': CompareConst.BENCH_MD5,
595
+ 'data_name_x': CompareConst.DATA_NAME,
596
+ 'stack_info_x': CompareConst.STACK}, inplace=True)
597
+
598
+ # process summary data
599
+ npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
600
+ bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
601
+ CompareConst.BENCH_NORM]
602
+ result[npu_summary] = result['summary_x'].apply(self.set_summary).tolist()
603
+ result[bench_summary] = result['summary_y'].apply(self.set_summary).tolist()
604
+
605
+ result_df = pd.DataFrame(columns=header)
606
+ for h in header:
607
+ if h in result.columns:
608
+ result_df[h] = result[h]
609
+ return result_df, header
610
+
611
+
612
+ class CalcStatsDiff:
613
+ def __init__(self, mode_config: ModeConfig):
614
+ self.mode_config = mode_config
444
615
 
445
- if self.dump_mode == Const.ALL:
446
- result_df = self.do_multi_process(input_param, result_df)
616
+ @staticmethod
617
+ def type_check(val):
618
+ """
619
+ 检查是否为数值或字符串形式的nan, 如果是返回True
620
+ """
621
+ check_series = pd.Series(False, index=val.index)
622
+ val_str = val.astype(str)
623
+ check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True
624
+ return check_series
447
625
 
448
- find_compare_result_error_rows(result_df, highlight_dict, self.dump_mode)
449
- highlight_rows_xlsx(result_df, highlight_dict, file_path)
626
+ @staticmethod
627
+ def get_number(val):
628
+ return pd.to_numeric(val.astype(str), errors='coerce')
629
+
630
+ def calc_summary_diff(self, result_df, cond_no_bench, stats_index: str):
631
+ npu_val = result_df['NPU ' + stats_index]
632
+ bench_val = result_df['Bench ' + stats_index]
633
+ diff_name = stats_index.capitalize() + ' diff'
634
+ rel_err_name = ('norm' if stats_index == 'l2norm' else stats_index).capitalize() + 'RelativeErr'
635
+
636
+ # npu、bench中统计量均为数字或nan
637
+ cond_num_nan = self.type_check(npu_val) & self.type_check(bench_val)
638
+
639
+ # 如果统计量不是数字或nan,就赋值统计量差异为N/A
640
+ result_df.loc[~cond_num_nan, [diff_name, rel_err_name]] = CompareConst.N_A
641
+ cond_valid_stat = ~cond_no_bench & cond_num_nan # 有效统计条件:bench_name不是N/A,并且NPU和bench的统计量都是数字或nan
642
+ result_df.loc[cond_valid_stat, diff_name] = self.get_number(npu_val) - self.get_number(bench_val)
643
+
644
+ cond_diff_nan = result_df[diff_name].isna() # 统计量差异是nan
645
+ cond_nan_diff = cond_valid_stat & cond_diff_nan
646
+ result_df.loc[cond_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN
647
+
648
+ cond_not_nan_diff = cond_valid_stat & ~cond_diff_nan
649
+ condition_pt_zero = bench_val == 0
650
+ result_df.loc[cond_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.N_A
651
+
652
+ # 相对误差转成百分比字符串
653
+ cond_ref_err = cond_not_nan_diff & ~condition_pt_zero
654
+ result_df.loc[cond_ref_err, rel_err_name] = (
655
+ result_df.loc[cond_ref_err, diff_name] / bench_val[cond_ref_err] * 100)
656
+ result_df.loc[cond_ref_err, rel_err_name] = (result_df.loc[cond_ref_err, rel_err_name].abs().astype(str) + '%')
657
+
658
+ magnitude = self.get_number(result_df[diff_name]).abs() / (pd.Series(
659
+ np.maximum(self.get_number(npu_val), self.get_number(bench_val))).abs() + CompareConst.EPSILON)
660
+ return magnitude > CompareConst.MAGNITUDE
661
+
662
+ def calc_accuracy(self, result_df, header):
663
+ # bench name N/A represents no bench data, err_msg adds "No bench data matched."
664
+ condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
665
+ result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
666
+ result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH
667
+
668
+ if self.mode_config.dump_mode == Const.MD5:
669
+ condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5]
670
+ result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS
671
+ result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
672
+ elif self.mode_config.dump_mode == Const.SUMMARY:
673
+ warning_list = [
674
+ self.calc_summary_diff(result_df, condition_no_bench, stats_index)
675
+ for stats_index in ['max', 'min', 'mean', 'l2norm']
676
+ ]
677
+ warning_flag = pd.DataFrame(warning_list).any()
678
+ result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
679
+ result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
680
+ result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
681
+ else:
682
+ fill_cols = [CompareConst.COSINE, CompareConst.EUC_DIST,
683
+ CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
684
+ CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
685
+ CompareConst.ERROR_MESSAGE]
686
+ result_df.loc[~condition_no_bench, fill_cols] = ''
687
+ result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES
688
+
689
+ return result_df[header]
690
+
691
+
692
+ def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig:
693
+ """公共的前置处理逻辑,返回封装后的 ComparisonConfig 对象"""
694
+ try:
695
+ config = ComparisonConfig(
696
+ dump_mode='',
697
+ stack_mode=False,
698
+ auto_analyze=kwargs.get('auto_analyze', True),
699
+ fuzzy_match=kwargs.get('fuzzy_match', False),
700
+ data_mapping=kwargs.get('data_mapping', {}),
701
+ suffix=kwargs.get('suffix', ''),
702
+ cell_mapping=kwargs.get('cell_mapping', {}),
703
+ api_mapping=kwargs.get('api_mapping', {}),
704
+ layer_mapping=kwargs.get('layer_mapping', {}),
705
+ compared_file_type='',
706
+ )
450
707
 
451
- if self.auto_analyze:
452
- advisor = Advisor(result_df, output_path, suffix)
453
- advisor.analysis()
708
+ set_dump_path(input_param)
709
+ config.dump_mode = get_dump_mode(input_param)
710
+ config.compared_file_type = get_file_type(input_param.get("npu_json_path", None))
454
711
 
455
- print_compare_ends_info()
712
+ # set stack_mode and set "stack_json_path" in input_param
713
+ if 'stack_json_path' in input_param:
714
+ config.stack_mode = kwargs.get('stack_mode', False)
715
+ else:
716
+ config.stack_mode = set_stack_json_path(input_param)
456
717
 
457
- def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
458
- cos_result = []
459
- max_err_result = []
460
- max_relative_err_result = []
461
- err_mess = []
462
- one_thousand_err_ratio_result = []
463
- five_thousand_err_ratio_result = []
464
- is_print_compare_log = input_param.get("is_print_compare_log")
465
- bench_data = load_json(input_param.get("bench_json_path")).get('data')
466
- for i in range(len(result_df)):
467
- npu_op_name = result_df.iloc[i, 0]
468
- bench_op_name = result_df.iloc[i, 1]
469
- if is_print_compare_log:
470
- logger.info("start compare: {}".format(npu_op_name))
471
-
472
- cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = \
473
- self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param, bench_data)
474
-
475
- if is_print_compare_log:
476
- logger.info(
477
- "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
478
- one_thousand_err_ratio {}, "
479
- "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err,
480
- err_msg, one_thousand_err_ratio, five_thousand_err_ratio))
481
- cos_result.append(cos_sim)
482
- max_err_result.append(max_abs_err)
483
- max_relative_err_result.append(max_relative_err)
484
- err_mess.append(err_msg)
485
- one_thousand_err_ratio_result.append(one_thousand_err_ratio)
486
- five_thousand_err_ratio_result.append(five_thousand_err_ratio)
487
-
488
- cr = ComparisonResult(
489
- cos_result=cos_result,
490
- max_err_result=max_err_result,
491
- max_relative_err_result=max_relative_err_result,
492
- err_msgs=err_mess,
493
- one_thousand_err_ratio_result=one_thousand_err_ratio_result,
494
- five_thousand_err_ratio_result=five_thousand_err_ratio_result
495
- )
718
+ check_configuration_param(config.stack_mode, config.auto_analyze, config.fuzzy_match,
719
+ input_param.get('is_print_compare_log', True))
720
+ create_directory(output_path)
721
+ check_compare_param(input_param, output_path, config.dump_mode, config.stack_mode)
496
722
 
497
- return _save_cmp_result(idx, cr, result_df, lock)
723
+ return config
498
724
 
499
- def do_multi_process(self, input_parma, result_df):
500
- try:
501
- result_df = _handle_multi_process(self.compare_ops, input_parma, result_df,
502
- multiprocessing.Manager().RLock())
503
- return result_df
504
- except ValueError as e:
505
- logger.error('result dataframe is not found.')
506
- raise CompareException(CompareException.INVALID_DATA_ERROR) from e
507
-
508
-
509
- def get_bench_data_name(bench_op_name, bench_data):
510
- bench_name_list = re.split(r'\.(input|output|kwargs|parameters|parameters_grad)\.', bench_op_name)
511
- if len(bench_name_list) > 1 and bench_name_list[1] == Const.PARAMS_GRAD:
512
- bench_data_bundle = bench_data.get(bench_name_list[0] + Const.SEP + bench_name_list[1], {})
513
- else:
514
- bench_data_bundle = bench_data.get(bench_name_list[0], {})
515
- if not bench_data_bundle or len(bench_name_list) < 3:
516
- return None
517
- layers = bench_name_list[2].split(Const.SEP)
518
-
519
- def _get(key, container):
520
- if isinstance(container, dict):
521
- return container.get(key)
522
- if isinstance(container, list):
523
- try:
524
- return container[int(key)]
525
- except (ValueError, IndexError):
526
- return None
527
- return None
528
-
529
- def get_by_layer(container, params_grad=False):
530
- data = container
531
- # dump.json中parameters_grad的结构为key:[{}], 如果存在key,有且只有一个列表元素,而op_name中只命名到了key,因此加'0'
532
- if params_grad:
533
- layers.append('0')
534
- for layer in layers:
535
- data = _get(layer, data)
536
- return _get(CompareConst.DATA_NAME.lower(), data)
537
-
538
- if Const.INPUT == bench_name_list[1]:
539
- return get_by_layer(bench_data_bundle.get(Const.INPUT, bench_data_bundle.get(Const.INPUT_ARGS)))
540
- elif Const.KWARGS == bench_name_list[1]:
541
- return get_by_layer(bench_data_bundle.get(Const.INPUT_KWARGS))
542
- elif Const.OUTPUT == bench_name_list[1]:
543
- return get_by_layer(bench_data_bundle.get(Const.OUTPUT))
544
- elif Const.PARAMS == bench_name_list[1]:
545
- return get_by_layer(bench_data_bundle.get(Const.PARAMS))
546
- elif Const.PARAMS_GRAD == bench_name_list[1]:
547
- return get_by_layer(bench_data_bundle, params_grad=True)
548
- else:
549
- return None
725
+ except (CompareException, FileCheckException) as error:
726
+ logger.error('Compare failed. Please check the arguments and do it again!')
727
+ raise CompareException(error.code) from error