mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,98 +1,102 @@
1
- import torch
2
- from msprobe.pytorch.free_benchmark.common.enums import DeviceType
3
-
4
-
5
- class Tools:
6
-
7
- @staticmethod
8
- def is_float_tensor(tensor) -> bool:
9
- if isinstance(tensor, torch.Tensor) and torch.is_floating_point(tensor):
10
- return True
11
- if isinstance(tensor, (list, tuple)):
12
- for value in tensor:
13
- if isinstance(value, torch.Tensor) and torch.is_floating_point(value):
14
- return True
15
- return False
16
-
17
- @staticmethod
18
- def get_dist_rank():
19
- try:
20
- return torch.distributed.get_rank()
21
- except RuntimeError:
22
- return 0
23
-
24
- @staticmethod
25
- def get_first_tensor_dtype(tensor_seq):
26
- if isinstance(tensor_seq, torch.Tensor):
27
- return tensor_seq.dtype
28
- if isinstance(tensor_seq, (list, tuple)):
29
- for object_ in tensor_seq:
30
- if isinstance(object_, torch.Tensor):
31
- return object_.dtype
32
- raise RuntimeError("The sequence does not contain tensors.")
33
-
34
- @staticmethod
35
- def get_pure_api_name(api_name: str):
36
- return api_name.rsplit(".", 2)[0]
37
-
38
- @staticmethod
39
- def convert_device_and_dtype(
40
- tensor_seq, device: str = DeviceType.CPU, change_dtype: bool = False
41
- ):
42
- if isinstance(tensor_seq, torch.Tensor):
43
- if change_dtype and tensor_seq.dtype in [torch.float16, torch.bfloat16]:
44
- return tensor_seq.detach().to(device).to(torch.float32)
45
- return tensor_seq.detach().to(device)
46
- if isinstance(tensor_seq, dict):
47
- return {
48
- key: Tools.convert_device_and_dtype(value, device, change_dtype)
49
- for key, value in tensor_seq.items()
50
- }
51
- if isinstance(tensor_seq, (tuple, list)):
52
- return type(tensor_seq)(
53
- [
54
- Tools.convert_device_and_dtype(value, device, change_dtype)
55
- for value in tensor_seq
56
- ]
57
- )
58
- return tensor_seq
59
-
60
- @staticmethod
61
- def convert_fuzz_output_to_origin(origin, perturbed):
62
- if isinstance(origin, torch.Tensor):
63
- origin.data = perturbed.to(origin.dtype).to(origin.device)
64
- return origin
65
- if isinstance(origin, dict):
66
- output = dict()
67
- for key, value in origin.items():
68
- output[key] = Tools.convert_fuzz_output_to_origin(value, perturbed[key])
69
- return output
70
- if isinstance(origin, (tuple, list)):
71
- result = list()
72
- for index_, value in enumerate(origin):
73
- result.append(
74
- Tools.convert_fuzz_output_to_origin(value, perturbed[index_])
75
- )
76
- return type(origin)(result)
77
- return origin
78
-
79
- class TorchC:
80
- sum = torch._C._VariableFunctionsClass.sum
81
- isinf = torch._C._VariableFunctionsClass.isinf
82
- isfinite = torch._C._VariableFunctionsClass.isfinite
83
- isnan = torch._C._VariableFunctionsClass.isnan
84
- logical_not = torch._C._VariableFunctionsClass.logical_not
85
- subtract = torch._C._VariableFunctionsClass.subtract
86
- abs = torch._C._VariableFunctionsClass.abs
87
- where = torch._C._VariableFunctionsClass.where
88
- div = torch._C._VariableFunctionsClass.div
89
- max = torch._C._VariableFunctionsClass.max
90
- min = torch._C._VariableFunctionsClass.min
91
- gt = torch._C._VariableFunctionsClass.gt
92
- ge = torch._C._VariableFunctionsClass.ge
93
- lt = torch._C._VariableFunctionsClass.lt
94
- mean = torch._C._VariableFunctionsClass.mean
95
- full = torch._C._VariableFunctionsClass.full
96
- add = torch._C._VariableFunctionsClass.add
97
- bitwise_xor = torch._C._VariableFunctionsClass.bitwise_xor
98
- clone = torch._C._VariableFunctionsClass.clone
1
+ import torch
2
+ from msprobe.pytorch.free_benchmark.common.enums import DeviceType
3
+
4
+
5
+ class Tools:
6
+
7
+ @staticmethod
8
+ def is_float_tensor(tensor) -> bool:
9
+ if isinstance(tensor, torch.Tensor) and torch.is_floating_point(tensor):
10
+ return True
11
+ if isinstance(tensor, (list, tuple)):
12
+ for value in tensor:
13
+ if isinstance(value, torch.Tensor) and torch.is_floating_point(value):
14
+ return True
15
+ return False
16
+
17
+ @staticmethod
18
+ def get_dist_rank():
19
+ try:
20
+ return torch.distributed.get_rank()
21
+ except RuntimeError:
22
+ return 0
23
+
24
+ @staticmethod
25
+ def get_first_tensor_dtype(tensor_seq):
26
+ if isinstance(tensor_seq, torch.Tensor):
27
+ return tensor_seq.dtype
28
+ if isinstance(tensor_seq, (list, tuple)):
29
+ for object_ in tensor_seq:
30
+ if isinstance(object_, torch.Tensor):
31
+ return object_.dtype
32
+ raise RuntimeError("The sequence does not contain tensors.")
33
+
34
+ @staticmethod
35
+ def get_pure_api_name(api_name: str):
36
+ return api_name.rsplit(".", 2)[0]
37
+
38
+ @staticmethod
39
+ def convert_device_and_dtype(
40
+ tensor_seq, device: str = DeviceType.CPU, change_dtype: bool = False
41
+ ):
42
+ if isinstance(tensor_seq, torch.Tensor):
43
+ if change_dtype and tensor_seq.dtype in [torch.float16, torch.bfloat16]:
44
+ return tensor_seq.detach().to(device).to(torch.float32)
45
+ return tensor_seq.detach().to(device)
46
+ if isinstance(tensor_seq, dict):
47
+ return {
48
+ key: Tools.convert_device_and_dtype(value, device, change_dtype)
49
+ for key, value in tensor_seq.items()
50
+ }
51
+ if isinstance(tensor_seq, (tuple, list)):
52
+ return type(tensor_seq)(
53
+ [
54
+ Tools.convert_device_and_dtype(value, device, change_dtype)
55
+ for value in tensor_seq
56
+ ]
57
+ )
58
+ return tensor_seq
59
+
60
+ @staticmethod
61
+ def convert_fuzz_output_to_origin(origin, perturbed):
62
+ if isinstance(origin, torch.Tensor):
63
+ origin.data = perturbed.to(origin.dtype).to(origin.device)
64
+ return origin
65
+ if isinstance(origin, dict):
66
+ output = dict()
67
+ for key, value in origin.items():
68
+ output[key] = Tools.convert_fuzz_output_to_origin(value, perturbed[key])
69
+ return output
70
+ if isinstance(origin, (tuple, list)):
71
+ result = list()
72
+ for index_, value in enumerate(origin):
73
+ result.append(
74
+ Tools.convert_fuzz_output_to_origin(value, perturbed[index_])
75
+ )
76
+ return type(origin)(result)
77
+ return origin
78
+
79
+ class TorchC:
80
+ sum = torch._C._VariableFunctionsClass.sum
81
+ isinf = torch._C._VariableFunctionsClass.isinf
82
+ isfinite = torch._C._VariableFunctionsClass.isfinite
83
+ isnan = torch._C._VariableFunctionsClass.isnan
84
+ logical_not = torch._C._VariableFunctionsClass.logical_not
85
+ subtract = torch._C._VariableFunctionsClass.subtract
86
+ abs = torch._C._VariableFunctionsClass.abs
87
+ where = torch._C._VariableFunctionsClass.where
88
+ div = torch._C._VariableFunctionsClass.div
89
+ max = torch._C._VariableFunctionsClass.max
90
+ min = torch._C._VariableFunctionsClass.min
91
+ gt = torch._C._VariableFunctionsClass.gt
92
+ ge = torch._C._VariableFunctionsClass.ge
93
+ lt = torch._C._VariableFunctionsClass.lt
94
+ mean = torch._C._VariableFunctionsClass.mean
95
+ full = torch._C._VariableFunctionsClass.full
96
+ add = torch._C._VariableFunctionsClass.add
97
+ bitwise_xor = torch._C._VariableFunctionsClass.bitwise_xor
98
+ clone = torch._C._VariableFunctionsClass.clone
99
+ clamp = torch._C._VariableFunctionsClass.clamp
100
+ tensor_split = torch._C._VariableFunctionsClass.tensor_split
101
+ stack = torch._C._VariableFunctionsClass.stack
102
+ reshape = torch._C._VariableFunctionsClass.reshape
@@ -1,183 +1,179 @@
1
- import torch
2
- from msprobe.core.common.exceptions import FreeBenchmarkException
3
- from msprobe.pytorch.free_benchmark import logger
4
- from msprobe.pytorch.free_benchmark.common.constant import CommonField
5
- from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams
6
- from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
7
- from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
8
- FuzzHandlerFactory,
9
- )
10
-
11
-
12
- class GradSaver:
13
-
14
- def __init__(self, origin_func, handler_params: HandlerParams):
15
-
16
- self.handler_params = handler_params
17
- self.api_name = handler_params.api_name
18
- self.origin_func = origin_func
19
- self.data_params = DataParams()
20
- self.is_compare = True
21
- self.kwargs = dict()
22
- self.perturbed_grad_input = tuple()
23
- self.origin_grad_input = tuple()
24
- self.need_grad_flag = list()
25
- self.backward_input = tuple()
26
-
27
- def register_compare_func_for_inputs(self, inputs, data_processor):
28
- _index = 0
29
- for j, obj in enumerate(inputs):
30
- if torch.is_tensor(obj) and obj.requires_grad:
31
-
32
- def compare_func(grad, new_grad_index=_index, input_index=j):
33
- if not self.is_compare:
34
- return grad
35
- try:
36
- perturbed_grad = self.check_grad_input(grad, new_grad_index)
37
- handler = FuzzHandlerFactory.create(self.handler_params)
38
- self.compare_grad_results(
39
- handler, grad, perturbed_grad, index=input_index
40
- )
41
- data_processor.update_unequal_rows(handler.get_unequal_rows())
42
- except IndexError:
43
- logger.warning_on_rank_0(
44
- f"[msprobe] Free benchmark: grad index out of range. api:{self.handler_params.api_name}."
45
- f"index:{new_grad_index}, perturbation grad len {len(self.perturbed_grad_input)}"
46
- )
47
- return grad
48
- except FreeBenchmarkException as e:
49
- logger.warning_on_rank_0(
50
- f"[msprobe] Free benchmark: grad input check error: {e}"
51
- )
52
- return grad
53
- except Exception as e:
54
- logger.warning_on_rank_0(
55
- f"[msprobe] Free benchmark: grad compare error: {e}"
56
- )
57
- return grad
58
- return grad
59
-
60
- obj.register_hook(compare_func)
61
- _index += 1
62
-
63
- def compare_grad_results(self, handler, origin_grad, perturbed_grad, index):
64
- # TODO get dtype?
65
- self.data_params.original_result = origin_grad
66
- self.data_params.perturbed_result = perturbed_grad
67
- self.data_params.grad_unequal_flag = False
68
- self.data_params.valid_input_index = index
69
- try:
70
- handler.handle(self.data_params)
71
- if not self.data_params.is_consistent:
72
- self.is_compare = False
73
- self.data_params.grad_unequal_flag = True
74
- self.data_params.is_consistent = True
75
- self.data_params.perturbed_result = self.perturbed_grad_input
76
- self.data_params.original_result = self.origin_grad_input
77
- handler.handle(self.data_params)
78
- except Exception as e:
79
- logger.warning_on_rank_0(
80
- f"[msprobe] Free benchmark: compare two vjp failed: api:{self.handler_params.api_name}."
81
- f"{e}"
82
- )
83
- # 在扰动前后输出对比后释放输出的引用
84
- self.data_params.perturbed_result = None
85
- self.data_params.original_result = None
86
-
87
- def check_grad_input(self, origin_grad, new_grad_index):
88
- if self.perturbed_grad_input is None:
89
- raise FreeBenchmarkException(
90
- FreeBenchmarkException.InvalidGrad,
91
- f"grad not exists : {self.api_name}."
92
- )
93
- with torch.no_grad():
94
- perturbed_grad = self.perturbed_grad_input[new_grad_index].to(
95
- origin_grad.device
96
- )
97
- if origin_grad.shape != perturbed_grad.shape:
98
- raise FreeBenchmarkException(
99
- FreeBenchmarkException.InvalidGrad,
100
- f"grad shapes are inconsistent. api:{self.handler_params.api_name}."
101
- f"origin:{origin_grad.shape}, perturbation: {perturbed_grad.shape}"
102
- )
103
- return perturbed_grad
104
-
105
- def cache_backward_input(self, backward_input_list):
106
- _inputs = []
107
- with torch.no_grad():
108
- for backward_input in backward_input_list:
109
- if torch.is_tensor(backward_input):
110
- _inputs.append(
111
- {
112
- CommonField.DEVICE: backward_input.device,
113
- CommonField.FUZZ_TENSOR: backward_input.cpu(),
114
- CommonField.REQUIRES_GRAD: backward_input.requires_grad,
115
- }
116
- )
117
- else:
118
- _inputs.append(backward_input)
119
- self.backward_input = _inputs
120
-
121
- def get_vjp_input(self):
122
- inner_args_tmp = []
123
- need_grad_tensors = []
124
- for object_ in self.backward_input:
125
- if isinstance(object_, dict) and CommonField.FUZZ_TENSOR in object_.keys():
126
- tensor_ = torch.tensor(
127
- object_.get(CommonField.FUZZ_TENSOR).data,
128
- dtype=object_.get(CommonField.FUZZ_TENSOR).dtype,
129
- device=object_.get(CommonField.DEVICE),
130
- requires_grad=object_.get(CommonField.REQUIRES_GRAD),
131
- )
132
-
133
- if tensor_.requires_grad:
134
- inner_args_tmp.append(CommonField.HOLD_PLACE)
135
- need_grad_tensors.append(tensor_)
136
- self.need_grad_flag.append(True)
137
- else:
138
- self.need_grad_flag.append(False)
139
- inner_args_tmp.append(tensor_)
140
- else:
141
- self.need_grad_flag.append(False)
142
- inner_args_tmp.append(object_)
143
-
144
- return need_grad_tensors, tuple(inner_args_tmp)
145
-
146
- def get_grad_input_from_vjp(self, need_grad_tensors, grad_output, inner_args):
147
- def vjp_func(*inputs):
148
- _real_input = []
149
- index_ = 0
150
- for object_ in inner_args:
151
- if object_ is CommonField.HOLD_PLACE:
152
- _real_input.append(inputs[index_])
153
- index_ += 1
154
- else:
155
- _real_input.append(object_)
156
- kwargs = self.kwargs.copy()
157
- if 'inplace' in kwargs:
158
- kwargs['inplace'] = False
159
- return self.origin_func(*_real_input, **kwargs)
160
-
161
- _, grad_input = torch.autograd.functional.vjp(
162
- vjp_func, tuple(need_grad_tensors), grad_output
163
- )
164
- return grad_input
165
-
166
- def calculate_perturbed_grad_input(self, grad_output, need_grad_tensors, inner_args):
167
- self.data_params.args = [need_grad_tensors, grad_output, inner_args]
168
- self.data_params.kwargs = {}
169
- self.data_params.valid_input_index = 0
170
- self.data_params.origin_func = self.get_grad_input_from_vjp
171
- layer = LayerFactory.create(
172
- self.handler_params.api_name,
173
- self.handler_params.fuzz_device,
174
- self.handler_params.pert_mode,
175
- )
176
- layer.handle(self.data_params)
177
- # 在计算扰动输出之后,释放输入的引用
178
- self.data_params.args = None
179
- # 确定扰动成功后,才会暂存
180
- if self.data_params.perturbed_result:
181
- self.perturbed_grad_input = tuple(
182
- [x.cpu() for x in self.data_params.perturbed_result]
183
- )
1
+ import torch
2
+ from msprobe.core.common.exceptions import FreeBenchmarkException
3
+ from msprobe.pytorch.free_benchmark import logger
4
+ from msprobe.pytorch.free_benchmark.common.constant import CommonField
5
+ from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams, data_pre_deal
6
+ from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
7
+ from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
8
+ FuzzHandlerFactory,
9
+ )
10
+
11
+
12
+ class GradSaver:
13
+
14
+ def __init__(self, origin_func, handler_params: HandlerParams):
15
+
16
+ self.handler_params = handler_params
17
+ self.api_name = handler_params.api_name
18
+ self.origin_func = origin_func
19
+ self.is_compare = True
20
+ self.kwargs = dict()
21
+ self.perturbed_grad_input = tuple()
22
+ self.origin_grad_input = tuple()
23
+ self.need_grad_flag = list()
24
+ self.backward_input = tuple()
25
+
26
+ def register_compare_func_for_inputs(self, inputs, data_processor):
27
+ _index = 0
28
+ for j, obj in enumerate(inputs):
29
+ if torch.is_tensor(obj) and obj.requires_grad:
30
+
31
+ def compare_func(grad, new_grad_index=_index, input_index=j):
32
+ if not self.is_compare:
33
+ return grad
34
+ try:
35
+ perturbed_grad = self.check_grad_input(grad, new_grad_index)
36
+ handler = FuzzHandlerFactory.create(self.handler_params)
37
+ self.compare_grad_results(
38
+ handler, grad, perturbed_grad, index=input_index
39
+ )
40
+ data_processor.update_unequal_rows(handler.get_unequal_rows())
41
+ except IndexError:
42
+ logger.warning_on_rank_0(
43
+ f"[msprobe] Free benchmark: grad index out of range. api:{self.handler_params.api_name}."
44
+ f"index:{new_grad_index}, perturbation grad len {len(self.perturbed_grad_input)}"
45
+ )
46
+ return grad
47
+ except FreeBenchmarkException as e:
48
+ logger.warning_on_rank_0(
49
+ f"[msprobe] Free benchmark: grad input check error: {e}"
50
+ )
51
+ return grad
52
+ except Exception as e:
53
+ logger.warning_on_rank_0(
54
+ f"[msprobe] Free benchmark: grad compare error: {e}"
55
+ )
56
+ return grad
57
+ return grad
58
+
59
+ obj.register_hook(compare_func)
60
+ _index += 1
61
+
62
+ def compare_grad_results(self, handler, origin_grad, perturbed_grad, index):
63
+ data_params = DataParams()
64
+ data_params.original_result = origin_grad
65
+ data_params.perturbed_result = perturbed_grad
66
+ data_params.grad_unequal_flag = False
67
+ data_params.valid_input_index = index
68
+ try:
69
+ handler.handle(data_params)
70
+ if not data_params.is_consistent:
71
+ self.is_compare = False
72
+ data_params.grad_unequal_flag = True
73
+ data_params.is_consistent = True
74
+ data_params.perturbed_result = self.perturbed_grad_input
75
+ data_params.original_result = self.origin_grad_input
76
+ handler.handle(data_params)
77
+ except Exception as e:
78
+ logger.warning_on_rank_0(
79
+ f"[msprobe] Free benchmark: compare two vjp failed: api:{self.handler_params.api_name}."
80
+ f"{e}"
81
+ )
82
+
83
+ def check_grad_input(self, origin_grad, new_grad_index):
84
+ if self.perturbed_grad_input is None:
85
+ raise FreeBenchmarkException(
86
+ FreeBenchmarkException.InvalidGrad,
87
+ f"grad not exists : {self.api_name}."
88
+ )
89
+ with torch.no_grad():
90
+ perturbed_grad = self.perturbed_grad_input[new_grad_index].to(
91
+ origin_grad.device
92
+ )
93
+ if origin_grad.shape != perturbed_grad.shape:
94
+ raise FreeBenchmarkException(
95
+ FreeBenchmarkException.InvalidGrad,
96
+ f"grad shapes are inconsistent. api:{self.handler_params.api_name}."
97
+ f"origin:{origin_grad.shape}, perturbation: {perturbed_grad.shape}"
98
+ )
99
+ return perturbed_grad
100
+
101
+ def cache_backward_input(self, backward_input_list):
102
+ _inputs = []
103
+ with torch.no_grad():
104
+ for backward_input in backward_input_list:
105
+ if torch.is_tensor(backward_input):
106
+ _inputs.append(
107
+ {
108
+ CommonField.DEVICE: backward_input.device,
109
+ CommonField.FUZZ_TENSOR: backward_input.cpu(),
110
+ CommonField.REQUIRES_GRAD: backward_input.requires_grad,
111
+ }
112
+ )
113
+ else:
114
+ _inputs.append(backward_input)
115
+ self.backward_input = _inputs
116
+
117
+ def get_vjp_input(self):
118
+ inner_args_tmp = []
119
+ need_grad_tensors = []
120
+ for object_ in self.backward_input:
121
+ if isinstance(object_, dict) and CommonField.FUZZ_TENSOR in object_.keys():
122
+ tensor_ = torch.tensor(
123
+ object_.get(CommonField.FUZZ_TENSOR).data,
124
+ dtype=object_.get(CommonField.FUZZ_TENSOR).dtype,
125
+ device=object_.get(CommonField.DEVICE),
126
+ requires_grad=object_.get(CommonField.REQUIRES_GRAD),
127
+ )
128
+
129
+ if tensor_.requires_grad:
130
+ inner_args_tmp.append(CommonField.HOLD_PLACE)
131
+ need_grad_tensors.append(tensor_)
132
+ self.need_grad_flag.append(True)
133
+ else:
134
+ self.need_grad_flag.append(False)
135
+ inner_args_tmp.append(tensor_)
136
+ else:
137
+ self.need_grad_flag.append(False)
138
+ inner_args_tmp.append(object_)
139
+
140
+ return need_grad_tensors, tuple(inner_args_tmp)
141
+
142
+ def get_grad_input_from_vjp(self, need_grad_tensors, grad_output, inner_args):
143
+ def vjp_func(*inputs):
144
+ _real_input = []
145
+ index_ = 0
146
+ for object_ in inner_args:
147
+ if object_ is CommonField.HOLD_PLACE:
148
+ _real_input.append(inputs[index_])
149
+ index_ += 1
150
+ else:
151
+ _real_input.append(object_)
152
+ kwargs = self.kwargs.copy()
153
+ if 'inplace' in kwargs:
154
+ kwargs['inplace'] = False
155
+ return self.origin_func(*_real_input, **kwargs)
156
+
157
+ _, grad_input = torch.autograd.functional.vjp(
158
+ vjp_func, tuple(need_grad_tensors), grad_output
159
+ )
160
+ return grad_input
161
+
162
+ def calculate_perturbed_grad_input(self, grad_output, need_grad_tensors, inner_args):
163
+ data_params = data_pre_deal(
164
+ self.handler_params.api_name,
165
+ self.get_grad_input_from_vjp,
166
+ [need_grad_tensors, grad_output, inner_args],
167
+ {}
168
+ )
169
+ layer = LayerFactory.create(
170
+ self.handler_params.api_name,
171
+ self.handler_params.fuzz_device,
172
+ self.handler_params.pert_mode,
173
+ )
174
+ layer.handle(data_params)
175
+ # 确定扰动成功后,才会暂存
176
+ if data_params.perturbed_result:
177
+ self.perturbed_grad_input = tuple(
178
+ [x.cpu() for x in data_params.perturbed_result]
179
+ )