mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -0,0 +1,149 @@
1
+
2
+ import multiprocessing
3
+ from dataclasses import dataclass
4
+ from functools import partial
5
+ import numpy as np
6
+ import pandas as pd
7
+ from msprobe.core.common.log import logger
8
+ from msprobe.core.common.utils import CompareException
9
+ from msprobe.core.common.const import CompareConst
10
+
11
+
12
+ def _handle_multi_process(func, input_parma, result_df, lock):
13
+ process_num = int((multiprocessing.cpu_count() + 1) / 2)
14
+ op_name_mapping_dict = read_dump_data(result_df)
15
+
16
+ df_chunk_size = len(result_df) // process_num
17
+ if df_chunk_size > 0:
18
+ df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
19
+ else:
20
+ df_chunks = [result_df]
21
+
22
+ results = []
23
+ pool = multiprocessing.Pool(process_num)
24
+
25
+ def err_call(args):
26
+ logger.error('multiprocess compare failed! Reason: {}'.format(args))
27
+ try:
28
+ pool.terminate()
29
+ except OSError as e:
30
+ logger.error("pool terminate failed")
31
+
32
+ for process_idx, df_chunk in enumerate(df_chunks):
33
+ idx = df_chunk_size * process_idx
34
+ result = pool.apply_async(func,
35
+ args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma),
36
+ error_callback=err_call)
37
+ results.append(result)
38
+ final_results = [r.get() for r in results]
39
+ pool.close()
40
+ pool.join()
41
+ return pd.concat(final_results, ignore_index=True)
42
+
43
+
44
+ def _ms_graph_handle_multi_process(func, result_df, mode):
45
+ process_num = int((multiprocessing.cpu_count() + 1) // 2)
46
+ df_chunk_size = len(result_df) // process_num
47
+ if df_chunk_size > 0:
48
+ df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
49
+ else:
50
+ df_chunks = [result_df]
51
+
52
+ results = []
53
+ pool = multiprocessing.Pool(process_num)
54
+
55
+ def err_call(args):
56
+ logger.error('multiprocess compare failed! Reason: {}'.format(args))
57
+ try:
58
+ pool.terminate()
59
+ except OSError as e:
60
+ logger.error("pool terminate failed")
61
+
62
+ for df_chunk in df_chunks:
63
+ result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call)
64
+ results.append(result)
65
+ final_results = [r.get() for r in results]
66
+ pool.close()
67
+ pool.join()
68
+ return pd.concat(final_results, ignore_index=True)
69
+
70
+
71
+ def read_dump_data(result_df):
72
+ try:
73
+ npu_dump_name_list = result_df.iloc[0:, 0].tolist()
74
+ npu_dump_tensor_list = result_df.iloc[0:, -1].tolist()
75
+ op_name_mapping_dict = {}
76
+ for index, _ in enumerate(npu_dump_name_list):
77
+ npu_dump_name = npu_dump_name_list[index]
78
+ npu_dump_tensor = npu_dump_tensor_list[index]
79
+ op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor]
80
+ return op_name_mapping_dict
81
+ except ValueError as e:
82
+ logger.error('result dataframe is not found.')
83
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
84
+ except IndexError as e:
85
+ logger.error('result dataframe elements can not be access.')
86
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
87
+
88
+ @dataclass
89
+ class ComparisonResult:
90
+ cos_result: list
91
+ max_err_result: list
92
+ max_relative_err_result: list
93
+ err_msgs: list
94
+ one_thousand_err_ratio_result: list
95
+ five_thousand_err_ratio_result: list
96
+
97
+
98
+ def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
99
+ """
100
+ Save comparison results into the result DataFrame with thread safety.
101
+ Args:
102
+ offset: offset for index
103
+ result: data struct of ComparisonResult
104
+ result_df: result of DataFrame
105
+ lock: thread lock
106
+
107
+ Returns:
108
+ comparison results in DataFrame
109
+ """
110
+
111
+ lock.acquire()
112
+ try:
113
+ for i, _ in enumerate(result.cos_result):
114
+ process_index = i + offset
115
+ result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i]
116
+ result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
117
+ result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
118
+ result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
119
+ result_df.loc[process_index, CompareConst.ACCURACY] = check_accuracy(result.cos_result[i], result.max_err_result[i])
120
+ result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousand_err_ratio_result[i]
121
+ result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousand_err_ratio_result[i]
122
+ return result_df
123
+ except ValueError as e:
124
+ logger.error('result dataframe is not found.')
125
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
126
+ except IndexError as e:
127
+ logger.error('result dataframe elements can not be access.')
128
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
129
+ finally:
130
+ lock.release()
131
+
132
+
133
+ def check_accuracy(cos, max_abs_err):
134
+ if cos == CompareConst.SHAPE_UNMATCH:
135
+ return CompareConst.ACCURACY_CHECK_UNMATCH
136
+ if cos == CompareConst.NONE or max_abs_err == CompareConst.NONE:
137
+ return CompareConst.NONE
138
+ if cos == "N/A" or max_abs_err == "N/A":
139
+ return CompareConst.ACCURACY_CHECK_NO
140
+ try:
141
+ cos, max_abs_err = float(cos), float(max_abs_err)
142
+ except ValueError:
143
+ logger.warning("Cosine or MaxAbsErr can not get float value.")
144
+ return CompareConst.NONE
145
+ if cos < CompareConst.COS_THRESHOLD and max_abs_err > CompareConst.MAX_ABS_ERR_THRESHOLD:
146
+ return CompareConst.ACCURACY_CHECK_NO
147
+ if cos < CompareConst.COS_MAX_THRESHOLD or max_abs_err > CompareConst.MAX_ABS_ERR_MAX_THRESHOLD:
148
+ return CompareConst.ACCURACY_CHECK_NO
149
+ return CompareConst.ACCURACY_CHECK_YES
@@ -2,10 +2,10 @@ import abc
2
2
  import numpy as np
3
3
  from msprobe.core.common.utils import format_value
4
4
  from msprobe.core.common.const import Const, CompareConst
5
- from msprobe.pytorch.common.log import logger
5
+ from msprobe.core.common.log import logger
6
6
 
7
7
 
8
- def handle_inf_nan(n_value, b_value):
8
+ def handle_inf_nan(n_value, b_value):
9
9
  """处理inf和nan的数据"""
10
10
  n_inf = np.isinf(n_value)
11
11
  b_inf = np.isinf(b_value)
@@ -54,7 +54,7 @@ def reshape_value(n_value, b_value):
54
54
  return n_value, b_value
55
55
 
56
56
 
57
- def get_error_message(n_value, b_value, op_name, error_flag, error_file=None):
57
+ def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None):
58
58
  """获取异常情况的错误信息"""
59
59
  if error_flag:
60
60
  if n_value == CompareConst.READ_NONE:
@@ -71,11 +71,62 @@ def get_error_message(n_value, b_value, op_name, error_flag, error_file=None):
71
71
  if not n_value.shape:
72
72
  return "This is type of scalar data, can not compare."
73
73
  if n_value.dtype != b_value.dtype:
74
- logger.warning("Dtype of NPU and bench Tensor do not match: {}".format(op_name))
74
+ logger.warning("Dtype of NPU and bench Tensor do not match: {}".format(npu_op_name))
75
75
  return "Dtype of NPU and bench Tensor do not match."
76
76
  return ""
77
77
 
78
78
 
79
+ def npy_data_check(n_value, b_value):
80
+ error_message = ""
81
+ if n_value is None or b_value is None:
82
+ error_message += "Dump file not found.\n"
83
+ if n_value == "" or b_value == "":
84
+ error_message += "Dump file not found.\n"
85
+
86
+ # 检查 n_value 和 b_value 是否为空
87
+ if not error_message and (n_value.size == 0 or b_value.size == 0):
88
+ error_message += "This is empty data, can not compare.\n"
89
+
90
+ if not error_message:
91
+ if not n_value.shape or not b_value.shape:
92
+ error_message += "This is type of scalar data, can not compare.\n"
93
+ if n_value.shape != b_value.shape:
94
+ error_message += "Shape of NPU and bench Tensor do not match.\n"
95
+ if n_value.dtype != b_value.dtype:
96
+ error_message += "Dtype of NPU and bench Tensor do not match. Skipped.\n"
97
+
98
+ if not error_message:
99
+ n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有 nan/inf 数据
100
+ if CompareConst.NAN in (n_value, b_value):
101
+ error_message += "The position of inf or nan in NPU and bench Tensor do not match.\n"
102
+ if error_message == "":
103
+ error_flag = False
104
+ else:
105
+ error_flag = True
106
+ return error_flag, error_message
107
+
108
+
109
+ def statistics_data_check(result_dict):
110
+ error_message = ""
111
+
112
+ if result_dict.get(CompareConst.NPU_NAME) is None or result_dict.get(CompareConst.BENCH_NAME) is None:
113
+ error_message += "Dump file not found.\n"
114
+
115
+ if not result_dict.get(CompareConst.NPU_SHAPE) or not result_dict.get(CompareConst.BENCH_SHAPE):
116
+ error_message += "This is type of scalar data, can not compare.\n"
117
+ elif result_dict.get(CompareConst.NPU_SHAPE) != result_dict.get(CompareConst.BENCH_SHAPE):
118
+ error_message += "Tensor shapes do not match.\n"
119
+
120
+ if result_dict.get(CompareConst.NPU_DTYPE) != result_dict.get(CompareConst.BENCH_DTYPE):
121
+ error_message += "Dtype of NPU and bench Tensor do not match. Skipped.\n"
122
+
123
+ if error_message == "":
124
+ error_flag = False
125
+ else:
126
+ error_flag = True
127
+ return error_flag, error_message
128
+
129
+
79
130
  class TensorComparisonBasic(abc.ABC):
80
131
  """NPU和bench中npy数据的比较模板"""
81
132
  @abc.abstractmethod
@@ -0,0 +1,429 @@
1
+
2
+ import os
3
+ import re
4
+ import numpy as np
5
+ from msprobe.core.common.const import Const, CompareConst
6
+ from msprobe.core.common.utils import CompareException, check_file_or_directory_path, check_regex_prefix_format_valid, logger
7
+
8
+
9
+ def extract_json(dirname, stack_json=False):
10
+ json_path = ''
11
+ for fname in os.listdir(dirname):
12
+ if fname == "construct.json":
13
+ continue
14
+ full_path = os.path.join(dirname, fname)
15
+ if full_path.endswith('.json'):
16
+ json_path = full_path
17
+ if not stack_json and 'stack' not in json_path:
18
+ break
19
+ if stack_json and 'stack' in json_path:
20
+ break
21
+
22
+ # Provide robustness on invalid directory inputs
23
+ if not json_path:
24
+ logger.error(f'No file is found in dump dir {dirname}. ')
25
+ raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
26
+ return json_path
27
+
28
+
29
+ def check_and_return_dir_contents(dump_dir, prefix):
30
+ """
31
+ check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
32
+ pattern: ^{prefix}(?:0|[0-9][1-9]*)?$
33
+
34
+ Args:
35
+ dump_dir (str): dump dir
36
+ prefix (str): prefix for the patterns, prefix should be less than 20 characters and alphanumeric/-/_ only
37
+
38
+ Returns:
39
+ content [list]: dir contents
40
+ Raises:
41
+ CompareException: invalid path
42
+ ValueError: prefix not match the patterns
43
+
44
+ """
45
+ check_regex_prefix_format_valid(prefix)
46
+ check_file_or_directory_path(dump_dir, True)
47
+ contents = os.listdir(dump_dir)
48
+ pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$')
49
+ for name in contents:
50
+ if not pattern.match(name):
51
+ logger.error(
52
+ f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump "
53
+ f"output. Please check and delete irrelevant files in {dump_dir} and try again."
54
+ )
55
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
56
+ return contents
57
+
58
+
59
+ def rename_api(npu_name, process):
60
+ npu_split = npu_name.split(process)
61
+ torch_func_index, in_out = npu_split[0], npu_split[1]
62
+ torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
63
+ torch_func = str(torch_func_split[0]) + str(in_out)
64
+ return torch_func
65
+
66
+
67
+ def read_op(op_data, op_name):
68
+ op_parsed_list = Const.DEFAULT_LIST
69
+ if Const.FORWARD in op_name:
70
+ if Const.INPUT_ARGS in op_data:
71
+ input_item = op_data[Const.INPUT_ARGS]
72
+ input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
73
+ op_parsed_list = input_parsed_list.copy()
74
+ input_parsed_list.clear()
75
+ if Const.INPUT_KWARGS in op_data:
76
+ kwargs_item = op_data[Const.INPUT_KWARGS]
77
+ if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list):
78
+ kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '.input', None)
79
+ op_parsed_list += kwarg_parsed_list
80
+ kwarg_parsed_list.clear()
81
+ elif kwargs_item:
82
+ for kwarg in kwargs_item:
83
+ kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '.input.' + kwarg, None)
84
+ op_parsed_list += kwarg_parsed_list
85
+ kwarg_parsed_list.clear()
86
+ if Const.OUTPUT in op_data:
87
+ output_item = op_data[Const.OUTPUT]
88
+ output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
89
+ op_parsed_list += output_parsed_list
90
+ output_parsed_list.clear()
91
+ if Const.BACKWARD in op_name:
92
+ if Const.INPUT in op_data:
93
+ input_item = op_data[Const.INPUT]
94
+ input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
95
+ op_parsed_list = input_parsed_list.copy()
96
+ input_parsed_list.clear()
97
+ if Const.OUTPUT in op_data:
98
+ output_item = op_data[Const.OUTPUT]
99
+ output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
100
+ op_parsed_list += output_parsed_list
101
+ output_parsed_list.clear()
102
+ return op_parsed_list
103
+
104
+
105
+ def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
106
+ if item_list is None:
107
+ item_list = []
108
+ if item is None or (isinstance(item, dict) and not item):
109
+ if not top_bool:
110
+ tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None,
111
+ 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'}
112
+ else:
113
+ tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None,
114
+ 'shape': None, 'md5': None, 'data_name': '-1'}
115
+ item_list.append(tmp)
116
+ return item_list
117
+ if index is None:
118
+ if isinstance(item, dict):
119
+ full_op_name = op_name + '.0'
120
+ else:
121
+ full_op_name = op_name
122
+ else:
123
+ full_op_name = op_name + Const.SEP + str(index)
124
+ if isinstance(item, dict):
125
+ if 'type' not in item:
126
+ for kwarg in item:
127
+ kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None)
128
+ item_list += kwarg_parsed_list
129
+ kwarg_parsed_list.clear()
130
+ elif 'dtype' in item:
131
+ parsed_item = item
132
+ parsed_item['full_op_name'] = full_op_name
133
+ item_list.append(parsed_item)
134
+ elif 'type' in item:
135
+ parsed_item = {}
136
+ if item['type'] == 'torch.Size':
137
+ parsed_item['full_op_name'] = full_op_name
138
+ parsed_item['dtype'] = 'torch.Size'
139
+ parsed_item['shape'] = str(item['value'])
140
+ parsed_item['md5'] = None
141
+ parsed_item['Max'] = None
142
+ parsed_item['Min'] = None
143
+ parsed_item['Mean'] = None
144
+ parsed_item['Norm'] = None
145
+ parsed_item['data_name'] = '-1'
146
+ item_list.append(parsed_item)
147
+ elif item['type'] == 'slice':
148
+ parsed_item['full_op_name'] = full_op_name
149
+ parsed_item['dtype'] = 'slice'
150
+ parsed_item['shape'] = str(np.shape(np.array(item['value'])))
151
+ parsed_item['md5'] = None
152
+ parsed_item['Max'] = None
153
+ parsed_item['Min'] = None
154
+ parsed_item['Mean'] = None
155
+ parsed_item['Norm'] = None
156
+ parsed_item['data_name'] = '-1'
157
+ item_list.append(parsed_item)
158
+ else:
159
+ parsed_item['full_op_name'] = full_op_name
160
+ parsed_item['dtype'] = str(type(item['value']))
161
+ parsed_item['shape'] = '[]'
162
+ parsed_item['md5'] = None
163
+ parsed_item['Max'] = item['value']
164
+ parsed_item['Min'] = item['value']
165
+ parsed_item['Mean'] = item['value']
166
+ parsed_item['Norm'] = item['value']
167
+ parsed_item['data_name'] = '-1'
168
+ item_list.append(parsed_item)
169
+ else:
170
+ resolve_api_special_parameters(item, full_op_name, item_list)
171
+ else:
172
+ for j, item_spec in enumerate(item):
173
+ op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False)
174
+ return item_list
175
+
176
+
177
+ def resolve_api_special_parameters(data_dict, full_op_name, item_list):
178
+ """
179
+ Function Description:
180
+ 解析下面格式的数据, 是api参数的一种特殊格式
181
+ {
182
+ "last_hidden_state": {
183
+ "type": "torch.Tensor",
184
+ "dtype": "torch.bfloat16",
185
+ ...
186
+ },
187
+ "loss": {
188
+ "type": "torch.Tensor",
189
+ "dtype": "torch.float32",
190
+ ...
191
+ }
192
+ }
193
+ Parameter:
194
+ data_dict: 字典格式的数据
195
+ full_op_name: 参数的全名字符串
196
+ item_list: 参数信息集合
197
+ """
198
+ for key, value in data_dict.items():
199
+ if isinstance(value, dict):
200
+ parsed_item = value
201
+ parts = full_op_name.split(Const.SEP)
202
+ parts.insert(-1, key)
203
+ full_op_name_new = ".".join(parts)
204
+ parsed_item['full_op_name'] = full_op_name_new
205
+ item_list.append(parsed_item)
206
+
207
+
208
+ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False):
209
+ def get_accuracy_core(n_start, n_len, b_start, b_len, key):
210
+ min_len = min(n_len, b_len)
211
+ npu_stack_info = n_dict.get("stack_info", None)
212
+ bench_stack_info = b_dict.get("stack_info", None)
213
+ has_stack = npu_stack_info and bench_stack_info
214
+
215
+ all_mode_bool = not (summary_compare or md5_compare)
216
+ if all_mode_bool:
217
+ npu_data_name = n_dict.get("data_name", None)
218
+ bench_data_name = b_dict.get("data_name", None)
219
+
220
+ for index in range(min_len):
221
+
222
+ n_name = n_dict['op_name'][n_start + index]
223
+ b_name = b_dict['op_name'][b_start + index]
224
+ n_struct = n_dict[key][index]
225
+ b_struct = b_dict[key][index]
226
+ err_msg = ""
227
+ if md5_compare:
228
+ result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
229
+ n_struct[2], b_struct[2],
230
+ CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF]
231
+ if has_stack and index == 0 and key == "input_struct":
232
+ result_item.extend(npu_stack_info)
233
+ else:
234
+ result_item.append(CompareConst.NONE)
235
+ result.append(result_item)
236
+ continue
237
+
238
+ if summary_compare:
239
+ result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
240
+ " ", " ", " ", " ", " ", " ", " ", " "]
241
+ else:
242
+ result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
243
+ " ", " ", " ", " ", " "]
244
+
245
+ npu_summary_data = n_dict.get("summary")[n_start + index]
246
+ result_item.extend(npu_summary_data)
247
+ bench_summary_data = b_dict.get("summary")[b_start + index]
248
+ result_item.extend(bench_summary_data)
249
+
250
+ if summary_compare:
251
+ start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
252
+ warning_flag = False
253
+ for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
254
+ if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
255
+ diff = npu_val - bench_val
256
+ if bench_val != 0:
257
+ relative = str(abs((diff / bench_val) * 100)) + '%'
258
+ else:
259
+ relative = "N/A"
260
+ result_item[start_idx + i] = diff
261
+ result_item[start_idx + i + 4] = relative
262
+ magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
263
+ if magnitude_diff > 0.5:
264
+ warning_flag = True
265
+ else:
266
+ result_item[start_idx + i] = CompareConst.NONE
267
+ accuracy_check = CompareConst.WARNING if warning_flag else ""
268
+ err_msg += "Need double check api accuracy." if warning_flag else ""
269
+ for i in range(start_idx, len(result_item)):
270
+ if str(result_item[i]) in ('inf', '-inf', 'nan'):
271
+ result_item[i] = f'{result_item[i]}\t'
272
+
273
+ result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES)
274
+ result_item.append(err_msg)
275
+ if has_stack and index == 0 and key == "input_struct":
276
+ result_item.extend(npu_stack_info)
277
+ else:
278
+ result_item.append(CompareConst.NONE)
279
+ if all_mode_bool:
280
+ result_item.append(npu_data_name[n_start + index])
281
+
282
+ result.append(result_item)
283
+
284
+ if n_len > b_len:
285
+ for index in range(b_len, n_len):
286
+ n_name = n_dict['op_name'][n_start + index]
287
+ n_struct = n_dict[key][index]
288
+ if md5_compare:
289
+ result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
290
+ n_struct[1], CompareConst.NAN, n_struct[2], CompareConst.NAN, CompareConst.NAN]
291
+ result.append(result_item)
292
+ continue
293
+ result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
294
+ n_struct[1], CompareConst.NAN, " ", " ", " ", " ", " "]
295
+ summary_data = n_dict.get("summary")[n_start + index]
296
+ result_item.extend(summary_data)
297
+ summary_data = [CompareConst.NAN for _ in range(len(n_dict.get("summary")[0]))]
298
+ result_item.extend(summary_data)
299
+
300
+ err_msg = ""
301
+ result_item.append(CompareConst.ACCURACY_CHECK_YES)
302
+ result_item.append(err_msg)
303
+
304
+ if has_stack and index == 0 and key == "input_struct":
305
+ result_item.extend(npu_stack_info)
306
+ else:
307
+ result_item.append(CompareConst.NONE)
308
+ if all_mode_bool:
309
+ result_item.append(npu_data_name[n_start + index])
310
+
311
+ result.append(result_item)
312
+
313
+ n_num = len(n_dict['op_name'])
314
+ b_num = len(b_dict['op_name'])
315
+ n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name])
316
+ b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name])
317
+ n_num_kwarg = len([name for name in n_dict['op_name'] if 'kwarg' in name])
318
+ b_num_kwarg = len([name for name in b_dict['op_name'] if 'kwarg' in name])
319
+ n_num_output = n_num - n_num_input - n_num_kwarg
320
+ b_num_output = b_num - b_num_input - b_num_kwarg
321
+ get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
322
+ get_accuracy_core(n_num_input, n_num_kwarg, b_num_input, b_num_kwarg, "kwargs_struct")
323
+ get_accuracy_core(n_num_input + n_num_kwarg, n_num_output, b_num_input + b_num_kwarg, b_num_output, 'output_struct')
324
+
325
+
326
+ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
327
+ index_out = 0
328
+ npu_stack_info = n_dict.get("stack_info", None)
329
+ bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
330
+ err_msg = CompareConst.NO_BENCH
331
+ accuracy_check_res = CompareConst.N_A
332
+ for index, n_name in enumerate(n_dict["op_name"]):
333
+ if n_name.find("input") != -1:
334
+ n_struct = n_dict["input_struct"][index]
335
+ else:
336
+ n_struct = n_dict["output_struct"][index_out]
337
+ index_out += 1
338
+
339
+ result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
340
+ if md5_compare:
341
+ result_item.extend([CompareConst.N_A] * 3)
342
+ if npu_stack_info and index == 0:
343
+ result_item.extend(npu_stack_info)
344
+ else:
345
+ result_item.append(CompareConst.NONE)
346
+ result.append(result_item)
347
+ continue
348
+ if summary_compare:
349
+ result_item.extend([CompareConst.N_A] * 8)
350
+ else:
351
+ result_item.extend([CompareConst.N_A] * 5)
352
+ npu_summary_data = n_dict.get("summary")[index]
353
+ result_item.extend(npu_summary_data)
354
+ bench_summary_data = [CompareConst.N_A] * 4
355
+ result_item.extend(bench_summary_data)
356
+ result_item.append(accuracy_check_res)
357
+ result_item.append(err_msg)
358
+ if npu_stack_info and index == 0:
359
+ result_item.extend(npu_stack_info)
360
+ else:
361
+ result_item.append(CompareConst.NONE)
362
+ if not md5_compare and not summary_compare and result_item[1] == CompareConst.N_A:
363
+ result_item.extend(["-1"])
364
+ result.append(result_item)
365
+
366
+
367
+ def merge_tensor(tensor_list, summary_compare, md5_compare):
368
+ op_dict = {}
369
+ op_dict["op_name"] = []
370
+ op_dict["input_struct"] = []
371
+ op_dict["kwargs_struct"] = []
372
+ op_dict["output_struct"] = []
373
+ op_dict["summary"] = []
374
+ op_dict["stack_info"] = []
375
+
376
+ all_mode_bool = not (summary_compare or md5_compare)
377
+ if all_mode_bool:
378
+ op_dict["data_name"] = []
379
+
380
+ for tensor in tensor_list:
381
+ if len(tensor) == 2:
382
+ op_dict['stack_info'].append(tensor['full_info'])
383
+ break
384
+ op_dict["op_name"].append(tensor['full_op_name'])
385
+ if not md5_compare:
386
+ if tensor['full_op_name'].find("input") != -1:
387
+ op_dict["input_struct"].append((tensor['dtype'], tensor['shape']))
388
+ elif tensor['full_op_name'].find("kwarg") != -1:
389
+ op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape']))
390
+ elif tensor['full_op_name'].find("output") != -1:
391
+ op_dict["output_struct"].append((tensor['dtype'], tensor['shape']))
392
+ else:
393
+ if tensor['full_op_name'].find("input") != -1:
394
+ op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
395
+ elif tensor['full_op_name'].find("kwarg") != -1:
396
+ op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
397
+ elif tensor['full_op_name'].find("output") != -1:
398
+ op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
399
+
400
+ op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
401
+
402
+ if all_mode_bool:
403
+ op_dict["data_name"].append(tensor['data_name'])
404
+
405
+ if not op_dict["kwargs_struct"]:
406
+ del op_dict["kwargs_struct"]
407
+ return op_dict if op_dict["op_name"] else {}
408
+
409
+
410
+ def _compare_parser(parser):
411
+ parser.add_argument("-i", "--input_path", dest="input_path", type=str,
412
+ help="<Required> The compare input path, a dict json.", required=True)
413
+ parser.add_argument("-o", "--output_path", dest="output_path", type=str,
414
+ help="<Required> The compare task result out path.", required=True)
415
+ parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
416
+ help="<optional> Whether to save stack info.", required=False)
417
+ parser.add_argument("-c", "--compare_only", dest="compare_only", action="store_true",
418
+ help="<optional> Whether to give advisor.", required=False)
419
+ parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true",
420
+ help="<optional> Whether to perform a fuzzy match on the api name.", required=False)
421
+ parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True,
422
+ help="<optional> The cell mapping file path.", required=False)
423
+ parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
424
+ help="<optional> The api mapping file path.", required=False)
425
+
426
+
427
+
428
+
429
+