mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -0,0 +1,299 @@
1
+ import multiprocessing
2
+ import os
3
+ import json
4
+ import pandas as pd
5
+ from msprobe.core.common.file_check import FileOpen
6
+ from msprobe.core.common.const import CompareConst, Const
7
+ from msprobe.core.common.exceptions import FileCheckException
8
+ from msprobe.core.common.log import logger
9
+ from msprobe.core.common.utils import add_time_with_xlsx, CompareException, check_file_not_exists
10
+ from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op
11
+ from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
12
+ from msprobe.core.compare.utils import read_op, merge_tensor, CompareException, get_un_match_accuracy, get_accuracy
13
+ from msprobe.core.compare.multiprocessing_compute import _handle_multi_process, ComparisonResult, _save_cmp_result
14
+ from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \
15
+ get_error_message
16
+ from msprobe.core.advisor.advisor import Advisor
17
+
18
+
19
+ class Comparator:
20
+
21
+ def __init__(self):
22
+ pass
23
+
24
+ @classmethod
25
+ def make_result_table(cls,result, md5_compare, summary_compare, stack_mode):
26
+ header = []
27
+ if md5_compare:
28
+ header = CompareConst.MD5_COMPARE_RESULT_HEADER[:]
29
+ elif summary_compare:
30
+ header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
31
+ else:
32
+ header = CompareConst.COMPARE_RESULT_HEADER[:]
33
+
34
+ all_mode_bool = not (summary_compare or md5_compare)
35
+ if stack_mode:
36
+ if all_mode_bool:
37
+ header.append(CompareConst.STACK)
38
+ header.append(CompareConst.DATA_NAME)
39
+ else:
40
+ header.append(CompareConst.STACK)
41
+ else:
42
+ if all_mode_bool:
43
+ for row in result:
44
+ del row[-2]
45
+ header.append(CompareConst.DATA_NAME)
46
+ else:
47
+ for row in result:
48
+ del row[-1]
49
+ result_df = pd.DataFrame(result, columns=header)
50
+ return result_df
51
+
52
+ @classmethod
53
+ def gen_merge_list(self, json_data, op_name,stack_json_data, summary_compare, md5_compare):
54
+ op_data = json_data['data'][op_name]
55
+ op_parsed_list = read_op(op_data, op_name)
56
+ if op_name in stack_json_data:
57
+ op_parsed_list.append({'full_op_name': op_name, 'full_info': stack_json_data[op_name]})
58
+ else:
59
+ op_parsed_list.append({'full_op_name': op_name, 'full_info': None})
60
+
61
+ merge_list = merge_tensor(op_parsed_list, summary_compare, md5_compare)
62
+ return merge_list
63
+
64
+ def check_op(self, npu_dict, bench_dict, fuzzy_match):
65
+ a_op_name = npu_dict["op_name"]
66
+ b_op_name = bench_dict["op_name"]
67
+ graph_mode = check_graph_mode(a_op_name[0], b_op_name[0])
68
+
69
+ frame_name = getattr(self,"frame_name")
70
+ if frame_name == "PTComparator":
71
+ from msprobe.pytorch.compare.match import graph_mapping
72
+ if graph_mode:
73
+ return graph_mapping.match(a_op_name[0], b_op_name[0])
74
+ struct_match = check_struct_match(npu_dict, bench_dict)
75
+ if not fuzzy_match:
76
+ return a_op_name == b_op_name and struct_match
77
+ is_match = True
78
+ try:
79
+ is_match = fuzzy_check_op(a_op_name, b_op_name)
80
+ except Exception as err:
81
+ logger.warning("%s and %s can not fuzzy match." % (a_op_name, b_op_name))
82
+ is_match = False
83
+ return is_match and struct_match
84
+
85
+ def match_op(self, npu_queue, bench_queue, fuzzy_match):
86
+ for b_index, b_op in enumerate(bench_queue[0: -1]):
87
+ if self.check_op(npu_queue[-1], b_op, fuzzy_match):
88
+ return len(npu_queue) - 1, b_index
89
+ if self.check_op(npu_queue[-1], bench_queue[-1], fuzzy_match):
90
+ return len(npu_queue) - 1, len(bench_queue) - 1
91
+ for n_index, n_op in enumerate(npu_queue[0: -1]):
92
+ if self.check_op(n_op, bench_queue[-1], fuzzy_match):
93
+ return n_index, len(bench_queue) - 1
94
+ return -1, -1
95
+
96
+ def compare_process(self, file_handles, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False):
97
+ npu_json_handle, bench_json_handle, stack_json_handle = file_handles
98
+ npu_json_data = json.load(npu_json_handle)
99
+ bench_json_data = json.load(bench_json_handle)
100
+ stack_json_data = json.load(stack_json_handle)
101
+
102
+ if fuzzy_match:
103
+ logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
104
+
105
+ npu_ops_queue = []
106
+ bench_ops_queue = []
107
+ result = []
108
+
109
+ ops_npu_iter = iter(npu_json_data['data'])
110
+ ops_bench_iter = iter(bench_json_data['data'])
111
+ read_err_npu = True
112
+ read_err_bench = True
113
+ last_npu_ops_len = 0
114
+ last_bench_ops_len = 0
115
+
116
+ while True:
117
+ if not read_err_npu and not read_err_bench:
118
+ break
119
+ try:
120
+ last_npu_ops_len = len(npu_ops_queue)
121
+ op_name_npu = next(ops_npu_iter)
122
+ read_err_npu = True
123
+ npu_merge_list = self.gen_merge_list(npu_json_data,op_name_npu,stack_json_data,summary_compare,md5_compare)
124
+ if npu_merge_list:
125
+ npu_ops_queue.append(npu_merge_list)
126
+ except StopIteration:
127
+ read_err_npu = False
128
+ try:
129
+ last_bench_ops_len = len(bench_ops_queue)
130
+ op_name_bench = next(ops_bench_iter)
131
+ bench_merge_list = self.gen_merge_list(bench_json_data,op_name_bench,stack_json_data,summary_compare,md5_compare)
132
+ if bench_merge_list:
133
+ bench_ops_queue.append(bench_merge_list)
134
+ except StopIteration:
135
+ read_err_bench = False
136
+
137
+ # merge all boolean expressions
138
+ both_empty = not npu_ops_queue and not bench_ops_queue
139
+ no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len)
140
+ if both_empty or no_change:
141
+ continue
142
+
143
+ # APIs in NPU and Bench models unconsistent judgment
144
+ if bool(npu_ops_queue) ^ bool(bench_ops_queue):
145
+ logger.info("Please check whether the number and calls of APIs in NPU and Bench models are consistent.")
146
+ break
147
+
148
+ n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue, fuzzy_match)
149
+ if n_match_point == -1 and b_match_point == -1:
150
+ continue
151
+ n_match_data = npu_ops_queue[n_match_point]
152
+ b_match_data = bench_ops_queue[b_match_point]
153
+ un_match_data = npu_ops_queue[0: n_match_point]
154
+ for npu_data in un_match_data:
155
+ get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
156
+ get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare)
157
+ del npu_ops_queue[0: n_match_point + 1]
158
+ del bench_ops_queue[0: b_match_point + 1]
159
+ if npu_ops_queue:
160
+ for npu_data in npu_ops_queue:
161
+ get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
162
+
163
+ result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
164
+ return result_df
165
+
166
+ def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
167
+ npu_bench_name_list = op_name_mapping_dict[npu_op_name]
168
+ data_name = npu_bench_name_list[1]
169
+ error_file, relative_err, error_flag = None, None, False
170
+ if data_name == '-1' or data_name == -1: # 没有真实数据路径
171
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
172
+ error_flag = True
173
+ else:
174
+ try:
175
+ read_npy_data = getattr(self, "read_npy_data")
176
+ frame_name = getattr(self, "frame_name")
177
+ if frame_name == "MSComparator":
178
+ n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX)
179
+ if self.cross_frame:
180
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.PT_SUFFIX, load_pt_file=True)
181
+ else:
182
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.NUMPY_SUFFIX)
183
+ else:
184
+ n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX)
185
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.PT_SUFFIX)
186
+ except IOError as error:
187
+ error_file = error.filename
188
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
189
+ error_flag = True
190
+ except FileCheckException:
191
+ error_file = data_name
192
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
193
+ error_flag = True
194
+
195
+ n_value, b_value, error_flag = get_error_type(n_value, b_value, error_flag)
196
+ if not error_flag:
197
+ relative_err = get_relative_err(n_value, b_value)
198
+ n_value, b_value = reshape_value(n_value, b_value)
199
+
200
+ err_msg = get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=error_file)
201
+ result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=relative_err)
202
+
203
+ if npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
204
+ err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
205
+ result_list.append(err_msg)
206
+ return result_list
207
+
208
+ def compare_core(self, input_parma, output_path, **kwargs):
209
+ """
210
+ Compares data from multiple JSON files and generates a comparison report.
211
+
212
+ Args:
213
+ input_parma (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
214
+ "stack_path").
215
+ output_path (str): The path where the output Excel report will be saved.
216
+ **kwargs: Additional keyword arguments including:
217
+ - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False.
218
+ - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
219
+ - suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
220
+ - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
221
+ - summary_compare (bool, optional): Enables summary comparison mode. Defaults to False.
222
+ - md5_compare (bool, optional): Enables MD5 comparison. Defaults to False.
223
+
224
+ Returns:
225
+ """
226
+ # get kwargs or set default value
227
+ stack_mode = kwargs.get('stack_mode', False)
228
+ auto_analyze = kwargs.get('auto_analyze', True)
229
+ suffix = kwargs.get('suffix', '')
230
+ fuzzy_match = kwargs.get('fuzzy_match', False)
231
+ summary_compare = kwargs.get('summary_compare', False)
232
+ md5_compare = kwargs.get('md5_compare', False)
233
+
234
+ logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
235
+ file_name = add_time_with_xlsx("compare_result" + suffix)
236
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
237
+ check_file_not_exists(file_path)
238
+ highlight_dict = {'red_rows': [], 'yellow_rows': []}
239
+
240
+ with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \
241
+ FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \
242
+ FileOpen(input_parma.get("stack_json_path"), "r") as stack_json:
243
+ result_df = self.compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match,
244
+ summary_compare, md5_compare)
245
+
246
+ if not md5_compare and not summary_compare:
247
+ result_df = self._do_multi_process(input_parma, result_df)
248
+ find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare)
249
+ highlight_rows_xlsx(result_df, highlight_dict, file_path)
250
+ if auto_analyze:
251
+ advisor = Advisor(result_df, output_path)
252
+ advisor.analysis()
253
+
254
+ def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
255
+ cos_result = []
256
+ max_err_result = []
257
+ max_relative_err_result = []
258
+ err_mess = []
259
+ one_thousand_err_ratio_result = []
260
+ five_thousand_err_ratio_result = []
261
+ is_print_compare_log = input_param.get("is_print_compare_log")
262
+ for i in range(len(result_df)):
263
+ npu_op_name = result_df.iloc[i, 0]
264
+ bench_op_name = result_df.iloc[i, 1]
265
+ if is_print_compare_log:
266
+ logger.info("start compare: {}".format(npu_op_name))
267
+ cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = self.compare_by_op(
268
+ npu_op_name, bench_op_name, dump_path_dict, input_param)
269
+ if is_print_compare_log:
270
+ logger.info(
271
+ "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, one_thousand_err_ratio {}, "
272
+ "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err, err_msg,
273
+ one_thousand_err_ratio, five_thousand_err_ratio))
274
+ cos_result.append(cos_sim)
275
+ max_err_result.append(max_abs_err)
276
+ max_relative_err_result.append(max_relative_err)
277
+ err_mess.append(err_msg)
278
+ one_thousand_err_ratio_result.append(one_thousand_err_ratio)
279
+ five_thousand_err_ratio_result.append(five_thousand_err_ratio)
280
+
281
+ cr = ComparisonResult(
282
+ cos_result=cos_result,
283
+ max_err_result=max_err_result,
284
+ max_relative_err_result=max_relative_err_result,
285
+ err_msgs=err_mess,
286
+ one_thousand_err_ratio_result=one_thousand_err_ratio_result,
287
+ five_thousand_err_ratio_result=five_thousand_err_ratio_result
288
+ )
289
+
290
+ return _save_cmp_result(idx, cr, result_df, lock)
291
+
292
+ def _do_multi_process(self,input_parma, result_df):
293
+ try:
294
+ result_df = _handle_multi_process(self.compare_ops, input_parma, result_df, multiprocessing.Manager().RLock())
295
+ return result_df
296
+ except ValueError as e:
297
+ logger.error('result dataframe is not found.')
298
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
299
+
@@ -0,0 +1,95 @@
1
+ from msprobe.core.common.log import logger
2
+ from msprobe.core.compare.utils import rename_api
3
+
4
+
5
+ dtype_mapping = {
6
+ "Int8": "torch.int8",
7
+ "UInt8": "torch.uint8",
8
+ "Int16": "torch.int16",
9
+ "UInt16": "torch.uint16",
10
+ "Int32": "torch.int32",
11
+ "UInt32": "torch.uint32",
12
+ "Int64": "torch.int64",
13
+ "UInt64": "torch.uint64",
14
+ "Float16": "torch.float16",
15
+ "Float32": "torch.float32",
16
+ "Float64": "torch.float64",
17
+ "Bool": "torch.bool",
18
+ "BFloat16": "torch.bfloat16",
19
+ "Complex64": "torch.complex64",
20
+ "Complex128": "torch.complex128"
21
+ }
22
+
23
+
24
+ def check_struct_match(npu_dict, bench_dict, cross_frame=False):
25
+ npu_struct_in = npu_dict.get("input_struct")
26
+ bench_struct_in = bench_dict.get("input_struct")
27
+ npu_struct_out = npu_dict.get("output_struct")
28
+ bench_struct_out = bench_dict.get("output_struct")
29
+
30
+ if cross_frame:
31
+ npu_struct_in = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_in]
32
+ npu_struct_out = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_out]
33
+ is_match = npu_struct_in == bench_struct_in and npu_struct_out == bench_struct_out
34
+ if not is_match:
35
+ if len(npu_struct_in) == 0 or len(bench_struct_in) == 0 or len(npu_struct_in) != len(bench_struct_in):
36
+ return False
37
+ struct_in_is_match = check_type_shape_match(npu_struct_in, bench_struct_in)
38
+ struct_out_is_match = check_type_shape_match(npu_struct_out, bench_struct_out)
39
+ is_match = struct_in_is_match and struct_out_is_match
40
+ return is_match
41
+
42
+
43
+ def check_type_shape_match(npu_struct, bench_struct):
44
+ shape_type_match = False
45
+ for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
46
+ npu_type = npu_type_shape[0]
47
+ npu_shape = npu_type_shape[1]
48
+ bench_type = bench_type_shape[0]
49
+ bench_shape = bench_type_shape[1]
50
+ shape_match = npu_shape == bench_shape
51
+ type_match = npu_type == bench_type
52
+ if not type_match:
53
+ ms_type=[["Float16", "Float32"], ["Float32", "Float16"],["Float16", "BFloat16"],["BFloat16", "Float16"]]
54
+ torch_type=[["torch.float16", "torch.float32"], ["torch.float32", "torch.float16"],
55
+ ["torch.float16", "torch.bfloat16"], ["torch.bfloat16", "torch.float16"]]
56
+ if ([npu_type, bench_type] in ms_type)or ([npu_type, bench_type] in torch_type):
57
+ type_match = True
58
+ else:
59
+ type_match = False
60
+ shape_type_match = shape_match and type_match
61
+ if not shape_type_match:
62
+ return False
63
+ return shape_type_match
64
+
65
+
66
+ def check_graph_mode(a_op_name, b_op_name):
67
+ if "Aten" in a_op_name and "Aten" not in b_op_name:
68
+ return True
69
+ if "Aten" not in a_op_name and "Aten" in b_op_name:
70
+ return True
71
+ return False
72
+
73
+
74
+ def fuzzy_check_op(npu_name_list, bench_name_list):
75
+ if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list):
76
+ return False
77
+ is_match = True
78
+ for npu_name, bench_name in zip(npu_name_list, bench_name_list):
79
+ is_match = fuzzy_check_name(npu_name, bench_name)
80
+ if not is_match:
81
+ break
82
+ return is_match
83
+
84
+
85
+ def fuzzy_check_name(npu_name, bench_name):
86
+ if "forward" in npu_name and "forward" in bench_name:
87
+ is_match = rename_api(npu_name, "forward") == rename_api(bench_name, "forward")
88
+ elif "backward" in npu_name and "backward" in bench_name:
89
+ is_match = rename_api(npu_name, "backward") == rename_api(bench_name, "backward")
90
+ else:
91
+ is_match = npu_name == bench_name
92
+ return is_match
93
+
94
+
95
+
@@ -0,0 +1,49 @@
1
+ import json
2
+ from msprobe.core.common.file_check import FileOpen, check_file_type
3
+ from msprobe.core.common.const import FileCheckConst, Const
4
+ from msprobe.core.common.utils import CompareException
5
+ from msprobe.core.common.log import logger
6
+
7
+
8
+ def compare_cli(args):
9
+ with FileOpen(args.input_path, "r") as file:
10
+ input_param = json.load(file)
11
+ npu_path = input_param.get("npu_path", None)
12
+ bench_path = input_param.get("bench_path", None)
13
+ frame_name = args.framework
14
+ auto_analyze = not args.compare_only
15
+ if frame_name == Const.PT_FRAMEWORK:
16
+ from msprobe.pytorch.compare.pt_compare import compare
17
+ from msprobe.pytorch.compare.distributed_compare import compare_distributed
18
+ else:
19
+ from msprobe.mindspore.compare.ms_compare import ms_compare
20
+ from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare
21
+ if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE:
22
+ input_param["npu_json_path"] = input_param.pop("npu_path")
23
+ input_param["bench_json_path"] = input_param.pop("bench_path")
24
+ input_param["stack_json_path"] = input_param.pop("stack_path")
25
+ if frame_name == Const.PT_FRAMEWORK:
26
+ compare(input_param, args.output_path, stack_mode=args.stack_mode, auto_analyze=auto_analyze,
27
+ fuzzy_match=args.fuzzy_match)
28
+ else:
29
+ kwargs = {
30
+ "stack_mode": args.stack_mode,
31
+ "auto_analyze": auto_analyze,
32
+ "fuzzy_match": args.fuzzy_match,
33
+ "cell_mapping": args.cell_mapping,
34
+ "api_mapping": args.api_mapping,
35
+ }
36
+
37
+ ms_compare(input_param, args.output_path, **kwargs)
38
+ elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
39
+ kwargs = {"stack_mode": args.stack_mode, "auto_analyze": auto_analyze, "fuzzy_match": args.fuzzy_match}
40
+ if input_param.get("rank_id") is not None:
41
+ ms_graph_compare(input_param, args.output_path)
42
+ return
43
+ if frame_name == Const.PT_FRAMEWORK:
44
+ compare_distributed(npu_path, bench_path, args.output_path, **kwargs)
45
+ else:
46
+ ms_compare_distributed(npu_path, bench_path, args.output_path, **kwargs)
47
+ else:
48
+ logger.error("The npu_path and bench_path need to be of the same type.")
49
+ raise CompareException(CompareException.INVALID_COMPARE_MODE)
@@ -0,0 +1,222 @@
1
+ import math
2
+ import abc
3
+ from collections import namedtuple
4
+ import numpy as np
5
+ import openpyxl
6
+ from openpyxl.styles import PatternFill
7
+ from msprobe.core.common.utils import get_header_index, save_workbook
8
+ from msprobe.core.common.log import logger
9
+ from msprobe.core.common.const import CompareConst
10
+
11
+
12
+ class HighlightCheck(abc.ABC):
13
+ @abc.abstractmethod
14
+ def apply(self, info, color_columns, summary_compare):
15
+ raise NotImplementedError
16
+
17
+
18
+ class CheckOrderMagnitude(HighlightCheck):
19
+ """检查Max diff的数量级差异"""
20
+ def apply(self, info, color_columns, summary_compare=True):
21
+ api_in, api_out, num = info
22
+ max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
23
+ if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]):
24
+ return
25
+ in_order = 0 if abs(api_in[max_diff_index]) < 1 else math.log10(abs(api_in[max_diff_index]))
26
+ out_order = 0 if abs(api_out[max_diff_index]) < 1 else math.log10(abs(api_out[max_diff_index]))
27
+ if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW:
28
+ color_columns.yellow.append(num)
29
+
30
+
31
+ class CheckOneThousandErrorRatio(HighlightCheck):
32
+ """检查千分误差比率"""
33
+ def apply(self, info, color_columns, summary_compare=True):
34
+ api_in, api_out, num = info
35
+ one_thousand_index = get_header_index('One Thousandth Err Ratio', summary_compare)
36
+ if not isinstance(api_in[one_thousand_index], (float, int)) or not isinstance(api_out[one_thousand_index], (float, int)):
37
+ return
38
+ if api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED:
39
+ color_columns.red.append(num)
40
+ elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
41
+ color_columns.yellow.append(num)
42
+
43
+
44
+ class CheckCosineSimilarity(HighlightCheck):
45
+ """检查余弦相似度"""
46
+ def apply(self, info, color_columns, summary_compare=True):
47
+ api_in, api_out, num = info
48
+ cosine_index = get_header_index('Cosine', summary_compare)
49
+ if not isinstance(api_in[cosine_index], (float, int)) or not isinstance(api_out[cosine_index], (float, int)):
50
+ return
51
+ if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
52
+ color_columns.yellow.append(num)
53
+
54
+
55
+ class CheckMaxRelativeDiff(HighlightCheck):
56
+ """检查最大相对差异"""
57
+ def apply(self, info, color_columns, summary_compare=True):
58
+ api_in, api_out, num = info
59
+ max_diff_index = get_header_index('Max diff', summary_compare)
60
+ bench_max_index = get_header_index('Bench max', summary_compare)
61
+ input_max_relative_diff = np.abs(np.divide(api_in[max_diff_index], max(0.01, api_in[bench_max_index])))
62
+ output_max_relative_diff = np.abs(np.divide(api_out[max_diff_index], max(0.01, api_out[bench_max_index])))
63
+ if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff,
64
+ (float, int)):
65
+ return
66
+ if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
67
+ color_columns.red.append(num)
68
+ elif output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW:
69
+ color_columns.yellow.append(num)
70
+
71
+
72
+ class CheckOverflow(HighlightCheck):
73
+ """检查是否存在溢出"""
74
+ def apply(self, info, color_columns, summary_compare=True):
75
+ line, num = info
76
+ npu_max_index = get_header_index('NPU max', summary_compare)
77
+ npu_min_index = get_header_index('NPU min', summary_compare)
78
+ max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
79
+ if str(line[npu_max_index]) in CompareConst.OVERFLOW_LIST or str(
80
+ line[npu_min_index]) in CompareConst.OVERFLOW_LIST:
81
+ color_columns.red.append(num)
82
+ return
83
+ # check if Max_Diff > 1e+10
84
+ if isinstance(line[max_diff_index], (float, int)) and line[max_diff_index] > CompareConst.MAX_DIFF_RED:
85
+ color_columns.red.append(num)
86
+
87
+
88
+ class HighlightRules:
89
+ """高亮规则集合,用于检查API的误差"""
90
+ # 适用于每行的规则
91
+ basic_rules = {
92
+ "check_overflow": CheckOverflow()
93
+ }
94
+
95
+ # 用于比较输入和输出的规则
96
+ compare_rules = {
97
+ "check_order_magnitude": CheckOrderMagnitude(),
98
+ "check_one_thousand_error": CheckOneThousandErrorRatio(),
99
+ "check_cosine_similarity": CheckCosineSimilarity()
100
+ }
101
+ summary_compare_rules = {
102
+ "check_order_magnitude": CheckOrderMagnitude(),
103
+ "check_max_relative_diff": CheckMaxRelativeDiff(),
104
+ }
105
+
106
+
107
+ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False, md5_compare=False):
108
+ """找到单个API中需要高亮的行"""
109
+ if md5_compare:
110
+ return
111
+ npu_max_index = get_header_index('NPU max', summary_compare)
112
+ bench_max_index = get_header_index('Bench max', summary_compare)
113
+ max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
114
+
115
+ red_lines, yellow_lines = [], []
116
+ LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
117
+ ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer'])
118
+ ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
119
+ color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
120
+
121
+ # 对单行API的输入或输出进行误差判断
122
+ for i, line in enumerate(result):
123
+ num = last_len + i
124
+ line_info = LineInfo(line_data=line, num_pointer=num)
125
+ for rule in HighlightRules.basic_rules.values():
126
+ rule.apply(line_info, color_columns, summary_compare)
127
+
128
+ # 对API的输出与输入比较,进行误差判断
129
+ for n, api_out in enumerate(result[n_num_input:len(result)]):
130
+ num = last_len + n_num_input + n
131
+ if num in red_lines:
132
+ continue
133
+ if not isinstance(api_out[npu_max_index], (float, int)) \
134
+ or not isinstance(api_out[bench_max_index], (float, int)) \
135
+ or not isinstance(api_out[max_diff_index], (float, int)):
136
+ continue
137
+ for _, api_in in enumerate(result[0:n_num_input]):
138
+ if not isinstance(api_in[npu_max_index], (float, int)) \
139
+ or not isinstance(api_in[bench_max_index], (float, int)) \
140
+ or not isinstance(api_in[max_diff_index], (float, int)):
141
+ continue
142
+
143
+ api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num)
144
+ if summary_compare:
145
+ for rule in HighlightRules.summary_compare_rules.values():
146
+ rule.apply(api_info, color_columns, summary_compare)
147
+ else:
148
+ for rule in HighlightRules.compare_rules.values():
149
+ rule.apply(api_info, color_columns, summary_compare)
150
+
151
+ highlight_dict.get('red_rows', []).extend(list(set(red_lines)))
152
+ highlight_dict.get('yellow_rows', []).extend(list(set(yellow_lines) - set(red_lines)))
153
+
154
+
155
+ def get_name_and_state(name):
156
+ """Get api/module name and state"""
157
+ if "input" in name:
158
+ api_name = name.split("input")[0]
159
+ state = "input"
160
+ else:
161
+ api_name = name.split("output")[0]
162
+ state = "output"
163
+ return api_name, state
164
+
165
+
166
+ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare):
167
+ """将dataframe根据API分组,并找到有误差的算子用于高亮"""
168
+ result = result_df.values
169
+ start, input_num, output_num, end = 0, 0, 0, len(result_df)
170
+ last_api_name, last_state = None, None
171
+ num, last_len = 0, 0
172
+ for res_i in result:
173
+ api_name, state = get_name_and_state(res_i[0])
174
+ if last_api_name:
175
+ if api_name == last_api_name:
176
+ if state == last_state:
177
+ num += 1
178
+ else:
179
+ input_num = num
180
+ num, last_state = 1, state
181
+ else:
182
+ output_num = num
183
+ find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
184
+ summary_compare, md5_compare)
185
+ num, last_api_name, last_state = 1, api_name, state
186
+ start += input_num + output_num
187
+ input_num, output_num = 1, 0
188
+ else:
189
+ num, last_api_name, last_state = 1, api_name, state
190
+ if state:
191
+ if state == "input":
192
+ input_num = num
193
+ else:
194
+ output_num = num
195
+ find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare, md5_compare)
196
+
197
+
198
+ def highlight_rows_xlsx(result_df, highlight_dict, file_path):
199
+ """Write and highlight results in Excel"""
200
+ logger.info('Compare result is %s' % file_path)
201
+
202
+ wb = openpyxl.Workbook()
203
+ ws = wb.active
204
+
205
+ # write header
206
+ for j, col_name in enumerate(result_df.columns, start=1):
207
+ ws.cell(row=1, column=j, value=col_name)
208
+
209
+ for i, row in enumerate(result_df.iterrows(), start=2):
210
+ for j, value in enumerate(row[1], start=1):
211
+ if not isinstance(value, (float, int)):
212
+ value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value)
213
+ ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value)
214
+
215
+ if (i - 2) in highlight_dict['red_rows']:
216
+ ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.RED,
217
+ end_color=CompareConst.RED, fill_type="solid")
218
+ elif (i - 2) in highlight_dict['yellow_rows']:
219
+ ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW,
220
+ end_color=CompareConst.YELLOW, fill_type="solid")
221
+
222
+ save_workbook(wb, file_path)