mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,58 +1,85 @@
1
- from msprobe.core.common.const import Const
2
- from msprobe.core.common.log import logger
3
- from msprobe.core.common.exceptions import MsaccException
4
-
5
-
6
- class CommonConfig:
7
- def __init__(self, json_config):
8
- self.task = json_config.get('task')
9
- self.dump_path = json_config.get('dump_path')
10
- self.rank = json_config.get('rank')
11
- self.step = json_config.get('step')
12
- self.level = json_config.get('level')
13
- self.seed = json_config.get('seed')
14
- self.acl_config = json_config.get('acl_config')
15
- self.is_deterministic = json_config.get('is_deterministic', False)
16
- self.enable_dataloader = json_config.get('enable_dataloader', False)
17
- self._check_config()
18
-
19
- def _check_config(self):
20
- if self.task and self.task not in Const.TASK_LIST:
21
- logger.error_log_with_exp(
22
- "task is invalid, it should be one of {}".format(Const.TASK_LIST), MsaccException(MsaccException.INVALID_PARAM_ERROR))
23
- if self.rank is not None and not isinstance(self.rank, list):
24
- logger.error_log_with_exp("rank is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR))
25
- if self.step is not None and not isinstance(self.step, list):
26
- logger.error_log_with_exp("step is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR))
27
- if self.level and self.level not in Const.LEVEL_LIST:
28
- logger.error_log_with_exp(
29
- "level is invalid, it should be one of {}".format(Const.LEVEL_LIST), MsaccException(MsaccException.INVALID_PARAM_ERROR))
30
- if self.seed is not None and not isinstance(self.seed, int):
31
- logger.error_log_with_exp("seed is invalid, it should be an integer", MsaccException(MsaccException.INVALID_PARAM_ERROR))
32
- if not isinstance(self.is_deterministic, bool):
33
- logger.error_log_with_exp(
34
- "is_deterministic is invalid, it should be a boolean", MsaccException(MsaccException.INVALID_PARAM_ERROR))
35
- if not isinstance(self.enable_dataloader, bool):
36
- logger.error_log_with_exp(
37
- "enable_dataloader is invalid, it should be a boolean", MsaccException(MsaccException.INVALID_PARAM_ERROR))
38
-
39
-
40
- class BaseConfig:
41
- def __init__(self, json_config):
42
- self.scope = json_config.get('scope')
43
- self.list = json_config.get('list')
44
- self.data_mode = json_config.get('data_mode')
45
- self.backward_input = json_config.get("backward_input")
46
- self.file_format = json_config.get("file_format")
47
- self.summary_mode = json_config.get("summary_mode")
48
- self.overflow_num = json_config.get("overflow_num")
49
- self.check_mode = json_config.get("check_mode")
50
-
51
- def check_config(self):
52
- if self.scope is not None and not isinstance(self.scope, list):
53
- logger.error_log_with_exp("scope is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR))
54
- if self.list is not None and not isinstance(self.list, list):
55
- logger.error_log_with_exp("list is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR))
56
- if self.data_mode is not None and not isinstance(self.data_mode, list):
57
- logger.error_log_with_exp("data_mode is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR))
58
-
1
+ from msprobe.core.common.const import Const, FileCheckConst
2
+ from msprobe.core.common.log import logger
3
+ from msprobe.core.common.exceptions import MsprobeException
4
+ from msprobe.core.common.file_utils import FileChecker
5
+
6
+
7
+ class CommonConfig:
8
+ def __init__(self, json_config):
9
+ self.task = json_config.get('task')
10
+ self.dump_path = json_config.get('dump_path')
11
+ self.rank = json_config.get('rank')
12
+ self.step = json_config.get('step')
13
+ self.level = json_config.get('level')
14
+ self.seed = json_config.get('seed')
15
+ self.acl_config = json_config.get('acl_config')
16
+ self.is_deterministic = json_config.get('is_deterministic', False)
17
+ self.enable_dataloader = json_config.get('enable_dataloader', False)
18
+ self._check_config()
19
+
20
+ def _check_config(self):
21
+ if self.task and self.task not in Const.TASK_LIST:
22
+ logger.error_log_with_exp("task is invalid, it should be one of {}".format(Const.TASK_LIST),
23
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
24
+ if self.dump_path is not None and not isinstance(self.dump_path, str):
25
+ logger.error_log_with_exp("dump_path is invalid, it should be a string",
26
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
27
+ if self.rank is not None and not isinstance(self.rank, list):
28
+ logger.error_log_with_exp("rank is invalid, it should be a list",
29
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
30
+ if self.step is not None and not isinstance(self.step, list):
31
+ logger.error_log_with_exp("step is invalid, it should be a list",
32
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
33
+ if self.level and self.level not in Const.LEVEL_LIST:
34
+ logger.error_log_with_exp("level is invalid, it should be one of {}".format(Const.LEVEL_LIST),
35
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
36
+ if self.seed is not None and not isinstance(self.seed, int):
37
+ logger.error_log_with_exp("seed is invalid, it should be an integer",
38
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
39
+ if not isinstance(self.is_deterministic, bool):
40
+ logger.error_log_with_exp("is_deterministic is invalid, it should be a boolean",
41
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
42
+ if not isinstance(self.enable_dataloader, bool):
43
+ logger.error_log_with_exp("enable_dataloader is invalid, it should be a boolean",
44
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
45
+ if self.acl_config:
46
+ self._check_acl_config()
47
+
48
+ def _check_acl_config(self):
49
+ if not isinstance(self.acl_config, str):
50
+ logger.error_log_with_exp("acl_config is invalid, it should be a string",
51
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
52
+ file_checker = FileChecker(
53
+ file_path=self.acl_config, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
54
+ file_checker.common_check()
55
+
56
+
57
+ class BaseConfig:
58
+ def __init__(self, json_config):
59
+ self.scope = json_config.get('scope')
60
+ self.list = json_config.get('list')
61
+ self.data_mode = json_config.get('data_mode')
62
+ self.backward_input = json_config.get("backward_input")
63
+ self.file_format = json_config.get("file_format")
64
+ self.summary_mode = json_config.get("summary_mode")
65
+ self.overflow_nums = json_config.get("overflow_nums")
66
+ self.check_mode = json_config.get("check_mode")
67
+ self.fuzz_device = json_config.get("fuzz_device")
68
+ self.pert_mode = json_config.get("pert_mode")
69
+ self.handler_type = json_config.get("handler_type")
70
+ self.fuzz_level = json_config.get("fuzz_level")
71
+ self.fuzz_stage = json_config.get("fuzz_stage")
72
+ self.if_preheat = json_config.get("if_preheat")
73
+ self.preheat_step = json_config.get("preheat_step")
74
+ self.max_sample = json_config.get("max_sample")
75
+
76
+ def check_config(self):
77
+ if self.scope is not None and not isinstance(self.scope, list):
78
+ logger.error_log_with_exp("scope is invalid, it should be a list",
79
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
80
+ if self.list is not None and not isinstance(self.list, list):
81
+ logger.error_log_with_exp("list is invalid, it should be a list",
82
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
83
+ if self.data_mode is not None and not isinstance(self.data_mode, list):
84
+ logger.error_log_with_exp("data_mode is invalid, it should be a list",
85
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
@@ -0,0 +1,300 @@
1
+ import multiprocessing
2
+ import os
3
+ import json
4
+ import pandas as pd
5
+ from msprobe.core.common.file_utils import FileOpen
6
+ from msprobe.core.common.const import CompareConst, Const
7
+ from msprobe.core.common.exceptions import FileCheckException
8
+ from msprobe.core.common.log import logger
9
+ from msprobe.core.common.utils import add_time_with_xlsx, CompareException
10
+ from msprobe.core.common.file_utils import remove_path
11
+ from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op
12
+ from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
13
+ from msprobe.core.compare.utils import read_op, merge_tensor, get_un_match_accuracy, get_accuracy
14
+ from msprobe.core.compare.multiprocessing_compute import _handle_multi_process, ComparisonResult, _save_cmp_result
15
+ from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \
16
+ get_error_message
17
+ from msprobe.core.advisor.advisor import Advisor
18
+
19
+
20
+ class Comparator:
21
+
22
+ def __init__(self):
23
+ pass
24
+
25
+ @classmethod
26
+ def make_result_table(cls,result, md5_compare, summary_compare, stack_mode):
27
+ header = []
28
+ if md5_compare:
29
+ header = CompareConst.MD5_COMPARE_RESULT_HEADER[:]
30
+ elif summary_compare:
31
+ header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
32
+ else:
33
+ header = CompareConst.COMPARE_RESULT_HEADER[:]
34
+
35
+ all_mode_bool = not (summary_compare or md5_compare)
36
+ if stack_mode:
37
+ if all_mode_bool:
38
+ header.append(CompareConst.STACK)
39
+ header.append(CompareConst.DATA_NAME)
40
+ else:
41
+ header.append(CompareConst.STACK)
42
+ else:
43
+ if all_mode_bool:
44
+ for row in result:
45
+ del row[-2]
46
+ header.append(CompareConst.DATA_NAME)
47
+ else:
48
+ for row in result:
49
+ del row[-1]
50
+ result_df = pd.DataFrame(result, columns=header)
51
+ return result_df
52
+
53
+ @classmethod
54
+ def gen_merge_list(self, json_data, op_name,stack_json_data, summary_compare, md5_compare):
55
+ op_data = json_data['data'][op_name]
56
+ op_parsed_list = read_op(op_data, op_name)
57
+ if op_name in stack_json_data:
58
+ op_parsed_list.append({'full_op_name': op_name, 'full_info': stack_json_data[op_name]})
59
+ else:
60
+ op_parsed_list.append({'full_op_name': op_name, 'full_info': None})
61
+
62
+ merge_list = merge_tensor(op_parsed_list, summary_compare, md5_compare)
63
+ return merge_list
64
+
65
+ def check_op(self, npu_dict, bench_dict, fuzzy_match):
66
+ a_op_name = npu_dict["op_name"]
67
+ b_op_name = bench_dict["op_name"]
68
+ graph_mode = check_graph_mode(a_op_name[0], b_op_name[0])
69
+
70
+ frame_name = getattr(self,"frame_name")
71
+ if frame_name == "PTComparator":
72
+ from msprobe.pytorch.compare.match import graph_mapping
73
+ if graph_mode:
74
+ return graph_mapping.match(a_op_name[0], b_op_name[0])
75
+ struct_match = check_struct_match(npu_dict, bench_dict)
76
+ if not fuzzy_match:
77
+ return a_op_name == b_op_name and struct_match
78
+ is_match = True
79
+ try:
80
+ is_match = fuzzy_check_op(a_op_name, b_op_name)
81
+ except Exception as err:
82
+ logger.warning("%s and %s can not fuzzy match." % (a_op_name, b_op_name))
83
+ is_match = False
84
+ return is_match and struct_match
85
+
86
+ def match_op(self, npu_queue, bench_queue, fuzzy_match):
87
+ for b_index, b_op in enumerate(bench_queue[0: -1]):
88
+ if self.check_op(npu_queue[-1], b_op, fuzzy_match):
89
+ return len(npu_queue) - 1, b_index
90
+ if self.check_op(npu_queue[-1], bench_queue[-1], fuzzy_match):
91
+ return len(npu_queue) - 1, len(bench_queue) - 1
92
+ for n_index, n_op in enumerate(npu_queue[0: -1]):
93
+ if self.check_op(n_op, bench_queue[-1], fuzzy_match):
94
+ return n_index, len(bench_queue) - 1
95
+ return -1, -1
96
+
97
+ def compare_process(self, file_handles, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False):
98
+ npu_json_handle, bench_json_handle, stack_json_handle = file_handles
99
+ npu_json_data = json.load(npu_json_handle)
100
+ bench_json_data = json.load(bench_json_handle)
101
+ stack_json_data = json.load(stack_json_handle)
102
+
103
+ if fuzzy_match:
104
+ logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
105
+
106
+ npu_ops_queue = []
107
+ bench_ops_queue = []
108
+ result = []
109
+
110
+ ops_npu_iter = iter(npu_json_data['data'])
111
+ ops_bench_iter = iter(bench_json_data['data'])
112
+ read_err_npu = True
113
+ read_err_bench = True
114
+ last_npu_ops_len = 0
115
+ last_bench_ops_len = 0
116
+
117
+ while True:
118
+ if not read_err_npu and not read_err_bench:
119
+ break
120
+ try:
121
+ last_npu_ops_len = len(npu_ops_queue)
122
+ op_name_npu = next(ops_npu_iter)
123
+ read_err_npu = True
124
+ npu_merge_list = self.gen_merge_list(npu_json_data,op_name_npu,stack_json_data,summary_compare,md5_compare)
125
+ if npu_merge_list:
126
+ npu_ops_queue.append(npu_merge_list)
127
+ except StopIteration:
128
+ read_err_npu = False
129
+ try:
130
+ last_bench_ops_len = len(bench_ops_queue)
131
+ op_name_bench = next(ops_bench_iter)
132
+ bench_merge_list = self.gen_merge_list(bench_json_data,op_name_bench,stack_json_data,summary_compare,md5_compare)
133
+ if bench_merge_list:
134
+ bench_ops_queue.append(bench_merge_list)
135
+ except StopIteration:
136
+ read_err_bench = False
137
+
138
+ # merge all boolean expressions
139
+ both_empty = not npu_ops_queue and not bench_ops_queue
140
+ no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len)
141
+ if both_empty or no_change:
142
+ continue
143
+
144
+ # APIs in NPU and Bench models unconsistent judgment
145
+ if bool(npu_ops_queue) ^ bool(bench_ops_queue):
146
+ logger.info("Please check whether the number and calls of APIs in NPU and Bench models are consistent.")
147
+ break
148
+
149
+ n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue, fuzzy_match)
150
+ if n_match_point == -1 and b_match_point == -1:
151
+ continue
152
+ n_match_data = npu_ops_queue[n_match_point]
153
+ b_match_data = bench_ops_queue[b_match_point]
154
+ un_match_data = npu_ops_queue[0: n_match_point]
155
+ for npu_data in un_match_data:
156
+ get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
157
+ get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare)
158
+ del npu_ops_queue[0: n_match_point + 1]
159
+ del bench_ops_queue[0: b_match_point + 1]
160
+ if npu_ops_queue:
161
+ for npu_data in npu_ops_queue:
162
+ get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
163
+
164
+ result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
165
+ return result_df
166
+
167
+ def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
168
+ npu_bench_name_list = op_name_mapping_dict[npu_op_name]
169
+ data_name = npu_bench_name_list[1]
170
+ error_file, relative_err, error_flag = None, None, False
171
+ if data_name == '-1' or data_name == -1: # 没有真实数据路径
172
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
173
+ error_flag = True
174
+ else:
175
+ try:
176
+ read_npy_data = getattr(self, "read_npy_data")
177
+ frame_name = getattr(self, "frame_name")
178
+ if frame_name == "MSComparator":
179
+ n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX)
180
+ if self.cross_frame:
181
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.PT_SUFFIX, load_pt_file=True)
182
+ else:
183
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.NUMPY_SUFFIX)
184
+ else:
185
+ n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX)
186
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.PT_SUFFIX)
187
+ except IOError as error:
188
+ error_file = error.filename
189
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
190
+ error_flag = True
191
+ except FileCheckException:
192
+ error_file = data_name
193
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
194
+ error_flag = True
195
+
196
+ n_value, b_value, error_flag = get_error_type(n_value, b_value, error_flag)
197
+ if not error_flag:
198
+ relative_err = get_relative_err(n_value, b_value)
199
+ n_value, b_value = reshape_value(n_value, b_value)
200
+
201
+ err_msg = get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=error_file)
202
+ result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=relative_err)
203
+
204
+ if npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
205
+ err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
206
+ result_list.append(err_msg)
207
+ return result_list
208
+
209
+ def compare_core(self, input_parma, output_path, **kwargs):
210
+ """
211
+ Compares data from multiple JSON files and generates a comparison report.
212
+
213
+ Args:
214
+ input_parma (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
215
+ "stack_path").
216
+ output_path (str): The path where the output Excel report will be saved.
217
+ **kwargs: Additional keyword arguments including:
218
+ - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False.
219
+ - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
220
+ - suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
221
+ - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
222
+ - summary_compare (bool, optional): Enables summary comparison mode. Defaults to False.
223
+ - md5_compare (bool, optional): Enables MD5 comparison. Defaults to False.
224
+
225
+ Returns:
226
+ """
227
+ # get kwargs or set default value
228
+ stack_mode = kwargs.get('stack_mode', False)
229
+ auto_analyze = kwargs.get('auto_analyze', True)
230
+ suffix = kwargs.get('suffix', '')
231
+ fuzzy_match = kwargs.get('fuzzy_match', False)
232
+ summary_compare = kwargs.get('summary_compare', False)
233
+ md5_compare = kwargs.get('md5_compare', False)
234
+
235
+ logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
236
+ file_name = add_time_with_xlsx("compare_result" + suffix)
237
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
238
+ remove_path(file_path)
239
+ highlight_dict = {'red_rows': [], 'yellow_rows': []}
240
+
241
+ with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \
242
+ FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \
243
+ FileOpen(input_parma.get("stack_json_path"), "r") as stack_json:
244
+ result_df = self.compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match,
245
+ summary_compare, md5_compare)
246
+
247
+ if not md5_compare and not summary_compare:
248
+ result_df = self._do_multi_process(input_parma, result_df)
249
+ find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare)
250
+ highlight_rows_xlsx(result_df, highlight_dict, file_path)
251
+ if auto_analyze:
252
+ advisor = Advisor(result_df, output_path)
253
+ advisor.analysis()
254
+
255
+ def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
256
+ cos_result = []
257
+ max_err_result = []
258
+ max_relative_err_result = []
259
+ err_mess = []
260
+ one_thousand_err_ratio_result = []
261
+ five_thousand_err_ratio_result = []
262
+ is_print_compare_log = input_param.get("is_print_compare_log")
263
+ for i in range(len(result_df)):
264
+ npu_op_name = result_df.iloc[i, 0]
265
+ bench_op_name = result_df.iloc[i, 1]
266
+ if is_print_compare_log:
267
+ logger.info("start compare: {}".format(npu_op_name))
268
+ cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = self.compare_by_op(
269
+ npu_op_name, bench_op_name, dump_path_dict, input_param)
270
+ if is_print_compare_log:
271
+ logger.info(
272
+ "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, one_thousand_err_ratio {}, "
273
+ "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err, err_msg,
274
+ one_thousand_err_ratio, five_thousand_err_ratio))
275
+ cos_result.append(cos_sim)
276
+ max_err_result.append(max_abs_err)
277
+ max_relative_err_result.append(max_relative_err)
278
+ err_mess.append(err_msg)
279
+ one_thousand_err_ratio_result.append(one_thousand_err_ratio)
280
+ five_thousand_err_ratio_result.append(five_thousand_err_ratio)
281
+
282
+ cr = ComparisonResult(
283
+ cos_result=cos_result,
284
+ max_err_result=max_err_result,
285
+ max_relative_err_result=max_relative_err_result,
286
+ err_msgs=err_mess,
287
+ one_thousand_err_ratio_result=one_thousand_err_ratio_result,
288
+ five_thousand_err_ratio_result=five_thousand_err_ratio_result
289
+ )
290
+
291
+ return _save_cmp_result(idx, cr, result_df, lock)
292
+
293
+ def _do_multi_process(self,input_parma, result_df):
294
+ try:
295
+ result_df = _handle_multi_process(self.compare_ops, input_parma, result_df, multiprocessing.Manager().RLock())
296
+ return result_df
297
+ except ValueError as e:
298
+ logger.error('result dataframe is not found.')
299
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
300
+
@@ -0,0 +1,95 @@
1
+ from msprobe.core.common.log import logger
2
+ from msprobe.core.compare.utils import rename_api
3
+
4
+
5
+ dtype_mapping = {
6
+ "Int8": "torch.int8",
7
+ "UInt8": "torch.uint8",
8
+ "Int16": "torch.int16",
9
+ "UInt16": "torch.uint16",
10
+ "Int32": "torch.int32",
11
+ "UInt32": "torch.uint32",
12
+ "Int64": "torch.int64",
13
+ "UInt64": "torch.uint64",
14
+ "Float16": "torch.float16",
15
+ "Float32": "torch.float32",
16
+ "Float64": "torch.float64",
17
+ "Bool": "torch.bool",
18
+ "BFloat16": "torch.bfloat16",
19
+ "Complex64": "torch.complex64",
20
+ "Complex128": "torch.complex128"
21
+ }
22
+
23
+
24
+ def check_struct_match(npu_dict, bench_dict, cross_frame=False):
25
+ npu_struct_in = npu_dict.get("input_struct")
26
+ bench_struct_in = bench_dict.get("input_struct")
27
+ npu_struct_out = npu_dict.get("output_struct")
28
+ bench_struct_out = bench_dict.get("output_struct")
29
+
30
+ if cross_frame:
31
+ npu_struct_in = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_in]
32
+ npu_struct_out = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_out]
33
+ is_match = npu_struct_in == bench_struct_in and npu_struct_out == bench_struct_out
34
+ if not is_match:
35
+ if len(npu_struct_in) == 0 or len(bench_struct_in) == 0 or len(npu_struct_in) != len(bench_struct_in):
36
+ return False
37
+ struct_in_is_match = check_type_shape_match(npu_struct_in, bench_struct_in)
38
+ struct_out_is_match = check_type_shape_match(npu_struct_out, bench_struct_out)
39
+ is_match = struct_in_is_match and struct_out_is_match
40
+ return is_match
41
+
42
+
43
+ def check_type_shape_match(npu_struct, bench_struct):
44
+ shape_type_match = False
45
+ for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
46
+ npu_type = npu_type_shape[0]
47
+ npu_shape = npu_type_shape[1]
48
+ bench_type = bench_type_shape[0]
49
+ bench_shape = bench_type_shape[1]
50
+ shape_match = npu_shape == bench_shape
51
+ type_match = npu_type == bench_type
52
+ if not type_match:
53
+ ms_type=[["Float16", "Float32"], ["Float32", "Float16"],["Float16", "BFloat16"],["BFloat16", "Float16"]]
54
+ torch_type=[["torch.float16", "torch.float32"], ["torch.float32", "torch.float16"],
55
+ ["torch.float16", "torch.bfloat16"], ["torch.bfloat16", "torch.float16"]]
56
+ if ([npu_type, bench_type] in ms_type)or ([npu_type, bench_type] in torch_type):
57
+ type_match = True
58
+ else:
59
+ type_match = False
60
+ shape_type_match = shape_match and type_match
61
+ if not shape_type_match:
62
+ return False
63
+ return shape_type_match
64
+
65
+
66
+ def check_graph_mode(a_op_name, b_op_name):
67
+ if "Aten" in a_op_name and "Aten" not in b_op_name:
68
+ return True
69
+ if "Aten" not in a_op_name and "Aten" in b_op_name:
70
+ return True
71
+ return False
72
+
73
+
74
+ def fuzzy_check_op(npu_name_list, bench_name_list):
75
+ if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list):
76
+ return False
77
+ is_match = True
78
+ for npu_name, bench_name in zip(npu_name_list, bench_name_list):
79
+ is_match = fuzzy_check_name(npu_name, bench_name)
80
+ if not is_match:
81
+ break
82
+ return is_match
83
+
84
+
85
+ def fuzzy_check_name(npu_name, bench_name):
86
+ if "forward" in npu_name and "forward" in bench_name:
87
+ is_match = rename_api(npu_name, "forward") == rename_api(bench_name, "forward")
88
+ elif "backward" in npu_name and "backward" in bench_name:
89
+ is_match = rename_api(npu_name, "backward") == rename_api(bench_name, "backward")
90
+ else:
91
+ is_match = npu_name == bench_name
92
+ return is_match
93
+
94
+
95
+
@@ -0,0 +1,49 @@
1
+ import json
2
+ from msprobe.core.common.file_utils import FileOpen, check_file_type
3
+ from msprobe.core.common.const import FileCheckConst, Const
4
+ from msprobe.core.common.utils import CompareException
5
+ from msprobe.core.common.log import logger
6
+
7
+
8
+ def compare_cli(args):
9
+ with FileOpen(args.input_path, "r") as file:
10
+ input_param = json.load(file)
11
+ npu_path = input_param.get("npu_path", None)
12
+ bench_path = input_param.get("bench_path", None)
13
+ frame_name = args.framework
14
+ auto_analyze = not args.compare_only
15
+ if frame_name == Const.PT_FRAMEWORK:
16
+ from msprobe.pytorch.compare.pt_compare import compare
17
+ from msprobe.pytorch.compare.distributed_compare import compare_distributed
18
+ else:
19
+ from msprobe.mindspore.compare.ms_compare import ms_compare
20
+ from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare
21
+ if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE:
22
+ input_param["npu_json_path"] = input_param.pop("npu_path")
23
+ input_param["bench_json_path"] = input_param.pop("bench_path")
24
+ input_param["stack_json_path"] = input_param.pop("stack_path")
25
+ if frame_name == Const.PT_FRAMEWORK:
26
+ compare(input_param, args.output_path, stack_mode=args.stack_mode, auto_analyze=auto_analyze,
27
+ fuzzy_match=args.fuzzy_match)
28
+ else:
29
+ kwargs = {
30
+ "stack_mode": args.stack_mode,
31
+ "auto_analyze": auto_analyze,
32
+ "fuzzy_match": args.fuzzy_match,
33
+ "cell_mapping": args.cell_mapping,
34
+ "api_mapping": args.api_mapping,
35
+ }
36
+
37
+ ms_compare(input_param, args.output_path, **kwargs)
38
+ elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
39
+ kwargs = {"stack_mode": args.stack_mode, "auto_analyze": auto_analyze, "fuzzy_match": args.fuzzy_match}
40
+ if input_param.get("rank_id") is not None:
41
+ ms_graph_compare(input_param, args.output_path)
42
+ return
43
+ if frame_name == Const.PT_FRAMEWORK:
44
+ compare_distributed(npu_path, bench_path, args.output_path, **kwargs)
45
+ else:
46
+ ms_compare_distributed(npu_path, bench_path, args.output_path, **kwargs)
47
+ else:
48
+ logger.error("The npu_path and bench_path need to be of the same type.")
49
+ raise CompareException(CompareException.INVALID_COMPARE_MODE)