mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,429 +1,430 @@
1
-
2
- import os
3
- import re
4
- import numpy as np
5
- from msprobe.core.common.const import Const, CompareConst
6
- from msprobe.core.common.utils import CompareException, check_file_or_directory_path, check_regex_prefix_format_valid, logger
7
-
8
-
9
- def extract_json(dirname, stack_json=False):
10
- json_path = ''
11
- for fname in os.listdir(dirname):
12
- if fname == "construct.json":
13
- continue
14
- full_path = os.path.join(dirname, fname)
15
- if full_path.endswith('.json'):
16
- json_path = full_path
17
- if not stack_json and 'stack' not in json_path:
18
- break
19
- if stack_json and 'stack' in json_path:
20
- break
21
-
22
- # Provide robustness on invalid directory inputs
23
- if not json_path:
24
- logger.error(f'No file is found in dump dir {dirname}. ')
25
- raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
26
- return json_path
27
-
28
-
29
- def check_and_return_dir_contents(dump_dir, prefix):
30
- """
31
- check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
32
- pattern: ^{prefix}(?:0|[0-9][1-9]*)?$
33
-
34
- Args:
35
- dump_dir (str): dump dir
36
- prefix (str): prefix for the patterns, prefix should be less than 20 characters and alphanumeric/-/_ only
37
-
38
- Returns:
39
- content [list]: dir contents
40
- Raises:
41
- CompareException: invalid path
42
- ValueError: prefix not match the patterns
43
-
44
- """
45
- check_regex_prefix_format_valid(prefix)
46
- check_file_or_directory_path(dump_dir, True)
47
- contents = os.listdir(dump_dir)
48
- pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$')
49
- for name in contents:
50
- if not pattern.match(name):
51
- logger.error(
52
- f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump "
53
- f"output. Please check and delete irrelevant files in {dump_dir} and try again."
54
- )
55
- raise CompareException(CompareException.INVALID_PATH_ERROR)
56
- return contents
57
-
58
-
59
- def rename_api(npu_name, process):
60
- npu_split = npu_name.split(process)
61
- torch_func_index, in_out = npu_split[0], npu_split[1]
62
- torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
63
- torch_func = str(torch_func_split[0]) + str(in_out)
64
- return torch_func
65
-
66
-
67
- def read_op(op_data, op_name):
68
- op_parsed_list = Const.DEFAULT_LIST
69
- if Const.FORWARD in op_name:
70
- if Const.INPUT_ARGS in op_data:
71
- input_item = op_data[Const.INPUT_ARGS]
72
- input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
73
- op_parsed_list = input_parsed_list.copy()
74
- input_parsed_list.clear()
75
- if Const.INPUT_KWARGS in op_data:
76
- kwargs_item = op_data[Const.INPUT_KWARGS]
77
- if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list):
78
- kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '.input', None)
79
- op_parsed_list += kwarg_parsed_list
80
- kwarg_parsed_list.clear()
81
- elif kwargs_item:
82
- for kwarg in kwargs_item:
83
- kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '.input.' + kwarg, None)
84
- op_parsed_list += kwarg_parsed_list
85
- kwarg_parsed_list.clear()
86
- if Const.OUTPUT in op_data:
87
- output_item = op_data[Const.OUTPUT]
88
- output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
89
- op_parsed_list += output_parsed_list
90
- output_parsed_list.clear()
91
- if Const.BACKWARD in op_name:
92
- if Const.INPUT in op_data:
93
- input_item = op_data[Const.INPUT]
94
- input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
95
- op_parsed_list = input_parsed_list.copy()
96
- input_parsed_list.clear()
97
- if Const.OUTPUT in op_data:
98
- output_item = op_data[Const.OUTPUT]
99
- output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
100
- op_parsed_list += output_parsed_list
101
- output_parsed_list.clear()
102
- return op_parsed_list
103
-
104
-
105
- def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
106
- if item_list is None:
107
- item_list = []
108
- if item is None or (isinstance(item, dict) and not item):
109
- if not top_bool:
110
- tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None,
111
- 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'}
112
- else:
113
- tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None,
114
- 'shape': None, 'md5': None, 'data_name': '-1'}
115
- item_list.append(tmp)
116
- return item_list
117
- if index is None:
118
- if isinstance(item, dict):
119
- full_op_name = op_name + '.0'
120
- else:
121
- full_op_name = op_name
122
- else:
123
- full_op_name = op_name + Const.SEP + str(index)
124
- if isinstance(item, dict):
125
- if 'type' not in item:
126
- for kwarg in item:
127
- kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None)
128
- item_list += kwarg_parsed_list
129
- kwarg_parsed_list.clear()
130
- elif 'dtype' in item:
131
- parsed_item = item
132
- parsed_item['full_op_name'] = full_op_name
133
- item_list.append(parsed_item)
134
- elif 'type' in item:
135
- parsed_item = {}
136
- if item['type'] == 'torch.Size':
137
- parsed_item['full_op_name'] = full_op_name
138
- parsed_item['dtype'] = 'torch.Size'
139
- parsed_item['shape'] = str(item['value'])
140
- parsed_item['md5'] = None
141
- parsed_item['Max'] = None
142
- parsed_item['Min'] = None
143
- parsed_item['Mean'] = None
144
- parsed_item['Norm'] = None
145
- parsed_item['data_name'] = '-1'
146
- item_list.append(parsed_item)
147
- elif item['type'] == 'slice':
148
- parsed_item['full_op_name'] = full_op_name
149
- parsed_item['dtype'] = 'slice'
150
- parsed_item['shape'] = str(np.shape(np.array(item['value'])))
151
- parsed_item['md5'] = None
152
- parsed_item['Max'] = None
153
- parsed_item['Min'] = None
154
- parsed_item['Mean'] = None
155
- parsed_item['Norm'] = None
156
- parsed_item['data_name'] = '-1'
157
- item_list.append(parsed_item)
158
- else:
159
- parsed_item['full_op_name'] = full_op_name
160
- parsed_item['dtype'] = str(type(item['value']))
161
- parsed_item['shape'] = '[]'
162
- parsed_item['md5'] = None
163
- parsed_item['Max'] = item['value']
164
- parsed_item['Min'] = item['value']
165
- parsed_item['Mean'] = item['value']
166
- parsed_item['Norm'] = item['value']
167
- parsed_item['data_name'] = '-1'
168
- item_list.append(parsed_item)
169
- else:
170
- resolve_api_special_parameters(item, full_op_name, item_list)
171
- else:
172
- for j, item_spec in enumerate(item):
173
- op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False)
174
- return item_list
175
-
176
-
177
- def resolve_api_special_parameters(data_dict, full_op_name, item_list):
178
- """
179
- Function Description:
180
- 解析下面格式的数据, 是api参数的一种特殊格式
181
- {
182
- "last_hidden_state": {
183
- "type": "torch.Tensor",
184
- "dtype": "torch.bfloat16",
185
- ...
186
- },
187
- "loss": {
188
- "type": "torch.Tensor",
189
- "dtype": "torch.float32",
190
- ...
191
- }
192
- }
193
- Parameter:
194
- data_dict: 字典格式的数据
195
- full_op_name: 参数的全名字符串
196
- item_list: 参数信息集合
197
- """
198
- for key, value in data_dict.items():
199
- if isinstance(value, dict):
200
- parsed_item = value
201
- parts = full_op_name.split(Const.SEP)
202
- parts.insert(-1, key)
203
- full_op_name_new = ".".join(parts)
204
- parsed_item['full_op_name'] = full_op_name_new
205
- item_list.append(parsed_item)
206
-
207
-
208
- def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False):
209
- def get_accuracy_core(n_start, n_len, b_start, b_len, key):
210
- min_len = min(n_len, b_len)
211
- npu_stack_info = n_dict.get("stack_info", None)
212
- bench_stack_info = b_dict.get("stack_info", None)
213
- has_stack = npu_stack_info and bench_stack_info
214
-
215
- all_mode_bool = not (summary_compare or md5_compare)
216
- if all_mode_bool:
217
- npu_data_name = n_dict.get("data_name", None)
218
- bench_data_name = b_dict.get("data_name", None)
219
-
220
- for index in range(min_len):
221
-
222
- n_name = n_dict['op_name'][n_start + index]
223
- b_name = b_dict['op_name'][b_start + index]
224
- n_struct = n_dict[key][index]
225
- b_struct = b_dict[key][index]
226
- err_msg = ""
227
- if md5_compare:
228
- result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
229
- n_struct[2], b_struct[2],
230
- CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF]
231
- if has_stack and index == 0 and key == "input_struct":
232
- result_item.extend(npu_stack_info)
233
- else:
234
- result_item.append(CompareConst.NONE)
235
- result.append(result_item)
236
- continue
237
-
238
- if summary_compare:
239
- result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
240
- " ", " ", " ", " ", " ", " ", " ", " "]
241
- else:
242
- result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
243
- " ", " ", " ", " ", " "]
244
-
245
- npu_summary_data = n_dict.get("summary")[n_start + index]
246
- result_item.extend(npu_summary_data)
247
- bench_summary_data = b_dict.get("summary")[b_start + index]
248
- result_item.extend(bench_summary_data)
249
-
250
- if summary_compare:
251
- start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
252
- warning_flag = False
253
- for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
254
- if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
255
- diff = npu_val - bench_val
256
- if bench_val != 0:
257
- relative = str(abs((diff / bench_val) * 100)) + '%'
258
- else:
259
- relative = "N/A"
260
- result_item[start_idx + i] = diff
261
- result_item[start_idx + i + 4] = relative
262
- magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
263
- if magnitude_diff > 0.5:
264
- warning_flag = True
265
- else:
266
- result_item[start_idx + i] = CompareConst.NONE
267
- accuracy_check = CompareConst.WARNING if warning_flag else ""
268
- err_msg += "Need double check api accuracy." if warning_flag else ""
269
- for i in range(start_idx, len(result_item)):
270
- if str(result_item[i]) in ('inf', '-inf', 'nan'):
271
- result_item[i] = f'{result_item[i]}\t'
272
-
273
- result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES)
274
- result_item.append(err_msg)
275
- if has_stack and index == 0 and key == "input_struct":
276
- result_item.extend(npu_stack_info)
277
- else:
278
- result_item.append(CompareConst.NONE)
279
- if all_mode_bool:
280
- result_item.append(npu_data_name[n_start + index])
281
-
282
- result.append(result_item)
283
-
284
- if n_len > b_len:
285
- for index in range(b_len, n_len):
286
- n_name = n_dict['op_name'][n_start + index]
287
- n_struct = n_dict[key][index]
288
- if md5_compare:
289
- result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
290
- n_struct[1], CompareConst.NAN, n_struct[2], CompareConst.NAN, CompareConst.NAN]
291
- result.append(result_item)
292
- continue
293
- result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
294
- n_struct[1], CompareConst.NAN, " ", " ", " ", " ", " "]
295
- summary_data = n_dict.get("summary")[n_start + index]
296
- result_item.extend(summary_data)
297
- summary_data = [CompareConst.NAN for _ in range(len(n_dict.get("summary")[0]))]
298
- result_item.extend(summary_data)
299
-
300
- err_msg = ""
301
- result_item.append(CompareConst.ACCURACY_CHECK_YES)
302
- result_item.append(err_msg)
303
-
304
- if has_stack and index == 0 and key == "input_struct":
305
- result_item.extend(npu_stack_info)
306
- else:
307
- result_item.append(CompareConst.NONE)
308
- if all_mode_bool:
309
- result_item.append(npu_data_name[n_start + index])
310
-
311
- result.append(result_item)
312
-
313
- n_num = len(n_dict['op_name'])
314
- b_num = len(b_dict['op_name'])
315
- n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name])
316
- b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name])
317
- n_num_kwarg = len([name for name in n_dict['op_name'] if 'kwarg' in name])
318
- b_num_kwarg = len([name for name in b_dict['op_name'] if 'kwarg' in name])
319
- n_num_output = n_num - n_num_input - n_num_kwarg
320
- b_num_output = b_num - b_num_input - b_num_kwarg
321
- get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
322
- get_accuracy_core(n_num_input, n_num_kwarg, b_num_input, b_num_kwarg, "kwargs_struct")
323
- get_accuracy_core(n_num_input + n_num_kwarg, n_num_output, b_num_input + b_num_kwarg, b_num_output, 'output_struct')
324
-
325
-
326
- def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
327
- index_out = 0
328
- npu_stack_info = n_dict.get("stack_info", None)
329
- bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
330
- err_msg = CompareConst.NO_BENCH
331
- accuracy_check_res = CompareConst.N_A
332
- for index, n_name in enumerate(n_dict["op_name"]):
333
- if n_name.find("input") != -1:
334
- n_struct = n_dict["input_struct"][index]
335
- else:
336
- n_struct = n_dict["output_struct"][index_out]
337
- index_out += 1
338
-
339
- result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
340
- if md5_compare:
341
- result_item.extend([CompareConst.N_A] * 3)
342
- if npu_stack_info and index == 0:
343
- result_item.extend(npu_stack_info)
344
- else:
345
- result_item.append(CompareConst.NONE)
346
- result.append(result_item)
347
- continue
348
- if summary_compare:
349
- result_item.extend([CompareConst.N_A] * 8)
350
- else:
351
- result_item.extend([CompareConst.N_A] * 5)
352
- npu_summary_data = n_dict.get("summary")[index]
353
- result_item.extend(npu_summary_data)
354
- bench_summary_data = [CompareConst.N_A] * 4
355
- result_item.extend(bench_summary_data)
356
- result_item.append(accuracy_check_res)
357
- result_item.append(err_msg)
358
- if npu_stack_info and index == 0:
359
- result_item.extend(npu_stack_info)
360
- else:
361
- result_item.append(CompareConst.NONE)
362
- if not md5_compare and not summary_compare and result_item[1] == CompareConst.N_A:
363
- result_item.extend(["-1"])
364
- result.append(result_item)
365
-
366
-
367
- def merge_tensor(tensor_list, summary_compare, md5_compare):
368
- op_dict = {}
369
- op_dict["op_name"] = []
370
- op_dict["input_struct"] = []
371
- op_dict["kwargs_struct"] = []
372
- op_dict["output_struct"] = []
373
- op_dict["summary"] = []
374
- op_dict["stack_info"] = []
375
-
376
- all_mode_bool = not (summary_compare or md5_compare)
377
- if all_mode_bool:
378
- op_dict["data_name"] = []
379
-
380
- for tensor in tensor_list:
381
- if len(tensor) == 2:
382
- op_dict['stack_info'].append(tensor['full_info'])
383
- break
384
- op_dict["op_name"].append(tensor['full_op_name'])
385
- if not md5_compare:
386
- if tensor['full_op_name'].find("input") != -1:
387
- op_dict["input_struct"].append((tensor['dtype'], tensor['shape']))
388
- elif tensor['full_op_name'].find("kwarg") != -1:
389
- op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape']))
390
- elif tensor['full_op_name'].find("output") != -1:
391
- op_dict["output_struct"].append((tensor['dtype'], tensor['shape']))
392
- else:
393
- if tensor['full_op_name'].find("input") != -1:
394
- op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
395
- elif tensor['full_op_name'].find("kwarg") != -1:
396
- op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
397
- elif tensor['full_op_name'].find("output") != -1:
398
- op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
399
-
400
- op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
401
-
402
- if all_mode_bool:
403
- op_dict["data_name"].append(tensor['data_name'])
404
-
405
- if not op_dict["kwargs_struct"]:
406
- del op_dict["kwargs_struct"]
407
- return op_dict if op_dict["op_name"] else {}
408
-
409
-
410
- def _compare_parser(parser):
411
- parser.add_argument("-i", "--input_path", dest="input_path", type=str,
412
- help="<Required> The compare input path, a dict json.", required=True)
413
- parser.add_argument("-o", "--output_path", dest="output_path", type=str,
414
- help="<Required> The compare task result out path.", required=True)
415
- parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
416
- help="<optional> Whether to save stack info.", required=False)
417
- parser.add_argument("-c", "--compare_only", dest="compare_only", action="store_true",
418
- help="<optional> Whether to give advisor.", required=False)
419
- parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true",
420
- help="<optional> Whether to perform a fuzzy match on the api name.", required=False)
421
- parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True,
422
- help="<optional> The cell mapping file path.", required=False)
423
- parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
424
- help="<optional> The api mapping file path.", required=False)
425
-
426
-
427
-
428
-
429
-
1
+
2
+ import os
3
+ import re
4
+ import numpy as np
5
+ from msprobe.core.common.const import Const, CompareConst
6
+ from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger
7
+ from msprobe.core.common.file_utils import check_file_or_directory_path
8
+
9
+
10
+ def extract_json(dirname, stack_json=False):
11
+ json_path = ''
12
+ for fname in os.listdir(dirname):
13
+ if fname == "construct.json":
14
+ continue
15
+ full_path = os.path.join(dirname, fname)
16
+ if full_path.endswith('.json'):
17
+ json_path = full_path
18
+ if not stack_json and 'stack' not in json_path:
19
+ break
20
+ if stack_json and 'stack' in json_path:
21
+ break
22
+
23
+ # Provide robustness on invalid directory inputs
24
+ if not json_path:
25
+ logger.error(f'No file is found in dump dir {dirname}. ')
26
+ raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
27
+ return json_path
28
+
29
+
30
+ def check_and_return_dir_contents(dump_dir, prefix):
31
+ """
32
+ check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
33
+ pattern: ^{prefix}(?:0|[0-9][1-9]*)?$
34
+
35
+ Args:
36
+ dump_dir (str): dump dir
37
+ prefix (str): prefix for the patterns, prefix should be less than 20 characters and alphanumeric/-/_ only
38
+
39
+ Returns:
40
+ content [list]: dir contents
41
+ Raises:
42
+ CompareException: invalid path
43
+ ValueError: prefix not match the patterns
44
+
45
+ """
46
+ check_regex_prefix_format_valid(prefix)
47
+ check_file_or_directory_path(dump_dir, True)
48
+ contents = os.listdir(dump_dir)
49
+ pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$')
50
+ for name in contents:
51
+ if not pattern.match(name):
52
+ logger.error(
53
+ f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump "
54
+ f"output. Please check and delete irrelevant files in {dump_dir} and try again."
55
+ )
56
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
57
+ return contents
58
+
59
+
60
+ def rename_api(npu_name, process):
61
+ npu_split = npu_name.split(process)
62
+ torch_func_index, in_out = npu_split[0], npu_split[1]
63
+ torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
64
+ torch_func = str(torch_func_split[0]) + str(in_out)
65
+ return torch_func
66
+
67
+
68
+ def read_op(op_data, op_name):
69
+ op_parsed_list = Const.DEFAULT_LIST
70
+ if Const.FORWARD in op_name:
71
+ if Const.INPUT_ARGS in op_data:
72
+ input_item = op_data[Const.INPUT_ARGS]
73
+ input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
74
+ op_parsed_list = input_parsed_list.copy()
75
+ input_parsed_list.clear()
76
+ if Const.INPUT_KWARGS in op_data:
77
+ kwargs_item = op_data[Const.INPUT_KWARGS]
78
+ if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list):
79
+ kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '.input', None)
80
+ op_parsed_list += kwarg_parsed_list
81
+ kwarg_parsed_list.clear()
82
+ elif kwargs_item:
83
+ for kwarg in kwargs_item:
84
+ kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '.input.' + kwarg, None)
85
+ op_parsed_list += kwarg_parsed_list
86
+ kwarg_parsed_list.clear()
87
+ if Const.OUTPUT in op_data:
88
+ output_item = op_data[Const.OUTPUT]
89
+ output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
90
+ op_parsed_list += output_parsed_list
91
+ output_parsed_list.clear()
92
+ if Const.BACKWARD in op_name:
93
+ if Const.INPUT in op_data:
94
+ input_item = op_data[Const.INPUT]
95
+ input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
96
+ op_parsed_list = input_parsed_list.copy()
97
+ input_parsed_list.clear()
98
+ if Const.OUTPUT in op_data:
99
+ output_item = op_data[Const.OUTPUT]
100
+ output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
101
+ op_parsed_list += output_parsed_list
102
+ output_parsed_list.clear()
103
+ return op_parsed_list
104
+
105
+
106
+ def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
107
+ if item_list is None:
108
+ item_list = []
109
+ if item is None or (isinstance(item, dict) and not item):
110
+ if not top_bool:
111
+ tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None,
112
+ 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'}
113
+ else:
114
+ tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None,
115
+ 'shape': None, 'md5': None, 'data_name': '-1'}
116
+ item_list.append(tmp)
117
+ return item_list
118
+ if index is None:
119
+ if isinstance(item, dict):
120
+ full_op_name = op_name + '.0'
121
+ else:
122
+ full_op_name = op_name
123
+ else:
124
+ full_op_name = op_name + Const.SEP + str(index)
125
+ if isinstance(item, dict):
126
+ if 'type' not in item:
127
+ for kwarg in item:
128
+ kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None)
129
+ item_list += kwarg_parsed_list
130
+ kwarg_parsed_list.clear()
131
+ elif 'dtype' in item:
132
+ parsed_item = item
133
+ parsed_item['full_op_name'] = full_op_name
134
+ item_list.append(parsed_item)
135
+ elif 'type' in item:
136
+ parsed_item = {}
137
+ if item['type'] == 'torch.Size':
138
+ parsed_item['full_op_name'] = full_op_name
139
+ parsed_item['dtype'] = 'torch.Size'
140
+ parsed_item['shape'] = str(item['value'])
141
+ parsed_item['md5'] = None
142
+ parsed_item['Max'] = None
143
+ parsed_item['Min'] = None
144
+ parsed_item['Mean'] = None
145
+ parsed_item['Norm'] = None
146
+ parsed_item['data_name'] = '-1'
147
+ item_list.append(parsed_item)
148
+ elif item['type'] == 'slice':
149
+ parsed_item['full_op_name'] = full_op_name
150
+ parsed_item['dtype'] = 'slice'
151
+ parsed_item['shape'] = str(np.shape(np.array(item['value'])))
152
+ parsed_item['md5'] = None
153
+ parsed_item['Max'] = None
154
+ parsed_item['Min'] = None
155
+ parsed_item['Mean'] = None
156
+ parsed_item['Norm'] = None
157
+ parsed_item['data_name'] = '-1'
158
+ item_list.append(parsed_item)
159
+ else:
160
+ parsed_item['full_op_name'] = full_op_name
161
+ parsed_item['dtype'] = str(type(item['value']))
162
+ parsed_item['shape'] = '[]'
163
+ parsed_item['md5'] = None
164
+ parsed_item['Max'] = item['value']
165
+ parsed_item['Min'] = item['value']
166
+ parsed_item['Mean'] = item['value']
167
+ parsed_item['Norm'] = item['value']
168
+ parsed_item['data_name'] = '-1'
169
+ item_list.append(parsed_item)
170
+ else:
171
+ resolve_api_special_parameters(item, full_op_name, item_list)
172
+ else:
173
+ for j, item_spec in enumerate(item):
174
+ op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False)
175
+ return item_list
176
+
177
+
178
+ def resolve_api_special_parameters(data_dict, full_op_name, item_list):
179
+ """
180
+ Function Description:
181
+ 解析下面格式的数据, 是api参数的一种特殊格式
182
+ {
183
+ "last_hidden_state": {
184
+ "type": "torch.Tensor",
185
+ "dtype": "torch.bfloat16",
186
+ ...
187
+ },
188
+ "loss": {
189
+ "type": "torch.Tensor",
190
+ "dtype": "torch.float32",
191
+ ...
192
+ }
193
+ }
194
+ Parameter:
195
+ data_dict: 字典格式的数据
196
+ full_op_name: 参数的全名字符串
197
+ item_list: 参数信息集合
198
+ """
199
+ for key, value in data_dict.items():
200
+ if isinstance(value, dict):
201
+ parsed_item = value
202
+ parts = full_op_name.split(Const.SEP)
203
+ parts.insert(-1, key)
204
+ full_op_name_new = ".".join(parts)
205
+ parsed_item['full_op_name'] = full_op_name_new
206
+ item_list.append(parsed_item)
207
+
208
+
209
+ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False):
210
+ def get_accuracy_core(n_start, n_len, b_start, b_len, key):
211
+ min_len = min(n_len, b_len)
212
+ npu_stack_info = n_dict.get("stack_info", None)
213
+ bench_stack_info = b_dict.get("stack_info", None)
214
+ has_stack = npu_stack_info and bench_stack_info
215
+
216
+ all_mode_bool = not (summary_compare or md5_compare)
217
+ if all_mode_bool:
218
+ npu_data_name = n_dict.get("data_name", None)
219
+ bench_data_name = b_dict.get("data_name", None)
220
+
221
+ for index in range(min_len):
222
+
223
+ n_name = n_dict['op_name'][n_start + index]
224
+ b_name = b_dict['op_name'][b_start + index]
225
+ n_struct = n_dict[key][index]
226
+ b_struct = b_dict[key][index]
227
+ err_msg = ""
228
+ if md5_compare:
229
+ result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
230
+ n_struct[2], b_struct[2],
231
+ CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF]
232
+ if has_stack and index == 0 and key == "input_struct":
233
+ result_item.extend(npu_stack_info)
234
+ else:
235
+ result_item.append(CompareConst.NONE)
236
+ result.append(result_item)
237
+ continue
238
+
239
+ if summary_compare:
240
+ result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
241
+ " ", " ", " ", " ", " ", " ", " ", " "]
242
+ else:
243
+ result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
244
+ " ", " ", " ", " ", " "]
245
+
246
+ npu_summary_data = n_dict.get("summary")[n_start + index]
247
+ result_item.extend(npu_summary_data)
248
+ bench_summary_data = b_dict.get("summary")[b_start + index]
249
+ result_item.extend(bench_summary_data)
250
+
251
+ if summary_compare:
252
+ start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
253
+ warning_flag = False
254
+ for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
255
+ if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
256
+ diff = npu_val - bench_val
257
+ if bench_val != 0:
258
+ relative = str(abs((diff / bench_val) * 100)) + '%'
259
+ else:
260
+ relative = "N/A"
261
+ result_item[start_idx + i] = diff
262
+ result_item[start_idx + i + 4] = relative
263
+ magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
264
+ if magnitude_diff > 0.5:
265
+ warning_flag = True
266
+ else:
267
+ result_item[start_idx + i] = CompareConst.NONE
268
+ accuracy_check = CompareConst.WARNING if warning_flag else ""
269
+ err_msg += "Need double check api accuracy." if warning_flag else ""
270
+ for i in range(start_idx, len(result_item)):
271
+ if str(result_item[i]) in ('inf', '-inf', 'nan'):
272
+ result_item[i] = f'{result_item[i]}\t'
273
+
274
+ result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES)
275
+ result_item.append(err_msg)
276
+ if has_stack and index == 0 and key == "input_struct":
277
+ result_item.extend(npu_stack_info)
278
+ else:
279
+ result_item.append(CompareConst.NONE)
280
+ if all_mode_bool:
281
+ result_item.append(npu_data_name[n_start + index])
282
+
283
+ result.append(result_item)
284
+
285
+ if n_len > b_len:
286
+ for index in range(b_len, n_len):
287
+ n_name = n_dict['op_name'][n_start + index]
288
+ n_struct = n_dict[key][index]
289
+ if md5_compare:
290
+ result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
291
+ n_struct[1], CompareConst.NAN, n_struct[2], CompareConst.NAN, CompareConst.NAN]
292
+ result.append(result_item)
293
+ continue
294
+ result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
295
+ n_struct[1], CompareConst.NAN, " ", " ", " ", " ", " "]
296
+ summary_data = n_dict.get("summary")[n_start + index]
297
+ result_item.extend(summary_data)
298
+ summary_data = [CompareConst.NAN for _ in range(len(n_dict.get("summary")[0]))]
299
+ result_item.extend(summary_data)
300
+
301
+ err_msg = ""
302
+ result_item.append(CompareConst.ACCURACY_CHECK_YES)
303
+ result_item.append(err_msg)
304
+
305
+ if has_stack and index == 0 and key == "input_struct":
306
+ result_item.extend(npu_stack_info)
307
+ else:
308
+ result_item.append(CompareConst.NONE)
309
+ if all_mode_bool:
310
+ result_item.append(npu_data_name[n_start + index])
311
+
312
+ result.append(result_item)
313
+
314
+ n_num = len(n_dict['op_name'])
315
+ b_num = len(b_dict['op_name'])
316
+ n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name])
317
+ b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name])
318
+ n_num_kwarg = len([name for name in n_dict['op_name'] if 'kwarg' in name])
319
+ b_num_kwarg = len([name for name in b_dict['op_name'] if 'kwarg' in name])
320
+ n_num_output = n_num - n_num_input - n_num_kwarg
321
+ b_num_output = b_num - b_num_input - b_num_kwarg
322
+ get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
323
+ get_accuracy_core(n_num_input, n_num_kwarg, b_num_input, b_num_kwarg, "kwargs_struct")
324
+ get_accuracy_core(n_num_input + n_num_kwarg, n_num_output, b_num_input + b_num_kwarg, b_num_output, 'output_struct')
325
+
326
+
327
+ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
328
+ index_out = 0
329
+ npu_stack_info = n_dict.get("stack_info", None)
330
+ bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
331
+ err_msg = CompareConst.NO_BENCH
332
+ accuracy_check_res = CompareConst.N_A
333
+ for index, n_name in enumerate(n_dict["op_name"]):
334
+ if n_name.find("input") != -1:
335
+ n_struct = n_dict["input_struct"][index]
336
+ else:
337
+ n_struct = n_dict["output_struct"][index_out]
338
+ index_out += 1
339
+
340
+ result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
341
+ if md5_compare:
342
+ result_item.extend([CompareConst.N_A] * 3)
343
+ if npu_stack_info and index == 0:
344
+ result_item.extend(npu_stack_info)
345
+ else:
346
+ result_item.append(CompareConst.NONE)
347
+ result.append(result_item)
348
+ continue
349
+ if summary_compare:
350
+ result_item.extend([CompareConst.N_A] * 8)
351
+ else:
352
+ result_item.extend([CompareConst.N_A] * 5)
353
+ npu_summary_data = n_dict.get("summary")[index]
354
+ result_item.extend(npu_summary_data)
355
+ bench_summary_data = [CompareConst.N_A] * 4
356
+ result_item.extend(bench_summary_data)
357
+ result_item.append(accuracy_check_res)
358
+ result_item.append(err_msg)
359
+ if npu_stack_info and index == 0:
360
+ result_item.extend(npu_stack_info)
361
+ else:
362
+ result_item.append(CompareConst.NONE)
363
+ if not md5_compare and not summary_compare and result_item[1] == CompareConst.N_A:
364
+ result_item.extend(["-1"])
365
+ result.append(result_item)
366
+
367
+
368
+ def merge_tensor(tensor_list, summary_compare, md5_compare):
369
+ op_dict = {}
370
+ op_dict["op_name"] = []
371
+ op_dict["input_struct"] = []
372
+ op_dict["kwargs_struct"] = []
373
+ op_dict["output_struct"] = []
374
+ op_dict["summary"] = []
375
+ op_dict["stack_info"] = []
376
+
377
+ all_mode_bool = not (summary_compare or md5_compare)
378
+ if all_mode_bool:
379
+ op_dict["data_name"] = []
380
+
381
+ for tensor in tensor_list:
382
+ if len(tensor) == 2:
383
+ op_dict['stack_info'].append(tensor['full_info'])
384
+ break
385
+ op_dict["op_name"].append(tensor['full_op_name'])
386
+ if not md5_compare:
387
+ if tensor['full_op_name'].find("input") != -1:
388
+ op_dict["input_struct"].append((tensor['dtype'], tensor['shape']))
389
+ elif tensor['full_op_name'].find("kwarg") != -1:
390
+ op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape']))
391
+ elif tensor['full_op_name'].find("output") != -1:
392
+ op_dict["output_struct"].append((tensor['dtype'], tensor['shape']))
393
+ else:
394
+ if tensor['full_op_name'].find("input") != -1:
395
+ op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
396
+ elif tensor['full_op_name'].find("kwarg") != -1:
397
+ op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
398
+ elif tensor['full_op_name'].find("output") != -1:
399
+ op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
400
+
401
+ op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
402
+
403
+ if all_mode_bool:
404
+ op_dict["data_name"].append(tensor['data_name'])
405
+
406
+ if not op_dict["kwargs_struct"]:
407
+ del op_dict["kwargs_struct"]
408
+ return op_dict if op_dict["op_name"] else {}
409
+
410
+
411
+ def _compare_parser(parser):
412
+ parser.add_argument("-i", "--input_path", dest="input_path", type=str,
413
+ help="<Required> The compare input path, a dict json.", required=True)
414
+ parser.add_argument("-o", "--output_path", dest="output_path", type=str,
415
+ help="<Required> The compare task result out path.", required=True)
416
+ parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
417
+ help="<optional> Whether to save stack info.", required=False)
418
+ parser.add_argument("-c", "--compare_only", dest="compare_only", action="store_true",
419
+ help="<optional> Whether to give advisor.", required=False)
420
+ parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true",
421
+ help="<optional> Whether to perform a fuzzy match on the api name.", required=False)
422
+ parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True,
423
+ help="<optional> The cell mapping file path.", required=False)
424
+ parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
425
+ help="<optional> The api mapping file path.", required=False)
426
+
427
+
428
+
429
+
430
+