mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -0,0 +1,106 @@
1
+ import numpy as np
2
+ import mindspore as ms
3
+
4
+ from msprobe.core.common.const import Const as CoreConst
5
+
6
+
7
+ class Const:
8
+ CELL = "cell"
9
+ API = "api"
10
+ KERNEL = "kernel"
11
+ TOOL_LEVEL_DICT = {
12
+ CoreConst.LEVEL_L0: CELL,
13
+ CoreConst.LEVEL_L1: API,
14
+ CoreConst.LEVEL_L2: KERNEL
15
+ }
16
+ PYNATIVE_MODE = "pynative"
17
+ GRAPH_GE_MODE = "graph_ge"
18
+ GRAPH_KBYK_MODE = "graph_kbyk"
19
+ JIT_LEVEL = "jit_level"
20
+ JIT_LEVEL_O0 = "O0"
21
+ JIT_LEVEL_O1 = "O1"
22
+ JIT_LEVEL_O2 = "O2"
23
+ ASCEND_910A = "ascend910"
24
+
25
+ OPS_PREFIX = "mindspore.ops."
26
+ Tensor_PREFIX = "mindspore.Tensor."
27
+ MINT_PREFIX = "mindspore.mint."
28
+ MINT_NN_FUNC_PREFIX = "mindspore.mint.nn.functional."
29
+ COMM_PREFIX = "mindspore.communication.comm_func."
30
+ COMMUNICATION_API_LIST = [
31
+ "mindspore.communication.comm_func.all_gather_into_tensor",
32
+ "mindspore.communication.comm_func.gather_into_tensor",
33
+ "mindspore.communication.comm_func.all_reduce",
34
+ "mindspore.communication.comm_func.reduce",
35
+ "mindspore.communication.comm_func.reduce_scatter_tensor"
36
+ ]
37
+ TENSOR_DATA_PREFIX = "Tensor."
38
+ STUB_TENSOR_DATA_PREFIX = "Tensor."
39
+ OPS_DATA_PREFIX = "Functional."
40
+ MINT_DATA_PREFIX = "Mint."
41
+ MINT_NN_FUNC_DATA_PREFIX = "MintFunctional."
42
+
43
+ SUPPORTED_API_LIST_FILE = "support_wrap_ops.yaml"
44
+ SUPPORTED_TENSOR_LIST_KEY = "tensor"
45
+ SUPPORTED_OPS_LIST_KEY = "ops"
46
+ SUPPORTED_MINT_LIST_KEY = "mint.ops"
47
+ SUPPORTED__MINT_NN_FUNC_LIST_KEY = "mint.nn.functional"
48
+
49
+ DROPOUT_API_NAME_PREFIX = "dropout"
50
+
51
+
52
+ class FreeBenchmarkConst:
53
+ ADD_NOISE = "add_noise"
54
+ BIT_NOISE = "bit_noise"
55
+ NO_CHANGE = "no_change"
56
+ EXCHANGE_VALUE = "change_value"
57
+ IMPROVE_PRECISION = "improve_precision"
58
+ CHECK = "check"
59
+ FIX = "fix"
60
+ DEFAULT_DEVICE = "npu"
61
+ DEFAULT_STAGE = CoreConst.FORWARD
62
+ DEFAULT_DUMP_LEVEL = "L1"
63
+ DEFAULT_PERT_TYPE = IMPROVE_PRECISION
64
+ DEFAULT_HANDLER_TYPE = CHECK
65
+ DEVICE_LIST = [DEFAULT_DEVICE]
66
+ STAGE_LIST = [CoreConst.FORWARD]
67
+ DUMP_LEVEL_LIST = [DEFAULT_DUMP_LEVEL]
68
+ PERT_TYPE_LIST = [IMPROVE_PRECISION, ADD_NOISE, BIT_NOISE, NO_CHANGE, EXCHANGE_VALUE]
69
+ HANDLER_TYPE_LIST = [CHECK, FIX]
70
+ NO_CHANGE_ERROR_THRESHOLD = 1.0
71
+ SYMBOL_FLIPPING_RATIO = 8.0
72
+
73
+ API_PREFIX_DICT = {
74
+ "ops": Const.OPS_PREFIX,
75
+ "Tensor": Const.Tensor_PREFIX,
76
+ "mint": Const.MINT_PREFIX,
77
+ "mint.nn.functional": Const.MINT_NN_FUNC_PREFIX,
78
+ "communication": Const.COMM_PREFIX
79
+ }
80
+
81
+ PERT_VALUE_DICT = {
82
+ ms.bfloat16: 1e-4,
83
+ ms.float16: 1e-6,
84
+ ms.float32: 1e-8,
85
+ ms.float64: 1e-16
86
+ }
87
+
88
+ ERROR_THRESHOLD = {
89
+ ms.float16: 1.002,
90
+ ms.float32: 1.0002
91
+ }
92
+
93
+ PERT_BIT_DICT = {
94
+ ms.float16: np.int16,
95
+ ms.float32: np.int32,
96
+ ms.float64: np.int64
97
+ }
98
+
99
+ MS_NUMPY_DTYPE_DICT = {
100
+ ms.int16: np.int16,
101
+ ms.int32: np.int32,
102
+ ms.int64: np.int64,
103
+ ms.float16: np.float16,
104
+ ms.float32: np.float32,
105
+ ms.float64: np.float64
106
+ }
@@ -0,0 +1,38 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ import os
17
+ import time
18
+ import sys
19
+
20
+ from msprobe.mindspore.common.utils import get_rank_if_initialized
21
+ from msprobe.core.common.log import BaseLogger
22
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
23
+
24
+
25
+ class MindsporeLogger(BaseLogger):
26
+ def __init__(self):
27
+ super().__init__()
28
+
29
+ def get_rank(self):
30
+ try:
31
+ current_rank = get_rank_if_initialized()
32
+ except DistributedNotInitializedError:
33
+ current_rank = None
34
+
35
+ return current_rank
36
+
37
+
38
+ logger = MindsporeLogger()
@@ -0,0 +1,81 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ import os
16
+ import mindspore as ms
17
+
18
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
19
+ from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy
20
+ from msprobe.core.common.log import logger
21
+
22
+
23
+ def get_rank_if_initialized():
24
+ if ms.communication.GlobalComm.INITED:
25
+ return ms.communication.get_rank()
26
+ else:
27
+ raise DistributedNotInitializedError("mindspore distributed environment is not initialized")
28
+
29
+
30
+ def convert_bf16_to_fp32(tensor):
31
+ if tensor.dtype == ms.bfloat16:
32
+ tensor = tensor.to(ms.float32)
33
+ return tensor
34
+
35
+
36
+ def save_tensor_as_npy(tensor, file_path):
37
+ if not path_len_exceeds_limit(file_path):
38
+ tensor = convert_bf16_to_fp32(tensor)
39
+ saved_tensor = tensor.asnumpy()
40
+ save_npy(saved_tensor, file_path)
41
+ else:
42
+ logger.warning(f'The file path {file_path} length exceeds limit.')
43
+
44
+
45
+ def convert_to_int(value):
46
+ try:
47
+ return int(value)
48
+ except Exception:
49
+ return -1
50
+
51
+
52
+ def list_lowest_level_directories(root_dir):
53
+ check_path_exists(root_dir)
54
+ lowest_level_dirs = []
55
+
56
+ def recurse_dirs(current_dir):
57
+ for entry in os.listdir(current_dir):
58
+ full_path = os.path.join(current_dir, entry)
59
+ if os.path.isdir(full_path):
60
+ if any(os.path.isdir(os.path.join(full_path, subentry)) for subentry in os.listdir(full_path)):
61
+ recurse_dirs(full_path)
62
+ else:
63
+ lowest_level_dirs.append(full_path)
64
+
65
+ recurse_dirs(root_dir)
66
+ return lowest_level_dirs
67
+
68
+
69
+
70
+ class MsprobeStep(ms.train.Callback):
71
+
72
+ def __init__(self, debugger):
73
+ super(MsprobeStep, self).__init__()
74
+ self.debugger = debugger
75
+
76
+ def on_train_step_begin(self, run_context):
77
+ self.debugger.start()
78
+
79
+ def on_train_step_end(self, run_context):
80
+ self.debugger.stop()
81
+ self.debugger.step()
@@ -0,0 +1,75 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2019-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ import os
18
+ from msprobe.core.common.utils import CompareException, check_compare_param, \
19
+ check_configuration_param, task_dumppath_get
20
+ from msprobe.core.common.file_utils import create_directory
21
+ from msprobe.core.common.exceptions import FileCheckException
22
+ from msprobe.mindspore.common.log import logger
23
+ from msprobe.mindspore.compare.ms_compare import MSComparator
24
+ from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
25
+ from msprobe.mindspore.compare.ms_graph_compare import GraphMSComparator
26
+
27
+ def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
28
+ if kwargs.get('suffix'):
29
+ logger.error("Argument 'suffix' is not supported for compare_distributed.")
30
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
31
+ stack_mode = kwargs.get('stack_mode', False)
32
+ auto_analyze = kwargs.get('auto_analyze', True)
33
+ fuzzy_match = kwargs.get('fuzzy_match', False)
34
+ # get the ranks and match by order
35
+ npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
36
+ bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
37
+ if len(npu_ranks) != len(bench_ranks):
38
+ logger.error('The number of ranks in the two runs are different. '
39
+ 'Unable to match the ranks. Please use another folder to compare '
40
+ 'or use compare() api and manually match the ranks.')
41
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
42
+ for nr, br in zip(npu_ranks, bench_ranks):
43
+ npu_data_dir = os.path.join(npu_dump_dir, nr)
44
+ bench_data_dir = os.path.join(bench_dump_dir, br)
45
+ npu_path = extract_json(npu_data_dir, stack_json=False)
46
+ bench_path = extract_json(bench_data_dir, stack_json=False)
47
+ stack_path = extract_json(npu_data_dir, stack_json=True)
48
+
49
+ dump_result_param = {
50
+ 'npu_json_path': npu_path,
51
+ 'bench_json_path': bench_path,
52
+ 'stack_json_path': stack_path,
53
+ 'is_print_compare_log': True
54
+ }
55
+ try:
56
+ summary_compare, md5_compare = task_dumppath_get(dump_result_param)
57
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
58
+ create_directory(output_path)
59
+ check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare)
60
+ except (CompareException, FileCheckException) as error:
61
+ logger.error('Compare failed. Please check the arguments and do it again!')
62
+ raise CompareException(error.code) from error
63
+ ms_comparator = MSComparator()
64
+ ms_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare,
65
+ md5_compare=md5_compare, **kwargs)
66
+
67
+
68
+ def ms_graph_compare(inputs, outputs):
69
+ try:
70
+ create_directory(outputs)
71
+ except (CompareException, FileCheckException) as error:
72
+ logger.error('Compare failed. Please check the arguments and do it again!')
73
+ return
74
+ msComparator = GraphMSComparator(inputs, outputs)
75
+ msComparator.compare_core()
@@ -0,0 +1,219 @@
1
+ import os
2
+ import copy
3
+ from msprobe.core.common.utils import check_compare_param, CompareException, check_configuration_param, \
4
+ task_dumppath_get
5
+ from msprobe.core.common.file_utils import create_directory, load_yaml, load_npy
6
+ from msprobe.core.common.const import Const, CompareConst
7
+ from msprobe.core.common.log import logger
8
+ from msprobe.core.common.exceptions import FileCheckException
9
+ from msprobe.core.compare.acc_compare import Comparator
10
+ from msprobe.core.compare.check import check_struct_match, fuzzy_check_op
11
+
12
+
13
+ class MSComparator(Comparator):
14
+ def __init__(self, cell_mapping=None, api_mapping=None):
15
+ self.frame_name = MSComparator.__name__
16
+ self.cell_mapping = cell_mapping
17
+ self.api_mapping = api_mapping
18
+ self.cross_frame = cell_mapping is not None or api_mapping is not None
19
+ self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
20
+ self.api_mapping_dict = self.load_mapping_file(self.api_mapping)
21
+ if api_mapping is not None:
22
+ self.ms_to_pt_mapping = self.load_internal_api()
23
+
24
+ def load_internal_api(self):
25
+ cur_path = os.path.dirname(os.path.realpath(__file__))
26
+ yaml_path = os.path.join(cur_path,"ms_to_pt_api.yaml")
27
+ return load_yaml(yaml_path)
28
+
29
+ def load_mapping_file(self, mapping_file):
30
+ if isinstance(mapping_file, str):
31
+ mapping_dict = load_yaml(mapping_file)
32
+ else:
33
+ mapping_dict = {}
34
+ return mapping_dict
35
+
36
+ def process_cell_mapping(self, npu_op_name):
37
+ npu_op_name = [op_name.replace("Cell", "Module", 1) for op_name in npu_op_name]
38
+ if self.cell_mapping_dict:
39
+ for index, op_name in enumerate(npu_op_name):
40
+ # get cell name & class name from op_name
41
+ # Cell.fc1.Dense.forward.0.input.0
42
+ cell_name = op_name.split(Const.SEP, 1)[-1].rsplit(Const.SEP, 4)[0]
43
+ if cell_name in self.cell_mapping_dict:
44
+ npu_op_name[index] = op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
45
+ return npu_op_name
46
+
47
+ def check_op(self, npu_dict, bench_dict, fuzzy_match):
48
+ npu_dict_new, bench_dict_new = copy.deepcopy(npu_dict), copy.deepcopy(bench_dict)
49
+ npu_op_name, bench_op_name = npu_dict_new.get(CompareConst.OP_NAME), bench_dict_new.get(CompareConst.OP_NAME)
50
+ if self.cell_mapping is not None:
51
+ npu_op_name = self.process_cell_mapping(npu_op_name)
52
+ if self.api_mapping is not None:
53
+ npu_op_name = self.process_internal_api_mapping(npu_op_name, bench_op_name)
54
+ if isinstance(self.api_mapping, str):
55
+ npu_dict_new, bench_dict_new, target_dict = self.transform_user_mapping_api(npu_dict_new, bench_dict_new)
56
+ if target_dict:
57
+ bench_dict = self.reconstitution_bench_dict(npu_dict, copy.deepcopy(bench_dict_new), target_dict)
58
+ npu_op_name, bench_op_name = npu_dict_new.get(CompareConst.OP_NAME), bench_dict_new.get(CompareConst.OP_NAME)
59
+ struct_match = check_struct_match(npu_dict_new, bench_dict_new, cross_frame=self.cross_frame)
60
+ if not fuzzy_match:
61
+ return npu_op_name == bench_op_name and struct_match
62
+ is_match = True
63
+ try:
64
+ is_match = fuzzy_check_op(npu_op_name, bench_op_name)
65
+ except Exception as err:
66
+ logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
67
+ is_match = False
68
+ return is_match and struct_match
69
+
70
+ def read_npy_data(self, dir_path, file_name, load_pt_file=False):
71
+ data_path = os.path.join(dir_path, file_name)
72
+ if load_pt_file:
73
+ import torch
74
+ from msprobe.pytorch.common.utils import load_pt
75
+ data_value = load_pt(data_path).detach()
76
+ if data_value.dtype == torch.bfloat16:
77
+ data_value = data_value.to(torch.float32)
78
+ data_value = data_value.numpy()
79
+ else:
80
+ data_value = load_npy(data_path)
81
+ return data_value
82
+
83
+ def api_replace(self, npu_op_name, target, para):
84
+ for idx, _ in enumerate(npu_op_name):
85
+ npu_op_name[idx] = npu_op_name[idx].replace(target, para)
86
+ return npu_op_name
87
+
88
+ def process_internal_api_mapping(self, npu_op_name, bench_op_name):
89
+ # get api name & class name from op_name
90
+ # Functional.addcmul.0.forward.input.0
91
+ npu_op_name, bench_op_name = npu_op_name.copy(), bench_op_name.copy()
92
+ ms_api_name = self.get_api_name(npu_op_name[0].split(Const.SEP))
93
+ pt_api_name = self.get_api_name(bench_op_name[0].split(Const.SEP))
94
+ class_name = ms_api_name.split(Const.SEP)[0]
95
+ if class_name == "Mint":
96
+ return self.api_replace(npu_op_name, "Mint", "Torch")
97
+ elif class_name == "MintFunctional":
98
+ return self.api_replace(npu_op_name, "MintFunctional", "Functional")
99
+ elif self.ms_to_pt_mapping.get(ms_api_name) == pt_api_name:
100
+ return self.api_replace(npu_op_name, ms_api_name, pt_api_name)
101
+ else:
102
+ return npu_op_name
103
+
104
+ def remove_element(self, op_name, struct, summary, idx):
105
+ del op_name[idx]
106
+ del struct[idx]
107
+ del summary[idx]
108
+
109
+ def get_api_name(self, api_list):
110
+ return api_list[0] + Const.SEP + api_list[1]
111
+
112
+ def transform_user_mapping_api(self, new_npu_dict, new_bench_dict):
113
+ """
114
+ Transform user mapping API based on new NPU and benchmark dictionaries.
115
+ Parameters:
116
+ new_npu_dict (dict): New NPU operation dictionary.
117
+ new_bench_dict (dict): New benchmark operation dictionary.
118
+ Returns:
119
+ tuple: Updated NPU and benchmark dictionaries, along with the target dictionary.
120
+ """
121
+ npu_op_name, bench_op_name = new_npu_dict.get(CompareConst.OP_NAME), new_bench_dict.get(CompareConst.OP_NAME)
122
+ npu_struct_in, bench_struct_in = new_npu_dict.get(CompareConst.INPUT_STRUCT), new_bench_dict.get(CompareConst.INPUT_STRUCT)
123
+ npu_struct_out, bench_struct_out = new_npu_dict.get(CompareConst.OUTPUT_STRUCT), new_bench_dict.get(CompareConst.OUTPUT_STRUCT)
124
+ npu_summary, bench_summary = new_npu_dict.get(CompareConst.SUMMARY), new_bench_dict.get(CompareConst.SUMMARY)
125
+ npu_in_len, bench_in_len, npu_out_len, bench_out_len = len(npu_struct_in), len(bench_struct_in), len(npu_struct_out), len(bench_struct_out)
126
+ ms_api_list, pt_api_list = npu_op_name[0].split(Const.SEP), bench_op_name[0].split(Const.SEP)
127
+ ms_api_name = self.get_api_name(ms_api_list)
128
+ pt_api_name = self.get_api_name(pt_api_list)
129
+ target_dict = {}
130
+ for api_dict in self.api_mapping_dict:
131
+ if api_dict.get("pt_api") == pt_api_name and api_dict.get("ms_api") == ms_api_name:
132
+ ms_user_args_len, pt_user_args_len = len(api_dict.get("ms_args")), len(api_dict.get("pt_args"))
133
+ ms_user_output_len, pt_user_output_len = len(api_dict.get("ms_output")), len(api_dict.get("pt_output"))
134
+ if ms_user_args_len != pt_user_args_len or ms_user_output_len != pt_user_output_len:
135
+ logger.warning("The user-defined mapping table is incorrect, make sure that the number of parameters is equal" )
136
+ break
137
+ ms_out_list = api_dict.get("ms_output", [])
138
+ for idx in reversed(range(npu_out_len)):
139
+ if idx not in ms_out_list:
140
+ del npu_struct_out[idx]
141
+ del npu_summary[idx + npu_in_len]
142
+ del npu_op_name[idx + npu_in_len]
143
+ pt_out_list = api_dict.get("pt_output", [])
144
+ for idx in reversed(range(bench_out_len)):
145
+ if idx not in pt_out_list:
146
+ del bench_struct_out[idx]
147
+ del bench_summary[idx + bench_in_len]
148
+ del bench_op_name[idx + bench_in_len]
149
+ ms_para_list = api_dict.get("ms_args", [])
150
+ for idx in reversed(range(npu_in_len)):
151
+ if idx not in ms_para_list:
152
+ self.remove_element(npu_op_name, npu_struct_in, npu_summary, idx)
153
+ pt_para_list = api_dict.get("pt_args", [])
154
+ for idx in reversed(range(bench_in_len)):
155
+ if idx not in pt_para_list:
156
+ self.remove_element(bench_op_name, bench_struct_in, bench_summary, idx)
157
+ npu_op_name = self.api_replace(npu_op_name, ms_api_name, pt_api_name)
158
+ npu_op_name = self.para_sequence_update(npu_op_name, bench_op_name)
159
+ target_dict = api_dict
160
+ break
161
+ if target_dict:
162
+ new_npu_dict.update({CompareConst.OP_NAME: npu_op_name, CompareConst.INPUT_STRUCT: npu_struct_in, CompareConst.OUTPUT_STRUCT: npu_struct_out, CompareConst.SUMMARY: npu_summary})
163
+ new_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in, CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
164
+ return new_npu_dict, new_bench_dict, target_dict
165
+
166
+ def para_sequence_update(self, npu_op_name, bench_op_name):
167
+ for idx, _ in enumerate(npu_op_name):
168
+ bench_op_name_list = bench_op_name[idx].rsplit(Const.SEP, 1)
169
+ if len(bench_op_name_list) != 0:
170
+ npu_op_name[idx] = npu_op_name[idx][:-1] + bench_op_name_list[-1]
171
+ return npu_op_name
172
+
173
+ def reconstitution_bench_dict(self, npu_dict, del_bench_dict, api_dict):
174
+ ms_user_args_list = api_dict.get("ms_args", [])
175
+ ms_user_output_list = api_dict.get("ms_output", [])
176
+ npu_struct_in = npu_dict.get(CompareConst.INPUT_STRUCT)
177
+ npu_struct_out = npu_dict.get(CompareConst.OUTPUT_STRUCT)
178
+ npu_in_len = len(npu_struct_in)
179
+ npu_out_len = len(npu_struct_out)
180
+ if npu_in_len == len(ms_user_args_list) and npu_out_len == len(ms_user_output_list):
181
+ return del_bench_dict
182
+ ms_input_args_list = [i for i in range(npu_in_len)]
183
+ input_sub_list =list(set(ms_input_args_list) - set(ms_user_args_list))
184
+ ms_output_args_list = [i for i in range(npu_out_len)]
185
+ output_sub_list =list(set(ms_output_args_list) - set(ms_user_output_list))
186
+ bench_op_name = del_bench_dict.get(CompareConst.OP_NAME, [])
187
+ bench_struct_in = del_bench_dict.get(CompareConst.INPUT_STRUCT, [])
188
+ bench_struct_out = del_bench_dict.get(CompareConst.OUTPUT_STRUCT, [])
189
+ bench_summary = del_bench_dict.get(CompareConst.SUMMARY, [])
190
+ for idx in input_sub_list: # Fill in the blank value field in the pt dictionary
191
+ bench_op_name.insert(idx, CompareConst.NAN)
192
+ bench_struct_in.insert(idx, CompareConst.NAN)
193
+ bench_summary.insert(idx, CompareConst.NAN)
194
+ for idx in output_sub_list: # Fill in the blank value field in the pt dictionary
195
+ bench_op_name.insert(npu_in_len + idx, CompareConst.NAN)
196
+ bench_struct_out.insert(idx, CompareConst.NAN)
197
+ bench_summary.insert(npu_in_len + idx, CompareConst.NAN)
198
+ del_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in, CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
199
+ return del_bench_dict
200
+
201
+
202
+ def ms_compare(input_param, output_path, **kwargs):
203
+ try:
204
+ stack_mode = kwargs.get('stack_mode', False)
205
+ auto_analyze = kwargs.get('auto_analyze', True)
206
+ fuzzy_match = kwargs.get('fuzzy_match', False)
207
+ cell_mapping = kwargs.get('cell_mapping', None)
208
+ api_mapping = kwargs.get('api_mapping', None)
209
+ summary_compare, md5_compare = task_dumppath_get(input_param)
210
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
211
+ create_directory(output_path)
212
+ check_compare_param(input_param, output_path, summary_compare, md5_compare)
213
+ except (CompareException, FileCheckException) as error:
214
+ logger.error('Compare failed. Please check the arguments and do it again!')
215
+ raise CompareException(error.code) from error
216
+ ms_comparator = MSComparator(cell_mapping, api_mapping)
217
+ ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
218
+ auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
219
+ md5_compare=md5_compare)