mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,117 +1,357 @@
1
- import os.path
2
- from msprobe.core.common.utils import check_compare_param, CompareException, check_configuration_param, \
3
- task_dumppath_get, load_yaml, load_npy
4
- from msprobe.core.common.file_check import create_directory
5
- from msprobe.core.common.const import Const
6
- from msprobe.core.common.log import logger
7
- from msprobe.core.common.exceptions import FileCheckException
8
- from msprobe.core.compare.acc_compare import Comparator
9
- from msprobe.core.compare.check import check_struct_match, fuzzy_check_op
10
-
11
-
12
- class MSComparator(Comparator):
13
- def __init__(self, cell_mapping=None, api_mapping=None):
14
- self.frame_name = MSComparator.__name__
15
- self.cell_mapping = cell_mapping
16
- self.api_mapping = api_mapping
17
- self.cross_frame = cell_mapping is not None or api_mapping is not None
18
- self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
19
- self.api_mapping_dict = {}
20
- if api_mapping is not None:
21
- self.ms_to_pt_mapping = self.load_internal_api()
22
-
23
- def load_internal_api(self):
24
- cur_path = os.path.dirname(os.path.realpath(__file__))
25
- yaml_path = os.path.join(cur_path,"ms_to_pt_api.yaml")
26
- return load_yaml(yaml_path)
27
-
28
- def load_mapping_file(self, mapping_file):
29
- if isinstance(mapping_file, str):
30
- mapping_dict = load_yaml(mapping_file)
31
- else:
32
- mapping_dict = {}
33
- return mapping_dict
34
-
35
- def process_cell_mapping(self, npu_op_name):
36
- npu_op_name = [op_name.replace("Cell", "Module", 1) for op_name in npu_op_name]
37
- if self.cell_mapping_dict:
38
- for index, op_name in enumerate(npu_op_name):
39
- # get cell name & class name from op_name
40
- # Cell.fc1.Dense.forward.0.input.0
41
- cell_name = op_name.split(Const.SEP, 1)[-1].rsplit(Const.SEP, 4)[0]
42
- if cell_name in self.cell_mapping_dict:
43
- npu_op_name[index] = op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
44
- return npu_op_name
45
-
46
- def check_op(self, npu_dict, bench_dict, fuzzy_match):
47
- npu_op_name = npu_dict["op_name"].copy()
48
- bench_op_name = bench_dict["op_name"].copy()
49
-
50
- if self.api_mapping is not None:
51
- npu_op_name = self.process_api_mapping(npu_op_name, bench_op_name)
52
- if self.cell_mapping is not None:
53
- npu_op_name = self.process_cell_mapping(npu_op_name)
54
-
55
- struct_match = check_struct_match(npu_dict, bench_dict, cross_frame=self.cross_frame)
56
- if not fuzzy_match:
57
- return npu_op_name == bench_op_name and struct_match
58
- is_match = True
59
- try:
60
- is_match = fuzzy_check_op(npu_op_name, bench_op_name)
61
- except Exception as err:
62
- logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
63
- is_match = False
64
- return is_match and struct_match
65
-
66
- def read_npy_data(self, dir_path, file_name, load_pt_file=False):
67
- data_path = os.path.join(dir_path, file_name)
68
- if load_pt_file:
69
- import torch
70
- from msprobe.pytorch.common.utils import load_pt
71
- data_value = load_pt(data_path).detach()
72
- if data_value.dtype == torch.bfloat16:
73
- data_value = data_value.to(torch.float32)
74
- data_value = data_value.numpy()
75
- else:
76
- data_value = load_npy(data_path)
77
- return data_value
78
-
79
- def api_replace(self, npu_op_name, target, para):
80
- for idx, _ in enumerate(npu_op_name):
81
- npu_op_name[idx] = npu_op_name[idx].replace(target, para)
82
- return npu_op_name
83
-
84
- def process_api_mapping(self, npu_op_name, bench_op_name):
85
- # get api name & class name from op_name
86
- # Functional.addcmul.0.forward.input.0
87
- ms_api_name = npu_op_name[0].rsplit(Const.SEP, 4)[0]
88
- pt_api_name = bench_op_name[0].rsplit(Const.SEP, 4)[0]
89
- class_name = ms_api_name.split(Const.SEP)[0]
90
- if class_name == "Mint":
91
- return self.api_replace(npu_op_name, "Mint", "Torch")
92
- elif class_name == "MintFunctional":
93
- return self.api_replace(npu_op_name, "MintFunctional", "Functional")
94
- elif self.ms_to_pt_mapping.get(ms_api_name) == pt_api_name:
95
- return self.api_replace(npu_op_name, ms_api_name, pt_api_name)
96
- else:
97
- return npu_op_name
98
-
99
-
100
- def ms_compare(input_param, output_path, **kwargs):
101
- try:
102
- stack_mode = kwargs.get('stack_mode', False)
103
- auto_analyze = kwargs.get('auto_analyze', True)
104
- fuzzy_match = kwargs.get('fuzzy_match', False)
105
- cell_mapping = kwargs.get('cell_mapping', None)
106
- api_mapping = kwargs.get('api_mapping', None)
107
- summary_compare, md5_compare = task_dumppath_get(input_param)
108
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
109
- create_directory(output_path)
110
- check_compare_param(input_param, output_path, summary_compare, md5_compare)
111
- except (CompareException, FileCheckException) as error:
112
- logger.error('Compare failed. Please check the arguments and do it again!')
113
- raise CompareException(error.code) from error
114
- ms_comparator = MSComparator(cell_mapping, api_mapping)
115
- ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
116
- auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
117
- md5_compare=md5_compare)
1
+ import os
2
+ import re
3
+ import copy
4
+ import sys
5
+ from itertools import zip_longest
6
+
7
+ from msprobe.core.common.utils import check_compare_param, CompareException, check_configuration_param, \
8
+ task_dumppath_get, struct_json_get, add_time_with_yaml
9
+ from msprobe.core.common.file_utils import create_directory, load_yaml, load_npy, load_json, save_yaml, FileOpen
10
+ from msprobe.core.common.const import Const, CompareConst
11
+ from msprobe.core.common.log import logger
12
+ from msprobe.core.common.exceptions import FileCheckException
13
+ from msprobe.core.compare.acc_compare import Comparator
14
+ from msprobe.core.compare.check import check_struct_match, fuzzy_check_op
15
+ from msprobe.mindspore.compare.modify_mapping import modify_mapping_with_stack
16
+ from msprobe.mindspore.compare.layer_mapping import get_layer_mapping
17
+
18
+ class MSComparator(Comparator):
19
+ def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None, is_cross_framework=False):
20
+ self.frame_name = MSComparator.__name__
21
+ self.cell_mapping = cell_mapping
22
+ self.api_mapping = api_mapping
23
+ self.data_mapping = data_mapping
24
+ if data_mapping:
25
+ self.cross_frame = is_cross_framework
26
+ else:
27
+ self.cross_frame = cell_mapping is not None or api_mapping is not None
28
+ self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
29
+ self.api_mapping_dict = self.load_mapping_file(self.api_mapping)
30
+ if api_mapping is not None:
31
+ self.ms_to_pt_mapping = self.load_internal_api()
32
+
33
+ if isinstance(self.data_mapping, str) or self.data_mapping is None:
34
+ self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
35
+ elif isinstance(self.data_mapping, dict):
36
+ self.data_mapping_dict = self.data_mapping
37
+ else:
38
+ raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
39
+ f"{type(self.data_mapping)}")
40
+
41
+ def load_internal_api(self):
42
+ cur_path = os.path.dirname(os.path.realpath(__file__))
43
+ yaml_path = os.path.join(cur_path, "ms_to_pt_api.yaml")
44
+ return load_yaml(yaml_path)
45
+
46
+ def load_mapping_file(self, mapping_file):
47
+ if isinstance(mapping_file, str):
48
+ mapping_dict = load_yaml(mapping_file)
49
+ else:
50
+ mapping_dict = {}
51
+ return mapping_dict
52
+
53
+ def process_cell_mapping(self, npu_op_name):
54
+ npu_op_name = [op_name.replace("Cell", "Module", 1) for op_name in npu_op_name]
55
+ if self.cell_mapping_dict:
56
+ for index, op_name in enumerate(npu_op_name):
57
+ # get cell name & class name from op_name
58
+ # Cell.fc1.Dense.forward.0.input.0
59
+ cell_name = op_name.split(Const.SEP, 1)[-1].rsplit(Const.SEP, 4)[0]
60
+ if cell_name in self.cell_mapping_dict:
61
+ npu_op_name[index] = op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
62
+ return npu_op_name
63
+
64
+ def check_op(self, npu_dict, bench_dict, fuzzy_match):
65
+ npu_dict_new, bench_dict_new = copy.deepcopy(npu_dict), copy.deepcopy(bench_dict)
66
+ npu_op_name, bench_op_name = npu_dict_new.get(CompareConst.OP_NAME), bench_dict_new.get(CompareConst.OP_NAME)
67
+ if self.cell_mapping is not None:
68
+ npu_op_name = self.process_cell_mapping(npu_op_name)
69
+ if self.api_mapping is not None:
70
+ npu_op_name = self.process_internal_api_mapping(npu_op_name, bench_op_name)
71
+ if isinstance(self.api_mapping, str):
72
+ npu_dict_new, bench_dict_new, target_dict = self.transform_user_mapping_api(npu_dict_new,
73
+ bench_dict_new)
74
+ if target_dict:
75
+ bench_dict = self.reconstitution_bench_dict(npu_dict, copy.deepcopy(bench_dict_new), target_dict)
76
+ npu_op_name = npu_dict_new.get(CompareConst.OP_NAME)
77
+ bench_op_name = bench_dict_new.get(CompareConst.OP_NAME)
78
+ struct_match = check_struct_match(npu_dict_new, bench_dict_new, cross_frame=self.cross_frame)
79
+ if not fuzzy_match:
80
+ return npu_op_name == bench_op_name and struct_match
81
+ is_match = True
82
+ try:
83
+ is_match = fuzzy_check_op(npu_op_name, bench_op_name)
84
+ except Exception as err:
85
+ logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
86
+ is_match = False
87
+ return is_match and struct_match
88
+
89
+ def read_npy_data(self, dir_path, file_name, load_pt_file=False):
90
+ data_path = os.path.join(dir_path, file_name)
91
+ if load_pt_file:
92
+ import torch
93
+ from msprobe.pytorch.common.utils import load_pt
94
+ data_value = load_pt(data_path, True).detach()
95
+ if data_value.dtype == torch.bfloat16:
96
+ data_value = data_value.to(torch.float32)
97
+ data_value = data_value.numpy()
98
+ else:
99
+ data_value = load_npy(data_path)
100
+ return data_value
101
+
102
+ def api_replace(self, npu_op_name, target, para):
103
+ for idx, _ in enumerate(npu_op_name):
104
+ npu_op_name[idx] = npu_op_name[idx].replace(target, para)
105
+ return npu_op_name
106
+
107
+ def process_internal_api_mapping(self, npu_op_name, bench_op_name):
108
+ # get api name & class name from op_name
109
+ # Functional.addcmul.0.forward.input.0
110
+ npu_op_name, bench_op_name = npu_op_name.copy(), bench_op_name.copy()
111
+ ms_api_name = self.get_api_name(npu_op_name[0].split(Const.SEP))
112
+ pt_api_name = self.get_api_name(bench_op_name[0].split(Const.SEP))
113
+ class_name = ms_api_name.split(Const.SEP)[0]
114
+ if class_name == "Mint":
115
+ return self.api_replace(npu_op_name, "Mint", "Torch")
116
+ elif class_name == "MintFunctional":
117
+ return self.api_replace(npu_op_name, "MintFunctional", "Functional")
118
+ elif self.ms_to_pt_mapping.get(ms_api_name) == pt_api_name:
119
+ return self.api_replace(npu_op_name, ms_api_name, pt_api_name)
120
+ else:
121
+ return npu_op_name
122
+
123
+ def remove_element(self, op_name, struct, summary, idx):
124
+ del op_name[idx]
125
+ del struct[idx]
126
+ del summary[idx]
127
+
128
+ def get_api_name(self, api_list):
129
+ try:
130
+ api_name = api_list[0] + Const.SEP + api_list[1]
131
+ except IndexError as error:
132
+ logger.error(f'Failed to retrieve API name, please check if the dump data is reasonable')
133
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
134
+ return api_name
135
+
136
+ def transform_user_mapping_api(self, new_npu_dict, new_bench_dict):
137
+ """
138
+ Transform user mapping API based on new NPU and benchmark dictionaries.
139
+ Parameters:
140
+ new_npu_dict (dict): New NPU operation dictionary.
141
+ new_bench_dict (dict): New benchmark operation dictionary.
142
+ Returns:
143
+ tuple: Updated NPU and benchmark dictionaries, along with the target dictionary.
144
+ """
145
+ npu_op_name, bench_op_name = new_npu_dict.get(CompareConst.OP_NAME), new_bench_dict.get(CompareConst.OP_NAME)
146
+ npu_struct_in = new_npu_dict.get(CompareConst.INPUT_STRUCT)
147
+ bench_struct_in = new_bench_dict.get(CompareConst.INPUT_STRUCT)
148
+ npu_struct_out = new_npu_dict.get(CompareConst.OUTPUT_STRUCT)
149
+ bench_struct_out = new_bench_dict.get(CompareConst.OUTPUT_STRUCT)
150
+ npu_summary, bench_summary = new_npu_dict.get(CompareConst.SUMMARY), new_bench_dict.get(CompareConst.SUMMARY)
151
+ npu_in_len, bench_in_len = len(npu_struct_in), len(bench_struct_in)
152
+ npu_out_len, bench_out_len = len(npu_struct_out), len(bench_struct_out)
153
+ ms_api_list, pt_api_list = npu_op_name[0].split(Const.SEP), bench_op_name[0].split(Const.SEP)
154
+ ms_api_name = self.get_api_name(ms_api_list)
155
+ pt_api_name = self.get_api_name(pt_api_list)
156
+ target_dict = {}
157
+ for api_dict in self.api_mapping_dict:
158
+ if api_dict.get("pt_api") == pt_api_name and api_dict.get("ms_api") == ms_api_name:
159
+ ms_user_args_len, pt_user_args_len = len(api_dict.get("ms_args")), len(api_dict.get("pt_args"))
160
+ ms_user_output_len, pt_user_output_len = len(api_dict.get("ms_output")), len(api_dict.get("pt_output"))
161
+ if ms_user_args_len != pt_user_args_len or ms_user_output_len != pt_user_output_len:
162
+ logger.warning("The user-defined mapping table is incorrect,\
163
+ make sure that the number of parameters is equal")
164
+ break
165
+ ms_out_list = api_dict.get("ms_output", [])
166
+ for idx in reversed(range(npu_out_len)):
167
+ if idx not in ms_out_list:
168
+ del npu_struct_out[idx]
169
+ if idx + npu_in_len < len(npu_summary) and idx + npu_in_len < len(npu_op_name):
170
+ del npu_summary[idx + npu_in_len]
171
+ del npu_op_name[idx + npu_in_len]
172
+ pt_out_list = api_dict.get("pt_output", [])
173
+ for idx in reversed(range(bench_out_len)):
174
+ if idx not in pt_out_list:
175
+ del bench_struct_out[idx]
176
+ if idx + bench_in_len < len(bench_summary) and idx + bench_in_len < len(bench_op_name):
177
+ del bench_summary[idx + bench_in_len]
178
+ del bench_op_name[idx + bench_in_len]
179
+ ms_para_list = api_dict.get("ms_args", [])
180
+ for idx in reversed(range(npu_in_len)):
181
+ if idx not in ms_para_list:
182
+ self.remove_element(npu_op_name, npu_struct_in, npu_summary, idx)
183
+ pt_para_list = api_dict.get("pt_args", [])
184
+ for idx in reversed(range(bench_in_len)):
185
+ if idx not in pt_para_list:
186
+ self.remove_element(bench_op_name, bench_struct_in, bench_summary, idx)
187
+ npu_op_name = self.api_replace(npu_op_name, ms_api_name, pt_api_name)
188
+ npu_op_name = self.para_sequence_update(npu_op_name, bench_op_name)
189
+ target_dict = api_dict
190
+ break
191
+ if target_dict:
192
+ new_npu_dict.update({CompareConst.OP_NAME: npu_op_name, CompareConst.INPUT_STRUCT: npu_struct_in,
193
+ CompareConst.OUTPUT_STRUCT: npu_struct_out, CompareConst.SUMMARY: npu_summary})
194
+ new_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in,
195
+ CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
196
+ return new_npu_dict, new_bench_dict, target_dict
197
+
198
+ def para_sequence_update(self, npu_op_name, bench_op_name):
199
+ for idx, _ in enumerate(npu_op_name):
200
+ bench_op_name_list = bench_op_name[idx].rsplit(Const.SEP, 1)
201
+ if len(bench_op_name_list) != 0:
202
+ npu_op_name[idx] = npu_op_name[idx][:-1] + bench_op_name_list[-1]
203
+ return npu_op_name
204
+
205
+ def reconstitution_bench_dict(self, npu_dict, del_bench_dict, api_dict):
206
+ ms_user_args_list = api_dict.get("ms_args", [])
207
+ ms_user_output_list = api_dict.get("ms_output", [])
208
+ npu_struct_in = npu_dict.get(CompareConst.INPUT_STRUCT)
209
+ npu_struct_out = npu_dict.get(CompareConst.OUTPUT_STRUCT)
210
+ npu_in_len = len(npu_struct_in)
211
+ npu_out_len = len(npu_struct_out)
212
+ if npu_in_len == len(ms_user_args_list) and npu_out_len == len(ms_user_output_list):
213
+ return del_bench_dict
214
+ ms_input_args_list = [i for i in range(npu_in_len)]
215
+ input_sub_list = list(set(ms_input_args_list) - set(ms_user_args_list))
216
+ ms_output_args_list = [i for i in range(npu_out_len)]
217
+ output_sub_list = list(set(ms_output_args_list) - set(ms_user_output_list))
218
+ bench_op_name = del_bench_dict.get(CompareConst.OP_NAME, [])
219
+ bench_struct_in = del_bench_dict.get(CompareConst.INPUT_STRUCT, [])
220
+ bench_struct_out = del_bench_dict.get(CompareConst.OUTPUT_STRUCT, [])
221
+ bench_summary = del_bench_dict.get(CompareConst.SUMMARY, [])
222
+ for idx in input_sub_list: # Fill in the blank value field in the pt dictionary
223
+ bench_op_name.insert(idx, CompareConst.N_A)
224
+ bench_struct_in.insert(idx, CompareConst.N_A)
225
+ bench_summary.insert(idx, CompareConst.N_A)
226
+ for idx in output_sub_list: # Fill in the blank value field in the pt dictionary
227
+ bench_op_name.insert(npu_in_len + idx, CompareConst.N_A)
228
+ bench_struct_out.insert(idx, CompareConst.N_A)
229
+ bench_summary.insert(npu_in_len + idx, CompareConst.N_A)
230
+ del_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in,
231
+ CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
232
+ return del_bench_dict
233
+
234
+
235
+ def sort_by_execution_sequence(npu_data, bench_data, mapping_list, flag):
236
+ def generate_execution_sequence(data):
237
+ sequence_map = {}
238
+ for index, item in enumerate(data.keys()):
239
+ if flag in item:
240
+ item_split = item.split(Const.SEP)
241
+ item_name = Const.SEP.join(item_split[0:-2])
242
+ item_index = item_split[-1]
243
+ if item_index == 'forward' or item_index == 'backward':
244
+ item_index = item_split[-2]
245
+ item_key = f"{item_name}.{item_index}"
246
+ sequence_map[item_key] = index
247
+ return sequence_map
248
+
249
+ npu_map = generate_execution_sequence(npu_data)
250
+ bench_map = generate_execution_sequence(bench_data)
251
+
252
+ def sort_by_map(item):
253
+ first_key = npu_map.get(item[0], sys.maxsize)
254
+ second_key = bench_map.get(item[1], sys.maxsize)
255
+ return first_key, second_key
256
+
257
+ return sorted(mapping_list, key=sort_by_map)
258
+
259
+
260
+ def generate_kernel_data(map_value, data, flag):
261
+ if not map_value:
262
+ return [], []
263
+ inputs_name = []
264
+ outputs_name = []
265
+ map_split = map_value.split(Const.SEP)
266
+ map_name = Const.SEP.join(map_split[0:-1])
267
+ map_index = map_split[-1]
268
+ for key, value in data.items():
269
+ if key.find(flag) != -1 and key.find(map_name) != -1:
270
+ if key.split(Const.SEP)[-1] != map_index and key.split(Const.SEP)[-2] != map_index :
271
+ continue
272
+ if flag == 'forward':
273
+ input_args = value.get('input_args', {})
274
+ else:
275
+ input_args = value.get('input', {})
276
+ output_args = value.get('output', {})
277
+ for i in range(len(input_args)):
278
+ inputs_name.append(f"{key}.input.{i}")
279
+ for i in range(len(output_args)):
280
+ outputs_name.append(f"{key}.output.{i}")
281
+ return inputs_name, outputs_name
282
+
283
+
284
+ def generate_file_mapping(npu_json_path, bench_json_path, mapping_list):
285
+
286
+ npu_data = load_json(npu_json_path).get("data", {})
287
+ bench_data = load_json(bench_json_path).get("data", {})
288
+
289
+ forward_data = []
290
+ mapping_list = sort_by_execution_sequence(npu_data, bench_data, mapping_list, Const.FORWARD)
291
+ for map_value in mapping_list:
292
+ npu_forward_inputs, npu_backward_outputs = generate_kernel_data(map_value[0], npu_data, "forward")
293
+ bench_forward_inputs, bench_backward_outputs = generate_kernel_data(map_value[1], bench_data, "forward")
294
+ inputs_zip = list(zip_longest(npu_forward_inputs, bench_forward_inputs))
295
+ outputs_zip = list(zip_longest(npu_backward_outputs, bench_backward_outputs))
296
+ forward_data.extend(inputs_zip)
297
+ forward_data.extend(outputs_zip)
298
+
299
+ backward_data = []
300
+ mapping_list = sort_by_execution_sequence(npu_data, bench_data, mapping_list, Const.BACKWARD)
301
+ for map_value in mapping_list:
302
+ npu_forward_inputs, npu_backward_outputs = generate_kernel_data(map_value[0], npu_data, "backward")
303
+ bench_forward_inputs, bench_backward_outputs = generate_kernel_data(map_value[1], bench_data, "backward")
304
+ inputs_zip = list(zip_longest(npu_forward_inputs, bench_forward_inputs))
305
+ outputs_zip = list(zip_longest(npu_backward_outputs, bench_backward_outputs))
306
+ backward_data.extend(inputs_zip)
307
+ backward_data.extend(outputs_zip)
308
+
309
+ kernel_data = forward_data + backward_data
310
+ result = {key: value for key, value in kernel_data if key is not None}
311
+
312
+ return result
313
+
314
+
315
+ def check_cross_framework(bench_json_path):
316
+ pattern = r'"data_name":\s*"[^"]+\.pt"'
317
+ with FileOpen(bench_json_path, 'r') as file:
318
+ for line in file:
319
+ if re.search(pattern, line):
320
+ return True
321
+ return False
322
+
323
+
324
+ def ms_compare(input_param, output_path, **kwargs):
325
+ try:
326
+ stack_mode = kwargs.get('stack_mode', False)
327
+ auto_analyze = kwargs.get('auto_analyze', True)
328
+ fuzzy_match = kwargs.get('fuzzy_match', False)
329
+ cell_mapping = kwargs.get('cell_mapping', None)
330
+ api_mapping = kwargs.get('api_mapping', None)
331
+ data_mapping = kwargs.get('data_mapping', None)
332
+ layer_mapping = kwargs.get('layer_mapping', None)
333
+
334
+ summary_compare, md5_compare = task_dumppath_get(input_param)
335
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
336
+ create_directory(output_path)
337
+ check_compare_param(input_param, output_path, summary_compare, md5_compare)
338
+ except (CompareException, FileCheckException) as error:
339
+ logger.error('Compare failed. Please check the arguments and do it again!')
340
+ raise CompareException(error.code) from error
341
+ if layer_mapping:
342
+ pt_stack, pt_construct = struct_json_get(input_param, Const.PT_FRAMEWORK)
343
+ ms_stack, ms_construct = struct_json_get(input_param, Const.MS_FRAMEWORK)
344
+ mapping = load_yaml(layer_mapping)
345
+ ms_mapping_result = modify_mapping_with_stack(ms_stack, ms_construct)
346
+ pt_mapping_result = modify_mapping_with_stack(pt_stack, pt_construct)
347
+ layer_mapping = get_layer_mapping(ms_mapping_result, pt_mapping_result, mapping)
348
+ data_mapping = generate_file_mapping(input_param.get("npu_json_path"), input_param.get("bench_json_path"), layer_mapping)
349
+
350
+ data_mapping_name = add_time_with_yaml(f"data_mapping")
351
+ data_mapping_path = os.path.join(os.path.realpath(output_path), f"{data_mapping_name}")
352
+ save_yaml(data_mapping_path, data_mapping)
353
+ is_cross_framework = check_cross_framework(input_param.get("bench_json_path"))
354
+ ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping, is_cross_framework)
355
+ ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
356
+ auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
357
+ md5_compare=md5_compare)