mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,429 +1,464 @@
1
-
2
- import os
3
- import re
4
- import numpy as np
5
- from msprobe.core.common.const import Const, CompareConst
6
- from msprobe.core.common.utils import CompareException, check_file_or_directory_path, check_regex_prefix_format_valid, logger
7
-
8
-
9
- def extract_json(dirname, stack_json=False):
10
- json_path = ''
11
- for fname in os.listdir(dirname):
12
- if fname == "construct.json":
13
- continue
14
- full_path = os.path.join(dirname, fname)
15
- if full_path.endswith('.json'):
16
- json_path = full_path
17
- if not stack_json and 'stack' not in json_path:
18
- break
19
- if stack_json and 'stack' in json_path:
20
- break
21
-
22
- # Provide robustness on invalid directory inputs
23
- if not json_path:
24
- logger.error(f'No file is found in dump dir {dirname}. ')
25
- raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
26
- return json_path
27
-
28
-
29
- def check_and_return_dir_contents(dump_dir, prefix):
30
- """
31
- check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
32
- pattern: ^{prefix}(?:0|[0-9][1-9]*)?$
33
-
34
- Args:
35
- dump_dir (str): dump dir
36
- prefix (str): prefix for the patterns, prefix should be less than 20 characters and alphanumeric/-/_ only
37
-
38
- Returns:
39
- content [list]: dir contents
40
- Raises:
41
- CompareException: invalid path
42
- ValueError: prefix not match the patterns
43
-
44
- """
45
- check_regex_prefix_format_valid(prefix)
46
- check_file_or_directory_path(dump_dir, True)
47
- contents = os.listdir(dump_dir)
48
- pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$')
49
- for name in contents:
50
- if not pattern.match(name):
51
- logger.error(
52
- f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump "
53
- f"output. Please check and delete irrelevant files in {dump_dir} and try again."
54
- )
55
- raise CompareException(CompareException.INVALID_PATH_ERROR)
56
- return contents
57
-
58
-
59
- def rename_api(npu_name, process):
60
- npu_split = npu_name.split(process)
61
- torch_func_index, in_out = npu_split[0], npu_split[1]
62
- torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
63
- torch_func = str(torch_func_split[0]) + str(in_out)
64
- return torch_func
65
-
66
-
67
- def read_op(op_data, op_name):
68
- op_parsed_list = Const.DEFAULT_LIST
69
- if Const.FORWARD in op_name:
70
- if Const.INPUT_ARGS in op_data:
71
- input_item = op_data[Const.INPUT_ARGS]
72
- input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
73
- op_parsed_list = input_parsed_list.copy()
74
- input_parsed_list.clear()
75
- if Const.INPUT_KWARGS in op_data:
76
- kwargs_item = op_data[Const.INPUT_KWARGS]
77
- if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list):
78
- kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '.input', None)
79
- op_parsed_list += kwarg_parsed_list
80
- kwarg_parsed_list.clear()
81
- elif kwargs_item:
82
- for kwarg in kwargs_item:
83
- kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '.input.' + kwarg, None)
84
- op_parsed_list += kwarg_parsed_list
85
- kwarg_parsed_list.clear()
86
- if Const.OUTPUT in op_data:
87
- output_item = op_data[Const.OUTPUT]
88
- output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
89
- op_parsed_list += output_parsed_list
90
- output_parsed_list.clear()
91
- if Const.BACKWARD in op_name:
92
- if Const.INPUT in op_data:
93
- input_item = op_data[Const.INPUT]
94
- input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
95
- op_parsed_list = input_parsed_list.copy()
96
- input_parsed_list.clear()
97
- if Const.OUTPUT in op_data:
98
- output_item = op_data[Const.OUTPUT]
99
- output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
100
- op_parsed_list += output_parsed_list
101
- output_parsed_list.clear()
102
- return op_parsed_list
103
-
104
-
105
- def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
106
- if item_list is None:
107
- item_list = []
108
- if item is None or (isinstance(item, dict) and not item):
109
- if not top_bool:
110
- tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None,
111
- 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'}
112
- else:
113
- tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None,
114
- 'shape': None, 'md5': None, 'data_name': '-1'}
115
- item_list.append(tmp)
116
- return item_list
117
- if index is None:
118
- if isinstance(item, dict):
119
- full_op_name = op_name + '.0'
120
- else:
121
- full_op_name = op_name
122
- else:
123
- full_op_name = op_name + Const.SEP + str(index)
124
- if isinstance(item, dict):
125
- if 'type' not in item:
126
- for kwarg in item:
127
- kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None)
128
- item_list += kwarg_parsed_list
129
- kwarg_parsed_list.clear()
130
- elif 'dtype' in item:
131
- parsed_item = item
132
- parsed_item['full_op_name'] = full_op_name
133
- item_list.append(parsed_item)
134
- elif 'type' in item:
135
- parsed_item = {}
136
- if item['type'] == 'torch.Size':
137
- parsed_item['full_op_name'] = full_op_name
138
- parsed_item['dtype'] = 'torch.Size'
139
- parsed_item['shape'] = str(item['value'])
140
- parsed_item['md5'] = None
141
- parsed_item['Max'] = None
142
- parsed_item['Min'] = None
143
- parsed_item['Mean'] = None
144
- parsed_item['Norm'] = None
145
- parsed_item['data_name'] = '-1'
146
- item_list.append(parsed_item)
147
- elif item['type'] == 'slice':
148
- parsed_item['full_op_name'] = full_op_name
149
- parsed_item['dtype'] = 'slice'
150
- parsed_item['shape'] = str(np.shape(np.array(item['value'])))
151
- parsed_item['md5'] = None
152
- parsed_item['Max'] = None
153
- parsed_item['Min'] = None
154
- parsed_item['Mean'] = None
155
- parsed_item['Norm'] = None
156
- parsed_item['data_name'] = '-1'
157
- item_list.append(parsed_item)
158
- else:
159
- parsed_item['full_op_name'] = full_op_name
160
- parsed_item['dtype'] = str(type(item['value']))
161
- parsed_item['shape'] = '[]'
162
- parsed_item['md5'] = None
163
- parsed_item['Max'] = item['value']
164
- parsed_item['Min'] = item['value']
165
- parsed_item['Mean'] = item['value']
166
- parsed_item['Norm'] = item['value']
167
- parsed_item['data_name'] = '-1'
168
- item_list.append(parsed_item)
169
- else:
170
- resolve_api_special_parameters(item, full_op_name, item_list)
171
- else:
172
- for j, item_spec in enumerate(item):
173
- op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False)
174
- return item_list
175
-
176
-
177
- def resolve_api_special_parameters(data_dict, full_op_name, item_list):
178
- """
179
- Function Description:
180
- 解析下面格式的数据, 是api参数的一种特殊格式
181
- {
182
- "last_hidden_state": {
183
- "type": "torch.Tensor",
184
- "dtype": "torch.bfloat16",
185
- ...
186
- },
187
- "loss": {
188
- "type": "torch.Tensor",
189
- "dtype": "torch.float32",
190
- ...
191
- }
192
- }
193
- Parameter:
194
- data_dict: 字典格式的数据
195
- full_op_name: 参数的全名字符串
196
- item_list: 参数信息集合
197
- """
198
- for key, value in data_dict.items():
199
- if isinstance(value, dict):
200
- parsed_item = value
201
- parts = full_op_name.split(Const.SEP)
202
- parts.insert(-1, key)
203
- full_op_name_new = ".".join(parts)
204
- parsed_item['full_op_name'] = full_op_name_new
205
- item_list.append(parsed_item)
206
-
207
-
208
- def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False):
209
- def get_accuracy_core(n_start, n_len, b_start, b_len, key):
210
- min_len = min(n_len, b_len)
211
- npu_stack_info = n_dict.get("stack_info", None)
212
- bench_stack_info = b_dict.get("stack_info", None)
213
- has_stack = npu_stack_info and bench_stack_info
214
-
215
- all_mode_bool = not (summary_compare or md5_compare)
216
- if all_mode_bool:
217
- npu_data_name = n_dict.get("data_name", None)
218
- bench_data_name = b_dict.get("data_name", None)
219
-
220
- for index in range(min_len):
221
-
222
- n_name = n_dict['op_name'][n_start + index]
223
- b_name = b_dict['op_name'][b_start + index]
224
- n_struct = n_dict[key][index]
225
- b_struct = b_dict[key][index]
226
- err_msg = ""
227
- if md5_compare:
228
- result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
229
- n_struct[2], b_struct[2],
230
- CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF]
231
- if has_stack and index == 0 and key == "input_struct":
232
- result_item.extend(npu_stack_info)
233
- else:
234
- result_item.append(CompareConst.NONE)
235
- result.append(result_item)
236
- continue
237
-
238
- if summary_compare:
239
- result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
240
- " ", " ", " ", " ", " ", " ", " ", " "]
241
- else:
242
- result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
243
- " ", " ", " ", " ", " "]
244
-
245
- npu_summary_data = n_dict.get("summary")[n_start + index]
246
- result_item.extend(npu_summary_data)
247
- bench_summary_data = b_dict.get("summary")[b_start + index]
248
- result_item.extend(bench_summary_data)
249
-
250
- if summary_compare:
251
- start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
252
- warning_flag = False
253
- for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
254
- if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
255
- diff = npu_val - bench_val
256
- if bench_val != 0:
257
- relative = str(abs((diff / bench_val) * 100)) + '%'
258
- else:
259
- relative = "N/A"
260
- result_item[start_idx + i] = diff
261
- result_item[start_idx + i + 4] = relative
262
- magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
263
- if magnitude_diff > 0.5:
264
- warning_flag = True
265
- else:
266
- result_item[start_idx + i] = CompareConst.NONE
267
- accuracy_check = CompareConst.WARNING if warning_flag else ""
268
- err_msg += "Need double check api accuracy." if warning_flag else ""
269
- for i in range(start_idx, len(result_item)):
270
- if str(result_item[i]) in ('inf', '-inf', 'nan'):
271
- result_item[i] = f'{result_item[i]}\t'
272
-
273
- result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES)
274
- result_item.append(err_msg)
275
- if has_stack and index == 0 and key == "input_struct":
276
- result_item.extend(npu_stack_info)
277
- else:
278
- result_item.append(CompareConst.NONE)
279
- if all_mode_bool:
280
- result_item.append(npu_data_name[n_start + index])
281
-
282
- result.append(result_item)
283
-
284
- if n_len > b_len:
285
- for index in range(b_len, n_len):
286
- n_name = n_dict['op_name'][n_start + index]
287
- n_struct = n_dict[key][index]
288
- if md5_compare:
289
- result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
290
- n_struct[1], CompareConst.NAN, n_struct[2], CompareConst.NAN, CompareConst.NAN]
291
- result.append(result_item)
292
- continue
293
- result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
294
- n_struct[1], CompareConst.NAN, " ", " ", " ", " ", " "]
295
- summary_data = n_dict.get("summary")[n_start + index]
296
- result_item.extend(summary_data)
297
- summary_data = [CompareConst.NAN for _ in range(len(n_dict.get("summary")[0]))]
298
- result_item.extend(summary_data)
299
-
300
- err_msg = ""
301
- result_item.append(CompareConst.ACCURACY_CHECK_YES)
302
- result_item.append(err_msg)
303
-
304
- if has_stack and index == 0 and key == "input_struct":
305
- result_item.extend(npu_stack_info)
306
- else:
307
- result_item.append(CompareConst.NONE)
308
- if all_mode_bool:
309
- result_item.append(npu_data_name[n_start + index])
310
-
311
- result.append(result_item)
312
-
313
- n_num = len(n_dict['op_name'])
314
- b_num = len(b_dict['op_name'])
315
- n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name])
316
- b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name])
317
- n_num_kwarg = len([name for name in n_dict['op_name'] if 'kwarg' in name])
318
- b_num_kwarg = len([name for name in b_dict['op_name'] if 'kwarg' in name])
319
- n_num_output = n_num - n_num_input - n_num_kwarg
320
- b_num_output = b_num - b_num_input - b_num_kwarg
321
- get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
322
- get_accuracy_core(n_num_input, n_num_kwarg, b_num_input, b_num_kwarg, "kwargs_struct")
323
- get_accuracy_core(n_num_input + n_num_kwarg, n_num_output, b_num_input + b_num_kwarg, b_num_output, 'output_struct')
324
-
325
-
326
- def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
327
- index_out = 0
328
- npu_stack_info = n_dict.get("stack_info", None)
329
- bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
330
- err_msg = CompareConst.NO_BENCH
331
- accuracy_check_res = CompareConst.N_A
332
- for index, n_name in enumerate(n_dict["op_name"]):
333
- if n_name.find("input") != -1:
334
- n_struct = n_dict["input_struct"][index]
335
- else:
336
- n_struct = n_dict["output_struct"][index_out]
337
- index_out += 1
338
-
339
- result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
340
- if md5_compare:
341
- result_item.extend([CompareConst.N_A] * 3)
342
- if npu_stack_info and index == 0:
343
- result_item.extend(npu_stack_info)
344
- else:
345
- result_item.append(CompareConst.NONE)
346
- result.append(result_item)
347
- continue
348
- if summary_compare:
349
- result_item.extend([CompareConst.N_A] * 8)
350
- else:
351
- result_item.extend([CompareConst.N_A] * 5)
352
- npu_summary_data = n_dict.get("summary")[index]
353
- result_item.extend(npu_summary_data)
354
- bench_summary_data = [CompareConst.N_A] * 4
355
- result_item.extend(bench_summary_data)
356
- result_item.append(accuracy_check_res)
357
- result_item.append(err_msg)
358
- if npu_stack_info and index == 0:
359
- result_item.extend(npu_stack_info)
360
- else:
361
- result_item.append(CompareConst.NONE)
362
- if not md5_compare and not summary_compare and result_item[1] == CompareConst.N_A:
363
- result_item.extend(["-1"])
364
- result.append(result_item)
365
-
366
-
367
- def merge_tensor(tensor_list, summary_compare, md5_compare):
368
- op_dict = {}
369
- op_dict["op_name"] = []
370
- op_dict["input_struct"] = []
371
- op_dict["kwargs_struct"] = []
372
- op_dict["output_struct"] = []
373
- op_dict["summary"] = []
374
- op_dict["stack_info"] = []
375
-
376
- all_mode_bool = not (summary_compare or md5_compare)
377
- if all_mode_bool:
378
- op_dict["data_name"] = []
379
-
380
- for tensor in tensor_list:
381
- if len(tensor) == 2:
382
- op_dict['stack_info'].append(tensor['full_info'])
383
- break
384
- op_dict["op_name"].append(tensor['full_op_name'])
385
- if not md5_compare:
386
- if tensor['full_op_name'].find("input") != -1:
387
- op_dict["input_struct"].append((tensor['dtype'], tensor['shape']))
388
- elif tensor['full_op_name'].find("kwarg") != -1:
389
- op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape']))
390
- elif tensor['full_op_name'].find("output") != -1:
391
- op_dict["output_struct"].append((tensor['dtype'], tensor['shape']))
392
- else:
393
- if tensor['full_op_name'].find("input") != -1:
394
- op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
395
- elif tensor['full_op_name'].find("kwarg") != -1:
396
- op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
397
- elif tensor['full_op_name'].find("output") != -1:
398
- op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
399
-
400
- op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
401
-
402
- if all_mode_bool:
403
- op_dict["data_name"].append(tensor['data_name'])
404
-
405
- if not op_dict["kwargs_struct"]:
406
- del op_dict["kwargs_struct"]
407
- return op_dict if op_dict["op_name"] else {}
408
-
409
-
410
- def _compare_parser(parser):
411
- parser.add_argument("-i", "--input_path", dest="input_path", type=str,
412
- help="<Required> The compare input path, a dict json.", required=True)
413
- parser.add_argument("-o", "--output_path", dest="output_path", type=str,
414
- help="<Required> The compare task result out path.", required=True)
415
- parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
416
- help="<optional> Whether to save stack info.", required=False)
417
- parser.add_argument("-c", "--compare_only", dest="compare_only", action="store_true",
418
- help="<optional> Whether to give advisor.", required=False)
419
- parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true",
420
- help="<optional> Whether to perform a fuzzy match on the api name.", required=False)
421
- parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True,
422
- help="<optional> The cell mapping file path.", required=False)
423
- parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
424
- help="<optional> The api mapping file path.", required=False)
425
-
426
-
427
-
428
-
429
-
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import re
18
+ import numpy as np
19
+ from msprobe.core.common.const import Const, CompareConst
20
+ from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger
21
+ from msprobe.core.common.file_utils import check_file_or_directory_path
22
+
23
+
24
+ def extract_json(dirname, stack_json=False):
25
+ json_path = ''
26
+ for fname in os.listdir(dirname):
27
+ if fname == "construct.json":
28
+ continue
29
+ full_path = os.path.join(dirname, fname)
30
+ if full_path.endswith('.json'):
31
+ json_path = full_path
32
+ if not stack_json and 'stack' not in json_path:
33
+ break
34
+ if stack_json and 'stack' in json_path:
35
+ break
36
+
37
+ # Provide robustness on invalid directory inputs
38
+ if not json_path:
39
+ logger.error(f'No file is found in dump dir {dirname}. ')
40
+ raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
41
+ return json_path
42
+
43
+
44
+ def check_and_return_dir_contents(dump_dir, prefix):
45
+ """
46
+ check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
47
+ pattern: ^{prefix}(?:0|[0-9][1-9]*)?$
48
+
49
+ Args:
50
+ dump_dir (str): dump dir
51
+ prefix (str): prefix for the patterns, prefix should be less than 20 characters and alphanumeric/-/_ only
52
+
53
+ Returns:
54
+ content [list]: dir contents
55
+ Raises:
56
+ CompareException: invalid path
57
+ ValueError: prefix not match the patterns
58
+
59
+ """
60
+ check_regex_prefix_format_valid(prefix)
61
+ check_file_or_directory_path(dump_dir, True)
62
+ contents = os.listdir(dump_dir)
63
+ pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$')
64
+ for name in contents:
65
+ if not pattern.match(name):
66
+ logger.error(
67
+ f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump "
68
+ f"output. Please check and delete irrelevant files in {dump_dir} and try again."
69
+ )
70
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
71
+ return contents
72
+
73
+
74
+ def rename_api(npu_name, process):
75
+ npu_split = npu_name.split(process)
76
+ try:
77
+ torch_func_index, in_out = npu_split[0], npu_split[1]
78
+ except IndexError as error:
79
+ logger.error(f'{npu_name} can not be split with {process}, please check!')
80
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
81
+ torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
82
+ torch_func = str(torch_func_split[0]) + str(in_out)
83
+ return torch_func
84
+
85
+
86
+ def read_op(op_data, op_name):
87
+ op_parsed_list = []
88
+ if Const.FORWARD in op_name:
89
+ if Const.INPUT_ARGS in op_data:
90
+ input_item = op_data[Const.INPUT_ARGS]
91
+ input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
92
+ op_parsed_list = input_parsed_list.copy()
93
+ input_parsed_list.clear()
94
+ if Const.INPUT_KWARGS in op_data:
95
+ kwargs_item = op_data[Const.INPUT_KWARGS]
96
+ if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list):
97
+ kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '.input', None)
98
+ op_parsed_list += kwarg_parsed_list
99
+ kwarg_parsed_list.clear()
100
+ elif kwargs_item:
101
+ for kwarg in kwargs_item:
102
+ kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '.input.' + kwarg, None)
103
+ op_parsed_list += kwarg_parsed_list
104
+ kwarg_parsed_list.clear()
105
+ if Const.OUTPUT in op_data:
106
+ output_item = op_data[Const.OUTPUT]
107
+ output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
108
+ op_parsed_list += output_parsed_list
109
+ output_parsed_list.clear()
110
+ if Const.BACKWARD in op_name:
111
+ if Const.INPUT in op_data:
112
+ input_item = op_data[Const.INPUT]
113
+ input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
114
+ op_parsed_list = input_parsed_list.copy()
115
+ input_parsed_list.clear()
116
+ if Const.OUTPUT in op_data:
117
+ output_item = op_data[Const.OUTPUT]
118
+ output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
119
+ op_parsed_list += output_parsed_list
120
+ output_parsed_list.clear()
121
+ return op_parsed_list
122
+
123
+
124
+ def op_item_parse(item, op_name, index, item_list=None, top_bool=True, depth=0):
125
+ if depth > Const.MAX_DEPTH:
126
+ logger.error(f"parse of api/module of {op_name} exceeds the recursion limit.")
127
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
128
+ if item_list is None:
129
+ item_list = []
130
+ if item is None or (isinstance(item, dict) and not item):
131
+ if not top_bool:
132
+ tmp = {
133
+ 'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None,
134
+ 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'
135
+ }
136
+ else:
137
+ tmp = {
138
+ 'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None,
139
+ 'shape': None, 'md5': None, 'data_name': '-1'
140
+ }
141
+ item_list.append(tmp)
142
+ return item_list
143
+ if index is None:
144
+ if isinstance(item, dict):
145
+ full_op_name = op_name + '.0'
146
+ else:
147
+ full_op_name = op_name
148
+ else:
149
+ full_op_name = op_name + Const.SEP + str(index)
150
+ if isinstance(item, dict):
151
+ if 'type' not in item:
152
+ for kwarg in item:
153
+ kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None, depth=depth+1)
154
+ item_list += kwarg_parsed_list
155
+ kwarg_parsed_list.clear()
156
+ elif 'dtype' in item:
157
+ parsed_item = item
158
+ parsed_item['full_op_name'] = full_op_name
159
+ item_list.append(parsed_item)
160
+ elif 'type' in item:
161
+ parsed_item = {}
162
+ if item['type'] == 'torch.Size':
163
+ parsed_item['full_op_name'] = full_op_name
164
+ parsed_item['dtype'] = 'torch.Size'
165
+ parsed_item['shape'] = str(item['value'])
166
+ parsed_item['md5'] = None
167
+ parsed_item['Max'] = None
168
+ parsed_item['Min'] = None
169
+ parsed_item['Mean'] = None
170
+ parsed_item['Norm'] = None
171
+ parsed_item['data_name'] = '-1'
172
+ item_list.append(parsed_item)
173
+ elif item['type'] == 'slice':
174
+ parsed_item['full_op_name'] = full_op_name
175
+ parsed_item['dtype'] = 'slice'
176
+ parsed_item['shape'] = str(np.shape(np.array(item['value'])))
177
+ parsed_item['md5'] = None
178
+ parsed_item['Max'] = None
179
+ parsed_item['Min'] = None
180
+ parsed_item['Mean'] = None
181
+ parsed_item['Norm'] = None
182
+ parsed_item['data_name'] = '-1'
183
+ item_list.append(parsed_item)
184
+ else:
185
+ parsed_item['full_op_name'] = full_op_name
186
+ parsed_item['dtype'] = str(type(item['value']))
187
+ parsed_item['shape'] = '[]'
188
+ parsed_item['md5'] = None
189
+ parsed_item['Max'] = item['value']
190
+ parsed_item['Min'] = item['value']
191
+ parsed_item['Mean'] = item['value']
192
+ parsed_item['Norm'] = item['value']
193
+ parsed_item['data_name'] = '-1'
194
+ item_list.append(parsed_item)
195
+ else:
196
+ resolve_api_special_parameters(item, full_op_name, item_list)
197
+ else:
198
+ for j, item_spec in enumerate(item):
199
+ op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False, depth=depth+1)
200
+ return item_list
201
+
202
+
203
+ def resolve_api_special_parameters(data_dict, full_op_name, item_list):
204
+ """
205
+ Function Description:
206
+ 解析下面格式的数据, 是api参数的一种特殊格式
207
+ {
208
+ "last_hidden_state": {
209
+ "type": "torch.Tensor",
210
+ "dtype": "torch.bfloat16",
211
+ ...
212
+ },
213
+ "loss": {
214
+ "type": "torch.Tensor",
215
+ "dtype": "torch.float32",
216
+ ...
217
+ }
218
+ }
219
+ Parameter:
220
+ data_dict: 字典格式的数据
221
+ full_op_name: 参数的全名字符串
222
+ item_list: 参数信息集合
223
+ """
224
+ for key, value in data_dict.items():
225
+ if isinstance(value, dict):
226
+ parsed_item = value
227
+ parts = full_op_name.split(Const.SEP)
228
+ parts.insert(-1, key)
229
+ full_op_name_new = ".".join(parts)
230
+ parsed_item['full_op_name'] = full_op_name_new
231
+ item_list.append(parsed_item)
232
+
233
+
234
+ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False):
235
+ def get_accuracy_core(n_start, n_len, b_start, b_len, key):
236
+ min_len = min(n_len, b_len)
237
+ npu_stack_info = n_dict.get("stack_info", None)
238
+ bench_stack_info = b_dict.get("stack_info", None)
239
+ has_stack = npu_stack_info and bench_stack_info
240
+
241
+ all_mode_bool = not (summary_compare or md5_compare)
242
+ if all_mode_bool:
243
+ npu_data_name = n_dict.get("data_name", None)
244
+ bench_data_name = b_dict.get("data_name", None)
245
+
246
+ for index in range(min_len):
247
+
248
+ n_name = n_dict['op_name'][n_start + index]
249
+ b_name = b_dict['op_name'][b_start + index]
250
+ n_struct = n_dict[key][index]
251
+ b_struct = b_dict[key][index]
252
+ err_msg = ""
253
+ if md5_compare:
254
+ result_item = [
255
+ n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], n_struct[2], b_struct[2],
256
+ CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF
257
+ ]
258
+ if has_stack and index == 0 and key == "input_struct":
259
+ result_item.extend(npu_stack_info)
260
+ else:
261
+ result_item.append(CompareConst.NONE)
262
+ result.append(result_item)
263
+ continue
264
+
265
+ if summary_compare:
266
+ result_item = [
267
+ n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
268
+ " ", " ", " ", " ", " ", " ", " ", " "
269
+ ]
270
+ else:
271
+ result_item = [
272
+ n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
273
+ " ", " ", " ", " ", " "
274
+ ]
275
+
276
+ npu_summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
277
+ result_item.extend(npu_summary_data)
278
+ bench_summary_data = b_dict.get(CompareConst.SUMMARY)[b_start + index]
279
+ result_item.extend(bench_summary_data)
280
+
281
+ if summary_compare:
282
+ start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
283
+ warning_flag = False
284
+ for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
285
+ if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
286
+ diff = npu_val - bench_val
287
+ if bench_val != 0:
288
+ relative = str(abs((diff / bench_val) * 100)) + '%'
289
+ else:
290
+ relative = CompareConst.N_A
291
+ result_item[start_idx + i] = diff
292
+ result_item[start_idx + i + 4] = relative
293
+ magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
294
+ if magnitude_diff > 0.5:
295
+ warning_flag = True
296
+ else:
297
+ result_item[start_idx + i] = CompareConst.NONE
298
+ accuracy_check = CompareConst.WARNING if warning_flag else ""
299
+ err_msg += "Need double check api accuracy." if warning_flag else ""
300
+ for i in range(start_idx, len(result_item)):
301
+ if str(result_item[i]) in ('inf', '-inf', 'nan'):
302
+ result_item[i] = f'{result_item[i]}\t'
303
+
304
+ result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES)
305
+ result_item.append(err_msg)
306
+ if has_stack and index == 0 and key == "input_struct":
307
+ result_item.extend(npu_stack_info)
308
+ else:
309
+ result_item.append(CompareConst.NONE)
310
+ if all_mode_bool:
311
+ result_item.append(npu_data_name[n_start + index])
312
+
313
+ result.append(result_item)
314
+
315
+ if n_len > b_len:
316
+ for index in range(b_len, n_len):
317
+ n_name = n_dict['op_name'][n_start + index]
318
+ n_struct = n_dict[key][index]
319
+ if md5_compare:
320
+ result_item = [
321
+ n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
322
+ n_struct[2], CompareConst.NAN, CompareConst.NAN
323
+ ]
324
+ result.append(result_item)
325
+ continue
326
+ result_item = [
327
+ n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
328
+ " ", " ", " ", " ", " "
329
+ ]
330
+ summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
331
+ result_item.extend(summary_data)
332
+ summary_data = [CompareConst.NAN for _ in range(len(n_dict.get(CompareConst.SUMMARY)[0]))]
333
+ result_item.extend(summary_data)
334
+
335
+ err_msg = ""
336
+ result_item.append(CompareConst.ACCURACY_CHECK_YES)
337
+ result_item.append(err_msg)
338
+
339
+ if has_stack and index == 0 and key == "input_struct":
340
+ result_item.extend(npu_stack_info)
341
+ else:
342
+ result_item.append(CompareConst.NONE)
343
+ if all_mode_bool:
344
+ result_item.append(npu_data_name[n_start + index])
345
+
346
+ result.append(result_item)
347
+
348
+ n_num = len(n_dict['op_name'])
349
+ b_num = len(b_dict['op_name'])
350
+ n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
351
+ b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
352
+ n_num_output = n_num - n_num_input
353
+ b_num_output = b_num - b_num_input
354
+ get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
355
+ get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, 'output_struct')
356
+
357
+
358
+ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
359
+ index_out = 0
360
+ npu_stack_info = n_dict.get("stack_info", None)
361
+ bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
362
+ err_msg = CompareConst.NO_BENCH
363
+ accuracy_check_res = CompareConst.N_A
364
+ for index, n_name in enumerate(n_dict["op_name"]):
365
+ name_ele_list = n_name.split(Const.SEP)
366
+ if "input" in name_ele_list:
367
+ n_struct = n_dict["input_struct"][index]
368
+ else:
369
+ n_struct = n_dict["output_struct"][index_out]
370
+ index_out += 1
371
+
372
+ result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
373
+ if md5_compare:
374
+ result_item.extend([CompareConst.N_A] * 3)
375
+ if npu_stack_info and index == 0:
376
+ result_item.extend(npu_stack_info)
377
+ else:
378
+ result_item.append(CompareConst.NONE)
379
+ result.append(result_item)
380
+ continue
381
+ if summary_compare:
382
+ result_item.extend([CompareConst.N_A] * 8)
383
+ else:
384
+ result_item.extend([CompareConst.N_A] * 5)
385
+ npu_summary_data = n_dict.get("summary")[index]
386
+ result_item.extend(npu_summary_data)
387
+ bench_summary_data = [CompareConst.N_A] * 4
388
+ result_item.extend(bench_summary_data)
389
+ result_item.append(accuracy_check_res)
390
+ result_item.append(err_msg)
391
+ if npu_stack_info and index == 0:
392
+ result_item.extend(npu_stack_info)
393
+ else:
394
+ result_item.append(CompareConst.NONE)
395
+ if not md5_compare and not summary_compare and result_item[1] == CompareConst.N_A:
396
+ result_item.extend(["-1"])
397
+ result.append(result_item)
398
+
399
+
400
+ def merge_tensor(tensor_list, summary_compare, md5_compare):
401
+ op_dict = {}
402
+ op_dict["op_name"] = []
403
+ op_dict["input_struct"] = []
404
+ op_dict["kwargs_struct"] = []
405
+ op_dict["output_struct"] = []
406
+ op_dict["summary"] = []
407
+ op_dict["stack_info"] = []
408
+
409
+ all_mode_bool = not (summary_compare or md5_compare)
410
+ if all_mode_bool:
411
+ op_dict["data_name"] = []
412
+
413
+ for tensor in tensor_list:
414
+ if len(tensor) == 2:
415
+ op_dict['stack_info'].append(tensor['full_info'])
416
+ break
417
+ op_dict["op_name"].append(tensor['full_op_name'])
418
+ name_ele_list = tensor['full_op_name'].split(Const.SEP)
419
+ if not md5_compare:
420
+ if "input" in name_ele_list:
421
+ op_dict["input_struct"].append((tensor['dtype'], tensor['shape']))
422
+ elif "kwarg" in name_ele_list:
423
+ op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape']))
424
+ elif "output" in name_ele_list:
425
+ op_dict["output_struct"].append((tensor['dtype'], tensor['shape']))
426
+ else:
427
+ if "input" in name_ele_list:
428
+ op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
429
+ if "kwarg" in name_ele_list:
430
+ op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
431
+ elif "output" in name_ele_list:
432
+ op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
433
+ op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
434
+
435
+ if all_mode_bool:
436
+ op_dict["data_name"].append(tensor['data_name'])
437
+ data_name = op_dict["data_name"][-1].rsplit(Const.SEP, 1)[0]
438
+ if data_name != "-1":
439
+ op_dict["op_name"][-1] = data_name
440
+
441
+ if not op_dict["kwargs_struct"]:
442
+ del op_dict["kwargs_struct"]
443
+ return op_dict if op_dict["op_name"] else {}
444
+
445
+
446
+ def _compare_parser(parser):
447
+ parser.add_argument("-i", "--input_path", dest="input_path", type=str,
448
+ help="<Required> The compare input path, a dict json.", required=True)
449
+ parser.add_argument("-o", "--output_path", dest="output_path", type=str,
450
+ help="<Required> The compare task result out path.", required=True)
451
+ parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
452
+ help="<optional> Whether to save stack info.", required=False)
453
+ parser.add_argument("-c", "--compare_only", dest="compare_only", action="store_true",
454
+ help="<optional> Whether to give advisor.", required=False)
455
+ parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true",
456
+ help="<optional> Whether to perform a fuzzy match on the api name.", required=False)
457
+ parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True,
458
+ help="<optional> The cell mapping file path.", required=False)
459
+ parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
460
+ help="<optional> The api mapping file path.", required=False)
461
+ parser.add_argument("-dm", "--data_mapping", dest="data_mapping", type=str,
462
+ help="<optional> The data mapping file path.", required=False)
463
+ parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str,
464
+ help="<optional> The layer mapping file path.", required=False)