mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -0,0 +1,223 @@
1
+ import math
2
+ import abc
3
+ from collections import namedtuple
4
+ import numpy as np
5
+ import openpyxl
6
+ from openpyxl.styles import PatternFill
7
+ from msprobe.core.common.utils import get_header_index
8
+ from msprobe.core.common.file_utils import save_workbook
9
+ from msprobe.core.common.log import logger
10
+ from msprobe.core.common.const import CompareConst
11
+
12
+
13
+ class HighlightCheck(abc.ABC):
14
+ @abc.abstractmethod
15
+ def apply(self, info, color_columns, summary_compare):
16
+ raise NotImplementedError
17
+
18
+
19
+ class CheckOrderMagnitude(HighlightCheck):
20
+ """检查Max diff的数量级差异"""
21
+ def apply(self, info, color_columns, summary_compare=True):
22
+ api_in, api_out, num = info
23
+ max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
24
+ if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]):
25
+ return
26
+ in_order = 0 if abs(api_in[max_diff_index]) < 1 else math.log10(abs(api_in[max_diff_index]))
27
+ out_order = 0 if abs(api_out[max_diff_index]) < 1 else math.log10(abs(api_out[max_diff_index]))
28
+ if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW:
29
+ color_columns.yellow.append(num)
30
+
31
+
32
+ class CheckOneThousandErrorRatio(HighlightCheck):
33
+ """检查千分误差比率"""
34
+ def apply(self, info, color_columns, summary_compare=True):
35
+ api_in, api_out, num = info
36
+ one_thousand_index = get_header_index('One Thousandth Err Ratio', summary_compare)
37
+ if not isinstance(api_in[one_thousand_index], (float, int)) or not isinstance(api_out[one_thousand_index], (float, int)):
38
+ return
39
+ if api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED:
40
+ color_columns.red.append(num)
41
+ elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
42
+ color_columns.yellow.append(num)
43
+
44
+
45
+ class CheckCosineSimilarity(HighlightCheck):
46
+ """检查余弦相似度"""
47
+ def apply(self, info, color_columns, summary_compare=True):
48
+ api_in, api_out, num = info
49
+ cosine_index = get_header_index('Cosine', summary_compare)
50
+ if not isinstance(api_in[cosine_index], (float, int)) or not isinstance(api_out[cosine_index], (float, int)):
51
+ return
52
+ if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
53
+ color_columns.yellow.append(num)
54
+
55
+
56
+ class CheckMaxRelativeDiff(HighlightCheck):
57
+ """检查最大相对差异"""
58
+ def apply(self, info, color_columns, summary_compare=True):
59
+ api_in, api_out, num = info
60
+ max_diff_index = get_header_index('Max diff', summary_compare)
61
+ bench_max_index = get_header_index('Bench max', summary_compare)
62
+ input_max_relative_diff = np.abs(np.divide(api_in[max_diff_index], max(0.01, api_in[bench_max_index])))
63
+ output_max_relative_diff = np.abs(np.divide(api_out[max_diff_index], max(0.01, api_out[bench_max_index])))
64
+ if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff,
65
+ (float, int)):
66
+ return
67
+ if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
68
+ color_columns.red.append(num)
69
+ elif output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW:
70
+ color_columns.yellow.append(num)
71
+
72
+
73
+ class CheckOverflow(HighlightCheck):
74
+ """检查是否存在溢出"""
75
+ def apply(self, info, color_columns, summary_compare=True):
76
+ line, num = info
77
+ npu_max_index = get_header_index('NPU max', summary_compare)
78
+ npu_min_index = get_header_index('NPU min', summary_compare)
79
+ max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
80
+ if str(line[npu_max_index]) in CompareConst.OVERFLOW_LIST or str(
81
+ line[npu_min_index]) in CompareConst.OVERFLOW_LIST:
82
+ color_columns.red.append(num)
83
+ return
84
+ # check if Max_Diff > 1e+10
85
+ if isinstance(line[max_diff_index], (float, int)) and line[max_diff_index] > CompareConst.MAX_DIFF_RED:
86
+ color_columns.red.append(num)
87
+
88
+
89
+ class HighlightRules:
90
+ """高亮规则集合,用于检查API的误差"""
91
+ # 适用于每行的规则
92
+ basic_rules = {
93
+ "check_overflow": CheckOverflow()
94
+ }
95
+
96
+ # 用于比较输入和输出的规则
97
+ compare_rules = {
98
+ "check_order_magnitude": CheckOrderMagnitude(),
99
+ "check_one_thousand_error": CheckOneThousandErrorRatio(),
100
+ "check_cosine_similarity": CheckCosineSimilarity()
101
+ }
102
+ summary_compare_rules = {
103
+ "check_order_magnitude": CheckOrderMagnitude(),
104
+ "check_max_relative_diff": CheckMaxRelativeDiff(),
105
+ }
106
+
107
+
108
+ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False, md5_compare=False):
109
+ """找到单个API中需要高亮的行"""
110
+ if md5_compare:
111
+ return
112
+ npu_max_index = get_header_index('NPU max', summary_compare)
113
+ bench_max_index = get_header_index('Bench max', summary_compare)
114
+ max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
115
+
116
+ red_lines, yellow_lines = [], []
117
+ LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
118
+ ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer'])
119
+ ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
120
+ color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
121
+
122
+ # 对单行API的输入或输出进行误差判断
123
+ for i, line in enumerate(result):
124
+ num = last_len + i
125
+ line_info = LineInfo(line_data=line, num_pointer=num)
126
+ for rule in HighlightRules.basic_rules.values():
127
+ rule.apply(line_info, color_columns, summary_compare)
128
+
129
+ # 对API的输出与输入比较,进行误差判断
130
+ for n, api_out in enumerate(result[n_num_input:len(result)]):
131
+ num = last_len + n_num_input + n
132
+ if num in red_lines:
133
+ continue
134
+ if not isinstance(api_out[npu_max_index], (float, int)) \
135
+ or not isinstance(api_out[bench_max_index], (float, int)) \
136
+ or not isinstance(api_out[max_diff_index], (float, int)):
137
+ continue
138
+ for _, api_in in enumerate(result[0:n_num_input]):
139
+ if not isinstance(api_in[npu_max_index], (float, int)) \
140
+ or not isinstance(api_in[bench_max_index], (float, int)) \
141
+ or not isinstance(api_in[max_diff_index], (float, int)):
142
+ continue
143
+
144
+ api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num)
145
+ if summary_compare:
146
+ for rule in HighlightRules.summary_compare_rules.values():
147
+ rule.apply(api_info, color_columns, summary_compare)
148
+ else:
149
+ for rule in HighlightRules.compare_rules.values():
150
+ rule.apply(api_info, color_columns, summary_compare)
151
+
152
+ highlight_dict.get('red_rows', []).extend(list(set(red_lines)))
153
+ highlight_dict.get('yellow_rows', []).extend(list(set(yellow_lines) - set(red_lines)))
154
+
155
+
156
+ def get_name_and_state(name):
157
+ """Get api/module name and state"""
158
+ if "input" in name:
159
+ api_name = name.split("input")[0]
160
+ state = "input"
161
+ else:
162
+ api_name = name.split("output")[0]
163
+ state = "output"
164
+ return api_name, state
165
+
166
+
167
+ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare):
168
+ """将dataframe根据API分组,并找到有误差的算子用于高亮"""
169
+ result = result_df.values
170
+ start, input_num, output_num, end = 0, 0, 0, len(result_df)
171
+ last_api_name, last_state = None, None
172
+ num, last_len = 0, 0
173
+ for res_i in result:
174
+ api_name, state = get_name_and_state(res_i[0])
175
+ if last_api_name:
176
+ if api_name == last_api_name:
177
+ if state == last_state:
178
+ num += 1
179
+ else:
180
+ input_num = num
181
+ num, last_state = 1, state
182
+ else:
183
+ output_num = num
184
+ find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
185
+ summary_compare, md5_compare)
186
+ num, last_api_name, last_state = 1, api_name, state
187
+ start += input_num + output_num
188
+ input_num, output_num = 1, 0
189
+ else:
190
+ num, last_api_name, last_state = 1, api_name, state
191
+ if state:
192
+ if state == "input":
193
+ input_num = num
194
+ else:
195
+ output_num = num
196
+ find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare, md5_compare)
197
+
198
+
199
+ def highlight_rows_xlsx(result_df, highlight_dict, file_path):
200
+ """Write and highlight results in Excel"""
201
+ logger.info('Compare result is %s' % file_path)
202
+
203
+ wb = openpyxl.Workbook()
204
+ ws = wb.active
205
+
206
+ # write header
207
+ for j, col_name in enumerate(result_df.columns, start=1):
208
+ ws.cell(row=1, column=j, value=col_name)
209
+
210
+ for i, row in enumerate(result_df.iterrows(), start=2):
211
+ for j, value in enumerate(row[1], start=1):
212
+ if not isinstance(value, (float, int)):
213
+ value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value)
214
+ ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value)
215
+
216
+ if (i - 2) in highlight_dict['red_rows']:
217
+ ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.RED,
218
+ end_color=CompareConst.RED, fill_type="solid")
219
+ elif (i - 2) in highlight_dict['yellow_rows']:
220
+ ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW,
221
+ end_color=CompareConst.YELLOW, fill_type="solid")
222
+
223
+ save_workbook(wb, file_path)
@@ -0,0 +1,149 @@
1
+
2
+ import multiprocessing
3
+ from dataclasses import dataclass
4
+ from functools import partial
5
+ import numpy as np
6
+ import pandas as pd
7
+ from msprobe.core.common.log import logger
8
+ from msprobe.core.common.utils import CompareException
9
+ from msprobe.core.common.const import CompareConst
10
+
11
+
12
+ def _handle_multi_process(func, input_parma, result_df, lock):
13
+ process_num = int((multiprocessing.cpu_count() + 1) / 2)
14
+ op_name_mapping_dict = read_dump_data(result_df)
15
+
16
+ df_chunk_size = len(result_df) // process_num
17
+ if df_chunk_size > 0:
18
+ df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
19
+ else:
20
+ df_chunks = [result_df]
21
+
22
+ results = []
23
+ pool = multiprocessing.Pool(process_num)
24
+
25
+ def err_call(args):
26
+ logger.error('multiprocess compare failed! Reason: {}'.format(args))
27
+ try:
28
+ pool.terminate()
29
+ except OSError as e:
30
+ logger.error("pool terminate failed")
31
+
32
+ for process_idx, df_chunk in enumerate(df_chunks):
33
+ idx = df_chunk_size * process_idx
34
+ result = pool.apply_async(func,
35
+ args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma),
36
+ error_callback=err_call)
37
+ results.append(result)
38
+ final_results = [r.get() for r in results]
39
+ pool.close()
40
+ pool.join()
41
+ return pd.concat(final_results, ignore_index=True)
42
+
43
+
44
+ def _ms_graph_handle_multi_process(func, result_df, mode):
45
+ process_num = int((multiprocessing.cpu_count() + 1) // 2)
46
+ df_chunk_size = len(result_df) // process_num
47
+ if df_chunk_size > 0:
48
+ df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
49
+ else:
50
+ df_chunks = [result_df]
51
+
52
+ results = []
53
+ pool = multiprocessing.Pool(process_num)
54
+
55
+ def err_call(args):
56
+ logger.error('multiprocess compare failed! Reason: {}'.format(args))
57
+ try:
58
+ pool.terminate()
59
+ except OSError as e:
60
+ logger.error("pool terminate failed")
61
+
62
+ for df_chunk in df_chunks:
63
+ result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call)
64
+ results.append(result)
65
+ final_results = [r.get() for r in results]
66
+ pool.close()
67
+ pool.join()
68
+ return pd.concat(final_results, ignore_index=True)
69
+
70
+
71
+ def read_dump_data(result_df):
72
+ try:
73
+ npu_dump_name_list = result_df.iloc[0:, 0].tolist()
74
+ npu_dump_tensor_list = result_df.iloc[0:, -1].tolist()
75
+ op_name_mapping_dict = {}
76
+ for index, _ in enumerate(npu_dump_name_list):
77
+ npu_dump_name = npu_dump_name_list[index]
78
+ npu_dump_tensor = npu_dump_tensor_list[index]
79
+ op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor]
80
+ return op_name_mapping_dict
81
+ except ValueError as e:
82
+ logger.error('result dataframe is not found.')
83
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
84
+ except IndexError as e:
85
+ logger.error('result dataframe elements can not be access.')
86
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
87
+
88
+ @dataclass
89
+ class ComparisonResult:
90
+ cos_result: list
91
+ max_err_result: list
92
+ max_relative_err_result: list
93
+ err_msgs: list
94
+ one_thousand_err_ratio_result: list
95
+ five_thousand_err_ratio_result: list
96
+
97
+
98
+ def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
99
+ """
100
+ Save comparison results into the result DataFrame with thread safety.
101
+ Args:
102
+ offset: offset for index
103
+ result: data struct of ComparisonResult
104
+ result_df: result of DataFrame
105
+ lock: thread lock
106
+
107
+ Returns:
108
+ comparison results in DataFrame
109
+ """
110
+
111
+ lock.acquire()
112
+ try:
113
+ for i, _ in enumerate(result.cos_result):
114
+ process_index = i + offset
115
+ result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i]
116
+ result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
117
+ result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
118
+ result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
119
+ result_df.loc[process_index, CompareConst.ACCURACY] = check_accuracy(result.cos_result[i], result.max_err_result[i])
120
+ result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousand_err_ratio_result[i]
121
+ result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousand_err_ratio_result[i]
122
+ return result_df
123
+ except ValueError as e:
124
+ logger.error('result dataframe is not found.')
125
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
126
+ except IndexError as e:
127
+ logger.error('result dataframe elements can not be access.')
128
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
129
+ finally:
130
+ lock.release()
131
+
132
+
133
+ def check_accuracy(cos, max_abs_err):
134
+ if cos == CompareConst.SHAPE_UNMATCH:
135
+ return CompareConst.ACCURACY_CHECK_UNMATCH
136
+ if cos == CompareConst.NONE or max_abs_err == CompareConst.NONE:
137
+ return CompareConst.NONE
138
+ if cos == "N/A" or max_abs_err == "N/A":
139
+ return CompareConst.ACCURACY_CHECK_NO
140
+ try:
141
+ cos, max_abs_err = float(cos), float(max_abs_err)
142
+ except ValueError:
143
+ logger.warning("Cosine or MaxAbsErr can not get float value.")
144
+ return CompareConst.NONE
145
+ if cos < CompareConst.COS_THRESHOLD and max_abs_err > CompareConst.MAX_ABS_ERR_THRESHOLD:
146
+ return CompareConst.ACCURACY_CHECK_NO
147
+ if cos < CompareConst.COS_MAX_THRESHOLD or max_abs_err > CompareConst.MAX_ABS_ERR_MAX_THRESHOLD:
148
+ return CompareConst.ACCURACY_CHECK_NO
149
+ return CompareConst.ACCURACY_CHECK_YES