mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,73 +1,93 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
-
18
- import os
19
- import torch
20
- import torch_npu
21
- import yaml
22
-
23
- from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
- from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version
25
- from msprobe.core.common.const import Const
26
- from msprobe.core.common.file_check import FileOpen
27
-
28
- cur_path = os.path.dirname(os.path.realpath(__file__))
29
- yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
30
- with FileOpen(yaml_path, 'r') as f:
31
- WrapNpuOps = yaml.safe_load(f).get('torch_npu')
32
-
33
-
34
- def get_npu_ops():
35
- global WrapNpuOps
36
- if torch_without_guard_version:
37
- _npu_ops = dir(torch.ops.npu)
38
- else:
39
- _npu_ops = dir(torch_npu._C._VariableFunctionsClass)
40
- return set(WrapNpuOps) & set(_npu_ops)
41
-
42
-
43
- class HOOKNpuOP(object):
44
- pass
45
-
46
-
47
- class NpuOPTemplate(HOOKModule):
48
-
49
- def __init__(self, op_name, hook):
50
- self.op_name_ = op_name
51
- self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP
52
- super().__init__(hook)
53
-
54
- @torch_device_guard
55
- def forward(self, *args, **kwargs):
56
- if torch_without_guard_version:
57
- return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs)
58
- else:
59
- return getattr(torch_npu._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs)
60
-
61
-
62
- def wrap_npu_op(op_name, hook):
63
-
64
- def npu_op_template(*args, **kwargs):
65
- return NpuOPTemplate(op_name, hook)(*args, **kwargs)
66
-
67
- return npu_op_template
68
-
69
-
70
- def wrap_npu_ops_and_bind(hook):
71
- _npu_ops = get_npu_ops()
72
- for op_name in _npu_ops:
73
- setattr(HOOKNpuOP, "wrap_" + str(op_name), wrap_npu_op(op_name, hook))
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ import os
19
+ import torch
20
+
21
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
22
+ from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version
23
+ from msprobe.core.common.const import Const
24
+ from msprobe.core.common.file_utils import load_yaml
25
+ from msprobe.pytorch.function_factory import npu_custom_functions
26
+
27
+ cur_path = os.path.dirname(os.path.realpath(__file__))
28
+ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
29
+
30
+
31
+ try:
32
+ import torch_npu
33
+ except ImportError:
34
+ is_gpu = True
35
+ else:
36
+ is_gpu = False
37
+
38
+
39
+ cuda_func_mapping = {
40
+ "npu_fusion_attention" : "gpu_fusion_attention"
41
+ }
42
+
43
+
44
+ def get_npu_ops():
45
+ if torch_without_guard_version:
46
+ _npu_ops = dir(torch.ops.npu)
47
+ else:
48
+ _npu_ops = dir(torch_npu._C._VariableFunctionsClass)
49
+ yaml_data = load_yaml(yaml_path)
50
+ wrap_npu_ops = yaml_data.get('torch_npu')
51
+ return set(wrap_npu_ops) & set(_npu_ops)
52
+
53
+
54
+ class HOOKNpuOP(object):
55
+ pass
56
+
57
+
58
+ class NpuOPTemplate(HOOKModule):
59
+
60
+ def __init__(self, op_name, hook, need_hook=True, device=Const.CPU_LOWERCASE):
61
+ self.op_name_ = op_name
62
+ self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP
63
+ self.need_hook = need_hook
64
+ self.device = device
65
+ if need_hook:
66
+ super().__init__(hook)
67
+
68
+ @torch_device_guard
69
+ def forward(self, *args, **kwargs):
70
+ if not self.need_hook:
71
+ if self.op_name_ not in npu_custom_functions:
72
+ raise Exception(f'There is not bench function {self.op_name_}')
73
+ if self.device == Const.CUDA_LOWERCASE:
74
+ self.op_name_ = cuda_func_mapping.get(self.op_name_, self.op_name_)
75
+ if self.device in [Const.CUDA_LOWERCASE, Const.CPU_LOWERCASE]:
76
+ return npu_custom_functions[self.op_name_](*args, **kwargs)
77
+ if torch_without_guard_version:
78
+ return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs)
79
+ else:
80
+ return getattr(torch_npu._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs)
81
+
82
+
83
+ def wrap_npu_op(op_name, hook):
84
+ def npu_op_template(*args, **kwargs):
85
+ return NpuOPTemplate(op_name, hook)(*args, **kwargs)
86
+
87
+ return npu_op_template
88
+
89
+
90
+ def wrap_npu_ops_and_bind(hook):
91
+ _npu_ops = get_npu_ops()
92
+ for op_name in _npu_ops:
93
+ setattr(HOOKNpuOP, "wrap_" + str(op_name), wrap_npu_op(op_name, hook))
@@ -1,72 +1,71 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
-
18
- import os
19
-
20
- import torch
21
- import yaml
22
-
23
- from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
- from msprobe.pytorch.common.utils import torch_device_guard, parameter_adapter
25
- from msprobe.core.common.const import Const
26
- from msprobe.core.common.file_check import FileOpen
27
-
28
- cur_path = os.path.dirname(os.path.realpath(__file__))
29
- yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
30
- with FileOpen(yaml_path, 'r') as f:
31
- WrapTensorOps = yaml.safe_load(f).get('tensor')
32
-
33
-
34
- def get_tensor_ops():
35
- global WrapTensorOps
36
- _tensor_ops = dir(torch.Tensor)
37
- return set(WrapTensorOps) & set(_tensor_ops)
38
-
39
-
40
- TensorOps = {op: getattr(torch.Tensor, op) for op in get_tensor_ops()}
41
-
42
-
43
- class HOOKTensor(object):
44
- pass
45
-
46
-
47
- class TensorOPTemplate(HOOKModule):
48
-
49
- def __init__(self, op_name, hook, need_hook=True):
50
- self.op_name_ = op_name
51
- self.prefix_op_name_ = "Tensor" + Const.SEP + str(op_name) + Const.SEP
52
- if need_hook:
53
- super().__init__(hook)
54
-
55
- @torch_device_guard
56
- @parameter_adapter
57
- def forward(self, *args, **kwargs):
58
- return TensorOps[str(self.op_name_)](*args, **kwargs)
59
-
60
-
61
- def wrap_tensor_op(op_name, hook):
62
-
63
- def tensor_op_template(*args, **kwargs):
64
- return TensorOPTemplate(op_name, hook)(*args, **kwargs)
65
-
66
- return tensor_op_template
67
-
68
-
69
- def wrap_tensor_ops_and_bind(hook):
70
- _tensor_ops = get_tensor_ops()
71
- for op_name in _tensor_ops:
72
- setattr(HOOKTensor, "wrap_" + str(op_name), wrap_tensor_op(op_name, hook))
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ import os
19
+
20
+ import torch
21
+
22
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
23
+ from msprobe.pytorch.common.utils import torch_device_guard, parameter_adapter
24
+ from msprobe.core.common.const import Const
25
+ from msprobe.core.common.file_utils import load_yaml
26
+
27
+
28
+ cur_path = os.path.dirname(os.path.realpath(__file__))
29
+ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
30
+
31
+
32
+ def get_tensor_ops():
33
+ _tensor_ops = dir(torch.Tensor)
34
+ yaml_data = load_yaml(yaml_path)
35
+ wrap_tensor_ops = yaml_data.get('tensor')
36
+ return set(wrap_tensor_ops) & set(_tensor_ops)
37
+
38
+
39
+ TensorOps = {op: getattr(torch.Tensor, op) for op in get_tensor_ops()}
40
+
41
+
42
+ class HOOKTensor(object):
43
+ pass
44
+
45
+
46
+ class TensorOPTemplate(HOOKModule):
47
+
48
+ def __init__(self, op_name, hook, need_hook=True):
49
+ self.op_name_ = op_name
50
+ self.prefix_op_name_ = "Tensor" + Const.SEP + str(op_name) + Const.SEP
51
+ if need_hook:
52
+ super().__init__(hook)
53
+
54
+ @torch_device_guard
55
+ @parameter_adapter
56
+ def forward(self, *args, **kwargs):
57
+ return TensorOps[str(self.op_name_)](*args, **kwargs)
58
+
59
+
60
+ def wrap_tensor_op(op_name, hook):
61
+
62
+ def tensor_op_template(*args, **kwargs):
63
+ return TensorOPTemplate(op_name, hook)(*args, **kwargs)
64
+
65
+ return tensor_op_template
66
+
67
+
68
+ def wrap_tensor_ops_and_bind(hook):
69
+ _tensor_ops = get_tensor_ops()
70
+ for op_name in _tensor_ops:
71
+ setattr(HOOKTensor, "wrap_" + str(op_name), wrap_tensor_op(op_name, hook))
@@ -1,88 +1,86 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
-
18
- import os
19
-
20
- import torch
21
- import yaml
22
-
23
- from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
- from msprobe.pytorch.common.utils import torch_device_guard
25
- from msprobe.core.common.const import Const
26
- from msprobe.core.common.file_check import FileOpen
27
-
28
- cur_path = os.path.dirname(os.path.realpath(__file__))
29
- yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
30
- with FileOpen(yaml_path, 'r') as f:
31
- WrapTorchOps = yaml.safe_load(f).get('torch')
32
-
33
-
34
- def get_torch_ops():
35
- global WrapTorchOps
36
- _torch_ops = []
37
- for operation in WrapTorchOps:
38
- if '.' in operation:
39
- operation_sub_module_name, operation_sub_op = operation.rsplit('.', 1)
40
- operation_sub_module = getattr(torch, operation_sub_module_name)
41
- if operation_sub_op in dir(operation_sub_module):
42
- _torch_ops.append(operation)
43
- else:
44
- if hasattr(torch, operation):
45
- _torch_ops.append(operation)
46
- return set(_torch_ops)
47
-
48
-
49
- TorchOps = {}
50
- for op in get_torch_ops():
51
- if '.' in op:
52
- sub_module_name, sub_op = op.rsplit('.', 1)
53
- sub_module = getattr(torch, sub_module_name)
54
- TorchOps[op] = getattr(sub_module, sub_op)
55
- else:
56
- TorchOps[op] = getattr(torch, op)
57
-
58
-
59
-
60
- class HOOKTorchOP(object):
61
- pass
62
-
63
-
64
- class TorchOPTemplate(HOOKModule):
65
-
66
- def __init__(self, op_name, hook, need_hook=True):
67
- self.op_name_ = op_name
68
- self.prefix_op_name_ = "Torch" + Const.SEP + str(op_name) + Const.SEP
69
- if need_hook:
70
- super().__init__(hook)
71
-
72
- @torch_device_guard
73
- def forward(self, *args, **kwargs):
74
- return TorchOps[str(self.op_name_)](*args, **kwargs)
75
-
76
-
77
- def wrap_torch_op(op_name, hook):
78
-
79
- def torch_op_template(*args, **kwargs):
80
- return TorchOPTemplate(op_name, hook)(*args, **kwargs)
81
-
82
- return torch_op_template
83
-
84
-
85
- def wrap_torch_ops_and_bind(hook):
86
- _torch_ops = get_torch_ops()
87
- for op_name in _torch_ops:
88
- setattr(HOOKTorchOP, "wrap_" + op_name, wrap_torch_op(op_name, hook))
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ import os
19
+ import torch
20
+
21
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
22
+ from msprobe.pytorch.common.utils import torch_device_guard
23
+ from msprobe.core.common.const import Const
24
+ from msprobe.core.common.file_utils import load_yaml
25
+
26
+
27
+ cur_path = os.path.dirname(os.path.realpath(__file__))
28
+ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
29
+
30
+
31
+ def get_torch_ops():
32
+ _torch_ops = []
33
+ yaml_data = load_yaml(yaml_path)
34
+ wrap_torch_ops = yaml_data.get('torch')
35
+ for operation in wrap_torch_ops:
36
+ if '.' in operation:
37
+ operation_sub_module_name, operation_sub_op = operation.rsplit('.', 1)
38
+ operation_sub_module = getattr(torch, operation_sub_module_name)
39
+ if operation_sub_op in dir(operation_sub_module):
40
+ _torch_ops.append(operation)
41
+ else:
42
+ if hasattr(torch, operation):
43
+ _torch_ops.append(operation)
44
+ return set(_torch_ops)
45
+
46
+
47
+ TorchOps = {}
48
+ for op in get_torch_ops():
49
+ if '.' in op:
50
+ sub_module_name, sub_op = op.rsplit('.', 1)
51
+ sub_module = getattr(torch, sub_module_name)
52
+ TorchOps[op] = getattr(sub_module, sub_op)
53
+ else:
54
+ TorchOps[op] = getattr(torch, op)
55
+
56
+
57
+
58
+ class HOOKTorchOP(object):
59
+ pass
60
+
61
+
62
+ class TorchOPTemplate(HOOKModule):
63
+
64
+ def __init__(self, op_name, hook, need_hook=True):
65
+ self.op_name_ = op_name
66
+ self.prefix_op_name_ = "Torch" + Const.SEP + str(op_name) + Const.SEP
67
+ if need_hook:
68
+ super().__init__(hook)
69
+
70
+ @torch_device_guard
71
+ def forward(self, *args, **kwargs):
72
+ return TorchOps[str(self.op_name_)](*args, **kwargs)
73
+
74
+
75
+ def wrap_torch_op(op_name, hook):
76
+
77
+ def torch_op_template(*args, **kwargs):
78
+ return TorchOPTemplate(op_name, hook)(*args, **kwargs)
79
+
80
+ return torch_op_template
81
+
82
+
83
+ def wrap_torch_ops_and_bind(hook):
84
+ _torch_ops = get_torch_ops()
85
+ for op_name in _torch_ops:
86
+ setattr(HOOKTorchOP, "wrap_" + op_name, wrap_torch_op(op_name, hook))
@@ -1,64 +1,62 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
-
18
- import os
19
-
20
- import torch
21
- import yaml
22
-
23
- from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
- from msprobe.core.common.file_check import FileOpen
25
- from msprobe.pytorch.common.utils import torch_device_guard
26
- from msprobe.core.common.const import Const
27
-
28
- cur_path = os.path.dirname(os.path.realpath(__file__))
29
- yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
30
- with FileOpen(yaml_path, 'r') as f:
31
- WrapVfOps = yaml.safe_load(f).get('_VF')
32
-
33
-
34
- def get_vf_ops():
35
- global WrapVfOps
36
- return WrapVfOps
37
-
38
-
39
- class HOOKVfOP(object):
40
- pass
41
-
42
-
43
- class VfOPTemplate(HOOKModule):
44
- def __init__(self, op_name, hook):
45
- self.op_name_ = op_name
46
- self.prefix_op_name_ = "VF" + Const.SEP + str(op_name) + Const.SEP
47
- super().__init__(hook)
48
-
49
- @torch_device_guard
50
- def forward(self, *args, **kwargs):
51
- return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs)
52
-
53
-
54
- def wrap_vf_op(op_name, hook):
55
- def vf_op_template(*args, **kwargs):
56
- return VfOPTemplate(op_name, hook)(*args, **kwargs)
57
-
58
- return vf_op_template
59
-
60
-
61
- def wrap_vf_ops_and_bind(hook):
62
- _vf_ops = get_vf_ops()
63
- for op_name in _vf_ops:
64
- setattr(HOOKVfOP, "wrap_" + op_name, wrap_vf_op(op_name, hook))
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ import os
19
+ import torch
20
+
21
+ from msprobe.core.common.const import Const
22
+ from msprobe.core.common.file_utils import load_yaml
23
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
+ from msprobe.pytorch.common.utils import torch_device_guard
25
+
26
+
27
+ cur_path = os.path.dirname(os.path.realpath(__file__))
28
+ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
29
+
30
+
31
+ def get_vf_ops():
32
+ yaml_data = load_yaml(yaml_path)
33
+ wrap_vf_ops = yaml_data.get('_VF')
34
+ return wrap_vf_ops
35
+
36
+
37
+ class HOOKVfOP(object):
38
+ pass
39
+
40
+
41
+ class VfOPTemplate(HOOKModule):
42
+ def __init__(self, op_name, hook):
43
+ self.op_name_ = op_name
44
+ self.prefix_op_name_ = "VF" + Const.SEP + str(op_name) + Const.SEP
45
+ super().__init__(hook)
46
+
47
+ @torch_device_guard
48
+ def forward(self, *args, **kwargs):
49
+ return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs)
50
+
51
+
52
+ def wrap_vf_op(op_name, hook):
53
+ def vf_op_template(*args, **kwargs):
54
+ return VfOPTemplate(op_name, hook)(*args, **kwargs)
55
+
56
+ return vf_op_template
57
+
58
+
59
+ def wrap_vf_ops_and_bind(hook):
60
+ _vf_ops = get_vf_ops()
61
+ for op_name in _vf_ops:
62
+ setattr(HOOKVfOP, "wrap_" + op_name, wrap_vf_op(op_name, hook))