mindstudio-probe 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
  2. mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
  3. mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
  4. mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
  5. mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
  6. mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
  7. msprobe/README.md +182 -0
  8. msprobe/__init__.py +0 -0
  9. msprobe/config/README.md +397 -0
  10. msprobe/config/config.json +28 -0
  11. msprobe/config/img/free_benchmark.png +0 -0
  12. msprobe/core/common/const.py +241 -0
  13. msprobe/core/common/exceptions.py +88 -0
  14. msprobe/core/common/file_check.py +265 -0
  15. msprobe/core/common/log.py +55 -0
  16. msprobe/core/common/utils.py +516 -0
  17. msprobe/core/common_config.py +58 -0
  18. msprobe/core/data_dump/data_collector.py +140 -0
  19. msprobe/core/data_dump/data_processor/base.py +245 -0
  20. msprobe/core/data_dump/data_processor/factory.py +61 -0
  21. msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
  22. msprobe/core/data_dump/json_writer.py +116 -0
  23. msprobe/core/data_dump/scope.py +178 -0
  24. msprobe/mindspore/__init__.py +1 -0
  25. msprobe/mindspore/debugger/__init__.py +0 -0
  26. msprobe/mindspore/debugger/debugger_config.py +51 -0
  27. msprobe/mindspore/debugger/precision_debugger.py +32 -0
  28. msprobe/mindspore/doc/dump.md +65 -0
  29. msprobe/mindspore/dump/__init__.py +0 -0
  30. msprobe/mindspore/dump/api_kbk_dump.py +55 -0
  31. msprobe/mindspore/dump/dump_tool_factory.py +38 -0
  32. msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
  33. msprobe/mindspore/ms_config.py +78 -0
  34. msprobe/mindspore/overflow_check/__init__.py +0 -0
  35. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
  36. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
  37. msprobe/mindspore/task_handler_factory.py +21 -0
  38. msprobe/msprobe.py +67 -0
  39. msprobe/pytorch/__init__.py +4 -0
  40. msprobe/pytorch/advisor/advisor.py +124 -0
  41. msprobe/pytorch/advisor/advisor_const.py +59 -0
  42. msprobe/pytorch/advisor/advisor_result.py +58 -0
  43. msprobe/pytorch/api_accuracy_checker/.keep +0 -0
  44. msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
  45. msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
  46. msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
  47. msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
  48. msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
  49. msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
  50. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
  51. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
  52. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
  53. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
  54. msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
  55. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
  58. msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
  59. msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
  60. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
  61. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
  62. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
  63. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
  64. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
  65. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
  66. msprobe/pytorch/common/__init__.py +2 -0
  67. msprobe/pytorch/common/compare_script.template +14 -0
  68. msprobe/pytorch/common/log.py +32 -0
  69. msprobe/pytorch/common/parse_json.py +37 -0
  70. msprobe/pytorch/common/utils.py +224 -0
  71. msprobe/pytorch/compare/acc_compare.py +1024 -0
  72. msprobe/pytorch/compare/distributed_compare.py +111 -0
  73. msprobe/pytorch/compare/highlight.py +100 -0
  74. msprobe/pytorch/compare/mapping.yaml +607 -0
  75. msprobe/pytorch/compare/match.py +36 -0
  76. msprobe/pytorch/compare/npy_compare.py +244 -0
  77. msprobe/pytorch/debugger/__init__.py +0 -0
  78. msprobe/pytorch/debugger/debugger_config.py +86 -0
  79. msprobe/pytorch/debugger/precision_debugger.py +95 -0
  80. msprobe/pytorch/doc/FAQ.md +193 -0
  81. msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
  82. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
  83. msprobe/pytorch/doc/dump.md +207 -0
  84. msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
  85. msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
  86. msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
  87. msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
  88. msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
  89. msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
  90. msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
  91. msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
  92. msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
  93. msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
  94. msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
  95. msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
  96. msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
  97. msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
  98. msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
  99. msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
  100. msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
  101. msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
  102. msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
  103. msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
  104. msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
  105. msprobe/pytorch/doc/img/cpu_info.png +0 -0
  106. msprobe/pytorch/doc/img/module_compare.png +0 -0
  107. msprobe/pytorch/doc/parse_tool.md +286 -0
  108. msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
  109. msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
  110. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
  111. msprobe/pytorch/doc/run_overflow_check.md +25 -0
  112. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
  113. msprobe/pytorch/free_benchmark/__init__.py +8 -0
  114. msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
  115. msprobe/pytorch/free_benchmark/common/constant.py +67 -0
  116. msprobe/pytorch/free_benchmark/common/counter.py +72 -0
  117. msprobe/pytorch/free_benchmark/common/enums.py +37 -0
  118. msprobe/pytorch/free_benchmark/common/params.py +129 -0
  119. msprobe/pytorch/free_benchmark/common/utils.py +98 -0
  120. msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
  121. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
  122. msprobe/pytorch/free_benchmark/main.py +102 -0
  123. msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
  124. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
  125. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
  126. msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
  127. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
  132. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
  133. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
  134. msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
  135. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
  136. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
  137. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
  138. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
  139. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
  140. msprobe/pytorch/functional/__init__.py +0 -0
  141. msprobe/pytorch/functional/data_processor.py +0 -0
  142. msprobe/pytorch/functional/dump_module.py +39 -0
  143. msprobe/pytorch/hook_module/__init__.py +1 -0
  144. msprobe/pytorch/hook_module/api_registry.py +161 -0
  145. msprobe/pytorch/hook_module/hook_module.py +109 -0
  146. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
  147. msprobe/pytorch/hook_module/utils.py +29 -0
  148. msprobe/pytorch/hook_module/wrap_aten.py +100 -0
  149. msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
  150. msprobe/pytorch/hook_module/wrap_functional.py +108 -0
  151. msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
  152. msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
  153. msprobe/pytorch/hook_module/wrap_torch.py +88 -0
  154. msprobe/pytorch/hook_module/wrap_vf.py +64 -0
  155. msprobe/pytorch/module_processer.py +98 -0
  156. msprobe/pytorch/online_dispatch/__init__.py +20 -0
  157. msprobe/pytorch/online_dispatch/compare.py +236 -0
  158. msprobe/pytorch/online_dispatch/dispatch.py +274 -0
  159. msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
  160. msprobe/pytorch/online_dispatch/single_compare.py +391 -0
  161. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
  162. msprobe/pytorch/online_dispatch/utils.py +187 -0
  163. msprobe/pytorch/parse.py +4 -0
  164. msprobe/pytorch/parse_tool/__init__.py +0 -0
  165. msprobe/pytorch/parse_tool/cli.py +32 -0
  166. msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
  167. msprobe/pytorch/parse_tool/lib/compare.py +259 -0
  168. msprobe/pytorch/parse_tool/lib/config.py +51 -0
  169. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
  170. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
  171. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
  172. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
  173. msprobe/pytorch/parse_tool/lib/utils.py +367 -0
  174. msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
  175. msprobe/pytorch/pt_config.py +93 -0
  176. msprobe/pytorch/service.py +167 -0
  177. msprobe/test/core_ut/common/test_utils.py +345 -0
  178. msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
  179. msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
  180. msprobe/test/core_ut/data_dump/test_scope.py +151 -0
  181. msprobe/test/core_ut/test_common_config.py +152 -0
  182. msprobe/test/core_ut/test_file_check.py +218 -0
  183. msprobe/test/core_ut/test_log.py +109 -0
  184. msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
  185. msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
  186. msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
  187. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
  188. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
  189. msprobe/test/mindspore_ut/test_ms_config.py +69 -0
  190. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
  191. msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
  192. msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
  193. msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
  194. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
  195. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
  196. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
  197. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
  198. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
  199. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
  200. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
  201. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
  202. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
  203. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
  204. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
  205. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
  206. msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
  207. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
  208. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
  209. msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
  210. msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
  211. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
  212. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
  213. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
  214. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
  215. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
  216. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
  217. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
  218. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
  219. msprobe/test/pytorch_ut/test_pt_config.py +69 -0
  220. msprobe/test/pytorch_ut/test_service.py +59 -0
  221. msprobe/test/resources/advisor.txt +3 -0
  222. msprobe/test/resources/compare_result_20230703104808.csv +9 -0
  223. msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
  224. msprobe/test/resources/config.yaml +3 -0
  225. msprobe/test/resources/npu_test.pkl +8 -0
  226. msprobe/test/run_test.sh +30 -0
  227. msprobe/test/run_ut.py +58 -0
  228. msprobe/test/test_module_processer.py +64 -0
@@ -0,0 +1,1024 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2019-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ import json
19
+ import multiprocessing
20
+ import os.path
21
+ import sys
22
+ import torch
23
+ import numpy as np
24
+ import pandas as pd
25
+ import openpyxl
26
+ from openpyxl.styles import PatternFill
27
+ from collections import namedtuple
28
+ from dataclasses import dataclass
29
+
30
+ from msprobe.pytorch.compare.match import graph_mapping
31
+ from msprobe.pytorch.compare.highlight import HighlightRules, get_header_index
32
+ from msprobe.pytorch.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \
33
+ get_error_message
34
+ from msprobe.pytorch.advisor.advisor import Advisor
35
+ from msprobe.pytorch.common.log import logger
36
+ from msprobe.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, \
37
+ format_value, check_file_not_exists, check_configuration_param, task_dumppath_get
38
+ from msprobe.core.common.file_check import FileChecker, change_mode, FileOpen, create_directory
39
+ from msprobe.core.common.const import Const, CompareConst, FileCheckConst
40
+
41
+
42
+ def check_graph_mode(a_op_name, b_op_name):
43
+ if "Aten" in a_op_name and "Aten" not in b_op_name:
44
+ return True
45
+ if "Aten" not in a_op_name and "Aten" in b_op_name:
46
+ return True
47
+ return False
48
+
49
+
50
+ def check_op(npu_dict, bench_dict, fuzzy_match):
51
+ a_op_name = npu_dict["op_name"]
52
+ b_op_name = bench_dict["op_name"]
53
+ graph_mode = check_graph_mode(a_op_name[0], b_op_name[0])
54
+ if graph_mode:
55
+ return graph_mapping.match(a_op_name[0], b_op_name[0])
56
+ struct_match = check_struct_match(npu_dict, bench_dict)
57
+ if not fuzzy_match:
58
+ return a_op_name == b_op_name and struct_match
59
+ is_match = True
60
+ try:
61
+ is_match = fuzzy_check_op(a_op_name, b_op_name)
62
+ except Exception as err:
63
+ logger.warning("%s and %s can not fuzzy match." % (a_op_name, b_op_name))
64
+ is_match = False
65
+ return is_match and struct_match
66
+
67
+
68
+ def check_struct_match(npu_dict, bench_dict):
69
+ npu_struct_in = npu_dict.get("input_struct")
70
+ bench_struct_in = bench_dict.get("input_struct")
71
+ npu_struct_out = npu_dict.get("output_struct")
72
+ bench_struct_out = bench_dict.get("output_struct")
73
+ is_match = npu_struct_in == bench_struct_in and npu_struct_out == bench_struct_out
74
+ if not is_match:
75
+ if len(npu_struct_in) == 0 or len(bench_struct_in) == 0 or len(npu_struct_in) != len(bench_struct_in):
76
+ return False
77
+ struct_in_is_match = check_type_shape_match(npu_struct_in, bench_struct_in)
78
+ struct_out_is_match = check_type_shape_match(npu_struct_out, bench_struct_out)
79
+ is_match = struct_in_is_match and struct_out_is_match
80
+ return is_match
81
+
82
+
83
+ def check_type_shape_match(npu_struct, bench_struct):
84
+ shape_type_match = False
85
+ for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
86
+ npu_type = npu_type_shape[0]
87
+ npu_shape = npu_type_shape[1]
88
+ bench_type = bench_type_shape[0]
89
+ bench_shape = bench_type_shape[1]
90
+ shape_match = npu_shape == bench_shape
91
+ type_match = npu_type == bench_type
92
+ if not type_match:
93
+ if [npu_type, bench_type] in [["torch.float16", "torch.float32"], ["torch.float32", "torch.float16"],
94
+ ["torch.float16", "torch.bfloat16"], ["torch.bfloat16", "torch.float16"]]:
95
+ type_match = True
96
+ else:
97
+ type_match = False
98
+ shape_type_match = shape_match and type_match
99
+ if not shape_type_match:
100
+ return False
101
+ return shape_type_match
102
+
103
+
104
+ def fuzzy_check_op(npu_name_list, bench_name_list):
105
+ if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list):
106
+ return False
107
+ is_match = True
108
+ for npu_name, bench_name in zip(npu_name_list, bench_name_list):
109
+ is_match = fuzzy_check_name(npu_name, bench_name)
110
+ if not is_match:
111
+ break
112
+ return is_match
113
+
114
+
115
+ def fuzzy_check_name(npu_name, bench_name):
116
+ if "forward" in npu_name and "forward" in bench_name:
117
+ is_match = rename_api(npu_name, "forward") == rename_api(bench_name, "forward")
118
+ elif "backward" in npu_name and "backward" in bench_name:
119
+ is_match = rename_api(npu_name, "backward") == rename_api(bench_name, "backward")
120
+ else:
121
+ is_match = npu_name == bench_name
122
+ return is_match
123
+
124
+
125
+ def rename_api(npu_name, process):
126
+ npu_split = npu_name.split(process)
127
+ torch_func_index, in_out = npu_split[0], npu_split[1]
128
+ torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
129
+ torch_func = str(torch_func_split[0]) + str(in_out)
130
+ return torch_func
131
+
132
+
133
+ def merge_tensor(tensor_list, summary_compare, md5_compare):
134
+ op_dict = {}
135
+ op_dict["op_name"] = []
136
+ op_dict["input_struct"] = []
137
+ op_dict["kwargs_struct"] = []
138
+ op_dict["output_struct"] = []
139
+ op_dict["summary"] = []
140
+ op_dict["stack_info"] = []
141
+
142
+ all_mode_bool = not (summary_compare or md5_compare)
143
+ if all_mode_bool:
144
+ op_dict["data_name"] = []
145
+
146
+ for tensor in tensor_list:
147
+ if len(tensor) == 2:
148
+ op_dict['stack_info'].append(tensor['full_info'])
149
+ break
150
+ op_dict["op_name"].append(tensor['full_op_name'])
151
+ if not md5_compare:
152
+ if tensor['full_op_name'].find("input") != -1:
153
+ op_dict["input_struct"].append((tensor['dtype'], tensor['shape']))
154
+ elif tensor['full_op_name'].find("kwarg") != -1:
155
+ op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape']))
156
+ elif tensor['full_op_name'].find("output") != -1:
157
+ op_dict["output_struct"].append((tensor['dtype'], tensor['shape']))
158
+ else:
159
+ if tensor['full_op_name'].find("input") != -1:
160
+ op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
161
+ elif tensor['full_op_name'].find("kwarg") != -1:
162
+ op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
163
+ elif tensor['full_op_name'].find("output") != -1:
164
+ op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
165
+
166
+ op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
167
+
168
+ if all_mode_bool:
169
+ op_dict["data_name"].append(tensor['data_name'])
170
+
171
+ if not op_dict["kwargs_struct"]:
172
+ del op_dict["kwargs_struct"]
173
+ return op_dict if op_dict["op_name"] else {}
174
+
175
+
176
+ def match_op(npu_queue, bench_queue, fuzzy_match):
177
+ for b_index, b_op in enumerate(bench_queue[0: -1]):
178
+ if check_op(npu_queue[-1], b_op, fuzzy_match):
179
+ return len(npu_queue) - 1, b_index
180
+ if check_op(npu_queue[-1], bench_queue[-1], fuzzy_match):
181
+ return len(npu_queue) - 1, len(bench_queue) - 1
182
+ for n_index, n_op in enumerate(npu_queue[0: -1]):
183
+ if check_op(n_op, bench_queue[-1], fuzzy_match):
184
+ return n_index, len(bench_queue) - 1
185
+ return -1, -1
186
+
187
+
188
+ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False):
189
+ def get_accuracy_core(n_start, n_len, b_start, b_len, key):
190
+ min_len = min(n_len, b_len)
191
+ npu_stack_info = n_dict.get("stack_info", None)
192
+ bench_stack_info = b_dict.get("stack_info", None)
193
+ has_stack = npu_stack_info and bench_stack_info
194
+
195
+ all_mode_bool = not (summary_compare or md5_compare)
196
+ if all_mode_bool:
197
+ npu_data_name = n_dict.get("data_name", None)
198
+ bench_data_name = b_dict.get("data_name", None)
199
+
200
+ for index in range(min_len):
201
+
202
+ n_name = n_dict['op_name'][n_start + index]
203
+ b_name = b_dict['op_name'][b_start + index]
204
+ n_struct = n_dict[key][index]
205
+ b_struct = b_dict[key][index]
206
+ err_msg = ""
207
+ if md5_compare:
208
+ result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
209
+ n_struct[2], b_struct[2],
210
+ CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF]
211
+ if has_stack and index == 0 and key == "input_struct":
212
+ result_item.extend(npu_stack_info)
213
+ else:
214
+ result_item.append(CompareConst.NONE)
215
+ result.append(result_item)
216
+ continue
217
+
218
+ if summary_compare:
219
+ result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
220
+ " ", " ", " ", " ", " ", " ", " ", " "]
221
+ else:
222
+ result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
223
+ " ", " ", " ", " ", " "]
224
+
225
+ npu_summary_data = n_dict.get("summary")[n_start + index]
226
+ result_item.extend(npu_summary_data)
227
+ bench_summary_data = b_dict.get("summary")[b_start + index]
228
+ result_item.extend(bench_summary_data)
229
+
230
+ if summary_compare:
231
+ start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
232
+ warning_flag = False
233
+ for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
234
+ if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
235
+ diff = npu_val - bench_val
236
+ if bench_val != 0:
237
+ relative = str(abs((diff / bench_val) * 100)) + '%'
238
+ else:
239
+ relative = "N/A"
240
+ result_item[start_idx + i] = diff
241
+ result_item[start_idx + i + 4] = relative
242
+ magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
243
+ if magnitude_diff > 0.5:
244
+ warning_flag = True
245
+ else:
246
+ result_item[start_idx + i] = CompareConst.NONE
247
+ accuracy_check = CompareConst.WARNING if warning_flag else ""
248
+ err_msg += "Need double check api accuracy." if warning_flag else ""
249
+ for i in range(start_idx, len(result_item)):
250
+ if str(result_item[i]) in ('inf', '-inf', 'nan'):
251
+ result_item[i] = f'{result_item[i]}\t'
252
+
253
+ result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES)
254
+ result_item.append(err_msg)
255
+ if has_stack and index == 0 and key == "input_struct":
256
+ result_item.extend(npu_stack_info)
257
+ else:
258
+ result_item.append(CompareConst.NONE)
259
+ if all_mode_bool:
260
+ result_item.append(npu_data_name[n_start + index])
261
+
262
+ result.append(result_item)
263
+
264
+ if n_len > b_len:
265
+ for index in range(b_len, n_len):
266
+ n_name = n_dict['op_name'][n_start + index]
267
+ n_struct = n_dict[key][index]
268
+ if md5_compare:
269
+ result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
270
+ n_struct[1], CompareConst.NAN, n_struct[2], CompareConst.NAN, CompareConst.NAN]
271
+ result.append(result_item)
272
+ continue
273
+ result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
274
+ n_struct[1], CompareConst.NAN, " ", " ", " ", " ", " "]
275
+ summary_data = n_dict.get("summary")[n_start + index]
276
+ result_item.extend(summary_data)
277
+ summary_data = [CompareConst.NAN for _ in range(len(n_dict.get("summary")[0]))]
278
+ result_item.extend(summary_data)
279
+
280
+ err_msg = ""
281
+ result_item.append(CompareConst.ACCURACY_CHECK_YES)
282
+ result_item.append(err_msg)
283
+
284
+ if has_stack and index == 0 and key == "input_struct":
285
+ result_item.extend(npu_stack_info)
286
+ else:
287
+ result_item.append(CompareConst.NONE)
288
+ if all_mode_bool:
289
+ result_item.append(npu_data_name[n_start + index])
290
+
291
+ result.append(result_item)
292
+
293
+ n_num = len(n_dict['op_name'])
294
+ b_num = len(b_dict['op_name'])
295
+ n_num_input = len([name for name in n_dict['op_name'] if 'input' in name])
296
+ b_num_input = len([name for name in b_dict['op_name'] if 'input' in name])
297
+ n_num_kwarg = len([name for name in n_dict['op_name'] if 'kwarg' in name])
298
+ b_num_kwarg = len([name for name in b_dict['op_name'] if 'kwarg' in name])
299
+ n_num_output = n_num - n_num_input - n_num_kwarg
300
+ b_num_output = b_num - b_num_input - b_num_kwarg
301
+ get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
302
+ get_accuracy_core(n_num_input, n_num_kwarg, b_num_input, b_num_kwarg, "kwargs_struct")
303
+ get_accuracy_core(n_num_input + n_num_kwarg, n_num_output, b_num_input + b_num_kwarg, b_num_output, 'output_struct')
304
+
305
+
306
+ def _do_multi_process(input_parma, result_df):
307
+ try:
308
+ result_df = _handle_multi_process(compare_ops, input_parma, result_df, multiprocessing.Manager().RLock())
309
+ return result_df
310
+ except ValueError as e:
311
+ logger.error('result dataframe is not found.')
312
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
313
+
314
+
315
+ def read_dump_data(result_df):
316
+ try:
317
+ npu_dump_name_list = result_df.iloc[0:, 0].tolist()
318
+ npu_dump_tensor_list = result_df.iloc[0:, -1].tolist()
319
+ op_name_mapping_dict = {}
320
+ for index, _ in enumerate(npu_dump_name_list):
321
+ npu_dump_name = npu_dump_name_list[index]
322
+ npu_dump_tensor = npu_dump_tensor_list[index]
323
+ op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor]
324
+ return op_name_mapping_dict
325
+ except ValueError as e:
326
+ logger.error('result dataframe is not found.')
327
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
328
+ except IndexError as e:
329
+ logger.error('result dataframe elements can not be access.')
330
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
331
+
332
+
333
+ def _handle_multi_process(func, input_parma, result_df, lock):
334
+ process_num = int((multiprocessing.cpu_count() + 1) / 2)
335
+ op_name_mapping_dict = read_dump_data(result_df)
336
+
337
+ df_chunk_size = len(result_df) // process_num
338
+ if df_chunk_size > 0:
339
+ df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
340
+ else:
341
+ df_chunks = [result_df]
342
+
343
+ results = []
344
+ pool = multiprocessing.Pool(process_num)
345
+
346
+ def err_call(args):
347
+ logger.error('multiprocess compare failed! Reason: {}'.format(args))
348
+ try:
349
+ pool.terminate()
350
+ except OSError as e:
351
+ logger.error("pool terminate failed")
352
+
353
+ for process_idx, df_chunk in enumerate(df_chunks):
354
+ idx = df_chunk_size * process_idx
355
+ result = pool.apply_async(func,
356
+ args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma),
357
+ error_callback=err_call)
358
+ results.append(result)
359
+ final_results = [r.get() for r in results]
360
+ pool.close()
361
+ pool.join()
362
+ return pd.concat(final_results, ignore_index=True)
363
+
364
+
365
+ def compare_ops(idx, dump_path_dict, result_df, lock, input_parma):
366
+ cos_result = []
367
+ max_err_result = []
368
+ max_relative_err_result = []
369
+ err_mess = []
370
+ one_thousand_err_ratio_result = []
371
+ five_thousand_err_ratio_result = []
372
+ is_print_compare_log = input_parma.get("is_print_compare_log")
373
+ for i in range(len(result_df)):
374
+ op_name = result_df.iloc[i, 0]
375
+ if is_print_compare_log:
376
+ logger.info("start compare: {}".format(op_name))
377
+ cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = compare_by_op(
378
+ op_name, dump_path_dict, input_parma)
379
+ if is_print_compare_log:
380
+ logger.info(
381
+ "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, one_thousand_err_ratio {}, "
382
+ "five_thousand_err_ratio {}".format(op_name, cos_sim, max_abs_err, max_relative_err, err_msg,
383
+ one_thousand_err_ratio, five_thousand_err_ratio))
384
+ cos_result.append(cos_sim)
385
+ max_err_result.append(max_abs_err)
386
+ max_relative_err_result.append(max_relative_err)
387
+ err_mess.append(err_msg)
388
+ one_thousand_err_ratio_result.append(one_thousand_err_ratio)
389
+ five_thousand_err_ratio_result.append(five_thousand_err_ratio)
390
+
391
+ cr = ComparisonResult(
392
+ cos_result=cos_result,
393
+ max_err_result=max_err_result,
394
+ max_relative_err_result=max_relative_err_result,
395
+ err_msgs=err_mess,
396
+ one_thousand_err_ratio_result=one_thousand_err_ratio_result,
397
+ five_thousand_err_ratio_result=five_thousand_err_ratio_result
398
+ )
399
+
400
+ return _save_cmp_result(idx, cr, result_df, lock)
401
+
402
+
403
+ @dataclass
404
+ class ComparisonResult:
405
+ cos_result: list
406
+ max_err_result: list
407
+ max_relative_err_result: list
408
+ err_msgs: list
409
+ one_thousand_err_ratio_result: list
410
+ five_thousand_err_ratio_result: list
411
+
412
+
413
+ def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
414
+ """
415
+ Save comparison results into the result DataFrame with thread safety.
416
+ Args:
417
+ offset: offset for index
418
+ result: data struct of ComparisonResult
419
+ result_df: result of DataFrame
420
+ lock: thread lock
421
+
422
+ Returns:
423
+ comparison results in DataFrame
424
+ """
425
+
426
+ lock.acquire()
427
+ try:
428
+ for i, _ in enumerate(result.cos_result):
429
+ process_index = i + offset
430
+ result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i]
431
+ result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
432
+ result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
433
+ result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
434
+ result_df.loc[process_index, CompareConst.ACCURACY] = check_accuracy(result.cos_result[i], result.max_err_result[i])
435
+ result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousand_err_ratio_result[i]
436
+ result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousand_err_ratio_result[i]
437
+ return result_df
438
+ except ValueError as e:
439
+ logger.error('result dataframe is not found.')
440
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
441
+ except IndexError as e:
442
+ logger.error('result dataframe elements can not be access.')
443
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
444
+ finally:
445
+ lock.release()
446
+
447
+
448
+ def check_accuracy(cos, max_abs_err):
449
+ if cos == CompareConst.SHAPE_UNMATCH:
450
+ return CompareConst.ACCURACY_CHECK_UNMATCH
451
+ if cos == CompareConst.NONE or max_abs_err == CompareConst.NONE:
452
+ return CompareConst.NONE
453
+ if cos == "N/A" or max_abs_err == "N/A":
454
+ return CompareConst.ACCURACY_CHECK_NO
455
+ try:
456
+ cos, max_abs_err = float(cos), float(max_abs_err)
457
+ except ValueError:
458
+ logger.warning("Cosine or MaxAbsErr can not get float value.")
459
+ return CompareConst.NONE
460
+ if cos < CompareConst.COS_THRESHOLD and max_abs_err > CompareConst.MAX_ABS_ERR_THRESHOLD:
461
+ return CompareConst.ACCURACY_CHECK_NO
462
+ if cos < CompareConst.COS_MAX_THRESHOLD or max_abs_err > CompareConst.MAX_ABS_ERR_MAX_THRESHOLD:
463
+ return CompareConst.ACCURACY_CHECK_NO
464
+ return CompareConst.ACCURACY_CHECK_YES
465
+
466
+
467
+ def read_npy_data(dir_path, file_name):
468
+ data_path = os.path.join(dir_path, file_name)
469
+ path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
470
+ FileCheckConst.PT_SUFFIX, False)
471
+ data_path = path_checker.common_check()
472
+ data_value = torch.load(data_path, map_location=torch.device('cpu')).detach() # detach for less memory
473
+ if data_value.dtype == torch.bfloat16:
474
+ data_value = data_value.to(torch.float32)
475
+ data_value = data_value.numpy()
476
+ return data_value
477
+
478
+
479
+ def compare_by_op(op_name, op_name_mapping_dict, input_parma):
480
+ npu_bench_name_list = op_name_mapping_dict[op_name]
481
+ data_name = npu_bench_name_list[1]
482
+ error_file, relative_err, error_flag = None, None, False
483
+ if data_name == '-1' or data_name == -1: # 没有真实数据路径
484
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
485
+ error_flag = True
486
+ else:
487
+ try:
488
+ n_value = read_npy_data(input_parma.get("npu_dump_data_dir"), npu_bench_name_list[0])
489
+ b_value = read_npy_data(input_parma.get("bench_dump_data_dir"), npu_bench_name_list[1])
490
+ except IOError as error:
491
+ error_file = error.filename
492
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
493
+ error_flag = True
494
+
495
+ n_value, b_value, error_flag = get_error_type(n_value, b_value, error_flag)
496
+ if not error_flag:
497
+ relative_err = get_relative_err(n_value, b_value)
498
+ n_value, b_value = reshape_value(n_value, b_value)
499
+
500
+ err_msg = get_error_message(n_value, b_value, op_name, error_flag, error_file=error_file)
501
+ result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=relative_err)
502
+
503
+ if npu_bench_name_list[0] != npu_bench_name_list[1]:
504
+ err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
505
+ result_list.append(err_msg)
506
+ return result_list
507
+
508
+
509
+ def handle_inf_nan(n_value, b_value):
510
+ n_inf = np.isinf(n_value)
511
+ b_inf = np.isinf(b_value)
512
+ n_nan = np.isnan(n_value)
513
+ b_nan = np.isnan(b_value)
514
+
515
+ # merge boolean expressions
516
+ any_inf = np.any(n_inf) or np.any(b_inf)
517
+ any_nan = np.any(n_nan) or np.any(b_nan)
518
+ if any_inf or any_nan:
519
+ if np.array_equal(n_inf, b_inf) and np.array_equal(n_nan, b_nan):
520
+ n_value[n_inf] = 0
521
+ b_value[b_inf] = 0
522
+ n_value[n_nan] = 0
523
+ b_value[b_nan] = 0
524
+ else:
525
+ return CompareConst.NAN, CompareConst.NAN
526
+ return n_value, b_value
527
+
528
+
529
+ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False, md5_compare=False):
530
+ """找到单个API中需要高亮的行"""
531
+ if md5_compare:
532
+ return
533
+ npu_max_index = get_header_index('NPU max', summary_compare)
534
+ bench_max_index = get_header_index('Bench max', summary_compare)
535
+ max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
536
+
537
+ red_lines, yellow_lines = [], []
538
+ LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
539
+ ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer'])
540
+ ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
541
+ color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
542
+
543
+ # 对单行API的输入或输出进行误差判断
544
+ for i, line in enumerate(result):
545
+ num = last_len + i
546
+ line_info = LineInfo(line_data=line, num_pointer=num)
547
+ for rule in HighlightRules.basic_rules.values():
548
+ rule.apply(line_info, color_columns, summary_compare)
549
+
550
+ # 对API的输出与输入比较,进行误差判断
551
+ for n, api_out in enumerate(result[n_num_input:len(result)]):
552
+ num = last_len + n_num_input + n
553
+ if num in red_lines:
554
+ continue
555
+ if not isinstance(api_out[npu_max_index], (float, int)) \
556
+ or not isinstance(api_out[bench_max_index], (float, int)) \
557
+ or not isinstance(api_out[max_diff_index], (float, int)):
558
+ continue
559
+ for _, api_in in enumerate(result[0:n_num_input]):
560
+ if not isinstance(api_in[npu_max_index], (float, int)) \
561
+ or not isinstance(api_in[bench_max_index], (float, int)) \
562
+ or not isinstance(api_in[max_diff_index], (float, int)):
563
+ continue
564
+
565
+ api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num)
566
+ if summary_compare:
567
+ for rule in HighlightRules.summary_compare_rules.values():
568
+ rule.apply(api_info, color_columns, summary_compare)
569
+ else:
570
+ for rule in HighlightRules.compare_rules.values():
571
+ rule.apply(api_info, color_columns, summary_compare)
572
+
573
+ highlight_dict.get('red_rows', []).extend(list(set(red_lines)))
574
+ highlight_dict.get('yellow_rows', []).extend(list(set(yellow_lines) - set(red_lines)))
575
+
576
+
577
+ def get_name_and_state(name):
578
+ """Get api/module name and state"""
579
+ if "input" in name:
580
+ api_name = name.split("input")[0]
581
+ state = "input"
582
+ else:
583
+ api_name = name.split("output")[0]
584
+ state = "output"
585
+ return api_name, state
586
+
587
+
588
+ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare):
589
+ """将dataframe根据API分组,并找到有误差的算子用于高亮"""
590
+ result = result_df.values
591
+ start, input_num, output_num, end = 0, 0, 0, len(result_df)
592
+ last_api_name, last_state = None, None
593
+ num, last_len = 0, 0
594
+ for res_i in result:
595
+ api_name, state = get_name_and_state(res_i[0])
596
+ if last_api_name:
597
+ if api_name == last_api_name:
598
+ if state == last_state:
599
+ num += 1
600
+ else:
601
+ input_num = num
602
+ num, last_state = 1, state
603
+ else:
604
+ output_num = num
605
+ find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
606
+ summary_compare, md5_compare)
607
+ num, last_api_name, last_state = 1, api_name, state
608
+ start += input_num + output_num
609
+ input_num, output_num = 1, 0
610
+ else:
611
+ num, last_api_name, last_state = 1, api_name, state
612
+ if state:
613
+ if state == "input":
614
+ input_num = num
615
+ else:
616
+ output_num = num
617
+ find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare, md5_compare)
618
+
619
+
620
+ def highlight_rows_xlsx(result_df, highlight_dict, file_path):
621
+ """Write and highlight results in Excel"""
622
+ logger.info('Compare result is %s' % file_path)
623
+
624
+ wb = openpyxl.Workbook()
625
+ ws = wb.active
626
+
627
+ # write header
628
+ for j, col_name in enumerate(result_df.columns, start=1):
629
+ ws.cell(row=1, column=j, value=col_name)
630
+
631
+ for i, row in enumerate(result_df.iterrows(), start=2):
632
+ for j, value in enumerate(row[1], start=1):
633
+ if not isinstance(value, (float, int)):
634
+ value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value)
635
+ ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value)
636
+
637
+ if (i - 2) in highlight_dict['red_rows']:
638
+ ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.RED,
639
+ end_color=CompareConst.RED, fill_type="solid")
640
+ elif (i - 2) in highlight_dict['yellow_rows']:
641
+ ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW,
642
+ end_color=CompareConst.YELLOW, fill_type="solid")
643
+ wb.save(file_path)
644
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
645
+
646
+
647
+ def compare(input_parma, output_path, stack_mode=False, auto_analyze=True,
648
+ fuzzy_match=False):
649
+ try:
650
+ summary_compare, md5_compare = task_dumppath_get(input_parma)
651
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
652
+ create_directory(output_path)
653
+ check_compare_param(input_parma, output_path, stack_mode, summary_compare, md5_compare)
654
+ except CompareException as error:
655
+ logger.error('Compare failed. Please check the arguments and do it again!')
656
+ sys.exit(error.code)
657
+ compare_core(input_parma, output_path, stack_mode=stack_mode,
658
+ auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
659
+ md5_compare=md5_compare)
660
+
661
+
662
+ def compare_core(input_parma, output_path, **kwargs):
663
+ """
664
+ Compares data from multiple JSON files and generates a comparison report.
665
+
666
+ Args:
667
+ input_parma (dict): A dictionary containing paths to JSON files ("npu_json_path", "bench_json_path",
668
+ "stack_json_path").
669
+ output_path (str): The path where the output Excel report will be saved.
670
+ **kwargs: Additional keyword arguments including:
671
+ - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False.
672
+ - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
673
+ - suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
674
+ - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
675
+ - summary_compare (bool, optional): Enables summary comparison mode. Defaults to False.
676
+ - md5_compare (bool, optional): Enables MD5 comparison. Defaults to False.
677
+
678
+ Returns:
679
+ """
680
+ # get kwargs or set default value
681
+ stack_mode = kwargs.get('stack_mode', False)
682
+ auto_analyze = kwargs.get('auto_analyze', True)
683
+ suffix = kwargs.get('suffix', '')
684
+ fuzzy_match = kwargs.get('fuzzy_match', False)
685
+ summary_compare = kwargs.get('summary_compare', False)
686
+ md5_compare = kwargs.get('md5_compare', False)
687
+
688
+ logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
689
+ file_name = add_time_with_xlsx("compare_result" + suffix)
690
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
691
+ check_file_not_exists(file_path)
692
+ highlight_dict = {'red_rows': [], 'yellow_rows': []}
693
+
694
+ with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \
695
+ FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \
696
+ FileOpen(input_parma.get("stack_json_path"), "r") as stack_json:
697
+ result_df = compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match,
698
+ summary_compare, md5_compare)
699
+
700
+ if not md5_compare and not summary_compare:
701
+ result_df = _do_multi_process(input_parma, result_df)
702
+ find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare)
703
+ highlight_rows_xlsx(result_df, highlight_dict, file_path)
704
+ if auto_analyze:
705
+ advisor = Advisor(result_df, output_path)
706
+ advisor.analysis()
707
+
708
+
709
+ def parse(pkl_file, module_name_prefix):
710
+ if not isinstance(module_name_prefix, str):
711
+ logger.error("The parameter:module_name_prefix is not a string.")
712
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
713
+ with FileOpen(pkl_file, "r") as f:
714
+ done = False
715
+ title_printed = False
716
+ while not done:
717
+ pkl_line = f.readline()
718
+ if pkl_line == '\n':
719
+ continue
720
+ if len(pkl_line) == 0:
721
+ done = True
722
+ break
723
+
724
+ msg = json.loads(pkl_line)
725
+ info_prefix = msg[0]
726
+ if not info_prefix.startswith(module_name_prefix):
727
+ continue
728
+
729
+ if info_prefix.find("stack_info") != -1:
730
+ logger.info("\nTrace back({}):".format(msg[0]))
731
+ for item in reversed(msg[1]):
732
+ logger.info(" File \"{}\", line {}, in {}".format(item[0], item[1], item[2]))
733
+ logger.info(" {}".format(item[3]))
734
+ continue
735
+ if len(msg) > 5:
736
+ summary_info = " [{}][dtype: {}][shape: {}][max: {}][min: {}][mean: {}]" \
737
+ .format(msg[0], msg[3], msg[4], msg[5][0], msg[5][1], msg[5][2])
738
+ if not title_printed:
739
+ logger.info("\nStatistic Info:")
740
+ title_printed = True
741
+ logger.info(summary_info)
742
+
743
+
744
+ def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
745
+ if item_list is None:
746
+ item_list = []
747
+ if item is None or (isinstance(item, dict) and not item):
748
+ if not top_bool:
749
+ tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None,
750
+ 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'}
751
+ else:
752
+ tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None,
753
+ 'shape': None, 'md5': None, 'data_name': '-1'}
754
+ item_list.append(tmp)
755
+ return item_list
756
+ if index is None:
757
+ if isinstance(item, dict):
758
+ full_op_name = op_name + '.0'
759
+ else:
760
+ full_op_name = op_name
761
+ else:
762
+ full_op_name = op_name + '.' + str(index)
763
+ if isinstance(item, dict):
764
+ if 'dtype' in item:
765
+ parsed_item = item
766
+ parsed_item['full_op_name'] = full_op_name
767
+ item_list.append(parsed_item)
768
+ elif 'type' in item:
769
+ parsed_item = {}
770
+ if item['type'] == 'torch.Size':
771
+ parsed_item['full_op_name'] = full_op_name
772
+ parsed_item['dtype'] = 'torch.Size'
773
+ parsed_item['shape'] = str(item['value'])
774
+ parsed_item['md5'] = None
775
+ parsed_item['Max'] = None
776
+ parsed_item['Min'] = None
777
+ parsed_item['Mean'] = None
778
+ parsed_item['Norm'] = None
779
+ parsed_item['data_name'] = '-1'
780
+ item_list.append(parsed_item)
781
+ elif item['type'] == 'slice':
782
+ parsed_item['full_op_name'] = full_op_name
783
+ parsed_item['dtype'] = 'slice'
784
+ parsed_item['shape'] = str(np.shape(np.array(item['value'])))
785
+ parsed_item['md5'] = None
786
+ parsed_item['Max'] = None
787
+ parsed_item['Min'] = None
788
+ parsed_item['Mean'] = None
789
+ parsed_item['Norm'] = None
790
+ parsed_item['data_name'] = '-1'
791
+ item_list.append(parsed_item)
792
+ else:
793
+ parsed_item['full_op_name'] = full_op_name
794
+ parsed_item['dtype'] = str(type(item['value']))
795
+ parsed_item['shape'] = '[]'
796
+ parsed_item['md5'] = None
797
+ parsed_item['Max'] = item['value']
798
+ parsed_item['Min'] = item['value']
799
+ parsed_item['Mean'] = item['value']
800
+ parsed_item['Norm'] = item['value']
801
+ parsed_item['data_name'] = '-1'
802
+ item_list.append(parsed_item)
803
+ else:
804
+ resolve_api_special_parameters(item, full_op_name, item_list)
805
+ else:
806
+ for j, item_spec in enumerate(item):
807
+ op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False)
808
+ return item_list
809
+
810
+
811
+ def resolve_api_special_parameters(data_dict, full_op_name, item_list):
812
+ """
813
+ Function Description:
814
+ 解析下面格式的数据, 是api参数的一种特殊格式
815
+ {
816
+ "last_hidden_state": {
817
+ "type": "torch.Tensor",
818
+ "dtype": "torch.bfloat16",
819
+ ...
820
+ },
821
+ "loss": {
822
+ "type": "torch.Tensor",
823
+ "dtype": "torch.float32",
824
+ ...
825
+ }
826
+ }
827
+ Parameter:
828
+ data_dict: 字典格式的数据
829
+ full_op_name: 参数的全名字符串
830
+ item_list: 参数信息集合
831
+ """
832
+ for key, value in data_dict.items():
833
+ if isinstance(value, dict):
834
+ parsed_item = value
835
+ parts = full_op_name.split(".")
836
+ parts.insert(-1, key)
837
+ full_op_name_new = ".".join(parts)
838
+ parsed_item['full_op_name'] = full_op_name_new
839
+ item_list.append(parsed_item)
840
+
841
+
842
+ def read_op(op_data, op_name):
843
+ op_parsed_list = []
844
+ if 'forward' in op_name:
845
+ if 'input_args' in op_data:
846
+ input_item = op_data['input_args']
847
+ input_parsed_list = op_item_parse(input_item, op_name + '_input', None)
848
+ op_parsed_list = input_parsed_list.copy()
849
+ input_parsed_list.clear()
850
+ if 'input_kwargs' in op_data:
851
+ kwargs_item = op_data['input_kwargs']
852
+ if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list):
853
+ kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '_input', None)
854
+ op_parsed_list += kwarg_parsed_list
855
+ kwarg_parsed_list.clear()
856
+ elif kwargs_item:
857
+ for kwarg in kwargs_item:
858
+ kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '_input.' + kwarg, None)
859
+ op_parsed_list += kwarg_parsed_list
860
+ kwarg_parsed_list.clear()
861
+ if 'output' in op_data:
862
+ output_item = op_data['output']
863
+ output_parsed_list = op_item_parse(output_item, op_name + '_output', None)
864
+ op_parsed_list += output_parsed_list
865
+ output_parsed_list.clear()
866
+ if 'backward' in op_name:
867
+ if 'grad_input' in op_data:
868
+ input_item = op_data['grad_input']
869
+ input_parsed_list = op_item_parse(input_item, op_name + '_input', None)
870
+ op_parsed_list = input_parsed_list.copy()
871
+ input_parsed_list.clear()
872
+ if 'grad_output' in op_data:
873
+ output_item = op_data['grad_output']
874
+ output_parsed_list = op_item_parse(output_item, op_name + '_output', None)
875
+ op_parsed_list += output_parsed_list
876
+ output_parsed_list.clear()
877
+ return op_parsed_list
878
+
879
+
880
+ def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False):
881
+ npu_json_handle, bench_json_handle, stack_json_handle = file_handles
882
+ npu_json_data = json.load(npu_json_handle)
883
+ bench_json_data = json.load(bench_json_handle)
884
+ stack_json_data = json.load(stack_json_handle)
885
+
886
+ if fuzzy_match:
887
+ logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
888
+
889
+ npu_ops_queue = []
890
+ bench_ops_queue = []
891
+ result = []
892
+
893
+ ops_npu_iter = iter(npu_json_data['data'])
894
+ ops_bench_iter = iter(bench_json_data['data'])
895
+ read_err_npu = True
896
+ read_err_bench = True
897
+ last_npu_ops_len = 0
898
+ last_bench_ops_len = 0
899
+
900
+ while True:
901
+ if not read_err_npu and not read_err_bench:
902
+ break
903
+ try:
904
+ last_npu_ops_len = len(npu_ops_queue)
905
+ op_name_npu = next(ops_npu_iter)
906
+ read_err_npu = True
907
+
908
+ npu_op_data = npu_json_data['data'][op_name_npu]
909
+ npu_op_parsed_list = read_op(npu_op_data, op_name_npu)
910
+ if op_name_npu in stack_json_data:
911
+ npu_op_parsed_list.append({'full_op_name': op_name_npu, 'full_info': stack_json_data[op_name_npu]})
912
+ else:
913
+ npu_op_parsed_list.append({'full_op_name': op_name_npu, 'full_info': None})
914
+
915
+ npu_merge_list = merge_tensor(npu_op_parsed_list, summary_compare, md5_compare)
916
+ if npu_merge_list:
917
+ npu_ops_queue.append(npu_merge_list)
918
+ except StopIteration:
919
+ read_err_npu = False
920
+ try:
921
+ last_bench_ops_len = len(bench_ops_queue)
922
+ op_name_bench = next(ops_bench_iter)
923
+
924
+ bench_op_data = bench_json_data['data'][op_name_bench]
925
+ bench_op_parsed_list = read_op(bench_op_data, op_name_bench)
926
+ if op_name_bench in stack_json_data:
927
+ bench_op_parsed_list.append(
928
+ {'full_op_name': op_name_bench, 'full_info': stack_json_data[op_name_bench]})
929
+ else:
930
+ bench_op_parsed_list.append({'full_op_name': op_name_bench, 'full_info': None})
931
+
932
+ bench_merge_list = merge_tensor(bench_op_parsed_list, summary_compare, md5_compare)
933
+ if bench_merge_list:
934
+ bench_ops_queue.append(bench_merge_list)
935
+ except StopIteration:
936
+ read_err_bench = False
937
+
938
+ # merge all boolean expressions
939
+ both_empty = not npu_ops_queue and not bench_ops_queue
940
+ no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len)
941
+ if both_empty or no_change:
942
+ continue
943
+
944
+ n_match_point, b_match_point = match_op(npu_ops_queue, bench_ops_queue, fuzzy_match)
945
+ if n_match_point == -1 and b_match_point == -1:
946
+ continue
947
+ n_match_data = npu_ops_queue[n_match_point]
948
+ b_match_data = bench_ops_queue[b_match_point]
949
+ un_match_data = npu_ops_queue[0: n_match_point]
950
+ for npu_data in un_match_data:
951
+ get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
952
+ get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare)
953
+ del npu_ops_queue[0: n_match_point + 1]
954
+ del bench_ops_queue[0: b_match_point + 1]
955
+ if npu_ops_queue:
956
+ for npu_data in npu_ops_queue:
957
+ get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
958
+
959
+ header = []
960
+ if md5_compare:
961
+ header = CompareConst.MD5_COMPARE_RESULT_HEADER[:]
962
+ elif summary_compare:
963
+ header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
964
+ else:
965
+ header = CompareConst.COMPARE_RESULT_HEADER[:]
966
+
967
+ all_mode_bool = not (summary_compare or md5_compare)
968
+ if stack_mode:
969
+ if all_mode_bool:
970
+ header.append(CompareConst.STACK)
971
+ header.append(CompareConst.DATA_NAME)
972
+ else:
973
+ header.append(CompareConst.STACK)
974
+ else:
975
+ if all_mode_bool:
976
+ for row in result:
977
+ del row[-2]
978
+ header.append(CompareConst.DATA_NAME)
979
+ else:
980
+ for row in result:
981
+ del row[-1]
982
+
983
+ result_df = pd.DataFrame(result, columns=header)
984
+ return result_df
985
+
986
+
987
+ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
988
+ index_out = 0
989
+ npu_stack_info = n_dict.get("stack_info", None)
990
+ bench_name, bench_type, bench_shape = CompareConst.NAN, CompareConst.NAN, CompareConst.NAN
991
+ err_msg = CompareConst.NO_BENCH
992
+ accuracy_check_res = CompareConst.NAN
993
+ for index, n_name in enumerate(n_dict["op_name"]):
994
+ if n_name.find("input") != -1:
995
+ n_struct = n_dict["input_struct"][index]
996
+ else:
997
+ n_struct = n_dict["output_struct"][index_out]
998
+ index_out += 1
999
+
1000
+ result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
1001
+ if md5_compare:
1002
+ result_item.extend([CompareConst.NAN] * 3)
1003
+ if npu_stack_info and index == 0:
1004
+ result_item.extend(npu_stack_info)
1005
+ result.append(result_item)
1006
+ continue
1007
+ if summary_compare:
1008
+ result_item.extend([CompareConst.NAN] * 8)
1009
+ else:
1010
+ result_item.extend([CompareConst.NAN] * 5)
1011
+ summary_data = n_dict.get("summary")[index]
1012
+ result_item.extend(summary_data)
1013
+ summary_data = [CompareConst.NAN] * 4
1014
+ result_item.extend(summary_data)
1015
+ result_item.append(accuracy_check_res)
1016
+ result_item.append(err_msg)
1017
+ if npu_stack_info and index == 0:
1018
+ result_item.extend(npu_stack_info)
1019
+ if not md5_compare and not summary_compare and result_item[1] == CompareConst.NAN:
1020
+ if index == 0:
1021
+ result_item.extend(["-1"])
1022
+ else:
1023
+ result_item.extend([CompareConst.NONE, "-1"])
1024
+ result.append(result_item)