mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,29 +1,30 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
-
18
- import os
19
- from msprobe.core.common.utils import load_yaml
20
-
21
-
22
- def get_ops():
23
- cur_path = os.path.dirname(os.path.realpath(__file__))
24
- yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
25
- ops = load_yaml(yaml_path)
26
- wrap_functional = ops.get('functional')
27
- wrap_tensor = ops.get('tensor')
28
- wrap_torch = ops.get('torch')
29
- return set(wrap_functional) | set(wrap_tensor) | set(wrap_torch)
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ import os
19
+ from msprobe.core.common.file_utils import load_yaml
20
+
21
+
22
+ def get_ops():
23
+ cur_path = os.path.dirname(os.path.realpath(__file__))
24
+ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
25
+ ops = load_yaml(yaml_path)
26
+ wrap_functional = ops.get('functional')
27
+ wrap_tensor = ops.get('tensor')
28
+ wrap_torch = ops.get('torch')
29
+ wrap_npu_ops = ops.get('torch_npu')
30
+ return set(wrap_functional) | set(wrap_tensor) | set(wrap_torch) | set(wrap_npu_ops)
@@ -1,110 +1,110 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
-
18
- import os
19
- import torch
20
-
21
- from msprobe.pytorch.hook_module.hook_module import HOOKModule
22
- from msprobe.pytorch.common.utils import torch_device_guard
23
- from msprobe.core.common.const import Const
24
- from msprobe.core.common.utils import load_yaml
25
- from msprobe.pytorch.function_factory import npu_custom_grad_functions
26
-
27
- cur_path = os.path.dirname(os.path.realpath(__file__))
28
- yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
29
- ops = load_yaml(yaml_path)
30
- wrap_aten_ops = ops.get('aten')
31
- white_aten_ops = ops.get('white_aten_ops', [])
32
-
33
-
34
- aten_func = {}
35
- for f in dir(torch.ops.aten):
36
- aten_func[f] = getattr(torch.ops.aten, f)
37
-
38
-
39
- def get_aten_ops():
40
- global wrap_aten_ops
41
- _all_aten_ops = dir(torch.ops.aten)
42
- return set(wrap_aten_ops) & set(_all_aten_ops)
43
-
44
-
45
- class HOOKAtenOP(object):
46
- pass
47
-
48
-
49
- class AtenOPTemplate(HOOKModule):
50
- def __init__(self, op, hook, need_hook=True):
51
- if isinstance(op, torch._ops.OpOverloadPacket):
52
- op_name_ = op._qualified_op_name.split("::")[-1]
53
- else:
54
- op_name_ = op.name().split("::")[-1]
55
- overload_name = op._overloadname
56
- if not '.' + overload_name in op_name_:
57
- op_name_ = op_name_ + '.' + overload_name
58
- self.op = op
59
- self.prefix_op_name_ = "Aten" + Const.SEP + str(op_name_) + Const.SEP
60
- self.need_hook = need_hook
61
- if self.need_hook:
62
- super().__init__(hook)
63
-
64
- @torch_device_guard
65
- def forward(self, *args, **kwargs):
66
- if isinstance(self.op, str):
67
- if self.op in npu_custom_grad_functions:
68
- return npu_custom_grad_functions[self.op](*args, **kwargs)
69
- if self.op in white_aten_ops:
70
- return eval(f"torch.ops.aten.{self.op}")(*args, **kwargs)
71
- if self.op not in aten_func:
72
- raise Exception(f"Skip op[{self.op}] accuracy check, because the op is not "
73
- f"in dir(torch.ops.aten) and support yaml.")
74
- return aten_func[self.op](*args, **kwargs)
75
- return self.op(*args, **kwargs)
76
-
77
-
78
- class AtenOPPacketTemplate():
79
- def __init__(self, opPacket, hook):
80
- self.opPacket = opPacket
81
- self.hook = hook
82
-
83
- def __getattr__(self, key):
84
- try:
85
- attr = getattr(self.opPacket, key)
86
- except AttributeError as e:
87
- raise AttributeError(f"AtenOPPacketTemplate or OpOverloadPacket does not have attribute '{key}'.") from e
88
- if isinstance(attr, torch._ops.OpOverload):
89
- return AtenOPTemplate(attr, self.hook)
90
- else:
91
- return attr
92
-
93
- @torch_device_guard
94
- def __call__(self, *args, **kwargs):
95
- return AtenOPTemplate(self.opPacket, self.hook)(*args, **kwargs)
96
-
97
- def overloads(self):
98
- return self.opPacket.overloads()
99
-
100
-
101
- def wrap_aten_op(op, hook):
102
- return AtenOPPacketTemplate(op, hook)
103
-
104
-
105
- def wrap_aten_ops_and_bind(hook):
106
- _aten_ops = get_aten_ops()
107
- for op_name in _aten_ops:
108
- if not isinstance(aten_func.get(op_name), torch._ops.OpOverloadPacket):
109
- continue
110
- setattr(HOOKAtenOP, "wrap_" + str(op_name), wrap_aten_op(aten_func.get(op_name), hook))
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ import os
19
+ import torch
20
+
21
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
22
+ from msprobe.pytorch.common.utils import torch_device_guard
23
+ from msprobe.core.common.const import Const
24
+ from msprobe.core.common.file_utils import load_yaml
25
+ from msprobe.pytorch.function_factory import npu_custom_grad_functions
26
+
27
+ cur_path = os.path.dirname(os.path.realpath(__file__))
28
+ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
29
+ ops = load_yaml(yaml_path)
30
+ wrap_aten_ops = ops.get('aten')
31
+ white_aten_ops = ops.get('white_aten_ops', [])
32
+
33
+
34
+ aten_func = {}
35
+ for f in dir(torch.ops.aten):
36
+ aten_func[f] = getattr(torch.ops.aten, f)
37
+
38
+
39
+ def get_aten_ops():
40
+ global wrap_aten_ops
41
+ _all_aten_ops = dir(torch.ops.aten)
42
+ return set(wrap_aten_ops) & set(_all_aten_ops)
43
+
44
+
45
+ class HOOKAtenOP(object):
46
+ pass
47
+
48
+
49
+ class AtenOPTemplate(HOOKModule):
50
+ def __init__(self, op, hook, need_hook=True):
51
+ if isinstance(op, torch._ops.OpOverloadPacket):
52
+ op_name_ = op._qualified_op_name.split("::")[-1]
53
+ else:
54
+ op_name_ = op.name().split("::")[-1]
55
+ overload_name = op._overloadname
56
+ if not '.' + overload_name in op_name_:
57
+ op_name_ = op_name_ + '.' + overload_name
58
+ self.op = op
59
+ self.prefix_op_name_ = "Aten" + Const.SEP + str(op_name_) + Const.SEP
60
+ self.need_hook = need_hook
61
+ if self.need_hook:
62
+ super().__init__(hook)
63
+
64
+ @torch_device_guard
65
+ def forward(self, *args, **kwargs):
66
+ if isinstance(self.op, str):
67
+ if self.op in npu_custom_grad_functions:
68
+ return npu_custom_grad_functions[self.op](*args, **kwargs)
69
+ if self.op in white_aten_ops:
70
+ return eval(f"torch.ops.aten.{self.op}")(*args, **kwargs)
71
+ if self.op not in aten_func:
72
+ raise Exception(f"Skip op[{self.op}] accuracy check, because the op is not "
73
+ f"in dir(torch.ops.aten) and support yaml.")
74
+ return aten_func[self.op](*args, **kwargs)
75
+ return self.op(*args, **kwargs)
76
+
77
+
78
+ class AtenOPPacketTemplate():
79
+ def __init__(self, opPacket, hook):
80
+ self.opPacket = opPacket
81
+ self.hook = hook
82
+
83
+ def __getattr__(self, key):
84
+ try:
85
+ attr = getattr(self.opPacket, key)
86
+ except AttributeError as e:
87
+ raise AttributeError(f"AtenOPPacketTemplate or OpOverloadPacket does not have attribute '{key}'.") from e
88
+ if isinstance(attr, torch._ops.OpOverload):
89
+ return AtenOPTemplate(attr, self.hook)
90
+ else:
91
+ return attr
92
+
93
+ @torch_device_guard
94
+ def __call__(self, *args, **kwargs):
95
+ return AtenOPTemplate(self.opPacket, self.hook)(*args, **kwargs)
96
+
97
+ def overloads(self):
98
+ return self.opPacket.overloads()
99
+
100
+
101
+ def wrap_aten_op(op, hook):
102
+ return AtenOPPacketTemplate(op, hook)
103
+
104
+
105
+ def wrap_aten_ops_and_bind(hook):
106
+ _aten_ops = get_aten_ops()
107
+ for op_name in _aten_ops:
108
+ if not isinstance(aten_func.get(op_name), torch._ops.OpOverloadPacket):
109
+ continue
110
+ setattr(HOOKAtenOP, "wrap_" + str(op_name), wrap_aten_op(aten_func.get(op_name), hook))
@@ -1,78 +1,78 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
-
18
- import os
19
- from functools import wraps
20
- import torch.distributed as dist
21
-
22
- from msprobe.pytorch.hook_module.hook_module import HOOKModule
23
- from msprobe.pytorch.common.utils import torch_device_guard
24
- from msprobe.core.common.const import Const
25
- from msprobe.core.common.utils import load_yaml
26
-
27
-
28
- cur_path = os.path.dirname(os.path.realpath(__file__))
29
- yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
30
-
31
-
32
- distributed_func = {}
33
- for f in dir(dist):
34
- distributed_func[f] = getattr(dist, f)
35
-
36
-
37
- def get_distributed_ops():
38
- _all_distributed_ops = dir(dist)
39
- yaml_data = load_yaml(yaml_path)
40
- wrap_distributed_ops = yaml_data.get('distributed')
41
- return set(wrap_distributed_ops) & set(_all_distributed_ops)
42
-
43
-
44
- class HOOKDistributedOP(object):
45
- pass
46
-
47
-
48
- class DistributedOPTemplate(HOOKModule):
49
- def __init__(self, op_name, build_hook):
50
- self.op_name_ = op_name
51
- self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP
52
- super().__init__(build_hook)
53
- if not self.stop_hook and self.op_name_ in Const.INPLACE_LIST:
54
- self.op_is_inplace = True
55
-
56
- @torch_device_guard
57
- def forward(self, *args, **kwargs):
58
- if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]:
59
- handle = distributed_func.get(self.op_name_)(*args, **kwargs)
60
- handle.wait()
61
- return handle
62
- else:
63
- return distributed_func.get(self.op_name_)(*args, **kwargs)
64
-
65
-
66
- def wrap_distributed_op(op_name, hook):
67
- @wraps(DistributedOPTemplate)
68
- def distributed_op_template(*args, **kwargs):
69
- return DistributedOPTemplate(op_name, hook)(*args, **kwargs)
70
-
71
- distributed_op_template.__name__ = op_name
72
- return distributed_op_template
73
-
74
-
75
- def wrap_distributed_ops_and_bind(hook):
76
- _distributed_ops = get_distributed_ops()
77
- for op_name in _distributed_ops:
78
- setattr(HOOKDistributedOP, "wrap_" + str(op_name), wrap_distributed_op(op_name, hook))
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ import os
19
+ from functools import wraps
20
+ import torch.distributed as dist
21
+
22
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
23
+ from msprobe.pytorch.common.utils import torch_device_guard
24
+ from msprobe.core.common.const import Const
25
+ from msprobe.core.common.file_utils import load_yaml
26
+
27
+
28
+ cur_path = os.path.dirname(os.path.realpath(__file__))
29
+ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
30
+
31
+
32
+ distributed_func = {}
33
+ for f in dir(dist):
34
+ distributed_func[f] = getattr(dist, f)
35
+
36
+
37
+ def get_distributed_ops():
38
+ _all_distributed_ops = dir(dist)
39
+ yaml_data = load_yaml(yaml_path)
40
+ wrap_distributed_ops = yaml_data.get('distributed')
41
+ return set(wrap_distributed_ops) & set(_all_distributed_ops)
42
+
43
+
44
+ class HOOKDistributedOP(object):
45
+ pass
46
+
47
+
48
+ class DistributedOPTemplate(HOOKModule):
49
+ def __init__(self, op_name, build_hook):
50
+ self.op_name_ = op_name
51
+ self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP
52
+ super().__init__(build_hook)
53
+ if not self.stop_hook and self.op_name_ in Const.INPLACE_LIST:
54
+ self.op_is_inplace = True
55
+
56
+ @torch_device_guard
57
+ def forward(self, *args, **kwargs):
58
+ if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]:
59
+ handle = distributed_func.get(self.op_name_)(*args, **kwargs)
60
+ handle.wait()
61
+ return handle
62
+ else:
63
+ return distributed_func.get(self.op_name_)(*args, **kwargs)
64
+
65
+
66
+ def wrap_distributed_op(op_name, hook):
67
+ @wraps(DistributedOPTemplate)
68
+ def distributed_op_template(*args, **kwargs):
69
+ return DistributedOPTemplate(op_name, hook)(*args, **kwargs)
70
+
71
+ distributed_op_template.__name__ = op_name
72
+ return distributed_op_template
73
+
74
+
75
+ def wrap_distributed_ops_and_bind(hook):
76
+ _distributed_ops = get_distributed_ops()
77
+ for op_name in _distributed_ops:
78
+ setattr(HOOKDistributedOP, "wrap_" + str(op_name), wrap_distributed_op(op_name, hook))
@@ -1,105 +1,105 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
-
18
- import os
19
- import torch
20
-
21
- from msprobe.pytorch.hook_module.hook_module import HOOKModule
22
- from msprobe.pytorch.common.utils import torch_device_guard
23
- from msprobe.core.common.const import Const
24
- from msprobe.pytorch.common.log import logger
25
- from msprobe.core.common.utils import load_yaml
26
-
27
-
28
- def remove_dropout():
29
- if torch.__version__ > "1.8":
30
- logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.")
31
- import torch.nn.functional as F
32
- from torch import _VF
33
- from torch.overrides import has_torch_function_unary, handle_torch_function
34
-
35
- def function_dropout(input: torch.Tensor, p: float = 0.5, training: bool = True,
36
- inplace: bool = False) -> torch.Tensor:
37
- if has_torch_function_unary(input):
38
- return handle_torch_function(function_dropout, (input,), input, p=0., training=training, inplace=inplace)
39
- if p < 0.0 or p > 1.0:
40
- raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
41
- return _VF.dropout_(input, 0., training) if inplace else _VF.dropout(input, 0., training)
42
-
43
-
44
- def function_dropout2d(input: torch.Tensor, p: float = 0.5, training: bool = True,
45
- inplace: bool = False) -> torch.Tensor:
46
- if has_torch_function_unary(input):
47
- return handle_torch_function(function_dropout2d, (input,), input, p=0., training=training, inplace=inplace)
48
- if p < 0.0 or p > 1.0:
49
- raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
50
- return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training)
51
-
52
-
53
- def function_dropout3d(input: torch.Tensor, p: float = 0.5, training: bool = True,
54
- inplace: bool = False) -> torch.Tensor:
55
- if has_torch_function_unary(input):
56
- return handle_torch_function(function_dropout3d, (input,), input, p=0., training=training, inplace=inplace)
57
- if p < 0.0 or p > 1.0:
58
- raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
59
- return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training)
60
-
61
- F.dropout = function_dropout
62
- F.dropout2d = function_dropout2d
63
- F.dropout3d = function_dropout3d
64
-
65
- cur_path = os.path.dirname(os.path.realpath(__file__))
66
- yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
67
-
68
-
69
- def get_functional_ops():
70
- yaml_data = load_yaml(yaml_path)
71
- wrap_functional_ops = yaml_data.get('functional')
72
- _all_functional_ops = dir(torch.nn.functional)
73
- return set(wrap_functional_ops) & set(_all_functional_ops)
74
-
75
-
76
- TorchFunctions = {func: getattr(torch.nn.functional, func) for func in get_functional_ops()}
77
-
78
-
79
- class HOOKFunctionalOP(object):
80
- pass
81
-
82
-
83
- class FunctionalOPTemplate(HOOKModule):
84
- def __init__(self, op_name, hook, need_hook=True):
85
- self.op_name_ = op_name
86
- self.prefix_op_name_ = "Functional" + Const.SEP + str(op_name) + Const.SEP
87
- if need_hook:
88
- super().__init__(hook)
89
-
90
- @torch_device_guard
91
- def forward(self, *args, **kwargs):
92
- return TorchFunctions[str(self.op_name_)](*args, **kwargs)
93
-
94
-
95
- def wrap_functional_op(op_name, hook):
96
- def functional_op_template(*args, **kwargs):
97
- return FunctionalOPTemplate(op_name, hook)(*args, **kwargs)
98
-
99
- return functional_op_template
100
-
101
-
102
- def wrap_functional_ops_and_bind(hook):
103
- _functional_ops = get_functional_ops()
104
- for op_name in _functional_ops:
105
- setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook))
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ import os
19
+ import torch
20
+
21
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
22
+ from msprobe.pytorch.common.utils import torch_device_guard
23
+ from msprobe.core.common.const import Const
24
+ from msprobe.pytorch.common.log import logger
25
+ from msprobe.core.common.file_utils import load_yaml
26
+
27
+
28
+ def remove_dropout():
29
+ if torch.__version__ > "1.8":
30
+ logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.")
31
+ import torch.nn.functional as F
32
+ from torch import _VF
33
+ from torch.overrides import has_torch_function_unary, handle_torch_function
34
+
35
+ def function_dropout(input: torch.Tensor, p: float = 0.5, training: bool = True,
36
+ inplace: bool = False) -> torch.Tensor:
37
+ if has_torch_function_unary(input):
38
+ return handle_torch_function(function_dropout, (input,), input, p=0., training=training, inplace=inplace)
39
+ if p < 0.0 or p > 1.0:
40
+ raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
41
+ return _VF.dropout_(input, 0., training) if inplace else _VF.dropout(input, 0., training)
42
+
43
+
44
+ def function_dropout2d(input: torch.Tensor, p: float = 0.5, training: bool = True,
45
+ inplace: bool = False) -> torch.Tensor:
46
+ if has_torch_function_unary(input):
47
+ return handle_torch_function(function_dropout2d, (input,), input, p=0., training=training, inplace=inplace)
48
+ if p < 0.0 or p > 1.0:
49
+ raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
50
+ return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training)
51
+
52
+
53
+ def function_dropout3d(input: torch.Tensor, p: float = 0.5, training: bool = True,
54
+ inplace: bool = False) -> torch.Tensor:
55
+ if has_torch_function_unary(input):
56
+ return handle_torch_function(function_dropout3d, (input,), input, p=0., training=training, inplace=inplace)
57
+ if p < 0.0 or p > 1.0:
58
+ raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
59
+ return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training)
60
+
61
+ F.dropout = function_dropout
62
+ F.dropout2d = function_dropout2d
63
+ F.dropout3d = function_dropout3d
64
+
65
+ cur_path = os.path.dirname(os.path.realpath(__file__))
66
+ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
67
+
68
+
69
+ def get_functional_ops():
70
+ yaml_data = load_yaml(yaml_path)
71
+ wrap_functional_ops = yaml_data.get('functional')
72
+ _all_functional_ops = dir(torch.nn.functional)
73
+ return set(wrap_functional_ops) & set(_all_functional_ops)
74
+
75
+
76
+ TorchFunctions = {func: getattr(torch.nn.functional, func) for func in get_functional_ops()}
77
+
78
+
79
+ class HOOKFunctionalOP(object):
80
+ pass
81
+
82
+
83
+ class FunctionalOPTemplate(HOOKModule):
84
+ def __init__(self, op_name, hook, need_hook=True):
85
+ self.op_name_ = op_name
86
+ self.prefix_op_name_ = "Functional" + Const.SEP + str(op_name) + Const.SEP
87
+ if need_hook:
88
+ super().__init__(hook)
89
+
90
+ @torch_device_guard
91
+ def forward(self, *args, **kwargs):
92
+ return TorchFunctions[str(self.op_name_)](*args, **kwargs)
93
+
94
+
95
+ def wrap_functional_op(op_name, hook):
96
+ def functional_op_template(*args, **kwargs):
97
+ return FunctionalOPTemplate(op_name, hook)(*args, **kwargs)
98
+
99
+ return functional_op_template
100
+
101
+
102
+ def wrap_functional_ops_and_bind(hook):
103
+ _functional_ops = get_functional_ops()
104
+ for op_name in _functional_ops:
105
+ setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook))