mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -0,0 +1,113 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ import os
19
+
20
+ from mindspore import Tensor, ops, mint
21
+ from mindspore.mint.nn import functional
22
+ from mindspore.common._stub_tensor import StubTensor
23
+
24
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
25
+ from msprobe.core.common.const import Const
26
+ from msprobe.mindspore.common.const import Const as MsConst
27
+ from msprobe.core.common.file_utils import load_yaml
28
+
29
+
30
+ cur_path = os.path.dirname(os.path.realpath(__file__))
31
+ yaml_path = os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE)
32
+
33
+
34
+ class HOOKTensor(object):
35
+ pass
36
+
37
+
38
+ class HOOKStubTensor(object):
39
+ pass
40
+
41
+
42
+ class HOOKFunctionalOP(object):
43
+ pass
44
+
45
+
46
+ class HOOKMintOP(object):
47
+ pass
48
+
49
+
50
+ class HOOKMintNNFunctionalOP(object):
51
+ pass
52
+
53
+
54
+ class ApiTemplate(HOOKCell):
55
+ def __init__(self, api_name, api_dict, prefix, hook):
56
+ self.api_name = api_name
57
+ self.api_func = api_dict[api_name]
58
+ self.prefix_api_name = prefix + str(api_name.split(Const.SEP)[-1]) + Const.SEP
59
+ super().__init__(hook)
60
+
61
+ def construct(self, *args, **kwargs):
62
+ if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
63
+ return args[0] if args else kwargs.get(Const.INPUT)
64
+ return self.api_func(*args, **kwargs)
65
+
66
+
67
+ class WrapApiName:
68
+ def __init__(self, tensor_api_names, stub_tensor_api_names, ops_api_names, mint_api_names, mint_nn_func_api_names):
69
+ self.tensor_api_names = tensor_api_names
70
+ self.stub_tensor_api_names = stub_tensor_api_names
71
+ self.ops_api_names = ops_api_names
72
+ self.mint_api_names = mint_api_names
73
+ self.mint_nn_func_api_names = mint_nn_func_api_names
74
+
75
+
76
+ def get_wrap_api_list():
77
+ api_list = load_yaml(yaml_path)
78
+ tensor_api = api_list.get(MsConst.SUPPORTED_TENSOR_LIST_KEY)
79
+ ops_api = api_list.get(MsConst.SUPPORTED_OPS_LIST_KEY)
80
+ mint_api = api_list.get(MsConst.SUPPORTED_MINT_LIST_KEY)
81
+ mint_nn_func_api = api_list.get(MsConst.SUPPORTED__MINT_NN_FUNC_LIST_KEY)
82
+ wrap_api_name = WrapApiName(set(tensor_api) & set(dir(Tensor)),
83
+ set(tensor_api) & set(dir(StubTensor)),
84
+ set(ops_api) & set(dir(ops)),
85
+ set(mint_api) & set(dir(mint)),
86
+ set(mint_nn_func_api) & set(dir(functional)))
87
+ return wrap_api_name
88
+
89
+
90
+ def wrap_api_func(api_name, api_dict, prefix, hook):
91
+ def api_function(*args, **kwargs):
92
+ return ApiTemplate(api_name, api_dict, prefix, hook)(*args, **kwargs)
93
+ return api_function
94
+
95
+
96
+ def wrap_api_func_and_bind(api_list, api_dict, prefix, hook, hook_class):
97
+ for api_name in api_list:
98
+ if callable(api_dict[api_name]):
99
+ setattr(hook_class, Const.ATTR_NAME_PREFIX + api_name, wrap_api_func(api_name, api_dict, prefix, hook))
100
+
101
+
102
+ def setup_hooks(hook):
103
+ wrap_api_name = get_wrap_api_list()
104
+ wrap_api_func_and_bind(wrap_api_name.tensor_api_names, {f: getattr(Tensor, f) for f in dir(Tensor)},
105
+ MsConst.TENSOR_DATA_PREFIX, hook, HOOKTensor)
106
+ wrap_api_func_and_bind(wrap_api_name.stub_tensor_api_names, {f: getattr(StubTensor, f) for f in dir(StubTensor)},
107
+ MsConst.STUB_TENSOR_DATA_PREFIX, hook, HOOKStubTensor)
108
+ wrap_api_func_and_bind(wrap_api_name.ops_api_names, {f: getattr(ops, f) for f in dir(ops)},
109
+ MsConst.OPS_DATA_PREFIX, hook, HOOKFunctionalOP)
110
+ wrap_api_func_and_bind(wrap_api_name.mint_api_names, {f: getattr(mint, f) for f in dir(mint)},
111
+ MsConst.MINT_DATA_PREFIX, hook, HOOKMintOP)
112
+ wrap_api_func_and_bind(wrap_api_name.mint_nn_func_api_names, {f: getattr(functional, f) for f in dir(functional)},
113
+ MsConst.MINT_NN_FUNC_DATA_PREFIX, hook, HOOKMintNNFunctionalOP)
@@ -0,0 +1,72 @@
1
+ import os
2
+
3
+ from mindspore import Tensor
4
+ from mindspore.common.api import _MindsporeFunctionExecutor
5
+ from mindspore._c_expression import PyNativeExecutor_
6
+
7
+ from msprobe.mindspore.dump.hook_cell.api_registry import api_register
8
+ from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs
9
+ from msprobe.core.common.const import Const
10
+
11
+
12
+ def dump_jit(name, in_feat, out_feat, is_forward):
13
+ pid = os.getpid()
14
+ ori_args = str(name)
15
+ index = ori_args.find("<")
16
+ if index != 0 and index != -1:
17
+ result = ori_args[0:index]
18
+ else:
19
+ result = "JitFunction"
20
+ if is_forward:
21
+ name_template = "Jit." + result + ".forward"
22
+ else:
23
+ name_template = "Jit." + result + ".backward"
24
+ if JitDump.need_dump():
25
+ JitDump.data_collector.update_api_or_module_name(name_template)
26
+ module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat)
27
+ JitDump.data_collector.forward_data_collect(name_template, {}, pid, module_input_output)
28
+
29
+
30
+ class JitDump(_MindsporeFunctionExecutor):
31
+ dump_config = None
32
+ jit_enable = False
33
+
34
+ def __init__(self, *args, **kwargs):
35
+ super().__init__(*args, **kwargs)
36
+ self._executor = PyNativeExecutor_.get_instance()
37
+
38
+ def __call__(self, *args, **kwargs):
39
+ api_register.api_set_ori_func()
40
+ out = super().__call__(*args, **kwargs)
41
+ if isinstance(args[0], Tensor):
42
+ dump_jit({}, args, out, True)
43
+ else:
44
+ dump_jit(args[0], args[1:], out, True)
45
+ JitDump.jit_enable = True
46
+ api_register.api_set_hook_func()
47
+ return out
48
+
49
+ @classmethod
50
+ def set_config(cls, value):
51
+ cls.dump_config = value
52
+
53
+ @classmethod
54
+ def set_data_collector(cls, value):
55
+ cls.data_collector = value
56
+
57
+ @classmethod
58
+ def need_dump(cls):
59
+ if cls.dump_config.task != Const.TENSOR and cls.dump_config.task != Const.STATISTICS:
60
+ return False
61
+ if not cls.data_collector or cls.data_collector.data_processor.is_terminated:
62
+ return False
63
+ return True
64
+
65
+ def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
66
+ if JitDump.jit_enable:
67
+ api_register.api_set_ori_func()
68
+ output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
69
+ if JitDump.jit_enable:
70
+ dump_jit(obj, args, None, False)
71
+ api_register.api_set_hook_func()
72
+ return output
@@ -1,60 +1,59 @@
1
- import os
2
- import json
3
- from msprobe.core.common.utils import make_dump_path_if_not_exists
4
- from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
5
- from msprobe.core.common.log import logger
6
- from msprobe.core.common.file_check import FileOpen
7
-
8
-
9
- class KernelGraphDump:
10
- def __init__(self, config: DebuggerConfig):
11
- self.dump_json = dict()
12
- self.dump_json["common_dump_settings"] = dict()
13
- self.dump_json["common_dump_settings"]["dump_mode"] = 0
14
- self.dump_json["common_dump_settings"]["path"] = ""
15
- self.dump_json["common_dump_settings"]["net_name"] = "Net"
16
- self.dump_json["common_dump_settings"]["iteration"] = "all"
17
- self.dump_json["common_dump_settings"]["saved_data"] = "statistic"
18
- self.dump_json["common_dump_settings"]["input_output"] = 0
19
- self.dump_json["common_dump_settings"]["kernels"] = []
20
- self.dump_json["common_dump_settings"]["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7]
21
- self.dump_json["common_dump_settings"]["op_debug_mode"] = 0
22
- self.dump_json["common_dump_settings"]["file_format"] = "npy"
23
-
24
- if len(config.list) > 0:
25
- self.dump_json["common_dump_settings"]["dump_mode"] = 1
26
- self.dump_json["common_dump_settings"]["kernels"] = config.list
27
- self.dump_json["common_dump_settings"]["path"] = config.dump_path
28
- if len(config.step) > 0:
29
- step_str = ""
30
- for s in config.step:
31
- step_str += (str(s) + '|')
32
- self.dump_json["common_dump_settings"]["iteration"] = step_str[:-1]
33
- if len(config.rank) > 0:
34
- self.dump_json["common_dump_settings"]["support_device"] = config.rank
35
- if config.task == "tensor":
36
- self.dump_json["common_dump_settings"]["saved_data"] = "tensor"
37
- self.dump_json["common_dump_settings"]["file_format"] = config.file_format
38
- if len(config.data_mode) == 1:
39
- if config.data_mode[0] == "input":
40
- self.dump_json["common_dump_settings"]["input_output"] = 1
41
- if config.data_mode[0] == "output":
42
- self.dump_json["common_dump_settings"]["input_output"] = 2
43
-
44
- def handle(self):
45
- if os.getenv("GRAPH_OP_RUN") == "1":
46
- raise Exception("Must run in graph mode, not kbk mode")
47
- json_path = self.dump_json["common_dump_settings"]["path"]
48
- make_dump_path_if_not_exists(json_path)
49
- json_path = os.path.join(json_path, "kernel_graph_dump.json")
50
- with FileOpen(json_path, 'w') as f:
51
- json.dump(self.dump_json, f)
52
- logger.info(json_path + " has been created.")
53
- os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
54
- if self.dump_json["common_dump_settings"]["dump_mode"] == 0:
55
- if self.dump_json["common_dump_settings"]["iteration"] != "all" or \
56
- len(self.dump_json["common_dump_settings"]["kernels"]) == 0:
57
- os.environ["MS_ACL_DUMP_CFG_PATH"] = json_path
58
- else:
59
- if "MS_ACL_DUMP_CFG_PATH" in os.environ:
60
- del os.environ["MS_ACL_DUMP_CFG_PATH"]
1
+ import os
2
+ import json
3
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
4
+ from msprobe.mindspore.common.log import logger
5
+ from msprobe.core.common.file_utils import FileOpen, create_directory
6
+
7
+
8
+ class KernelGraphDump:
9
+ def __init__(self, config: DebuggerConfig):
10
+ self.dump_json = dict()
11
+ self.dump_json["common_dump_settings"] = dict()
12
+ self.dump_json["common_dump_settings"]["dump_mode"] = 0
13
+ self.dump_json["common_dump_settings"]["path"] = ""
14
+ self.dump_json["common_dump_settings"]["net_name"] = "Net"
15
+ self.dump_json["common_dump_settings"]["iteration"] = "all"
16
+ self.dump_json["common_dump_settings"]["saved_data"] = "statistic"
17
+ self.dump_json["common_dump_settings"]["input_output"] = 0
18
+ self.dump_json["common_dump_settings"]["kernels"] = []
19
+ self.dump_json["common_dump_settings"]["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7]
20
+ self.dump_json["common_dump_settings"]["op_debug_mode"] = 0
21
+ self.dump_json["common_dump_settings"]["file_format"] = "npy"
22
+
23
+ if len(config.list) > 0:
24
+ self.dump_json["common_dump_settings"]["dump_mode"] = 1
25
+ self.dump_json["common_dump_settings"]["kernels"] = config.list
26
+ self.dump_json["common_dump_settings"]["path"] = config.dump_path
27
+ if len(config.step) > 0:
28
+ step_str = ""
29
+ for s in config.step:
30
+ step_str += (str(s) + '|')
31
+ self.dump_json["common_dump_settings"]["iteration"] = step_str[:-1]
32
+ if len(config.rank) > 0:
33
+ self.dump_json["common_dump_settings"]["support_device"] = config.rank
34
+ if config.task == "tensor":
35
+ self.dump_json["common_dump_settings"]["saved_data"] = "tensor"
36
+ self.dump_json["common_dump_settings"]["file_format"] = config.file_format
37
+ if len(config.data_mode) == 1:
38
+ if config.data_mode[0] == "input":
39
+ self.dump_json["common_dump_settings"]["input_output"] = 1
40
+ if config.data_mode[0] == "output":
41
+ self.dump_json["common_dump_settings"]["input_output"] = 2
42
+
43
+ def handle(self):
44
+ if os.getenv("GRAPH_OP_RUN") == "1":
45
+ raise Exception("Must run in graph mode, not kbk mode")
46
+ json_path = self.dump_json["common_dump_settings"]["path"]
47
+ create_directory(json_path)
48
+ json_path = os.path.join(json_path, "kernel_graph_dump.json")
49
+ with FileOpen(json_path, 'w') as f:
50
+ json.dump(self.dump_json, f)
51
+ logger.info(json_path + " has been created.")
52
+ os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
53
+ if self.dump_json["common_dump_settings"]["dump_mode"] == 0:
54
+ if self.dump_json["common_dump_settings"]["iteration"] != "all" or \
55
+ len(self.dump_json["common_dump_settings"]["kernels"]) == 0:
56
+ os.environ["MS_ACL_DUMP_CFG_PATH"] = json_path
57
+ else:
58
+ if "MS_ACL_DUMP_CFG_PATH" in os.environ:
59
+ del os.environ["MS_ACL_DUMP_CFG_PATH"]
@@ -0,0 +1,64 @@
1
+ import os
2
+ import json
3
+
4
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
5
+ from msprobe.mindspore.common.log import logger
6
+ from msprobe.core.common.file_utils import FileOpen, create_directory
7
+ from msprobe.core.common.const import Const
8
+
9
+
10
+ class KernelKbykDump:
11
+ COMMON_SETTINGS = "common_dump_settings"
12
+ E2E_SETTINGS = "e2e_dump_settings"
13
+
14
+ def __init__(self, config: DebuggerConfig):
15
+ self.dump_json = dict()
16
+ common_set = dict()
17
+ e2e_set = dict()
18
+
19
+ common_set = dict()
20
+ common_set["dump_mode"] = 0
21
+ common_set["path"] = ""
22
+ common_set["net_name"] = "Net"
23
+ common_set["iteration"] = "all"
24
+ common_set["saved_data"] = "statistic"
25
+ common_set["input_output"] = 0
26
+ common_set["kernels"] = []
27
+ common_set["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7]
28
+ e2e_set = dict()
29
+ e2e_set["enable"] = True
30
+ e2e_set["trans_flag"] = True
31
+
32
+ if config.list:
33
+ common_set["dump_mode"] = 1
34
+ common_set["kernels"] = config.list
35
+ common_set["path"] = config.dump_path
36
+ if config.step:
37
+ step_str = ""
38
+ for s in config.step:
39
+ step_str += (str(s) + '|')
40
+ common_set["iteration"] = step_str[:-1]
41
+ if config.rank:
42
+ common_set["support_device"] = config.rank
43
+ if config.task == Const.TENSOR:
44
+ common_set["saved_data"] = Const.TENSOR
45
+ if len(config.data_mode) == 1:
46
+ if config.data_mode[0] == Const.INPUT:
47
+ common_set["input_output"] = 1
48
+ if config.data_mode[0] == Const.OUTPUT:
49
+ common_set["input_output"] = 2
50
+
51
+ self.dump_json[KernelKbykDump.COMMON_SETTINGS] = common_set
52
+ self.dump_json[KernelKbykDump.E2E_SETTINGS] = e2e_set
53
+
54
+ def handle(self):
55
+ json_path = self.dump_json[KernelKbykDump.COMMON_SETTINGS]["path"]
56
+ create_directory(json_path)
57
+ json_path = os.path.join(json_path, "kernel_kbyk_dump.json")
58
+ with FileOpen(json_path, 'w') as f:
59
+ json.dump(self.dump_json, f)
60
+ logger.info(json_path + " has been created.")
61
+
62
+ os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
63
+ if "MS_ACL_DUMP_CFG_PATH" in os.environ:
64
+ del os.environ["MS_ACL_DUMP_CFG_PATH"]
File without changes
@@ -0,0 +1,116 @@
1
+ import os
2
+ import inspect
3
+ import importlib
4
+
5
+ import mindspore as ms
6
+ from mindspore.communication import comm_func
7
+
8
+ from msprobe.core.common.file_utils import load_yaml, check_path_length
9
+ from msprobe.core.common.const import Const
10
+ from msprobe.mindspore.common.const import Const as MsConst
11
+ from msprobe.mindspore.common.const import FreeBenchmarkConst
12
+ from msprobe.mindspore.free_benchmark.common.config import Config
13
+ from msprobe.mindspore.common.log import logger
14
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
15
+ from msprobe.mindspore.free_benchmark.decorator.decorator_factory import decorate_forward_function
16
+
17
+
18
+ class ApiPyNativeSelFCheck:
19
+ def __init__(self, config: DebuggerConfig):
20
+ Config.is_enable = True
21
+ Config.handler_type = config.handler_type
22
+ Config.pert_type = config.pert_type
23
+ Config.stage = config.stage
24
+ Config.dump_level = config.dump_level
25
+ Config.steps = config.step
26
+ Config.ranks = config.rank
27
+ Config.dump_path = os.path.join(config.dump_path, "free_benchmark.csv")
28
+ check_path_length(Config.dump_path)
29
+
30
+ self.api_list = config.list
31
+ all_api = get_supported_ops()
32
+ if not self.api_list:
33
+ self.api_list = all_api
34
+ else:
35
+ self.api_list = set(self.api_list) & all_api
36
+
37
+ def handle(self):
38
+ for api_name in self.api_list:
39
+ hijack(api_name)
40
+
41
+
42
+ def get_supported_ops():
43
+ supported_ops = []
44
+ cur_path = os.path.dirname(os.path.realpath(__file__))
45
+ yaml_path = os.path.join(cur_path, "data", "support_wrap_ops.yaml")
46
+
47
+ yaml_data = load_yaml(yaml_path)
48
+ for k, v in FreeBenchmarkConst.API_PREFIX_DICT.items():
49
+ ops = yaml_data.get(k)
50
+ if ops:
51
+ ops = [v + i for i in ops]
52
+ supported_ops += ops
53
+
54
+ _all_functional_ops = []
55
+ ms_ops = dir(ms.ops)
56
+ ms_ops = [MsConst.OPS_PREFIX + i for i in ms_ops]
57
+ _all_functional_ops += ms_ops
58
+
59
+ ms_tensor = dir(ms.Tensor)
60
+ ms_tensor = [MsConst.Tensor_PREFIX + i for i in ms_tensor]
61
+ _all_functional_ops += ms_tensor
62
+
63
+ ms_mint = dir(ms.mint)
64
+ ms_mint = [MsConst.MINT_PREFIX + i for i in ms_mint]
65
+ _all_functional_ops += ms_mint
66
+
67
+ ms_mint_nn_func = dir(ms.mint.nn.functional)
68
+ ms_mint_nn_func = [MsConst.MINT_NN_FUNC_PREFIX + i for i in ms_mint_nn_func]
69
+ _all_functional_ops += ms_mint_nn_func
70
+
71
+ ms_communication = dir(comm_func)
72
+ ms_communication = [MsConst.COMM_PREFIX + i for i in ms_communication]
73
+ _all_functional_ops += ms_communication
74
+
75
+ return set(supported_ops) & set(_all_functional_ops)
76
+
77
+
78
+ def get_decorate_func():
79
+ return decorate_forward_function
80
+
81
+
82
+ def is_func_support_decorate(orig_func):
83
+ return not inspect.isclass(orig_func) and callable(orig_func)
84
+
85
+
86
+ def get_wrapper_obj(orig_func, api_name):
87
+ if is_func_support_decorate(orig_func):
88
+ wrapped_obj = get_decorate_func()(orig_func, api_name)
89
+ else:
90
+ wrapped_obj = orig_func
91
+ return wrapped_obj
92
+
93
+
94
+ def get_module(api_name):
95
+ func_name_list = api_name.split(Const.SEP)
96
+ func_name = func_name_list[-1]
97
+ module_obj = importlib.import_module(func_name_list[0])
98
+ for i, module_name in enumerate(func_name_list[1:-1]):
99
+ if not hasattr(module_obj, module_name):
100
+ importlib.import_module(f"{Const.SEP.join(func_name_list[:i+2])}")
101
+ module_obj = getattr(module_obj, module_name)
102
+ orig_func = getattr(module_obj, func_name)
103
+
104
+ return module_obj, orig_func
105
+
106
+
107
+ def hijack(api_name):
108
+ if not api_name.strip():
109
+ return
110
+ try:
111
+ func_name = api_name.split(Const.SEP)[-1]
112
+ module_obj, origin_func = get_module(api_name)
113
+ wrapped_obj = get_wrapper_obj(origin_func, api_name)
114
+ setattr(module_obj, func_name, wrapped_obj)
115
+ except Exception as e:
116
+ logger.error(f"Failed decorator {api_name}: {e}")
File without changes
@@ -0,0 +1,12 @@
1
+ from msprobe.mindspore.common.const import FreeBenchmarkConst
2
+
3
+
4
+ class Config:
5
+ is_enable: bool = False
6
+ handler_type = FreeBenchmarkConst.DEFAULT_HANDLER_TYPE
7
+ pert_type = FreeBenchmarkConst.DEFAULT_PERT_TYPE
8
+ stage = FreeBenchmarkConst.DEFAULT_STAGE
9
+ dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
10
+ steps: list = []
11
+ ranks: list = []
12
+ dump_path: str = ""
@@ -0,0 +1,17 @@
1
+ from typing import Optional, Any, Tuple, Dict, Callable
2
+
3
+
4
+ class HandlerParams:
5
+ """
6
+ 参数结合体
7
+
8
+ """
9
+ args: Optional[Tuple] = None
10
+ kwargs: Optional[Dict] = None
11
+ index: Optional[int] = None
12
+ original_result: Optional[Any] = None
13
+ fuzzed_result: Optional[Any] = None
14
+ is_consistent: Optional[bool] = True
15
+ save_flag: Optional[bool] = True
16
+ fuzzed_value: Optional[Any] = None
17
+ original_func: Optional[Callable] = None
@@ -0,0 +1,71 @@
1
+ from typing import Any
2
+ from typing import Optional
3
+ from dataclasses import dataclass
4
+
5
+ import mindspore as ms
6
+ from mindspore import Tensor
7
+
8
+ from msprobe.mindspore.runtime import Runtime
9
+ from msprobe.mindspore.common.const import FreeBenchmarkConst
10
+ from .config import Config
11
+ from .handler_params import HandlerParams
12
+
13
+
14
+ class Tools:
15
+
16
+ @staticmethod
17
+ def get_first_tensor_dtype(tensor_seq: Any):
18
+ if isinstance(tensor_seq, Tensor):
19
+ return tensor_seq.dtype
20
+ if isinstance(tensor_seq, (list, tuple)):
21
+ for i in tensor_seq:
22
+ if isinstance(i, Tensor):
23
+ return i.dtype
24
+ raise Exception("The sequence does not contain tensors.")
25
+
26
+ @staticmethod
27
+ def get_default_error_threshold(dtype):
28
+ if Config.pert_type == FreeBenchmarkConst.NO_CHANGE:
29
+ return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD
30
+ return FreeBenchmarkConst.ERROR_THRESHOLD.get(dtype, FreeBenchmarkConst.ERROR_THRESHOLD.get(ms.float32))
31
+
32
+
33
+ @dataclass
34
+ class UnequalRow:
35
+ rank: Optional[int] = None
36
+ pert_type: Optional[str] = None
37
+ stage: Optional[str] = None
38
+ step: Optional[int] = None
39
+ api_name: Optional[str] = None
40
+ max_rel: Optional[float] = None
41
+ dtype: Optional[str] = None
42
+ shape: Optional[str] = None
43
+ output_index: Optional[int] = None
44
+
45
+
46
+ def make_unequal_row(
47
+ api_name: str,
48
+ params: HandlerParams,
49
+ ratio: float = None,
50
+ index: int = None,
51
+ ):
52
+ row = UnequalRow(
53
+ api_name=api_name,
54
+ pert_type=Config.pert_type,
55
+ output_index=index,
56
+ stage=Config.stage,
57
+ step=Runtime.step_count
58
+ )
59
+ if isinstance(ratio, float):
60
+ row.max_rel = ratio - 1
61
+ original_tensor = params.original_result
62
+ fuzzed_tensor = params.fuzzed_result
63
+ if index is not None:
64
+ original_tensor = original_tensor[index]
65
+ fuzzed_tensor = fuzzed_tensor[index]
66
+ row.output_index = index
67
+ if isinstance(original_tensor, Tensor):
68
+ row.dtype = original_tensor.dtype
69
+ row.shape = original_tensor.shape
70
+ row.rank = Runtime.rank_id if Runtime.rank_id != -1 else None
71
+ return row