mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,105 +1,104 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
-
18
- import os
19
- import torch
20
-
21
- from msprobe.pytorch.hook_module.hook_module import HOOKModule
22
- from msprobe.pytorch.common.utils import torch_device_guard
23
- from msprobe.core.common.const import Const
24
- from msprobe.pytorch.common.log import logger
25
- from msprobe.core.common.utils import load_yaml
26
-
27
-
28
- def remove_dropout():
29
- if torch.__version__ > "1.8":
30
- logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.")
31
- import torch.nn.functional as F
32
- from torch import _VF
33
- from torch.overrides import has_torch_function_unary, handle_torch_function
34
-
35
- def function_dropout(input: torch.Tensor, p: float = 0.5, training: bool = True,
36
- inplace: bool = False) -> torch.Tensor:
37
- if has_torch_function_unary(input):
38
- return handle_torch_function(function_dropout, (input,), input, p=0., training=training, inplace=inplace)
39
- if p < 0.0 or p > 1.0:
40
- raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
41
- return _VF.dropout_(input, 0., training) if inplace else _VF.dropout(input, 0., training)
42
-
43
-
44
- def function_dropout2d(input: torch.Tensor, p: float = 0.5, training: bool = True,
45
- inplace: bool = False) -> torch.Tensor:
46
- if has_torch_function_unary(input):
47
- return handle_torch_function(function_dropout2d, (input,), input, p=0., training=training, inplace=inplace)
48
- if p < 0.0 or p > 1.0:
49
- raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
50
- return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training)
51
-
52
-
53
- def function_dropout3d(input: torch.Tensor, p: float = 0.5, training: bool = True,
54
- inplace: bool = False) -> torch.Tensor:
55
- if has_torch_function_unary(input):
56
- return handle_torch_function(function_dropout3d, (input,), input, p=0., training=training, inplace=inplace)
57
- if p < 0.0 or p > 1.0:
58
- raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
59
- return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training)
60
-
61
- F.dropout = function_dropout
62
- F.dropout2d = function_dropout2d
63
- F.dropout3d = function_dropout3d
64
-
65
- cur_path = os.path.dirname(os.path.realpath(__file__))
66
- yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
67
-
68
-
69
- def get_functional_ops():
70
- yaml_data = load_yaml(yaml_path)
71
- wrap_functional_ops = yaml_data.get('functional')
72
- _all_functional_ops = dir(torch.nn.functional)
73
- return set(wrap_functional_ops) & set(_all_functional_ops)
74
-
75
-
76
- TorchFunctions = {func: getattr(torch.nn.functional, func) for func in get_functional_ops()}
77
-
78
-
79
- class HOOKFunctionalOP(object):
80
- pass
81
-
82
-
83
- class FunctionalOPTemplate(HOOKModule):
84
- def __init__(self, op_name, hook, need_hook=True):
85
- self.op_name_ = op_name
86
- self.prefix_op_name_ = "Functional" + Const.SEP + str(op_name) + Const.SEP
87
- if need_hook:
88
- super().__init__(hook)
89
-
90
- @torch_device_guard
91
- def forward(self, *args, **kwargs):
92
- return TorchFunctions[str(self.op_name_)](*args, **kwargs)
93
-
94
-
95
- def wrap_functional_op(op_name, hook):
96
- def functional_op_template(*args, **kwargs):
97
- return FunctionalOPTemplate(op_name, hook)(*args, **kwargs)
98
-
99
- return functional_op_template
100
-
101
-
102
- def wrap_functional_ops_and_bind(hook):
103
- _functional_ops = get_functional_ops()
104
- for op_name in _functional_ops:
105
- setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook))
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import torch
18
+
19
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
20
+ from msprobe.pytorch.common.utils import torch_device_guard
21
+ from msprobe.core.common.const import Const
22
+ from msprobe.pytorch.common.log import logger
23
+ from msprobe.core.common.file_utils import load_yaml
24
+
25
+
26
+ def remove_dropout():
27
+ if torch.__version__ > "1.8":
28
+ logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.")
29
+ import torch.nn.functional as F
30
+ from torch import _VF
31
+ from torch.overrides import has_torch_function_unary, handle_torch_function
32
+
33
+ def function_dropout(input: torch.Tensor, p: float = 0.5, training: bool = True,
34
+ inplace: bool = False) -> torch.Tensor:
35
+ if has_torch_function_unary(input):
36
+ return handle_torch_function(
37
+ function_dropout, (input,), input, p=0., training=training, inplace=inplace)
38
+ if p < 0.0 or p > 1.0:
39
+ raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
40
+ return _VF.dropout_(input, 0., training) if inplace else _VF.dropout(input, 0., training)
41
+
42
+ def function_dropout2d(input: torch.Tensor, p: float = 0.5, training: bool = True,
43
+ inplace: bool = False) -> torch.Tensor:
44
+ if has_torch_function_unary(input):
45
+ return handle_torch_function(
46
+ function_dropout2d, (input,), input, p=0., training=training, inplace=inplace)
47
+ if p < 0.0 or p > 1.0:
48
+ raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
49
+ return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training)
50
+
51
+ def function_dropout3d(input: torch.Tensor, p: float = 0.5, training: bool = True,
52
+ inplace: bool = False) -> torch.Tensor:
53
+ if has_torch_function_unary(input):
54
+ return handle_torch_function(
55
+ function_dropout3d, (input,), input, p=0., training=training, inplace=inplace)
56
+ if p < 0.0 or p > 1.0:
57
+ raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
58
+ return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training)
59
+
60
+ F.dropout = function_dropout
61
+ F.dropout2d = function_dropout2d
62
+ F.dropout3d = function_dropout3d
63
+
64
+ cur_path = os.path.dirname(os.path.realpath(__file__))
65
+ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
66
+
67
+
68
+ def get_functional_ops():
69
+ yaml_data = load_yaml(yaml_path)
70
+ wrap_functional_ops = yaml_data.get('functional')
71
+ _all_functional_ops = dir(torch.nn.functional)
72
+ return set(wrap_functional_ops) & set(_all_functional_ops)
73
+
74
+
75
+ TorchFunctions = {func: getattr(torch.nn.functional, func) for func in get_functional_ops()}
76
+
77
+
78
+ class HOOKFunctionalOP(object):
79
+ pass
80
+
81
+
82
+ class FunctionalOPTemplate(HOOKModule):
83
+ def __init__(self, op_name, hook, need_hook=True):
84
+ self.op_name_ = op_name
85
+ self.prefix_op_name_ = "Functional" + Const.SEP + str(op_name) + Const.SEP
86
+ if need_hook:
87
+ super().__init__(hook)
88
+
89
+ @torch_device_guard
90
+ def forward(self, *args, **kwargs):
91
+ return TorchFunctions[str(self.op_name_)](*args, **kwargs)
92
+
93
+
94
+ def wrap_functional_op(op_name, hook):
95
+ def functional_op_template(*args, **kwargs):
96
+ return FunctionalOPTemplate(op_name, hook)(*args, **kwargs)
97
+
98
+ return functional_op_template
99
+
100
+
101
+ def wrap_functional_ops_and_bind(hook):
102
+ _functional_ops = get_functional_ops()
103
+ for op_name in _functional_ops:
104
+ setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook))
@@ -1,84 +1,85 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
-
18
- import os
19
- import torch
20
-
21
- from msprobe.pytorch.hook_module.hook_module import HOOKModule
22
- from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version
23
- from msprobe.core.common.const import Const
24
- from msprobe.core.common.utils import load_yaml
25
- from msprobe.pytorch.function_factory import npu_custom_functions
26
-
27
- cur_path = os.path.dirname(os.path.realpath(__file__))
28
- yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
29
-
30
-
31
- try:
32
- import torch_npu
33
- except ImportError:
34
- is_gpu = True
35
- else:
36
- is_gpu = False
37
-
38
-
39
- def get_npu_ops():
40
- if torch_without_guard_version:
41
- _npu_ops = dir(torch.ops.npu)
42
- else:
43
- _npu_ops = dir(torch_npu._C._VariableFunctionsClass)
44
- yaml_data = load_yaml(yaml_path)
45
- wrap_npu_ops = yaml_data.get('torch_npu')
46
- return set(wrap_npu_ops) & set(_npu_ops)
47
-
48
-
49
- class HOOKNpuOP(object):
50
- pass
51
-
52
-
53
- class NpuOPTemplate(HOOKModule):
54
-
55
- def __init__(self, op_name, hook, need_hook=True):
56
- self.op_name_ = op_name
57
- self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP
58
- self.need_hook = need_hook
59
- if need_hook:
60
- super().__init__(hook)
61
-
62
- @torch_device_guard
63
- def forward(self, *args, **kwargs):
64
- if not self.need_hook:
65
- if self.op_name_ not in npu_custom_functions:
66
- raise Exception(f'There is not bench function {self.op_name_}')
67
- return npu_custom_functions[self.op_name_](*args, **kwargs)
68
- if torch_without_guard_version:
69
- return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs)
70
- else:
71
- return getattr(torch_npu._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs)
72
-
73
-
74
- def wrap_npu_op(op_name, hook):
75
- def npu_op_template(*args, **kwargs):
76
- return NpuOPTemplate(op_name, hook)(*args, **kwargs)
77
-
78
- return npu_op_template
79
-
80
-
81
- def wrap_npu_ops_and_bind(hook):
82
- _npu_ops = get_npu_ops()
83
- for op_name in _npu_ops:
84
- setattr(HOOKNpuOP, "wrap_" + str(op_name), wrap_npu_op(op_name, hook))
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import torch
18
+
19
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
20
+ from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version
21
+ from msprobe.core.common.const import Const
22
+ from msprobe.core.common.log import logger
23
+ from msprobe.core.common.file_utils import load_yaml
24
+ from msprobe.pytorch.function_factory import npu_custom_functions
25
+
26
+ try:
27
+ import torch_npu
28
+ except ImportError:
29
+ logger.info("Failing to import torch_npu.")
30
+
31
+
32
+ cur_path = os.path.dirname(os.path.realpath(__file__))
33
+ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
34
+ cuda_func_mapping = {"npu_fusion_attention" : "gpu_fusion_attention"}
35
+
36
+
37
+ def get_npu_ops():
38
+ if torch_without_guard_version:
39
+ _npu_ops = dir(torch.ops.npu)
40
+ else:
41
+ _npu_ops = dir(torch_npu._C._VariableFunctionsClass)
42
+ yaml_data = load_yaml(yaml_path)
43
+ wrap_npu_ops = yaml_data.get('torch_npu')
44
+ return set(wrap_npu_ops) & set(_npu_ops)
45
+
46
+
47
+ class HOOKNpuOP(object):
48
+ pass
49
+
50
+
51
+ class NpuOPTemplate(HOOKModule):
52
+
53
+ def __init__(self, op_name, hook, need_hook=True, device=Const.CPU_LOWERCASE):
54
+ self.op_name_ = op_name
55
+ self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP
56
+ self.need_hook = need_hook
57
+ self.device = device
58
+ if need_hook:
59
+ super().__init__(hook)
60
+
61
+ @torch_device_guard
62
+ def forward(self, *args, **kwargs):
63
+ if not self.need_hook:
64
+ if self.op_name_ not in npu_custom_functions:
65
+ raise Exception(f'There is not bench function {self.op_name_}')
66
+ if self.device == Const.CUDA_LOWERCASE:
67
+ self.op_name_ = cuda_func_mapping.get(self.op_name_, self.op_name_)
68
+ if self.device in [Const.CUDA_LOWERCASE, Const.CPU_LOWERCASE]:
69
+ return npu_custom_functions[self.op_name_](*args, **kwargs)
70
+ if torch_without_guard_version:
71
+ return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs)
72
+ else:
73
+ return getattr(torch_npu._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs)
74
+
75
+
76
+ def wrap_npu_op(op_name, hook):
77
+ def npu_op_template(*args, **kwargs):
78
+ return NpuOPTemplate(op_name, hook)(*args, **kwargs)
79
+ return npu_op_template
80
+
81
+
82
+ def wrap_npu_ops_and_bind(hook):
83
+ _npu_ops = get_npu_ops()
84
+ for op_name in _npu_ops:
85
+ setattr(HOOKNpuOP, "wrap_" + str(op_name), wrap_npu_op(op_name, hook))
@@ -1,71 +1,69 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
-
18
- import os
19
-
20
- import torch
21
-
22
- from msprobe.pytorch.hook_module.hook_module import HOOKModule
23
- from msprobe.pytorch.common.utils import torch_device_guard, parameter_adapter
24
- from msprobe.core.common.const import Const
25
- from msprobe.core.common.utils import load_yaml
26
-
27
-
28
- cur_path = os.path.dirname(os.path.realpath(__file__))
29
- yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
30
-
31
-
32
- def get_tensor_ops():
33
- _tensor_ops = dir(torch.Tensor)
34
- yaml_data = load_yaml(yaml_path)
35
- wrap_tensor_ops = yaml_data.get('tensor')
36
- return set(wrap_tensor_ops) & set(_tensor_ops)
37
-
38
-
39
- TensorOps = {op: getattr(torch.Tensor, op) for op in get_tensor_ops()}
40
-
41
-
42
- class HOOKTensor(object):
43
- pass
44
-
45
-
46
- class TensorOPTemplate(HOOKModule):
47
-
48
- def __init__(self, op_name, hook, need_hook=True):
49
- self.op_name_ = op_name
50
- self.prefix_op_name_ = "Tensor" + Const.SEP + str(op_name) + Const.SEP
51
- if need_hook:
52
- super().__init__(hook)
53
-
54
- @torch_device_guard
55
- @parameter_adapter
56
- def forward(self, *args, **kwargs):
57
- return TensorOps[str(self.op_name_)](*args, **kwargs)
58
-
59
-
60
- def wrap_tensor_op(op_name, hook):
61
-
62
- def tensor_op_template(*args, **kwargs):
63
- return TensorOPTemplate(op_name, hook)(*args, **kwargs)
64
-
65
- return tensor_op_template
66
-
67
-
68
- def wrap_tensor_ops_and_bind(hook):
69
- _tensor_ops = get_tensor_ops()
70
- for op_name in _tensor_ops:
71
- setattr(HOOKTensor, "wrap_" + str(op_name), wrap_tensor_op(op_name, hook))
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ import torch
19
+
20
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
21
+ from msprobe.pytorch.common.utils import torch_device_guard, parameter_adapter
22
+ from msprobe.core.common.const import Const
23
+ from msprobe.core.common.file_utils import load_yaml
24
+
25
+
26
+ cur_path = os.path.dirname(os.path.realpath(__file__))
27
+ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
28
+
29
+
30
+ def get_tensor_ops():
31
+ _tensor_ops = dir(torch.Tensor)
32
+ yaml_data = load_yaml(yaml_path)
33
+ wrap_tensor_ops = yaml_data.get('tensor')
34
+ return set(wrap_tensor_ops) & set(_tensor_ops)
35
+
36
+
37
+ TensorOps = {op: getattr(torch.Tensor, op) for op in get_tensor_ops()}
38
+
39
+
40
+ class HOOKTensor(object):
41
+ pass
42
+
43
+
44
+ class TensorOPTemplate(HOOKModule):
45
+
46
+ def __init__(self, op_name, hook, need_hook=True):
47
+ self.op_name_ = op_name
48
+ self.prefix_op_name_ = "Tensor" + Const.SEP + str(op_name) + Const.SEP
49
+ if need_hook:
50
+ super().__init__(hook)
51
+
52
+ @torch_device_guard
53
+ @parameter_adapter
54
+ def forward(self, *args, **kwargs):
55
+ return TensorOps[str(self.op_name_)](*args, **kwargs)
56
+
57
+
58
+ def wrap_tensor_op(op_name, hook):
59
+
60
+ def tensor_op_template(*args, **kwargs):
61
+ return TensorOPTemplate(op_name, hook)(*args, **kwargs)
62
+
63
+ return tensor_op_template
64
+
65
+
66
+ def wrap_tensor_ops_and_bind(hook):
67
+ _tensor_ops = get_tensor_ops()
68
+ for op_name in _tensor_ops:
69
+ setattr(HOOKTensor, "wrap_" + str(op_name), wrap_tensor_op(op_name, hook))